hermes-agent/tests/gateway/test_telegram_approval_buttons.py
rahimsais 737314fe91 fix(telegram): normalize dm threads and retry control sends
Cherry-picked from PR #10371. Two-layer defense for the spurious-thread_id
issue (#3206):

1. _build_message_event filters DM thread_ids: only preserve thread_id
   for real topic messages (is_topic_message=True). Telegram puts
   message_thread_id on every DM that is a reply, but reply-chain ids
   route to nonexistent threads on send.

2. _send_message_with_thread_fallback helper: control sends
   (send_update_prompt, send_exec_approval / send_slash_confirm,
   send_model_picker) retry once without message_thread_id when
   Telegram returns BadRequest 'Message thread not found'. Mirrors
   the pattern PR #3390 added for the streaming send path.

Salvage notes:
- Conflict 1 (line ~4099): merged the contributor's DM is_topic_message
  filter with the existing forum General-topic default from #22423,
  preserving both behaviors.
- Conflict 2 (line ~1664 / 1690): kept main's delete_message (PR #23416)
  alongside the new helper. Tightened the helper's exception catch
  from bare 'Exception' to use the existing _is_bad_request_error +
  _is_thread_not_found_error helpers (line 484-496) for consistency
  with the streaming send path.
- Widened the fix to send_update_prompt (was bare self._bot.send_message,
  same bug class).

Authored by rahimsais via PR #10371 (re-attributed from donrhmexe@
local commit author).
2026-05-10 18:09:31 -07:00

478 lines
17 KiB
Python

"""Tests for Telegram inline keyboard approval buttons."""
import asyncio
import os
import sys
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Ensure the repo root is importable
# ---------------------------------------------------------------------------
_repo = str(Path(__file__).resolve().parents[2])
if _repo not in sys.path:
sys.path.insert(0, _repo)
# ---------------------------------------------------------------------------
# Minimal Telegram mock so TelegramAdapter can be imported
# ---------------------------------------------------------------------------
def _ensure_telegram_mock():
"""Wire up the minimal mocks required to import TelegramAdapter."""
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
mod = MagicMock()
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
mod.constants.ParseMode.MARKDOWN = "Markdown"
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
mod.constants.ParseMode.HTML = "HTML"
mod.constants.ChatType.PRIVATE = "private"
mod.constants.ChatType.GROUP = "group"
mod.constants.ChatType.SUPERGROUP = "supergroup"
mod.constants.ChatType.CHANNEL = "channel"
# Provide real exception classes so ``except (NetworkError, ...)`` in
# connect() doesn't blow up under xdist when this mock leaks.
mod.error.NetworkError = type("NetworkError", (OSError,), {})
mod.error.TimedOut = type("TimedOut", (OSError,), {})
mod.error.BadRequest = type("BadRequest", (Exception,), {})
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, mod)
sys.modules.setdefault("telegram.error", mod.error)
_ensure_telegram_mock()
from gateway.platforms.telegram import TelegramAdapter
from gateway.config import Platform, PlatformConfig
def _make_adapter(extra=None):
"""Create a TelegramAdapter with mocked internals."""
config = PlatformConfig(enabled=True, token="test-token", extra=extra or {})
adapter = TelegramAdapter(config)
adapter._bot = AsyncMock()
adapter._app = MagicMock()
return adapter
class _AuthRunner:
"""Minimal runner shim for callback auth tests."""
def __init__(self, authorized: bool):
self.authorized = authorized
self.last_source = None
async def _handle_message(self, event):
return None
def _is_user_authorized(self, source):
self.last_source = source
return self.authorized
# ===========================================================================
# send_exec_approval — inline keyboard buttons
# ===========================================================================
class TestTelegramExecApproval:
"""Test the send_exec_approval method sends InlineKeyboard buttons."""
@pytest.mark.asyncio
async def test_sends_inline_keyboard(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
result = await adapter.send_exec_approval(
chat_id="12345",
command="rm -rf /important",
session_key="agent:main:telegram:group:12345:99",
description="dangerous deletion",
)
assert result.success is True
assert result.message_id == "42"
adapter._bot.send_message.assert_called_once()
kwargs = adapter._bot.send_message.call_args[1]
assert kwargs["chat_id"] == 12345
assert "rm -rf /important" in kwargs["text"]
assert "dangerous deletion" in kwargs["text"]
assert kwargs["reply_markup"] is not None # InlineKeyboardMarkup
@pytest.mark.asyncio
async def test_stores_approval_state(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
await adapter.send_exec_approval(
chat_id="12345",
command="echo test",
session_key="my-session-key",
)
# The approval_id should map to the session_key
assert len(adapter._approval_state) == 1
approval_id = list(adapter._approval_state.keys())[0]
assert adapter._approval_state[approval_id] == "my-session-key"
@pytest.mark.asyncio
async def test_sends_in_thread(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
await adapter.send_exec_approval(
chat_id="12345",
command="ls",
session_key="s",
metadata={"thread_id": "999"},
)
kwargs = adapter._bot.send_message.call_args[1]
assert kwargs.get("message_thread_id") == 999
@pytest.mark.asyncio
async def test_retries_without_thread_when_thread_not_found(self):
adapter = _make_adapter()
call_log = []
class FakeBadRequest(Exception):
pass
async def mock_send_message(**kwargs):
call_log.append(dict(kwargs))
if kwargs.get("message_thread_id") is not None:
raise FakeBadRequest("Message thread not found")
return SimpleNamespace(message_id=42)
adapter._bot.send_message = AsyncMock(side_effect=mock_send_message)
result = await adapter.send_exec_approval(
chat_id="12345",
command="ls",
session_key="s",
metadata={"thread_id": "999"},
)
assert result.success is True
assert len(call_log) == 2
assert call_log[0]["message_thread_id"] == 999
assert "message_thread_id" not in call_log[1] or call_log[1]["message_thread_id"] is None
@pytest.mark.asyncio
async def test_not_connected(self):
adapter = _make_adapter()
adapter._bot = None
result = await adapter.send_exec_approval(
chat_id="12345", command="ls", session_key="s"
)
assert result.success is False
@pytest.mark.asyncio
async def test_disable_link_previews_sets_preview_kwargs(self):
adapter = _make_adapter(extra={"disable_link_previews": True})
mock_msg = MagicMock()
mock_msg.message_id = 42
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
await adapter.send_exec_approval(
chat_id="12345", command="ls", session_key="s"
)
kwargs = adapter._bot.send_message.call_args[1]
assert (
kwargs.get("disable_web_page_preview") is True
or kwargs.get("link_preview_options") is not None
)
@pytest.mark.asyncio
async def test_truncates_long_command(self):
adapter = _make_adapter()
mock_msg = MagicMock()
mock_msg.message_id = 1
adapter._bot.send_message = AsyncMock(return_value=mock_msg)
long_cmd = "x" * 5000
await adapter.send_exec_approval(
chat_id="12345", command=long_cmd, session_key="s"
)
kwargs = adapter._bot.send_message.call_args[1]
assert "..." in kwargs["text"]
assert len(kwargs["text"]) < 5000
# ===========================================================================
# _handle_callback_query — approval button clicks
# ===========================================================================
class TestTelegramApprovalCallback:
"""Test the approval callback handling in _handle_callback_query."""
@pytest.mark.asyncio
async def test_resolves_approval_on_click(self):
adapter = _make_adapter()
# Set up approval state
adapter._approval_state[1] = "agent:main:telegram:group:12345:99"
# Mock callback query
query = AsyncMock()
query.data = "ea:once:1"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Norbert"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
query.from_user.id = "12345"
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once")
query.answer.assert_called_once()
query.edit_message_text.assert_called_once()
# State should be cleaned up
assert 1 not in adapter._approval_state
@pytest.mark.asyncio
async def test_deny_button(self):
adapter = _make_adapter()
adapter._approval_state[2] = "some-session"
query = AsyncMock()
query.data = "ea:deny:2"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Alice"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
query.from_user.id = "12345"
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("some-session", "deny")
edit_kwargs = query.edit_message_text.call_args[1]
assert "Denied" in edit_kwargs["text"]
@pytest.mark.asyncio
async def test_approval_callback_rejects_user_blocked_by_global_allowlist(self):
adapter = _make_adapter()
adapter._approval_state[7] = "agent:main:telegram:group:12345:99"
runner = _AuthRunner(authorized=False)
adapter._message_handler = runner._handle_message
query = AsyncMock()
query.data = "ea:once:7"
query.message = MagicMock()
query.message.chat_id = 12345
query.message.chat.type = "private"
query.from_user = MagicMock()
query.from_user.id = 222
query.from_user.first_name = "Mallory"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_not_called()
query.answer.assert_called_once()
assert "not authorized" in query.answer.call_args[1]["text"].lower()
query.edit_message_text.assert_not_called()
assert adapter._approval_state[7] == "agent:main:telegram:group:12345:99"
assert runner.last_source is not None
assert runner.last_source.platform == Platform.TELEGRAM
assert runner.last_source.user_id == "222"
assert runner.last_source.chat_id == "12345"
@pytest.mark.asyncio
async def test_already_resolved(self):
adapter = _make_adapter()
# No state for approval_id 99 — already resolved
query = AsyncMock()
query.data = "ea:once:99"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.first_name = "Bob"
query.answer = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
query.from_user.id = "12345"
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_callback_query(update, context)
# Should NOT resolve — already handled
mock_resolve.assert_not_called()
# Should still ack with "already resolved" message
query.answer.assert_called_once()
assert "already been resolved" in query.answer.call_args[1]["text"]
@pytest.mark.asyncio
async def test_model_picker_callback_not_affected(self):
"""Ensure model picker callbacks still route correctly."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "mp:some_provider"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
# Model picker callback should be handled (not crash)
# We just verify it doesn't try to resolve an approval
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
with patch.object(adapter, "_handle_model_picker_callback", new_callable=AsyncMock):
await adapter._handle_callback_query(update, context)
mock_resolve.assert_not_called()
@pytest.mark.asyncio
async def test_update_prompt_callback_not_affected(self, tmp_path):
"""Ensure update prompt callbacks still work."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "update_prompt:y"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.id = 123
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
with patch("hermes_constants.get_hermes_home", return_value=tmp_path):
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": ""}):
await adapter._handle_callback_query(update, context)
# Should NOT have triggered approval resolution
mock_resolve.assert_not_called()
assert (tmp_path / ".update_response").read_text() == "y"
@pytest.mark.asyncio
async def test_update_prompt_callback_rejects_unauthorized_user(self, tmp_path):
"""Update prompt buttons should honor TELEGRAM_ALLOWED_USERS."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "update_prompt:y"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.id = 222
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("hermes_constants.get_hermes_home", return_value=tmp_path):
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "111"}):
await adapter._handle_callback_query(update, context)
query.answer.assert_called_once()
assert "not authorized" in query.answer.call_args[1]["text"].lower()
query.edit_message_text.assert_not_called()
assert not (tmp_path / ".update_response").exists()
@pytest.mark.asyncio
async def test_update_prompt_callback_rejects_user_blocked_by_global_allowlist(self, tmp_path):
adapter = _make_adapter()
runner = _AuthRunner(authorized=False)
adapter._message_handler = runner._handle_message
query = AsyncMock()
query.data = "update_prompt:y"
query.message = MagicMock()
query.message.chat_id = 12345
query.message.chat.type = "private"
query.from_user = MagicMock()
query.from_user.id = 222
query.from_user.first_name = "Mallory"
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("hermes_constants.get_hermes_home", return_value=tmp_path):
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": ""}):
await adapter._handle_callback_query(update, context)
query.answer.assert_called_once()
assert "not authorized" in query.answer.call_args[1]["text"].lower()
query.edit_message_text.assert_not_called()
assert not (tmp_path / ".update_response").exists()
assert runner.last_source is not None
assert runner.last_source.platform == Platform.TELEGRAM
assert runner.last_source.user_id == "222"
@pytest.mark.asyncio
async def test_update_prompt_callback_allows_authorized_user(self, tmp_path):
"""Allowed Telegram users can still answer update prompt buttons."""
adapter = _make_adapter()
query = AsyncMock()
query.data = "update_prompt:n"
query.message = MagicMock()
query.message.chat_id = 12345
query.from_user = MagicMock()
query.from_user.id = 111
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
with patch("hermes_constants.get_hermes_home", return_value=tmp_path):
with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "111"}):
await adapter._handle_callback_query(update, context)
query.answer.assert_called_once()
query.edit_message_text.assert_called_once()
assert (tmp_path / ".update_response").read_text() == "n"