diff --git a/agent/checkpoint_injection.py b/agent/checkpoint_injection.py new file mode 100644 index 0000000000..f9a7f02169 --- /dev/null +++ b/agent/checkpoint_injection.py @@ -0,0 +1,86 @@ +"""Checkpoint injection -- load checkpoint from parent session into system prompt. + +When a new session starts (or a compression creates a continuation session), +this module finds the most recent checkpoint from the session's lineage +and injects it into the system prompt. + +This is Layer 4 in the memory architecture. Only in_progress checkpoints +are injected -- completed ones are skipped. +""" + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +SEPARATOR = "\u2550" * 46 # same as vault injection +CHAR_LIMIT = 4000 + + +def build_checkpoint_system_prompt(store, parent_session_id: str = None) -> str: + """Build the checkpoint injection block for the system prompt. + + Returns empty string if no applicable checkpoint exists. + The caller is responsible for walking the parent lineage chain to find + the nearest ancestor session that has a checkpoint file. + """ + if not store or not parent_session_id: + return "" + + data = store.read(parent_session_id) + + if not data or not isinstance(data, dict): + return "" + + if data.get("status") == "completed": + return "" + + lines = [] + lines.append(f"Task: {data.get('task', 'Unknown')}") + lines.append(f"Status: {data.get('status', 'in_progress')}") + lines.append("") + + progress = data.get("progress", []) + if progress: + lines.append("Progress:") + for step in progress: + status_icon = { + "completed": "[x]", "in_progress": "[~]", + "pending": "[ ]", "cancelled": "[-]", + }.get(step.get("status", "pending"), "[ ]") + line = f" {status_icon} {step.get('step', '')}" + if step.get("result"): + line += f" -- {step['result']}" + lines.append(line) + lines.append("") + + state = data.get("state", {}) + if state: + lines.append("State:") + for key, value in state.items(): + if value and str(value) not in ("[]", ""): + lines.append(f" {key}: {value}") + lines.append("") + + decisions = data.get("decisions", []) + if decisions: + lines.append("Decisions:") + for d in decisions: + lines.append(f" - {d}") + lines.append("") + + blocked = data.get("blocked", []) + if blocked: + lines.append(f"Blocked: {', '.join(blocked)}") + + unresolved = data.get("unresolved", []) + if unresolved: + lines.append(f"Unresolved: {', '.join(unresolved)}") + + content = "\n".join(lines) + + if len(content) > CHAR_LIMIT: + content = content[:CHAR_LIMIT] + "\n[... truncated at char limit ...]" + + header = "CHECKPOINT: RESUME FROM HERE (saved before compaction, injected on session start)" + return f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{content}" \ No newline at end of file diff --git a/agent/checkpoint_store.py b/agent/checkpoint_store.py new file mode 100644 index 0000000000..d5ca40a84c --- /dev/null +++ b/agent/checkpoint_store.py @@ -0,0 +1,117 @@ +"""Checkpoint store -- read/write YAML checkpoint files for session resumption. + +Each checkpoint captures enough state to resume a task from where the agent +left off after context compaction or session restart. Files live under +~/.hermes/checkpoints/.yaml. +""" + +import logging +import re +import yaml +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# Session IDs must be safe filesystem names -- no path separators or traversal +_SAFE_SESSION_ID = re.compile(r"^[A-Za-z0-9_-]+$") + +DEFAULT_CHECKPOINTS_DIR = Path.home() / ".hermes" / "checkpoints" +DEFAULT_GC_MAX_AGE_DAYS = 7 + + +class CheckpointStore: + """Manages checkpoint YAML files on disk.""" + + def __init__(self, checkpoints_dir: Path = None): + self._dir = Path(checkpoints_dir) if checkpoints_dir else DEFAULT_CHECKPOINTS_DIR + self._dir.mkdir(parents=True, exist_ok=True) + + def _path_for(self, session_id: str) -> Path: + if not _SAFE_SESSION_ID.match(session_id): + raise ValueError(f"Invalid session_id: {session_id!r} (must match {_SAFE_SESSION_ID.pattern})") + return self._dir / f"{session_id}.yaml" + + def write(self, session_id: str, data: Dict[str, Any]) -> None: + """Write a checkpoint for the given session. Overwrites if exists.""" + path = self._path_for(session_id) + path.parent.mkdir(parents=True, exist_ok=True) + # Ensure the 'updated' timestamp is current + data["updated"] = datetime.now().isoformat() + try: + path.write_text( + yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + logger.debug("Checkpoint written: %s (%d chars)", session_id, path.stat().st_size) + except (OSError, yaml.YAMLError) as e: + logger.warning("Failed to write checkpoint %s: %s", session_id, e) + + def read(self, session_id: str) -> Optional[Dict[str, Any]]: + """Read a checkpoint. Returns None if missing or corrupt.""" + path = self._path_for(session_id) + if not path.exists(): + return None + try: + content = path.read_text(encoding="utf-8") + data = yaml.safe_load(content) + if not isinstance(data, dict): + return None + return data + except (OSError, yaml.YAMLError) as e: + logger.debug("Failed to read checkpoint %s: %s", session_id, e) + return None + + def delete(self, session_id: str) -> None: + """Delete a checkpoint. No-op if missing.""" + path = self._path_for(session_id) + try: + path.unlink(missing_ok=True) + except OSError: + pass + + def list_sessions(self) -> List[str]: + """Return all session IDs that have checkpoint files.""" + if not self._dir.exists(): + return [] + return sorted(p.stem for p in self._dir.glob("*.yaml")) + + def garbage_collect(self, max_age_days: int = DEFAULT_GC_MAX_AGE_DAYS) -> List[str]: + """Remove stale checkpoints. Returns list of removed session IDs. + + A checkpoint is stale if: + - Its 'updated' timestamp is older than max_age_days AND not completed + - OR it's 'completed' and older than 1 day (task done, no need to keep) + """ + removed = [] + cutoff = datetime.now() - timedelta(days=max_age_days) + completed_cutoff = datetime.now() - timedelta(days=1) + + for path in self._dir.glob("*.yaml"): + try: + data = yaml.safe_load(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + path.unlink(missing_ok=True) + removed.append(path.stem) + continue + updated_str = data.get("updated", "") + status = data.get("status", "in_progress") + try: + updated = datetime.fromisoformat(updated_str) if updated_str else cutoff + except ValueError: + updated = cutoff + + if status == "completed" and updated < completed_cutoff: + path.unlink(missing_ok=True) + removed.append(path.stem) + elif updated < cutoff: + path.unlink(missing_ok=True) + removed.append(path.stem) + except (OSError, yaml.YAMLError): + path.unlink(missing_ok=True) + removed.append(path.stem) + + if removed: + logger.info("Checkpoint GC: removed %d stale file(s): %s", len(removed), removed) + return removed \ No newline at end of file diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 3a6ec24415..1a0ee14bd6 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -176,6 +176,20 @@ SKILLS_GUIDANCE = ( "Skills that aren't maintained become liabilities." ) +CHECKPOINT_GUIDANCE = ( + "When a CHECKPOINT block appears at session start, read it and resume from " + "where the previous session left off. Do not start over — pick up at the " + "first in_progress step.\n" + "Write a checkpoint (action=\"write\") at these moments:\n" + "- Before a risky command that could break state\n" + "- After completing a major step in a multi-step task (then action=\"update\" " + "for subsequent steps)\n" + "- When you have important \"why\" decisions that won't be obvious later\n" + "Clear the checkpoint (action=\"clear\") when the task is fully done.\n" + "Pre-compression checkpoints are auto-written — you do not need to worry " + "about losing progress when the context window fills." +) + TOOL_USE_ENFORCEMENT_GUIDANCE = ( "# Tool-use enforcement\n" "You MUST use your tools to take action — do not describe what you would do " diff --git a/model_tools.py b/model_tools.py index 36cea8f304..b7dbb340b7 100644 --- a/model_tools.py +++ b/model_tools.py @@ -366,7 +366,7 @@ def get_tool_definitions( # because they need agent-level state (TodoStore, MemoryStore, etc.). # The registry still holds their schemas; dispatch just returns a stub error # so if something slips through, the LLM sees a sensible message. -_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"} +_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task", "checkpoint"} _READ_SEARCH_TOOLS = {"read_file", "search_files"} diff --git a/run_agent.py b/run_agent.py index 6770f568c0..0ba8ad5de2 100644 --- a/run_agent.py +++ b/run_agent.py @@ -85,6 +85,7 @@ from agent.error_classifier import classify_api_error, FailoverReason from agent.prompt_builder import ( DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, + CHECKPOINT_GUIDANCE, build_nous_subscription_prompt, ) from agent.model_metadata import ( @@ -1561,6 +1562,11 @@ class AIAgent: # In-memory todo list for task planning (one per agent/session) from tools.todo_tool import TodoStore self._todo_store = TodoStore() + + # Layer 4: Checkpoint store (session resumption across compression) + from agent.checkpoint_store import CheckpointStore + self._checkpoint_store = CheckpointStore() + self._checkpoint_store.garbage_collect() # prune stale checkpoints on startup # Load config once for memory, skills, and compression sections try: @@ -4391,6 +4397,8 @@ class AIAgent: tool_guidance.append(SESSION_SEARCH_GUIDANCE) if "skill_manage" in self.valid_tool_names: tool_guidance.append(SKILLS_GUIDANCE) + if "checkpoint" in self.valid_tool_names: + tool_guidance.append(CHECKPOINT_GUIDANCE) if tool_guidance: prompt_parts.append(" ".join(tool_guidance)) @@ -4448,6 +4456,54 @@ class AIAgent: if user_block: prompt_parts.append(user_block) + # Vault auto-injection (Layer 3) — reads working-context.md and + # user-profile.md from the Obsidian vault and injects them into + # the system prompt. Structural fix for vault neglect. + _vault_log = logging.getLogger("agent.vault") + if self._vault_enabled and self._vault_path: + try: + from agent.vault_injection import build_vault_system_prompt + _vault_block = build_vault_system_prompt(self._vault_path) + if _vault_block: + prompt_parts.append(_vault_block) + _vault_log.info("Injection succeeded: %d chars from %s", len(_vault_block), self._vault_path) + else: + _vault_log.warning("Injection returned empty for path %s", self._vault_path) + except Exception as e: + _vault_log.warning("Injection failed: %s: %s", type(e).__name__, e) + else: + _vault_log.info("Injection skipped: enabled=%s path=%s", self._vault_enabled, self._vault_path) + + # Checkpoint injection (Layer 4) -- resume from where the previous session left off + if self._checkpoint_store: + try: + _checkpoint_session_id = None + if self._session_db: + _sid = self.session_id + # Walk up the parent chain to find the nearest ancestor with a checkpoint + for _ in range(5): # max 5 hops up lineage + try: + _sess = self._session_db.get_session(_sid) + except Exception: + break + if not _sess: + break + _parent = _sess.get("parent_session_id") + if not _parent: + break + if self._checkpoint_store.read(_parent) is not None: + _checkpoint_session_id = _parent + break + _sid = _parent # keep walking up + from agent.checkpoint_injection import build_checkpoint_system_prompt + _cp_block = build_checkpoint_system_prompt( + store=self._checkpoint_store, + parent_session_id=_checkpoint_session_id, + ) + if _cp_block: + prompt_parts.append(_cp_block) + except Exception as e: + logger.debug("Checkpoint injection skipped: %s", e) # External memory provider system prompt block (additive to built-in) if self._memory_manager: try: @@ -8119,6 +8175,52 @@ class AIAgent: if messages and messages[-1].get("_flush_sentinel") == _sentinel: messages.pop() + def flush_checkpoint(self, messages: list = None) -> None: + """Write a checkpoint before context compression destroys progress. + + Called by _compress_context before the compression happens. + Captures whatever task state we can extract from the agent's + current context (todo list, session ID, git state) and writes + a checkpoint. If a checkpoint already exists for this session, + it gets overwritten -- the pre-compression checkpoint is always + the most accurate. + """ + if not self._checkpoint_store: + return + + from tools.checkpoint_tool import checkpoint_tool + # Extract task description from session title or fallback + task_desc = "Session checkpoint (auto-saved before compression)" + try: + if self._session_db: + title = self._session_db.get_session_title(self.session_id) + if title: + task_desc = title + except Exception: + pass + + # Extract progress from todo store if available + progress = [] + if self._todo_store: + try: + items = self._todo_store._items if hasattr(self._todo_store, '_items') else [] + for item in items: + step = {"step": item.get("content", ""), "status": item.get("status", "pending")} + progress.append(step) + except Exception: + pass + + checkpoint_tool( + action="write", + task=task_desc, + progress=progress, + state={}, + decisions=[], + store=self._checkpoint_store, + agent=self, + ) + logger.info("Pre-compression checkpoint saved for session %s", self.session_id) + def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default", focus_topic: str = None) -> tuple: """Compress conversation context and split the session in SQLite. @@ -8140,6 +8242,9 @@ class AIAgent: # Pre-compression memory flush: let the model save memories before they're lost self.flush_memories(messages, min_turns=0) + # Pre-compression checkpoint: save task state before context is lost + self.flush_checkpoint(messages) + # Notify external memory provider before compression discards context if self._memory_manager: try: @@ -8351,6 +8456,19 @@ class AIAgent: ) elif function_name == "delegate_task": return self._dispatch_delegate_task(function_args) + elif function_name == "checkpoint": + from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool + return _checkpoint_tool( + action=function_args.get("action"), + task=function_args.get("task"), + progress=function_args.get("progress"), + state=function_args.get("state"), + decisions=function_args.get("decisions"), + blocked=function_args.get("blocked"), + unresolved=function_args.get("unresolved"), + store=self._checkpoint_store, + agent=self, + ) else: return handle_function_call( function_name, function_args, effective_task_id, @@ -8892,6 +9010,22 @@ class AIAgent: spinner.stop(cute_msg) elif self._should_emit_quiet_tool_messages(): self._vprint(f" {cute_msg}") + elif function_name == "checkpoint": + from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool + function_result = _checkpoint_tool( + action=function_args.get("action"), + task=function_args.get("task"), + progress=function_args.get("progress"), + state=function_args.get("state"), + decisions=function_args.get("decisions"), + blocked=function_args.get("blocked"), + unresolved=function_args.get("unresolved"), + store=self._checkpoint_store, + agent=self, + ) + tool_duration = time.time() - tool_start_time + if self._should_emit_quiet_tool_messages(): + self._vprint(f" {_get_cute_tool_message_impl('checkpoint', function_args, tool_duration, result=function_result)}") elif self._context_engine_tool_names and function_name in self._context_engine_tool_names: # Context engine tools (lcm_grep, lcm_describe, lcm_expand, etc.) spinner = None diff --git a/tests/test_checkpoint_agent_dispatch.py b/tests/test_checkpoint_agent_dispatch.py new file mode 100644 index 0000000000..6e48a8201d --- /dev/null +++ b/tests/test_checkpoint_agent_dispatch.py @@ -0,0 +1,16 @@ +"""Verify that checkpoint is routed through the agent loop, not the registry.""" +import json +import pytest +from model_tools import handle_function_call, _AGENT_LOOP_TOOLS + + +def test_checkpoint_is_agent_loop_tool(): + """checkpoint must be in _AGENT_LOOP_TOOLS so it gets agent-level state.""" + assert "checkpoint" in _AGENT_LOOP_TOOLS + + +def test_checkpoint_registry_dispatch_returns_error(): + """Calling handle_function_call for checkpoint should return agent-loop error.""" + result = handle_function_call("checkpoint", {"action": "read"}) + data = json.loads(result) + assert "must be handled" in data["error"].lower() \ No newline at end of file diff --git a/tests/test_checkpoint_flush.py b/tests/test_checkpoint_flush.py new file mode 100644 index 0000000000..757e68bb4b --- /dev/null +++ b/tests/test_checkpoint_flush.py @@ -0,0 +1,53 @@ +"""Test that flush_checkpoint writes a checkpoint via checkpoint_tool.""" +import json +import pytest +from unittest.mock import MagicMock, patch + + +def test_flush_checkpoint_method_exists(): + """AIAgent must have a flush_checkpoint method.""" + from run_agent import AIAgent + assert hasattr(AIAgent, "flush_checkpoint") + + +def test_flush_checkpoint_writes_to_store(tmp_path): + """flush_checkpoint should write a checkpoint with current session state.""" + from agent.checkpoint_store import CheckpointStore + + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + + mock_todo = MagicMock() + mock_todo._items = [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "in_progress"}, + ] + mock_todo.format_for_injection.return_value = "- [x] Step 1\n- [~] Step 2" + + mock_session_db = MagicMock() + mock_session_db.get_session_title.return_value = "Test session title" + + mock_agent = MagicMock() + mock_agent.session_id = "flush_test_session" + mock_agent._checkpoint_store = store + mock_agent._todo_store = mock_todo + mock_agent._session_db = mock_session_db + + # Call the real flush_checkpoint method + from run_agent import AIAgent + AIAgent.flush_checkpoint(mock_agent) + + # Verify a checkpoint was written for this session + saved = store.read("flush_test_session") + assert saved is not None + assert saved["task"] == "Test session title" + assert saved["status"] == "in_progress" + + +def test_flush_checkpoint_noops_without_store(tmp_path): + """flush_checkpoint should silently return if _checkpoint_store is None.""" + mock_agent = MagicMock() + mock_agent._checkpoint_store = None + + from run_agent import AIAgent + # Should not raise + AIAgent.flush_checkpoint(mock_agent) \ No newline at end of file diff --git a/tests/test_checkpoint_injection.py b/tests/test_checkpoint_injection.py new file mode 100644 index 0000000000..1e90ae3338 --- /dev/null +++ b/tests/test_checkpoint_injection.py @@ -0,0 +1,58 @@ +import pytest +from agent.checkpoint_store import CheckpointStore +from agent.checkpoint_injection import build_checkpoint_system_prompt + + +@pytest.fixture +def store(tmp_path): + return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + + +@pytest.fixture +def parent_checkpoint(): + return { + "session_id": "old_session_001", + "task": "Red team QA on PR #14303", + "status": "in_progress", + "created": "2026-04-22T22:58:00", + "updated": "2026-04-22T23:15:00", + "progress": [ + {"step": "Read code paths", "status": "completed"}, + {"step": "Add edge-case tests", "status": "in_progress"}, + ], + "state": {"active_branch": "fix/delegation", "tests_status": "22/22 passing"}, + "decisions": ["SSRF is low-severity"], + "blocked": [], + "unresolved": [], + } + + +class TestBuildCheckpointPrompt: + def test_injects_in_progress_checkpoint(self, store, parent_checkpoint): + store.write("old_session_001", parent_checkpoint) + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001") + assert "Red team QA on PR #14303" in prompt + assert "CHECKPOINT" in prompt + + def test_skips_completed_checkpoints(self, store, parent_checkpoint): + parent_checkpoint["status"] = "completed" + store.write("old_session_001", parent_checkpoint) + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001") + assert prompt == "" + + def test_returns_empty_when_no_parent(self, store): + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="nonexistent") + assert prompt == "" + + def test_includes_progress_and_decisions(self, store, parent_checkpoint): + store.write("old_session_001", parent_checkpoint) + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001") + assert "Read code paths" in prompt + assert "SSRF is low-severity" in prompt + + def test_truncates_oversized_checkpoints(self, store, parent_checkpoint): + parent_checkpoint["task"] = "x" * 5000 + parent_checkpoint["decisions"] = ["d" * 2000] * 10 + store.write("old_session_001", parent_checkpoint) + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001") + assert len(prompt) < 6000 \ No newline at end of file diff --git a/tests/test_checkpoint_integration.py b/tests/test_checkpoint_integration.py new file mode 100644 index 0000000000..09520d9361 --- /dev/null +++ b/tests/test_checkpoint_integration.py @@ -0,0 +1,82 @@ +"""End-to-end test: write -> compress -> resume cycle.""" +import json +import pytest +import yaml +from datetime import datetime, timedelta +from unittest.mock import MagicMock +from agent.checkpoint_store import CheckpointStore +from agent.checkpoint_injection import build_checkpoint_system_prompt +from tools.checkpoint_tool import checkpoint_tool + + +@pytest.fixture +def store(tmp_path): + return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + + +def test_full_cycle_write_compress_resume(store): + """Simulate: agent writes checkpoint -> new session gets injection.""" + old_agent = MagicMock() + old_agent.session_id = "20260422_225800_abc123" + old_agent._todo_store = MagicMock() + old_agent._todo_store.format_for_injection.return_value = "" + + checkpoint_tool( + action="write", + task="Red team QA on PR #14303", + progress=[ + {"step": "Read code", "status": "completed"}, + {"step": "Write tests", "status": "in_progress"}, + ], + state={ + "active_branch": "fix/delegation", + "tests_status": "19/19 passing", + "last_commit": "a90557d", + "pushed": True, + "working_directory": "/tmp/repo", + }, + decisions=["SSRF is low-severity, doc note only"], + blocked=[], + unresolved=[], + store=store, + agent=old_agent, + ) + + # New session starts, checkpoint gets injected + prompt = build_checkpoint_system_prompt( + store=store, + parent_session_id="20260422_225800_abc123", + ) + + assert "Red team QA on PR #14303" in prompt + assert "Write tests" in prompt + assert "fix/delegation" in prompt + assert "SSRF is low-severity" in prompt + + +def test_gc_removes_stale_checkpoints(store): + """GC removes checkpoints older than 7 days.""" + mock_agent = MagicMock() + mock_agent.session_id = "old_sess" + mock_agent._todo_store = MagicMock() + mock_agent._todo_store.format_for_injection.return_value = "" + + checkpoint_tool( + action="write", + task="Old task", + progress=[], + state={}, + decisions=[], + store=store, + agent=mock_agent, + ) + # Age the checkpoint by writing directly with an old timestamp + data = store.read("old_sess") + data["updated"] = (datetime.now() - timedelta(days=10)).isoformat() + store._path_for("old_sess").write_text( + yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + + removed = store.garbage_collect(max_age_days=7) + assert "old_sess" in removed \ No newline at end of file diff --git a/tests/test_checkpoint_regression.py b/tests/test_checkpoint_regression.py new file mode 100644 index 0000000000..2870534884 --- /dev/null +++ b/tests/test_checkpoint_regression.py @@ -0,0 +1,144 @@ +"""Regression tests for bugs found during red-team QA review. + +Bug #1: get_parent_session_id() didn't exist on SessionDB +Bug #2: Path traversal via session_id +Bug #3: "checkpoint" not in any toolset definition +Bug #4: flush_checkpoint test didn't actually call flush_checkpoint +Bug #5: Injection fallback grabbed ANY in_progress checkpoint from unrelated sessions +""" +import pytest +from pathlib import Path + + +class TestBug1SessionDBLineageWalk: + """Bug #1: _build_system_prompt must use get_session()["parent_session_id"] + instead of the nonexistent get_parent_session_id() method. + """ + + def test_session_db_has_no_get_parent_session_id_method(self): + """SessionDB must NOT have a get_parent_session_id method (it never did).""" + from hermes_state import SessionDB + assert not hasattr(SessionDB, "get_parent_session_id"), ( + "SessionDB should not have get_parent_session_id -- " + "use get_session(sid)['parent_session_id'] instead" + ) + + def test_get_session_returns_parent_session_id_key(self, tmp_path): + """get_session() must return a dict with 'parent_session_id' key.""" + from hermes_state import SessionDB + from pathlib import Path + db = SessionDB(db_path=Path(tmp_path) / "test.db") + # Create a parent session + parent_id = db.create_session("parent_session", source="test", model="test") + # Create a child session + child_id = db.create_session("child_session", source="test", model="test", parent_session_id=parent_id) + sess = db.get_session(child_id) + assert sess is not None + assert "parent_session_id" in sess + assert sess["parent_session_id"] == parent_id + + +class TestBug2PathTraversal: + """Bug #2: session_id must be validated to prevent path traversal.""" + + def test_reject_dotdot_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("../../../etc/passwd", {"task": "evil"}) + + def test_reject_slash_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("sub/dir", {"task": "evil"}) + + def test_reject_backslash_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("sub\\dir", {"task": "evil"}) + + def test_accept_normal_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Should not raise + store.write("20260422_225800_abc123", {"task": "normal"}) + + def test_read_rejects_path_traversal(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.read("../../../etc/passwd") + + def test_delete_rejects_path_traversal(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.delete("../../../etc/passwd") + + +class TestBug3CheckpointInToolset: + """Bug #3: 'checkpoint' must be in a toolset definition so the model gets the schema.""" + + def test_checkpoint_in_todo_toolset(self): + from toolsets import TOOLSETS + todo_tools = TOOLSETS.get("todo", {}).get("tools", []) + assert "checkpoint" in todo_tools, ( + f"'checkpoint' must be in the 'todo' toolset tools list. " + f"Current todo tools: {todo_tools}" + ) + + def test_resolve_toolset_todo_includes_checkpoint(self): + from toolsets import resolve_toolset + resolved = resolve_toolset("todo") + assert "checkpoint" in resolved, ( + f"resolve_toolset('todo') must return ['todo', 'checkpoint']. Got: {resolved}" + ) + + +class TestBug5InjectionNoCrossSession: + """Bug #5: build_checkpoint_system_prompt must NOT fall through to grabbing + any random in_progress checkpoint from an unrelated session. + """ + + def test_no_injection_when_parent_has_no_checkpoint(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + from agent.checkpoint_injection import build_checkpoint_system_prompt + + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Write a checkpoint for an unrelated session + store.write("unrelated_session", { + "task": "Unrelated task", + "status": "in_progress", + "progress": [], + }) + + # Ask for a parent_session_id that has no checkpoint + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="missing_session") + assert prompt == "", ( + "Must not inject a checkpoint from an unrelated session when " + "the requested parent_session_id has no checkpoint" + ) + + def test_only_injects_specific_parent(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + from agent.checkpoint_injection import build_checkpoint_system_prompt + + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Write checkpoints for two sessions + store.write("parent_A", { + "task": "Parent A task", + "status": "in_progress", + "progress": [], + }) + store.write("parent_B", { + "task": "Parent B task", + "status": "in_progress", + "progress": [], + }) + + # Request parent_A specifically + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="parent_A") + assert "Parent A task" in prompt + assert "Parent B task" not in prompt \ No newline at end of file diff --git a/tests/test_checkpoint_store.py b/tests/test_checkpoint_store.py new file mode 100644 index 0000000000..aed5a84e89 --- /dev/null +++ b/tests/test_checkpoint_store.py @@ -0,0 +1,118 @@ +import pytest +import yaml +from datetime import datetime, timedelta +from pathlib import Path +from agent.checkpoint_store import CheckpointStore + + +@pytest.fixture +def store(tmp_path): + return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + + +@pytest.fixture +def sample_checkpoint(): + return { + "task": "Red team QA on PR #14303", + "status": "in_progress", + "created": "2026-04-22T22:58:00", + "updated": "2026-04-22T23:15:00", + "progress": [ + {"step": "Read code", "status": "completed"}, + {"step": "Write tests", "status": "in_progress"}, + ], + "state": { + "active_branch": "fix/delegation", + "files_changed": ["tools/delegate_tool.py"], + "tests_status": "22/22 passing", + "last_commit": "abc123", + "pushed": True, + "working_directory": "/tmp/repo", + }, + "decisions": ["No SSRF fix needed"], + "blocked": [], + "unresolved": [], + } + + +class TestCheckpointStoreWrite: + def test_write_creates_yaml_file(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + path = store._path_for("sess_001") + assert path.exists() + + def test_write_content_is_valid_yaml(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + path = store._path_for("sess_001") + data = yaml.safe_load(path.read_text()) + assert data["task"] == "Red team QA on PR #14303" + + def test_write_overwrites_existing(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + sample_checkpoint["status"] = "completed" + store.write("sess_001", sample_checkpoint) + data = store.read("sess_001") + assert data["status"] == "completed" + + +class TestCheckpointStoreRead: + def test_read_existing(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + data = store.read("sess_001") + assert data["task"] == "Red team QA on PR #14303" + + def test_read_nonexistent_returns_none(self, store): + assert store.read("no_such_session") is None + + def test_read_corrupt_yaml_returns_none(self, store): + path = store._path_for("sess_bad") + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("{{invalid yaml: [") + assert store.read("sess_bad") is None + + +class TestCheckpointStoreDelete: + def test_delete_removes_file(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + store.delete("sess_001") + assert not store._path_for("sess_001").exists() + + def test_delete_nonexistent_is_noop(self, store): + store.delete("no_such_session") # no error + + +class TestCheckpointStoreList: + def test_list_returns_all_session_ids(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + store.write("sess_002", sample_checkpoint) + ids = store.list_sessions() + assert sorted(ids) == ["sess_001", "sess_002"] + + def test_list_empty_dir(self, store): + assert store.list_sessions() == [] + + +class TestCheckpointStoreGC: + def test_gc_removes_old_checkpoints(self, store, sample_checkpoint): + # Write an old checkpoint directly to disk (bypass write() which stamps now()) + old_time = (datetime.now() - timedelta(days=10)).isoformat() + sample_checkpoint["updated"] = old_time + path_old = store._path_for("sess_old") + path_old.parent.mkdir(parents=True, exist_ok=True) + path_old.write_text( + yaml.dump(sample_checkpoint, default_flow_style=False, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + # Write a recent one through the normal path + store.write("sess_new", sample_checkpoint) + removed = store.garbage_collect(max_age_days=7) + assert "sess_old" in removed + assert "sess_new" not in removed + assert store.read("sess_old") is None + assert store.read("sess_new") is not None + + def test_gc_keeps_all_when_fresh(self, store, sample_checkpoint): + store.write("sess_001", sample_checkpoint) + removed = store.garbage_collect(max_age_days=7) + assert removed == [] + assert store.read("sess_001") is not None \ No newline at end of file diff --git a/tests/test_checkpoint_tool.py b/tests/test_checkpoint_tool.py new file mode 100644 index 0000000000..80119e9e3a --- /dev/null +++ b/tests/test_checkpoint_tool.py @@ -0,0 +1,130 @@ +import json +import pytest +from unittest.mock import MagicMock +from agent.checkpoint_store import CheckpointStore +from tools.checkpoint_tool import checkpoint_tool + + +@pytest.fixture +def store(tmp_path): + return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + + +@pytest.fixture +def agent_context(store): + """Simulate the fields the tool handler pulls from the AIAgent.""" + ctx = MagicMock() + ctx.session_id = "test_session_001" + ctx._todo_store = MagicMock() + ctx._todo_store.format_for_injection.return_value = "- [x] Step 1\n- [ ] Step 2" + return ctx + + +class TestCheckpointWrite: + def test_write_basic_checkpoint(self, store, agent_context): + result = checkpoint_tool( + action="write", + task="Build feature X", + progress=[{"step": "Setup", "status": "completed"}], + state={"active_branch": "main"}, + decisions=["Use SQLite"], + store=store, + agent=agent_context, + ) + data = json.loads(result) + assert data["success"] is True + assert data["session_id"] == "test_session_001" + saved = store.read("test_session_001") + assert saved["task"] == "Build feature X" + assert saved["status"] == "in_progress" + + def test_write_includes_auto_fields(self, store, agent_context): + checkpoint_tool( + action="write", + task="Test task", + progress=[], + state={}, + decisions=[], + store=store, + agent=agent_context, + ) + saved = store.read("test_session_001") + assert "created" in saved + assert "updated" in saved + assert saved["session_id"] == "test_session_001" + + +class TestCheckpointUpdate: + def test_update_merges_progress(self, store, agent_context): + checkpoint_tool( + action="write", + task="Build feature X", + progress=[{"step": "Setup", "status": "completed"}], + state={"active_branch": "main"}, + decisions=[], + store=store, + agent=agent_context, + ) + result = checkpoint_tool( + action="update", + progress=[{"step": "Implement", "status": "in_progress"}], + state={"active_branch": "feat/x"}, + store=store, + agent=agent_context, + ) + data = json.loads(result) + assert data["success"] is True + saved = store.read("test_session_001") + assert saved["task"] == "Build feature X" # preserved + assert len(saved["progress"]) == 2 # merged + + def test_update_nonexistent_returns_error(self, store, agent_context): + result = checkpoint_tool( + action="update", + progress=[], + state={}, + store=store, + agent=agent_context, + ) + data = json.loads(result) + assert data["success"] is False + + +class TestCheckpointRead: + def test_read_existing(self, store, agent_context): + checkpoint_tool( + action="write", + task="Test task", + progress=[], + state={}, + decisions=[], + store=store, + agent=agent_context, + ) + result = checkpoint_tool(action="read", store=store, agent=agent_context) + data = json.loads(result) + assert data["success"] is True + assert data["checkpoint"]["task"] == "Test task" + + def test_read_nonexistent(self, store, agent_context): + result = checkpoint_tool(action="read", store=store, agent=agent_context) + data = json.loads(result) + assert data["success"] is False + assert "no checkpoint" in data["error"].lower() + + +class TestCheckpointClear: + def test_clear_removes_checkpoint(self, store, agent_context): + checkpoint_tool( + action="write", + task="Test task", + progress=[], + state={}, + decisions=[], + store=store, + agent=agent_context, + ) + result = checkpoint_tool(action="clear", store=store, agent=agent_context) + data = json.loads(result) + assert data["success"] is True + assert store.read("test_session_001") is None \ No newline at end of file diff --git a/tools/checkpoint_tool.py b/tools/checkpoint_tool.py new file mode 100644 index 0000000000..f66455e2c4 --- /dev/null +++ b/tools/checkpoint_tool.py @@ -0,0 +1,264 @@ +"""Checkpoint tool -- save and restore mid-task state across context compaction. + +The checkpoint tool lets the agent save its current task progress, state, and +decisions to disk so it can resume after context compaction or session restart. +This is Layer 4 in the memory architecture (after Layer 1 memory, Layer 2 +personality, Layer 3 vault). + +Checkpoint files are stored as YAML in ~/.hermes/checkpoints/.yaml. +""" + +import json +import logging +import subprocess +from datetime import datetime +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +CHECKPOINT_TOOL_SCHEMA = { + "name": "checkpoint", + "description": ( + "Save or restore mid-task state (checkpoint). " + "Write a checkpoint before risky operations or when you have made significant progress. " + "The system also auto-writes a checkpoint before context compaction. " + "On session start, any existing checkpoint from the parent session is auto-injected " + "so you can resume where you left off. " + "Use 'write' to create/overwrite, 'update' to merge into existing, " + "'read' to check current checkpoint, 'clear' to delete." + ), + "parameters": { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["write", "update", "read", "clear"], + "description": ( + "write: Create or overwrite checkpoint for current session. " + "update: Merge progress/state/decisions into existing checkpoint. " + "read: Read current checkpoint. " + "clear: Delete current checkpoint (task done or abandoned)." + ), + }, + "task": { + "type": "string", + "description": "One-line description of what you are working on. Required for 'write'.", + }, + "progress": { + "type": "array", + "items": { + "type": "object", + "properties": { + "step": {"type": "string", "description": "What this step does"}, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed", "cancelled"], + }, + "result": { + "type": "string", + "description": "Outcome (optional, for completed steps)", + }, + }, + "required": ["step", "status"], + }, + "description": "Ordered list of task steps with status. Each step: {step, status, result?}.", + }, + "state": { + "type": "object", + "properties": { + "active_branch": {"type": "string"}, + "files_changed": {"type": "array", "items": {"type": "string"}}, + "tests_status": {"type": "string"}, + "last_commit": {"type": "string"}, + "pushed": {"type": "boolean"}, + "working_directory": {"type": "string"}, + }, + "description": "Machine-readable facts about the current state: branch, files, test results, commits.", + }, + "decisions": { + "type": "array", + "items": {"type": "string"}, + "description": "Non-obvious choices made during this task (the 'why', not the 'what').", + }, + "blocked": { + "type": "array", + "items": {"type": "string"}, + "description": "Things blocked on external input. Empty if unblocked.", + }, + "unresolved": { + "type": "array", + "items": {"type": "string"}, + "description": "Open questions or unknowns. Empty if none.", + }, + }, + "required": ["action"], + }, +} + + +def _git_state(workdir: str = None) -> Dict[str, Any]: + """Best-effort capture of git state from the working directory.""" + state = {} + if not workdir: + return state + try: + branch = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, text=True, timeout=5, cwd=workdir, + ) + if branch.returncode == 0: + state["active_branch"] = branch.stdout.strip() + + commit = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, text=True, timeout=5, cwd=workdir, + ) + if commit.returncode == 0: + state["last_commit"] = commit.stdout.strip() + except (OSError, subprocess.TimeoutExpired): + pass + return state + + +def checkpoint_tool( + action: str, + task: str = None, + progress: List[Dict] = None, + state: Dict[str, Any] = None, + decisions: List[str] = None, + blocked: List[str] = None, + unresolved: List[str] = None, + store=None, + agent=None, +) -> str: + """Execute a checkpoint action. Returns JSON result string.""" + if store is None: + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore() + + session_id = getattr(agent, "session_id", "unknown") if agent else "unknown" + + if action == "write": + if not task: + return json.dumps({"success": False, "error": "'task' is required for write action"}) + + # Auto-populate git state if not provided + effective_state = dict(state or {}) + workdir = effective_state.get("working_directory") + if workdir and "active_branch" not in effective_state: + git = _git_state(workdir) + effective_state.update(git) + # Auto-populate from todo store if available + todo_snapshot = None + if agent and hasattr(agent, "_todo_store") and agent._todo_store: + try: + todo_snapshot = agent._todo_store.format_for_injection() + except Exception: + pass + if todo_snapshot: + effective_state["todo_snapshot"] = todo_snapshot + + data = { + "session_id": session_id, + "task": task, + "status": "in_progress", + "created": datetime.now().isoformat(), + "updated": datetime.now().isoformat(), + "progress": progress or [], + "state": effective_state, + "decisions": decisions or [], + "blocked": blocked or [], + "unresolved": unresolved or [], + } + store.write(session_id, data) + return json.dumps({ + "success": True, + "session_id": session_id, + "message": f"Checkpoint saved for session {session_id}", + }) + + elif action == "update": + existing = store.read(session_id) + if not existing: + return json.dumps({ + "success": False, + "error": f"No checkpoint exists for session {session_id}. Use 'write' first.", + }) + + # Merge progress (append new steps) + if progress: + existing["progress"] = existing.get("progress", []) + progress + + # Merge state (overwrite keys) + if state: + existing["state"] = {**existing.get("state", {}), **state} + + # Append decisions + if decisions: + existing["decisions"] = existing.get("decisions", []) + decisions + + # Replace blocked/unresolved + if blocked is not None: + existing["blocked"] = blocked + if unresolved is not None: + existing["unresolved"] = unresolved + + # Update git state if workdir provided + workdir = (state or {}).get("working_directory") if state else None + if workdir: + git = _git_state(workdir) + existing["state"] = {**existing.get("state", {}), **git} + + store.write(session_id, existing) + return json.dumps({ + "success": True, + "session_id": session_id, + "message": f"Checkpoint updated for session {session_id}", + }) + + elif action == "read": + data = store.read(session_id) + if not data: + return json.dumps({ + "success": False, + "error": f"No checkpoint exists for session {session_id}", + }) + return json.dumps({ + "success": True, + "session_id": session_id, + "checkpoint": data, + }) + + elif action == "clear": + store.delete(session_id) + return json.dumps({ + "success": True, + "session_id": session_id, + "message": f"Checkpoint cleared for session {session_id}", + }) + + else: + return json.dumps({"success": False, "error": f"Unknown action: {action}"}) + + +# --- Registry --- +from tools.registry import registry + +registry.register( + name="checkpoint", + toolset="todo", + schema=CHECKPOINT_TOOL_SCHEMA, + handler=lambda args, **kw: checkpoint_tool( + action=args.get("action"), + task=args.get("task"), + progress=args.get("progress"), + state=args.get("state"), + decisions=args.get("decisions"), + blocked=args.get("blocked"), + unresolved=args.get("unresolved"), + store=kw.get("store"), + agent=kw.get("agent"), + ), + check_fn=lambda: True, # always available + emoji="🔖", +) \ No newline at end of file diff --git a/toolsets.py b/toolsets.py index 65f560bfe4..b6619d3ab7 100644 --- a/toolsets.py +++ b/toolsets.py @@ -159,7 +159,7 @@ TOOLSETS = { "todo": { "description": "Task planning and tracking for multi-step work", - "tools": ["todo"], + "tools": ["todo", "checkpoint"], "includes": [] },