fix(gateway): avoid false failure reactions on restart cancellation

This commit is contained in:
Kenny Xie 2026-04-08 16:07:07 -07:00 committed by Teknium
parent af7d809354
commit 4f2f09affa
8 changed files with 131 additions and 26 deletions

View file

@ -502,6 +502,14 @@ class MessageType(Enum):
COMMAND = "command" # /command style
class ProcessingOutcome(Enum):
"""Result classification for message-processing lifecycle hooks."""
SUCCESS = "success"
FAILURE = "failure"
CANCELLED = "cancelled"
@dataclass
class MessageEvent:
"""
@ -625,6 +633,7 @@ class BasePlatformAdapter(ABC):
# Gateway shutdown cancels these so an old gateway instance doesn't keep
# working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set()
self._expected_cancelled_tasks: set[asyncio.Task] = set()
# Chats where auto-TTS on voice input is disabled (set by /voice off)
self._auto_tts_disabled_chats: set = set()
# Chats where typing indicator is paused (e.g. during approval waits).
@ -1133,7 +1142,7 @@ class BasePlatformAdapter(ABC):
async def on_processing_start(self, event: MessageEvent) -> None:
"""Hook called when background processing begins."""
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
"""Hook called when background processing completes."""
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
@ -1352,6 +1361,7 @@ class BasePlatformAdapter(ABC):
return
if hasattr(task, "add_done_callback"):
task.add_done_callback(self._background_tasks.discard)
task.add_done_callback(self._expected_cancelled_tasks.discard)
@staticmethod
def _get_human_delay() -> float:
@ -1580,7 +1590,11 @@ class BasePlatformAdapter(ABC):
# Determine overall success for the processing hook
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
await self._run_processing_hook("on_processing_complete", event, processing_ok)
await self._run_processing_hook(
"on_processing_complete",
event,
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
)
# Check if there's a pending message that was queued during our processing
if session_key in self._pending_messages:
@ -1599,10 +1613,14 @@ class BasePlatformAdapter(ABC):
return # Already cleaned up
except asyncio.CancelledError:
await self._run_processing_hook("on_processing_complete", event, False)
current_task = asyncio.current_task()
outcome = ProcessingOutcome.CANCELLED
if current_task is None or current_task not in self._expected_cancelled_tasks:
outcome = ProcessingOutcome.FAILURE
await self._run_processing_hook("on_processing_complete", event, outcome)
raise
except Exception as e:
await self._run_processing_hook("on_processing_complete", event, False)
await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
# Send the error to the user so they aren't left with radio silence
try:
@ -1646,10 +1664,12 @@ class BasePlatformAdapter(ABC):
"""
tasks = [task for task in self._background_tasks if not task.done()]
for task in tasks:
self._expected_cancelled_tasks.add(task)
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._background_tasks.clear()
self._expected_cancelled_tasks.clear()
self._pending_messages.clear()
self._active_sessions.clear()