diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 2cda33642..5522c631d 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -74,6 +74,8 @@ class GatewayStreamConsumer: self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA) self._last_edit_time = 0.0 self._last_sent_text = "" # Track last-sent text to skip redundant edits + self._fallback_final_send = False + self._fallback_prefix = "" @property def already_sent(self) -> bool: @@ -138,12 +140,19 @@ class GatewayStreamConsumer: while ( len(self._accumulated) > _safe_limit and self._message_id is not None + and self._edit_supported ): split_at = self._accumulated.rfind("\n", 0, _safe_limit) if split_at < _safe_limit // 2: split_at = _safe_limit chunk = self._accumulated[:split_at] await self._send_or_edit(chunk) + if self._fallback_final_send: + # Edit failed while attempting to split an oversized + # message. Keep the full accumulated text intact so + # the fallback final-send path can deliver the + # remaining continuation without dropping content. + break self._accumulated = self._accumulated[split_at:].lstrip("\n") self._message_id = None self._last_sent_text = "" @@ -156,9 +165,17 @@ class GatewayStreamConsumer: self._last_edit_time = time.monotonic() if got_done: - # Final edit without cursor - if self._accumulated and self._message_id: - await self._send_or_edit(self._accumulated) + # Final edit without cursor. If progressive editing failed + # mid-stream, send a single continuation/fallback message + # here instead of letting the base gateway path send the + # full response again. + if self._accumulated: + if self._fallback_final_send: + await self._send_fallback_final(self._accumulated) + elif self._message_id: + await self._send_or_edit(self._accumulated) + elif not self._already_sent: + await self._send_or_edit(self._accumulated) return # Tool boundary: the should_edit block above already flushed @@ -169,6 +186,8 @@ class GatewayStreamConsumer: self._message_id = None self._accumulated = "" self._last_sent_text = "" + self._fallback_final_send = False + self._fallback_prefix = "" await asyncio.sleep(0.05) # Small yield to not busy-loop @@ -207,6 +226,86 @@ class GatewayStreamConsumer: # Strip trailing whitespace/newlines but preserve leading content return cleaned.rstrip() + def _visible_prefix(self) -> str: + """Return the visible text already shown in the streamed message.""" + prefix = self._last_sent_text or "" + if self.cfg.cursor and prefix.endswith(self.cfg.cursor): + prefix = prefix[:-len(self.cfg.cursor)] + return self._clean_for_display(prefix) + + def _continuation_text(self, final_text: str) -> str: + """Return only the part of final_text the user has not already seen.""" + prefix = self._fallback_prefix or self._visible_prefix() + if prefix and final_text.startswith(prefix): + return final_text[len(prefix):].lstrip() + return final_text + + @staticmethod + def _split_text_chunks(text: str, limit: int) -> list[str]: + """Split text into reasonably sized chunks for fallback sends.""" + if len(text) <= limit: + return [text] + chunks: list[str] = [] + remaining = text + while len(remaining) > limit: + split_at = remaining.rfind("\n", 0, limit) + if split_at < limit // 2: + split_at = limit + chunks.append(remaining[:split_at]) + remaining = remaining[split_at:].lstrip("\n") + if remaining: + chunks.append(remaining) + return chunks + + async def _send_fallback_final(self, text: str) -> None: + """Send the final continuation after streaming edits stop working.""" + final_text = self._clean_for_display(text) + continuation = self._continuation_text(final_text) + self._fallback_final_send = False + if not continuation.strip(): + # Nothing new to send — the visible partial already matches final text. + self._already_sent = True + return + + raw_limit = getattr(self.adapter, "MAX_MESSAGE_LENGTH", 4096) + safe_limit = max(500, raw_limit - 100) + chunks = self._split_text_chunks(continuation, safe_limit) + + last_message_id: Optional[str] = None + last_successful_chunk = "" + sent_any_chunk = False + for chunk in chunks: + result = await self.adapter.send( + chat_id=self.chat_id, + content=chunk, + metadata=self.metadata, + ) + if not result.success: + if sent_any_chunk: + # Some continuation text already reached the user. Suppress + # the base gateway final-send path so we don't resend the + # full response and create another duplicate. + self._already_sent = True + self._message_id = last_message_id + self._last_sent_text = last_successful_chunk + self._fallback_prefix = "" + return + # No fallback chunk reached the user — allow the normal gateway + # final-send path to try one more time. + self._already_sent = False + self._message_id = None + self._last_sent_text = "" + self._fallback_prefix = "" + return + sent_any_chunk = True + last_successful_chunk = chunk + last_message_id = result.message_id or last_message_id + + self._message_id = last_message_id + self._already_sent = True + self._last_sent_text = chunks[-1] + self._fallback_prefix = "" + async def _send_or_edit(self, text: str) -> None: """Send or edit the streaming message.""" # Strip MEDIA: directives so they don't appear as visible text. @@ -232,14 +331,16 @@ class GatewayStreamConsumer: self._last_sent_text = text else: # If an edit fails mid-stream (especially Telegram flood control), - # stop progressive edits and let the normal final send path deliver - # the complete answer instead of leaving the user with a partial. + # stop progressive edits and send only the missing tail once the + # final response is available. logger.debug("Edit failed, disabling streaming for this adapter") + self._fallback_prefix = self._visible_prefix() + self._fallback_final_send = True self._edit_supported = False - self._already_sent = False + self._already_sent = True else: # Editing not supported — skip intermediate updates. - # The final response will be sent by the normal path. + # The final response will be sent by the fallback path. pass else: # First message — send new diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 6c908bbe4..ddc88fc2f 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -324,3 +324,91 @@ class TestSegmentBreakOnToolBoundary: await consumer.run() assert consumer.already_sent + + @pytest.mark.asyncio + async def test_edit_failure_sends_only_unsent_tail_at_finish(self): + """If an edit fails mid-stream, send only the missing tail once at finish.""" + adapter = MagicMock() + send_results = [ + SimpleNamespace(success=True, message_id="msg_1"), + SimpleNamespace(success=True, message_id="msg_2"), + ] + adapter.send = AsyncMock(side_effect=send_results) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6")) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉") + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Hello") + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(" world") + await asyncio.sleep(0.08) + consumer.finish() + await task + + assert adapter.send.call_count == 2 + first_text = adapter.send.call_args_list[0][1]["content"] + second_text = adapter.send.call_args_list[1][1]["content"] + assert "Hello" in first_text + assert second_text.strip() == "world" + assert consumer.already_sent + + @pytest.mark.asyncio + async def test_segment_break_clears_failed_edit_fallback_state(self): + """A tool boundary after edit failure must not duplicate the next segment.""" + adapter = MagicMock() + send_results = [ + SimpleNamespace(success=True, message_id="msg_1"), + SimpleNamespace(success=True, message_id="msg_2"), + ] + adapter.send = AsyncMock(side_effect=send_results) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6")) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉") + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Hello") + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(" world") + await asyncio.sleep(0.08) + consumer.on_delta(None) + consumer.on_delta("Next segment") + consumer.finish() + await task + + sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] + assert sent_texts == ["Hello ▉", "Next segment"] + + @pytest.mark.asyncio + async def test_fallback_final_splits_long_continuation_without_dropping_text(self): + """Long continuation tails should be chunked when fallback final-send runs.""" + adapter = MagicMock() + adapter.send = AsyncMock(side_effect=[ + SimpleNamespace(success=True, message_id="msg_1"), + SimpleNamespace(success=True, message_id="msg_2"), + SimpleNamespace(success=True, message_id="msg_3"), + ]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6")) + adapter.MAX_MESSAGE_LENGTH = 610 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉") + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + prefix = "abc" + tail = "x" * 620 + consumer.on_delta(prefix) + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(tail) + await asyncio.sleep(0.08) + consumer.finish() + await task + + sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] + assert len(sent_texts) == 3 + assert sent_texts[0].startswith(prefix) + assert sum(len(t) for t in sent_texts[1:]) == len(tail)