mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Add confirmation for destructive slash commands
This commit is contained in:
parent
d1d425e9d0
commit
064a2882f8
7 changed files with 289 additions and 16 deletions
17
cli.py
17
cli.py
|
|
@ -5277,10 +5277,16 @@ class HermesCLI:
|
|||
|
||||
# Resolve aliases via central registry so adding an alias is a one-line
|
||||
# change in hermes_cli/commands.py instead of touching every dispatch site.
|
||||
from hermes_cli.commands import resolve_command as _resolve_cmd
|
||||
from hermes_cli.commands import (
|
||||
destructive_command_confirmation_message as _destructive_command_msg,
|
||||
destructive_command_is_confirmed as _destructive_command_is_confirmed,
|
||||
resolve_command as _resolve_cmd,
|
||||
)
|
||||
_base_word = cmd_lower.split()[0].lstrip("/")
|
||||
_cmd_def = _resolve_cmd(_base_word)
|
||||
canonical = _cmd_def.name if _cmd_def else _base_word
|
||||
_typed_command = cmd_original.split()[0].lstrip("/").lower()
|
||||
_args = cmd_original.split(maxsplit=1)[1] if " " in cmd_original else ""
|
||||
|
||||
if canonical in ("quit", "exit", "q"):
|
||||
return False
|
||||
|
|
@ -5295,6 +5301,9 @@ class HermesCLI:
|
|||
elif canonical == "config":
|
||||
self.show_config()
|
||||
elif canonical == "clear":
|
||||
if not _destructive_command_is_confirmed(_args):
|
||||
_cprint(_destructive_command_msg("clear", _typed_command))
|
||||
return True
|
||||
self.new_session(silent=True)
|
||||
# Clear terminal screen. Inside the TUI, Rich's console.clear()
|
||||
# goes through patch_stdout's StdoutProxy which swallows the
|
||||
|
|
@ -5412,6 +5421,9 @@ class HermesCLI:
|
|||
else:
|
||||
_cprint(" Session database not available.")
|
||||
elif canonical == "new":
|
||||
if not _destructive_command_is_confirmed(_args):
|
||||
_cprint(_destructive_command_msg("new", _typed_command))
|
||||
return True
|
||||
self.new_session()
|
||||
elif canonical == "resume":
|
||||
self._handle_resume_command(cmd_original)
|
||||
|
|
@ -5431,6 +5443,9 @@ class HermesCLI:
|
|||
# Re-queue the message so process_loop sends it to the agent
|
||||
self._pending_input.put(retry_msg)
|
||||
elif canonical == "undo":
|
||||
if not _destructive_command_is_confirmed(_args):
|
||||
_cprint(_destructive_command_msg("undo", _typed_command))
|
||||
return True
|
||||
self.undo_last()
|
||||
elif canonical == "branch":
|
||||
self._handle_branch_command(cmd_original)
|
||||
|
|
|
|||
|
|
@ -4259,6 +4259,15 @@ class GatewayRunner:
|
|||
|
||||
async def _handle_reset_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /new or /reset command."""
|
||||
from hermes_cli.commands import (
|
||||
destructive_command_confirmation_message,
|
||||
destructive_command_is_confirmed,
|
||||
)
|
||||
|
||||
typed_command = event.get_command() or "new"
|
||||
if not destructive_command_is_confirmed(event.get_command_args()):
|
||||
return destructive_command_confirmation_message("new", typed_command)
|
||||
|
||||
source = event.source
|
||||
|
||||
# Get existing session key
|
||||
|
|
@ -5079,6 +5088,11 @@ class GatewayRunner:
|
|||
|
||||
async def _handle_undo_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /undo command - remove the last user/assistant exchange."""
|
||||
from hermes_cli.commands import (
|
||||
destructive_command_confirmation_message,
|
||||
destructive_command_is_confirmed,
|
||||
)
|
||||
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
|
@ -5092,7 +5106,11 @@ class GatewayRunner:
|
|||
|
||||
if last_user_idx is None:
|
||||
return "Nothing to undo."
|
||||
|
||||
|
||||
if not destructive_command_is_confirmed(event.get_command_args()):
|
||||
typed_command = event.get_command() or "undo"
|
||||
return destructive_command_confirmation_message("undo", typed_command)
|
||||
|
||||
removed_msg = history[last_user_idx].get("content", "")
|
||||
removed_count = len(history) - last_user_idx
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
||||
|
|
|
|||
|
|
@ -59,15 +59,16 @@ class CommandDef:
|
|||
COMMAND_REGISTRY: list[CommandDef] = [
|
||||
# Session
|
||||
CommandDef("new", "Start a new session (fresh session ID + history)", "Session",
|
||||
aliases=("reset",)),
|
||||
aliases=("reset",), args_hint="[--yes]"),
|
||||
CommandDef("clear", "Clear screen and start a new session", "Session",
|
||||
cli_only=True),
|
||||
cli_only=True, args_hint="[--yes]"),
|
||||
CommandDef("history", "Show conversation history", "Session",
|
||||
cli_only=True),
|
||||
CommandDef("save", "Save the current conversation", "Session",
|
||||
cli_only=True),
|
||||
CommandDef("retry", "Retry the last message (resend to agent)", "Session"),
|
||||
CommandDef("undo", "Remove the last user/assistant exchange", "Session"),
|
||||
CommandDef("undo", "Remove the last user/assistant exchange", "Session",
|
||||
args_hint="[--yes]"),
|
||||
CommandDef("title", "Set a title for the current session", "Session",
|
||||
args_hint="[name]"),
|
||||
CommandDef("branch", "Branch the current session (explore a different path)", "Session",
|
||||
|
|
@ -184,6 +185,18 @@ def _build_command_lookup() -> dict[str, CommandDef]:
|
|||
|
||||
_COMMAND_LOOKUP: dict[str, CommandDef] = _build_command_lookup()
|
||||
|
||||
_DESTRUCTIVE_COMMAND_ACTIONS: dict[str, str] = {
|
||||
"new": "discard the current conversation history and start a fresh session",
|
||||
"clear": "clear the screen and discard the current conversation history",
|
||||
"undo": "remove the last user/assistant exchange from this session",
|
||||
}
|
||||
_DESTRUCTIVE_CONFIRM_TOKENS: frozenset[str] = frozenset({
|
||||
"--yes",
|
||||
"--confirm",
|
||||
"yes",
|
||||
"confirm",
|
||||
})
|
||||
|
||||
|
||||
def resolve_command(name: str) -> CommandDef | None:
|
||||
"""Resolve a command name or alias to its CommandDef.
|
||||
|
|
@ -193,6 +206,26 @@ def resolve_command(name: str) -> CommandDef | None:
|
|||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def destructive_command_is_confirmed(args: str) -> bool:
|
||||
"""Return True when the command args include an explicit confirmation token."""
|
||||
return any(token.lower() in _DESTRUCTIVE_CONFIRM_TOKENS for token in args.split())
|
||||
|
||||
|
||||
def destructive_command_confirmation_message(
|
||||
canonical_name: str,
|
||||
typed_name: str | None = None,
|
||||
) -> str:
|
||||
"""Build a consistent warning for destructive session commands."""
|
||||
action = _DESTRUCTIVE_COMMAND_ACTIONS.get(canonical_name)
|
||||
if not action:
|
||||
raise KeyError(f"Unknown destructive command: {canonical_name}")
|
||||
command_name = (typed_name or canonical_name).lstrip("/").lower()
|
||||
return (
|
||||
f"Confirmation required: this will {action}. "
|
||||
f"Re-run `/{command_name} --yes` to continue."
|
||||
)
|
||||
|
||||
|
||||
def _build_description(cmd: CommandDef) -> str:
|
||||
"""Build a CLI-facing description string including usage hint."""
|
||||
if cmd.args_hint:
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class _FakeAgent:
|
|||
[{"id": "t1", "content": "unfinished task", "status": "in_progress"}]
|
||||
)
|
||||
self.flush_memories = MagicMock()
|
||||
self.commit_memory_session = MagicMock()
|
||||
self._invalidate_system_prompt = MagicMock()
|
||||
|
||||
# Token counters (non-zero to verify reset)
|
||||
|
|
@ -116,6 +117,7 @@ def _make_cli(env_overrides=None, config_overrides=None, **kwargs):
|
|||
with patch.object(_cli_mod, "get_tool_definitions", return_value=[]), patch.dict(
|
||||
_cli_mod.__dict__, {"CLI_CONFIG": _clean_config}
|
||||
):
|
||||
_cli_mod._cprint = MagicMock()
|
||||
return _cli_mod.HermesCLI(**kwargs)
|
||||
|
||||
|
||||
|
|
@ -138,7 +140,7 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path)
|
|||
old_session_id = cli.session_id
|
||||
old_session_start = cli.session_start
|
||||
|
||||
cli.process_command("/new")
|
||||
cli.process_command("/new --yes")
|
||||
|
||||
assert cli.session_id != old_session_id
|
||||
|
||||
|
|
@ -164,7 +166,7 @@ def test_reset_command_is_alias_for_new_session(tmp_path):
|
|||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
|
||||
cli.process_command("/reset")
|
||||
cli.process_command("/reset --yes")
|
||||
|
||||
assert cli.session_id != old_session_id
|
||||
assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session"
|
||||
|
|
@ -177,7 +179,7 @@ def test_clear_command_starts_new_session_before_redrawing(tmp_path):
|
|||
cli.show_banner = MagicMock()
|
||||
|
||||
old_session_id = cli.session_id
|
||||
cli.process_command("/clear")
|
||||
cli.process_command("/clear --yes")
|
||||
|
||||
assert cli.session_id != old_session_id
|
||||
assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session"
|
||||
|
|
@ -197,7 +199,7 @@ def test_new_session_resets_token_counters(tmp_path):
|
|||
assert agent.session_api_calls > 0
|
||||
assert agent.context_compressor.compression_count > 0
|
||||
|
||||
cli.process_command("/new")
|
||||
cli.process_command("/new --yes")
|
||||
|
||||
# All agent token counters must be zero
|
||||
assert agent.session_total_tokens == 0
|
||||
|
|
@ -220,3 +222,79 @@ def test_new_session_resets_token_counters(tmp_path):
|
|||
assert comp.last_total_tokens == 0
|
||||
assert comp.compression_count == 0
|
||||
assert comp._context_probed is False
|
||||
|
||||
|
||||
def test_new_command_requires_confirmation_before_resetting(tmp_path):
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
|
||||
cli.process_command("/new")
|
||||
|
||||
assert cli.session_id == old_session_id
|
||||
assert cli._session_db.get_session(old_session_id) is not None
|
||||
assert cli.conversation_history == [{"role": "user", "content": "hello"}]
|
||||
|
||||
|
||||
def test_reset_without_confirmation_does_not_reset_session(tmp_path):
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
|
||||
cli.process_command("/reset")
|
||||
|
||||
assert cli.session_id == old_session_id
|
||||
|
||||
|
||||
def test_reset_alias_warning_uses_typed_command():
|
||||
from hermes_cli.commands import destructive_command_confirmation_message
|
||||
|
||||
assert "/reset --yes" in destructive_command_confirmation_message("new", "reset")
|
||||
|
||||
|
||||
def test_clear_requires_confirmation_before_redrawing(tmp_path):
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
cli.console = MagicMock()
|
||||
cli.show_banner = MagicMock()
|
||||
old_session_id = cli.session_id
|
||||
|
||||
cli.process_command("/clear")
|
||||
|
||||
assert cli.session_id == old_session_id
|
||||
assert cli.conversation_history == [{"role": "user", "content": "hello"}]
|
||||
cli.console.clear.assert_not_called()
|
||||
cli.show_banner.assert_not_called()
|
||||
|
||||
|
||||
def test_undo_requires_confirmation_before_mutating_history():
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "first prompt"},
|
||||
{"role": "assistant", "content": "first reply"},
|
||||
{"role": "user", "content": "second prompt"},
|
||||
{"role": "assistant", "content": "second reply"},
|
||||
]
|
||||
|
||||
cli.process_command("/undo")
|
||||
|
||||
assert [msg["content"] for msg in cli.conversation_history] == [
|
||||
"first prompt",
|
||||
"first reply",
|
||||
"second prompt",
|
||||
"second reply",
|
||||
]
|
||||
|
||||
|
||||
def test_undo_with_confirmation_removes_last_exchange():
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "user", "content": "first prompt"},
|
||||
{"role": "assistant", "content": "first reply"},
|
||||
{"role": "user", "content": "second prompt"},
|
||||
{"role": "assistant", "content": "second reply"},
|
||||
]
|
||||
|
||||
cli.process_command("/undo --yes")
|
||||
|
||||
assert [msg["content"] for msg in cli.conversation_history] == [
|
||||
"first prompt",
|
||||
"first reply",
|
||||
]
|
||||
|
|
|
|||
129
tests/gateway/test_destructive_command_confirmation.py
Normal file
129
tests/gateway/test_destructive_command_confirmation.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Tests confirmation guards for destructive gateway session commands."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._background_tasks = set()
|
||||
|
||||
session_key = build_session_key(_make_source())
|
||||
session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-old",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
new_session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-new",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.reset_session.return_value = new_session_entry
|
||||
runner.session_store._entries = {session_key: session_entry}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._agent_cache_lock = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._format_session_info = lambda: ""
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requires_confirmation_before_reset():
|
||||
runner = _make_runner()
|
||||
|
||||
result = await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
assert "/new --yes" in result
|
||||
runner.session_store.reset_session.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_alias_confirmation_mentions_reset():
|
||||
runner = _make_runner()
|
||||
|
||||
result = await runner._handle_reset_command(_make_event("/reset"))
|
||||
|
||||
assert "/reset --yes" in result
|
||||
runner.session_store.reset_session.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_undo_requires_confirmation_before_rewriting_transcript():
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
|
||||
result = await runner._handle_undo_command(_make_event("/undo"))
|
||||
|
||||
assert "/undo --yes" in result
|
||||
runner.session_store.rewrite_transcript.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_undo_with_confirmation_rewrites_transcript():
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "undo this"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
]
|
||||
|
||||
result = await runner._handle_undo_command(_make_event("/undo --yes"))
|
||||
|
||||
runner.session_store.rewrite_transcript.assert_called_once_with(
|
||||
"sess-old",
|
||||
[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
],
|
||||
)
|
||||
assert "Undid 2 message(s)." in result
|
||||
|
|
@ -79,7 +79,7 @@ async def test_reset_fires_finalize_hook(mock_invoke_hook):
|
|||
"""/new must fire on_session_finalize with the OLD session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="sess-old", platform="telegram"
|
||||
|
|
@ -92,7 +92,7 @@ async def test_reset_fires_reset_hook(mock_invoke_hook):
|
|||
"""/new must fire on_session_reset with the NEW session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_reset", session_id="sess-new", platform="telegram"
|
||||
|
|
@ -105,7 +105,7 @@ async def test_finalize_before_reset(mock_invoke_hook):
|
|||
"""on_session_finalize must fire before on_session_reset."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
calls = [c for c in mock_invoke_hook.call_args_list
|
||||
if c[0][0] in ("on_session_finalize", "on_session_reset")]
|
||||
|
|
@ -162,7 +162,7 @@ async def test_hook_error_does_not_break_reset(mock_invoke_hook):
|
|||
"""Plugin hook errors must not prevent /new from completing."""
|
||||
runner = _make_runner()
|
||||
|
||||
result = await runner._handle_reset_command(_make_event("/new"))
|
||||
result = await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
# Should still return a success message despite hook errors
|
||||
assert "Session reset" in result or "New session" in result
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ async def test_new_command_clears_session_model_override():
|
|||
"api_mode": "openai",
|
||||
}
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
assert session_key not in runner._session_model_overrides
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ async def test_new_command_no_override_is_noop():
|
|||
|
||||
assert session_key not in runner._session_model_overrides
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
assert session_key not in runner._session_model_overrides
|
||||
|
||||
|
|
@ -120,7 +120,7 @@ async def test_new_command_only_clears_own_session():
|
|||
"api_mode": "anthropic",
|
||||
}
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
await runner._handle_reset_command(_make_event("/new --yes"))
|
||||
|
||||
assert session_key not in runner._session_model_overrides
|
||||
assert other_key in runner._session_model_overrides
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue