diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index b460754331..ae77100f6a 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 8642 MAX_STORED_RESPONSES = 100 -MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies +MAX_REQUEST_BYTES = 10_000_000 # 10 MB — accommodates long agent conversations with tool calls CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0 MAX_NORMALIZED_TEXT_LENGTH = 65_536 # 64 KB cap for normalized content parts MAX_CONTENT_LIST_SIZE = 1_000 # Max items when content is an array @@ -1349,6 +1349,22 @@ class APIServerAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass logger.info("SSE client disconnected; interrupted agent task %s", completion_id) + except Exception as _exc: + # Agent crashed mid-stream. Try to emit an error chunk + # so the client gets a proper response instead of a + # TransferEncodingError from incomplete chunked encoding. + import traceback as _tb + logger.error("Agent crashed mid-stream for %s: %s", completion_id, _tb.format_exc()[:300]) + try: + error_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "error"}], + } + await response.write(f"data: {json.dumps(error_chunk)}\n\n".encode()) + await response.write(b"data: [DONE]\n\n") + except Exception: + pass return response @@ -1669,20 +1685,54 @@ class APIServerAdapter(BasePlatformAdapter): async def _dispatch(it) -> None: """Route a queue item to the correct SSE emitter. - Plain strings are text deltas. Tagged tuples with - ``__tool_started__`` / ``__tool_completed__`` prefixes - are tool lifecycle events. + Plain strings are text deltas — they are batched (50ms) + to reduce Open WebUI re-render storms. Tagged tuples + with ``__tool_started__`` / ``__tool_completed__`` + prefixes are tool lifecycle events and flush the buffer + before emitting. """ + nonlocal _batch_timer if isinstance(it, tuple) and len(it) == 2 and isinstance(it[0], str): tag, payload = it + # Flush batched text before tool events + if _batch_buf: + await _flush_batch() if tag == "__tool_started__": await _emit_tool_started(payload) elif tag == "__tool_completed__": await _emit_tool_completed(payload) - # Unknown tags are silently ignored (forward-compat). elif isinstance(it, str): - await _emit_text_delta(it) - # Other types (non-string, non-tuple) are silently dropped. + # Batch text deltas — append to buffer, flush on timer + _batch_buf.append(it) + if _batch_timer is None: + _batch_timer = asyncio.create_task(_batch_flush_after(0.05)) + # Other types are silently dropped. + + # ── Batching state ── + _batch_buf: List[str] = [] + _batch_timer: Optional[asyncio.Task] = None + _batch_lock = asyncio.Lock() + + async def _batch_flush_after(delay: float) -> None: + """Wait delay seconds, then flush accumulated text deltas.""" + try: + await asyncio.sleep(delay) + except asyncio.CancelledError: + return + # Clear timer reference BEFORE flush so new deltas + # can start a fresh timer while we emit + nonlocal _batch_buf, _batch_timer + _batch_timer = None + await _flush_batch() + + async def _flush_batch() -> None: + """Emit a single SSE delta for all accumulated text.""" + nonlocal _batch_buf + async with _batch_lock: + if _batch_buf: + combined = "".join(_batch_buf) + _batch_buf = [] + await _emit_text_delta(combined) loop = asyncio.get_running_loop() while True: @@ -1707,11 +1757,21 @@ class APIServerAdapter(BasePlatformAdapter): continue if item is None: # EOS sentinel + # Cancel pending timer and flush remaining batched text + if _batch_timer and not _batch_timer.done(): + _batch_timer.cancel() + _batch_timer = None + if _batch_buf: + await _flush_batch() break await _dispatch(item) last_activity = time.monotonic() + # Flush any final batched text before processing result + if _batch_buf: + await _flush_batch() + # Pick up agent result + usage from the completed task try: result, agent_usage = await agent_task @@ -1762,6 +1822,31 @@ class APIServerAdapter(BasePlatformAdapter): # payload still see the assistant text. This mirrors the # shape produced by _extract_output_items in the batch path. final_items: List[Dict[str, Any]] = list(emitted_items) + + # Trim large content from tool call arguments to keep the + # response.completed event under ~100KB. Clients already + # received full details via incremental events. + for _item in final_items: + if _item.get("type") == "function_call": + try: + _args = json.loads(_item.get("arguments", "{}")) if isinstance(_item.get("arguments"), str) else _item.get("arguments", {}) + if isinstance(_args, dict): + for _k in ("content", "query", "pattern", "old_string", "new_string"): + if isinstance(_args.get(_k), str) and len(_args[_k]) > 500: + _args[_k] = "[" + str(len(_args[_k])) + " chars — truncated for response.completed]" + _item["arguments"] = json.dumps(_args) + except Exception: + pass + elif _item.get("type") == "function_call_output": + _output = _item.get("output", []) + if isinstance(_output, list) and _output: + _first = _output[0] + if isinstance(_first, dict) and _first.get("type") == "input_text": + _text = _first.get("text", "") + if len(_text) > 1000: + _first["text"] = _text[:500] + "...[" + str(len(_text) - 500) + " more chars]" + _item["output"] = [_first] + final_items.append({ "type": "message", "role": "assistant", @@ -1852,6 +1937,30 @@ class APIServerAdapter(BasePlatformAdapter): agent_task.cancel() logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id) raise + except Exception as _exc: + # Agent crashed with an unhandled error (e.g. model API error like + # BadRequestError, AuthenticationError). Emit a response.failed + # event and properly terminate the SSE stream so the client doesn't + # get a TransferEncodingError from incomplete chunked encoding. + import traceback as _tb + _persist_incomplete_if_needed() + agent_error = _tb.format_exc() + try: + failed_env = _envelope("failed") + failed_env["output"] = list(emitted_items) + failed_env["error"] = {"message": str(_exc)[:500], "type": "server_error"} + failed_env["usage"] = { + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + await _write_event("response.failed", { + "type": "response.failed", + "response": failed_env, + }) + except Exception: + pass + logger.error("Agent crashed mid-stream for %s: %s", response_id, str(agent_error)[:300]) return response @@ -2935,7 +3044,7 @@ class APIServerAdapter(BasePlatformAdapter): try: mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None] - self._app = web.Application(middlewares=mws) + self._app = web.Application(middlewares=mws, client_max_size=MAX_REQUEST_BYTES) self._app["api_server_adapter"] = self self._app.router.add_get("/health", self._handle_health) self._app.router.add_get("/health/detailed", self._handle_health_detailed)