diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 7f4a73d0b7..59e72755a0 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -28,6 +28,10 @@ logger = logging.getLogger("gateway.stream_consumer") # Sentinel to signal the stream is complete _DONE = object() +# Sentinel to signal a tool boundary — finalize current message and start a +# new one so that subsequent text appears below tool progress messages. +_NEW_SEGMENT = object() + @dataclass class StreamConsumerConfig: @@ -78,9 +82,16 @@ class GatewayStreamConsumer: return self._already_sent def on_delta(self, text: str) -> None: - """Thread-safe callback — called from the agent's worker thread.""" + """Thread-safe callback — called from the agent's worker thread. + + When *text* is ``None``, signals a tool boundary: the current message + is finalized and subsequent text will be sent as a new message so it + appears below any tool-progress messages the gateway sent in between. + """ if text: self._queue.put(text) + elif text is None: + self._queue.put(_NEW_SEGMENT) def finish(self) -> None: """Signal that the stream is complete.""" @@ -96,12 +107,16 @@ class GatewayStreamConsumer: while True: # Drain all available items from the queue got_done = False + got_segment_break = False while True: try: item = self._queue.get_nowait() if item is _DONE: got_done = True break + if item is _NEW_SEGMENT: + got_segment_break = True + break self._accumulated += item except queue.Empty: break @@ -111,6 +126,7 @@ class GatewayStreamConsumer: elapsed = now - self._last_edit_time should_edit = ( got_done + or got_segment_break or (elapsed >= self.cfg.edit_interval and len(self._accumulated) > 0) or len(self._accumulated) >= self.cfg.buffer_threshold @@ -133,7 +149,7 @@ class GatewayStreamConsumer: self._last_sent_text = "" display_text = self._accumulated - if not got_done: + if not got_done and not got_segment_break: display_text += self.cfg.cursor await self._send_or_edit(display_text) @@ -145,6 +161,15 @@ class GatewayStreamConsumer: await self._send_or_edit(self._accumulated) return + # Tool boundary: the should_edit block above already flushed + # accumulated text without a cursor. Reset state so the next + # text chunk creates a fresh message below any tool-progress + # messages the gateway sent in between. + if got_segment_break: + self._message_id = None + self._accumulated = "" + self._last_sent_text = "" + await asyncio.sleep(0.05) # Small yield to not busy-loop except asyncio.CancelledError: diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 1234307ca2..6c908bbe40 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -177,3 +177,150 @@ class TestStreamRunMediaStripping: assert "MEDIA:" not in sent_text, f"MEDIA: leaked into display: {sent_text!r}" assert consumer.already_sent + + +# ── Segment break (tool boundary) tests ────────────────────────────────── + + +class TestSegmentBreakOnToolBoundary: + """Verify that on_delta(None) finalizes the current message and starts a + new one so the final response appears below tool-progress messages.""" + + @pytest.mark.asyncio + async def test_segment_break_creates_new_message(self): + """After a None boundary, next text creates a fresh message.""" + adapter = MagicMock() + send_result_1 = SimpleNamespace(success=True, message_id="msg_1") + send_result_2 = SimpleNamespace(success=True, message_id="msg_2") + edit_result = SimpleNamespace(success=True) + adapter.send = AsyncMock(side_effect=[send_result_1, send_result_2]) + adapter.edit_message = AsyncMock(return_value=edit_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + # Phase 1: intermediate text before tool calls + consumer.on_delta("Let me search for that...") + # Tool boundary — model is about to call tools + consumer.on_delta(None) + # Phase 2: final response text after tools finished + consumer.on_delta("Here are the results.") + consumer.finish() + + await consumer.run() + + # Should have sent TWO separate messages (two adapter.send calls), + # not just edited the first one. + assert adapter.send.call_count == 2 + first_text = adapter.send.call_args_list[0][1]["content"] + second_text = adapter.send.call_args_list[1][1]["content"] + assert "search" in first_text + assert "results" in second_text + + @pytest.mark.asyncio + async def test_segment_break_no_text_before(self): + """A None boundary with no preceding text is a no-op.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + # No text before the boundary — model went straight to tool calls + consumer.on_delta(None) + consumer.on_delta("Final answer.") + consumer.finish() + + await consumer.run() + + # Only one send call (the final answer) + assert adapter.send.call_count == 1 + assert "Final answer" in adapter.send.call_args_list[0][1]["content"] + + @pytest.mark.asyncio + async def test_segment_break_removes_cursor(self): + """The finalized segment message should not have a cursor.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + edit_result = SimpleNamespace(success=True) + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=edit_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉") + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Thinking...") + consumer.on_delta(None) + consumer.on_delta("Done.") + consumer.finish() + + await consumer.run() + + # The first segment should have been finalized without cursor. + # Check all edit_message calls + the initial send for the first segment. + # The last state of msg_1 should NOT have the cursor. + all_texts = [] + for call in adapter.send.call_args_list: + all_texts.append(call[1].get("content", "")) + for call in adapter.edit_message.call_args_list: + all_texts.append(call[1].get("content", "")) + + # Find the text(s) that contain "Thinking" — the finalized version + # should not have the cursor. + thinking_texts = [t for t in all_texts if "Thinking" in t] + assert thinking_texts, "Expected at least one message with 'Thinking'" + # The LAST occurrence is the finalized version + assert "▉" not in thinking_texts[-1], ( + f"Cursor found in finalized segment: {thinking_texts[-1]!r}" + ) + + @pytest.mark.asyncio + async def test_multiple_segment_breaks(self): + """Multiple tool boundaries create multiple message segments.""" + adapter = MagicMock() + msg_counter = iter(["msg_1", "msg_2", "msg_3"]) + adapter.send = AsyncMock( + side_effect=lambda **kw: SimpleNamespace(success=True, message_id=next(msg_counter)) + ) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Phase 1") + consumer.on_delta(None) # tool boundary + consumer.on_delta("Phase 2") + consumer.on_delta(None) # another tool boundary + consumer.on_delta("Phase 3") + consumer.finish() + + await consumer.run() + + # Three separate messages + assert adapter.send.call_count == 3 + + @pytest.mark.asyncio + async def test_already_sent_stays_true_after_segment(self): + """already_sent remains True after a segment break.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5) + consumer = GatewayStreamConsumer(adapter, "chat_123", config) + + consumer.on_delta("Text") + consumer.on_delta(None) + consumer.finish() + + await consumer.run() + + assert consumer.already_sent