This commit is contained in:
aj-nt 2026-04-24 23:33:20 +00:00 committed by GitHub
commit 9a56339ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1218 additions and 2 deletions

View file

@ -0,0 +1,86 @@
"""Checkpoint injection -- load checkpoint from parent session into system prompt.
When a new session starts (or a compression creates a continuation session),
this module finds the most recent checkpoint from the session's lineage
and injects it into the system prompt.
This is Layer 4 in the memory architecture. Only in_progress checkpoints
are injected -- completed ones are skipped.
"""
import logging
from typing import Optional
logger = logging.getLogger(__name__)
SEPARATOR = "\u2550" * 46 # same as vault injection
CHAR_LIMIT = 4000
def build_checkpoint_system_prompt(store, parent_session_id: str = None) -> str:
"""Build the checkpoint injection block for the system prompt.
Returns empty string if no applicable checkpoint exists.
The caller is responsible for walking the parent lineage chain to find
the nearest ancestor session that has a checkpoint file.
"""
if not store or not parent_session_id:
return ""
data = store.read(parent_session_id)
if not data or not isinstance(data, dict):
return ""
if data.get("status") == "completed":
return ""
lines = []
lines.append(f"Task: {data.get('task', 'Unknown')}")
lines.append(f"Status: {data.get('status', 'in_progress')}")
lines.append("")
progress = data.get("progress", [])
if progress:
lines.append("Progress:")
for step in progress:
status_icon = {
"completed": "[x]", "in_progress": "[~]",
"pending": "[ ]", "cancelled": "[-]",
}.get(step.get("status", "pending"), "[ ]")
line = f" {status_icon} {step.get('step', '')}"
if step.get("result"):
line += f" -- {step['result']}"
lines.append(line)
lines.append("")
state = data.get("state", {})
if state:
lines.append("State:")
for key, value in state.items():
if value and str(value) not in ("[]", ""):
lines.append(f" {key}: {value}")
lines.append("")
decisions = data.get("decisions", [])
if decisions:
lines.append("Decisions:")
for d in decisions:
lines.append(f" - {d}")
lines.append("")
blocked = data.get("blocked", [])
if blocked:
lines.append(f"Blocked: {', '.join(blocked)}")
unresolved = data.get("unresolved", [])
if unresolved:
lines.append(f"Unresolved: {', '.join(unresolved)}")
content = "\n".join(lines)
if len(content) > CHAR_LIMIT:
content = content[:CHAR_LIMIT] + "\n[... truncated at char limit ...]"
header = "CHECKPOINT: RESUME FROM HERE (saved before compaction, injected on session start)"
return f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{content}"

117
agent/checkpoint_store.py Normal file
View file

@ -0,0 +1,117 @@
"""Checkpoint store -- read/write YAML checkpoint files for session resumption.
Each checkpoint captures enough state to resume a task from where the agent
left off after context compaction or session restart. Files live under
~/.hermes/checkpoints/<session_id>.yaml.
"""
import logging
import re
import yaml
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# Session IDs must be safe filesystem names -- no path separators or traversal
_SAFE_SESSION_ID = re.compile(r"^[A-Za-z0-9_-]+$")
DEFAULT_CHECKPOINTS_DIR = Path.home() / ".hermes" / "checkpoints"
DEFAULT_GC_MAX_AGE_DAYS = 7
class CheckpointStore:
"""Manages checkpoint YAML files on disk."""
def __init__(self, checkpoints_dir: Path = None):
self._dir = Path(checkpoints_dir) if checkpoints_dir else DEFAULT_CHECKPOINTS_DIR
self._dir.mkdir(parents=True, exist_ok=True)
def _path_for(self, session_id: str) -> Path:
if not _SAFE_SESSION_ID.match(session_id):
raise ValueError(f"Invalid session_id: {session_id!r} (must match {_SAFE_SESSION_ID.pattern})")
return self._dir / f"{session_id}.yaml"
def write(self, session_id: str, data: Dict[str, Any]) -> None:
"""Write a checkpoint for the given session. Overwrites if exists."""
path = self._path_for(session_id)
path.parent.mkdir(parents=True, exist_ok=True)
# Ensure the 'updated' timestamp is current
data["updated"] = datetime.now().isoformat()
try:
path.write_text(
yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
logger.debug("Checkpoint written: %s (%d chars)", session_id, path.stat().st_size)
except (OSError, yaml.YAMLError) as e:
logger.warning("Failed to write checkpoint %s: %s", session_id, e)
def read(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Read a checkpoint. Returns None if missing or corrupt."""
path = self._path_for(session_id)
if not path.exists():
return None
try:
content = path.read_text(encoding="utf-8")
data = yaml.safe_load(content)
if not isinstance(data, dict):
return None
return data
except (OSError, yaml.YAMLError) as e:
logger.debug("Failed to read checkpoint %s: %s", session_id, e)
return None
def delete(self, session_id: str) -> None:
"""Delete a checkpoint. No-op if missing."""
path = self._path_for(session_id)
try:
path.unlink(missing_ok=True)
except OSError:
pass
def list_sessions(self) -> List[str]:
"""Return all session IDs that have checkpoint files."""
if not self._dir.exists():
return []
return sorted(p.stem for p in self._dir.glob("*.yaml"))
def garbage_collect(self, max_age_days: int = DEFAULT_GC_MAX_AGE_DAYS) -> List[str]:
"""Remove stale checkpoints. Returns list of removed session IDs.
A checkpoint is stale if:
- Its 'updated' timestamp is older than max_age_days AND not completed
- OR it's 'completed' and older than 1 day (task done, no need to keep)
"""
removed = []
cutoff = datetime.now() - timedelta(days=max_age_days)
completed_cutoff = datetime.now() - timedelta(days=1)
for path in self._dir.glob("*.yaml"):
try:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
path.unlink(missing_ok=True)
removed.append(path.stem)
continue
updated_str = data.get("updated", "")
status = data.get("status", "in_progress")
try:
updated = datetime.fromisoformat(updated_str) if updated_str else cutoff
except ValueError:
updated = cutoff
if status == "completed" and updated < completed_cutoff:
path.unlink(missing_ok=True)
removed.append(path.stem)
elif updated < cutoff:
path.unlink(missing_ok=True)
removed.append(path.stem)
except (OSError, yaml.YAMLError):
path.unlink(missing_ok=True)
removed.append(path.stem)
if removed:
logger.info("Checkpoint GC: removed %d stale file(s): %s", len(removed), removed)
return removed

View file

@ -176,6 +176,20 @@ SKILLS_GUIDANCE = (
"Skills that aren't maintained become liabilities."
)
CHECKPOINT_GUIDANCE = (
"When a CHECKPOINT block appears at session start, read it and resume from "
"where the previous session left off. Do not start over — pick up at the "
"first in_progress step.\n"
"Write a checkpoint (action=\"write\") at these moments:\n"
"- Before a risky command that could break state\n"
"- After completing a major step in a multi-step task (then action=\"update\" "
"for subsequent steps)\n"
"- When you have important \"why\" decisions that won't be obvious later\n"
"Clear the checkpoint (action=\"clear\") when the task is fully done.\n"
"Pre-compression checkpoints are auto-written — you do not need to worry "
"about losing progress when the context window fills."
)
TOOL_USE_ENFORCEMENT_GUIDANCE = (
"# Tool-use enforcement\n"
"You MUST use your tools to take action — do not describe what you would do "

View file

@ -366,7 +366,7 @@ def get_tool_definitions(
# because they need agent-level state (TodoStore, MemoryStore, etc.).
# The registry still holds their schemas; dispatch just returns a stub error
# so if something slips through, the LLM sees a sensible message.
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task", "checkpoint"}
_READ_SEARCH_TOOLS = {"read_file", "search_files"}

View file

@ -85,6 +85,7 @@ from agent.error_classifier import classify_api_error, FailoverReason
from agent.prompt_builder import (
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
CHECKPOINT_GUIDANCE,
build_nous_subscription_prompt,
)
from agent.model_metadata import (
@ -1561,6 +1562,11 @@ class AIAgent:
# In-memory todo list for task planning (one per agent/session)
from tools.todo_tool import TodoStore
self._todo_store = TodoStore()
# Layer 4: Checkpoint store (session resumption across compression)
from agent.checkpoint_store import CheckpointStore
self._checkpoint_store = CheckpointStore()
self._checkpoint_store.garbage_collect() # prune stale checkpoints on startup
# Load config once for memory, skills, and compression sections
try:
@ -4391,6 +4397,8 @@ class AIAgent:
tool_guidance.append(SESSION_SEARCH_GUIDANCE)
if "skill_manage" in self.valid_tool_names:
tool_guidance.append(SKILLS_GUIDANCE)
if "checkpoint" in self.valid_tool_names:
tool_guidance.append(CHECKPOINT_GUIDANCE)
if tool_guidance:
prompt_parts.append(" ".join(tool_guidance))
@ -4448,6 +4456,54 @@ class AIAgent:
if user_block:
prompt_parts.append(user_block)
# Vault auto-injection (Layer 3) — reads working-context.md and
# user-profile.md from the Obsidian vault and injects them into
# the system prompt. Structural fix for vault neglect.
_vault_log = logging.getLogger("agent.vault")
if self._vault_enabled and self._vault_path:
try:
from agent.vault_injection import build_vault_system_prompt
_vault_block = build_vault_system_prompt(self._vault_path)
if _vault_block:
prompt_parts.append(_vault_block)
_vault_log.info("Injection succeeded: %d chars from %s", len(_vault_block), self._vault_path)
else:
_vault_log.warning("Injection returned empty for path %s", self._vault_path)
except Exception as e:
_vault_log.warning("Injection failed: %s: %s", type(e).__name__, e)
else:
_vault_log.info("Injection skipped: enabled=%s path=%s", self._vault_enabled, self._vault_path)
# Checkpoint injection (Layer 4) -- resume from where the previous session left off
if self._checkpoint_store:
try:
_checkpoint_session_id = None
if self._session_db:
_sid = self.session_id
# Walk up the parent chain to find the nearest ancestor with a checkpoint
for _ in range(5): # max 5 hops up lineage
try:
_sess = self._session_db.get_session(_sid)
except Exception:
break
if not _sess:
break
_parent = _sess.get("parent_session_id")
if not _parent:
break
if self._checkpoint_store.read(_parent) is not None:
_checkpoint_session_id = _parent
break
_sid = _parent # keep walking up
from agent.checkpoint_injection import build_checkpoint_system_prompt
_cp_block = build_checkpoint_system_prompt(
store=self._checkpoint_store,
parent_session_id=_checkpoint_session_id,
)
if _cp_block:
prompt_parts.append(_cp_block)
except Exception as e:
logger.debug("Checkpoint injection skipped: %s", e)
# External memory provider system prompt block (additive to built-in)
if self._memory_manager:
try:
@ -8119,6 +8175,52 @@ class AIAgent:
if messages and messages[-1].get("_flush_sentinel") == _sentinel:
messages.pop()
def flush_checkpoint(self, messages: list = None) -> None:
"""Write a checkpoint before context compression destroys progress.
Called by _compress_context before the compression happens.
Captures whatever task state we can extract from the agent's
current context (todo list, session ID, git state) and writes
a checkpoint. If a checkpoint already exists for this session,
it gets overwritten -- the pre-compression checkpoint is always
the most accurate.
"""
if not self._checkpoint_store:
return
from tools.checkpoint_tool import checkpoint_tool
# Extract task description from session title or fallback
task_desc = "Session checkpoint (auto-saved before compression)"
try:
if self._session_db:
title = self._session_db.get_session_title(self.session_id)
if title:
task_desc = title
except Exception:
pass
# Extract progress from todo store if available
progress = []
if self._todo_store:
try:
items = self._todo_store._items if hasattr(self._todo_store, '_items') else []
for item in items:
step = {"step": item.get("content", ""), "status": item.get("status", "pending")}
progress.append(step)
except Exception:
pass
checkpoint_tool(
action="write",
task=task_desc,
progress=progress,
state={},
decisions=[],
store=self._checkpoint_store,
agent=self,
)
logger.info("Pre-compression checkpoint saved for session %s", self.session_id)
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default", focus_topic: str = None) -> tuple:
"""Compress conversation context and split the session in SQLite.
@ -8140,6 +8242,9 @@ class AIAgent:
# Pre-compression memory flush: let the model save memories before they're lost
self.flush_memories(messages, min_turns=0)
# Pre-compression checkpoint: save task state before context is lost
self.flush_checkpoint(messages)
# Notify external memory provider before compression discards context
if self._memory_manager:
try:
@ -8351,6 +8456,19 @@ class AIAgent:
)
elif function_name == "delegate_task":
return self._dispatch_delegate_task(function_args)
elif function_name == "checkpoint":
from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool
return _checkpoint_tool(
action=function_args.get("action"),
task=function_args.get("task"),
progress=function_args.get("progress"),
state=function_args.get("state"),
decisions=function_args.get("decisions"),
blocked=function_args.get("blocked"),
unresolved=function_args.get("unresolved"),
store=self._checkpoint_store,
agent=self,
)
else:
return handle_function_call(
function_name, function_args, effective_task_id,
@ -8892,6 +9010,22 @@ class AIAgent:
spinner.stop(cute_msg)
elif self._should_emit_quiet_tool_messages():
self._vprint(f" {cute_msg}")
elif function_name == "checkpoint":
from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool
function_result = _checkpoint_tool(
action=function_args.get("action"),
task=function_args.get("task"),
progress=function_args.get("progress"),
state=function_args.get("state"),
decisions=function_args.get("decisions"),
blocked=function_args.get("blocked"),
unresolved=function_args.get("unresolved"),
store=self._checkpoint_store,
agent=self,
)
tool_duration = time.time() - tool_start_time
if self._should_emit_quiet_tool_messages():
self._vprint(f" {_get_cute_tool_message_impl('checkpoint', function_args, tool_duration, result=function_result)}")
elif self._context_engine_tool_names and function_name in self._context_engine_tool_names:
# Context engine tools (lcm_grep, lcm_describe, lcm_expand, etc.)
spinner = None

View file

@ -0,0 +1,16 @@
"""Verify that checkpoint is routed through the agent loop, not the registry."""
import json
import pytest
from model_tools import handle_function_call, _AGENT_LOOP_TOOLS
def test_checkpoint_is_agent_loop_tool():
"""checkpoint must be in _AGENT_LOOP_TOOLS so it gets agent-level state."""
assert "checkpoint" in _AGENT_LOOP_TOOLS
def test_checkpoint_registry_dispatch_returns_error():
"""Calling handle_function_call for checkpoint should return agent-loop error."""
result = handle_function_call("checkpoint", {"action": "read"})
data = json.loads(result)
assert "must be handled" in data["error"].lower()

View file

@ -0,0 +1,53 @@
"""Test that flush_checkpoint writes a checkpoint via checkpoint_tool."""
import json
import pytest
from unittest.mock import MagicMock, patch
def test_flush_checkpoint_method_exists():
"""AIAgent must have a flush_checkpoint method."""
from run_agent import AIAgent
assert hasattr(AIAgent, "flush_checkpoint")
def test_flush_checkpoint_writes_to_store(tmp_path):
"""flush_checkpoint should write a checkpoint with current session state."""
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
mock_todo = MagicMock()
mock_todo._items = [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "in_progress"},
]
mock_todo.format_for_injection.return_value = "- [x] Step 1\n- [~] Step 2"
mock_session_db = MagicMock()
mock_session_db.get_session_title.return_value = "Test session title"
mock_agent = MagicMock()
mock_agent.session_id = "flush_test_session"
mock_agent._checkpoint_store = store
mock_agent._todo_store = mock_todo
mock_agent._session_db = mock_session_db
# Call the real flush_checkpoint method
from run_agent import AIAgent
AIAgent.flush_checkpoint(mock_agent)
# Verify a checkpoint was written for this session
saved = store.read("flush_test_session")
assert saved is not None
assert saved["task"] == "Test session title"
assert saved["status"] == "in_progress"
def test_flush_checkpoint_noops_without_store(tmp_path):
"""flush_checkpoint should silently return if _checkpoint_store is None."""
mock_agent = MagicMock()
mock_agent._checkpoint_store = None
from run_agent import AIAgent
# Should not raise
AIAgent.flush_checkpoint(mock_agent)

View file

@ -0,0 +1,58 @@
import pytest
from agent.checkpoint_store import CheckpointStore
from agent.checkpoint_injection import build_checkpoint_system_prompt
@pytest.fixture
def store(tmp_path):
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
@pytest.fixture
def parent_checkpoint():
return {
"session_id": "old_session_001",
"task": "Red team QA on PR #14303",
"status": "in_progress",
"created": "2026-04-22T22:58:00",
"updated": "2026-04-22T23:15:00",
"progress": [
{"step": "Read code paths", "status": "completed"},
{"step": "Add edge-case tests", "status": "in_progress"},
],
"state": {"active_branch": "fix/delegation", "tests_status": "22/22 passing"},
"decisions": ["SSRF is low-severity"],
"blocked": [],
"unresolved": [],
}
class TestBuildCheckpointPrompt:
def test_injects_in_progress_checkpoint(self, store, parent_checkpoint):
store.write("old_session_001", parent_checkpoint)
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
assert "Red team QA on PR #14303" in prompt
assert "CHECKPOINT" in prompt
def test_skips_completed_checkpoints(self, store, parent_checkpoint):
parent_checkpoint["status"] = "completed"
store.write("old_session_001", parent_checkpoint)
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
assert prompt == ""
def test_returns_empty_when_no_parent(self, store):
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="nonexistent")
assert prompt == ""
def test_includes_progress_and_decisions(self, store, parent_checkpoint):
store.write("old_session_001", parent_checkpoint)
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
assert "Read code paths" in prompt
assert "SSRF is low-severity" in prompt
def test_truncates_oversized_checkpoints(self, store, parent_checkpoint):
parent_checkpoint["task"] = "x" * 5000
parent_checkpoint["decisions"] = ["d" * 2000] * 10
store.write("old_session_001", parent_checkpoint)
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
assert len(prompt) < 6000

View file

@ -0,0 +1,82 @@
"""End-to-end test: write -> compress -> resume cycle."""
import json
import pytest
import yaml
from datetime import datetime, timedelta
from unittest.mock import MagicMock
from agent.checkpoint_store import CheckpointStore
from agent.checkpoint_injection import build_checkpoint_system_prompt
from tools.checkpoint_tool import checkpoint_tool
@pytest.fixture
def store(tmp_path):
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
def test_full_cycle_write_compress_resume(store):
"""Simulate: agent writes checkpoint -> new session gets injection."""
old_agent = MagicMock()
old_agent.session_id = "20260422_225800_abc123"
old_agent._todo_store = MagicMock()
old_agent._todo_store.format_for_injection.return_value = ""
checkpoint_tool(
action="write",
task="Red team QA on PR #14303",
progress=[
{"step": "Read code", "status": "completed"},
{"step": "Write tests", "status": "in_progress"},
],
state={
"active_branch": "fix/delegation",
"tests_status": "19/19 passing",
"last_commit": "a90557d",
"pushed": True,
"working_directory": "/tmp/repo",
},
decisions=["SSRF is low-severity, doc note only"],
blocked=[],
unresolved=[],
store=store,
agent=old_agent,
)
# New session starts, checkpoint gets injected
prompt = build_checkpoint_system_prompt(
store=store,
parent_session_id="20260422_225800_abc123",
)
assert "Red team QA on PR #14303" in prompt
assert "Write tests" in prompt
assert "fix/delegation" in prompt
assert "SSRF is low-severity" in prompt
def test_gc_removes_stale_checkpoints(store):
"""GC removes checkpoints older than 7 days."""
mock_agent = MagicMock()
mock_agent.session_id = "old_sess"
mock_agent._todo_store = MagicMock()
mock_agent._todo_store.format_for_injection.return_value = ""
checkpoint_tool(
action="write",
task="Old task",
progress=[],
state={},
decisions=[],
store=store,
agent=mock_agent,
)
# Age the checkpoint by writing directly with an old timestamp
data = store.read("old_sess")
data["updated"] = (datetime.now() - timedelta(days=10)).isoformat()
store._path_for("old_sess").write_text(
yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
removed = store.garbage_collect(max_age_days=7)
assert "old_sess" in removed

View file

@ -0,0 +1,144 @@
"""Regression tests for bugs found during red-team QA review.
Bug #1: get_parent_session_id() didn't exist on SessionDB
Bug #2: Path traversal via session_id
Bug #3: "checkpoint" not in any toolset definition
Bug #4: flush_checkpoint test didn't actually call flush_checkpoint
Bug #5: Injection fallback grabbed ANY in_progress checkpoint from unrelated sessions
"""
import pytest
from pathlib import Path
class TestBug1SessionDBLineageWalk:
"""Bug #1: _build_system_prompt must use get_session()["parent_session_id"]
instead of the nonexistent get_parent_session_id() method.
"""
def test_session_db_has_no_get_parent_session_id_method(self):
"""SessionDB must NOT have a get_parent_session_id method (it never did)."""
from hermes_state import SessionDB
assert not hasattr(SessionDB, "get_parent_session_id"), (
"SessionDB should not have get_parent_session_id -- "
"use get_session(sid)['parent_session_id'] instead"
)
def test_get_session_returns_parent_session_id_key(self, tmp_path):
"""get_session() must return a dict with 'parent_session_id' key."""
from hermes_state import SessionDB
from pathlib import Path
db = SessionDB(db_path=Path(tmp_path) / "test.db")
# Create a parent session
parent_id = db.create_session("parent_session", source="test", model="test")
# Create a child session
child_id = db.create_session("child_session", source="test", model="test", parent_session_id=parent_id)
sess = db.get_session(child_id)
assert sess is not None
assert "parent_session_id" in sess
assert sess["parent_session_id"] == parent_id
class TestBug2PathTraversal:
"""Bug #2: session_id must be validated to prevent path traversal."""
def test_reject_dotdot_in_session_id(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
with pytest.raises(ValueError, match="Invalid session_id"):
store.write("../../../etc/passwd", {"task": "evil"})
def test_reject_slash_in_session_id(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
with pytest.raises(ValueError, match="Invalid session_id"):
store.write("sub/dir", {"task": "evil"})
def test_reject_backslash_in_session_id(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
with pytest.raises(ValueError, match="Invalid session_id"):
store.write("sub\\dir", {"task": "evil"})
def test_accept_normal_session_id(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
# Should not raise
store.write("20260422_225800_abc123", {"task": "normal"})
def test_read_rejects_path_traversal(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
with pytest.raises(ValueError, match="Invalid session_id"):
store.read("../../../etc/passwd")
def test_delete_rejects_path_traversal(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
with pytest.raises(ValueError, match="Invalid session_id"):
store.delete("../../../etc/passwd")
class TestBug3CheckpointInToolset:
"""Bug #3: 'checkpoint' must be in a toolset definition so the model gets the schema."""
def test_checkpoint_in_todo_toolset(self):
from toolsets import TOOLSETS
todo_tools = TOOLSETS.get("todo", {}).get("tools", [])
assert "checkpoint" in todo_tools, (
f"'checkpoint' must be in the 'todo' toolset tools list. "
f"Current todo tools: {todo_tools}"
)
def test_resolve_toolset_todo_includes_checkpoint(self):
from toolsets import resolve_toolset
resolved = resolve_toolset("todo")
assert "checkpoint" in resolved, (
f"resolve_toolset('todo') must return ['todo', 'checkpoint']. Got: {resolved}"
)
class TestBug5InjectionNoCrossSession:
"""Bug #5: build_checkpoint_system_prompt must NOT fall through to grabbing
any random in_progress checkpoint from an unrelated session.
"""
def test_no_injection_when_parent_has_no_checkpoint(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
from agent.checkpoint_injection import build_checkpoint_system_prompt
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
# Write a checkpoint for an unrelated session
store.write("unrelated_session", {
"task": "Unrelated task",
"status": "in_progress",
"progress": [],
})
# Ask for a parent_session_id that has no checkpoint
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="missing_session")
assert prompt == "", (
"Must not inject a checkpoint from an unrelated session when "
"the requested parent_session_id has no checkpoint"
)
def test_only_injects_specific_parent(self, tmp_path):
from agent.checkpoint_store import CheckpointStore
from agent.checkpoint_injection import build_checkpoint_system_prompt
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
# Write checkpoints for two sessions
store.write("parent_A", {
"task": "Parent A task",
"status": "in_progress",
"progress": [],
})
store.write("parent_B", {
"task": "Parent B task",
"status": "in_progress",
"progress": [],
})
# Request parent_A specifically
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="parent_A")
assert "Parent A task" in prompt
assert "Parent B task" not in prompt

View file

@ -0,0 +1,118 @@
import pytest
import yaml
from datetime import datetime, timedelta
from pathlib import Path
from agent.checkpoint_store import CheckpointStore
@pytest.fixture
def store(tmp_path):
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
@pytest.fixture
def sample_checkpoint():
return {
"task": "Red team QA on PR #14303",
"status": "in_progress",
"created": "2026-04-22T22:58:00",
"updated": "2026-04-22T23:15:00",
"progress": [
{"step": "Read code", "status": "completed"},
{"step": "Write tests", "status": "in_progress"},
],
"state": {
"active_branch": "fix/delegation",
"files_changed": ["tools/delegate_tool.py"],
"tests_status": "22/22 passing",
"last_commit": "abc123",
"pushed": True,
"working_directory": "/tmp/repo",
},
"decisions": ["No SSRF fix needed"],
"blocked": [],
"unresolved": [],
}
class TestCheckpointStoreWrite:
def test_write_creates_yaml_file(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
path = store._path_for("sess_001")
assert path.exists()
def test_write_content_is_valid_yaml(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
path = store._path_for("sess_001")
data = yaml.safe_load(path.read_text())
assert data["task"] == "Red team QA on PR #14303"
def test_write_overwrites_existing(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
sample_checkpoint["status"] = "completed"
store.write("sess_001", sample_checkpoint)
data = store.read("sess_001")
assert data["status"] == "completed"
class TestCheckpointStoreRead:
def test_read_existing(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
data = store.read("sess_001")
assert data["task"] == "Red team QA on PR #14303"
def test_read_nonexistent_returns_none(self, store):
assert store.read("no_such_session") is None
def test_read_corrupt_yaml_returns_none(self, store):
path = store._path_for("sess_bad")
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("{{invalid yaml: [")
assert store.read("sess_bad") is None
class TestCheckpointStoreDelete:
def test_delete_removes_file(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
store.delete("sess_001")
assert not store._path_for("sess_001").exists()
def test_delete_nonexistent_is_noop(self, store):
store.delete("no_such_session") # no error
class TestCheckpointStoreList:
def test_list_returns_all_session_ids(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
store.write("sess_002", sample_checkpoint)
ids = store.list_sessions()
assert sorted(ids) == ["sess_001", "sess_002"]
def test_list_empty_dir(self, store):
assert store.list_sessions() == []
class TestCheckpointStoreGC:
def test_gc_removes_old_checkpoints(self, store, sample_checkpoint):
# Write an old checkpoint directly to disk (bypass write() which stamps now())
old_time = (datetime.now() - timedelta(days=10)).isoformat()
sample_checkpoint["updated"] = old_time
path_old = store._path_for("sess_old")
path_old.parent.mkdir(parents=True, exist_ok=True)
path_old.write_text(
yaml.dump(sample_checkpoint, default_flow_style=False, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
# Write a recent one through the normal path
store.write("sess_new", sample_checkpoint)
removed = store.garbage_collect(max_age_days=7)
assert "sess_old" in removed
assert "sess_new" not in removed
assert store.read("sess_old") is None
assert store.read("sess_new") is not None
def test_gc_keeps_all_when_fresh(self, store, sample_checkpoint):
store.write("sess_001", sample_checkpoint)
removed = store.garbage_collect(max_age_days=7)
assert removed == []
assert store.read("sess_001") is not None

View file

@ -0,0 +1,130 @@
import json
import pytest
from unittest.mock import MagicMock
from agent.checkpoint_store import CheckpointStore
from tools.checkpoint_tool import checkpoint_tool
@pytest.fixture
def store(tmp_path):
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
@pytest.fixture
def agent_context(store):
"""Simulate the fields the tool handler pulls from the AIAgent."""
ctx = MagicMock()
ctx.session_id = "test_session_001"
ctx._todo_store = MagicMock()
ctx._todo_store.format_for_injection.return_value = "- [x] Step 1\n- [ ] Step 2"
return ctx
class TestCheckpointWrite:
def test_write_basic_checkpoint(self, store, agent_context):
result = checkpoint_tool(
action="write",
task="Build feature X",
progress=[{"step": "Setup", "status": "completed"}],
state={"active_branch": "main"},
decisions=["Use SQLite"],
store=store,
agent=agent_context,
)
data = json.loads(result)
assert data["success"] is True
assert data["session_id"] == "test_session_001"
saved = store.read("test_session_001")
assert saved["task"] == "Build feature X"
assert saved["status"] == "in_progress"
def test_write_includes_auto_fields(self, store, agent_context):
checkpoint_tool(
action="write",
task="Test task",
progress=[],
state={},
decisions=[],
store=store,
agent=agent_context,
)
saved = store.read("test_session_001")
assert "created" in saved
assert "updated" in saved
assert saved["session_id"] == "test_session_001"
class TestCheckpointUpdate:
def test_update_merges_progress(self, store, agent_context):
checkpoint_tool(
action="write",
task="Build feature X",
progress=[{"step": "Setup", "status": "completed"}],
state={"active_branch": "main"},
decisions=[],
store=store,
agent=agent_context,
)
result = checkpoint_tool(
action="update",
progress=[{"step": "Implement", "status": "in_progress"}],
state={"active_branch": "feat/x"},
store=store,
agent=agent_context,
)
data = json.loads(result)
assert data["success"] is True
saved = store.read("test_session_001")
assert saved["task"] == "Build feature X" # preserved
assert len(saved["progress"]) == 2 # merged
def test_update_nonexistent_returns_error(self, store, agent_context):
result = checkpoint_tool(
action="update",
progress=[],
state={},
store=store,
agent=agent_context,
)
data = json.loads(result)
assert data["success"] is False
class TestCheckpointRead:
def test_read_existing(self, store, agent_context):
checkpoint_tool(
action="write",
task="Test task",
progress=[],
state={},
decisions=[],
store=store,
agent=agent_context,
)
result = checkpoint_tool(action="read", store=store, agent=agent_context)
data = json.loads(result)
assert data["success"] is True
assert data["checkpoint"]["task"] == "Test task"
def test_read_nonexistent(self, store, agent_context):
result = checkpoint_tool(action="read", store=store, agent=agent_context)
data = json.loads(result)
assert data["success"] is False
assert "no checkpoint" in data["error"].lower()
class TestCheckpointClear:
def test_clear_removes_checkpoint(self, store, agent_context):
checkpoint_tool(
action="write",
task="Test task",
progress=[],
state={},
decisions=[],
store=store,
agent=agent_context,
)
result = checkpoint_tool(action="clear", store=store, agent=agent_context)
data = json.loads(result)
assert data["success"] is True
assert store.read("test_session_001") is None

264
tools/checkpoint_tool.py Normal file
View file

@ -0,0 +1,264 @@
"""Checkpoint tool -- save and restore mid-task state across context compaction.
The checkpoint tool lets the agent save its current task progress, state, and
decisions to disk so it can resume after context compaction or session restart.
This is Layer 4 in the memory architecture (after Layer 1 memory, Layer 2
personality, Layer 3 vault).
Checkpoint files are stored as YAML in ~/.hermes/checkpoints/<session_id>.yaml.
"""
import json
import logging
import subprocess
from datetime import datetime
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
CHECKPOINT_TOOL_SCHEMA = {
"name": "checkpoint",
"description": (
"Save or restore mid-task state (checkpoint). "
"Write a checkpoint before risky operations or when you have made significant progress. "
"The system also auto-writes a checkpoint before context compaction. "
"On session start, any existing checkpoint from the parent session is auto-injected "
"so you can resume where you left off. "
"Use 'write' to create/overwrite, 'update' to merge into existing, "
"'read' to check current checkpoint, 'clear' to delete."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["write", "update", "read", "clear"],
"description": (
"write: Create or overwrite checkpoint for current session. "
"update: Merge progress/state/decisions into existing checkpoint. "
"read: Read current checkpoint. "
"clear: Delete current checkpoint (task done or abandoned)."
),
},
"task": {
"type": "string",
"description": "One-line description of what you are working on. Required for 'write'.",
},
"progress": {
"type": "array",
"items": {
"type": "object",
"properties": {
"step": {"type": "string", "description": "What this step does"},
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed", "cancelled"],
},
"result": {
"type": "string",
"description": "Outcome (optional, for completed steps)",
},
},
"required": ["step", "status"],
},
"description": "Ordered list of task steps with status. Each step: {step, status, result?}.",
},
"state": {
"type": "object",
"properties": {
"active_branch": {"type": "string"},
"files_changed": {"type": "array", "items": {"type": "string"}},
"tests_status": {"type": "string"},
"last_commit": {"type": "string"},
"pushed": {"type": "boolean"},
"working_directory": {"type": "string"},
},
"description": "Machine-readable facts about the current state: branch, files, test results, commits.",
},
"decisions": {
"type": "array",
"items": {"type": "string"},
"description": "Non-obvious choices made during this task (the 'why', not the 'what').",
},
"blocked": {
"type": "array",
"items": {"type": "string"},
"description": "Things blocked on external input. Empty if unblocked.",
},
"unresolved": {
"type": "array",
"items": {"type": "string"},
"description": "Open questions or unknowns. Empty if none.",
},
},
"required": ["action"],
},
}
def _git_state(workdir: str = None) -> Dict[str, Any]:
"""Best-effort capture of git state from the working directory."""
state = {}
if not workdir:
return state
try:
branch = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
capture_output=True, text=True, timeout=5, cwd=workdir,
)
if branch.returncode == 0:
state["active_branch"] = branch.stdout.strip()
commit = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
capture_output=True, text=True, timeout=5, cwd=workdir,
)
if commit.returncode == 0:
state["last_commit"] = commit.stdout.strip()
except (OSError, subprocess.TimeoutExpired):
pass
return state
def checkpoint_tool(
action: str,
task: str = None,
progress: List[Dict] = None,
state: Dict[str, Any] = None,
decisions: List[str] = None,
blocked: List[str] = None,
unresolved: List[str] = None,
store=None,
agent=None,
) -> str:
"""Execute a checkpoint action. Returns JSON result string."""
if store is None:
from agent.checkpoint_store import CheckpointStore
store = CheckpointStore()
session_id = getattr(agent, "session_id", "unknown") if agent else "unknown"
if action == "write":
if not task:
return json.dumps({"success": False, "error": "'task' is required for write action"})
# Auto-populate git state if not provided
effective_state = dict(state or {})
workdir = effective_state.get("working_directory")
if workdir and "active_branch" not in effective_state:
git = _git_state(workdir)
effective_state.update(git)
# Auto-populate from todo store if available
todo_snapshot = None
if agent and hasattr(agent, "_todo_store") and agent._todo_store:
try:
todo_snapshot = agent._todo_store.format_for_injection()
except Exception:
pass
if todo_snapshot:
effective_state["todo_snapshot"] = todo_snapshot
data = {
"session_id": session_id,
"task": task,
"status": "in_progress",
"created": datetime.now().isoformat(),
"updated": datetime.now().isoformat(),
"progress": progress or [],
"state": effective_state,
"decisions": decisions or [],
"blocked": blocked or [],
"unresolved": unresolved or [],
}
store.write(session_id, data)
return json.dumps({
"success": True,
"session_id": session_id,
"message": f"Checkpoint saved for session {session_id}",
})
elif action == "update":
existing = store.read(session_id)
if not existing:
return json.dumps({
"success": False,
"error": f"No checkpoint exists for session {session_id}. Use 'write' first.",
})
# Merge progress (append new steps)
if progress:
existing["progress"] = existing.get("progress", []) + progress
# Merge state (overwrite keys)
if state:
existing["state"] = {**existing.get("state", {}), **state}
# Append decisions
if decisions:
existing["decisions"] = existing.get("decisions", []) + decisions
# Replace blocked/unresolved
if blocked is not None:
existing["blocked"] = blocked
if unresolved is not None:
existing["unresolved"] = unresolved
# Update git state if workdir provided
workdir = (state or {}).get("working_directory") if state else None
if workdir:
git = _git_state(workdir)
existing["state"] = {**existing.get("state", {}), **git}
store.write(session_id, existing)
return json.dumps({
"success": True,
"session_id": session_id,
"message": f"Checkpoint updated for session {session_id}",
})
elif action == "read":
data = store.read(session_id)
if not data:
return json.dumps({
"success": False,
"error": f"No checkpoint exists for session {session_id}",
})
return json.dumps({
"success": True,
"session_id": session_id,
"checkpoint": data,
})
elif action == "clear":
store.delete(session_id)
return json.dumps({
"success": True,
"session_id": session_id,
"message": f"Checkpoint cleared for session {session_id}",
})
else:
return json.dumps({"success": False, "error": f"Unknown action: {action}"})
# --- Registry ---
from tools.registry import registry
registry.register(
name="checkpoint",
toolset="todo",
schema=CHECKPOINT_TOOL_SCHEMA,
handler=lambda args, **kw: checkpoint_tool(
action=args.get("action"),
task=args.get("task"),
progress=args.get("progress"),
state=args.get("state"),
decisions=args.get("decisions"),
blocked=args.get("blocked"),
unresolved=args.get("unresolved"),
store=kw.get("store"),
agent=kw.get("agent"),
),
check_fn=lambda: True, # always available
emoji="🔖",
)

View file

@ -159,7 +159,7 @@ TOOLSETS = {
"todo": {
"description": "Task planning and tracking for multi-step work",
"tools": ["todo"],
"tools": ["todo", "checkpoint"],
"includes": []
},