diff --git a/gateway/run.py b/gateway/run.py index 2be5f486b..4b6f1dadb 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -7217,6 +7217,7 @@ class GatewayRunner: new_entry = self.session_store.switch_session(session_key, target_id) if not new_entry: return "Failed to switch session." + self._clear_session_boundary_security_state(session_key) # Get the title for confirmation title = self._session_db.get_session_title(target_id) or name @@ -7306,6 +7307,7 @@ class GatewayRunner: new_entry = self.session_store.switch_session(session_key, new_session_id) if not new_entry: return "Branch created but failed to switch to it." + self._clear_session_boundary_security_state(session_key) # Evict any cached agent for this session self._evict_cached_agent(session_key) @@ -8680,6 +8682,29 @@ class GatewayRunner: if hasattr(self, "_busy_ack_ts"): self._busy_ack_ts.pop(session_key, None) + def _clear_session_boundary_security_state(self, session_key: str) -> None: + """Clear approval state that must not survive a real conversation switch.""" + if not session_key: + return + + pending_approvals = getattr(self, "_pending_approvals", None) + if isinstance(pending_approvals, dict): + pending_approvals.pop(session_key, None) + + try: + from tools.approval import clear_session as _clear_approval_session + except Exception: + return + + try: + _clear_approval_session(session_key) + except Exception as e: + logger.debug( + "Failed to clear approval state for session boundary %s: %s", + session_key, + e, + ) + def _begin_session_run_generation(self, session_key: str) -> int: """Claim a fresh run generation token for ``session_key``. diff --git a/tests/gateway/test_session_boundary_security_state.py b/tests/gateway/test_session_boundary_security_state.py new file mode 100644 index 000000000..bd75eca8b --- /dev/null +++ b/tests/gateway/test_session_boundary_security_state.py @@ -0,0 +1,163 @@ +"""Regression tests for approval-state cleanup on session boundaries.""" + +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import Platform +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key +from tools import approval as approval_mod +from tools.approval import ( + approve_session, + enable_session_yolo, + is_approved, + is_session_yolo_enabled, +) + + +@pytest.fixture(autouse=True) +def _clear_approval_state(): + approval_mod._gateway_queues.clear() + approval_mod._gateway_notify_cbs.clear() + approval_mod._session_approved.clear() + approval_mod._session_yolo.clear() + approval_mod._permanent_approved.clear() + approval_mod._pending.clear() + yield + approval_mod._gateway_queues.clear() + approval_mod._gateway_notify_cbs.clear() + approval_mod._session_approved.clear() + approval_mod._session_yolo.clear() + approval_mod._permanent_approved.clear() + approval_mod._pending.clear() + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent(text=text, source=_make_source(), message_id="m1") + + +def _make_entry(session_id: str, source: SessionSource | None = None) -> SessionEntry: + source = source or _make_source() + return SessionEntry( + session_key=build_session_key(source), + session_id=session_id, + created_at=datetime.now(), + updated_at=datetime.now(), + origin=source, + platform=source.platform, + chat_type=source.chat_type, + ) + + +def _make_resume_runner(): + from gateway.run import GatewayRunner + + source = _make_source() + session_key = build_session_key(source) + current_entry = _make_entry("current-session", source) + resumed_entry = _make_entry("resumed-session", source) + + runner = object.__new__(GatewayRunner) + runner.adapters = {} + runner._background_tasks = set() + runner._async_flush_memories = AsyncMock() + runner._running_agents = {} + runner._running_agents_ts = {} + runner._busy_ack_ts = {} + runner._pending_approvals = {} + runner._agent_cache_lock = None + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = current_entry + runner.session_store.switch_session.return_value = resumed_entry + runner.session_store.load_transcript.return_value = [] + runner._session_db = MagicMock() + runner._session_db.resolve_session_by_title.return_value = "resumed-session" + runner._session_db.get_session_title.return_value = "Resumed Work" + return runner, session_key + + +def _make_branch_runner(): + from gateway.run import GatewayRunner + + source = _make_source() + session_key = build_session_key(source) + current_entry = _make_entry("current-session", source) + branched_entry = _make_entry("branched-session", source) + + runner = object.__new__(GatewayRunner) + runner.adapters = {} + runner.config = {} + runner._running_agents = {} + runner._running_agents_ts = {} + runner._busy_ack_ts = {} + runner._pending_approvals = {} + runner._agent_cache_lock = None + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = current_entry + runner.session_store.load_transcript.return_value = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ] + runner.session_store.switch_session.return_value = branched_entry + runner._session_db = MagicMock() + runner._session_db.get_session_title.return_value = "Current Work" + runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2" + return runner, session_key + + +@pytest.mark.asyncio +async def test_resume_clears_session_scoped_approval_and_yolo_state(): + runner, session_key = _make_resume_runner() + other_key = "agent:main:telegram:dm:other-chat" + + approve_session(session_key, "recursive delete") + approve_session(other_key, "recursive delete") + enable_session_yolo(session_key) + enable_session_yolo(other_key) + runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"} + runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"} + + result = await runner._handle_resume_command(_make_event("/resume Resumed Work")) + + assert "Resumed session" in result + assert is_approved(session_key, "recursive delete") is False + assert is_session_yolo_enabled(session_key) is False + assert session_key not in runner._pending_approvals + assert is_approved(other_key, "recursive delete") is True + assert is_session_yolo_enabled(other_key) is True + assert other_key in runner._pending_approvals + + +@pytest.mark.asyncio +async def test_branch_clears_session_scoped_approval_and_yolo_state(): + runner, session_key = _make_branch_runner() + other_key = "agent:main:telegram:dm:other-chat" + + approve_session(session_key, "recursive delete") + approve_session(other_key, "recursive delete") + enable_session_yolo(session_key) + enable_session_yolo(other_key) + runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"} + runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"} + + result = await runner._handle_branch_command(_make_event("/branch")) + + assert "Branched to" in result + assert is_approved(session_key, "recursive delete") is False + assert is_session_yolo_enabled(session_key) is False + assert session_key not in runner._pending_approvals + assert is_approved(other_key, "recursive delete") is True + assert is_session_yolo_enabled(other_key) is True + assert other_key in runner._pending_approvals