diff --git a/cli.py b/cli.py index fbc8f85250..f47427ff4f 100644 --- a/cli.py +++ b/cli.py @@ -5277,10 +5277,16 @@ class HermesCLI: # Resolve aliases via central registry so adding an alias is a one-line # change in hermes_cli/commands.py instead of touching every dispatch site. - from hermes_cli.commands import resolve_command as _resolve_cmd + from hermes_cli.commands import ( + destructive_command_confirmation_message as _destructive_command_msg, + destructive_command_is_confirmed as _destructive_command_is_confirmed, + resolve_command as _resolve_cmd, + ) _base_word = cmd_lower.split()[0].lstrip("/") _cmd_def = _resolve_cmd(_base_word) canonical = _cmd_def.name if _cmd_def else _base_word + _typed_command = cmd_original.split()[0].lstrip("/").lower() + _args = cmd_original.split(maxsplit=1)[1] if " " in cmd_original else "" if canonical in ("quit", "exit", "q"): return False @@ -5295,6 +5301,9 @@ class HermesCLI: elif canonical == "config": self.show_config() elif canonical == "clear": + if not _destructive_command_is_confirmed(_args): + _cprint(_destructive_command_msg("clear", _typed_command)) + return True self.new_session(silent=True) # Clear terminal screen. Inside the TUI, Rich's console.clear() # goes through patch_stdout's StdoutProxy which swallows the @@ -5412,6 +5421,9 @@ class HermesCLI: else: _cprint(" Session database not available.") elif canonical == "new": + if not _destructive_command_is_confirmed(_args): + _cprint(_destructive_command_msg("new", _typed_command)) + return True self.new_session() elif canonical == "resume": self._handle_resume_command(cmd_original) @@ -5431,6 +5443,9 @@ class HermesCLI: # Re-queue the message so process_loop sends it to the agent self._pending_input.put(retry_msg) elif canonical == "undo": + if not _destructive_command_is_confirmed(_args): + _cprint(_destructive_command_msg("undo", _typed_command)) + return True self.undo_last() elif canonical == "branch": self._handle_branch_command(cmd_original) diff --git a/gateway/run.py b/gateway/run.py index 327f8ae32a..fe67fd523f 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4259,6 +4259,15 @@ class GatewayRunner: async def _handle_reset_command(self, event: MessageEvent) -> str: """Handle /new or /reset command.""" + from hermes_cli.commands import ( + destructive_command_confirmation_message, + destructive_command_is_confirmed, + ) + + typed_command = event.get_command() or "new" + if not destructive_command_is_confirmed(event.get_command_args()): + return destructive_command_confirmation_message("new", typed_command) + source = event.source # Get existing session key @@ -5079,6 +5088,11 @@ class GatewayRunner: async def _handle_undo_command(self, event: MessageEvent) -> str: """Handle /undo command - remove the last user/assistant exchange.""" + from hermes_cli.commands import ( + destructive_command_confirmation_message, + destructive_command_is_confirmed, + ) + source = event.source session_entry = self.session_store.get_or_create_session(source) history = self.session_store.load_transcript(session_entry.session_id) @@ -5092,7 +5106,11 @@ class GatewayRunner: if last_user_idx is None: return "Nothing to undo." - + + if not destructive_command_is_confirmed(event.get_command_args()): + typed_command = event.get_command() or "undo" + return destructive_command_confirmation_message("undo", typed_command) + removed_msg = history[last_user_idx].get("content", "") removed_count = len(history) - last_user_idx self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx]) diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 516392bd1d..90ffdb4609 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -59,15 +59,16 @@ class CommandDef: COMMAND_REGISTRY: list[CommandDef] = [ # Session CommandDef("new", "Start a new session (fresh session ID + history)", "Session", - aliases=("reset",)), + aliases=("reset",), args_hint="[--yes]"), CommandDef("clear", "Clear screen and start a new session", "Session", - cli_only=True), + cli_only=True, args_hint="[--yes]"), CommandDef("history", "Show conversation history", "Session", cli_only=True), CommandDef("save", "Save the current conversation", "Session", cli_only=True), CommandDef("retry", "Retry the last message (resend to agent)", "Session"), - CommandDef("undo", "Remove the last user/assistant exchange", "Session"), + CommandDef("undo", "Remove the last user/assistant exchange", "Session", + args_hint="[--yes]"), CommandDef("title", "Set a title for the current session", "Session", args_hint="[name]"), CommandDef("branch", "Branch the current session (explore a different path)", "Session", @@ -184,6 +185,18 @@ def _build_command_lookup() -> dict[str, CommandDef]: _COMMAND_LOOKUP: dict[str, CommandDef] = _build_command_lookup() +_DESTRUCTIVE_COMMAND_ACTIONS: dict[str, str] = { + "new": "discard the current conversation history and start a fresh session", + "clear": "clear the screen and discard the current conversation history", + "undo": "remove the last user/assistant exchange from this session", +} +_DESTRUCTIVE_CONFIRM_TOKENS: frozenset[str] = frozenset({ + "--yes", + "--confirm", + "yes", + "confirm", +}) + def resolve_command(name: str) -> CommandDef | None: """Resolve a command name or alias to its CommandDef. @@ -193,6 +206,26 @@ def resolve_command(name: str) -> CommandDef | None: return _COMMAND_LOOKUP.get(name.lower().lstrip("/")) +def destructive_command_is_confirmed(args: str) -> bool: + """Return True when the command args include an explicit confirmation token.""" + return any(token.lower() in _DESTRUCTIVE_CONFIRM_TOKENS for token in args.split()) + + +def destructive_command_confirmation_message( + canonical_name: str, + typed_name: str | None = None, +) -> str: + """Build a consistent warning for destructive session commands.""" + action = _DESTRUCTIVE_COMMAND_ACTIONS.get(canonical_name) + if not action: + raise KeyError(f"Unknown destructive command: {canonical_name}") + command_name = (typed_name or canonical_name).lstrip("/").lower() + return ( + f"Confirmation required: this will {action}. " + f"Re-run `/{command_name} --yes` to continue." + ) + + def _build_description(cmd: CommandDef) -> str: """Build a CLI-facing description string including usage hint.""" if cmd.args_hint: diff --git a/tests/cli/test_cli_new_session.py b/tests/cli/test_cli_new_session.py index 0490aad9ce..0948004bc5 100644 --- a/tests/cli/test_cli_new_session.py +++ b/tests/cli/test_cli_new_session.py @@ -34,6 +34,7 @@ class _FakeAgent: [{"id": "t1", "content": "unfinished task", "status": "in_progress"}] ) self.flush_memories = MagicMock() + self.commit_memory_session = MagicMock() self._invalidate_system_prompt = MagicMock() # Token counters (non-zero to verify reset) @@ -116,6 +117,7 @@ def _make_cli(env_overrides=None, config_overrides=None, **kwargs): with patch.object(_cli_mod, "get_tool_definitions", return_value=[]), patch.dict( _cli_mod.__dict__, {"CLI_CONFIG": _clean_config} ): + _cli_mod._cprint = MagicMock() return _cli_mod.HermesCLI(**kwargs) @@ -138,7 +140,7 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path) old_session_id = cli.session_id old_session_start = cli.session_start - cli.process_command("/new") + cli.process_command("/new --yes") assert cli.session_id != old_session_id @@ -164,7 +166,7 @@ def test_reset_command_is_alias_for_new_session(tmp_path): cli = _prepare_cli_with_active_session(tmp_path) old_session_id = cli.session_id - cli.process_command("/reset") + cli.process_command("/reset --yes") assert cli.session_id != old_session_id assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session" @@ -177,7 +179,7 @@ def test_clear_command_starts_new_session_before_redrawing(tmp_path): cli.show_banner = MagicMock() old_session_id = cli.session_id - cli.process_command("/clear") + cli.process_command("/clear --yes") assert cli.session_id != old_session_id assert cli._session_db.get_session(old_session_id)["end_reason"] == "new_session" @@ -197,7 +199,7 @@ def test_new_session_resets_token_counters(tmp_path): assert agent.session_api_calls > 0 assert agent.context_compressor.compression_count > 0 - cli.process_command("/new") + cli.process_command("/new --yes") # All agent token counters must be zero assert agent.session_total_tokens == 0 @@ -220,3 +222,79 @@ def test_new_session_resets_token_counters(tmp_path): assert comp.last_total_tokens == 0 assert comp.compression_count == 0 assert comp._context_probed is False + + +def test_new_command_requires_confirmation_before_resetting(tmp_path): + cli = _prepare_cli_with_active_session(tmp_path) + old_session_id = cli.session_id + + cli.process_command("/new") + + assert cli.session_id == old_session_id + assert cli._session_db.get_session(old_session_id) is not None + assert cli.conversation_history == [{"role": "user", "content": "hello"}] + + +def test_reset_without_confirmation_does_not_reset_session(tmp_path): + cli = _prepare_cli_with_active_session(tmp_path) + old_session_id = cli.session_id + + cli.process_command("/reset") + + assert cli.session_id == old_session_id + + +def test_reset_alias_warning_uses_typed_command(): + from hermes_cli.commands import destructive_command_confirmation_message + + assert "/reset --yes" in destructive_command_confirmation_message("new", "reset") + + +def test_clear_requires_confirmation_before_redrawing(tmp_path): + cli = _prepare_cli_with_active_session(tmp_path) + cli.console = MagicMock() + cli.show_banner = MagicMock() + old_session_id = cli.session_id + + cli.process_command("/clear") + + assert cli.session_id == old_session_id + assert cli.conversation_history == [{"role": "user", "content": "hello"}] + cli.console.clear.assert_not_called() + cli.show_banner.assert_not_called() + + +def test_undo_requires_confirmation_before_mutating_history(): + cli = _make_cli() + cli.conversation_history = [ + {"role": "user", "content": "first prompt"}, + {"role": "assistant", "content": "first reply"}, + {"role": "user", "content": "second prompt"}, + {"role": "assistant", "content": "second reply"}, + ] + + cli.process_command("/undo") + + assert [msg["content"] for msg in cli.conversation_history] == [ + "first prompt", + "first reply", + "second prompt", + "second reply", + ] + + +def test_undo_with_confirmation_removes_last_exchange(): + cli = _make_cli() + cli.conversation_history = [ + {"role": "user", "content": "first prompt"}, + {"role": "assistant", "content": "first reply"}, + {"role": "user", "content": "second prompt"}, + {"role": "assistant", "content": "second reply"}, + ] + + cli.process_command("/undo --yes") + + assert [msg["content"] for msg in cli.conversation_history] == [ + "first prompt", + "first reply", + ] diff --git a/tests/gateway/test_destructive_command_confirmation.py b/tests/gateway/test_destructive_command_confirmation.py new file mode 100644 index 0000000000..01803f362a --- /dev/null +++ b/tests/gateway/test_destructive_command_confirmation.py @@ -0,0 +1,129 @@ +"""Tests confirmation guards for destructive gateway session commands.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent(text=text, source=_make_source(), message_id="m1") + + +def _make_runner(): + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner._session_model_overrides = {} + runner._pending_model_notes = {} + runner._background_tasks = set() + + session_key = build_session_key(_make_source()) + session_entry = SessionEntry( + session_key=session_key, + session_id="sess-old", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + new_session_entry = SessionEntry( + session_key=session_key, + session_id="sess-new", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.reset_session.return_value = new_session_entry + runner.session_store._entries = {session_key: session_entry} + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._agent_cache_lock = None + runner._is_user_authorized = lambda _source: True + runner._format_session_info = lambda: "" + + return runner + + +@pytest.mark.asyncio +async def test_new_requires_confirmation_before_reset(): + runner = _make_runner() + + result = await runner._handle_reset_command(_make_event("/new")) + + assert "/new --yes" in result + runner.session_store.reset_session.assert_not_called() + + +@pytest.mark.asyncio +async def test_reset_alias_confirmation_mentions_reset(): + runner = _make_runner() + + result = await runner._handle_reset_command(_make_event("/reset")) + + assert "/reset --yes" in result + runner.session_store.reset_session.assert_not_called() + + +@pytest.mark.asyncio +async def test_undo_requires_confirmation_before_rewriting_transcript(): + runner = _make_runner() + runner.session_store.load_transcript.return_value = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + + result = await runner._handle_undo_command(_make_event("/undo")) + + assert "/undo --yes" in result + runner.session_store.rewrite_transcript.assert_not_called() + + +@pytest.mark.asyncio +async def test_undo_with_confirmation_rewrites_transcript(): + runner = _make_runner() + runner.session_store.load_transcript.return_value = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + {"role": "user", "content": "undo this"}, + {"role": "assistant", "content": "ok"}, + ] + + result = await runner._handle_undo_command(_make_event("/undo --yes")) + + runner.session_store.rewrite_transcript.assert_called_once_with( + "sess-old", + [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ], + ) + assert "Undid 2 message(s)." in result diff --git a/tests/gateway/test_session_boundary_hooks.py b/tests/gateway/test_session_boundary_hooks.py index a556624363..76328265aa 100644 --- a/tests/gateway/test_session_boundary_hooks.py +++ b/tests/gateway/test_session_boundary_hooks.py @@ -79,7 +79,7 @@ async def test_reset_fires_finalize_hook(mock_invoke_hook): """/new must fire on_session_finalize with the OLD session id.""" runner = _make_runner() - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) mock_invoke_hook.assert_any_call( "on_session_finalize", session_id="sess-old", platform="telegram" @@ -92,7 +92,7 @@ async def test_reset_fires_reset_hook(mock_invoke_hook): """/new must fire on_session_reset with the NEW session id.""" runner = _make_runner() - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) mock_invoke_hook.assert_any_call( "on_session_reset", session_id="sess-new", platform="telegram" @@ -105,7 +105,7 @@ async def test_finalize_before_reset(mock_invoke_hook): """on_session_finalize must fire before on_session_reset.""" runner = _make_runner() - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) calls = [c for c in mock_invoke_hook.call_args_list if c[0][0] in ("on_session_finalize", "on_session_reset")] @@ -162,7 +162,7 @@ async def test_hook_error_does_not_break_reset(mock_invoke_hook): """Plugin hook errors must not prevent /new from completing.""" runner = _make_runner() - result = await runner._handle_reset_command(_make_event("/new")) + result = await runner._handle_reset_command(_make_event("/new --yes")) # Should still return a success message despite hook errors assert "Session reset" in result or "New session" in result diff --git a/tests/gateway/test_session_model_reset.py b/tests/gateway/test_session_model_reset.py index 6529f3a11d..9be813a296 100644 --- a/tests/gateway/test_session_model_reset.py +++ b/tests/gateway/test_session_model_reset.py @@ -80,7 +80,7 @@ async def test_new_command_clears_session_model_override(): "api_mode": "openai", } - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) assert session_key not in runner._session_model_overrides @@ -93,7 +93,7 @@ async def test_new_command_no_override_is_noop(): assert session_key not in runner._session_model_overrides - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) assert session_key not in runner._session_model_overrides @@ -120,7 +120,7 @@ async def test_new_command_only_clears_own_session(): "api_mode": "anthropic", } - await runner._handle_reset_command(_make_event("/new")) + await runner._handle_reset_command(_make_event("/new --yes")) assert session_key not in runner._session_model_overrides assert other_key in runner._session_model_overrides