diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 5e00c9f1ddd..2d940499e26 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -33,6 +33,7 @@ _AUDIO_EXTS = frozenset({'.ogg', '.opus', '.mp3', '.wav', '.m4a', '.flac'}) # delivered as a regular document. _TELEGRAM_AUDIO_ATTACHMENT_EXTS = frozenset({'.mp3', '.m4a'}) _TELEGRAM_VOICE_EXTS = frozenset({'.ogg', '.opus'}) +_POST_DELIVERY_CALLBACK_TIMEOUT_SECONDS = 30.0 def _platform_name(platform) -> str: @@ -4462,6 +4463,15 @@ class BasePlatformAdapter(ABC): except Exception: pass # Last resort — don't let error reporting crash the handler finally: + # Stop typing before any deferred callback work. Post-delivery + # callbacks may perform platform I/O; a stuck callback must not + # leave the typing refresh task running indefinitely. + await _stop_typing_task() + try: + if hasattr(self, "stop_typing"): + await self.stop_typing(event.source.chat_id) + except Exception: + pass # Fire any one-shot post-delivery callback registered for this # session (e.g. deferred background-review notifications). # @@ -4489,11 +4499,12 @@ class BasePlatformAdapter(ABC): try: _post_result = _post_cb() if inspect.isawaitable(_post_result): - await _post_result - except Exception: + await asyncio.wait_for( + _post_result, + timeout=_POST_DELIVERY_CALLBACK_TIMEOUT_SECONDS, + ) + except (asyncio.TimeoutError, Exception): pass - # Stop typing indicator - await _stop_typing_task() # Also cancel any platform-level persistent typing tasks (e.g. Discord) # that may have been recreated by _keep_typing after the last stop_typing() try: diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 28d7327fcdd..646ad92976b 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -9,6 +9,7 @@ from types import SimpleNamespace import pytest +import gateway.platforms.base as base_platform from gateway.config import Platform, PlatformConfig, StreamingConfig from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult from gateway.session import SessionSource @@ -1076,6 +1077,54 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send() assert released == [True] +@pytest.mark.asyncio +async def test_base_processing_stops_typing_before_hung_post_delivery_callback( + monkeypatch, +): + """A stuck post-delivery callback must not keep the typing task alive.""" + monkeypatch.setattr(base_platform, "_POST_DELIVERY_CALLBACK_TIMEOUT_SECONDS", 0.01) + adapter = ProgressCaptureAdapter() + events = [] + + async def _handler(event): + return "done" + + async def _post_delivery_cb(): + events.append("callback-start") + await asyncio.Event().wait() + + async def _stop_typing(chat_id): + events.append("typing-stopped") + await ProgressCaptureAdapter.stop_typing(adapter, chat_id) + + adapter.set_message_handler(_handler) + adapter.stop_typing = _stop_typing + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + thread_id="17585", + ) + event = MessageEvent( + text="hello", + message_type=MessageType.TEXT, + source=source, + message_id="msg-1", + ) + session_key = "agent:main:telegram:group:-1001:17585" + adapter._active_sessions[session_key] = asyncio.Event() + adapter._post_delivery_callbacks[session_key] = _post_delivery_cb + + await asyncio.wait_for( + adapter._process_message_background(event, session_key), timeout=1.0 + ) + + assert [call["content"] for call in adapter.sent] == ["done"] + assert events[:2] == ["typing-stopped", "callback-start"] + assert any(call["metadata"] == {"stopped": True} for call in adapter.typing) + + @pytest.mark.asyncio async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path): import yaml