diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index b2d5f6e22ab..d507c60bd77 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -3187,13 +3187,25 @@ class BasePlatformAdapter(ABC): logger.warning("[%s] Auto-TTS failed: %s", self.name, tts_err) # Play TTS audio before text (voice-first experience) + _tts_caption_delivered = False if _tts_path and Path(_tts_path).exists(): try: - await self.play_tts( + telegram_tts_caption = None + if ( + self.platform == Platform.TELEGRAM + and text_content + and text_content[:1024] == text_content + ): + telegram_tts_caption = text_content + tts_result = await self.play_tts( chat_id=event.source.chat_id, audio_path=_tts_path, + caption=telegram_tts_caption, metadata=_thread_metadata, ) + _tts_caption_delivered = bool( + telegram_tts_caption and getattr(tts_result, "success", False) + ) finally: try: os.remove(_tts_path) @@ -3201,7 +3213,7 @@ class BasePlatformAdapter(ABC): pass # Send the text portion - if text_content: + if text_content and not _tts_caption_delivered: logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id) _reply_anchor = _reply_anchor_for_event(event) # Mark final response messages for notification delivery. diff --git a/tests/gateway/test_base_topic_sessions.py b/tests/gateway/test_base_topic_sessions.py index 665f99ac4c2..a55fcb1d8ff 100644 --- a/tests/gateway/test_base_topic_sessions.py +++ b/tests/gateway/test_base_topic_sessions.py @@ -1,12 +1,14 @@ """Tests for BasePlatformAdapter topic-aware session handling.""" import asyncio +import json from types import SimpleNamespace +from unittest.mock import AsyncMock, patch import pytest from gateway.config import Platform, PlatformConfig -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, ProcessingOutcome, SendResult +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, ProcessingOutcome, SendResult from gateway.session import SessionSource, build_session_key @@ -246,3 +248,107 @@ class TestBasePlatformTopicSessions: ("start", "1"), ("complete", "1", ProcessingOutcome.CANCELLED), ] + + +class TestTelegramAutoTtsCaptionDelivery: + @staticmethod + def _make_voice_event(chat_id: str = "-1001", thread_id: str = "17585") -> MessageEvent: + return MessageEvent( + text="hello", + message_type=MessageType.VOICE, + source=SessionSource( + platform=Platform.TELEGRAM, + chat_id=chat_id, + chat_type="group", + thread_id=thread_id, + ), + message_id="voice-1", + ) + + @staticmethod + def _hold_typing(): + async def hold(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + return hold + + @pytest.mark.asyncio + async def test_short_telegram_auto_tts_uses_caption_without_followup_text(self, tmp_path): + adapter = DummyTelegramAdapter() + adapter._keep_typing = self._hold_typing() + adapter._should_auto_tts_for_chat = lambda _chat_id: True + adapter.play_tts = AsyncMock(return_value=SendResult(success=True, message_id="tts-1")) + adapter.set_message_handler(lambda _event: asyncio.sleep(0, result="Short reply")) + + tts_path = tmp_path / "reply.ogg" + tts_path.write_text("audio", encoding="utf-8") + event = self._make_voice_event() + + with patch("tools.tts_tool.check_tts_requirements", return_value=True), patch( + "tools.tts_tool.text_to_speech_tool", + return_value=json.dumps({"file_path": str(tts_path)}), + ): + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.play_tts.assert_awaited_once() + assert adapter.play_tts.await_args.kwargs["caption"] == "Short reply" + assert adapter.sent == [] + + @pytest.mark.asyncio + async def test_long_telegram_auto_tts_keeps_followup_text_when_caption_would_truncate(self, tmp_path): + adapter = DummyTelegramAdapter() + adapter._keep_typing = self._hold_typing() + adapter._should_auto_tts_for_chat = lambda _chat_id: True + adapter.play_tts = AsyncMock(return_value=SendResult(success=True, message_id="tts-1")) + long_reply = "x" * 1025 + adapter.set_message_handler(lambda _event: asyncio.sleep(0, result=long_reply)) + + tts_path = tmp_path / "reply.ogg" + tts_path.write_text("audio", encoding="utf-8") + event = self._make_voice_event() + + with patch("tools.tts_tool.check_tts_requirements", return_value=True), patch( + "tools.tts_tool.text_to_speech_tool", + return_value=json.dumps({"file_path": str(tts_path)}), + ): + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.play_tts.assert_awaited_once() + assert adapter.play_tts.await_args.kwargs["caption"] is None + assert adapter.sent == [ + { + "chat_id": "-1001", + "content": long_reply, + "reply_to": None, + "metadata": {"thread_id": "17585", "notify": True}, + } + ] + + @pytest.mark.asyncio + async def test_telegram_auto_tts_send_failure_keeps_followup_text(self, tmp_path): + adapter = DummyTelegramAdapter() + adapter._keep_typing = self._hold_typing() + adapter._should_auto_tts_for_chat = lambda _chat_id: True + adapter.play_tts = AsyncMock(return_value=SendResult(success=False, error="boom")) + adapter.set_message_handler(lambda _event: asyncio.sleep(0, result="Short reply")) + + tts_path = tmp_path / "reply.ogg" + tts_path.write_text("audio", encoding="utf-8") + event = self._make_voice_event() + + with patch("tools.tts_tool.check_tts_requirements", return_value=True), patch( + "tools.tts_tool.text_to_speech_tool", + return_value=json.dumps({"file_path": str(tts_path)}), + ): + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.play_tts.assert_awaited_once() + assert adapter.play_tts.await_args.kwargs["caption"] == "Short reply" + assert adapter.sent == [ + { + "chat_id": "-1001", + "content": "Short reply", + "reply_to": None, + "metadata": {"thread_id": "17585", "notify": True}, + } + ]