diff --git a/gateway/run.py b/gateway/run.py index 0dad9af10..66b3a3705 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -9108,6 +9108,7 @@ class GatewayRunner: chat_id=source.chat_id, config=_consumer_cfg, metadata=_thread_metadata, + run_still_current=_run_still_current, ) except Exception as _sc_err: logger.debug("Proxy: could not set up stream consumer: %s", _sc_err) @@ -9731,6 +9732,7 @@ class GatewayRunner: chat_id=source.chat_id, config=_consumer_cfg, metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None, + run_still_current=_run_still_current, ) if _want_stream_deltas: def _stream_delta_cb(text: str) -> None: diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 78e365712..54e794925 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -21,7 +21,7 @@ import queue import re import time from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Callable, Optional logger = logging.getLogger("gateway.stream_consumer") @@ -83,6 +83,7 @@ class GatewayStreamConsumer: chat_id: str, config: Optional[StreamConsumerConfig] = None, metadata: Optional[dict] = None, + run_still_current: Optional[Callable[[], bool]] = None, ): self.adapter = adapter self.chat_id = chat_id @@ -109,6 +110,11 @@ class GatewayStreamConsumer: getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True ) + # Session staleness guard — when set to False (e.g. after /new or + # /stop), the run() loop will abandon the stream early instead of + # continuing to edit and deliver stale deltas. + self._run_still_current = run_still_current or (lambda: True) + # Think-block filter state (mirrors CLI's _stream_delta tag suppression) self._in_think_block = False self._think_buffer = "" @@ -271,6 +277,12 @@ class GatewayStreamConsumer: try: while True: + # Abandon the stream early if the session has been reset + # (e.g. /new or /stop). Prevents stale deltas from being + # delivered after the user has already moved on. + if not self._run_still_current(): + return + # Drain all available items from the queue got_done = False got_segment_break = False diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 7ae587dad..0519da74a 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -1337,3 +1337,138 @@ class TestCursorStrippingOnFallback: assert consumer._already_sent is True # _last_sent_text must NOT be updated when the edit failed assert consumer._last_sent_text == "Hello ▉" + + +# ── run_still_current staleness guard ──────────────────────────────────── + +class TestRunStillCurrentGuard: + """Verify that the stream consumer abandons delivery when the session is + reset (e.g. /new or /stop), preventing stale deltas from reaching the user.""" + + @pytest.mark.asyncio + async def test_abandons_stream_when_session_reset_before_first_send(self): + """If _run_still_current returns False immediately, the consumer + exits without sending anything — even with queued deltas.""" + adapter = MagicMock() + adapter.send = AsyncMock() + adapter.edit_message = AsyncMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=3) + consumer = GatewayStreamConsumer( + adapter, "chat_123", config, + run_still_current=lambda: False, + ) + + consumer.on_delta("ABC") + consumer.on_delta("DEF") + consumer.on_delta("GHI") + + await consumer.run() + + adapter.send.assert_not_called() + adapter.edit_message.assert_not_called() + assert consumer._final_response_sent is False + + @pytest.mark.asyncio + async def test_abandons_stream_after_one_edit_when_session_reset(self): + """If staleness flips after the first edit, the consumer stops + on the next loop iteration and does not send the final response.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + call_count = [0] + + def is_current(): + call_count[0] += 1 + return call_count[0] == 1 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=3) + consumer = GatewayStreamConsumer( + adapter, "chat_123", config, + run_still_current=is_current, + ) + + consumer.on_delta("First segment") + consumer.on_delta(None) # segment break → resets message_id + consumer.on_delta("Second segment text that will be stale") + # No finish() — staleness should prevent second segment from sending + + await consumer.run() + + # First segment was sent, second was abandoned + assert adapter.send.call_count == 1 + assert "First segment" in adapter.send.call_args_list[0][1]["content"] + assert consumer._final_response_sent is False + + @pytest.mark.asyncio + async def test_normal_delivery_when_session_stays_current(self): + """When _run_still_current always returns True, the consumer + behaves normally and delivers the full response.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer( + adapter, "chat_123", config, + run_still_current=lambda: True, + ) + + consumer.on_delta("Hello, world!") + consumer.finish() + + await consumer.run() + + assert adapter.send.call_count >= 1 + assert consumer._final_response_sent is True + + @pytest.mark.asyncio + async def test_no_callback_defaults_to_always_current(self): + """When run_still_current is not provided (default), the consumer + always considers the session current — backward compatible.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Normal message") + consumer.finish() + + await consumer.run() + + assert adapter.send.call_count >= 1 + assert consumer._final_response_sent is True + + @pytest.mark.asyncio + async def test_abandons_even_with_pending_finish(self): + """If finish() has been called but the session is already reset + before the run loop starts, nothing is sent.""" + adapter = MagicMock() + adapter.send = AsyncMock() + adapter.edit_message = AsyncMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer( + adapter, "chat_123", config, + run_still_current=lambda: False, + ) + + consumer.on_delta("Stale text") + consumer.finish() + + await consumer.run() + + adapter.send.assert_not_called() + adapter.edit_message.assert_not_called() + assert consumer._final_response_sent is False