mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge 748f1a0417 into 4fade39c90
This commit is contained in:
commit
9a56339ddc
14 changed files with 1218 additions and 2 deletions
86
agent/checkpoint_injection.py
Normal file
86
agent/checkpoint_injection.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Checkpoint injection -- load checkpoint from parent session into system prompt.
|
||||
|
||||
When a new session starts (or a compression creates a continuation session),
|
||||
this module finds the most recent checkpoint from the session's lineage
|
||||
and injects it into the system prompt.
|
||||
|
||||
This is Layer 4 in the memory architecture. Only in_progress checkpoints
|
||||
are injected -- completed ones are skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SEPARATOR = "\u2550" * 46 # same as vault injection
|
||||
CHAR_LIMIT = 4000
|
||||
|
||||
|
||||
def build_checkpoint_system_prompt(store, parent_session_id: str = None) -> str:
|
||||
"""Build the checkpoint injection block for the system prompt.
|
||||
|
||||
Returns empty string if no applicable checkpoint exists.
|
||||
The caller is responsible for walking the parent lineage chain to find
|
||||
the nearest ancestor session that has a checkpoint file.
|
||||
"""
|
||||
if not store or not parent_session_id:
|
||||
return ""
|
||||
|
||||
data = store.read(parent_session_id)
|
||||
|
||||
if not data or not isinstance(data, dict):
|
||||
return ""
|
||||
|
||||
if data.get("status") == "completed":
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
lines.append(f"Task: {data.get('task', 'Unknown')}")
|
||||
lines.append(f"Status: {data.get('status', 'in_progress')}")
|
||||
lines.append("")
|
||||
|
||||
progress = data.get("progress", [])
|
||||
if progress:
|
||||
lines.append("Progress:")
|
||||
for step in progress:
|
||||
status_icon = {
|
||||
"completed": "[x]", "in_progress": "[~]",
|
||||
"pending": "[ ]", "cancelled": "[-]",
|
||||
}.get(step.get("status", "pending"), "[ ]")
|
||||
line = f" {status_icon} {step.get('step', '')}"
|
||||
if step.get("result"):
|
||||
line += f" -- {step['result']}"
|
||||
lines.append(line)
|
||||
lines.append("")
|
||||
|
||||
state = data.get("state", {})
|
||||
if state:
|
||||
lines.append("State:")
|
||||
for key, value in state.items():
|
||||
if value and str(value) not in ("[]", ""):
|
||||
lines.append(f" {key}: {value}")
|
||||
lines.append("")
|
||||
|
||||
decisions = data.get("decisions", [])
|
||||
if decisions:
|
||||
lines.append("Decisions:")
|
||||
for d in decisions:
|
||||
lines.append(f" - {d}")
|
||||
lines.append("")
|
||||
|
||||
blocked = data.get("blocked", [])
|
||||
if blocked:
|
||||
lines.append(f"Blocked: {', '.join(blocked)}")
|
||||
|
||||
unresolved = data.get("unresolved", [])
|
||||
if unresolved:
|
||||
lines.append(f"Unresolved: {', '.join(unresolved)}")
|
||||
|
||||
content = "\n".join(lines)
|
||||
|
||||
if len(content) > CHAR_LIMIT:
|
||||
content = content[:CHAR_LIMIT] + "\n[... truncated at char limit ...]"
|
||||
|
||||
header = "CHECKPOINT: RESUME FROM HERE (saved before compaction, injected on session start)"
|
||||
return f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{content}"
|
||||
117
agent/checkpoint_store.py
Normal file
117
agent/checkpoint_store.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Checkpoint store -- read/write YAML checkpoint files for session resumption.
|
||||
|
||||
Each checkpoint captures enough state to resume a task from where the agent
|
||||
left off after context compaction or session restart. Files live under
|
||||
~/.hermes/checkpoints/<session_id>.yaml.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Session IDs must be safe filesystem names -- no path separators or traversal
|
||||
_SAFE_SESSION_ID = re.compile(r"^[A-Za-z0-9_-]+$")
|
||||
|
||||
DEFAULT_CHECKPOINTS_DIR = Path.home() / ".hermes" / "checkpoints"
|
||||
DEFAULT_GC_MAX_AGE_DAYS = 7
|
||||
|
||||
|
||||
class CheckpointStore:
|
||||
"""Manages checkpoint YAML files on disk."""
|
||||
|
||||
def __init__(self, checkpoints_dir: Path = None):
|
||||
self._dir = Path(checkpoints_dir) if checkpoints_dir else DEFAULT_CHECKPOINTS_DIR
|
||||
self._dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _path_for(self, session_id: str) -> Path:
|
||||
if not _SAFE_SESSION_ID.match(session_id):
|
||||
raise ValueError(f"Invalid session_id: {session_id!r} (must match {_SAFE_SESSION_ID.pattern})")
|
||||
return self._dir / f"{session_id}.yaml"
|
||||
|
||||
def write(self, session_id: str, data: Dict[str, Any]) -> None:
|
||||
"""Write a checkpoint for the given session. Overwrites if exists."""
|
||||
path = self._path_for(session_id)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Ensure the 'updated' timestamp is current
|
||||
data["updated"] = datetime.now().isoformat()
|
||||
try:
|
||||
path.write_text(
|
||||
yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.debug("Checkpoint written: %s (%d chars)", session_id, path.stat().st_size)
|
||||
except (OSError, yaml.YAMLError) as e:
|
||||
logger.warning("Failed to write checkpoint %s: %s", session_id, e)
|
||||
|
||||
def read(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Read a checkpoint. Returns None if missing or corrupt."""
|
||||
path = self._path_for(session_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
data = yaml.safe_load(content)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
return data
|
||||
except (OSError, yaml.YAMLError) as e:
|
||||
logger.debug("Failed to read checkpoint %s: %s", session_id, e)
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> None:
|
||||
"""Delete a checkpoint. No-op if missing."""
|
||||
path = self._path_for(session_id)
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def list_sessions(self) -> List[str]:
|
||||
"""Return all session IDs that have checkpoint files."""
|
||||
if not self._dir.exists():
|
||||
return []
|
||||
return sorted(p.stem for p in self._dir.glob("*.yaml"))
|
||||
|
||||
def garbage_collect(self, max_age_days: int = DEFAULT_GC_MAX_AGE_DAYS) -> List[str]:
|
||||
"""Remove stale checkpoints. Returns list of removed session IDs.
|
||||
|
||||
A checkpoint is stale if:
|
||||
- Its 'updated' timestamp is older than max_age_days AND not completed
|
||||
- OR it's 'completed' and older than 1 day (task done, no need to keep)
|
||||
"""
|
||||
removed = []
|
||||
cutoff = datetime.now() - timedelta(days=max_age_days)
|
||||
completed_cutoff = datetime.now() - timedelta(days=1)
|
||||
|
||||
for path in self._dir.glob("*.yaml"):
|
||||
try:
|
||||
data = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
path.unlink(missing_ok=True)
|
||||
removed.append(path.stem)
|
||||
continue
|
||||
updated_str = data.get("updated", "")
|
||||
status = data.get("status", "in_progress")
|
||||
try:
|
||||
updated = datetime.fromisoformat(updated_str) if updated_str else cutoff
|
||||
except ValueError:
|
||||
updated = cutoff
|
||||
|
||||
if status == "completed" and updated < completed_cutoff:
|
||||
path.unlink(missing_ok=True)
|
||||
removed.append(path.stem)
|
||||
elif updated < cutoff:
|
||||
path.unlink(missing_ok=True)
|
||||
removed.append(path.stem)
|
||||
except (OSError, yaml.YAMLError):
|
||||
path.unlink(missing_ok=True)
|
||||
removed.append(path.stem)
|
||||
|
||||
if removed:
|
||||
logger.info("Checkpoint GC: removed %d stale file(s): %s", len(removed), removed)
|
||||
return removed
|
||||
|
|
@ -176,6 +176,20 @@ SKILLS_GUIDANCE = (
|
|||
"Skills that aren't maintained become liabilities."
|
||||
)
|
||||
|
||||
CHECKPOINT_GUIDANCE = (
|
||||
"When a CHECKPOINT block appears at session start, read it and resume from "
|
||||
"where the previous session left off. Do not start over — pick up at the "
|
||||
"first in_progress step.\n"
|
||||
"Write a checkpoint (action=\"write\") at these moments:\n"
|
||||
"- Before a risky command that could break state\n"
|
||||
"- After completing a major step in a multi-step task (then action=\"update\" "
|
||||
"for subsequent steps)\n"
|
||||
"- When you have important \"why\" decisions that won't be obvious later\n"
|
||||
"Clear the checkpoint (action=\"clear\") when the task is fully done.\n"
|
||||
"Pre-compression checkpoints are auto-written — you do not need to worry "
|
||||
"about losing progress when the context window fills."
|
||||
)
|
||||
|
||||
TOOL_USE_ENFORCEMENT_GUIDANCE = (
|
||||
"# Tool-use enforcement\n"
|
||||
"You MUST use your tools to take action — do not describe what you would do "
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ def get_tool_definitions(
|
|||
# because they need agent-level state (TodoStore, MemoryStore, etc.).
|
||||
# The registry still holds their schemas; dispatch just returns a stub error
|
||||
# so if something slips through, the LLM sees a sensible message.
|
||||
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task", "checkpoint"}
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
|
||||
|
||||
|
|
|
|||
134
run_agent.py
134
run_agent.py
|
|
@ -85,6 +85,7 @@ from agent.error_classifier import classify_api_error, FailoverReason
|
|||
from agent.prompt_builder import (
|
||||
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||
CHECKPOINT_GUIDANCE,
|
||||
build_nous_subscription_prompt,
|
||||
)
|
||||
from agent.model_metadata import (
|
||||
|
|
@ -1561,6 +1562,11 @@ class AIAgent:
|
|||
# In-memory todo list for task planning (one per agent/session)
|
||||
from tools.todo_tool import TodoStore
|
||||
self._todo_store = TodoStore()
|
||||
|
||||
# Layer 4: Checkpoint store (session resumption across compression)
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
self._checkpoint_store = CheckpointStore()
|
||||
self._checkpoint_store.garbage_collect() # prune stale checkpoints on startup
|
||||
|
||||
# Load config once for memory, skills, and compression sections
|
||||
try:
|
||||
|
|
@ -4391,6 +4397,8 @@ class AIAgent:
|
|||
tool_guidance.append(SESSION_SEARCH_GUIDANCE)
|
||||
if "skill_manage" in self.valid_tool_names:
|
||||
tool_guidance.append(SKILLS_GUIDANCE)
|
||||
if "checkpoint" in self.valid_tool_names:
|
||||
tool_guidance.append(CHECKPOINT_GUIDANCE)
|
||||
if tool_guidance:
|
||||
prompt_parts.append(" ".join(tool_guidance))
|
||||
|
||||
|
|
@ -4448,6 +4456,54 @@ class AIAgent:
|
|||
if user_block:
|
||||
prompt_parts.append(user_block)
|
||||
|
||||
# Vault auto-injection (Layer 3) — reads working-context.md and
|
||||
# user-profile.md from the Obsidian vault and injects them into
|
||||
# the system prompt. Structural fix for vault neglect.
|
||||
_vault_log = logging.getLogger("agent.vault")
|
||||
if self._vault_enabled and self._vault_path:
|
||||
try:
|
||||
from agent.vault_injection import build_vault_system_prompt
|
||||
_vault_block = build_vault_system_prompt(self._vault_path)
|
||||
if _vault_block:
|
||||
prompt_parts.append(_vault_block)
|
||||
_vault_log.info("Injection succeeded: %d chars from %s", len(_vault_block), self._vault_path)
|
||||
else:
|
||||
_vault_log.warning("Injection returned empty for path %s", self._vault_path)
|
||||
except Exception as e:
|
||||
_vault_log.warning("Injection failed: %s: %s", type(e).__name__, e)
|
||||
else:
|
||||
_vault_log.info("Injection skipped: enabled=%s path=%s", self._vault_enabled, self._vault_path)
|
||||
|
||||
# Checkpoint injection (Layer 4) -- resume from where the previous session left off
|
||||
if self._checkpoint_store:
|
||||
try:
|
||||
_checkpoint_session_id = None
|
||||
if self._session_db:
|
||||
_sid = self.session_id
|
||||
# Walk up the parent chain to find the nearest ancestor with a checkpoint
|
||||
for _ in range(5): # max 5 hops up lineage
|
||||
try:
|
||||
_sess = self._session_db.get_session(_sid)
|
||||
except Exception:
|
||||
break
|
||||
if not _sess:
|
||||
break
|
||||
_parent = _sess.get("parent_session_id")
|
||||
if not _parent:
|
||||
break
|
||||
if self._checkpoint_store.read(_parent) is not None:
|
||||
_checkpoint_session_id = _parent
|
||||
break
|
||||
_sid = _parent # keep walking up
|
||||
from agent.checkpoint_injection import build_checkpoint_system_prompt
|
||||
_cp_block = build_checkpoint_system_prompt(
|
||||
store=self._checkpoint_store,
|
||||
parent_session_id=_checkpoint_session_id,
|
||||
)
|
||||
if _cp_block:
|
||||
prompt_parts.append(_cp_block)
|
||||
except Exception as e:
|
||||
logger.debug("Checkpoint injection skipped: %s", e)
|
||||
# External memory provider system prompt block (additive to built-in)
|
||||
if self._memory_manager:
|
||||
try:
|
||||
|
|
@ -8119,6 +8175,52 @@ class AIAgent:
|
|||
if messages and messages[-1].get("_flush_sentinel") == _sentinel:
|
||||
messages.pop()
|
||||
|
||||
def flush_checkpoint(self, messages: list = None) -> None:
|
||||
"""Write a checkpoint before context compression destroys progress.
|
||||
|
||||
Called by _compress_context before the compression happens.
|
||||
Captures whatever task state we can extract from the agent's
|
||||
current context (todo list, session ID, git state) and writes
|
||||
a checkpoint. If a checkpoint already exists for this session,
|
||||
it gets overwritten -- the pre-compression checkpoint is always
|
||||
the most accurate.
|
||||
"""
|
||||
if not self._checkpoint_store:
|
||||
return
|
||||
|
||||
from tools.checkpoint_tool import checkpoint_tool
|
||||
# Extract task description from session title or fallback
|
||||
task_desc = "Session checkpoint (auto-saved before compression)"
|
||||
try:
|
||||
if self._session_db:
|
||||
title = self._session_db.get_session_title(self.session_id)
|
||||
if title:
|
||||
task_desc = title
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract progress from todo store if available
|
||||
progress = []
|
||||
if self._todo_store:
|
||||
try:
|
||||
items = self._todo_store._items if hasattr(self._todo_store, '_items') else []
|
||||
for item in items:
|
||||
step = {"step": item.get("content", ""), "status": item.get("status", "pending")}
|
||||
progress.append(step)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task=task_desc,
|
||||
progress=progress,
|
||||
state={},
|
||||
decisions=[],
|
||||
store=self._checkpoint_store,
|
||||
agent=self,
|
||||
)
|
||||
logger.info("Pre-compression checkpoint saved for session %s", self.session_id)
|
||||
|
||||
def _compress_context(self, messages: list, system_message: str, *, approx_tokens: int = None, task_id: str = "default", focus_topic: str = None) -> tuple:
|
||||
"""Compress conversation context and split the session in SQLite.
|
||||
|
||||
|
|
@ -8140,6 +8242,9 @@ class AIAgent:
|
|||
# Pre-compression memory flush: let the model save memories before they're lost
|
||||
self.flush_memories(messages, min_turns=0)
|
||||
|
||||
# Pre-compression checkpoint: save task state before context is lost
|
||||
self.flush_checkpoint(messages)
|
||||
|
||||
# Notify external memory provider before compression discards context
|
||||
if self._memory_manager:
|
||||
try:
|
||||
|
|
@ -8351,6 +8456,19 @@ class AIAgent:
|
|||
)
|
||||
elif function_name == "delegate_task":
|
||||
return self._dispatch_delegate_task(function_args)
|
||||
elif function_name == "checkpoint":
|
||||
from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool
|
||||
return _checkpoint_tool(
|
||||
action=function_args.get("action"),
|
||||
task=function_args.get("task"),
|
||||
progress=function_args.get("progress"),
|
||||
state=function_args.get("state"),
|
||||
decisions=function_args.get("decisions"),
|
||||
blocked=function_args.get("blocked"),
|
||||
unresolved=function_args.get("unresolved"),
|
||||
store=self._checkpoint_store,
|
||||
agent=self,
|
||||
)
|
||||
else:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
|
|
@ -8892,6 +9010,22 @@ class AIAgent:
|
|||
spinner.stop(cute_msg)
|
||||
elif self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif function_name == "checkpoint":
|
||||
from tools.checkpoint_tool import checkpoint_tool as _checkpoint_tool
|
||||
function_result = _checkpoint_tool(
|
||||
action=function_args.get("action"),
|
||||
task=function_args.get("task"),
|
||||
progress=function_args.get("progress"),
|
||||
state=function_args.get("state"),
|
||||
decisions=function_args.get("decisions"),
|
||||
blocked=function_args.get("blocked"),
|
||||
unresolved=function_args.get("unresolved"),
|
||||
store=self._checkpoint_store,
|
||||
agent=self,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self._should_emit_quiet_tool_messages():
|
||||
self._vprint(f" {_get_cute_tool_message_impl('checkpoint', function_args, tool_duration, result=function_result)}")
|
||||
elif self._context_engine_tool_names and function_name in self._context_engine_tool_names:
|
||||
# Context engine tools (lcm_grep, lcm_describe, lcm_expand, etc.)
|
||||
spinner = None
|
||||
|
|
|
|||
16
tests/test_checkpoint_agent_dispatch.py
Normal file
16
tests/test_checkpoint_agent_dispatch.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Verify that checkpoint is routed through the agent loop, not the registry."""
|
||||
import json
|
||||
import pytest
|
||||
from model_tools import handle_function_call, _AGENT_LOOP_TOOLS
|
||||
|
||||
|
||||
def test_checkpoint_is_agent_loop_tool():
|
||||
"""checkpoint must be in _AGENT_LOOP_TOOLS so it gets agent-level state."""
|
||||
assert "checkpoint" in _AGENT_LOOP_TOOLS
|
||||
|
||||
|
||||
def test_checkpoint_registry_dispatch_returns_error():
|
||||
"""Calling handle_function_call for checkpoint should return agent-loop error."""
|
||||
result = handle_function_call("checkpoint", {"action": "read"})
|
||||
data = json.loads(result)
|
||||
assert "must be handled" in data["error"].lower()
|
||||
53
tests/test_checkpoint_flush.py
Normal file
53
tests/test_checkpoint_flush.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Test that flush_checkpoint writes a checkpoint via checkpoint_tool."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_flush_checkpoint_method_exists():
|
||||
"""AIAgent must have a flush_checkpoint method."""
|
||||
from run_agent import AIAgent
|
||||
assert hasattr(AIAgent, "flush_checkpoint")
|
||||
|
||||
|
||||
def test_flush_checkpoint_writes_to_store(tmp_path):
|
||||
"""flush_checkpoint should write a checkpoint with current session state."""
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
mock_todo = MagicMock()
|
||||
mock_todo._items = [
|
||||
{"content": "Step 1", "status": "completed"},
|
||||
{"content": "Step 2", "status": "in_progress"},
|
||||
]
|
||||
mock_todo.format_for_injection.return_value = "- [x] Step 1\n- [~] Step 2"
|
||||
|
||||
mock_session_db = MagicMock()
|
||||
mock_session_db.get_session_title.return_value = "Test session title"
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.session_id = "flush_test_session"
|
||||
mock_agent._checkpoint_store = store
|
||||
mock_agent._todo_store = mock_todo
|
||||
mock_agent._session_db = mock_session_db
|
||||
|
||||
# Call the real flush_checkpoint method
|
||||
from run_agent import AIAgent
|
||||
AIAgent.flush_checkpoint(mock_agent)
|
||||
|
||||
# Verify a checkpoint was written for this session
|
||||
saved = store.read("flush_test_session")
|
||||
assert saved is not None
|
||||
assert saved["task"] == "Test session title"
|
||||
assert saved["status"] == "in_progress"
|
||||
|
||||
|
||||
def test_flush_checkpoint_noops_without_store(tmp_path):
|
||||
"""flush_checkpoint should silently return if _checkpoint_store is None."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent._checkpoint_store = None
|
||||
|
||||
from run_agent import AIAgent
|
||||
# Should not raise
|
||||
AIAgent.flush_checkpoint(mock_agent)
|
||||
58
tests/test_checkpoint_injection.py
Normal file
58
tests/test_checkpoint_injection.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import pytest
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from agent.checkpoint_injection import build_checkpoint_system_prompt
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path):
|
||||
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parent_checkpoint():
|
||||
return {
|
||||
"session_id": "old_session_001",
|
||||
"task": "Red team QA on PR #14303",
|
||||
"status": "in_progress",
|
||||
"created": "2026-04-22T22:58:00",
|
||||
"updated": "2026-04-22T23:15:00",
|
||||
"progress": [
|
||||
{"step": "Read code paths", "status": "completed"},
|
||||
{"step": "Add edge-case tests", "status": "in_progress"},
|
||||
],
|
||||
"state": {"active_branch": "fix/delegation", "tests_status": "22/22 passing"},
|
||||
"decisions": ["SSRF is low-severity"],
|
||||
"blocked": [],
|
||||
"unresolved": [],
|
||||
}
|
||||
|
||||
|
||||
class TestBuildCheckpointPrompt:
|
||||
def test_injects_in_progress_checkpoint(self, store, parent_checkpoint):
|
||||
store.write("old_session_001", parent_checkpoint)
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
|
||||
assert "Red team QA on PR #14303" in prompt
|
||||
assert "CHECKPOINT" in prompt
|
||||
|
||||
def test_skips_completed_checkpoints(self, store, parent_checkpoint):
|
||||
parent_checkpoint["status"] = "completed"
|
||||
store.write("old_session_001", parent_checkpoint)
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
|
||||
assert prompt == ""
|
||||
|
||||
def test_returns_empty_when_no_parent(self, store):
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="nonexistent")
|
||||
assert prompt == ""
|
||||
|
||||
def test_includes_progress_and_decisions(self, store, parent_checkpoint):
|
||||
store.write("old_session_001", parent_checkpoint)
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
|
||||
assert "Read code paths" in prompt
|
||||
assert "SSRF is low-severity" in prompt
|
||||
|
||||
def test_truncates_oversized_checkpoints(self, store, parent_checkpoint):
|
||||
parent_checkpoint["task"] = "x" * 5000
|
||||
parent_checkpoint["decisions"] = ["d" * 2000] * 10
|
||||
store.write("old_session_001", parent_checkpoint)
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="old_session_001")
|
||||
assert len(prompt) < 6000
|
||||
82
tests/test_checkpoint_integration.py
Normal file
82
tests/test_checkpoint_integration.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""End-to-end test: write -> compress -> resume cycle."""
|
||||
import json
|
||||
import pytest
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from agent.checkpoint_injection import build_checkpoint_system_prompt
|
||||
from tools.checkpoint_tool import checkpoint_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path):
|
||||
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
|
||||
def test_full_cycle_write_compress_resume(store):
|
||||
"""Simulate: agent writes checkpoint -> new session gets injection."""
|
||||
old_agent = MagicMock()
|
||||
old_agent.session_id = "20260422_225800_abc123"
|
||||
old_agent._todo_store = MagicMock()
|
||||
old_agent._todo_store.format_for_injection.return_value = ""
|
||||
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Red team QA on PR #14303",
|
||||
progress=[
|
||||
{"step": "Read code", "status": "completed"},
|
||||
{"step": "Write tests", "status": "in_progress"},
|
||||
],
|
||||
state={
|
||||
"active_branch": "fix/delegation",
|
||||
"tests_status": "19/19 passing",
|
||||
"last_commit": "a90557d",
|
||||
"pushed": True,
|
||||
"working_directory": "/tmp/repo",
|
||||
},
|
||||
decisions=["SSRF is low-severity, doc note only"],
|
||||
blocked=[],
|
||||
unresolved=[],
|
||||
store=store,
|
||||
agent=old_agent,
|
||||
)
|
||||
|
||||
# New session starts, checkpoint gets injected
|
||||
prompt = build_checkpoint_system_prompt(
|
||||
store=store,
|
||||
parent_session_id="20260422_225800_abc123",
|
||||
)
|
||||
|
||||
assert "Red team QA on PR #14303" in prompt
|
||||
assert "Write tests" in prompt
|
||||
assert "fix/delegation" in prompt
|
||||
assert "SSRF is low-severity" in prompt
|
||||
|
||||
|
||||
def test_gc_removes_stale_checkpoints(store):
|
||||
"""GC removes checkpoints older than 7 days."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.session_id = "old_sess"
|
||||
mock_agent._todo_store = MagicMock()
|
||||
mock_agent._todo_store.format_for_injection.return_value = ""
|
||||
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Old task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=mock_agent,
|
||||
)
|
||||
# Age the checkpoint by writing directly with an old timestamp
|
||||
data = store.read("old_sess")
|
||||
data["updated"] = (datetime.now() - timedelta(days=10)).isoformat()
|
||||
store._path_for("old_sess").write_text(
|
||||
yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
removed = store.garbage_collect(max_age_days=7)
|
||||
assert "old_sess" in removed
|
||||
144
tests/test_checkpoint_regression.py
Normal file
144
tests/test_checkpoint_regression.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Regression tests for bugs found during red-team QA review.
|
||||
|
||||
Bug #1: get_parent_session_id() didn't exist on SessionDB
|
||||
Bug #2: Path traversal via session_id
|
||||
Bug #3: "checkpoint" not in any toolset definition
|
||||
Bug #4: flush_checkpoint test didn't actually call flush_checkpoint
|
||||
Bug #5: Injection fallback grabbed ANY in_progress checkpoint from unrelated sessions
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestBug1SessionDBLineageWalk:
|
||||
"""Bug #1: _build_system_prompt must use get_session()["parent_session_id"]
|
||||
instead of the nonexistent get_parent_session_id() method.
|
||||
"""
|
||||
|
||||
def test_session_db_has_no_get_parent_session_id_method(self):
|
||||
"""SessionDB must NOT have a get_parent_session_id method (it never did)."""
|
||||
from hermes_state import SessionDB
|
||||
assert not hasattr(SessionDB, "get_parent_session_id"), (
|
||||
"SessionDB should not have get_parent_session_id -- "
|
||||
"use get_session(sid)['parent_session_id'] instead"
|
||||
)
|
||||
|
||||
def test_get_session_returns_parent_session_id_key(self, tmp_path):
|
||||
"""get_session() must return a dict with 'parent_session_id' key."""
|
||||
from hermes_state import SessionDB
|
||||
from pathlib import Path
|
||||
db = SessionDB(db_path=Path(tmp_path) / "test.db")
|
||||
# Create a parent session
|
||||
parent_id = db.create_session("parent_session", source="test", model="test")
|
||||
# Create a child session
|
||||
child_id = db.create_session("child_session", source="test", model="test", parent_session_id=parent_id)
|
||||
sess = db.get_session(child_id)
|
||||
assert sess is not None
|
||||
assert "parent_session_id" in sess
|
||||
assert sess["parent_session_id"] == parent_id
|
||||
|
||||
|
||||
class TestBug2PathTraversal:
|
||||
"""Bug #2: session_id must be validated to prevent path traversal."""
|
||||
|
||||
def test_reject_dotdot_in_session_id(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
with pytest.raises(ValueError, match="Invalid session_id"):
|
||||
store.write("../../../etc/passwd", {"task": "evil"})
|
||||
|
||||
def test_reject_slash_in_session_id(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
with pytest.raises(ValueError, match="Invalid session_id"):
|
||||
store.write("sub/dir", {"task": "evil"})
|
||||
|
||||
def test_reject_backslash_in_session_id(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
with pytest.raises(ValueError, match="Invalid session_id"):
|
||||
store.write("sub\\dir", {"task": "evil"})
|
||||
|
||||
def test_accept_normal_session_id(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
# Should not raise
|
||||
store.write("20260422_225800_abc123", {"task": "normal"})
|
||||
|
||||
def test_read_rejects_path_traversal(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
with pytest.raises(ValueError, match="Invalid session_id"):
|
||||
store.read("../../../etc/passwd")
|
||||
|
||||
def test_delete_rejects_path_traversal(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
with pytest.raises(ValueError, match="Invalid session_id"):
|
||||
store.delete("../../../etc/passwd")
|
||||
|
||||
|
||||
class TestBug3CheckpointInToolset:
|
||||
"""Bug #3: 'checkpoint' must be in a toolset definition so the model gets the schema."""
|
||||
|
||||
def test_checkpoint_in_todo_toolset(self):
|
||||
from toolsets import TOOLSETS
|
||||
todo_tools = TOOLSETS.get("todo", {}).get("tools", [])
|
||||
assert "checkpoint" in todo_tools, (
|
||||
f"'checkpoint' must be in the 'todo' toolset tools list. "
|
||||
f"Current todo tools: {todo_tools}"
|
||||
)
|
||||
|
||||
def test_resolve_toolset_todo_includes_checkpoint(self):
|
||||
from toolsets import resolve_toolset
|
||||
resolved = resolve_toolset("todo")
|
||||
assert "checkpoint" in resolved, (
|
||||
f"resolve_toolset('todo') must return ['todo', 'checkpoint']. Got: {resolved}"
|
||||
)
|
||||
|
||||
|
||||
class TestBug5InjectionNoCrossSession:
|
||||
"""Bug #5: build_checkpoint_system_prompt must NOT fall through to grabbing
|
||||
any random in_progress checkpoint from an unrelated session.
|
||||
"""
|
||||
|
||||
def test_no_injection_when_parent_has_no_checkpoint(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from agent.checkpoint_injection import build_checkpoint_system_prompt
|
||||
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
# Write a checkpoint for an unrelated session
|
||||
store.write("unrelated_session", {
|
||||
"task": "Unrelated task",
|
||||
"status": "in_progress",
|
||||
"progress": [],
|
||||
})
|
||||
|
||||
# Ask for a parent_session_id that has no checkpoint
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="missing_session")
|
||||
assert prompt == "", (
|
||||
"Must not inject a checkpoint from an unrelated session when "
|
||||
"the requested parent_session_id has no checkpoint"
|
||||
)
|
||||
|
||||
def test_only_injects_specific_parent(self, tmp_path):
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from agent.checkpoint_injection import build_checkpoint_system_prompt
|
||||
|
||||
store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
# Write checkpoints for two sessions
|
||||
store.write("parent_A", {
|
||||
"task": "Parent A task",
|
||||
"status": "in_progress",
|
||||
"progress": [],
|
||||
})
|
||||
store.write("parent_B", {
|
||||
"task": "Parent B task",
|
||||
"status": "in_progress",
|
||||
"progress": [],
|
||||
})
|
||||
|
||||
# Request parent_A specifically
|
||||
prompt = build_checkpoint_system_prompt(store=store, parent_session_id="parent_A")
|
||||
assert "Parent A task" in prompt
|
||||
assert "Parent B task" not in prompt
|
||||
118
tests/test_checkpoint_store.py
Normal file
118
tests/test_checkpoint_store.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
import yaml
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path):
|
||||
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_checkpoint():
|
||||
return {
|
||||
"task": "Red team QA on PR #14303",
|
||||
"status": "in_progress",
|
||||
"created": "2026-04-22T22:58:00",
|
||||
"updated": "2026-04-22T23:15:00",
|
||||
"progress": [
|
||||
{"step": "Read code", "status": "completed"},
|
||||
{"step": "Write tests", "status": "in_progress"},
|
||||
],
|
||||
"state": {
|
||||
"active_branch": "fix/delegation",
|
||||
"files_changed": ["tools/delegate_tool.py"],
|
||||
"tests_status": "22/22 passing",
|
||||
"last_commit": "abc123",
|
||||
"pushed": True,
|
||||
"working_directory": "/tmp/repo",
|
||||
},
|
||||
"decisions": ["No SSRF fix needed"],
|
||||
"blocked": [],
|
||||
"unresolved": [],
|
||||
}
|
||||
|
||||
|
||||
class TestCheckpointStoreWrite:
|
||||
def test_write_creates_yaml_file(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
path = store._path_for("sess_001")
|
||||
assert path.exists()
|
||||
|
||||
def test_write_content_is_valid_yaml(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
path = store._path_for("sess_001")
|
||||
data = yaml.safe_load(path.read_text())
|
||||
assert data["task"] == "Red team QA on PR #14303"
|
||||
|
||||
def test_write_overwrites_existing(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
sample_checkpoint["status"] = "completed"
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
data = store.read("sess_001")
|
||||
assert data["status"] == "completed"
|
||||
|
||||
|
||||
class TestCheckpointStoreRead:
|
||||
def test_read_existing(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
data = store.read("sess_001")
|
||||
assert data["task"] == "Red team QA on PR #14303"
|
||||
|
||||
def test_read_nonexistent_returns_none(self, store):
|
||||
assert store.read("no_such_session") is None
|
||||
|
||||
def test_read_corrupt_yaml_returns_none(self, store):
|
||||
path = store._path_for("sess_bad")
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text("{{invalid yaml: [")
|
||||
assert store.read("sess_bad") is None
|
||||
|
||||
|
||||
class TestCheckpointStoreDelete:
|
||||
def test_delete_removes_file(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
store.delete("sess_001")
|
||||
assert not store._path_for("sess_001").exists()
|
||||
|
||||
def test_delete_nonexistent_is_noop(self, store):
|
||||
store.delete("no_such_session") # no error
|
||||
|
||||
|
||||
class TestCheckpointStoreList:
|
||||
def test_list_returns_all_session_ids(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
store.write("sess_002", sample_checkpoint)
|
||||
ids = store.list_sessions()
|
||||
assert sorted(ids) == ["sess_001", "sess_002"]
|
||||
|
||||
def test_list_empty_dir(self, store):
|
||||
assert store.list_sessions() == []
|
||||
|
||||
|
||||
class TestCheckpointStoreGC:
|
||||
def test_gc_removes_old_checkpoints(self, store, sample_checkpoint):
|
||||
# Write an old checkpoint directly to disk (bypass write() which stamps now())
|
||||
old_time = (datetime.now() - timedelta(days=10)).isoformat()
|
||||
sample_checkpoint["updated"] = old_time
|
||||
path_old = store._path_for("sess_old")
|
||||
path_old.parent.mkdir(parents=True, exist_ok=True)
|
||||
path_old.write_text(
|
||||
yaml.dump(sample_checkpoint, default_flow_style=False, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
# Write a recent one through the normal path
|
||||
store.write("sess_new", sample_checkpoint)
|
||||
removed = store.garbage_collect(max_age_days=7)
|
||||
assert "sess_old" in removed
|
||||
assert "sess_new" not in removed
|
||||
assert store.read("sess_old") is None
|
||||
assert store.read("sess_new") is not None
|
||||
|
||||
def test_gc_keeps_all_when_fresh(self, store, sample_checkpoint):
|
||||
store.write("sess_001", sample_checkpoint)
|
||||
removed = store.garbage_collect(max_age_days=7)
|
||||
assert removed == []
|
||||
assert store.read("sess_001") is not None
|
||||
130
tests/test_checkpoint_tool.py
Normal file
130
tests/test_checkpoint_tool.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from tools.checkpoint_tool import checkpoint_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path):
|
||||
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_context(store):
|
||||
"""Simulate the fields the tool handler pulls from the AIAgent."""
|
||||
ctx = MagicMock()
|
||||
ctx.session_id = "test_session_001"
|
||||
ctx._todo_store = MagicMock()
|
||||
ctx._todo_store.format_for_injection.return_value = "- [x] Step 1\n- [ ] Step 2"
|
||||
return ctx
|
||||
|
||||
|
||||
class TestCheckpointWrite:
|
||||
def test_write_basic_checkpoint(self, store, agent_context):
|
||||
result = checkpoint_tool(
|
||||
action="write",
|
||||
task="Build feature X",
|
||||
progress=[{"step": "Setup", "status": "completed"}],
|
||||
state={"active_branch": "main"},
|
||||
decisions=["Use SQLite"],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert data["session_id"] == "test_session_001"
|
||||
saved = store.read("test_session_001")
|
||||
assert saved["task"] == "Build feature X"
|
||||
assert saved["status"] == "in_progress"
|
||||
|
||||
def test_write_includes_auto_fields(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
saved = store.read("test_session_001")
|
||||
assert "created" in saved
|
||||
assert "updated" in saved
|
||||
assert saved["session_id"] == "test_session_001"
|
||||
|
||||
|
||||
class TestCheckpointUpdate:
|
||||
def test_update_merges_progress(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Build feature X",
|
||||
progress=[{"step": "Setup", "status": "completed"}],
|
||||
state={"active_branch": "main"},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(
|
||||
action="update",
|
||||
progress=[{"step": "Implement", "status": "in_progress"}],
|
||||
state={"active_branch": "feat/x"},
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
saved = store.read("test_session_001")
|
||||
assert saved["task"] == "Build feature X" # preserved
|
||||
assert len(saved["progress"]) == 2 # merged
|
||||
|
||||
def test_update_nonexistent_returns_error(self, store, agent_context):
|
||||
result = checkpoint_tool(
|
||||
action="update",
|
||||
progress=[],
|
||||
state={},
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
class TestCheckpointRead:
|
||||
def test_read_existing(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(action="read", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert data["checkpoint"]["task"] == "Test task"
|
||||
|
||||
def test_read_nonexistent(self, store, agent_context):
|
||||
result = checkpoint_tool(action="read", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is False
|
||||
assert "no checkpoint" in data["error"].lower()
|
||||
|
||||
|
||||
class TestCheckpointClear:
|
||||
def test_clear_removes_checkpoint(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(action="clear", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert store.read("test_session_001") is None
|
||||
264
tools/checkpoint_tool.py
Normal file
264
tools/checkpoint_tool.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""Checkpoint tool -- save and restore mid-task state across context compaction.
|
||||
|
||||
The checkpoint tool lets the agent save its current task progress, state, and
|
||||
decisions to disk so it can resume after context compaction or session restart.
|
||||
This is Layer 4 in the memory architecture (after Layer 1 memory, Layer 2
|
||||
personality, Layer 3 vault).
|
||||
|
||||
Checkpoint files are stored as YAML in ~/.hermes/checkpoints/<session_id>.yaml.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHECKPOINT_TOOL_SCHEMA = {
|
||||
"name": "checkpoint",
|
||||
"description": (
|
||||
"Save or restore mid-task state (checkpoint). "
|
||||
"Write a checkpoint before risky operations or when you have made significant progress. "
|
||||
"The system also auto-writes a checkpoint before context compaction. "
|
||||
"On session start, any existing checkpoint from the parent session is auto-injected "
|
||||
"so you can resume where you left off. "
|
||||
"Use 'write' to create/overwrite, 'update' to merge into existing, "
|
||||
"'read' to check current checkpoint, 'clear' to delete."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["write", "update", "read", "clear"],
|
||||
"description": (
|
||||
"write: Create or overwrite checkpoint for current session. "
|
||||
"update: Merge progress/state/decisions into existing checkpoint. "
|
||||
"read: Read current checkpoint. "
|
||||
"clear: Delete current checkpoint (task done or abandoned)."
|
||||
),
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "One-line description of what you are working on. Required for 'write'.",
|
||||
},
|
||||
"progress": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"step": {"type": "string", "description": "What this step does"},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed", "cancelled"],
|
||||
},
|
||||
"result": {
|
||||
"type": "string",
|
||||
"description": "Outcome (optional, for completed steps)",
|
||||
},
|
||||
},
|
||||
"required": ["step", "status"],
|
||||
},
|
||||
"description": "Ordered list of task steps with status. Each step: {step, status, result?}.",
|
||||
},
|
||||
"state": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"active_branch": {"type": "string"},
|
||||
"files_changed": {"type": "array", "items": {"type": "string"}},
|
||||
"tests_status": {"type": "string"},
|
||||
"last_commit": {"type": "string"},
|
||||
"pushed": {"type": "boolean"},
|
||||
"working_directory": {"type": "string"},
|
||||
},
|
||||
"description": "Machine-readable facts about the current state: branch, files, test results, commits.",
|
||||
},
|
||||
"decisions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Non-obvious choices made during this task (the 'why', not the 'what').",
|
||||
},
|
||||
"blocked": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Things blocked on external input. Empty if unblocked.",
|
||||
},
|
||||
"unresolved": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Open questions or unknowns. Empty if none.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _git_state(workdir: str = None) -> Dict[str, Any]:
|
||||
"""Best-effort capture of git state from the working directory."""
|
||||
state = {}
|
||||
if not workdir:
|
||||
return state
|
||||
try:
|
||||
branch = subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5, cwd=workdir,
|
||||
)
|
||||
if branch.returncode == 0:
|
||||
state["active_branch"] = branch.stdout.strip()
|
||||
|
||||
commit = subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5, cwd=workdir,
|
||||
)
|
||||
if commit.returncode == 0:
|
||||
state["last_commit"] = commit.stdout.strip()
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return state
|
||||
|
||||
|
||||
def checkpoint_tool(
|
||||
action: str,
|
||||
task: str = None,
|
||||
progress: List[Dict] = None,
|
||||
state: Dict[str, Any] = None,
|
||||
decisions: List[str] = None,
|
||||
blocked: List[str] = None,
|
||||
unresolved: List[str] = None,
|
||||
store=None,
|
||||
agent=None,
|
||||
) -> str:
|
||||
"""Execute a checkpoint action. Returns JSON result string."""
|
||||
if store is None:
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore()
|
||||
|
||||
session_id = getattr(agent, "session_id", "unknown") if agent else "unknown"
|
||||
|
||||
if action == "write":
|
||||
if not task:
|
||||
return json.dumps({"success": False, "error": "'task' is required for write action"})
|
||||
|
||||
# Auto-populate git state if not provided
|
||||
effective_state = dict(state or {})
|
||||
workdir = effective_state.get("working_directory")
|
||||
if workdir and "active_branch" not in effective_state:
|
||||
git = _git_state(workdir)
|
||||
effective_state.update(git)
|
||||
# Auto-populate from todo store if available
|
||||
todo_snapshot = None
|
||||
if agent and hasattr(agent, "_todo_store") and agent._todo_store:
|
||||
try:
|
||||
todo_snapshot = agent._todo_store.format_for_injection()
|
||||
except Exception:
|
||||
pass
|
||||
if todo_snapshot:
|
||||
effective_state["todo_snapshot"] = todo_snapshot
|
||||
|
||||
data = {
|
||||
"session_id": session_id,
|
||||
"task": task,
|
||||
"status": "in_progress",
|
||||
"created": datetime.now().isoformat(),
|
||||
"updated": datetime.now().isoformat(),
|
||||
"progress": progress or [],
|
||||
"state": effective_state,
|
||||
"decisions": decisions or [],
|
||||
"blocked": blocked or [],
|
||||
"unresolved": unresolved or [],
|
||||
}
|
||||
store.write(session_id, data)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint saved for session {session_id}",
|
||||
})
|
||||
|
||||
elif action == "update":
|
||||
existing = store.read(session_id)
|
||||
if not existing:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"No checkpoint exists for session {session_id}. Use 'write' first.",
|
||||
})
|
||||
|
||||
# Merge progress (append new steps)
|
||||
if progress:
|
||||
existing["progress"] = existing.get("progress", []) + progress
|
||||
|
||||
# Merge state (overwrite keys)
|
||||
if state:
|
||||
existing["state"] = {**existing.get("state", {}), **state}
|
||||
|
||||
# Append decisions
|
||||
if decisions:
|
||||
existing["decisions"] = existing.get("decisions", []) + decisions
|
||||
|
||||
# Replace blocked/unresolved
|
||||
if blocked is not None:
|
||||
existing["blocked"] = blocked
|
||||
if unresolved is not None:
|
||||
existing["unresolved"] = unresolved
|
||||
|
||||
# Update git state if workdir provided
|
||||
workdir = (state or {}).get("working_directory") if state else None
|
||||
if workdir:
|
||||
git = _git_state(workdir)
|
||||
existing["state"] = {**existing.get("state", {}), **git}
|
||||
|
||||
store.write(session_id, existing)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint updated for session {session_id}",
|
||||
})
|
||||
|
||||
elif action == "read":
|
||||
data = store.read(session_id)
|
||||
if not data:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"No checkpoint exists for session {session_id}",
|
||||
})
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"checkpoint": data,
|
||||
})
|
||||
|
||||
elif action == "clear":
|
||||
store.delete(session_id)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint cleared for session {session_id}",
|
||||
})
|
||||
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Unknown action: {action}"})
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="checkpoint",
|
||||
toolset="todo",
|
||||
schema=CHECKPOINT_TOOL_SCHEMA,
|
||||
handler=lambda args, **kw: checkpoint_tool(
|
||||
action=args.get("action"),
|
||||
task=args.get("task"),
|
||||
progress=args.get("progress"),
|
||||
state=args.get("state"),
|
||||
decisions=args.get("decisions"),
|
||||
blocked=args.get("blocked"),
|
||||
unresolved=args.get("unresolved"),
|
||||
store=kw.get("store"),
|
||||
agent=kw.get("agent"),
|
||||
),
|
||||
check_fn=lambda: True, # always available
|
||||
emoji="🔖",
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ TOOLSETS = {
|
|||
|
||||
"todo": {
|
||||
"description": "Task planning and tracking for multi-step work",
|
||||
"tools": ["todo"],
|
||||
"tools": ["todo", "checkpoint"],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue