fix(gateway): avoid false failure reactions on restart cancellation

This commit is contained in:
Kenny Xie 2026-04-08 16:07:07 -07:00 committed by Teknium
parent af7d809354
commit 4f2f09affa
8 changed files with 131 additions and 26 deletions

View file

@ -502,6 +502,14 @@ class MessageType(Enum):
COMMAND = "command" # /command style COMMAND = "command" # /command style
class ProcessingOutcome(Enum):
"""Result classification for message-processing lifecycle hooks."""
SUCCESS = "success"
FAILURE = "failure"
CANCELLED = "cancelled"
@dataclass @dataclass
class MessageEvent: class MessageEvent:
""" """
@ -625,6 +633,7 @@ class BasePlatformAdapter(ABC):
# Gateway shutdown cancels these so an old gateway instance doesn't keep # Gateway shutdown cancels these so an old gateway instance doesn't keep
# working on a task after --replace or manual restarts. # working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set() self._background_tasks: set[asyncio.Task] = set()
self._expected_cancelled_tasks: set[asyncio.Task] = set()
# Chats where auto-TTS on voice input is disabled (set by /voice off) # Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set() self._auto_tts_disabled_chats: set = set()
# Chats where typing indicator is paused (e.g. during approval waits). # Chats where typing indicator is paused (e.g. during approval waits).
@ -1133,7 +1142,7 @@ class BasePlatformAdapter(ABC):
async def on_processing_start(self, event: MessageEvent) -> None: async def on_processing_start(self, event: MessageEvent) -> None:
"""Hook called when background processing begins.""" """Hook called when background processing begins."""
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
"""Hook called when background processing completes.""" """Hook called when background processing completes."""
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None: async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
@ -1352,6 +1361,7 @@ class BasePlatformAdapter(ABC):
return return
if hasattr(task, "add_done_callback"): if hasattr(task, "add_done_callback"):
task.add_done_callback(self._background_tasks.discard) task.add_done_callback(self._background_tasks.discard)
task.add_done_callback(self._expected_cancelled_tasks.discard)
@staticmethod @staticmethod
def _get_human_delay() -> float: def _get_human_delay() -> float:
@ -1580,7 +1590,11 @@ class BasePlatformAdapter(ABC):
# Determine overall success for the processing hook # Determine overall success for the processing hook
processing_ok = delivery_succeeded if delivery_attempted else not bool(response) processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
await self._run_processing_hook("on_processing_complete", event, processing_ok) await self._run_processing_hook(
"on_processing_complete",
event,
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
)
# Check if there's a pending message that was queued during our processing # Check if there's a pending message that was queued during our processing
if session_key in self._pending_messages: if session_key in self._pending_messages:
@ -1599,10 +1613,14 @@ class BasePlatformAdapter(ABC):
return # Already cleaned up return # Already cleaned up
except asyncio.CancelledError: except asyncio.CancelledError:
await self._run_processing_hook("on_processing_complete", event, False) current_task = asyncio.current_task()
outcome = ProcessingOutcome.CANCELLED
if current_task is None or current_task not in self._expected_cancelled_tasks:
outcome = ProcessingOutcome.FAILURE
await self._run_processing_hook("on_processing_complete", event, outcome)
raise raise
except Exception as e: except Exception as e:
await self._run_processing_hook("on_processing_complete", event, False) await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True) logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
# Send the error to the user so they aren't left with radio silence # Send the error to the user so they aren't left with radio silence
try: try:
@ -1646,10 +1664,12 @@ class BasePlatformAdapter(ABC):
""" """
tasks = [task for task in self._background_tasks if not task.done()] tasks = [task for task in self._background_tasks if not task.done()]
for task in tasks: for task in tasks:
self._expected_cancelled_tasks.add(task)
task.cancel() task.cancel()
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
self._background_tasks.clear() self._background_tasks.clear()
self._expected_cancelled_tasks.clear()
self._pending_messages.clear() self._pending_messages.clear()
self._active_sessions.clear() self._active_sessions.clear()

View file

@ -49,6 +49,7 @@ from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
MessageType, MessageType,
ProcessingOutcome,
SendResult, SendResult,
cache_image_from_url, cache_image_from_url,
cache_audio_from_url, cache_audio_from_url,
@ -754,14 +755,17 @@ class DiscordAdapter(BasePlatformAdapter):
if hasattr(message, "add_reaction"): if hasattr(message, "add_reaction"):
await self._add_reaction(message, "👀") await self._add_reaction(message, "👀")
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
"""Swap the in-progress reaction for a final success/failure reaction.""" """Swap the in-progress reaction for a final success/failure reaction."""
if not self._reactions_enabled(): if not self._reactions_enabled():
return return
message = event.raw_message message = event.raw_message
if hasattr(message, "add_reaction"): if hasattr(message, "add_reaction"):
await self._remove_reaction(message, "👀") await self._remove_reaction(message, "👀")
await self._add_reaction(message, "" if success else "") if outcome == ProcessingOutcome.SUCCESS:
await self._add_reaction(message, "")
elif outcome == ProcessingOutcome.FAILURE:
await self._add_reaction(message, "")
async def send( async def send(
self, self,

View file

@ -40,6 +40,7 @@ from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
MessageType, MessageType,
ProcessingOutcome,
SendResult, SendResult,
) )
@ -1479,7 +1480,7 @@ class MatrixAdapter(BasePlatformAdapter):
await self._send_reaction(room_id, msg_id, "\U0001f440") await self._send_reaction(room_id, msg_id, "\U0001f440")
async def on_processing_complete( async def on_processing_complete(
self, event: MessageEvent, success: bool, self, event: MessageEvent, outcome: ProcessingOutcome,
) -> None: ) -> None:
"""Replace eyes with checkmark (success) or cross (failure).""" """Replace eyes with checkmark (success) or cross (failure)."""
if not self._reactions_enabled: if not self._reactions_enabled:
@ -1488,11 +1489,15 @@ class MatrixAdapter(BasePlatformAdapter):
room_id = event.source.chat_id room_id = event.source.chat_id
if not msg_id or not room_id: if not msg_id or not room_id:
return return
if outcome == ProcessingOutcome.CANCELLED:
return
# Note: Matrix doesn't support removing a specific reaction easily # Note: Matrix doesn't support removing a specific reaction easily
# without tracking the reaction event_id. We send the new reaction; # without tracking the reaction event_id. We send the new reaction;
# the eyes stays (acceptable UX — both are visible). # the eyes stays (acceptable UX — both are visible).
await self._send_reaction( await self._send_reaction(
room_id, msg_id, "\u2705" if success else "\u274c", room_id,
msg_id,
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
) )
async def _on_reaction(self, room: Any, event: Any) -> None: async def _on_reaction(self, room: Any, event: Any) -> None:

View file

@ -60,6 +60,7 @@ from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
MessageType, MessageType,
ProcessingOutcome,
SendResult, SendResult,
cache_image_from_bytes, cache_image_from_bytes,
cache_audio_from_bytes, cache_audio_from_bytes,
@ -2732,7 +2733,7 @@ class TelegramAdapter(BasePlatformAdapter):
if chat_id and message_id: if chat_id and message_id:
await self._set_reaction(chat_id, message_id, "\U0001f440") await self._set_reaction(chat_id, message_id, "\U0001f440")
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
"""Swap the in-progress reaction for a final success/failure reaction. """Swap the in-progress reaction for a final success/failure reaction.
Unlike Discord (additive reactions), Telegram's set_message_reaction Unlike Discord (additive reactions), Telegram's set_message_reaction
@ -2742,5 +2743,9 @@ class TelegramAdapter(BasePlatformAdapter):
return return
chat_id = getattr(event.source, "chat_id", None) chat_id = getattr(event.source, "chat_id", None)
message_id = getattr(event, "message_id", None) message_id = getattr(event, "message_id", None)
if chat_id and message_id: if chat_id and message_id and outcome != ProcessingOutcome.CANCELLED:
await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c") await self._set_reaction(
chat_id,
message_id,
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
)

View file

@ -6,7 +6,7 @@ from types import SimpleNamespace
import pytest import pytest
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult from gateway.platforms.base import BasePlatformAdapter, MessageEvent, ProcessingOutcome, SendResult
from gateway.session import SessionSource, build_session_key from gateway.session import SessionSource, build_session_key
@ -44,8 +44,8 @@ class DummyTelegramAdapter(BasePlatformAdapter):
async def on_processing_start(self, event: MessageEvent) -> None: async def on_processing_start(self, event: MessageEvent) -> None:
self.processing_hooks.append(("start", event.message_id)) self.processing_hooks.append(("start", event.message_id))
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
self.processing_hooks.append(("complete", event.message_id, success)) self.processing_hooks.append(("complete", event.message_id, outcome))
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent: def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
@ -142,7 +142,7 @@ class TestBasePlatformTopicSessions:
] ]
assert adapter.processing_hooks == [ assert adapter.processing_hooks == [
("start", "1"), ("start", "1"),
("complete", "1", True), ("complete", "1", ProcessingOutcome.SUCCESS),
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -168,7 +168,7 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [ assert adapter.processing_hooks == [
("start", "1"), ("start", "1"),
("complete", "1", False), ("complete", "1", ProcessingOutcome.FAILURE),
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -190,7 +190,7 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [ assert adapter.processing_hooks == [
("start", "1"), ("start", "1"),
("complete", "1", False), ("complete", "1", ProcessingOutcome.FAILURE),
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -218,5 +218,31 @@ class TestBasePlatformTopicSessions:
assert adapter.processing_hooks == [ assert adapter.processing_hooks == [
("start", "1"), ("start", "1"),
("complete", "1", False), ("complete", "1", ProcessingOutcome.FAILURE),
]
@pytest.mark.asyncio
async def test_cancel_background_tasks_marks_expected_cancellation_cancelled(self):
adapter = DummyTelegramAdapter()
release = asyncio.Event()
async def handler(_event):
await release.wait()
return "ack"
async def hold_typing(_chat_id, interval=2.0, metadata=None):
await asyncio.Event().wait()
adapter.set_message_handler(handler)
adapter._keep_typing = hold_typing
event = _make_event("-1001", "17585")
await adapter.handle_message(event)
await asyncio.sleep(0)
await adapter.cancel_background_tasks()
assert adapter.processing_hooks == [
("start", "1"),
("complete", "1", ProcessingOutcome.CANCELLED),
] ]

View file

@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, SendResult from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome, SendResult
from gateway.session import SessionSource, build_session_key from gateway.session import SessionSource, build_session_key
@ -212,7 +212,7 @@ async def test_reactions_disabled_via_env_zero(adapter, monkeypatch):
event = _make_event("5", raw_message) event = _make_event("5", raw_message)
await adapter.on_processing_start(event) await adapter.on_processing_start(event)
await adapter.on_processing_complete(event, success=True) await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
raw_message.add_reaction.assert_not_awaited() raw_message.add_reaction.assert_not_awaited()
raw_message.remove_reaction.assert_not_awaited() raw_message.remove_reaction.assert_not_awaited()
@ -232,3 +232,17 @@ async def test_reactions_enabled_by_default(adapter, monkeypatch):
await adapter.on_processing_start(event) await adapter.on_processing_start(event)
raw_message.add_reaction.assert_awaited_once_with("👀") raw_message.add_reaction.assert_awaited_once_with("👀")
@pytest.mark.asyncio
async def test_on_processing_complete_cancelled_removes_eyes_without_terminal_reaction(adapter):
raw_message = SimpleNamespace(
add_reaction=AsyncMock(),
remove_reaction=AsyncMock(),
)
event = _make_event("7", raw_message)
await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
raw_message.remove_reaction.assert_awaited_once_with("👀", adapter._client.user)
raw_message.add_reaction.assert_not_awaited()

View file

@ -1980,7 +1980,7 @@ class TestMatrixReactions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_processing_complete_sends_check(self): async def test_on_processing_complete_sends_check(self):
from gateway.platforms.base import MessageEvent, MessageType from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
self.adapter._reactions_enabled = True self.adapter._reactions_enabled = True
self.adapter._send_reaction = AsyncMock(return_value=True) self.adapter._send_reaction = AsyncMock(return_value=True)
@ -1994,9 +1994,28 @@ class TestMatrixReactions:
raw_message={}, raw_message={},
message_id="$msg1", message_id="$msg1",
) )
await self.adapter.on_processing_complete(event, success=True) await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "") self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "")
@pytest.mark.asyncio
async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self):
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
self.adapter._reactions_enabled = True
self.adapter._send_reaction = AsyncMock(return_value=True)
source = MagicMock()
source.chat_id = "!room:ex"
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=source,
raw_message={},
message_id="$msg1",
)
await self.adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
self.adapter._send_reaction.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reactions_disabled(self): async def test_reactions_disabled(self):
from gateway.platforms.base import MessageEvent, MessageType from gateway.platforms.base import MessageEvent, MessageType

View file

@ -6,7 +6,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
from gateway.session import SessionSource from gateway.session import SessionSource
@ -180,7 +180,7 @@ async def test_on_processing_complete_success(monkeypatch):
adapter = _make_adapter() adapter = _make_adapter()
event = _make_event() event = _make_event()
await adapter.on_processing_complete(event, success=True) await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
adapter._bot.set_message_reaction.assert_awaited_once_with( adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123, chat_id=123,
@ -196,7 +196,7 @@ async def test_on_processing_complete_failure(monkeypatch):
adapter = _make_adapter() adapter = _make_adapter()
event = _make_event() event = _make_event()
await adapter.on_processing_complete(event, success=False) await adapter.on_processing_complete(event, ProcessingOutcome.FAILURE)
adapter._bot.set_message_reaction.assert_awaited_once_with( adapter._bot.set_message_reaction.assert_awaited_once_with(
chat_id=123, chat_id=123,
@ -212,7 +212,19 @@ async def test_on_processing_complete_skipped_when_disabled(monkeypatch):
adapter = _make_adapter() adapter = _make_adapter()
event = _make_event() event = _make_event()
await adapter.on_processing_complete(event, success=True) await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
adapter._bot.set_message_reaction.assert_not_awaited()
@pytest.mark.asyncio
async def test_on_processing_complete_cancelled_keeps_existing_reaction(monkeypatch):
"""Expected cancellation should not replace the in-progress reaction."""
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
adapter = _make_adapter()
event = _make_event()
await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
adapter._bot.set_message_reaction.assert_not_awaited() adapter._bot.set_message_reaction.assert_not_awaited()