fix(gateway): add session staleness guard to stream consumer

GatewayStreamConsumer's async run() loop processes queued deltas
without checking if the session has been reset (e.g. /new or /stop).
This causes stale response fragments to be delivered after the user
has already moved to a new session.

Add an optional run_still_current callback to GatewayStreamConsumer.
When it returns False, the run() loop exits early, abandoning any
remaining queued deltas. The gateway's _handle_message_with_agent
binds the existing _run_still_current closure at consumer creation
time, consistent with other async callbacks.

Fixes the bug where /new during active streaming delivers the old
response content alongside the "Session reset!" acknowledgment.

- gateway/stream_consumer.py: add run_still_current param and check
- gateway/run.py: pass _run_still_current at both consumer call sites
- tests: 5 new tests in TestRunStillCurrentGuard
This commit is contained in:
jason 2026-04-24 19:21:22 +08:00
parent 18f3fc8a6f
commit 383239248f
3 changed files with 150 additions and 1 deletions

View file

@ -9108,6 +9108,7 @@ class GatewayRunner:
chat_id=source.chat_id, chat_id=source.chat_id,
config=_consumer_cfg, config=_consumer_cfg,
metadata=_thread_metadata, metadata=_thread_metadata,
run_still_current=_run_still_current,
) )
except Exception as _sc_err: except Exception as _sc_err:
logger.debug("Proxy: could not set up stream consumer: %s", _sc_err) logger.debug("Proxy: could not set up stream consumer: %s", _sc_err)
@ -9731,6 +9732,7 @@ class GatewayRunner:
chat_id=source.chat_id, chat_id=source.chat_id,
config=_consumer_cfg, config=_consumer_cfg,
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None, metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
run_still_current=_run_still_current,
) )
if _want_stream_deltas: if _want_stream_deltas:
def _stream_delta_cb(text: str) -> None: def _stream_delta_cb(text: str) -> None:

View file

@ -21,7 +21,7 @@ import queue
import re import re
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Callable, Optional
logger = logging.getLogger("gateway.stream_consumer") logger = logging.getLogger("gateway.stream_consumer")
@ -83,6 +83,7 @@ class GatewayStreamConsumer:
chat_id: str, chat_id: str,
config: Optional[StreamConsumerConfig] = None, config: Optional[StreamConsumerConfig] = None,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
run_still_current: Optional[Callable[[], bool]] = None,
): ):
self.adapter = adapter self.adapter = adapter
self.chat_id = chat_id self.chat_id = chat_id
@ -109,6 +110,11 @@ class GatewayStreamConsumer:
getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True
) )
# Session staleness guard — when set to False (e.g. after /new or
# /stop), the run() loop will abandon the stream early instead of
# continuing to edit and deliver stale deltas.
self._run_still_current = run_still_current or (lambda: True)
# Think-block filter state (mirrors CLI's _stream_delta tag suppression) # Think-block filter state (mirrors CLI's _stream_delta tag suppression)
self._in_think_block = False self._in_think_block = False
self._think_buffer = "" self._think_buffer = ""
@ -271,6 +277,12 @@ class GatewayStreamConsumer:
try: try:
while True: while True:
# Abandon the stream early if the session has been reset
# (e.g. /new or /stop). Prevents stale deltas from being
# delivered after the user has already moved on.
if not self._run_still_current():
return
# Drain all available items from the queue # Drain all available items from the queue
got_done = False got_done = False
got_segment_break = False got_segment_break = False

View file

@ -1337,3 +1337,138 @@ class TestCursorStrippingOnFallback:
assert consumer._already_sent is True assert consumer._already_sent is True
# _last_sent_text must NOT be updated when the edit failed # _last_sent_text must NOT be updated when the edit failed
assert consumer._last_sent_text == "Hello ▉" assert consumer._last_sent_text == "Hello ▉"
# ── run_still_current staleness guard ────────────────────────────────────
class TestRunStillCurrentGuard:
"""Verify that the stream consumer abandons delivery when the session is
reset (e.g. /new or /stop), preventing stale deltas from reaching the user."""
@pytest.mark.asyncio
async def test_abandons_stream_when_session_reset_before_first_send(self):
"""If _run_still_current returns False immediately, the consumer
exits without sending anything even with queued deltas."""
adapter = MagicMock()
adapter.send = AsyncMock()
adapter.edit_message = AsyncMock()
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=3)
consumer = GatewayStreamConsumer(
adapter, "chat_123", config,
run_still_current=lambda: False,
)
consumer.on_delta("ABC")
consumer.on_delta("DEF")
consumer.on_delta("GHI")
await consumer.run()
adapter.send.assert_not_called()
adapter.edit_message.assert_not_called()
assert consumer._final_response_sent is False
@pytest.mark.asyncio
async def test_abandons_stream_after_one_edit_when_session_reset(self):
"""If staleness flips after the first edit, the consumer stops
on the next loop iteration and does not send the final response."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
call_count = [0]
def is_current():
call_count[0] += 1
return call_count[0] == 1
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=3)
consumer = GatewayStreamConsumer(
adapter, "chat_123", config,
run_still_current=is_current,
)
consumer.on_delta("First segment")
consumer.on_delta(None) # segment break → resets message_id
consumer.on_delta("Second segment text that will be stale")
# No finish() — staleness should prevent second segment from sending
await consumer.run()
# First segment was sent, second was abandoned
assert adapter.send.call_count == 1
assert "First segment" in adapter.send.call_args_list[0][1]["content"]
assert consumer._final_response_sent is False
@pytest.mark.asyncio
async def test_normal_delivery_when_session_stays_current(self):
"""When _run_still_current always returns True, the consumer
behaves normally and delivers the full response."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(
adapter, "chat_123", config,
run_still_current=lambda: True,
)
consumer.on_delta("Hello, world!")
consumer.finish()
await consumer.run()
assert adapter.send.call_count >= 1
assert consumer._final_response_sent is True
@pytest.mark.asyncio
async def test_no_callback_defaults_to_always_current(self):
"""When run_still_current is not provided (default), the consumer
always considers the session current backward compatible."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
consumer.on_delta("Normal message")
consumer.finish()
await consumer.run()
assert adapter.send.call_count >= 1
assert consumer._final_response_sent is True
@pytest.mark.asyncio
async def test_abandons_even_with_pending_finish(self):
"""If finish() has been called but the session is already reset
before the run loop starts, nothing is sent."""
adapter = MagicMock()
adapter.send = AsyncMock()
adapter.edit_message = AsyncMock()
adapter.MAX_MESSAGE_LENGTH = 4096
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
consumer = GatewayStreamConsumer(
adapter, "chat_123", config,
run_still_current=lambda: False,
)
consumer.on_delta("Stale text")
consumer.finish()
await consumer.run()
adapter.send.assert_not_called()
adapter.edit_message.assert_not_called()
assert consumer._final_response_sent is False