fix(streaming): filter <think> blocks from gateway stream consumer

Models like MiniMax emit inline <think>...</think> reasoning blocks in
their content field. The CLI already suppresses these via a state machine
in _stream_delta, but the gateway's GatewayStreamConsumer had no
equivalent filtering — raw think blocks were streamed directly to
Discord/Telegram/Slack.

The fix adds a _filter_and_accumulate() method that mirrors the CLI's
approach: a state machine tracks whether we're inside a reasoning block
and silently discards the content. Includes the same block-boundary
check (tag must appear at line start or after whitespace-only prefix)
to avoid false positives when models mention <think> in prose.

Handles all tag variants: <think>, <thinking>, <THINKING>, <thought>,
<reasoning>, <REASONING_SCRATCHPAD>.

Also handles edge cases:
- Tags split across streaming deltas (partial tag buffering)
- Unclosed blocks (content suppressed until stream ends)
- Multiple consecutive blocks
- _flush_think_buffer on stream end for held-back partial tags

Adds 22 unit tests + 1 integration test covering all scenarios.
This commit is contained in:
Teknium 2026-04-13 22:10:33 -07:00 committed by Teknium
parent e08590888a
commit 3de2b98503
2 changed files with 328 additions and 1 deletions

View file

@ -64,6 +64,18 @@ class GatewayStreamConsumer:
# progressive edits for the remainder of the stream.
_MAX_FLOOD_STRIKES = 3
# Reasoning/thinking tags that models emit inline in content.
# Must stay in sync with cli.py _OPEN_TAGS/_CLOSE_TAGS and
# run_agent.py _strip_think_blocks() tag variants.
_OPEN_THINK_TAGS = (
"<REASONING_SCRATCHPAD>", "<think>", "<reasoning>",
"<THINKING>", "<thinking>", "<thought>",
)
_CLOSE_THINK_TAGS = (
"</REASONING_SCRATCHPAD>", "</think>", "</reasoning>",
"</THINKING>", "</thinking>", "</thought>",
)
def __init__(
self,
adapter: Any,
@ -88,6 +100,10 @@ class GatewayStreamConsumer:
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
self._final_response_sent = False
# Think-block filter state (mirrors CLI's _stream_delta tag suppression)
self._in_think_block = False
self._think_buffer = ""
@property
def already_sent(self) -> bool:
"""True if at least one message was sent or edited during the run."""
@ -132,6 +148,112 @@ class GatewayStreamConsumer:
"""Signal that the stream is complete."""
self._queue.put(_DONE)
# ── Think-block filtering ────────────────────────────────────────
# Models like MiniMax emit inline <think>...</think> blocks in their
# content. The CLI's _stream_delta suppresses these via a state
# machine; we do the same here so gateway users never see raw
# reasoning tags. The agent also strips them from the final
# response (run_agent.py _strip_think_blocks), but the stream
# consumer sends intermediate edits before that stripping happens.
def _filter_and_accumulate(self, text: str) -> None:
"""Add a text delta to the accumulated buffer, suppressing think blocks.
Uses a state machine that tracks whether we are inside a
reasoning/thinking block. Text inside such blocks is silently
discarded. Partial tags at buffer boundaries are held back in
``_think_buffer`` until enough characters arrive to decide.
"""
buf = self._think_buffer + text
self._think_buffer = ""
while buf:
if self._in_think_block:
# Look for the earliest closing tag
best_idx = -1
best_len = 0
for tag in self._CLOSE_THINK_TAGS:
idx = buf.find(tag)
if idx != -1 and (best_idx == -1 or idx < best_idx):
best_idx = idx
best_len = len(tag)
if best_len:
# Found closing tag — discard block, process remainder
self._in_think_block = False
buf = buf[best_idx + best_len:]
else:
# No closing tag yet — hold tail that could be a
# partial closing tag prefix, discard the rest.
max_tag = max(len(t) for t in self._CLOSE_THINK_TAGS)
self._think_buffer = buf[-max_tag:] if len(buf) > max_tag else buf
return
else:
# Look for earliest opening tag at a block boundary
# (start of text / preceded by newline + optional whitespace).
# This prevents false positives when models *mention* tags
# in prose (e.g. "the <think> tag is used for…").
best_idx = -1
best_len = 0
for tag in self._OPEN_THINK_TAGS:
search_start = 0
while True:
idx = buf.find(tag, search_start)
if idx == -1:
break
# Block-boundary check (mirrors cli.py logic)
if idx == 0:
is_boundary = (
not self._accumulated
or self._accumulated.endswith("\n")
)
else:
preceding = buf[:idx]
last_nl = preceding.rfind("\n")
if last_nl == -1:
is_boundary = (
(not self._accumulated
or self._accumulated.endswith("\n"))
and preceding.strip() == ""
)
else:
is_boundary = preceding[last_nl + 1:].strip() == ""
if is_boundary and (best_idx == -1 or idx < best_idx):
best_idx = idx
best_len = len(tag)
break # first boundary hit for this tag is enough
search_start = idx + 1
if best_len:
# Emit text before the tag, enter think block
self._accumulated += buf[:best_idx]
self._in_think_block = True
buf = buf[best_idx + best_len:]
else:
# No opening tag — check for a partial tag at the tail
held_back = 0
for tag in self._OPEN_THINK_TAGS:
for i in range(1, len(tag)):
if buf.endswith(tag[:i]) and i > held_back:
held_back = i
if held_back:
self._accumulated += buf[:-held_back]
self._think_buffer = buf[-held_back:]
else:
self._accumulated += buf
return
def _flush_think_buffer(self) -> None:
"""Flush any held-back partial-tag buffer into accumulated text.
Called when the stream ends (got_done) so that partial text that
was held back waiting for a possible opening tag is not lost.
"""
if self._think_buffer and not self._in_think_block:
self._accumulated += self._think_buffer
self._think_buffer = ""
async def run(self) -> None:
"""Async task that drains the queue and edits the platform message."""
# Platform message length limit — leave room for cursor + formatting
@ -156,10 +278,16 @@ class GatewayStreamConsumer:
if isinstance(item, tuple) and len(item) == 2 and item[0] is _COMMENTARY:
commentary_text = item[1]
break
self._accumulated += item
self._filter_and_accumulate(item)
except queue.Empty:
break
# Flush any held-back partial-tag buffer on stream end
# so trailing text that was waiting for a potential open
# tag is not lost.
if got_done:
self._flush_think_buffer()
# Decide whether to flush an edit
now = time.monotonic()
elapsed = now - self._last_edit_time