Add confirmation for destructive slash commands

This commit is contained in:
Akshat 2026-04-16 02:26:25 +05:30
parent d1d425e9d0
commit 064a2882f8
7 changed files with 289 additions and 16 deletions

17
cli.py
View file

@ -5277,10 +5277,16 @@ class HermesCLI:
# Resolve aliases via central registry so adding an alias is a one-line
# change in hermes_cli/commands.py instead of touching every dispatch site.
from hermes_cli.commands import resolve_command as _resolve_cmd
from hermes_cli.commands import (
destructive_command_confirmation_message as _destructive_command_msg,
destructive_command_is_confirmed as _destructive_command_is_confirmed,
resolve_command as _resolve_cmd,
)
_base_word = cmd_lower.split()[0].lstrip("/")
_cmd_def = _resolve_cmd(_base_word)
canonical = _cmd_def.name if _cmd_def else _base_word
_typed_command = cmd_original.split()[0].lstrip("/").lower()
_args = cmd_original.split(maxsplit=1)[1] if " " in cmd_original else ""
if canonical in ("quit", "exit", "q"):
return False
@ -5295,6 +5301,9 @@ class HermesCLI:
elif canonical == "config":
self.show_config()
elif canonical == "clear":
if not _destructive_command_is_confirmed(_args):
_cprint(_destructive_command_msg("clear", _typed_command))
return True
self.new_session(silent=True)
# Clear terminal screen. Inside the TUI, Rich's console.clear()
# goes through patch_stdout's StdoutProxy which swallows the
@ -5412,6 +5421,9 @@ class HermesCLI:
else:
_cprint(" Session database not available.")
elif canonical == "new":
if not _destructive_command_is_confirmed(_args):
_cprint(_destructive_command_msg("new", _typed_command))
return True
self.new_session()
elif canonical == "resume":
self._handle_resume_command(cmd_original)
@ -5431,6 +5443,9 @@ class HermesCLI:
# Re-queue the message so process_loop sends it to the agent
self._pending_input.put(retry_msg)
elif canonical == "undo":
if not _destructive_command_is_confirmed(_args):
_cprint(_destructive_command_msg("undo", _typed_command))
return True
self.undo_last()
elif canonical == "branch":
self._handle_branch_command(cmd_original)

View file

@ -4259,6 +4259,15 @@ class GatewayRunner:
async def _handle_reset_command(self, event: MessageEvent) -> str:
"""Handle /new or /reset command."""
from hermes_cli.commands import (
destructive_command_confirmation_message,
destructive_command_is_confirmed,
)
typed_command = event.get_command() or "new"
if not destructive_command_is_confirmed(event.get_command_args()):
return destructive_command_confirmation_message("new", typed_command)
source = event.source
# Get existing session key
@ -5079,6 +5088,11 @@ class GatewayRunner:
async def _handle_undo_command(self, event: MessageEvent) -> str:
"""Handle /undo command - remove the last user/assistant exchange."""
from hermes_cli.commands import (
destructive_command_confirmation_message,
destructive_command_is_confirmed,
)
source = event.source
session_entry = self.session_store.get_or_create_session(source)
history = self.session_store.load_transcript(session_entry.session_id)
@ -5092,7 +5106,11 @@ class GatewayRunner:
if last_user_idx is None:
return "Nothing to undo."
if not destructive_command_is_confirmed(event.get_command_args()):
typed_command = event.get_command() or "undo"
return destructive_command_confirmation_message("undo", typed_command)
removed_msg = history[last_user_idx].get("content", "")
removed_count = len(history) - last_user_idx
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])

View file

@ -59,15 +59,16 @@ class CommandDef:
COMMAND_REGISTRY: list[CommandDef] = [
# Session
CommandDef("new", "Start a new session (fresh session ID + history)", "Session",
aliases=("reset",)),
aliases=("reset",), args_hint="[--yes]"),
CommandDef("clear", "Clear screen and start a new session", "Session",
cli_only=True),
cli_only=True, args_hint="[--yes]"),
CommandDef("history", "Show conversation history", "Session",
cli_only=True),
CommandDef("save", "Save the current conversation", "Session",
cli_only=True),
CommandDef("retry", "Retry the last message (resend to agent)", "Session"),
CommandDef("undo", "Remove the last user/assistant exchange", "Session"),
CommandDef("undo", "Remove the last user/assistant exchange", "Session",
args_hint="[--yes]"),
CommandDef("title", "Set a title for the current session", "Session",
args_hint="[name]"),
CommandDef("branch", "Branch the current session (explore a different path)", "Session",
@ -184,6 +185,18 @@ def _build_command_lookup() -> dict[str, CommandDef]:
_COMMAND_LOOKUP: dict[str, CommandDef] = _build_command_lookup()
_DESTRUCTIVE_COMMAND_ACTIONS: dict[str, str] = {
"new": "discard the current conversation history and start a fresh session",
"clear": "clear the screen and discard the current conversation history",
"undo": "remove the last user/assistant exchange from this session",
}
_DESTRUCTIVE_CONFIRM_TOKENS: frozenset[str] = frozenset({
"--yes",
"--confirm",
"yes",
"confirm",
})
def resolve_command(name: str) -> CommandDef | None:
"""Resolve a command name or alias to its CommandDef.
@ -193,6 +206,26 @@ def resolve_command(name: str) -> CommandDef | None:
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
def destructive_command_is_confirmed(args: str) -> bool:
"""Return True when the command args include an explicit confirmation token."""
return any(token.lower() in _DESTRUCTIVE_CONFIRM_TOKENS for token in args.split())
def destructive_command_confirmation_message(
canonical_name: str,
typed_name: str | None = None,
) -> str:
"""Build a consistent warning for destructive session commands."""
action = _DESTRUCTIVE_COMMAND_ACTIONS.get(canonical_name)
if not action:
raise KeyError(f"Unknown destructive command: {canonical_name}")
command_name = (typed_name or canonical_name).lstrip("/").lower()
return (
f"Confirmation required: this will {action}. "
f"Re-run `/{command_name} --yes` to continue."
)
def _build_description(cmd: CommandDef) -> str:
"""Build a CLI-facing description string including usage hint."""
if cmd.args_hint:

View file

@ -34,6 +34,7 @@ class _FakeAgent:
[{"id": "t1", "content": "unfinished task", "status": "in_progress"}]
)
self.flush_memories = MagicMock()
self.commit_memory_session = MagicMock()
self._invalidate_system_prompt = MagicMock()
# Token counters (non-zero to verify reset)
@ -116,6 +117,7 @@ def _make_cli(env_overrides=None, config_overrides=None, **kwargs):
with patch.object(_cli_mod, "get_tool_definitions", return_value=[]), patch.dict(
_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}
):
_cli_mod._cprint = MagicMock()
return _cli_mod.HermesCLI(**kwargs)
@ -138,7 +140,7 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path)
old_session_id = cli.session_id
old_session_start = cli.session_start
cli.process_command("/new")
cli.process_command("/new --yes")
assert cli.session_id != old_session_id
@ -164,7 +166,7 @@ def test_reset_command_is_alias_for_new_session(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id
cli.process_command("/reset")
cli.process_command("/reset --yes")
assert cli.session_id != old_session_id
assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session"
@ -177,7 +179,7 @@ def test_clear_command_starts_new_session_before_redrawing(tmp_path):
cli.show_banner = MagicMock()
old_session_id = cli.session_id
cli.process_command("/clear")
cli.process_command("/clear --yes")
assert cli.session_id != old_session_id
assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session"
@ -197,7 +199,7 @@ def test_new_session_resets_token_counters(tmp_path):
assert agent.session_api_calls > 0
assert agent.context_compressor.compression_count > 0
cli.process_command("/new")
cli.process_command("/new --yes")
# All agent token counters must be zero
assert agent.session_total_tokens == 0
@ -220,3 +222,79 @@ def test_new_session_resets_token_counters(tmp_path):
assert comp.last_total_tokens == 0
assert comp.compression_count == 0
assert comp._context_probed is False
def test_new_command_requires_confirmation_before_resetting(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id
cli.process_command("/new")
assert cli.session_id == old_session_id
assert cli._session_db.get_session(old_session_id) is not None
assert cli.conversation_history == [{"role": "user", "content": "hello"}]
def test_reset_without_confirmation_does_not_reset_session(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
old_session_id = cli.session_id
cli.process_command("/reset")
assert cli.session_id == old_session_id
def test_reset_alias_warning_uses_typed_command():
from hermes_cli.commands import destructive_command_confirmation_message
assert "/reset --yes" in destructive_command_confirmation_message("new", "reset")
def test_clear_requires_confirmation_before_redrawing(tmp_path):
cli = _prepare_cli_with_active_session(tmp_path)
cli.console = MagicMock()
cli.show_banner = MagicMock()
old_session_id = cli.session_id
cli.process_command("/clear")
assert cli.session_id == old_session_id
assert cli.conversation_history == [{"role": "user", "content": "hello"}]
cli.console.clear.assert_not_called()
cli.show_banner.assert_not_called()
def test_undo_requires_confirmation_before_mutating_history():
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "first prompt"},
{"role": "assistant", "content": "first reply"},
{"role": "user", "content": "second prompt"},
{"role": "assistant", "content": "second reply"},
]
cli.process_command("/undo")
assert [msg["content"] for msg in cli.conversation_history] == [
"first prompt",
"first reply",
"second prompt",
"second reply",
]
def test_undo_with_confirmation_removes_last_exchange():
cli = _make_cli()
cli.conversation_history = [
{"role": "user", "content": "first prompt"},
{"role": "assistant", "content": "first reply"},
{"role": "user", "content": "second prompt"},
{"role": "assistant", "content": "second reply"},
]
cli.process_command("/undo --yes")
assert [msg["content"] for msg in cli.conversation_history] == [
"first prompt",
"first reply",
]

View file

@ -0,0 +1,129 @@
"""Tests confirmation guards for destructive gateway session commands."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from gateway.session import SessionEntry, SessionSource, build_session_key
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_event(text: str) -> MessageEvent:
return MessageEvent(text=text, source=_make_source(), message_id="m1")
def _make_runner():
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner._session_model_overrides = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
session_key = build_session_key(_make_source())
session_entry = SessionEntry(
session_key=session_key,
session_id="sess-old",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
new_session_entry = SessionEntry(
session_key=session_key,
session_id="sess-new",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store = MagicMock()
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store.reset_session.return_value = new_session_entry
runner.session_store._entries = {session_key: session_entry}
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._agent_cache_lock = None
runner._is_user_authorized = lambda _source: True
runner._format_session_info = lambda: ""
return runner
@pytest.mark.asyncio
async def test_new_requires_confirmation_before_reset():
runner = _make_runner()
result = await runner._handle_reset_command(_make_event("/new"))
assert "/new --yes" in result
runner.session_store.reset_session.assert_not_called()
@pytest.mark.asyncio
async def test_reset_alias_confirmation_mentions_reset():
runner = _make_runner()
result = await runner._handle_reset_command(_make_event("/reset"))
assert "/reset --yes" in result
runner.session_store.reset_session.assert_not_called()
@pytest.mark.asyncio
async def test_undo_requires_confirmation_before_rewriting_transcript():
runner = _make_runner()
runner.session_store.load_transcript.return_value = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]
result = await runner._handle_undo_command(_make_event("/undo"))
assert "/undo --yes" in result
runner.session_store.rewrite_transcript.assert_not_called()
@pytest.mark.asyncio
async def test_undo_with_confirmation_rewrites_transcript():
runner = _make_runner()
runner.session_store.load_transcript.return_value = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "undo this"},
{"role": "assistant", "content": "ok"},
]
result = await runner._handle_undo_command(_make_event("/undo --yes"))
runner.session_store.rewrite_transcript.assert_called_once_with(
"sess-old",
[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
],
)
assert "Undid 2 message(s)." in result

View file

@ -79,7 +79,7 @@ async def test_reset_fires_finalize_hook(mock_invoke_hook):
"""/new must fire on_session_finalize with the OLD session id."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
mock_invoke_hook.assert_any_call(
"on_session_finalize", session_id="sess-old", platform="telegram"
@ -92,7 +92,7 @@ async def test_reset_fires_reset_hook(mock_invoke_hook):
"""/new must fire on_session_reset with the NEW session id."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
mock_invoke_hook.assert_any_call(
"on_session_reset", session_id="sess-new", platform="telegram"
@ -105,7 +105,7 @@ async def test_finalize_before_reset(mock_invoke_hook):
"""on_session_finalize must fire before on_session_reset."""
runner = _make_runner()
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
calls = [c for c in mock_invoke_hook.call_args_list
if c[0][0] in ("on_session_finalize", "on_session_reset")]
@ -162,7 +162,7 @@ async def test_hook_error_does_not_break_reset(mock_invoke_hook):
"""Plugin hook errors must not prevent /new from completing."""
runner = _make_runner()
result = await runner._handle_reset_command(_make_event("/new"))
result = await runner._handle_reset_command(_make_event("/new --yes"))
# Should still return a success message despite hook errors
assert "Session reset" in result or "New session" in result

View file

@ -80,7 +80,7 @@ async def test_new_command_clears_session_model_override():
"api_mode": "openai",
}
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
assert session_key not in runner._session_model_overrides
@ -93,7 +93,7 @@ async def test_new_command_no_override_is_noop():
assert session_key not in runner._session_model_overrides
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
assert session_key not in runner._session_model_overrides
@ -120,7 +120,7 @@ async def test_new_command_only_clears_own_session():
"api_mode": "anthropic",
}
await runner._handle_reset_command(_make_event("/new"))
await runner._handle_reset_command(_make_event("/new --yes"))
assert session_key not in runner._session_model_overrides
assert other_key in runner._session_model_overrides