diff --git a/agent/checkpoint_injection.py b/agent/checkpoint_injection.py index c49a034e481..f9a7f021699 100644 --- a/agent/checkpoint_injection.py +++ b/agent/checkpoint_injection.py @@ -21,21 +21,14 @@ def build_checkpoint_system_prompt(store, parent_session_id: str = None) -> str: """Build the checkpoint injection block for the system prompt. Returns empty string if no applicable checkpoint exists. + The caller is responsible for walking the parent lineage chain to find + the nearest ancestor session that has a checkpoint file. """ if not store or not parent_session_id: return "" data = store.read(parent_session_id) - # Walk lineage if direct parent has no checkpoint - if data is None and store._dir.exists(): - all_sessions = store.list_sessions() - for sid in reversed(all_sessions): - attempt = store.read(sid) - if attempt and attempt.get("status") == "in_progress": - data = attempt - break - if not data or not isinstance(data, dict): return "" diff --git a/agent/checkpoint_store.py b/agent/checkpoint_store.py index ee89744aca7..d5ca40a84c1 100644 --- a/agent/checkpoint_store.py +++ b/agent/checkpoint_store.py @@ -6,6 +6,7 @@ left off after context compaction or session restart. Files live under """ import logging +import re import yaml from datetime import datetime, timedelta from pathlib import Path @@ -13,6 +14,9 @@ from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) +# Session IDs must be safe filesystem names -- no path separators or traversal +_SAFE_SESSION_ID = re.compile(r"^[A-Za-z0-9_-]+$") + DEFAULT_CHECKPOINTS_DIR = Path.home() / ".hermes" / "checkpoints" DEFAULT_GC_MAX_AGE_DAYS = 7 @@ -25,6 +29,8 @@ class CheckpointStore: self._dir.mkdir(parents=True, exist_ok=True) def _path_for(self, session_id: str) -> Path: + if not _SAFE_SESSION_ID.match(session_id): + raise ValueError(f"Invalid session_id: {session_id!r} (must match {_SAFE_SESSION_ID.pattern})") return self._dir / f"{session_id}.yaml" def write(self, session_id: str, data: Dict[str, Any]) -> None: diff --git a/tests/test_checkpoint_flush.py b/tests/test_checkpoint_flush.py index fa253c666e5..757e68bb4be 100644 --- a/tests/test_checkpoint_flush.py +++ b/tests/test_checkpoint_flush.py @@ -1,7 +1,7 @@ -"""Test that flush_checkpoint is called during context compression.""" +"""Test that flush_checkpoint writes a checkpoint via checkpoint_tool.""" import json import pytest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch def test_flush_checkpoint_method_exists(): @@ -13,27 +13,41 @@ def test_flush_checkpoint_method_exists(): def test_flush_checkpoint_writes_to_store(tmp_path): """flush_checkpoint should write a checkpoint with current session state.""" from agent.checkpoint_store import CheckpointStore - from tools.checkpoint_tool import checkpoint_tool + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + mock_todo = MagicMock() + mock_todo._items = [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "in_progress"}, + ] + mock_todo.format_for_injection.return_value = "- [x] Step 1\n- [~] Step 2" + + mock_session_db = MagicMock() + mock_session_db.get_session_title.return_value = "Test session title" + mock_agent = MagicMock() mock_agent.session_id = "flush_test_session" mock_agent._checkpoint_store = store - mock_agent._todo_store = MagicMock() - mock_agent._todo_store.format_for_injection.return_value = "- [x] Step 1" + mock_agent._todo_store = mock_todo + mock_agent._session_db = mock_session_db - # Verify the checkpoint tool writes successfully (same path flush_checkpoint uses) - result = checkpoint_tool( - action="write", - task="Auto-checkpoint before compression", - progress=[], - state={}, - decisions=[], - store=store, - agent=mock_agent, - ) - data = json.loads(result) - assert data["success"] is True + # Call the real flush_checkpoint method + from run_agent import AIAgent + AIAgent.flush_checkpoint(mock_agent) + + # Verify a checkpoint was written for this session saved = store.read("flush_test_session") assert saved is not None - assert saved["task"] == "Auto-checkpoint before compression" \ No newline at end of file + assert saved["task"] == "Test session title" + assert saved["status"] == "in_progress" + + +def test_flush_checkpoint_noops_without_store(tmp_path): + """flush_checkpoint should silently return if _checkpoint_store is None.""" + mock_agent = MagicMock() + mock_agent._checkpoint_store = None + + from run_agent import AIAgent + # Should not raise + AIAgent.flush_checkpoint(mock_agent) \ No newline at end of file diff --git a/tests/test_checkpoint_regression.py b/tests/test_checkpoint_regression.py new file mode 100644 index 00000000000..28705348842 --- /dev/null +++ b/tests/test_checkpoint_regression.py @@ -0,0 +1,144 @@ +"""Regression tests for bugs found during red-team QA review. + +Bug #1: get_parent_session_id() didn't exist on SessionDB +Bug #2: Path traversal via session_id +Bug #3: "checkpoint" not in any toolset definition +Bug #4: flush_checkpoint test didn't actually call flush_checkpoint +Bug #5: Injection fallback grabbed ANY in_progress checkpoint from unrelated sessions +""" +import pytest +from pathlib import Path + + +class TestBug1SessionDBLineageWalk: + """Bug #1: _build_system_prompt must use get_session()["parent_session_id"] + instead of the nonexistent get_parent_session_id() method. + """ + + def test_session_db_has_no_get_parent_session_id_method(self): + """SessionDB must NOT have a get_parent_session_id method (it never did).""" + from hermes_state import SessionDB + assert not hasattr(SessionDB, "get_parent_session_id"), ( + "SessionDB should not have get_parent_session_id -- " + "use get_session(sid)['parent_session_id'] instead" + ) + + def test_get_session_returns_parent_session_id_key(self, tmp_path): + """get_session() must return a dict with 'parent_session_id' key.""" + from hermes_state import SessionDB + from pathlib import Path + db = SessionDB(db_path=Path(tmp_path) / "test.db") + # Create a parent session + parent_id = db.create_session("parent_session", source="test", model="test") + # Create a child session + child_id = db.create_session("child_session", source="test", model="test", parent_session_id=parent_id) + sess = db.get_session(child_id) + assert sess is not None + assert "parent_session_id" in sess + assert sess["parent_session_id"] == parent_id + + +class TestBug2PathTraversal: + """Bug #2: session_id must be validated to prevent path traversal.""" + + def test_reject_dotdot_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("../../../etc/passwd", {"task": "evil"}) + + def test_reject_slash_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("sub/dir", {"task": "evil"}) + + def test_reject_backslash_in_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.write("sub\\dir", {"task": "evil"}) + + def test_accept_normal_session_id(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Should not raise + store.write("20260422_225800_abc123", {"task": "normal"}) + + def test_read_rejects_path_traversal(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.read("../../../etc/passwd") + + def test_delete_rejects_path_traversal(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + with pytest.raises(ValueError, match="Invalid session_id"): + store.delete("../../../etc/passwd") + + +class TestBug3CheckpointInToolset: + """Bug #3: 'checkpoint' must be in a toolset definition so the model gets the schema.""" + + def test_checkpoint_in_todo_toolset(self): + from toolsets import TOOLSETS + todo_tools = TOOLSETS.get("todo", {}).get("tools", []) + assert "checkpoint" in todo_tools, ( + f"'checkpoint' must be in the 'todo' toolset tools list. " + f"Current todo tools: {todo_tools}" + ) + + def test_resolve_toolset_todo_includes_checkpoint(self): + from toolsets import resolve_toolset + resolved = resolve_toolset("todo") + assert "checkpoint" in resolved, ( + f"resolve_toolset('todo') must return ['todo', 'checkpoint']. Got: {resolved}" + ) + + +class TestBug5InjectionNoCrossSession: + """Bug #5: build_checkpoint_system_prompt must NOT fall through to grabbing + any random in_progress checkpoint from an unrelated session. + """ + + def test_no_injection_when_parent_has_no_checkpoint(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + from agent.checkpoint_injection import build_checkpoint_system_prompt + + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Write a checkpoint for an unrelated session + store.write("unrelated_session", { + "task": "Unrelated task", + "status": "in_progress", + "progress": [], + }) + + # Ask for a parent_session_id that has no checkpoint + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="missing_session") + assert prompt == "", ( + "Must not inject a checkpoint from an unrelated session when " + "the requested parent_session_id has no checkpoint" + ) + + def test_only_injects_specific_parent(self, tmp_path): + from agent.checkpoint_store import CheckpointStore + from agent.checkpoint_injection import build_checkpoint_system_prompt + + store = CheckpointStore(checkpoints_dir=tmp_path / "checkpoints") + # Write checkpoints for two sessions + store.write("parent_A", { + "task": "Parent A task", + "status": "in_progress", + "progress": [], + }) + store.write("parent_B", { + "task": "Parent B task", + "status": "in_progress", + "progress": [], + }) + + # Request parent_A specifically + prompt = build_checkpoint_system_prompt(store=store, parent_session_id="parent_A") + assert "Parent A task" in prompt + assert "Parent B task" not in prompt \ No newline at end of file diff --git a/toolsets.py b/toolsets.py index 65f560bfe45..b6619d3ab75 100644 --- a/toolsets.py +++ b/toolsets.py @@ -159,7 +159,7 @@ TOOLSETS = { "todo": { "description": "Task planning and tracking for multi-step work", - "tools": ["todo"], + "tools": ["todo", "checkpoint"], "includes": [] },