fix(telegram): normalize dm threads and retry control sends

Cherry-picked from PR #10371. Two-layer defense for the spurious-thread_id
issue (#3206):

1. _build_message_event filters DM thread_ids: only preserve thread_id
   for real topic messages (is_topic_message=True). Telegram puts
   message_thread_id on every DM that is a reply, but reply-chain ids
   route to nonexistent threads on send.

2. _send_message_with_thread_fallback helper: control sends
   (send_update_prompt, send_exec_approval / send_slash_confirm,
   send_model_picker) retry once without message_thread_id when
   Telegram returns BadRequest 'Message thread not found'. Mirrors
   the pattern PR #3390 added for the streaming send path.

Salvage notes:
- Conflict 1 (line ~4099): merged the contributor's DM is_topic_message
  filter with the existing forum General-topic default from #22423,
  preserving both behaviors.
- Conflict 2 (line ~1664 / 1690): kept main's delete_message (PR #23416)
  alongside the new helper. Tightened the helper's exception catch
  from bare 'Exception' to use the existing _is_bad_request_error +
  _is_thread_not_found_error helpers (line 484-496) for consistency
  with the streaming send path.
- Widened the fix to send_update_prompt (was bare self._bot.send_message,
  same bug class).

Authored by rahimsais via PR #10371 (re-attributed from donrhmexe@
local commit author).
This commit is contained in:
rahimsais 2026-05-10 16:25:13 -07:00 committed by Teknium
parent 404640a2b7
commit 737314fe91
4 changed files with 210 additions and 13 deletions

View file

@ -1686,6 +1686,38 @@ class TelegramAdapter(BasePlatformAdapter):
) )
return False return False
async def _send_message_with_thread_fallback(self, **kwargs):
"""Send a Telegram message, retrying once without message_thread_id
if Telegram returns 'Message thread not found'.
Used for control-style sends (approval prompts, model picker,
update prompts) that can carry a stale thread_id from a DM
reply chain. The streaming send loop has its own equivalent
(PR #3390) at the body of ``send``; this helper applies the
same retry pattern to the non-streaming control paths.
"""
if not self._bot:
raise RuntimeError("Not connected")
message_thread_id = kwargs.get("message_thread_id")
try:
return await self._bot.send_message(**kwargs)
except Exception as send_err:
if (
message_thread_id is not None
and self._is_bad_request_error(send_err)
and self._is_thread_not_found_error(send_err)
):
logger.warning(
"[%s] Thread %s not found for control message, retrying without message_thread_id",
self.name,
message_thread_id,
)
retry_kwargs = dict(kwargs)
retry_kwargs.pop("message_thread_id", None)
return await self._bot.send_message(**retry_kwargs)
raise
async def send_update_prompt( async def send_update_prompt(
self, chat_id: str, prompt: str, default: str = "", self, chat_id: str, prompt: str, default: str = "",
session_key: str = "", session_key: str = "",
@ -1709,7 +1741,7 @@ class TelegramAdapter(BasePlatformAdapter):
]) ])
thread_id = self._metadata_thread_id(metadata) thread_id = self._metadata_thread_id(metadata)
reply_to_id = self._reply_to_message_id_for_send(None, metadata) reply_to_id = self._reply_to_message_id_for_send(None, metadata)
msg = await self._bot.send_message( msg = await self._send_message_with_thread_fallback(
chat_id=int(chat_id), chat_id=int(chat_id),
text=text, text=text,
parse_mode=ParseMode.MARKDOWN, parse_mode=ParseMode.MARKDOWN,
@ -1789,7 +1821,7 @@ class TelegramAdapter(BasePlatformAdapter):
) )
) )
msg = await self._bot.send_message(**kwargs) msg = await self._send_message_with_thread_fallback(**kwargs)
# Store session_key keyed by approval_id for the callback handler # Store session_key keyed by approval_id for the callback handler
self._approval_state[approval_id] = session_key self._approval_state[approval_id] = session_key
@ -1841,7 +1873,7 @@ class TelegramAdapter(BasePlatformAdapter):
) )
) )
msg = await self._bot.send_message(**kwargs) msg = await self._send_message_with_thread_fallback(**kwargs)
self._slash_confirm_state[confirm_id] = session_key self._slash_confirm_state[confirm_id] = session_key
return SendResult(success=True, message_id=str(msg.message_id)) return SendResult(success=True, message_id=str(msg.message_id))
except Exception as e: except Exception as e:
@ -1899,7 +1931,7 @@ class TelegramAdapter(BasePlatformAdapter):
thread_id = metadata.get("thread_id") if metadata else None thread_id = metadata.get("thread_id") if metadata else None
reply_to_id = self._reply_to_message_id_for_send(None, metadata) reply_to_id = self._reply_to_message_id_for_send(None, metadata)
msg = await self._bot.send_message( msg = await self._send_message_with_thread_fallback(
chat_id=int(chat_id), chat_id=int(chat_id),
text=text, text=text,
parse_mode=ParseMode.MARKDOWN, parse_mode=ParseMode.MARKDOWN,
@ -4069,9 +4101,24 @@ class TelegramAdapter(BasePlatformAdapter):
elif chat.type == ChatType.CHANNEL: elif chat.type == ChatType.CHANNEL:
chat_type = "channel" chat_type = "channel"
# Resolve DM topic name and skill binding # Resolve DM topic name and skill binding.
# In private chats, only preserve thread ids for real topic messages
# (is_topic_message=True). Telegram puts message_thread_id on every
# DM that is a reply, even when the user is just replying to a
# previous message in the same DM — that bogus id then routes to a
# nonexistent thread and Telegram returns 'Message thread not found'
# on send (#3206).
thread_id_raw = message.message_thread_id thread_id_raw = message.message_thread_id
thread_id_str = str(thread_id_raw) if thread_id_raw is not None else None is_topic_message = bool(getattr(message, "is_topic_message", False))
thread_id_str = None
if thread_id_raw is not None:
if chat_type == "group":
thread_id_str = str(thread_id_raw)
elif chat_type == "dm" and is_topic_message:
thread_id_str = str(thread_id_raw)
# For forum groups without an explicit topic, default to the
# General-topic id so the gateway routes back to the General topic
# rather than dropping into the bot's main channel (#22423).
if chat_type == "group" and thread_id_str is None and getattr(chat, "is_forum", False): if chat_type == "group" and thread_id_str is None and getattr(chat, "is_forum", False):
thread_id_str = self._GENERAL_TOPIC_THREAD_ID thread_id_str = self._GENERAL_TOPIC_THREAD_ID
chat_topic = None chat_topic = None

View file

@ -448,7 +448,8 @@ def test_cache_dm_topic_from_message_no_overwrite():
def _make_mock_message(chat_id=111, chat_type="private", text="hello", thread_id=None, def _make_mock_message(chat_id=111, chat_type="private", text="hello", thread_id=None,
user_id=42, user_name="Test User", forum_topic_created=None): user_id=42, user_name="Test User", forum_topic_created=None,
is_topic_message=None):
"""Create a mock Telegram Message for _build_message_event tests.""" """Create a mock Telegram Message for _build_message_event tests."""
chat = SimpleNamespace( chat = SimpleNamespace(
id=chat_id, id=chat_id,
@ -464,11 +465,15 @@ def _make_mock_message(chat_id=111, chat_type="private", text="hello", thread_id
full_name=user_name, full_name=user_name,
) )
if is_topic_message is None:
is_topic_message = bool(thread_id) if chat_type == "private" else None
msg = SimpleNamespace( msg = SimpleNamespace(
chat=chat, chat=chat,
from_user=user, from_user=user,
text=text, text=text,
message_thread_id=thread_id, message_thread_id=thread_id,
is_topic_message=is_topic_message,
message_id=1001, message_id=1001,
reply_to_message=None, reply_to_message=None,
date=None, date=None,
@ -531,6 +536,40 @@ def test_build_message_event_no_auto_skill_without_thread():
assert event.auto_skill is None assert event.auto_skill is None
def test_build_message_event_filters_non_topic_dm_thread_id():
"""A DM reply-thread id should not be persisted unless Telegram marks it as a topic message."""
from gateway.platforms.base import MessageType
adapter = _make_adapter()
msg = _make_mock_message(chat_id=111, thread_id=777, is_topic_message=False)
event = adapter._build_message_event(msg, MessageType.TEXT)
assert event.source.thread_id is None
assert event.source.chat_topic is None
assert event.auto_skill is None
def test_build_message_event_preserves_true_dm_topic_thread_id():
"""True DM topic messages should keep their thread id for routing."""
from gateway.platforms.base import MessageType
adapter = _make_adapter([
{
"chat_id": 111,
"topics": [
{"name": "General", "thread_id": 200},
],
}
])
adapter._dm_topics["111:General"] = 200
msg = _make_mock_message(chat_id=111, thread_id=200, is_topic_message=True)
event = adapter._build_message_event(msg, MessageType.TEXT)
assert event.source.thread_id == "200"
assert event.source.chat_topic == "General"
# ── _build_message_event: group_topics skill binding ── # ── _build_message_event: group_topics skill binding ──
# The telegram mock sets sys.modules["telegram.constants"] = telegram_mod (root mock), # The telegram mock sets sys.modules["telegram.constants"] = telegram_mod (root mock),

View file

@ -4,6 +4,7 @@ import asyncio
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -140,6 +141,34 @@ class TestTelegramExecApproval:
kwargs = adapter._bot.send_message.call_args[1] kwargs = adapter._bot.send_message.call_args[1]
assert kwargs.get("message_thread_id") == 999 assert kwargs.get("message_thread_id") == 999
@pytest.mark.asyncio
async def test_retries_without_thread_when_thread_not_found(self):
adapter = _make_adapter()
call_log = []
class FakeBadRequest(Exception):
pass
async def mock_send_message(**kwargs):
call_log.append(dict(kwargs))
if kwargs.get("message_thread_id") is not None:
raise FakeBadRequest("Message thread not found")
return SimpleNamespace(message_id=42)
adapter._bot.send_message = AsyncMock(side_effect=mock_send_message)
result = await adapter.send_exec_approval(
chat_id="12345",
command="ls",
session_key="s",
metadata={"thread_id": "999"},
)
assert result.success is True
assert len(call_log) == 2
assert call_log[0]["message_thread_id"] == 999
assert "message_thread_id" not in call_log[1] or call_log[1]["message_thread_id"] is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_not_connected(self): async def test_not_connected(self):
adapter = _make_adapter() adapter = _make_adapter()
@ -209,9 +238,11 @@ class TestTelegramApprovalCallback:
update = MagicMock() update = MagicMock()
update.callback_query = query update.callback_query = query
context = MagicMock() context = MagicMock()
query.from_user.id = "12345"
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
await adapter._handle_callback_query(update, context) with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once") mock_resolve.assert_called_once_with("agent:main:telegram:group:12345:99", "once")
query.answer.assert_called_once() query.answer.assert_called_once()
@ -237,9 +268,11 @@ class TestTelegramApprovalCallback:
update = MagicMock() update = MagicMock()
update.callback_query = query update.callback_query = query
context = MagicMock() context = MagicMock()
query.from_user.id = "12345"
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
await adapter._handle_callback_query(update, context) with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
await adapter._handle_callback_query(update, context)
mock_resolve.assert_called_once_with("some-session", "deny") mock_resolve.assert_called_once_with("some-session", "deny")
edit_kwargs = query.edit_message_text.call_args[1] edit_kwargs = query.edit_message_text.call_args[1]
@ -296,9 +329,11 @@ class TestTelegramApprovalCallback:
update = MagicMock() update = MagicMock()
update.callback_query = query update.callback_query = query
context = MagicMock() context = MagicMock()
query.from_user.id = "12345"
with patch("tools.approval.resolve_gateway_approval") as mock_resolve: with patch.dict(os.environ, {"TELEGRAM_ALLOWED_USERS": "*"}, clear=False):
await adapter._handle_callback_query(update, context) with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
await adapter._handle_callback_query(update, context)
# Should NOT resolve — already handled # Should NOT resolve — already handled
mock_resolve.assert_not_called() mock_resolve.assert_not_called()

View file

@ -0,0 +1,76 @@
"""Tests for Telegram model picker thread fallback."""
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
def _ensure_telegram_mock():
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
mod = MagicMock()
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
mod.constants.ParseMode.MARKDOWN = "Markdown"
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
mod.constants.ParseMode.HTML = "HTML"
mod.constants.ChatType.PRIVATE = "private"
mod.constants.ChatType.GROUP = "group"
mod.constants.ChatType.SUPERGROUP = "supergroup"
mod.constants.ChatType.CHANNEL = "channel"
mod.error.NetworkError = type("NetworkError", (OSError,), {})
mod.error.TimedOut = type("TimedOut", (OSError,), {})
mod.error.BadRequest = type("BadRequest", (Exception,), {})
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, mod)
sys.modules.setdefault("telegram.error", mod.error)
_ensure_telegram_mock()
from gateway.config import PlatformConfig
from gateway.platforms.telegram import TelegramAdapter
def _make_adapter():
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="test-token"))
adapter._bot = AsyncMock()
adapter._app = MagicMock()
return adapter
class TestTelegramModelPicker:
@pytest.mark.asyncio
async def test_retries_without_thread_when_thread_not_found(self):
adapter = _make_adapter()
providers = [{"slug": "openai", "name": "OpenAI", "total_models": 2, "is_current": True}]
call_log = []
class FakeBadRequest(Exception):
pass
async def mock_send_message(**kwargs):
call_log.append(dict(kwargs))
if kwargs.get("message_thread_id") is not None:
raise FakeBadRequest("Message thread not found")
return SimpleNamespace(message_id=99)
adapter._bot.send_message = AsyncMock(side_effect=mock_send_message)
result = await adapter.send_model_picker(
chat_id="12345",
providers=providers,
current_model="gpt-5",
current_provider="openai",
session_key="s",
on_model_selected=AsyncMock(),
metadata={"thread_id": "99999"},
)
assert result.success is True
assert len(call_log) == 2
assert call_log[0]["message_thread_id"] == 99999
assert "message_thread_id" not in call_log[1] or call_log[1]["message_thread_id"] is None