mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(checkpoint): add checkpoint tool with write/update/read/clear actions
This commit is contained in:
parent
d588ce2b59
commit
c1b3536bb3
2 changed files with 394 additions and 0 deletions
130
tests/test_checkpoint_tool.py
Normal file
130
tests/test_checkpoint_tool.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
from tools.checkpoint_tool import checkpoint_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(tmp_path):
|
||||
return CheckpointStore(checkpoints_dir=tmp_path / "checkpoints")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_context(store):
|
||||
"""Simulate the fields the tool handler pulls from the AIAgent."""
|
||||
ctx = MagicMock()
|
||||
ctx.session_id = "test_session_001"
|
||||
ctx._todo_store = MagicMock()
|
||||
ctx._todo_store.format_for_injection.return_value = "- [x] Step 1\n- [ ] Step 2"
|
||||
return ctx
|
||||
|
||||
|
||||
class TestCheckpointWrite:
|
||||
def test_write_basic_checkpoint(self, store, agent_context):
|
||||
result = checkpoint_tool(
|
||||
action="write",
|
||||
task="Build feature X",
|
||||
progress=[{"step": "Setup", "status": "completed"}],
|
||||
state={"active_branch": "main"},
|
||||
decisions=["Use SQLite"],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert data["session_id"] == "test_session_001"
|
||||
saved = store.read("test_session_001")
|
||||
assert saved["task"] == "Build feature X"
|
||||
assert saved["status"] == "in_progress"
|
||||
|
||||
def test_write_includes_auto_fields(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
saved = store.read("test_session_001")
|
||||
assert "created" in saved
|
||||
assert "updated" in saved
|
||||
assert saved["session_id"] == "test_session_001"
|
||||
|
||||
|
||||
class TestCheckpointUpdate:
|
||||
def test_update_merges_progress(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Build feature X",
|
||||
progress=[{"step": "Setup", "status": "completed"}],
|
||||
state={"active_branch": "main"},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(
|
||||
action="update",
|
||||
progress=[{"step": "Implement", "status": "in_progress"}],
|
||||
state={"active_branch": "feat/x"},
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
saved = store.read("test_session_001")
|
||||
assert saved["task"] == "Build feature X" # preserved
|
||||
assert len(saved["progress"]) == 2 # merged
|
||||
|
||||
def test_update_nonexistent_returns_error(self, store, agent_context):
|
||||
result = checkpoint_tool(
|
||||
action="update",
|
||||
progress=[],
|
||||
state={},
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
class TestCheckpointRead:
|
||||
def test_read_existing(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(action="read", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert data["checkpoint"]["task"] == "Test task"
|
||||
|
||||
def test_read_nonexistent(self, store, agent_context):
|
||||
result = checkpoint_tool(action="read", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is False
|
||||
assert "no checkpoint" in data["error"].lower()
|
||||
|
||||
|
||||
class TestCheckpointClear:
|
||||
def test_clear_removes_checkpoint(self, store, agent_context):
|
||||
checkpoint_tool(
|
||||
action="write",
|
||||
task="Test task",
|
||||
progress=[],
|
||||
state={},
|
||||
decisions=[],
|
||||
store=store,
|
||||
agent=agent_context,
|
||||
)
|
||||
result = checkpoint_tool(action="clear", store=store, agent=agent_context)
|
||||
data = json.loads(result)
|
||||
assert data["success"] is True
|
||||
assert store.read("test_session_001") is None
|
||||
264
tools/checkpoint_tool.py
Normal file
264
tools/checkpoint_tool.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""Checkpoint tool -- save and restore mid-task state across context compaction.
|
||||
|
||||
The checkpoint tool lets the agent save its current task progress, state, and
|
||||
decisions to disk so it can resume after context compaction or session restart.
|
||||
This is Layer 4 in the memory architecture (after Layer 1 memory, Layer 2
|
||||
personality, Layer 3 vault).
|
||||
|
||||
Checkpoint files are stored as YAML in ~/.hermes/checkpoints/<session_id>.yaml.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHECKPOINT_TOOL_SCHEMA = {
|
||||
"name": "checkpoint",
|
||||
"description": (
|
||||
"Save or restore mid-task state (checkpoint). "
|
||||
"Write a checkpoint before risky operations or when you have made significant progress. "
|
||||
"The system also auto-writes a checkpoint before context compaction. "
|
||||
"On session start, any existing checkpoint from the parent session is auto-injected "
|
||||
"so you can resume where you left off. "
|
||||
"Use 'write' to create/overwrite, 'update' to merge into existing, "
|
||||
"'read' to check current checkpoint, 'clear' to delete."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["write", "update", "read", "clear"],
|
||||
"description": (
|
||||
"write: Create or overwrite checkpoint for current session. "
|
||||
"update: Merge progress/state/decisions into existing checkpoint. "
|
||||
"read: Read current checkpoint. "
|
||||
"clear: Delete current checkpoint (task done or abandoned)."
|
||||
),
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "One-line description of what you are working on. Required for 'write'.",
|
||||
},
|
||||
"progress": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"step": {"type": "string", "description": "What this step does"},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed", "cancelled"],
|
||||
},
|
||||
"result": {
|
||||
"type": "string",
|
||||
"description": "Outcome (optional, for completed steps)",
|
||||
},
|
||||
},
|
||||
"required": ["step", "status"],
|
||||
},
|
||||
"description": "Ordered list of task steps with status. Each step: {step, status, result?}.",
|
||||
},
|
||||
"state": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"active_branch": {"type": "string"},
|
||||
"files_changed": {"type": "array", "items": {"type": "string"}},
|
||||
"tests_status": {"type": "string"},
|
||||
"last_commit": {"type": "string"},
|
||||
"pushed": {"type": "boolean"},
|
||||
"working_directory": {"type": "string"},
|
||||
},
|
||||
"description": "Machine-readable facts about the current state: branch, files, test results, commits.",
|
||||
},
|
||||
"decisions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Non-obvious choices made during this task (the 'why', not the 'what').",
|
||||
},
|
||||
"blocked": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Things blocked on external input. Empty if unblocked.",
|
||||
},
|
||||
"unresolved": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Open questions or unknowns. Empty if none.",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _git_state(workdir: str = None) -> Dict[str, Any]:
|
||||
"""Best-effort capture of git state from the working directory."""
|
||||
state = {}
|
||||
if not workdir:
|
||||
return state
|
||||
try:
|
||||
branch = subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5, cwd=workdir,
|
||||
)
|
||||
if branch.returncode == 0:
|
||||
state["active_branch"] = branch.stdout.strip()
|
||||
|
||||
commit = subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
capture_output=True, text=True, timeout=5, cwd=workdir,
|
||||
)
|
||||
if commit.returncode == 0:
|
||||
state["last_commit"] = commit.stdout.strip()
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return state
|
||||
|
||||
|
||||
def checkpoint_tool(
|
||||
action: str,
|
||||
task: str = None,
|
||||
progress: List[Dict] = None,
|
||||
state: Dict[str, Any] = None,
|
||||
decisions: List[str] = None,
|
||||
blocked: List[str] = None,
|
||||
unresolved: List[str] = None,
|
||||
store=None,
|
||||
agent=None,
|
||||
) -> str:
|
||||
"""Execute a checkpoint action. Returns JSON result string."""
|
||||
if store is None:
|
||||
from agent.checkpoint_store import CheckpointStore
|
||||
store = CheckpointStore()
|
||||
|
||||
session_id = getattr(agent, "session_id", "unknown") if agent else "unknown"
|
||||
|
||||
if action == "write":
|
||||
if not task:
|
||||
return json.dumps({"success": False, "error": "'task' is required for write action"})
|
||||
|
||||
# Auto-populate git state if not provided
|
||||
effective_state = dict(state or {})
|
||||
workdir = effective_state.get("working_directory")
|
||||
if workdir and "active_branch" not in effective_state:
|
||||
git = _git_state(workdir)
|
||||
effective_state.update(git)
|
||||
# Auto-populate from todo store if available
|
||||
todo_snapshot = None
|
||||
if agent and hasattr(agent, "_todo_store") and agent._todo_store:
|
||||
try:
|
||||
todo_snapshot = agent._todo_store.format_for_injection()
|
||||
except Exception:
|
||||
pass
|
||||
if todo_snapshot:
|
||||
effective_state["todo_snapshot"] = todo_snapshot
|
||||
|
||||
data = {
|
||||
"session_id": session_id,
|
||||
"task": task,
|
||||
"status": "in_progress",
|
||||
"created": datetime.now().isoformat(),
|
||||
"updated": datetime.now().isoformat(),
|
||||
"progress": progress or [],
|
||||
"state": effective_state,
|
||||
"decisions": decisions or [],
|
||||
"blocked": blocked or [],
|
||||
"unresolved": unresolved or [],
|
||||
}
|
||||
store.write(session_id, data)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint saved for session {session_id}",
|
||||
})
|
||||
|
||||
elif action == "update":
|
||||
existing = store.read(session_id)
|
||||
if not existing:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"No checkpoint exists for session {session_id}. Use 'write' first.",
|
||||
})
|
||||
|
||||
# Merge progress (append new steps)
|
||||
if progress:
|
||||
existing["progress"] = existing.get("progress", []) + progress
|
||||
|
||||
# Merge state (overwrite keys)
|
||||
if state:
|
||||
existing["state"] = {**existing.get("state", {}), **state}
|
||||
|
||||
# Append decisions
|
||||
if decisions:
|
||||
existing["decisions"] = existing.get("decisions", []) + decisions
|
||||
|
||||
# Replace blocked/unresolved
|
||||
if blocked is not None:
|
||||
existing["blocked"] = blocked
|
||||
if unresolved is not None:
|
||||
existing["unresolved"] = unresolved
|
||||
|
||||
# Update git state if workdir provided
|
||||
workdir = (state or {}).get("working_directory") if state else None
|
||||
if workdir:
|
||||
git = _git_state(workdir)
|
||||
existing["state"] = {**existing.get("state", {}), **git}
|
||||
|
||||
store.write(session_id, existing)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint updated for session {session_id}",
|
||||
})
|
||||
|
||||
elif action == "read":
|
||||
data = store.read(session_id)
|
||||
if not data:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"No checkpoint exists for session {session_id}",
|
||||
})
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"checkpoint": data,
|
||||
})
|
||||
|
||||
elif action == "clear":
|
||||
store.delete(session_id)
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"message": f"Checkpoint cleared for session {session_id}",
|
||||
})
|
||||
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Unknown action: {action}"})
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="checkpoint",
|
||||
toolset="todo",
|
||||
schema=CHECKPOINT_TOOL_SCHEMA,
|
||||
handler=lambda args, **kw: checkpoint_tool(
|
||||
action=args.get("action"),
|
||||
task=args.get("task"),
|
||||
progress=args.get("progress"),
|
||||
state=args.get("state"),
|
||||
decisions=args.get("decisions"),
|
||||
blocked=args.get("blocked"),
|
||||
unresolved=args.get("unresolved"),
|
||||
store=kw.get("store"),
|
||||
agent=kw.get("agent"),
|
||||
),
|
||||
check_fn=lambda: True, # always available
|
||||
emoji="🔖",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue