mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): avoid false failure reactions on restart cancellation
This commit is contained in:
parent
af7d809354
commit
4f2f09affa
8 changed files with 131 additions and 26 deletions
|
|
@ -502,6 +502,14 @@ class MessageType(Enum):
|
||||||
COMMAND = "command" # /command style
|
COMMAND = "command" # /command style
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingOutcome(Enum):
|
||||||
|
"""Result classification for message-processing lifecycle hooks."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILURE = "failure"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageEvent:
|
class MessageEvent:
|
||||||
"""
|
"""
|
||||||
|
|
@ -625,6 +633,7 @@ class BasePlatformAdapter(ABC):
|
||||||
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
||||||
# working on a task after --replace or manual restarts.
|
# working on a task after --replace or manual restarts.
|
||||||
self._background_tasks: set[asyncio.Task] = set()
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
||||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||||
self._auto_tts_disabled_chats: set = set()
|
self._auto_tts_disabled_chats: set = set()
|
||||||
# Chats where typing indicator is paused (e.g. during approval waits).
|
# Chats where typing indicator is paused (e.g. during approval waits).
|
||||||
|
|
@ -1133,7 +1142,7 @@ class BasePlatformAdapter(ABC):
|
||||||
async def on_processing_start(self, event: MessageEvent) -> None:
|
async def on_processing_start(self, event: MessageEvent) -> None:
|
||||||
"""Hook called when background processing begins."""
|
"""Hook called when background processing begins."""
|
||||||
|
|
||||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||||
"""Hook called when background processing completes."""
|
"""Hook called when background processing completes."""
|
||||||
|
|
||||||
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
||||||
|
|
@ -1352,6 +1361,7 @@ class BasePlatformAdapter(ABC):
|
||||||
return
|
return
|
||||||
if hasattr(task, "add_done_callback"):
|
if hasattr(task, "add_done_callback"):
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
task.add_done_callback(self._expected_cancelled_tasks.discard)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_human_delay() -> float:
|
def _get_human_delay() -> float:
|
||||||
|
|
@ -1580,7 +1590,11 @@ class BasePlatformAdapter(ABC):
|
||||||
|
|
||||||
# Determine overall success for the processing hook
|
# Determine overall success for the processing hook
|
||||||
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
|
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
|
||||||
await self._run_processing_hook("on_processing_complete", event, processing_ok)
|
await self._run_processing_hook(
|
||||||
|
"on_processing_complete",
|
||||||
|
event,
|
||||||
|
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
|
||||||
|
)
|
||||||
|
|
||||||
# Check if there's a pending message that was queued during our processing
|
# Check if there's a pending message that was queued during our processing
|
||||||
if session_key in self._pending_messages:
|
if session_key in self._pending_messages:
|
||||||
|
|
@ -1599,10 +1613,14 @@ class BasePlatformAdapter(ABC):
|
||||||
return # Already cleaned up
|
return # Already cleaned up
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self._run_processing_hook("on_processing_complete", event, False)
|
current_task = asyncio.current_task()
|
||||||
|
outcome = ProcessingOutcome.CANCELLED
|
||||||
|
if current_task is None or current_task not in self._expected_cancelled_tasks:
|
||||||
|
outcome = ProcessingOutcome.FAILURE
|
||||||
|
await self._run_processing_hook("on_processing_complete", event, outcome)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self._run_processing_hook("on_processing_complete", event, False)
|
await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
|
||||||
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
|
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
|
||||||
# Send the error to the user so they aren't left with radio silence
|
# Send the error to the user so they aren't left with radio silence
|
||||||
try:
|
try:
|
||||||
|
|
@ -1646,10 +1664,12 @@ class BasePlatformAdapter(ABC):
|
||||||
"""
|
"""
|
||||||
tasks = [task for task in self._background_tasks if not task.done()]
|
tasks = [task for task in self._background_tasks if not task.done()]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
|
self._expected_cancelled_tasks.add(task)
|
||||||
task.cancel()
|
task.cancel()
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
self._background_tasks.clear()
|
self._background_tasks.clear()
|
||||||
|
self._expected_cancelled_tasks.clear()
|
||||||
self._pending_messages.clear()
|
self._pending_messages.clear()
|
||||||
self._active_sessions.clear()
|
self._active_sessions.clear()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ from gateway.platforms.base import (
|
||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
MessageType,
|
MessageType,
|
||||||
|
ProcessingOutcome,
|
||||||
SendResult,
|
SendResult,
|
||||||
cache_image_from_url,
|
cache_image_from_url,
|
||||||
cache_audio_from_url,
|
cache_audio_from_url,
|
||||||
|
|
@ -754,14 +755,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
if hasattr(message, "add_reaction"):
|
if hasattr(message, "add_reaction"):
|
||||||
await self._add_reaction(message, "👀")
|
await self._add_reaction(message, "👀")
|
||||||
|
|
||||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||||
"""Swap the in-progress reaction for a final success/failure reaction."""
|
"""Swap the in-progress reaction for a final success/failure reaction."""
|
||||||
if not self._reactions_enabled():
|
if not self._reactions_enabled():
|
||||||
return
|
return
|
||||||
message = event.raw_message
|
message = event.raw_message
|
||||||
if hasattr(message, "add_reaction"):
|
if hasattr(message, "add_reaction"):
|
||||||
await self._remove_reaction(message, "👀")
|
await self._remove_reaction(message, "👀")
|
||||||
await self._add_reaction(message, "✅" if success else "❌")
|
if outcome == ProcessingOutcome.SUCCESS:
|
||||||
|
await self._add_reaction(message, "✅")
|
||||||
|
elif outcome == ProcessingOutcome.FAILURE:
|
||||||
|
await self._add_reaction(message, "❌")
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ from gateway.platforms.base import (
|
||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
MessageType,
|
MessageType,
|
||||||
|
ProcessingOutcome,
|
||||||
SendResult,
|
SendResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1479,7 +1480,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||||
await self._send_reaction(room_id, msg_id, "\U0001f440")
|
await self._send_reaction(room_id, msg_id, "\U0001f440")
|
||||||
|
|
||||||
async def on_processing_complete(
|
async def on_processing_complete(
|
||||||
self, event: MessageEvent, success: bool,
|
self, event: MessageEvent, outcome: ProcessingOutcome,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Replace eyes with checkmark (success) or cross (failure)."""
|
"""Replace eyes with checkmark (success) or cross (failure)."""
|
||||||
if not self._reactions_enabled:
|
if not self._reactions_enabled:
|
||||||
|
|
@ -1488,11 +1489,15 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||||
room_id = event.source.chat_id
|
room_id = event.source.chat_id
|
||||||
if not msg_id or not room_id:
|
if not msg_id or not room_id:
|
||||||
return
|
return
|
||||||
|
if outcome == ProcessingOutcome.CANCELLED:
|
||||||
|
return
|
||||||
# Note: Matrix doesn't support removing a specific reaction easily
|
# Note: Matrix doesn't support removing a specific reaction easily
|
||||||
# without tracking the reaction event_id. We send the new reaction;
|
# without tracking the reaction event_id. We send the new reaction;
|
||||||
# the eyes stays (acceptable UX — both are visible).
|
# the eyes stays (acceptable UX — both are visible).
|
||||||
await self._send_reaction(
|
await self._send_reaction(
|
||||||
room_id, msg_id, "\u2705" if success else "\u274c",
|
room_id,
|
||||||
|
msg_id,
|
||||||
|
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_reaction(self, room: Any, event: Any) -> None:
|
async def _on_reaction(self, room: Any, event: Any) -> None:
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ from gateway.platforms.base import (
|
||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
MessageType,
|
MessageType,
|
||||||
|
ProcessingOutcome,
|
||||||
SendResult,
|
SendResult,
|
||||||
cache_image_from_bytes,
|
cache_image_from_bytes,
|
||||||
cache_audio_from_bytes,
|
cache_audio_from_bytes,
|
||||||
|
|
@ -2732,7 +2733,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||||
if chat_id and message_id:
|
if chat_id and message_id:
|
||||||
await self._set_reaction(chat_id, message_id, "\U0001f440")
|
await self._set_reaction(chat_id, message_id, "\U0001f440")
|
||||||
|
|
||||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||||
"""Swap the in-progress reaction for a final success/failure reaction.
|
"""Swap the in-progress reaction for a final success/failure reaction.
|
||||||
|
|
||||||
Unlike Discord (additive reactions), Telegram's set_message_reaction
|
Unlike Discord (additive reactions), Telegram's set_message_reaction
|
||||||
|
|
@ -2742,5 +2743,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||||
return
|
return
|
||||||
chat_id = getattr(event.source, "chat_id", None)
|
chat_id = getattr(event.source, "chat_id", None)
|
||||||
message_id = getattr(event, "message_id", None)
|
message_id = getattr(event, "message_id", None)
|
||||||
if chat_id and message_id:
|
if chat_id and message_id and outcome != ProcessingOutcome.CANCELLED:
|
||||||
await self._set_reaction(chat_id, message_id, "\u2705" if success else "\u274c")
|
await self._set_reaction(
|
||||||
|
chat_id,
|
||||||
|
message_id,
|
||||||
|
"\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c",
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from types import SimpleNamespace
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import Platform, PlatformConfig
|
from gateway.config import Platform, PlatformConfig
|
||||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, ProcessingOutcome, SendResult
|
||||||
from gateway.session import SessionSource, build_session_key
|
from gateway.session import SessionSource, build_session_key
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,8 +44,8 @@ class DummyTelegramAdapter(BasePlatformAdapter):
|
||||||
async def on_processing_start(self, event: MessageEvent) -> None:
|
async def on_processing_start(self, event: MessageEvent) -> None:
|
||||||
self.processing_hooks.append(("start", event.message_id))
|
self.processing_hooks.append(("start", event.message_id))
|
||||||
|
|
||||||
async def on_processing_complete(self, event: MessageEvent, success: bool) -> None:
|
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||||||
self.processing_hooks.append(("complete", event.message_id, success))
|
self.processing_hooks.append(("complete", event.message_id, outcome))
|
||||||
|
|
||||||
|
|
||||||
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
|
def _make_event(chat_id: str, thread_id: str, message_id: str = "1") -> MessageEvent:
|
||||||
|
|
@ -142,7 +142,7 @@ class TestBasePlatformTopicSessions:
|
||||||
]
|
]
|
||||||
assert adapter.processing_hooks == [
|
assert adapter.processing_hooks == [
|
||||||
("start", "1"),
|
("start", "1"),
|
||||||
("complete", "1", True),
|
("complete", "1", ProcessingOutcome.SUCCESS),
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -168,7 +168,7 @@ class TestBasePlatformTopicSessions:
|
||||||
|
|
||||||
assert adapter.processing_hooks == [
|
assert adapter.processing_hooks == [
|
||||||
("start", "1"),
|
("start", "1"),
|
||||||
("complete", "1", False),
|
("complete", "1", ProcessingOutcome.FAILURE),
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -190,7 +190,7 @@ class TestBasePlatformTopicSessions:
|
||||||
|
|
||||||
assert adapter.processing_hooks == [
|
assert adapter.processing_hooks == [
|
||||||
("start", "1"),
|
("start", "1"),
|
||||||
("complete", "1", False),
|
("complete", "1", ProcessingOutcome.FAILURE),
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -218,5 +218,31 @@ class TestBasePlatformTopicSessions:
|
||||||
|
|
||||||
assert adapter.processing_hooks == [
|
assert adapter.processing_hooks == [
|
||||||
("start", "1"),
|
("start", "1"),
|
||||||
("complete", "1", False),
|
("complete", "1", ProcessingOutcome.FAILURE),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_background_tasks_marks_expected_cancellation_cancelled(self):
|
||||||
|
adapter = DummyTelegramAdapter()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def handler(_event):
|
||||||
|
await release.wait()
|
||||||
|
return "ack"
|
||||||
|
|
||||||
|
async def hold_typing(_chat_id, interval=2.0, metadata=None):
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
adapter.set_message_handler(handler)
|
||||||
|
adapter._keep_typing = hold_typing
|
||||||
|
|
||||||
|
event = _make_event("-1001", "17585")
|
||||||
|
await adapter.handle_message(event)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
await adapter.cancel_background_tasks()
|
||||||
|
|
||||||
|
assert adapter.processing_hooks == [
|
||||||
|
("start", "1"),
|
||||||
|
("complete", "1", ProcessingOutcome.CANCELLED),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import Platform, PlatformConfig
|
from gateway.config import Platform, PlatformConfig
|
||||||
from gateway.platforms.base import MessageEvent, MessageType, SendResult
|
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome, SendResult
|
||||||
from gateway.session import SessionSource, build_session_key
|
from gateway.session import SessionSource, build_session_key
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -212,7 +212,7 @@ async def test_reactions_disabled_via_env_zero(adapter, monkeypatch):
|
||||||
|
|
||||||
event = _make_event("5", raw_message)
|
event = _make_event("5", raw_message)
|
||||||
await adapter.on_processing_start(event)
|
await adapter.on_processing_start(event)
|
||||||
await adapter.on_processing_complete(event, success=True)
|
await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||||
|
|
||||||
raw_message.add_reaction.assert_not_awaited()
|
raw_message.add_reaction.assert_not_awaited()
|
||||||
raw_message.remove_reaction.assert_not_awaited()
|
raw_message.remove_reaction.assert_not_awaited()
|
||||||
|
|
@ -232,3 +232,17 @@ async def test_reactions_enabled_by_default(adapter, monkeypatch):
|
||||||
await adapter.on_processing_start(event)
|
await adapter.on_processing_start(event)
|
||||||
|
|
||||||
raw_message.add_reaction.assert_awaited_once_with("👀")
|
raw_message.add_reaction.assert_awaited_once_with("👀")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_processing_complete_cancelled_removes_eyes_without_terminal_reaction(adapter):
|
||||||
|
raw_message = SimpleNamespace(
|
||||||
|
add_reaction=AsyncMock(),
|
||||||
|
remove_reaction=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
event = _make_event("7", raw_message)
|
||||||
|
await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
|
||||||
|
|
||||||
|
raw_message.remove_reaction.assert_awaited_once_with("👀", adapter._client.user)
|
||||||
|
raw_message.add_reaction.assert_not_awaited()
|
||||||
|
|
|
||||||
|
|
@ -1980,7 +1980,7 @@ class TestMatrixReactions:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_processing_complete_sends_check(self):
|
async def test_on_processing_complete_sends_check(self):
|
||||||
from gateway.platforms.base import MessageEvent, MessageType
|
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||||
|
|
||||||
self.adapter._reactions_enabled = True
|
self.adapter._reactions_enabled = True
|
||||||
self.adapter._send_reaction = AsyncMock(return_value=True)
|
self.adapter._send_reaction = AsyncMock(return_value=True)
|
||||||
|
|
@ -1994,9 +1994,28 @@ class TestMatrixReactions:
|
||||||
raw_message={},
|
raw_message={},
|
||||||
message_id="$msg1",
|
message_id="$msg1",
|
||||||
)
|
)
|
||||||
await self.adapter.on_processing_complete(event, success=True)
|
await self.adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||||
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅")
|
self.adapter._send_reaction.assert_called_once_with("!room:ex", "$msg1", "✅")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_processing_complete_cancelled_sends_no_terminal_reaction(self):
|
||||||
|
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||||
|
|
||||||
|
self.adapter._reactions_enabled = True
|
||||||
|
self.adapter._send_reaction = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
source = MagicMock()
|
||||||
|
source.chat_id = "!room:ex"
|
||||||
|
event = MessageEvent(
|
||||||
|
text="hello",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=source,
|
||||||
|
raw_message={},
|
||||||
|
message_id="$msg1",
|
||||||
|
)
|
||||||
|
await self.adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
|
||||||
|
self.adapter._send_reaction.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reactions_disabled(self):
|
async def test_reactions_disabled(self):
|
||||||
from gateway.platforms.base import MessageEvent, MessageType
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gateway.config import Platform, PlatformConfig
|
from gateway.config import Platform, PlatformConfig
|
||||||
from gateway.platforms.base import MessageEvent, MessageType
|
from gateway.platforms.base import MessageEvent, MessageType, ProcessingOutcome
|
||||||
from gateway.session import SessionSource
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -180,7 +180,7 @@ async def test_on_processing_complete_success(monkeypatch):
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
event = _make_event()
|
event = _make_event()
|
||||||
|
|
||||||
await adapter.on_processing_complete(event, success=True)
|
await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||||
|
|
||||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
|
|
@ -196,7 +196,7 @@ async def test_on_processing_complete_failure(monkeypatch):
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
event = _make_event()
|
event = _make_event()
|
||||||
|
|
||||||
await adapter.on_processing_complete(event, success=False)
|
await adapter.on_processing_complete(event, ProcessingOutcome.FAILURE)
|
||||||
|
|
||||||
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
adapter._bot.set_message_reaction.assert_awaited_once_with(
|
||||||
chat_id=123,
|
chat_id=123,
|
||||||
|
|
@ -212,7 +212,19 @@ async def test_on_processing_complete_skipped_when_disabled(monkeypatch):
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
event = _make_event()
|
event = _make_event()
|
||||||
|
|
||||||
await adapter.on_processing_complete(event, success=True)
|
await adapter.on_processing_complete(event, ProcessingOutcome.SUCCESS)
|
||||||
|
|
||||||
|
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_processing_complete_cancelled_keeps_existing_reaction(monkeypatch):
|
||||||
|
"""Expected cancellation should not replace the in-progress reaction."""
|
||||||
|
monkeypatch.setenv("TELEGRAM_REACTIONS", "true")
|
||||||
|
adapter = _make_adapter()
|
||||||
|
event = _make_event()
|
||||||
|
|
||||||
|
await adapter.on_processing_complete(event, ProcessingOutcome.CANCELLED)
|
||||||
|
|
||||||
adapter._bot.set_message_reaction.assert_not_awaited()
|
adapter._bot.set_message_reaction.assert_not_awaited()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue