diff --git a/agent/vault_injection.py b/agent/vault_injection.py new file mode 100644 index 0000000000..c46f8219d4 --- /dev/null +++ b/agent/vault_injection.py @@ -0,0 +1,117 @@ +"""Vault injection — auto-load Obsidian vault files into the system prompt. + +Reads working-context.md and user-profile.md from a configured vault path +at session start and injects them into the system prompt alongside Layer 1 +memory (MEMORY.md / USER.md). This is a structural fix for vault neglect: +the agent no longer needs to remember to read these files — they're injected +automatically, the same way Layer 1 memory is. + +The vault is Layer 3 in the memory architecture. Files injected here are +read-only in the system prompt (frozen at session start). Mid-session +writes to vault files require the read_file/write_file tools or the +memory-vault skill. + +Config (in config.yaml under 'vault'): + enabled: true # enable vault injection + path: /path/to/vault # absolute path to the Obsidian vault root + +Files read (relative to vault path): + Agent-Hermes/working-context.md — what the agent is actively doing + Agent-Shared/user-profile.md — who the user is (durable facts) + +If either file doesn't exist or is empty, it's silently skipped. +If the vault path doesn't exist or isn't configured, vault injection is +silently disabled. +""" + +import logging +import os +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +# Character limits for vault injection blocks (to prevent prompt bloat) +WORKING_CONTEXT_CHAR_LIMIT = 4000 +USER_PROFILE_CHAR_LIMIT = 4000 + +SEPARATOR = "\u2550" * 46 # ═ same as memory_tool uses + + +def _read_vault_file(path: Path, char_limit: int) -> Optional[str]: + """Read a vault file and return its content, or None if missing/empty. + + Truncates with a notice if the file exceeds char_limit. + """ + if not path.exists(): + return None + try: + content = path.read_text(encoding="utf-8").strip() + except (OSError, IOError) as e: + logger.debug("Could not read vault file %s: %s", path, e) + return None + + if not content: + return None + + # Strip YAML frontmatter (same as prompt_builder does for context files) + content = _strip_yaml_frontmatter(content) + + if not content: + return None + + if len(content) > char_limit: + truncated = content[:char_limit] + # Find last newline to avoid cutting mid-line + last_nl = truncated.rfind("\n") + if last_nl > char_limit // 2: + truncated = truncated[:last_nl] + content = truncated + f"\n[... truncated at {char_limit} chars ...]" + + return content + + +def _strip_yaml_frontmatter(content: str) -> str: + """Remove optional YAML frontmatter (--- delimited) from content.""" + if content.startswith("---"): + end = content.find("\n---", 3) + if end != -1: + body = content[end + 4:].lstrip("\n") + return body if body else content + return content + + +def build_vault_system_prompt(vault_path: str) -> str: + """Build the vault injection block for the system prompt. + + Reads working-context.md and user-profile.md from the vault and formats + them with headers matching the style of Layer 1 memory blocks. + + Returns an empty string if vault is disabled, path is missing, or + all files are empty. + """ + if not vault_path: + return "" + + vault_root = Path(vault_path) + if not vault_root.is_dir(): + logger.debug("Vault path does not exist or is not a directory: %s", vault_path) + return "" + + parts = [] + + # Read working-context.md (agent's current state) + wc_path = vault_root / "Agent-Hermes" / "working-context.md" + wc_content = _read_vault_file(wc_path, WORKING_CONTEXT_CHAR_LIMIT) + if wc_content: + header = "VAULT: WORKING CONTEXT (what you're doing right now)" + parts.append(f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{wc_content}") + + # Read user-profile.md (shared user profile) + up_path = vault_root / "Agent-Shared" / "user-profile.md" + up_content = _read_vault_file(up_path, USER_PROFILE_CHAR_LIMIT) + if up_content: + header = "VAULT: USER PROFILE (durable facts from Obsidian vault)" + parts.append(f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{up_content}") + + return "\n\n".join(parts) \ No newline at end of file diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 7678287a0e..71727913d7 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -754,6 +754,16 @@ DEFAULT_CONFIG = { "provider": "", }, + # Obsidian vault auto-injection — Layer 3 persistent memory. + # When enabled, working-context.md and user-profile.md are read from + # the vault path at session start and injected into the system prompt + # alongside Layer 1 memory. This is a structural fix for vault neglect: + # the agent no longer needs to remember to read these files manually. + "vault": { + "enabled": False, # set true to activate + "path": "", # absolute path to the Obsidian vault root + }, + # Subagent delegation — override the provider:model used by delegate_task # so child agents can run on a different (cheaper/faster) provider and model. # Uses the same runtime provider resolution as CLI/gateway startup, so all diff --git a/hermes_state.py b/hermes_state.py index ed95d25f45..61a837febc 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -31,7 +31,7 @@ T = TypeVar("T") DEFAULT_DB_PATH = get_hermes_home() / "state.db" -SCHEMA_VERSION = 8 +SCHEMA_VERSION = 9 SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS schema_version ( @@ -95,6 +95,13 @@ CREATE INDEX IF NOT EXISTS idx_sessions_source ON sessions(source); CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id); CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC); CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp); + +CREATE TABLE IF NOT EXISTS term_index ( + term TEXT NOT NULL, + message_id INTEGER NOT NULL REFERENCES messages(id), + session_id TEXT NOT NULL REFERENCES sessions(id), + PRIMARY KEY (term, message_id) +) WITHOUT ROWID; """ FTS_SQL = """ @@ -164,7 +171,14 @@ class SessionDB: self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA foreign_keys=ON") - self._init_schema() + needs_reindex = self._init_schema() + + # If we just migrated to v7, backfill the term index from existing + # messages. This runs outside _init_schema so we can use + # _execute_write (which manages its own transactions). + if needs_reindex: + logger.info("v7 migration detected — backfilling term index") + self.reindex_term_index() # ── Core write helper ── @@ -257,11 +271,17 @@ class SessionDB: self._conn = None def _init_schema(self): - """Create tables and FTS if they don't exist, run migrations.""" + """Create tables and FTS if they don't exist, run migrations. + + Returns True if a v7 migration was performed (term_index created + and needs backfill), False otherwise. + """ cursor = self._conn.cursor() cursor.executescript(SCHEMA_SQL) + needs_reindex = False + # Check schema version and run migrations cursor.execute("SELECT version FROM schema_version LIMIT 1") row = cursor.fetchone() @@ -356,6 +376,24 @@ class SessionDB: except sqlite3.OperationalError: pass # Column already exists cursor.execute("UPDATE schema_version SET version = 8") + if current_version < 9: + # v9: add term_index table for inverted-index session search. + # This is the clustered (term, message_id) WITHOUT ROWID table + # used by the session_search fast path. After creating the + # table, we backfill from existing messages. + try: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS term_index ( + term TEXT NOT NULL, + message_id INTEGER NOT NULL REFERENCES messages(id), + session_id TEXT NOT NULL REFERENCES sessions(id), + PRIMARY KEY (term, message_id) + ) WITHOUT ROWID + """) + except sqlite3.OperationalError: + pass # Table already exists + cursor.execute("UPDATE schema_version SET version = 9") + needs_reindex = True # Unique title index — always ensure it exists (safe to run after migrations # since the title column is guaranteed to exist at this point) @@ -375,6 +413,8 @@ class SessionDB: self._conn.commit() + return needs_reindex + # ========================================================================= # Session lifecycle # ========================================================================= @@ -569,6 +609,23 @@ class SessionDB: row = cursor.fetchone() return dict(row) if row else None + def get_child_session_ids(self, *parent_ids: str) -> List[str]: + """Return IDs of sessions whose parent_session_id is in *parent_ids*. + + Useful for finding delegation/compression child sessions that + belong to a parent conversation. Does NOT recurse — only + direct children are returned. + """ + if not parent_ids: + return [] + placeholders = ",".join("?" for _ in parent_ids) + with self._lock: + cursor = self._conn.execute( + f"SELECT id FROM sessions WHERE parent_session_id IN ({placeholders})", + list(parent_ids), + ) + return [row["id"] for row in cursor.fetchall()] + def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: """Resolve an exact or uniquely prefixed session ID to the full ID. @@ -1015,6 +1072,24 @@ class SessionDB: "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", (session_id,), ) + + # Insert terms into inverted index (delayed import avoids + # circular dependency — hermes_state is imported by nearly + # everything at startup, term_index must not be top-level) + # Skip tool-role messages: their structured JSON output produces + # noise terms (field names, numeric values) with no search value. + if content and role != "tool": + try: + from term_index import extract_terms + terms = extract_terms(content) + if terms: + conn.executemany( + "INSERT OR IGNORE INTO term_index (term, message_id, session_id) VALUES (?, ?, ?)", + [(t, msg_id, session_id) for t in terms], + ) + except Exception: + pass # Term indexing is best-effort; never block a message insert + return msg_id return self._execute_write(_do) @@ -1395,6 +1470,191 @@ class SessionDB: return matches + # ========================================================================= + # Term index search (inverted index fast path) + # ========================================================================= + + def search_by_terms( + self, + terms: List[str], + exclude_sources: List[str] = None, + limit: int = 10, + ) -> List[Dict[str, Any]]: + """ + Search sessions using the term_index inverted index. + + Takes a list of query terms, finds all message IDs for each term, + intersects them (AND logic), and returns matching sessions with + metadata and match counts. + + This is the fast path for session search -- no LLM calls needed. + """ + if not terms: + return [] + + # Filter out stop words from the query (belt and suspenders) + # Delayed import avoids circular dependency (same pattern as append_message) + from stop_words import is_stop_word + filtered = [t.lower() for t in terms if t and not is_stop_word(t.lower())] + if not filtered: + return [] + + # Build the query. For single terms, a simple WHERE suffices. + # For multiple terms, we use GROUP BY session_id + HAVING COUNT(DISTINCT term) = N + # to enforce AND semantics: only sessions containing ALL query terms match. + if len(filtered) == 1: + term = filtered[0] + params: list = [term] + where_sql = "ti.term = ?" + + if exclude_sources: + exclude_placeholders = ",".join("?" for _ in exclude_sources) + where_sql += f" AND s.source NOT IN ({exclude_placeholders})" + params.extend(exclude_sources) + + sql = f""" + SELECT ti.session_id, + s.source, + s.model, + s.started_at AS session_started, + s.title, + COUNT(DISTINCT ti.message_id) AS match_count + FROM term_index ti + JOIN sessions s ON s.id = ti.session_id + WHERE {where_sql} + GROUP BY ti.session_id + ORDER BY match_count DESC, s.started_at DESC + LIMIT ? + """ + params.append(limit) + else: + # Multi-term: GROUP BY + HAVING COUNT(DISTINCT term) = N enforces AND + term_placeholders = ",".join("?" for _ in filtered) + params = list(filtered) + exclude_sql = "" + if exclude_sources: + exclude_sql = f" AND s.source NOT IN ({','.join('?' for _ in exclude_sources)})" + params.extend(exclude_sources) + params.extend([len(filtered), limit]) + + sql = f""" + SELECT ti.session_id, + s.source, + s.model, + s.started_at AS session_started, + s.title, + COUNT(DISTINCT ti.term) AS term_count, + COUNT(DISTINCT ti.message_id) AS match_count + FROM term_index ti + JOIN sessions s ON s.id = ti.session_id + WHERE ti.term IN ({term_placeholders}) + {exclude_sql} + GROUP BY ti.session_id + HAVING COUNT(DISTINCT ti.term) = ? + ORDER BY match_count DESC, s.started_at DESC + LIMIT ? + """ + + with self._lock: + try: + cursor = self._conn.execute(sql, params) + results = [dict(row) for row in cursor.fetchall()] + except sqlite3.OperationalError: + logger.debug("term_index query failed", exc_info=True) + return [] + + return results + + def reindex_term_index(self, batch_size: int = 500) -> int: + """ + Rebuild the term_index from existing messages. + + Processes messages in batches to avoid holding the write lock too + long. Returns the total number of term entries inserted. + + Uses a swap strategy: builds a temporary table, then swaps it + into place in a single transaction. This avoids the empty-index + window that would occur with a simple clear-and-repopulate. + """ + from term_index import extract_terms + + # Count total messages to index + with self._lock: + total = self._conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0] + + if total == 0: + return 0 + + inserted = 0 + offset = 0 + + # Create a temporary table with the same schema as term_index. + # We'll populate this, then swap it into place in a single + # transaction — no empty-index window for concurrent readers. + def _create_temp(conn): + conn.execute(""" + CREATE TABLE IF NOT EXISTS _term_index_new ( + term TEXT NOT NULL, + message_id INTEGER NOT NULL, + session_id TEXT NOT NULL, + PRIMARY KEY (term, message_id) + ) WITHOUT ROWID + """) + conn.execute("DELETE FROM _term_index_new") + self._execute_write(_create_temp) + + while offset < total: + # Read batch outside write lock + with self._lock: + cursor = self._conn.execute( + "SELECT id, session_id, role, content FROM messages ORDER BY id LIMIT ? OFFSET ?", + (batch_size, offset), + ) + rows = cursor.fetchall() + + if not rows: + break + + # Extract terms for the batch, skipping tool-role messages + entries = [] + for row in rows: + msg_id = row["id"] + session_id = row["session_id"] + # Skip tool messages — structured JSON output produces noise terms + if row["role"] == "tool": + continue + content = row["content"] or "" + terms = extract_terms(content) + for term in terms: + entries.append((term, msg_id, session_id)) + + # Write batch to temp table + if entries: + def _insert(conn, _entries=entries): + conn.executemany( + "INSERT OR IGNORE INTO _term_index_new (term, message_id, session_id) VALUES (?, ?, ?)", + _entries, + ) + return len(_entries) + inserted += self._execute_write(_insert) + + offset += batch_size + + # Swap: replace term_index with the new table in one transaction. + # Concurrent readers see either the old index or the new one — + # never an empty table. + def _swap(conn): + conn.execute("DELETE FROM term_index") + conn.execute(""" + INSERT INTO term_index (term, message_id, session_id) + SELECT term, message_id, session_id FROM _term_index_new + """) + conn.execute("DROP TABLE _term_index_new") + self._execute_write(_swap) + + logger.info("Reindexed term_index: %d entries from %d messages", inserted, total) + return inserted + def search_sessions( self, source: str = None, @@ -1466,8 +1726,14 @@ class SessionDB: return results def clear_messages(self, session_id: str) -> None: - """Delete all messages for a session and reset its counters.""" + """Delete all messages for a session and reset its counters. + + Also removes stale term_index entries that reference the deleted messages. + """ def _do(conn): + conn.execute( + "DELETE FROM term_index WHERE session_id = ?", (session_id,) + ) conn.execute( "DELETE FROM messages WHERE session_id = ?", (session_id,) ) @@ -1496,6 +1762,7 @@ class SessionDB: "WHERE parent_session_id = ?", (session_id,), ) + conn.execute("DELETE FROM term_index WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) return True @@ -1536,6 +1803,7 @@ class SessionDB: ) for sid in session_ids: + conn.execute("DELETE FROM term_index WHERE session_id = ?", (sid,)) conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) return len(session_ids) diff --git a/run_agent.py b/run_agent.py index 6770f568c0..c1b2d60658 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1566,7 +1566,8 @@ class AIAgent: try: from hermes_cli.config import load_config as _load_agent_config _agent_cfg = _load_agent_config() - except Exception: + except Exception as e: + logger.warning("Agent init: load_config() failed: %s: %s — using empty config", type(e).__name__, e) _agent_cfg = {} # Cache only the derived auxiliary compression context override that is # needed later by the startup feasibility check. Avoid exposing a @@ -1597,8 +1598,20 @@ class AIAgent: self._memory_store.load_from_disk() except Exception: pass # Memory is optional -- don't break agent init - + # Obsidian vault auto-injection (Layer 3) — structural fix for + # vault neglect. Reads working-context.md and user-profile.md + # from the configured vault path and injects them into the system + # prompt at session start, just like Layer 1 memory. + self._vault_enabled = False + self._vault_path = "" + try: + vault_config = _agent_cfg.get("vault", {}) + self._vault_enabled = vault_config.get("enabled", False) + self._vault_path = vault_config.get("path", "") + logging.getLogger("agent.vault").info("Vault config: enabled=%s path=%s", self._vault_enabled, self._vault_path) + except Exception as e: + logging.getLogger("agent.vault").warning("Vault config read failed: %s: %s", type(e).__name__, e) # Memory provider plugin (external — one at a time, alongside built-in) # Reads memory.provider from config to select which plugin to activate. @@ -4448,6 +4461,24 @@ class AIAgent: if user_block: prompt_parts.append(user_block) + # Vault auto-injection (Layer 3) — reads working-context.md and + # user-profile.md from the Obsidian vault and injects them into + # the system prompt. Structural fix for vault neglect. + _vault_log = logging.getLogger("agent.vault") + if self._vault_enabled and self._vault_path: + try: + from agent.vault_injection import build_vault_system_prompt + _vault_block = build_vault_system_prompt(self._vault_path) + if _vault_block: + prompt_parts.append(_vault_block) + _vault_log.info("Injection succeeded: %d chars from %s", len(_vault_block), self._vault_path) + else: + _vault_log.warning("Injection returned empty for path %s", self._vault_path) + except Exception as e: + _vault_log.warning("Injection failed: %s: %s", type(e).__name__, e) + else: + _vault_log.info("Injection skipped: enabled=%s path=%s", self._vault_enabled, self._vault_path) + # External memory provider system prompt block (additive to built-in) if self._memory_manager: try: diff --git a/stop_words.py b/stop_words.py new file mode 100644 index 0000000000..d1346b4c25 --- /dev/null +++ b/stop_words.py @@ -0,0 +1,79 @@ +"""Stop word list for term index extraction. + +Uses the well-known NLTK English stop word list (179 words) as a baseline, +plus common JSON schema keys from tool output and pure-numeric filter. + +This module is self-contained -- no external dependencies. +""" + +import re + +# Standard English stop words (NLTK list, public domain) +# Covers articles, conjunctions, prepositions, pronouns, auxiliary verbs, +# and common function words. Intentionally excludes short tech terms +# that overlap (e.g., "go", "it" as in IT/InfoTech handled by context). +_ENGLISH_STOP_WORDS = frozenset( + w.lower() for w in [ + "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", + "your", "yours", "yourself", "yourselves", "he", "him", "his", + "himself", "she", "her", "hers", "herself", "it", "its", "itself", + "they", "them", "their", "theirs", "themselves", "what", "which", + "who", "whom", "this", "that", "these", "those", "am", "is", "are", + "was", "were", "be", "been", "being", "have", "has", "had", "having", + "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", + "or", "because", "as", "until", "while", "of", "at", "by", "for", + "with", "about", "against", "between", "through", "during", "before", + "after", "above", "below", "to", "from", "up", "down", "in", "out", + "on", "off", "over", "under", "again", "further", "then", "once", + "here", "there", "when", "where", "why", "how", "all", "both", "each", + "few", "more", "most", "other", "some", "such", "no", "nor", "not", + "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", + "will", "just", "don", "should", "now", "d", "ll", "m", "o", "re", + "ve", "y", "ain", "aren", "couldn", "didn", "doesn", "hadn", "hasn", + "haven", "isn", "ma", "mightn", "mustn", "needn", "shan", "shouldn", + "wasn", "weren", "won", "wouldn", + ] +) + +# JSON schema keys that appear constantly in tool output. +# These are field names from structured tool responses, not semantic content. +# Nobody searches for "exit_code" to find a past session. +_JSON_KEY_STOP_WORDS = frozenset([ + "output", + "exit_code", + "error", + "null", + "true", + "false", + "status", + "content", + "message", + "cleared", + "success", +]) + +# Combined stop word set +_STOP_WORDS = _ENGLISH_STOP_WORDS | _JSON_KEY_STOP_WORDS + +# Pattern to detect pure numeric tokens (integers, floats, hex) +_NUMERIC_RE = re.compile(r"^[0-9]+$") + + +def is_stop_word(word: str) -> bool: + """Check if a word is a stop word. Case-insensitive.""" + return word.lower() in _STOP_WORDS + + +def is_noise_term(word: str) -> bool: + """Check if a term is noise that should be excluded from the index. + + This covers stop words AND pure numeric tokens, which provide zero + search value. Nobody searches for '0', '1', or '42' to find a session. + """ + lower = word.lower() + return lower in _STOP_WORDS or _NUMERIC_RE.match(lower) is not None + + +def get_stop_words() -> frozenset: + """Return the full stop word set (for inspection/bulk use).""" + return _STOP_WORDS \ No newline at end of file diff --git a/term_index.py b/term_index.py new file mode 100644 index 0000000000..ebaa11781e --- /dev/null +++ b/term_index.py @@ -0,0 +1,47 @@ +"""Term index — inverted index extraction for session search fast path. + +Extracts non-stop-word terms from message content for insertion into the +term_index table in SessionDB. Terms are lowercased, punctuation-stripped +(with preservation of path-like strings), and deduplicated per message. + +Noise filtering: + - English stop words (NLTK list) + - JSON schema keys from tool output (output, exit_code, error, etc.) + - Pure numeric tokens (0, 1, 42, etc.) +""" + +import re +from stop_words import is_noise_term + +# Matches "words" including paths (foo/bar), filenames (file.py), and +# hyphenated terms (self-hosted). Filters out most punctuation but +# preserves dots in filenames and slashes in paths. +# Strategy: split on whitespace first, then strip leading/trailing punctuation. +_TERM_RE = re.compile(r"[a-zA-Z0-9][\w./\-]*[a-zA-Z0-9]|[a-zA-Z0-9]") + + +def extract_terms(content: str) -> list[str]: + """Extract non-noise terms from message content. + + Returns a deduplicated, lowercased list of terms. + Stop words, JSON keys, pure numerics, and empty strings are excluded. + """ + if not content: + return [] + + # Find candidate tokens + raw_tokens = _TERM_RE.findall(content) + + seen = set() + terms = [] + for token in raw_tokens: + lower = token.lower() + # Skip noise: stop words, JSON keys, pure numerics + if is_noise_term(lower): + continue + # Deduplicate within this message + if lower not in seen: + seen.add(lower) + terms.append(lower) + + return terms \ No newline at end of file diff --git a/tests/agent/test_vault_injection.py b/tests/agent/test_vault_injection.py new file mode 100644 index 0000000000..0eae544e41 --- /dev/null +++ b/tests/agent/test_vault_injection.py @@ -0,0 +1,174 @@ +"""Tests for agent/vault_injection.py — Obsidian vault auto-injection into system prompt.""" + +import pytest +import os +from pathlib import Path + +from agent.vault_injection import ( + build_vault_system_prompt, + _read_vault_file, + _strip_yaml_frontmatter, + WORKING_CONTEXT_CHAR_LIMIT, + USER_PROFILE_CHAR_LIMIT, +) + + +# --------------------------------------------------------------------------- +# _strip_yaml_frontmatter +# --------------------------------------------------------------------------- + +class TestStripYamlFrontmatter: + def test_strips_simple_frontmatter(self): + content = "---\ndate: 2026-04-22\n---\nHello world" + assert _strip_yaml_frontmatter(content) == "Hello world" + + def test_no_frontmatter(self): + content = "Just some text" + assert _strip_yaml_frontmatter(content) == "Just some text" + + def test_frontmatter_with_blank_lines(self): + content = "---\ndate: 2026-04-22\nprojects: [X]\n---\n\nActual content here" + result = _strip_yaml_frontmatter(content) + assert result == "Actual content here" + + def test_unclosed_frontmatter_returns_original(self): + content = "---\ndate: 2026-04-22\nNo closing dashes" + assert _strip_yaml_frontmatter(content) == content + + +# --------------------------------------------------------------------------- +# _read_vault_file +# --------------------------------------------------------------------------- + +class TestReadVaultFile: + def test_reads_existing_file(self, tmp_path): + f = tmp_path / "test.md" + f.write_text("some content", encoding="utf-8") + result = _read_vault_file(f, 4000) + assert result == "some content" + + def test_returns_none_for_missing_file(self, tmp_path): + f = tmp_path / "nonexistent.md" + result = _read_vault_file(f, 4000) + assert result is None + + def test_returns_none_for_empty_file(self, tmp_path): + f = tmp_path / "empty.md" + f.write_text("", encoding="utf-8") + result = _read_vault_file(f, 4000) + assert result is None + + def test_returns_none_for_whitespace_only(self, tmp_path): + f = tmp_path / "ws.md" + f.write_text(" \n\n ", encoding="utf-8") + result = _read_vault_file(f, 4000) + assert result is None + + def test_strips_frontmatter(self, tmp_path): + f = tmp_path / "frontmatter.md" + f.write_text("---\ndate: 2026-04-22\n---\nReal content", encoding="utf-8") + result = _read_vault_file(f, 4000) + assert result == "Real content" + + def test_truncates_long_file(self, tmp_path): + f = tmp_path / "long.md" + long_content = "x" * 5000 + f.write_text(long_content, encoding="utf-8") + result = _read_vault_file(f, 100) + assert len(result) < 200 # truncation + notice + assert "truncated" in result + + def test_truncation_at_newline(self, tmp_path): + f = tmp_path / "multiline.md" + lines = ["line " + str(i) for i in range(100)] + content = "\n".join(lines) + f.write_text(content, encoding="utf-8") + # Small limit, should truncate at a newline boundary + result = _read_vault_file(f, 50) + assert "truncated" in result + # Should not cut mid-line + for line in result.split("\n"): + if line and "truncated" not in line: + assert line.startswith("line") + + +# --------------------------------------------------------------------------- +# build_vault_system_prompt +# --------------------------------------------------------------------------- + +class TestBuildVaultSystemPrompt: + def test_empty_path_returns_empty(self): + assert build_vault_system_prompt("") == "" + + def test_nonexistent_path_returns_empty(self, tmp_path): + assert build_vault_system_prompt(str(tmp_path / "nope")) == "" + + def test_empty_vault_dir_returns_empty(self, tmp_path): + assert build_vault_system_prompt(str(tmp_path)) == "" + + def test_injects_working_context(self, tmp_path): + vault = tmp_path / "vault" + agent_dir = vault / "Agent-Hermes" + agent_dir.mkdir(parents=True) + wc = agent_dir / "working-context.md" + wc.write_text("---\ndate: 2026-04-22\n---\n## Current Status\n- Status: Active", encoding="utf-8") + + result = build_vault_system_prompt(str(vault)) + assert "VAULT: WORKING CONTEXT" in result + assert "Status: Active" in result + + def test_injects_user_profile(self, tmp_path): + vault = tmp_path / "vault" + shared_dir = vault / "Agent-Shared" + shared_dir.mkdir(parents=True) + up = shared_dir / "user-profile.md" + up.write_text("# User Profile\n\nName: AJ", encoding="utf-8") + + result = build_vault_system_prompt(str(vault)) + assert "VAULT: USER PROFILE" in result + assert "Name: AJ" in result + + def test_injects_both_files(self, tmp_path): + vault = tmp_path / "vault" + agent_dir = vault / "Agent-Hermes" + shared_dir = vault / "Agent-Shared" + agent_dir.mkdir(parents=True) + shared_dir.mkdir(parents=True) + + (agent_dir / "working-context.md").write_text( + "---\ndate: 2026-04-22\n---\nWorking on X", encoding="utf-8" + ) + (shared_dir / "user-profile.md").write_text( + "# User Profile\n\nName: AJ", encoding="utf-8" + ) + + result = build_vault_system_prompt(str(vault)) + assert "VAULT: WORKING CONTEXT" in result + assert "VAULT: USER PROFILE" in result + assert "Working on X" in result + assert "Name: AJ" in result + + def test_skips_empty_working_context(self, tmp_path): + vault = tmp_path / "vault" + agent_dir = vault / "Agent-Hermes" + shared_dir = vault / "Agent-Shared" + agent_dir.mkdir(parents=True) + shared_dir.mkdir(parents=True) + + (agent_dir / "working-context.md").write_text("", encoding="utf-8") + (shared_dir / "user-profile.md").write_text("Name: AJ", encoding="utf-8") + + result = build_vault_system_prompt(str(vault)) + assert "VAULT: WORKING CONTEXT" not in result + assert "VAULT: USER PROFILE" in result + + def test_format_matches_memory_block_style(self, tmp_path): + vault = tmp_path / "vault" + agent_dir = vault / "Agent-Hermes" + agent_dir.mkdir(parents=True) + (agent_dir / "working-context.md").write_text("Active task", encoding="utf-8") + + result = build_vault_system_prompt(str(vault)) + # Should use the same separator as memory_tool (═══) + assert "\u2550" in result # ═ character + assert "VAULT: WORKING CONTEXT" in result \ No newline at end of file diff --git a/tests/run_agent/test_vault_injection.py b/tests/run_agent/test_vault_injection.py new file mode 100644 index 0000000000..76529c62ce --- /dev/null +++ b/tests/run_agent/test_vault_injection.py @@ -0,0 +1,161 @@ +"""Tests for vault auto-injection integration with _build_system_prompt. + +Verifies that vault content appears in the system prompt when vault is +configured, and is absent otherwise. +""" + +import os +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch + + +def _make_minimal_agent(**overrides): + """Create a minimal AIAgent for testing, with vault attrs settable.""" + from run_agent import AIAgent + + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-k...7890", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + + # Apply overrides after creation + for k, v in overrides.items(): + setattr(a, k, v) + + return a + + +class TestVaultSystemPromptIntegration: + """Test that _build_system_prompt injects vault content when configured.""" + + def test_vault_not_injected_when_disabled(self, tmp_path): + """Vault content should not appear when vault_enabled=False.""" + vault_dir = tmp_path / "vault" + agent_dir = vault_dir / "Agent-Hermes" + agent_dir.mkdir(parents=True) + (agent_dir / "working-context.md").write_text("Active task X", encoding="utf-8") + + agent = _make_minimal_agent( + _vault_enabled=False, + _vault_path=str(vault_dir), + ) + + prompt = agent._build_system_prompt() + assert "VAULT: WORKING CONTEXT" not in prompt + assert "Active task X" not in prompt + + def test_vault_injected_when_enabled(self, tmp_path): + """Vault content should appear in system prompt when vault_enabled=True.""" + vault_dir = tmp_path / "vault" + agent_dir = vault_dir / "Agent-Hermes" + shared_dir = vault_dir / "Agent-Shared" + agent_dir.mkdir(parents=True) + shared_dir.mkdir(parents=True) + (agent_dir / "working-context.md").write_text( + "---\ndate: 2026-04-22\n---\n## Status\nActive: vault fix", + encoding="utf-8", + ) + (shared_dir / "user-profile.md").write_text( + "# User Profile\n\nName: Test User", + encoding="utf-8", + ) + + agent = _make_minimal_agent( + _vault_enabled=True, + _vault_path=str(vault_dir), + ) + + prompt = agent._build_system_prompt() + assert "VAULT: WORKING CONTEXT" in prompt + assert "Active: vault fix" in prompt + assert "VAULT: USER PROFILE" in prompt + assert "Name: Test User" in prompt + + def test_vault_after_memory_blocks(self, tmp_path): + """Vault injection should appear after Layer 1 memory blocks.""" + # Set up memory files + mem_dir = tmp_path / "memories" + mem_dir.mkdir(parents=True) + (mem_dir / "MEMORY.md").write_text("Layer 1 memory note", encoding="utf-8") + + # Set up vault files + vault_dir = tmp_path / "vault" + agent_dir = vault_dir / "Agent-Hermes" + agent_dir.mkdir(parents=True) + (agent_dir / "working-context.md").write_text("Vault content", encoding="utf-8") + + # Create agent with memory enabled + from run_agent import AIAgent + from tools.memory_tool import MemoryStore + + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch( + "hermes_cli.config.load_config", + return_value={ + "memory": { + "memory_enabled": True, + "user_profile_enabled": False, + "memory_char_limit": 2200, + "user_char_limit": 1375, + }, + }, + ), + ): + monkeypatch_env = {} + # Set HERMES_HOME so MemoryStore reads from tmp_path + os.environ["HERMES_HOME"] = str(tmp_path) + try: + a = AIAgent( + api_key="test-k...7890", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=False, + ) + a.client = MagicMock() + finally: + del os.environ["HERMES_HOME"] + + a._vault_enabled = True + a._vault_path = str(vault_dir) + + prompt = a._build_system_prompt() + mem_pos = prompt.find("MEMORY (your personal notes)") + vault_pos = prompt.find("VAULT: WORKING CONTEXT") + assert mem_pos > 0, "Layer 1 memory block not found in prompt" + assert vault_pos > 0, "Vault block not found in prompt" + assert mem_pos < vault_pos, "Vault should appear after Layer 1 memory" + + def test_missing_vault_path_graceful(self, tmp_path): + """Agent works fine even if vault path doesn't exist.""" + agent = _make_minimal_agent( + _vault_enabled=True, + _vault_path="/nonexistent/vault/path", + ) + + # Should not crash + prompt = agent._build_system_prompt() + assert "VAULT:" not in prompt + + def test_no_vault_config_graceful(self): + """Agent works fine with no vault set (defaults).""" + agent = _make_minimal_agent( + _vault_enabled=False, + _vault_path="", + ) + + prompt = agent._build_system_prompt() + assert "VAULT:" not in prompt \ No newline at end of file diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index f405cf8bd5..6df013da1a 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -1173,7 +1173,7 @@ class TestSchemaInit: def test_schema_version(self, db): cursor = db._conn.execute("SELECT version FROM schema_version") version = cursor.fetchone()[0] - assert version == 8 + assert version == 9 def test_title_column_exists(self, db): """Verify the title column was created in the sessions table.""" @@ -1229,12 +1229,12 @@ class TestSchemaInit: conn.commit() conn.close() - # Open with SessionDB — should migrate to v8 + # Open with SessionDB — should migrate to v9 migrated_db = SessionDB(db_path=db_path) # Verify migration cursor = migrated_db._conn.execute("SELECT version FROM schema_version") - assert cursor.fetchone()[0] == 8 + assert cursor.fetchone()[0] == 9 # Verify title column exists and is NULL for existing sessions session = migrated_db.get_session("existing") diff --git a/tests/test_term_index.py b/tests/test_term_index.py new file mode 100644 index 0000000000..c2e8944ffa --- /dev/null +++ b/tests/test_term_index.py @@ -0,0 +1,715 @@ +"""Tests for term_index — inverted index for session search fast path. + +Covers: stop word filtering, term extraction, term insertion at write time, +term-based search with session-level results, multi-term intersection. +""" + +import time +import pytest +from pathlib import Path + +from hermes_state import SessionDB + + +@pytest.fixture() +def db(tmp_path): + """Create a SessionDB with a temp database file.""" + db_path = tmp_path / "test_state.db" + session_db = SessionDB(db_path=db_path) + yield session_db + session_db.close() + + +# ========================================================================= +# Stop word filtering +# ========================================================================= + +class TestStopWords: + def test_common_english_words_are_stopped(self): + from stop_words import is_stop_word + for w in ["the", "and", "is", "in", "it", "of", "to", "a", "was", "for"]: + assert is_stop_word(w), f"'{w}' should be a stop word" + + def test_case_insensitive_stop_words(self): + from stop_words import is_stop_word + assert is_stop_word("The") + assert is_stop_word("AND") + assert is_stop_word("Is") + + def test_non_stop_words_pass(self): + from stop_words import is_stop_word + for w in ["docker", "kubernetes", "python", "hermes", "session"]: + assert not is_stop_word(w), f"'{w}' should NOT be a stop word" + + def test_short_words_not_auto_stopped(self): + """Single letters and 2-letter words that aren't in the list should pass.""" + from stop_words import is_stop_word + # 'go' is a real tech term, 'I' is a stop word + assert not is_stop_word("go") + assert is_stop_word("I") + + +# ========================================================================= +# Term extraction +# ========================================================================= + +class TestTermExtraction: + def test_extracts_words_from_content(self): + from term_index import extract_terms + terms = extract_terms("docker compose up -d") + assert "docker" in terms + assert "compose" in terms + + def test_strips_punctuation(self): + from term_index import extract_terms + terms = extract_terms("It's working! Check the file.py, okay?") + assert "working" in terms + assert "file.py" in terms # dots in filenames preserved + assert "okay" in terms + + def test_filters_stop_words(self): + from term_index import extract_terms + terms = extract_terms("the docker container is running in the background") + assert "the" not in terms + assert "is" not in terms + assert "in" not in terms + assert "docker" in terms + assert "container" in terms + assert "running" in terms + + def test_case_folded(self): + from term_index import extract_terms + terms = extract_terms("Docker DOCKER docker") + # Should be case-folded to single term + assert len(terms) == len(set(terms)), "Terms should be deduplicated after case folding" + + def test_empty_content(self): + from term_index import extract_terms + terms = extract_terms("") + assert terms == [] + + def test_none_content(self): + from term_index import extract_terms + terms = extract_terms(None) + assert terms == [] + + def test_preserves_paths_and_commands(self): + from term_index import extract_terms + terms = extract_terms("edited /etc/hosts and ran git push origin main") + assert "/etc/hosts" in terms or "etc/hosts" in terms # path fragment + assert "git" in terms + assert "push" in terms + + +# ========================================================================= +# Term index insertion +# ========================================================================= + +class TestTermIndexInsertion: + def test_terms_inserted_on_append_message(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message( + session_id="s1", + role="user", + content="I need to deploy the docker container", + ) + + # Should be able to find the message by term + results = db.search_by_terms(["docker"]) + assert len(results) >= 1 + assert any(r["session_id"] == "s1" for r in results) + + def test_stop_words_not_indexed(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message( + session_id="s1", + role="user", + content="the and is in of to a", + ) + + # All stop words — should find nothing + results = db.search_by_terms(["the", "and", "is"]) + assert len(results) == 0 + + def test_same_term_multiple_messages_same_session(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="docker is great") + db.append_message(session_id="s1", role="assistant", content="docker compose ready") + + results = db.search_by_terms(["docker"]) + # Should return session once, not twice + sids = [r["session_id"] for r in results] + assert sids.count("s1") == 1 + + def test_term_indexed_across_sessions(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="telegram") + db.append_message(session_id="s1", role="user", content="fix the docker bug") + db.append_message(session_id="s2", role="user", content="docker pull failed") + + results = db.search_by_terms(["docker"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" in sids + + +# ========================================================================= +# Term-based search +# ========================================================================= + +class TestTermSearch: + def test_single_term_search(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message( + session_id="s1", + role="user", + content="I need to configure kubernetes", + ) + + results = db.search_by_terms(["kubernetes"]) + assert len(results) >= 1 + assert results[0]["session_id"] == "s1" + # Should include session metadata + assert "source" in results[0] + assert "started_at" in results[0] or "session_started" in results[0] + + def test_multi_term_intersection(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.create_session(session_id="s3", source="cli") + + db.append_message(session_id="s1", role="user", content="docker networking issue") + db.append_message(session_id="s2", role="user", content="docker container running") + db.append_message(session_id="s3", role="user", content="kubernetes networking problem") + + # Both "docker" AND "networking" should only match s1 + results = db.search_by_terms(["docker", "networking"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + assert "s3" not in sids + + def test_search_returns_empty_for_stop_words_only(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="the and is") + + results = db.search_by_terms(["the", "and"]) + assert results == [] + + def test_search_excludes_hidden_sources(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="tool") + db.append_message(session_id="s1", role="user", content="docker deployment") + db.append_message(session_id="s2", role="user", content="docker deployment tool") + + results = db.search_by_terms(["docker"], exclude_sources=["tool"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + + def test_search_with_limit(self, db): + for i in range(5): + sid = f"s{i}" + db.create_session(session_id=sid, source="cli") + db.append_message(session_id=sid, role="user", content="python script") + + results = db.search_by_terms(["python"], limit=3) + assert len(results) <= 3 + + def test_nonexistent_term_returns_empty(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="hello world") + + results = db.search_by_terms(["nonexistent_xyzzy"]) + assert results == [] + + def test_term_result_includes_match_count(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="docker docker docker") + db.append_message(session_id="s1", role="assistant", content="docker ready") + + results = db.search_by_terms(["docker"]) + assert len(results) >= 1 + # Should tell us how many messages matched in the session + assert "match_count" in results[0] + + +# ========================================================================= +# Schema and migration +# ========================================================================= + +class TestTermIndexSchema: + def test_term_index_table_exists(self, db): + cursor = db._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='term_index'" + ) + assert cursor.fetchone() is not None + + def test_term_index_is_without_rowid(self, db): + cursor = db._conn.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='term_index'" + ) + row = cursor.fetchone() + assert row is not None + assert "WITHOUT ROWID" in row[0] + + def test_schema_version_bumped(self, db): + cursor = db._conn.execute("SELECT version FROM schema_version LIMIT 1") + version = cursor.fetchone()[0] + assert version >= 9 + + def test_existing_data_survives_migration(self, tmp_path): + """Create a v6 DB, then open it with current code -- data should survive.""" + # Build a v6 DB manually + db_path = tmp_path / "migrate.db" + db = SessionDB(db_path=db_path) + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="hello world") + db.close() + + # Re-open -- migration should run, data intact + db2 = SessionDB(db_path=db_path) + session = db2.get_session("s1") + assert session is not None + assert session["source"] == "cli" + # term_index should now exist + cursor = db2._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='term_index'" + ) + assert cursor.fetchone() is not None + db2.close() + + def test_v9_migration_auto_reindexes(self, tmp_path): + """When a v6 DB with existing messages is opened, the v9 migration + should create the term_index and backfill it automatically.""" + db_path = tmp_path / "migrate_v9.db" + + # Step 1: Create a fresh DB, add messages, then manually downgrade + # to v6 so the next open triggers the migration path. + db = SessionDB(db_path=db_path) + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.append_message(session_id="s1", role="user", content="deploy the kubernetes cluster") + db.append_message(session_id="s2", role="user", content="debug docker networking issue") + db.close() + + # Step 2: Re-open raw, manually set version to 6 and wipe term_index + # to simulate a pre-v7 DB. + import sqlite3 + conn = sqlite3.connect(str(db_path)) + conn.execute("UPDATE schema_version SET version = 6") + conn.execute("DROP TABLE IF EXISTS term_index") + conn.commit() + conn.close() + + # Step 3: Open with SessionDB — should migrate to v9 and auto-reindex. + db2 = SessionDB(db_path=db_path) + # Verify version is now 9 + cursor = db2._conn.execute("SELECT version FROM schema_version") + assert cursor.fetchone()[0] == 9 + + # Verify term_index is populated — search should find the terms + results = db2.search_by_terms(["kubernetes"]) + assert len(results) >= 1 + assert results[0]["session_id"] == "s1" + + results2 = db2.search_by_terms(["docker"]) + assert len(results2) >= 1 + assert results2[0]["session_id"] == "s2" + + db2.close() + + +# ========================================================================= +# Regression tests for red-team QA bugs +# ========================================================================= + +class TestClearMessagesCleansTermIndex: + """BUG 3: clear_messages() left stale term_index entries. + + After clearing messages, search_by_terms should return zero results + for that session, not ghost matches pointing to deleted message IDs. + """ + + def test_clear_messages_removes_term_entries(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="docker networking issue") + + # Confirm indexed + results = db.search_by_terms(["docker"]) + assert len(results) >= 1 + + # Clear messages + db.clear_messages(session_id="s1") + + # Term entries should be gone + results = db.search_by_terms(["docker"]) + assert results == [] + + def test_clear_messages_does_not_affect_other_sessions(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.append_message(session_id="s1", role="user", content="docker test") + db.append_message(session_id="s2", role="user", content="docker prod") + + db.clear_messages(session_id="s1") + + # s2 should still be searchable + results = db.search_by_terms(["docker"]) + sids = [r["session_id"] for r in results] + assert "s2" in sids + assert "s1" not in sids + + def test_clear_messages_no_stray_term_rows(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="kubernetes deployment") + + db.clear_messages(session_id="s1") + + cursor = db._conn.execute( + "SELECT COUNT(*) FROM term_index WHERE session_id = 's1'" + ) + assert cursor.fetchone()[0] == 0 + + +class TestSearchByTermsParamBinding: + """BUG 1: search_by_terms() had dead code with wrong param binding. + + The multi-term GROUP BY + HAVING path is the one that actually runs. + These tests verify parameter binding is correct for both single and + multi-term queries, including with exclude_sources. + """ + + def test_single_term_with_exclude_sources(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="tool") + db.append_message(session_id="s1", role="user", content="docker deploy") + db.append_message(session_id="s2", role="user", content="docker deploy") + + results = db.search_by_terms(["docker"], exclude_sources=["tool"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + + def test_multi_term_and_semantics(self, db): + """Multi-term search should use AND: only sessions with ALL terms match.""" + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.create_session(session_id="s3", source="cli") + db.append_message(session_id="s1", role="user", content="docker networking issue") + db.append_message(session_id="s2", role="user", content="docker container only") + db.append_message(session_id="s3", role="user", content="networking problem only") + + results = db.search_by_terms(["docker", "networking"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + assert "s3" not in sids + + def test_multi_term_with_exclude_sources(self, db): + """Multi-term + exclude_sources: param binding must be correct.""" + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="tool") + db.append_message(session_id="s1", role="user", content="docker networking setup") + db.append_message(session_id="s2", role="user", content="docker networking deploy") + + results = db.search_by_terms( + ["docker", "networking"], exclude_sources=["tool"] + ) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + + def test_three_term_intersection(self, db): + """Three-term AND: all three must be present in the session.""" + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.append_message(session_id="s1", role="user", content="docker kubernetes aws deployment") + db.append_message(session_id="s2", role="user", content="docker kubernetes only two terms") + + results = db.search_by_terms(["docker", "kubernetes", "aws"]) + sids = [r["session_id"] for r in results] + assert "s1" in sids + assert "s2" not in sids + + +class TestDeleteSessionCleansTermIndex: + """Verify delete_session() and prune_sessions() clean term_index.""" + + def test_delete_session_removes_term_entries(self, db): + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="docker deploy") + db.append_message(session_id="s1", role="assistant", content="docker is running") + + db.delete_session(session_id="s1") + + cursor = db._conn.execute( + "SELECT COUNT(*) FROM term_index WHERE session_id = 's1'" + ) + assert cursor.fetchone()[0] == 0 + + def test_delete_session_does_not_affect_other_sessions(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + db.append_message(session_id="s1", role="user", content="docker one") + db.append_message(session_id="s2", role="user", content="docker two") + + db.delete_session(session_id="s1") + + results = db.search_by_terms(["docker"]) + sids = [r["session_id"] for r in results] + assert "s2" in sids + assert "s1" not in sids + + +class TestFastSearchSessionResolution: + """BUG 2: _fast_search didn't resolve child sessions to parent. + + A delegation child and its parent both containing "docker" would appear + as two separate results. They should be resolved to the parent session. + Also, current session lineage exclusion must cover the entire chain. + """ + + def test_child_resolved_to_parent(self, db): + """Parent + child matching same term should return 1 result (parent).""" + import json + from tools.session_search_tool import _fast_search + + db.create_session(session_id="parent-1", source="cli") + db.create_session(session_id="child-1", source="cli", parent_session_id="parent-1") + db.append_message(session_id="parent-1", role="user", content="docker setup question") + db.append_message(session_id="child-1", role="assistant", content="docker setup done") + + result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id=None)) + assert result["success"] + sids = [e["session_id"] for e in result["results"]] + # Should collapse to parent, not show both + assert "child-1" not in sids, "Child should be resolved to parent" + assert "parent-1" in sids + assert len(result["results"]) == 1 + + def test_match_count_accumulates_from_children(self, db): + """Match_count should sum parent + child matches.""" + import json + from tools.session_search_tool import _fast_search + + db.create_session(session_id="p", source="cli") + db.create_session(session_id="c", source="cli", parent_session_id="p") + db.append_message(session_id="p", role="user", content="docker question") + db.append_message(session_id="c", role="assistant", content="docker answer") + + result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id=None)) + entry = result["results"][0] + assert entry["session_id"] == "p" + assert entry["match_count"] >= 2, f"Expected accumulated count >= 2, got {entry['match_count']}" + + def test_current_session_lineage_excludes_children(self, db): + """When current session is a child, parent should also be excluded.""" + import json + from tools.session_search_tool import _fast_search + + db.create_session(session_id="parent-2", source="cli") + db.create_session(session_id="child-2", source="cli", parent_session_id="parent-2") + db.create_session(session_id="unrelated", source="cli") + db.append_message(session_id="parent-2", role="user", content="docker deploy") + db.append_message(session_id="child-2", role="assistant", content="docker deployed") + db.append_message(session_id="unrelated", role="user", content="docker build") + + # Current session = child -> should exclude parent-2 AND child-2, keep unrelated + result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id="child-2")) + sids = [e["session_id"] for e in result["results"]] + assert "parent-2" not in sids, "Parent of current should be excluded" + assert "child-2" not in sids, "Current child should be excluded" + assert "unrelated" in sids, "Unrelated session should appear" + + +class TestGetChildSessionIds: + """Tests for SessionDB.get_child_session_ids -- public API replacing + direct db._lock/db._conn access in _fast_search.""" + + def test_returns_child_ids(self, db): + db.create_session(session_id="parent", source="cli") + db.create_session(session_id="child-1", source="delegation", parent_session_id="parent") + db.create_session(session_id="child-2", source="compression", parent_session_id="parent") + db.create_session(session_id="orphan", source="cli") + + children = db.get_child_session_ids("parent") + assert set(children) == {"child-1", "child-2"} + + def test_returns_empty_for_leaf_session(self, db): + db.create_session(session_id="leaf", source="cli") + assert db.get_child_session_ids("leaf") == [] + + def test_returns_empty_for_no_args(self, db): + assert db.get_child_session_ids() == [] + + def test_multiple_parent_ids(self, db): + db.create_session(session_id="p1", source="cli") + db.create_session(session_id="p2", source="cli") + db.create_session(session_id="c1", source="delegation", parent_session_id="p1") + db.create_session(session_id="c2", source="delegation", parent_session_id="p2") + + children = db.get_child_session_ids("p1", "p2") + assert set(children) == {"c1", "c2"} + + def test_does_not_recurse(self, db): + """Only direct children, not grandchildren.""" + db.create_session(session_id="root", source="cli") + db.create_session(session_id="child", source="delegation", parent_session_id="root") + db.create_session(session_id="grandchild", source="delegation", parent_session_id="child") + + children = db.get_child_session_ids("root") + assert children == ["child"] + + +class TestNoiseReduction: + """Tests for noise reduction in term indexing. + + Tool-role messages (structured JSON output) produce junk terms like + 'output', 'exit_code', 'null', 'true', 'false'. Pure numeric tokens + ('0', '1', '2') are never useful search targets. JSON key names that + appear in tool output schemas should be treated as stop words. + """ + + def test_tool_role_messages_not_indexed(self, db): + """Tool-role messages should be skipped entirely during indexing.""" + db.create_session(session_id="s1", source="cli") + db.append_message( + session_id="s1", + role="tool", + content='{"output": "docker is running", "exit_code": 0}', + tool_name="terminal", + ) + + # Tool output should NOT index any terms from the JSON blob + # Even though 'docker' appears in the output string, it's inside + # structured JSON from a tool call, not natural language + cursor = db._conn.execute( + "SELECT COUNT(*) FROM term_index WHERE session_id = 's1'" + ) + assert cursor.fetchone()[0] == 0 + + def test_assistant_role_still_indexed(self, db): + """Non-tool messages should still be indexed normally.""" + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="docker deploy") + db.append_message( + session_id="s1", role="assistant", content="docker is now running" + ) + + results = db.search_by_terms(["docker"]) + assert len(results) >= 1 + + def test_pure_numeric_tokens_filtered(self): + """Pure numeric tokens should be excluded from term extraction.""" + from term_index import extract_terms + + terms = extract_terms("exit code 0 with 42 errors in 123 steps") + # These numeric tokens provide zero search value + for num in ["0", "42", "123"]: + assert num not in terms, f"Pure numeric '{num}' should be filtered" + + # But word tokens should survive + assert "exit" in terms + assert "code" in terms + assert "errors" in terms + assert "steps" in terms + + def test_json_key_stopwords_filtered(self): + """Common JSON schema keys from tool output should be stop words.""" + from stop_words import is_stop_word + + json_keys = [ + "output", + "exit_code", + "error", + "null", + "true", + "false", + "status", + "content", + "message", + "cleared", + "success", + ] + for key in json_keys: + assert is_stop_word(key), f"JSON key '{key}' should be a stop word" + + def test_json_key_stopwords_in_extract_terms(self): + """JSON key stop words should be filtered by extract_terms.""" + from term_index import extract_terms + + # Simulates typical tool output content + terms = extract_terms( + '{"output": "hello world", "exit_code": 0, "error": null, "success": true}' + ) + for junk in ["output", "exit_code", "error", "null", "success", "true", "false"]: + assert junk not in terms, f"JSON key '{junk}' should be filtered" + + # Actual content words should survive + assert "hello" in terms + assert "world" in terms + + def test_reindex_skips_tool_messages(self, db): + """reindex_term_index should not index tool-role messages.""" + db.create_session(session_id="s1", source="cli") + db.append_message(session_id="s1", role="user", content="deploy docker") + db.append_message( + session_id="s1", + role="tool", + content='{"output": "docker running", "exit_code": 0}', + ) + + # Clear and reindex + db._conn.execute("DELETE FROM term_index") + db._conn.commit() + db.reindex_term_index() + + # Tool message terms should not be in index + cursor = db._conn.execute( + "SELECT term FROM term_index WHERE session_id = 's1'" + ) + indexed_terms = [row[0] for row in cursor.fetchall()] + for junk in ["output", "exit_code", "0"]: + assert junk not in indexed_terms, f"'{junk}' should not be indexed from tool messages" + + +class TestCJKFallbackInFastSearch: + """CJK queries should fall through to the slow path even when fast=True. + + The term index can't handle CJK because extract_terms() splits on + whitespace, and CJK languages don't use spaces between words. + session_search should detect this and use the FTS5+LIKE fallback. + """ + + def test_cjk_query_bypasses_fast_path(self, db): + """A CJK query with fast=True should be downgraded to fast=False.""" + import json + from tools.session_search_tool import session_search + + db.create_session(session_id="cjk-1", source="cli") + db.append_message(session_id="cjk-1", role="user", content="测试中文搜索") + + # fast=True, but CJK query should fall through to full search + result = json.loads(session_search( + query="中文", db=db, limit=3, fast=True, current_session_id=None + )) + # The result should come from the slow path (mode="full") + # not the fast path (mode="fast") since CJK triggers fallback + assert result["success"] + # mode should be "full" (not "fast") because CJK forced the fallback + assert result.get("mode") != "fast" + + def test_english_query_stays_fast(self, db): + """Non-CJK queries should still use the fast path.""" + import json + from tools.session_search_tool import session_search + + db.create_session(session_id="eng-1", source="cli") + db.append_message(session_id="eng-1", role="user", content="deploy the server") + + result = json.loads(session_search( + query="deploy", db=db, limit=3, fast=True, current_session_id=None + )) + assert result["success"] + assert result.get("mode") == "fast" \ No newline at end of file diff --git a/tests/tools/test_delegate.py b/tests/tools/test_delegate.py index f3a1a2632d..a8f5511da0 100644 --- a/tests/tools/test_delegate.py +++ b/tests/tools/test_delegate.py @@ -656,10 +656,12 @@ class TestDelegationCredentialResolution(unittest.TestCase): self.assertEqual(creds["provider"], "custom") def test_direct_endpoint_does_not_fall_back_to_openrouter_api_key_env(self): + """Remote endpoint without OPENAI_API_KEY should raise ValueError, + even if OPENROUTER_API_KEY is set (only OPENAI_API_KEY is checked).""" parent = _make_mock_parent(depth=0) cfg = { "model": "qwen2.5-coder", - "base_url": "http://localhost:1234/v1", + "base_url": "https://api.example.com/v1", # remote, not localhost } with patch.dict( os.environ, diff --git a/tests/tools/test_delegation_local_provider.py b/tests/tools/test_delegation_local_provider.py new file mode 100644 index 0000000000..5d4a5797b5 --- /dev/null +++ b/tests/tools/test_delegation_local_provider.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Tests for delegation with local/Ollama providers that don't require API keys. + +Ollama and other local model servers run on localhost and accept requests +without authentication. The delegation credential resolver should allow these +endpoints to work without requiring an API key. + +Run with: python -m pytest tests/tools/test_delegation_local_provider.py -v +""" + +import os +import unittest +from unittest.mock import MagicMock, patch + +from tools.delegate_tool import ( + _resolve_delegation_credentials, +) + + +def _make_mock_parent(depth=0): + """Create a mock parent agent with the fields delegate_task expects.""" + parent = MagicMock() + parent.base_url = "http://localhost:11434/v1" + parent.api_key = "ollama" + parent.provider = "custom" + parent.api_mode = "chat_completions" + parent.model = "glm-5.1:cloud" + parent.platform = "cli" + parent.providers_allowed = None + parent.providers_ignored = None + parent.providers_order = None + parent.provider_sort = None + parent._session_db = None + parent._delegate_depth = depth + parent._active_children = [] + parent._active_children_lock = __import__("threading").Lock() + parent._print_fn = None + parent.tool_progress_callback = None + parent.thinking_callback = None + parent._credential_pool = None + parent.reasoning_config = None + parent.max_tokens = None + parent.prefill_messages = None + parent.acp_command = None + parent.acp_args = [] + parent.valid_tool_names = ["terminal", "file", "web"] + parent.enabled_toolsets = None # None = all tools + return parent + + +class TestLocalProviderCredentials(unittest.TestCase): + """Tests for _resolve_delegation_credentials with local providers.""" + + # --- base_url path (localhost) --- + + def test_localhost_base_url_no_api_key_allowed(self): + """localhost base_url should work without an API key (Ollama, LM Studio, etc.).""" + parent = _make_mock_parent() + cfg = { + "model": "devstral-small-2:24b-cloud", + "provider": "custom", + "base_url": "http://localhost:11434/v1", + "api_key": "", + } + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["base_url"], "http://localhost:11434/v1") + self.assertIsNotNone(creds["api_key"]) + # API key should be a harmless placeholder, not None + self.assertNotEqual(creds["api_key"], "") + + def test_127_base_url_no_api_key_allowed(self): + """127.0.0.1 base_url should work without an API key.""" + parent = _make_mock_parent() + cfg = { + "model": "devstral-small-2:24b-cloud", + "provider": "", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": "", + } + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["base_url"], "http://127.0.0.1:11434/v1") + self.assertIsNotNone(creds["api_key"]) + + def test_dotlocal_base_url_no_api_key_allowed(self): + """.local mDNS hostnames (e.g. studio.local) should work without an API key.""" + parent = _make_mock_parent() + cfg = { + "model": "devstral-small-2:24b-cloud", + "provider": "", + "base_url": "http://studio.local:11434/v1", + "api_key": "", + } + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["base_url"], "http://studio.local:11434/v1") + self.assertIsNotNone(creds["api_key"]) + + def test_localhost_base_url_with_explicit_api_key_preserved(self): + """If user provides an API key for localhost, it should be preserved as-is.""" + parent = _make_mock_parent() + cfg = { + "model": "devstral-small-2:24b-cloud", + "provider": "custom", + "base_url": "http://localhost:11434/v1", + "api_key": "my-secret-key", + } + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["api_key"], "my-secret-key") + + def test_localhost_base_url_whitespace_api_key_gets_placeholder(self): + """Whitespace-only api_key should be treated as absent and get placeholder.""" + parent = _make_mock_parent() + cfg = { + "model": "devstral-small-2:24b-cloud", + "provider": "custom", + "base_url": "http://localhost:11434/v1", + "api_key": " ", + } + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["api_key"], "ollama") + + # --- base_url path (remote) should still require API key --- + + def test_remote_base_url_still_requires_api_key(self): + """Non-localhost base_url without API key should still raise ValueError.""" + parent = _make_mock_parent() + cfg = { + "model": "gpt-4o-mini", + "provider": "", + "base_url": "https://api.openai.com/v1", + "api_key": "", + } + with patch.dict(os.environ, {"OPENAI_API_KEY": ""}, clear=False): + with self.assertRaises(ValueError) as ctx: + _resolve_delegation_credentials(cfg, parent) + self.assertIn("API key", str(ctx.exception)) + + # --- provider path with custom/local --- + + @patch("hermes_cli.runtime_provider.resolve_runtime_provider") + def test_custom_provider_resolving_to_localhost_no_api_key(self, mock_resolve): + """When delegation.provider='custom' resolves to localhost, empty API key should be allowed.""" + mock_resolve.return_value = { + "provider": "custom", + "base_url": "http://localhost:11434/v1", + "api_key": "", + "api_mode": "chat_completions", + } + parent = _make_mock_parent() + cfg = {"model": "devstral-small-2:24b-cloud", "provider": "custom"} + creds = _resolve_delegation_credentials(cfg, parent) + self.assertEqual(creds["provider"], "custom") + self.assertEqual(creds["base_url"], "http://localhost:11434/v1") + # Should get a placeholder key, not raise ValueError + self.assertIsNotNone(creds["api_key"]) + self.assertNotEqual(creds["api_key"], "") + + @patch("hermes_cli.runtime_provider.resolve_runtime_provider") + def test_remote_provider_still_requires_api_key(self, mock_resolve): + """Provider resolving to a remote endpoint without API key should still raise.""" + mock_resolve.return_value = { + "provider": "openrouter", + "base_url": "https://openrouter.ai/api/v1", + "api_key": "", + "api_mode": "chat_completions", + } + parent = _make_mock_parent() + cfg = {"model": "some-model", "provider": "openrouter"} + with self.assertRaises(ValueError) as ctx: + _resolve_delegation_credentials(cfg, parent) + self.assertIn("no API key", str(ctx.exception)) + + # --- Integration: child agent gets local placeholder key --- + + @patch("tools.delegate_tool._load_config") + def test_local_delegation_uses_placeholder_key(self, mock_cfg): + """Delegation with localhost base_url should get 'ollama' placeholder API key.""" + mock_cfg.return_value = { + "model": "devstral-small-2:24b-cloud", + "provider": "custom", + "base_url": "http://localhost:11434/v1", + "api_key": "", + "max_iterations": 10, + "max_concurrent_children": 1, + } + parent = _make_mock_parent() + creds = _resolve_delegation_credentials(mock_cfg.return_value, parent) + self.assertEqual(creds["base_url"], "http://localhost:11434/v1") + self.assertEqual(creds["api_key"], "ollama") + + +class TestIsLocalBaseUrlHelper(unittest.TestCase): + """Tests for the _is_local_base_url helper function.""" + + def test_localhost_with_port(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://localhost:11434/v1")) + + def test_localhost_no_port(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://localhost/v1")) + + def test_127_ip(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://127.0.0.1:11434/v1")) + + def test_dotlocal(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://studio.local:11434/v1")) + + def test_remote_url(self): + from tools.delegate_tool import _is_local_base_url + self.assertFalse(_is_local_base_url("https://api.openai.com/v1")) + + def test_openrouter(self): + from tools.delegate_tool import _is_local_base_url + self.assertFalse(_is_local_base_url("https://openrouter.ai/api/v1")) + + def test_192_168_private(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://192.168.1.100:11434/v1")) + + def test_10_private(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://10.0.0.5:11434/v1")) + + def test_172_private(self): + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://172.16.0.1:11434/v1")) + + def test_empty_string(self): + from tools.delegate_tool import _is_local_base_url + self.assertFalse(_is_local_base_url("")) + + def test_none(self): + from tools.delegate_tool import _is_local_base_url + self.assertFalse(_is_local_base_url(None)) + + def test_ipv6_loopback(self): + """IPv6 loopback [::1] should be recognized as local.""" + from tools.delegate_tool import _is_local_base_url + self.assertTrue(_is_local_base_url("http://[::1]:11434/v1")) + + def test_172_outside_private_range(self): + """172.32.x.x is NOT in 172.16/12 and should not be treated as local.""" + from tools.delegate_tool import _is_local_base_url + self.assertFalse(_is_local_base_url("http://172.32.0.1:11434/v1")) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index 2bbf354cf7..f064ad4fa9 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -2098,6 +2098,46 @@ def _resolve_child_credential_pool(effective_provider: Optional[str], parent_age return None +def _is_local_base_url(base_url: Optional[str]) -> bool: + """Return True if base_url points to a local/private network address. + + Local providers (Ollama, LM Studio, llama.cpp server, etc.) typically + don't require authentication. This check covers: + - localhost / loopback (127.0.0.1, ::1) + - .local mDNS hostnames (e.g. studio.local) + - RFC 1918 private networks (10/8, 172.16/12, 192.168/16) + + .. note:: + Any URL that resolves as "local" by this function will receive a + placeholder ``"ollama"`` API key. If an internal service on a + private network actually requires authentication (e.g. a corporate + AI gateway at 192.168.x.x), the placeholder key will be rejected + by that server (401/403). This is intentional — local servers that + genuinely don't need auth work out-of-the-box, while misconfigured + endpoints fail loudly rather than silently. + """ + if not base_url: + return False + hostname = base_url_hostname(base_url) + if not hostname: + return False + # localhost variants + if hostname in ("localhost", "127.0.0.1", "::1"): + return True + # mDNS .local hostnames + if hostname.endswith(".local"): + return True + # RFC 1918 private subnets + import ipaddress + + try: + ip = ipaddress.ip_address(hostname) + return ip.is_private or ip.is_loopback + except ValueError: + pass # not an IP address, that's fine + return False + + def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict: """Resolve credentials for subagent delegation. @@ -2111,6 +2151,10 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict: If neither base_url nor provider is configured, returns None values so the child inherits everything from the parent agent. + Local endpoints (localhost, 127.0.0.1, .local, RFC 1918 private nets) + don't require API keys — a placeholder "ollama" key is used when none + is provided, since these servers accept any or no authentication. + Raises ValueError with a user-friendly message on credential failure. """ configured_model = str(cfg.get("model") or "").strip() or None @@ -2120,6 +2164,10 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict: if configured_base_url: api_key = configured_api_key or os.getenv("OPENAI_API_KEY", "").strip() + # Local endpoints (Ollama, LM Studio, etc.) don't require auth. + # Use a dummy key so the OpenAI client doesn't reject the request. + if not api_key and _is_local_base_url(configured_base_url): + api_key = "ollama" if not api_key: raise ValueError( "Delegation base_url is configured but no API key was found. " @@ -2175,10 +2223,15 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict: api_key = runtime.get("api_key", "") if not api_key: - raise ValueError( - f"Delegation provider '{configured_provider}' resolved but has no API key. " - f"Set the appropriate environment variable or run 'hermes auth'." - ) + # Local providers don't require real API keys. + resolved_base = runtime.get("base_url", "") + if _is_local_base_url(resolved_base): + api_key = "ollama" + else: + raise ValueError( + f"Delegation provider '{configured_provider}' resolved but has no API key. " + f"Set the appropriate environment variable or run 'hermes auth'." + ) return { "model": configured_model, diff --git a/tools/session_search_tool.py b/tools/session_search_tool.py index 16aaea109f..deaa20a285 100644 --- a/tools/session_search_tool.py +++ b/tools/session_search_tool.py @@ -321,11 +321,14 @@ def session_search( limit: int = 3, db=None, current_session_id: str = None, + fast: bool = True, ) -> str: """ Search past sessions and return focused summaries of matching conversations. - Uses FTS5 to find matches, then summarizes the top sessions with Gemini Flash. + By default (fast=True), uses the term_index inverted index for instant + results with session metadata and match counts — no LLM calls needed. + Set fast=False to use FTS5 + LLM summarization for detailed summaries. The current session is excluded from results since the agent already has that context. """ if db is None: @@ -348,6 +351,120 @@ def session_search( query = query.strip() + # CJK queries can't be handled by the term index (no word boundaries + # for extract_terms to split on). Fall through to FTS5 + LIKE which + # has a CJK bigram/LIKE fallback. + if fast and db._contains_cjk(query): + fast = False + + # ── Fast path: term index (no LLM, ~1ms) ────────────────────────── + if fast: + return _fast_search(query, db, limit, current_session_id) + + # ── Slow path: FTS5 + LLM summarization (~5-15s) ─────────────────── + return _full_search(query, role_filter, limit, db, current_session_id) + + +def _fast_search(query: str, db, limit: int, current_session_id: str = None) -> str: + """Term index fast path: instant search, no LLM calls.""" + from term_index import extract_terms + + terms = extract_terms(query) + if not terms: + return json.dumps({ + "success": True, + "query": query, + "results": [], + "count": 0, + "message": "No searchable terms in query (all stop words or empty).", + }, ensure_ascii=False) + + # Fetch extra results so we have room after dedup/lineage exclusion + raw_results = db.search_by_terms( + terms=terms, + exclude_sources=list(_HIDDEN_SESSION_SOURCES), + limit=limit * 3, + ) + + # Resolve child sessions to their parent root, just like _full_search. + # Delegation stores detailed content in child sessions, but the user + # sees the parent conversation. Without this, parent + child both + # containing "docker" would appear as two separate results. + def _resolve_to_parent(session_id: str) -> str: + visited = set() + sid = session_id + while sid and sid not in visited: + visited.add(sid) + try: + session = db.get_session(sid) + if not session: + break + parent = session.get("parent_session_id") + if parent: + sid = parent + else: + break + except Exception: + break + return sid + + # Determine current session lineage for exclusion + current_lineage = set() + if current_session_id: + # Walk parent chain AND collect all children + root = _resolve_to_parent(current_session_id) + current_lineage.add(root) + current_lineage.add(current_session_id) + # Also find any child sessions of the current root + try: + children = db.get_child_session_ids(root, current_session_id) + current_lineage.update(children) + except Exception: + pass + + seen_sessions = {} + for r in raw_results: + raw_sid = r.get("session_id", "") + resolved_sid = _resolve_to_parent(raw_sid) + if resolved_sid in current_lineage or raw_sid in current_lineage: + continue + if resolved_sid not in seen_sessions: + # Sum match_count from child into parent + seen_sessions[resolved_sid] = dict(r) + seen_sessions[resolved_sid]["session_id"] = resolved_sid + else: + # Accumulate match_count from child sessions + seen_sessions[resolved_sid]["match_count"] = ( + seen_sessions[resolved_sid].get("match_count", 0) + + r.get("match_count", 0) + ) + if len(seen_sessions) >= limit: + break + + entries = [] + for sid, r in seen_sessions.items(): + entries.append({ + "session_id": sid, + "when": _format_timestamp(r.get("session_started")), + "source": r.get("source", "unknown"), + "model": r.get("model"), + "title": r.get("title"), + "match_count": r.get("match_count", 0), + }) + + return json.dumps({ + "success": True, + "query": query, + "mode": "fast", + "results": entries, + "count": len(entries), + "message": f"Found {len(entries)} matching sessions via term index (instant, no LLM)." + f" Use fast=False for LLM-summarized results.", + }, ensure_ascii=False) + + +def _full_search(query: str, role_filter: str, limit: int, db, current_session_id: str = None) -> str: + """FTS5 + LLM summarization path (original behavior).""" try: # Parse role filter role_list = None @@ -367,6 +484,7 @@ def session_search( return json.dumps({ "success": True, "query": query, + "mode": "full", "results": [], "count": 0, "message": "No matching sessions found.", @@ -506,6 +624,7 @@ def session_search( return json.dumps({ "success": True, "query": query, + "mode": "full", "results": summaries, "count": len(summaries), "sessions_searched": len(seen_sessions), @@ -535,7 +654,8 @@ SESSION_SEARCH_SCHEMA = { "Returns titles, previews, and timestamps. Zero LLM cost, instant. " "Start here when the user asks what were we working on or what did we do recently.\n" "2. Keyword search (with query): Search for specific topics across all past sessions. " - "Returns LLM-generated summaries of matching sessions.\n\n" + "By default uses the term index for instant results (no LLM). " + "Set fast=False for detailed LLM-generated summaries.\n\n" "USE THIS PROACTIVELY when:\n" "- The user says 'we did this before', 'remember when', 'last time', 'as I mentioned'\n" "- The user asks about a topic you worked on before but don't have in current context\n" @@ -544,11 +664,15 @@ SESSION_SEARCH_SCHEMA = { "- The user asks 'what did we do about X?' or 'how did we fix Y?'\n\n" "Don't hesitate to search when it is actually cross-session -- it's fast and cheap. " "Better to search and confirm than to guess or ask the user to repeat themselves.\n\n" - "Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), " - "phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). " - "IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses " - "sessions that only mention some terms. If a broad OR query returns nothing, try individual " - "keyword searches in parallel. Returns summaries of the top matching sessions." + "Search syntax depends on the mode:\n" + "- fast=True (default): Simple keyword search with AND semantics. Multiple words " + "all must appear in a session. No boolean operators or phrase matching. " + "Instant, zero LLM cost.\n" + "- fast=False: FTS5 full syntax — OR for broad recall (elevenlabs OR baseten OR funding), " + "phrases for exact match (\\\"docker networking\\\"), boolean (python NOT java), " + "prefix (deploy*). IMPORTANT: Use OR between keywords for best FTS5 results — " + "it defaults to AND which misses sessions that only mention some terms. " + "Slower (5-15s) but returns LLM-summarized results." ), "parameters": { "type": "object", @@ -559,13 +683,18 @@ SESSION_SEARCH_SCHEMA = { }, "role_filter": { "type": "string", - "description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs.", + "description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs. Only used when fast=False.", }, "limit": { "type": "integer", "description": "Max sessions to summarize (default: 3, max: 5).", "default": 3, }, + "fast": { + "type": "boolean", + "description": "When true (default), use the term index for instant results with no LLM cost. When false, use FTS5 + LLM summarization for detailed summaries.", + "default": True, + }, }, "required": [], }, @@ -583,6 +712,7 @@ registry.register( query=args.get("query") or "", role_filter=args.get("role_filter"), limit=args.get("limit", 3), + fast=args.get("fast", True), db=kw.get("db"), current_session_id=kw.get("current_session_id")), check_fn=check_session_search_requirements,