diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 0a8390a7a..e57a84bb3 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -502,6 +502,14 @@ class MessageType(Enum): COMMAND = "command" # /command style +class ProcessingOutcome(Enum): + """Result classification for message-processing lifecycle hooks.""" + + SUCCESS = "success" + FAILURE = "failure" + CANCELLED = "cancelled" + + @dataclass class MessageEvent: """ @@ -625,6 +633,7 @@ class BasePlatformAdapter(ABC): # Gateway shutdown cancels these so an old gateway instance doesn't keep # working on a task after --replace or manual restarts. self._background_tasks: set[asyncio.Task] = set() + self._expected_cancelled_tasks: set[asyncio.Task] = set() # Chats where auto-TTS on voice input is disabled (set by /voice off) self._auto_tts_disabled_chats: set = set() # Chats where typing indicator is paused (e.g. during approval waits). @@ -1133,7 +1142,7 @@ class BasePlatformAdapter(ABC): async def on_processing_start(self, event: MessageEvent) -> None: """Hook called when background processing begins.""" - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: """Hook called when background processing completes.""" async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None: @@ -1352,6 +1361,7 @@ class BasePlatformAdapter(ABC): return if hasattr(task, "add_done_callback"): task.add_done_callback(self._background_tasks.discard) + task.add_done_callback(self._expected_cancelled_tasks.discard) @staticmethod def _get_human_delay() -> float: @@ -1580,7 +1590,11 @@ class BasePlatformAdapter(ABC): # Determine overall success for the processing hook processing_ok = delivery_succeeded if delivery_attempted else not bool(response) - await self._run_processing_hook("on_processing_complete", event, processing_ok) + await self._run_processing_hook( + "on_processing_complete", + event, + ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE, + ) # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: @@ -1599,10 +1613,14 @@ class BasePlatformAdapter(ABC): return # Already cleaned up except asyncio.CancelledError: - await self._run_processing_hook("on_processing_complete", event, False) + current_task = asyncio.current_task() + outcome = ProcessingOutcome.CANCELLED + if current_task is None or current_task not in self._expected_cancelled_tasks: + outcome = ProcessingOutcome.FAILURE + await self._run_processing_hook("on_processing_complete", event, outcome) raise except Exception as e: - await self._run_processing_hook("on_processing_complete", event, False) + await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE) logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True) # Send the error to the user so they aren't left with radio silence try: @@ -1646,10 +1664,12 @@ class BasePlatformAdapter(ABC): """ tasks = [task for task in self._background_tasks if not task.done()] for task in tasks: + self._expected_cancelled_tasks.add(task) task.cancel() if tasks: await asyncio.gather(*tasks, return_exceptions=True) self._background_tasks.clear() + self._expected_cancelled_tasks.clear() self._pending_messages.clear() self._active_sessions.clear() diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 34a51e721..e503f0edd 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -49,6 +49,7 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, cache_image_from_url, cache_audio_from_url, @@ -754,14 +755,17 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message, "add_reaction"): await self._add_reaction(message, "👀") - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: """Swap the in-progress reaction for a final success/failure reaction.""" if not self._reactions_enabled(): return message = event.raw_message if hasattr(message, "add_reaction"): await self._remove_reaction(message, "👀") - await self._add_reaction(message, "✅" if success else "❌") + if outcome == ProcessingOutcome.SUCCESS: + await self._add_reaction(message, "✅") + elif outcome == ProcessingOutcome.FAILURE: + await self._add_reaction(message, "❌") async def send( self, diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 750df7a29..cf72d9566 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -40,6 +40,7 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, ) @@ -1479,7 +1480,7 @@ class MatrixAdapter(BasePlatformAdapter): await self._send_reaction(room_id, msg_id, "\U0001f440") async def on_processing_complete( - self, event: MessageEvent, success: bool, + self, event: MessageEvent, outcome: ProcessingOutcome, ) -> None: """Replace eyes with checkmark (success) or cross (failure).""" if not self._reactions_enabled: @@ -1488,11 +1489,15 @@ class MatrixAdapter(BasePlatformAdapter): room_id = event.source.chat_id if not msg_id or not room_id: return + if outcome == ProcessingOutcome.CANCELLED: + return # Note: Matrix doesn't support removing a specific reaction easily # without tracking the reaction event_id. We send the new reaction; # the eyes stays (acceptable UX — both are visible). await self._send_reaction( - room_id, msg_id, "\u2705" if success else "\u274c", + room_id, + msg_id, + "\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c", ) async def _on_reaction(self, room: Any, event: Any) -> None: diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 91de45fe8..ac5b7fb8c 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -60,6 +60,7 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, cache_image_from_bytes, cache_audio_from_bytes, @@ -2732,7 +2733,7 @@ class TelegramAdapter(BasePlatformAdapter): if chat_id and message_id: await self._set_reaction(chat_id, message_id, "\U0001f440") - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: """Swap the in-progress reaction for a final success/failure reaction. Unlike Discord (additive reactions), Telegram's set_message_reaction @@ -2742,5 +2743,9 @@ class TelegramAdapter(BasePlatformAdapter): return chat_id = getattr(event.source, "chat_id", None) message_id = getattr(event, "message_id", None) - if chat_id and message_id: - await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c") + if chat_id and message_id and outcome != ProcessingOutcome.CANCELLED: + await self._set_reaction( + chat_id, + message_id, + "\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c", + ) diff --git a/tests/gateway/test_base_topic_sessions.py b/tests/gateway/test_base_topic_sessions.py index 37e00b279..901bc3468 100644 --- a/tests/gateway/test_base_topic_sessions.py +++ b/tests/gateway/test_base_topic_sessions.py @@ -6,7 +6,7 @@ from types import SimpleNamespace import pytest from gateway.config import Platform, PlatformConfig -from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, ProcessingOutcome, SendResult from gateway.session import SessionSource, build_session_key @@ -44,8 +44,8 @@ class DummyTelegramAdapter(BasePlatformAdapter): async def on_processing_start(self, event: MessageEvent) -> None: self.processing_hooks.append(("start", event.message_id)) - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: - self.processing_hooks.append(("complete", event.message_id, success)) + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: + self.processing_hooks.append(("complete", event.message_id, outcome)) def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent: @@ -142,7 +142,7 @@ class TestBasePlatformTopicSessions: ] assert adapter.processing_hooks == [ ("start", "1"), - ("complete", "1", True), + ("complete", "1", ProcessingOutcome.SUCCESS), ] @pytest.mark.asyncio @@ -168,7 +168,7 @@ class TestBasePlatformTopicSessions: assert adapter.processing_hooks == [ ("start", "1"), - ("complete", "1", False), + ("complete", "1", ProcessingOutcome.FAILURE), ] @pytest.mark.asyncio @@ -190,7 +190,7 @@ class TestBasePlatformTopicSessions: assert adapter.processing_hooks == [ ("start", "1"), - ("complete", "1", False), + ("complete", "1", ProcessingOutcome.FAILURE), ] @pytest.mark.asyncio @@ -218,5 +218,31 @@ class TestBasePlatformTopicSessions: assert adapter.processing_hooks == [ ("start", "1"), - ("complete", "1", False), + ("complete", "1", ProcessingOutcome.FAILURE), + ] + + @pytest.mark.asyncio + async def test_cancel_background_tasks_marks_expected_cancellation_cancelled(self): + adapter = DummyTelegramAdapter() + release = asyncio.Event() + + async def handler(_event): + await release.wait() + return "ack" + + async def hold_typing(_chat_id, interval=2.0, metadata=None): + await asyncio.Event().wait() + + adapter.set_message_handler(handler) + adapter._keep_typing = hold_typing + + event = _make_event("-1001", "17585") + await adapter.handle_message(event) + await asyncio.sleep(0) + + await adapter.cancel_background_tasks() + + assert adapter.processing_hooks == [ + ("start", "1"), + ("complete", "1", ProcessingOutcome.CANCELLED), ] diff --git a/tests/gateway/test_discord_reactions.py b/tests/gateway/test_discord_reactions.py index 3988c67b5..2d7b2a2c9 100644 --- a/tests/gateway/test_discord_reactions.py +++ b/tests/gateway/test_discord_reactions.py @@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest from gateway.config import Platform, PlatformConfig -from gateway.platforms.base import MessageEvent, MessageType, SendResult +from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome, SendResult from gateway.session import SessionSource, build_session_key @@ -212,7 +212,7 @@ async def test_reactions_disabled_via_env_zero(adapter, monkeypatch): event = _make_event("5", raw_message) await adapter.on_processing_start(event) - await adapter.on_processing_complete(event, success=True) + await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) raw_message.add_reaction.assert_not_awaited() raw_message.remove_reaction.assert_not_awaited() @@ -232,3 +232,17 @@ async def test_reactions_enabled_by_default(adapter, monkeypatch): await adapter.on_processing_start(event) raw_message.add_reaction.assert_awaited_once_with("👀") + + +@pytest.mark.asyncio +async def test_on_processing_complete_cancelled_removes_eyes_without_terminal_reaction(adapter): + raw_message = SimpleNamespace( + add_reaction=AsyncMock(), + remove_reaction=AsyncMock(), + ) + + event = _make_event("7", raw_message) + await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED) + + raw_message.remove_reaction.assert_awaited_once_with("👀", adapter._client.user) + raw_message.add_reaction.assert_not_awaited() diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 0de00b736..09cdd8a44 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1980,7 +1980,7 @@ class TestMatrixReactions: @pytest.mark.asyncio async def test_on_processing_complete_sends_check(self): - from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome self.adapter._reactions_enabled = True self.adapter._send_reaction = AsyncMock(return_value=True) @@ -1994,9 +1994,28 @@ class TestMatrixReactions: raw_message={}, message_id="$msg1", ) - await self.adapter.on_processing_complete(event, success=True) + await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅") + @pytest.mark.asyncio + async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self): + from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome + + self.adapter._reactions_enabled = True + self.adapter._send_reaction = AsyncMock(return_value=True) + + source = MagicMock() + source.chat_id = "!room:ex" + event = MessageEvent( + text="hello", + message_type=MessageType.TEXT, + source=source, + raw_message={}, + message_id="$msg1", + ) + await self.adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED) + self.adapter._send_reaction.assert_not_called() + @pytest.mark.asyncio async def test_reactions_disabled(self): from gateway.platforms.base import MessageEvent, MessageType diff --git a/tests/gateway/test_telegram_reactions.py b/tests/gateway/test_telegram_reactions.py index 5068adb9f..98a75afbe 100644 --- a/tests/gateway/test_telegram_reactions.py +++ b/tests/gateway/test_telegram_reactions.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock import pytest from gateway.config import Platform, PlatformConfig -from gateway.platforms.base import MessageEvent, MessageType +from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome from gateway.session import SessionSource @@ -180,7 +180,7 @@ async def test_on_processing_complete_success(monkeypatch): adapter = _make_adapter() event = _make_event() - await adapter.on_processing_complete(event, success=True) + await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) adapter._bot.set_message_reaction.assert_awaited_once_with( chat_id=123, @@ -196,7 +196,7 @@ async def test_on_processing_complete_failure(monkeypatch): adapter = _make_adapter() event = _make_event() - await adapter.on_processing_complete(event, success=False) + await adapter.on_processing_complete(event, ProcessingOutcome.FAILURE) adapter._bot.set_message_reaction.assert_awaited_once_with( chat_id=123, @@ -212,7 +212,19 @@ async def test_on_processing_complete_skipped_when_disabled(monkeypatch): adapter = _make_adapter() event = _make_event() - await adapter.on_processing_complete(event, success=True) + await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS) + + adapter._bot.set_message_reaction.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_processing_complete_cancelled_keeps_existing_reaction(monkeypatch): + """Expected cancellation should not replace the in-progress reaction.""" + monkeypatch.setenv("TELEGRAM_REACTIONS", "true") + adapter = _make_adapter() + event = _make_event() + + await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED) adapter._bot.set_message_reaction.assert_not_awaited()