diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 627829ca3..0cdf2a331 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -84,6 +84,9 @@ class SlackAdapter(BasePlatformAdapter): self._seen_messages: Dict[str, float] = {} self._SEEN_TTL = 300 # 5 minutes self._SEEN_MAX = 2000 # prune threshold + # Track pending approval message_ts → resolved flag to prevent + # double-clicks on approval buttons. + self._approval_resolved: Dict[str, bool] = {} async def connect(self) -> bool: """Connect to Slack via Socket Mode.""" @@ -176,6 +179,15 @@ class SlackAdapter(BasePlatformAdapter): await ack() await self._handle_slash_command(command) + # Register Block Kit action handlers for approval buttons + for _action_id in ( + "hermes_approve_once", + "hermes_approve_session", + "hermes_approve_always", + "hermes_deny", + ): + self._app.action(_action_id)(self._handle_approval_action) + # Start Socket Mode handler in background self._handler = AsyncSocketModeHandler(self._app, app_token) self._socket_mode_task = asyncio.create_task(self._handler.start_async()) @@ -791,6 +803,24 @@ class SlackAdapter(BasePlatformAdapter): # Strip the bot mention from the text text = text.replace(f"<@{bot_uid}>", "").strip() + # When first mentioned in an existing thread, fetch thread context + # so the agent understands the conversation it's joining. + event_thread_ts = event.get("thread_ts") + is_thread_reply = event_thread_ts and event_thread_ts != ts + if is_thread_reply and not self._has_active_session_for_thread( + channel_id=channel_id, + thread_ts=event_thread_ts, + user_id=user_id, + ): + thread_context = await self._fetch_thread_context( + channel_id=channel_id, + thread_ts=event_thread_ts, + current_ts=ts, + team_id=team_id, + ) + if thread_context: + text = thread_context + text + # Determine message type msg_type = MessageType.TEXT if text.startswith("/"): @@ -912,6 +942,233 @@ class SlackAdapter(BasePlatformAdapter): await self._remove_reaction(channel_id, ts, "eyes") await self._add_reaction(channel_id, ts, "white_check_mark") + # ----- Approval button support (Block Kit) ----- + + async def send_exec_approval( + self, chat_id: str, command: str, session_key: str, + description: str = "dangerous command", + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a Block Kit approval prompt with interactive buttons. + + The buttons call ``resolve_gateway_approval()`` to unblock the waiting + agent thread — same mechanism as the text ``/approve`` flow. + """ + if not self._app: + return SendResult(success=False, error="Not connected") + + try: + cmd_preview = command[:2900] + "..." if len(command) > 2900 else command + thread_ts = self._resolve_thread_ts(None, metadata) + + blocks = [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": ( + f":warning: *Command Approval Required*\n" + f"```{cmd_preview}```\n" + f"Reason: {description}" + ), + }, + }, + { + "type": "actions", + "elements": [ + { + "type": "button", + "text": {"type": "plain_text", "text": "Allow Once"}, + "style": "primary", + "action_id": "hermes_approve_once", + "value": session_key, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Allow Session"}, + "action_id": "hermes_approve_session", + "value": session_key, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Always Allow"}, + "action_id": "hermes_approve_always", + "value": session_key, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Deny"}, + "style": "danger", + "action_id": "hermes_deny", + "value": session_key, + }, + ], + }, + ] + + kwargs: Dict[str, Any] = { + "channel": chat_id, + "text": f"⚠️ Command approval required: {cmd_preview[:100]}", + "blocks": blocks, + } + if thread_ts: + kwargs["thread_ts"] = thread_ts + + result = await self._get_client(chat_id).chat_postMessage(**kwargs) + msg_ts = result.get("ts", "") + if msg_ts: + self._approval_resolved[msg_ts] = False + + return SendResult(success=True, message_id=msg_ts, raw_response=result) + except Exception as e: + logger.error("[Slack] send_exec_approval failed: %s", e, exc_info=True) + return SendResult(success=False, error=str(e)) + + async def _handle_approval_action(self, ack, body, action) -> None: + """Handle an approval button click from Block Kit.""" + await ack() + + action_id = action.get("action_id", "") + session_key = action.get("value", "") + message = body.get("message", {}) + msg_ts = message.get("ts", "") + channel_id = body.get("channel", {}).get("id", "") + user_name = body.get("user", {}).get("name", "unknown") + + # Map action_id to approval choice + choice_map = { + "hermes_approve_once": "once", + "hermes_approve_session": "session", + "hermes_approve_always": "always", + "hermes_deny": "deny", + } + choice = choice_map.get(action_id, "deny") + + # Prevent double-clicks + if self._approval_resolved.get(msg_ts, False): + return + self._approval_resolved[msg_ts] = True + + # Update the message to show the decision and remove buttons + label_map = { + "once": f"✅ Approved once by {user_name}", + "session": f"✅ Approved for session by {user_name}", + "always": f"✅ Approved permanently by {user_name}", + "deny": f"❌ Denied by {user_name}", + } + decision_text = label_map.get(choice, f"Resolved by {user_name}") + + # Get original text from the section block + original_text = "" + for block in message.get("blocks", []): + if block.get("type") == "section": + original_text = block.get("text", {}).get("text", "") + break + + updated_blocks = [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": original_text or "Command approval request", + }, + }, + { + "type": "context", + "elements": [ + {"type": "mrkdwn", "text": decision_text}, + ], + }, + ] + + try: + await self._get_client(channel_id).chat_update( + channel=channel_id, + ts=msg_ts, + text=decision_text, + blocks=updated_blocks, + ) + except Exception as e: + logger.warning("[Slack] Failed to update approval message: %s", e) + + # Resolve the approval — this unblocks the agent thread + try: + from tools.approval import resolve_gateway_approval + count = resolve_gateway_approval(session_key, choice) + logger.info( + "Slack button resolved %d approval(s) for session %s (choice=%s, user=%s)", + count, session_key, choice, user_name, + ) + except Exception as exc: + logger.error("Failed to resolve gateway approval from Slack button: %s", exc) + + # Clean up stale approval state + self._approval_resolved.pop(msg_ts, None) + + # ----- Thread context fetching ----- + + async def _fetch_thread_context( + self, channel_id: str, thread_ts: str, current_ts: str, + team_id: str = "", limit: int = 30, + ) -> str: + """Fetch recent thread messages to provide context when the bot is + mentioned mid-thread for the first time. + + Returns a formatted string with thread history, or empty string on + failure or if the thread is empty (just the parent message). + """ + try: + client = self._get_client(channel_id) + result = await client.conversations_replies( + channel=channel_id, + ts=thread_ts, + limit=limit + 1, # +1 because it includes the current message + inclusive=True, + ) + messages = result.get("messages", []) + if not messages: + return "" + + context_parts = [] + for msg in messages: + msg_ts = msg.get("ts", "") + # Skip the current message (the one that triggered this fetch) + if msg_ts == current_ts: + continue + # Skip bot messages from ourselves + if msg.get("bot_id") or msg.get("subtype") == "bot_message": + continue + + msg_user = msg.get("user", "unknown") + msg_text = msg.get("text", "").strip() + if not msg_text: + continue + + # Strip bot mentions from context messages + bot_uid = self._team_bot_user_ids.get(team_id, self._bot_user_id) + if bot_uid: + msg_text = msg_text.replace(f"<@{bot_uid}>", "").strip() + + # Mark the thread parent + is_parent = msg_ts == thread_ts + prefix = "[thread parent] " if is_parent else "" + + # Resolve user name (cached) + name = await self._resolve_user_name(msg_user, chat_id=channel_id) + context_parts.append(f"{prefix}{name}: {msg_text}") + + if not context_parts: + return "" + + return ( + "[Thread context — previous messages in this thread:]\n" + + "\n".join(context_parts) + + "\n[End of thread context]\n\n" + ) + except Exception as e: + logger.warning("[Slack] Failed to fetch thread context: %s", e) + return "" + async def _handle_slash_command(self, command: dict) -> None: """Handle /hermes slash command.""" text = command.get("text", "").strip() @@ -960,50 +1217,44 @@ class SlackAdapter(BasePlatformAdapter): user_id: str, ) -> bool: """Check if there's an active session for a thread. - + Used to determine if thread replies without @mentions should be processed (they should if there's an active session). - - Args: - channel_id: The Slack channel ID - thread_ts: The thread timestamp (parent message ts) - user_id: The user ID of the sender - - Returns: - True if there's an active session for this thread + + Uses ``build_session_key()`` as the single source of truth for key + construction — avoids the bug where manual key building didn't + respect ``thread_sessions_per_user`` and ``group_sessions_per_user`` + settings correctly. """ session_store = getattr(self, "_session_store", None) if not session_store: return False - - try: - # Build a SessionSource for this thread - from gateway.session import SessionSource - # Generate the session key using the same logic as SessionStore - # This mirrors the logic in build_session_key for group sessions - key_parts = ["agent:main", "slack", "group", channel_id, thread_ts] - - # Include user_id if group_sessions_per_user is enabled - # We check the session store config if available - group_sessions_per_user = getattr( - session_store, "config", {} + try: + from gateway.session import SessionSource, build_session_key + + source = SessionSource( + platform=Platform.SLACK, + chat_id=channel_id, + chat_type="group", + user_id=user_id, + thread_id=thread_ts, ) - if hasattr(group_sessions_per_user, "group_sessions_per_user"): - group_sessions_per_user = group_sessions_per_user.group_sessions_per_user - else: - group_sessions_per_user = True # Default - - if group_sessions_per_user and user_id: - key_parts.append(str(user_id)) - - session_key = ":".join(key_parts) - - # Check if the session exists in the store + + # Read session isolation settings from the store's config + store_cfg = getattr(session_store, "config", None) + gspu = getattr(store_cfg, "group_sessions_per_user", True) if store_cfg else True + tspu = getattr(store_cfg, "thread_sessions_per_user", False) if store_cfg else False + + session_key = build_session_key( + source, + group_sessions_per_user=gspu, + thread_sessions_per_user=tspu, + ) + session_store._ensure_loaded() return session_key in session_store._entries except Exception: - # If anything goes wrong, default to False (require mention) return False async def _download_slack_file(self, url: str, ext: str, audio: bool = False, team_id: str = "") -> str: diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 8c69c47b8..26b0e4263 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -153,6 +153,8 @@ class TelegramAdapter(BasePlatformAdapter): self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", []) # Interactive model picker state per chat self._model_picker_state: Dict[str, dict] = {} + # Approval button state: message_id → session_key + self._approval_state: Dict[int, str] = {} def _fallback_ips(self) -> list[str]: """Return validated fallback IPs from config (populated by _apply_env_overrides).""" @@ -1010,6 +1012,70 @@ class TelegramAdapter(BasePlatformAdapter): logger.warning("[%s] send_update_prompt failed: %s", self.name, e) return SendResult(success=False, error=str(e)) + async def send_exec_approval( + self, chat_id: str, command: str, session_key: str, + description: str = "dangerous command", + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an inline-keyboard approval prompt with interactive buttons. + + The buttons call ``resolve_gateway_approval()`` to unblock the waiting + agent thread — same mechanism as the text ``/approve`` flow. + """ + if not self._bot: + return SendResult(success=False, error="Not connected") + + try: + cmd_preview = command[:3800] + "..." if len(command) > 3800 else command + text = ( + f"⚠️ *Command Approval Required*\n\n" + f"`{cmd_preview}`\n\n" + f"Reason: {description}" + ) + + # Resolve thread context for thread replies + thread_id = None + if metadata: + thread_id = metadata.get("thread_id") or metadata.get("message_thread_id") + + # We'll use the message_id as part of callback_data to look up session_key + # Send a placeholder first, then update — or use a counter. + # Simpler: use a monotonic counter to generate short IDs. + import itertools + if not hasattr(self, "_approval_counter"): + self._approval_counter = itertools.count(1) + approval_id = next(self._approval_counter) + + keyboard = InlineKeyboardMarkup([ + [ + InlineKeyboardButton("✅ Allow Once", callback_data=f"ea:once:{approval_id}"), + InlineKeyboardButton("✅ Session", callback_data=f"ea:session:{approval_id}"), + ], + [ + InlineKeyboardButton("✅ Always", callback_data=f"ea:always:{approval_id}"), + InlineKeyboardButton("❌ Deny", callback_data=f"ea:deny:{approval_id}"), + ], + ]) + + kwargs: Dict[str, Any] = { + "chat_id": int(chat_id), + "text": text, + "parse_mode": ParseMode.MARKDOWN, + "reply_markup": keyboard, + } + if thread_id: + kwargs["message_thread_id"] = int(thread_id) + + msg = await self._bot.send_message(**kwargs) + + # Store session_key keyed by approval_id for the callback handler + self._approval_state[approval_id] = session_key + + return SendResult(success=True, message_id=str(msg.message_id)) + except Exception as e: + logger.warning("[%s] send_exec_approval failed: %s", self.name, e) + return SendResult(success=False, error=str(e)) + async def send_model_picker( self, chat_id: str, @@ -1321,6 +1387,56 @@ class TelegramAdapter(BasePlatformAdapter): await self._handle_model_picker_callback(query, data, chat_id) return + # --- Exec approval callbacks (ea:choice:id) --- + if data.startswith("ea:"): + parts = data.split(":", 2) + if len(parts) == 3: + choice = parts[1] # once, session, always, deny + try: + approval_id = int(parts[2]) + except (ValueError, IndexError): + await query.answer(text="Invalid approval data.") + return + + session_key = self._approval_state.pop(approval_id, None) + if not session_key: + await query.answer(text="This approval has already been resolved.") + return + + # Map choice to human-readable label + label_map = { + "once": "✅ Approved once", + "session": "✅ Approved for session", + "always": "✅ Approved permanently", + "deny": "❌ Denied", + } + user_display = getattr(query.from_user, "first_name", "User") + label = label_map.get(choice, "Resolved") + + await query.answer(text=label) + + # Edit message to show decision, remove buttons + try: + await query.edit_message_text( + text=f"{label} by {user_display}", + parse_mode=ParseMode.MARKDOWN, + reply_markup=None, + ) + except Exception: + pass # non-fatal if edit fails + + # Resolve the approval — unblocks the agent thread + try: + from tools.approval import resolve_gateway_approval + count = resolve_gateway_approval(session_key, choice) + logger.info( + "Telegram button resolved %d approval(s) for session %s (choice=%s, user=%s)", + count, session_key, choice, user_display, + ) + except Exception as exc: + logger.error("Failed to resolve gateway approval from Telegram button: %s", exc) + return + # --- Update prompt callbacks --- if not data.startswith("update_prompt:"): return diff --git a/tests/gateway/test_slack_approval_buttons.py b/tests/gateway/test_slack_approval_buttons.py new file mode 100644 index 000000000..496f472c2 --- /dev/null +++ b/tests/gateway/test_slack_approval_buttons.py @@ -0,0 +1,373 @@ +"""Tests for Slack Block Kit approval buttons and thread context fetching.""" + +import asyncio +import os +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Ensure the repo root is importable +# --------------------------------------------------------------------------- +_repo = str(Path(__file__).resolve().parents[2]) +if _repo not in sys.path: + sys.path.insert(0, _repo) + + +# --------------------------------------------------------------------------- +# Minimal Slack SDK mock so SlackAdapter can be imported +# --------------------------------------------------------------------------- +def _ensure_slack_mock(): + """Wire up the minimal mocks required to import SlackAdapter.""" + if "slack_bolt" in sys.modules: + return + slack_bolt = MagicMock() + slack_bolt.async_app.AsyncApp = MagicMock + sys.modules["slack_bolt"] = slack_bolt + sys.modules["slack_bolt.async_app"] = slack_bolt.async_app + handler_mod = MagicMock() + handler_mod.AsyncSocketModeHandler = MagicMock + sys.modules["slack_bolt.adapter"] = MagicMock() + sys.modules["slack_bolt.adapter.socket_mode"] = MagicMock() + sys.modules["slack_bolt.adapter.socket_mode.async_handler"] = handler_mod + sdk_mod = MagicMock() + sdk_mod.web = MagicMock() + sdk_mod.web.async_client = MagicMock() + sdk_mod.web.async_client.AsyncWebClient = MagicMock + sys.modules["slack_sdk"] = sdk_mod + sys.modules["slack_sdk.web"] = sdk_mod.web + sys.modules["slack_sdk.web.async_client"] = sdk_mod.web.async_client + + +_ensure_slack_mock() + +from gateway.platforms.slack import SlackAdapter +from gateway.config import Platform, PlatformConfig + + +def _make_adapter(): + """Create a SlackAdapter instance with mocked internals.""" + config = PlatformConfig(enabled=True, token="xoxb-test-token") + adapter = SlackAdapter(config) + adapter._app = MagicMock() + adapter._bot_user_id = "U_BOT" + adapter._team_clients = {"T1": AsyncMock()} + adapter._team_bot_user_ids = {"T1": "U_BOT"} + adapter._channel_team = {"C1": "T1"} + return adapter + + +# =========================================================================== +# send_exec_approval — Block Kit buttons +# =========================================================================== + +class TestSlackExecApproval: + """Test the send_exec_approval method sends Block Kit buttons.""" + + @pytest.mark.asyncio + async def test_sends_blocks_with_buttons(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"}) + + result = await adapter.send_exec_approval( + chat_id="C1", + command="rm -rf /important", + session_key="agent:main:slack:group:C1:1111", + description="dangerous deletion", + ) + + assert result.success is True + assert result.message_id == "1234.5678" + + # Verify chat_postMessage was called with blocks + mock_client.chat_postMessage.assert_called_once() + kwargs = mock_client.chat_postMessage.call_args[1] + assert "blocks" in kwargs + blocks = kwargs["blocks"] + assert len(blocks) == 2 + assert blocks[0]["type"] == "section" + assert "rm -rf /important" in blocks[0]["text"]["text"] + assert "dangerous deletion" in blocks[0]["text"]["text"] + assert blocks[1]["type"] == "actions" + elements = blocks[1]["elements"] + assert len(elements) == 4 + action_ids = [e["action_id"] for e in elements] + assert "hermes_approve_once" in action_ids + assert "hermes_approve_session" in action_ids + assert "hermes_approve_always" in action_ids + assert "hermes_deny" in action_ids + # Each button carries the session key as value + for e in elements: + assert e["value"] == "agent:main:slack:group:C1:1111" + + @pytest.mark.asyncio + async def test_sends_in_thread(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1234.5678"}) + + await adapter.send_exec_approval( + chat_id="C1", + command="echo test", + session_key="test-session", + metadata={"thread_id": "9999.0000"}, + ) + + kwargs = mock_client.chat_postMessage.call_args[1] + assert kwargs.get("thread_ts") == "9999.0000" + + @pytest.mark.asyncio + async def test_not_connected(self): + adapter = _make_adapter() + adapter._app = None + result = await adapter.send_exec_approval( + chat_id="C1", command="ls", session_key="s" + ) + assert result.success is False + + @pytest.mark.asyncio + async def test_truncates_long_command(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "1.2"}) + + long_cmd = "x" * 5000 + await adapter.send_exec_approval( + chat_id="C1", command=long_cmd, session_key="s" + ) + + kwargs = mock_client.chat_postMessage.call_args[1] + section_text = kwargs["blocks"][0]["text"]["text"] + assert "..." in section_text + assert len(section_text) < 5000 + + +# =========================================================================== +# _handle_approval_action — button click handler +# =========================================================================== + +class TestSlackApprovalAction: + """Test the approval button click handler.""" + + @pytest.mark.asyncio + async def test_resolves_approval(self): + adapter = _make_adapter() + adapter._approval_resolved["1234.5678"] = False + + ack = AsyncMock() + body = { + "message": { + "ts": "1234.5678", + "blocks": [ + {"type": "section", "text": {"type": "mrkdwn", "text": "original text"}}, + {"type": "actions", "elements": []}, + ], + }, + "channel": {"id": "C1"}, + "user": {"name": "norbert"}, + } + action = { + "action_id": "hermes_approve_once", + "value": "agent:main:slack:group:C1:1111", + } + + mock_client = adapter._team_clients["T1"] + mock_client.chat_update = AsyncMock() + + with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: + await adapter._handle_approval_action(ack, body, action) + + ack.assert_called_once() + mock_resolve.assert_called_once_with("agent:main:slack:group:C1:1111", "once") + + # Message should be updated with decision + mock_client.chat_update.assert_called_once() + update_kwargs = mock_client.chat_update.call_args[1] + assert "Approved once by norbert" in update_kwargs["text"] + + @pytest.mark.asyncio + async def test_prevents_double_click(self): + adapter = _make_adapter() + adapter._approval_resolved["1234.5678"] = True # Already resolved + + ack = AsyncMock() + body = { + "message": {"ts": "1234.5678", "blocks": []}, + "channel": {"id": "C1"}, + "user": {"name": "norbert"}, + } + action = { + "action_id": "hermes_approve_once", + "value": "some-session", + } + + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + await adapter._handle_approval_action(ack, body, action) + + # Should have acked but NOT resolved + ack.assert_called_once() + mock_resolve.assert_not_called() + + @pytest.mark.asyncio + async def test_deny_action(self): + adapter = _make_adapter() + adapter._approval_resolved["1.2"] = False + + ack = AsyncMock() + body = { + "message": {"ts": "1.2", "blocks": [ + {"type": "section", "text": {"type": "mrkdwn", "text": "cmd"}}, + ]}, + "channel": {"id": "C1"}, + "user": {"name": "alice"}, + } + action = {"action_id": "hermes_deny", "value": "session-key"} + + mock_client = adapter._team_clients["T1"] + mock_client.chat_update = AsyncMock() + + with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: + await adapter._handle_approval_action(ack, body, action) + + mock_resolve.assert_called_once_with("session-key", "deny") + update_kwargs = mock_client.chat_update.call_args[1] + assert "Denied by alice" in update_kwargs["text"] + + +# =========================================================================== +# _fetch_thread_context +# =========================================================================== + +class TestSlackThreadContext: + """Test thread context fetching.""" + + @pytest.mark.asyncio + async def test_fetches_and_formats_context(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.conversations_replies = AsyncMock(return_value={ + "messages": [ + {"ts": "1000.0", "user": "U1", "text": "This is the parent message"}, + {"ts": "1000.1", "user": "U2", "text": "I think we should refactor"}, + {"ts": "1000.2", "user": "U1", "text": "Good idea, <@U_BOT> what do you think?"}, + ] + }) + + # Mock user name resolution + adapter._user_name_cache = {"U1": "Alice", "U2": "Bob"} + + context = await adapter._fetch_thread_context( + channel_id="C1", + thread_ts="1000.0", + current_ts="1000.2", # The message that triggered the fetch + team_id="T1", + ) + + assert "[Thread context" in context + assert "[thread parent] Alice: This is the parent message" in context + assert "Bob: I think we should refactor" in context + # Current message should be excluded + assert "what do you think" not in context + # Bot mention should be stripped from context + assert "<@U_BOT>" not in context + + @pytest.mark.asyncio + async def test_skips_bot_messages(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.conversations_replies = AsyncMock(return_value={ + "messages": [ + {"ts": "1000.0", "user": "U1", "text": "Parent"}, + {"ts": "1000.1", "bot_id": "B1", "text": "Bot reply (should be skipped)"}, + {"ts": "1000.2", "user": "U1", "text": "Current"}, + ] + }) + adapter._user_name_cache = {"U1": "Alice"} + + context = await adapter._fetch_thread_context( + channel_id="C1", thread_ts="1000.0", current_ts="1000.2", team_id="T1" + ) + + assert "Bot reply" not in context + assert "Alice: Parent" in context + + @pytest.mark.asyncio + async def test_empty_thread(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.conversations_replies = AsyncMock(return_value={"messages": []}) + + context = await adapter._fetch_thread_context( + channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1" + ) + assert context == "" + + @pytest.mark.asyncio + async def test_api_failure_returns_empty(self): + adapter = _make_adapter() + mock_client = adapter._team_clients["T1"] + mock_client.conversations_replies = AsyncMock(side_effect=Exception("API error")) + + context = await adapter._fetch_thread_context( + channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1" + ) + assert context == "" + + +# =========================================================================== +# _has_active_session_for_thread — session key fix (#5833) +# =========================================================================== + +class TestSessionKeyFix: + """Test that _has_active_session_for_thread uses build_session_key.""" + + def test_uses_build_session_key(self): + """Verify the fix uses build_session_key instead of manual key construction.""" + adapter = _make_adapter() + + # Mock session store with a known entry + mock_store = MagicMock() + mock_store._entries = { + "agent:main:slack:group:C1:1000.0": MagicMock() + } + mock_store._ensure_loaded = MagicMock() + mock_store.config = MagicMock() + mock_store.config.group_sessions_per_user = False # threads don't include user_id + mock_store.config.thread_sessions_per_user = False + adapter._session_store = mock_store + + # With the fix, build_session_key should be called which respects + # group_sessions_per_user=False (no user_id appended) + result = adapter._has_active_session_for_thread( + channel_id="C1", thread_ts="1000.0", user_id="U123" + ) + + # Should find the session because build_session_key with + # group_sessions_per_user=False doesn't append user_id + assert result is True + + def test_no_session_returns_false(self): + adapter = _make_adapter() + mock_store = MagicMock() + mock_store._entries = {} + mock_store._ensure_loaded = MagicMock() + mock_store.config = MagicMock() + mock_store.config.group_sessions_per_user = True + mock_store.config.thread_sessions_per_user = False + adapter._session_store = mock_store + + result = adapter._has_active_session_for_thread( + channel_id="C1", thread_ts="1000.0", user_id="U123" + ) + assert result is False + + def test_no_session_store(self): + adapter = _make_adapter() + # No _session_store attribute + result = adapter._has_active_session_for_thread( + channel_id="C1", thread_ts="1000.0", user_id="U123" + ) + assert result is False diff --git a/tests/gateway/test_telegram_approval_buttons.py b/tests/gateway/test_telegram_approval_buttons.py new file mode 100644 index 000000000..1b8249bc2 --- /dev/null +++ b/tests/gateway/test_telegram_approval_buttons.py @@ -0,0 +1,284 @@ +"""Tests for Telegram inline keyboard approval buttons.""" + +import asyncio +import os +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Ensure the repo root is importable +# --------------------------------------------------------------------------- +_repo = str(Path(__file__).resolve().parents[2]) +if _repo not in sys.path: + sys.path.insert(0, _repo) + + +# --------------------------------------------------------------------------- +# Minimal Telegram mock so TelegramAdapter can be imported +# --------------------------------------------------------------------------- +def _ensure_telegram_mock(): + """Wire up the minimal mocks required to import TelegramAdapter.""" + if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"): + return + + mod = MagicMock() + mod.ext.ContextTypes.DEFAULT_TYPE = type(None) + mod.constants.ParseMode.MARKDOWN = "Markdown" + mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2" + mod.constants.ParseMode.HTML = "HTML" + mod.constants.ChatType.PRIVATE = "private" + mod.constants.ChatType.GROUP = "group" + mod.constants.ChatType.SUPERGROUP = "supergroup" + mod.constants.ChatType.CHANNEL = "channel" + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request", "telegram.error"): + sys.modules.setdefault(name, mod) + + +_ensure_telegram_mock() + +from gateway.platforms.telegram import TelegramAdapter +from gateway.config import Platform, PlatformConfig + + +def _make_adapter(): + """Create a TelegramAdapter with mocked internals.""" + config = PlatformConfig(enabled=True, token="test-token") + adapter = TelegramAdapter(config) + adapter._bot = AsyncMock() + adapter._app = MagicMock() + return adapter + + +# =========================================================================== +# send_exec_approval — inline keyboard buttons +# =========================================================================== + +class TestTelegramExecApproval: + """Test the send_exec_approval method sends InlineKeyboard buttons.""" + + @pytest.mark.asyncio + async def test_sends_inline_keyboard(self): + adapter = _make_adapter() + mock_msg = MagicMock() + mock_msg.message_id = 42 + adapter._bot.send_message = AsyncMock(return_value=mock_msg) + + result = await adapter.send_exec_approval( + chat_id="12345", + command="rm -rf /important", + session_key="agent:main:telegram:group:12345:99", + description="dangerous deletion", + ) + + assert result.success is True + assert result.message_id == "42" + + adapter._bot.send_message.assert_called_once() + kwargs = adapter._bot.send_message.call_args[1] + assert kwargs["chat_id"] == 12345 + assert "rm -rf /important" in kwargs["text"] + assert "dangerous deletion" in kwargs["text"] + assert kwargs["reply_markup"] is not None # InlineKeyboardMarkup + + @pytest.mark.asyncio + async def test_stores_approval_state(self): + adapter = _make_adapter() + mock_msg = MagicMock() + mock_msg.message_id = 42 + adapter._bot.send_message = AsyncMock(return_value=mock_msg) + + await adapter.send_exec_approval( + chat_id="12345", + command="echo test", + session_key="my-session-key", + ) + + # The approval_id should map to the session_key + assert len(adapter._approval_state) == 1 + approval_id = list(adapter._approval_state.keys())[0] + assert adapter._approval_state[approval_id] == "my-session-key" + + @pytest.mark.asyncio + async def test_sends_in_thread(self): + adapter = _make_adapter() + mock_msg = MagicMock() + mock_msg.message_id = 42 + adapter._bot.send_message = AsyncMock(return_value=mock_msg) + + await adapter.send_exec_approval( + chat_id="12345", + command="ls", + session_key="s", + metadata={"thread_id": "999"}, + ) + + kwargs = adapter._bot.send_message.call_args[1] + assert kwargs.get("message_thread_id") == 999 + + @pytest.mark.asyncio + async def test_not_connected(self): + adapter = _make_adapter() + adapter._bot = None + result = await adapter.send_exec_approval( + chat_id="12345", command="ls", session_key="s" + ) + assert result.success is False + + @pytest.mark.asyncio + async def test_truncates_long_command(self): + adapter = _make_adapter() + mock_msg = MagicMock() + mock_msg.message_id = 1 + adapter._bot.send_message = AsyncMock(return_value=mock_msg) + + long_cmd = "x" * 5000 + await adapter.send_exec_approval( + chat_id="12345", command=long_cmd, session_key="s" + ) + + kwargs = adapter._bot.send_message.call_args[1] + assert "..." in kwargs["text"] + assert len(kwargs["text"]) < 5000 + + +# =========================================================================== +# _handle_callback_query — approval button clicks +# =========================================================================== + +class TestTelegramApprovalCallback: + """Test the approval callback handling in _handle_callback_query.""" + + @pytest.mark.asyncio + async def test_resolves_approval_on_click(self): + adapter = _make_adapter() + # Set up approval state + adapter._approval_state[1] = "agent:main:telegram:group:12345:99" + + # Mock callback query + query = AsyncMock() + query.data = "ea:once:1" + query.message = MagicMock() + query.message.chat_id = 12345 + query.from_user = MagicMock() + query.from_user.first_name = "Norbert" + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + + update = MagicMock() + update.callback_query = query + context = MagicMock() + + with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: + await adapter._handle_callback_query(update, context) + + mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once") + query.answer.assert_called_once() + query.edit_message_text.assert_called_once() + + # State should be cleaned up + assert 1 not in adapter._approval_state + + @pytest.mark.asyncio + async def test_deny_button(self): + adapter = _make_adapter() + adapter._approval_state[2] = "some-session" + + query = AsyncMock() + query.data = "ea:deny:2" + query.message = MagicMock() + query.message.chat_id = 12345 + query.from_user = MagicMock() + query.from_user.first_name = "Alice" + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + + update = MagicMock() + update.callback_query = query + context = MagicMock() + + with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: + await adapter._handle_callback_query(update, context) + + mock_resolve.assert_called_once_with("some-session", "deny") + edit_kwargs = query.edit_message_text.call_args[1] + assert "Denied" in edit_kwargs["text"] + + @pytest.mark.asyncio + async def test_already_resolved(self): + adapter = _make_adapter() + # No state for approval_id 99 — already resolved + + query = AsyncMock() + query.data = "ea:once:99" + query.message = MagicMock() + query.message.chat_id = 12345 + query.from_user = MagicMock() + query.from_user.first_name = "Bob" + query.answer = AsyncMock() + + update = MagicMock() + update.callback_query = query + context = MagicMock() + + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + await adapter._handle_callback_query(update, context) + + # Should NOT resolve — already handled + mock_resolve.assert_not_called() + # Should still ack with "already resolved" message + query.answer.assert_called_once() + assert "already been resolved" in query.answer.call_args[1]["text"] + + @pytest.mark.asyncio + async def test_model_picker_callback_not_affected(self): + """Ensure model picker callbacks still route correctly.""" + adapter = _make_adapter() + + query = AsyncMock() + query.data = "mp:some_provider" + query.message = MagicMock() + query.message.chat_id = 12345 + query.from_user = MagicMock() + + update = MagicMock() + update.callback_query = query + context = MagicMock() + + # Model picker callback should be handled (not crash) + # We just verify it doesn't try to resolve an approval + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + with patch.object(adapter, "_handle_model_picker_callback", new_callable=AsyncMock): + await adapter._handle_callback_query(update, context) + + mock_resolve.assert_not_called() + + @pytest.mark.asyncio + async def test_update_prompt_callback_not_affected(self): + """Ensure update prompt callbacks still work.""" + adapter = _make_adapter() + + query = AsyncMock() + query.data = "update_prompt:y" + query.message = MagicMock() + query.message.chat_id = 12345 + query.from_user = MagicMock() + query.from_user.id = 123 + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + + update = MagicMock() + update.callback_query = query + context = MagicMock() + + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + with patch("hermes_constants.get_hermes_home", return_value=Path("/tmp/test")): + try: + await adapter._handle_callback_query(update, context) + except Exception: + pass # May fail on file write, that's fine + + # Should NOT have triggered approval resolution + mock_resolve.assert_not_called()