fix(api_server): SSE token batching + error handling for Open WebUI performance

Reduces SSE event rate ~500/turn → ~20/turn via 50ms text-delta batching in
_dispatch(), which eliminates markdown re-render storms on Open WebUI. Also:

- Trim tool_call.arguments in the response.completed event to 100KB
  (prevents silent hangs on 848KB+ single-line SSE events).
- Catch-all exception handlers in _write_sse_responses() + _write_sse_chat_completion()
  emit a proper error chunk instead of TransferEncodingError from incomplete
  chunked encoding when the agent crashes mid-stream.
- MAX_REQUEST_BYTES 1MB → 10MB; pass client_max_size to aiohttp Application to
  avoid silent 400s on truncated request bodies for long conversations.

Salvage of #17552 (api_server portion only). The contrib/openwebui-filter/
payload from that PR — Open WebUI Filter Function + benchmark writeup — is
a client-side user-installable add-on and doesn't need to live in the repo;
dropped here. Closes #17537.

Co-authored-by: bogerman1 <93757150+bogerman1@users.noreply.github.com>
This commit is contained in:
bogerman1 2026-05-05 15:12:21 -07:00 committed by Teknium
parent 3082fa0829
commit 3188e63b05

View file

@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8642
MAX_STORED_RESPONSES = 100
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
MAX_REQUEST_BYTES = 10_000_000 # 10 MB — accommodates long agent conversations with tool calls
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
MAX_NORMALIZED_TEXT_LENGTH = 65_536 # 64 KB cap for normalized content parts
MAX_CONTENT_LIST_SIZE = 1_000 # Max items when content is an array
@ -1349,6 +1349,22 @@ class APIServerAdapter(BasePlatformAdapter):
except (asyncio.CancelledError, Exception):
pass
logger.info("SSE client disconnected; interrupted agent task %s", completion_id)
except Exception as _exc:
# Agent crashed mid-stream. Try to emit an error chunk
# so the client gets a proper response instead of a
# TransferEncodingError from incomplete chunked encoding.
import traceback as _tb
logger.error("Agent crashed mid-stream for %s: %s", completion_id, _tb.format_exc()[:300])
try:
error_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "error"}],
}
await response.write(f"data: {json.dumps(error_chunk)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
except Exception:
pass
return response
@ -1669,20 +1685,54 @@ class APIServerAdapter(BasePlatformAdapter):
async def _dispatch(it) -> None:
"""Route a queue item to the correct SSE emitter.
Plain strings are text deltas. Tagged tuples with
``__tool_started__`` / ``__tool_completed__`` prefixes
are tool lifecycle events.
Plain strings are text deltas they are batched (50ms)
to reduce Open WebUI re-render storms. Tagged tuples
with ``__tool_started__`` / ``__tool_completed__``
prefixes are tool lifecycle events and flush the buffer
before emitting.
"""
nonlocal _batch_timer
if isinstance(it, tuple) and len(it) == 2 and isinstance(it[0], str):
tag, payload = it
# Flush batched text before tool events
if _batch_buf:
await _flush_batch()
if tag == "__tool_started__":
await _emit_tool_started(payload)
elif tag == "__tool_completed__":
await _emit_tool_completed(payload)
# Unknown tags are silently ignored (forward-compat).
elif isinstance(it, str):
await _emit_text_delta(it)
# Other types (non-string, non-tuple) are silently dropped.
# Batch text deltas — append to buffer, flush on timer
_batch_buf.append(it)
if _batch_timer is None:
_batch_timer = asyncio.create_task(_batch_flush_after(0.05))
# Other types are silently dropped.
# ── Batching state ──
_batch_buf: List[str] = []
_batch_timer: Optional[asyncio.Task] = None
_batch_lock = asyncio.Lock()
async def _batch_flush_after(delay: float) -> None:
"""Wait delay seconds, then flush accumulated text deltas."""
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
return
# Clear timer reference BEFORE flush so new deltas
# can start a fresh timer while we emit
nonlocal _batch_buf, _batch_timer
_batch_timer = None
await _flush_batch()
async def _flush_batch() -> None:
"""Emit a single SSE delta for all accumulated text."""
nonlocal _batch_buf
async with _batch_lock:
if _batch_buf:
combined = "".join(_batch_buf)
_batch_buf = []
await _emit_text_delta(combined)
loop = asyncio.get_running_loop()
while True:
@ -1707,11 +1757,21 @@ class APIServerAdapter(BasePlatformAdapter):
continue
if item is None: # EOS sentinel
# Cancel pending timer and flush remaining batched text
if _batch_timer and not _batch_timer.done():
_batch_timer.cancel()
_batch_timer = None
if _batch_buf:
await _flush_batch()
break
await _dispatch(item)
last_activity = time.monotonic()
# Flush any final batched text before processing result
if _batch_buf:
await _flush_batch()
# Pick up agent result + usage from the completed task
try:
result, agent_usage = await agent_task
@ -1762,6 +1822,31 @@ class APIServerAdapter(BasePlatformAdapter):
# payload still see the assistant text. This mirrors the
# shape produced by _extract_output_items in the batch path.
final_items: List[Dict[str, Any]] = list(emitted_items)
# Trim large content from tool call arguments to keep the
# response.completed event under ~100KB. Clients already
# received full details via incremental events.
for _item in final_items:
if _item.get("type") == "function_call":
try:
_args = json.loads(_item.get("arguments", "{}")) if isinstance(_item.get("arguments"), str) else _item.get("arguments", {})
if isinstance(_args, dict):
for _k in ("content", "query", "pattern", "old_string", "new_string"):
if isinstance(_args.get(_k), str) and len(_args[_k]) > 500:
_args[_k] = "[" + str(len(_args[_k])) + " chars — truncated for response.completed]"
_item["arguments"] = json.dumps(_args)
except Exception:
pass
elif _item.get("type") == "function_call_output":
_output = _item.get("output", [])
if isinstance(_output, list) and _output:
_first = _output[0]
if isinstance(_first, dict) and _first.get("type") == "input_text":
_text = _first.get("text", "")
if len(_text) > 1000:
_first["text"] = _text[:500] + "...[" + str(len(_text) - 500) + " more chars]"
_item["output"] = [_first]
final_items.append({
"type": "message",
"role": "assistant",
@ -1852,6 +1937,30 @@ class APIServerAdapter(BasePlatformAdapter):
agent_task.cancel()
logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id)
raise
except Exception as _exc:
# Agent crashed with an unhandled error (e.g. model API error like
# BadRequestError, AuthenticationError). Emit a response.failed
# event and properly terminate the SSE stream so the client doesn't
# get a TransferEncodingError from incomplete chunked encoding.
import traceback as _tb
_persist_incomplete_if_needed()
agent_error = _tb.format_exc()
try:
failed_env = _envelope("failed")
failed_env["output"] = list(emitted_items)
failed_env["error"] = {"message": str(_exc)[:500], "type": "server_error"}
failed_env["usage"] = {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
await _write_event("response.failed", {
"type": "response.failed",
"response": failed_env,
})
except Exception:
pass
logger.error("Agent crashed mid-stream for %s: %s", response_id, str(agent_error)[:300])
return response
@ -2935,7 +3044,7 @@ class APIServerAdapter(BasePlatformAdapter):
try:
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
self._app = web.Application(middlewares=mws)
self._app = web.Application(middlewares=mws, client_max_size=MAX_REQUEST_BYTES)
self._app["api_server_adapter"] = self
self._app.router.add_get("/health", self._handle_health)
self._app.router.add_get("/health/detailed", self._handle_health_detailed)