fix(gateway): harden kanban and provider cleanup races

This commit is contained in:
helix4u 2026-05-20 14:58:01 -06:00 committed by Teknium
parent 31a0100104
commit 1a7bb988fc
6 changed files with 259 additions and 101 deletions

View file

@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No
def _iter_pool_sockets(client: Any):
"""Yield raw sockets reachable from an OpenAI/httpx client pool.
httpcore 1.x stores the concrete HTTP11/HTTP2 connection under
``conn._connection``; older versions exposed stream attributes directly
on the pool entry. Keep the traversal defensive because these are private
transport internals and vary across httpx/httpcore releases.
"""
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return
transport = getattr(http_client, "_transport", None)
if transport is None:
return
pool = getattr(transport, "_pool", None)
if pool is None:
return
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
except Exception:
return
seen: set[int] = set()
for conn in list(connections):
candidates = [conn]
inner = getattr(conn, "_connection", None)
if inner is not None:
candidates.append(inner)
for candidate in candidates:
stream = (
getattr(candidate, "_network_stream", None)
or getattr(candidate, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
get_extra_info = getattr(stream, "get_extra_info", None)
if callable(get_extra_info):
try:
sock = get_extra_info("socket")
except Exception:
sock = None
if sock is None:
wrapped = getattr(stream, "stream", None)
if wrapped is not None:
sock = getattr(wrapped, "_sock", None)
if sock is None:
# anyio-backed streams expose the raw socket through
# SocketAttribute.raw_socket when available.
wrapped = getattr(stream, "_stream", None)
extra = getattr(wrapped, "extra", None)
if callable(extra):
try:
from anyio.abc import SocketAttribute
sock = extra(SocketAttribute.raw_socket)
except Exception:
sock = None
if sock is None:
continue
marker = id(sock)
if marker in seen:
continue
seen.add(marker)
yield sock
def cleanup_dead_connections(agent) -> bool:
"""Detect and clean up dead TCP connections on the primary client.
@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool:
if client is None:
return False
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return False
transport = getattr(http_client, "_transport", None)
if transport is None:
return False
pool = getattr(transport, "_pool", None)
if pool is None:
return False
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
dead_count = 0
for conn in list(connections):
# Check for connections that are idle but have closed sockets
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
for sock in _iter_pool_sockets(client):
# Probe socket health with a non-blocking recv peek
import socket as _socket
try:
@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int:
closed = 0
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return 0
transport = getattr(http_client, "_transport", None)
if transport is None:
return 0
pool = getattr(transport, "_pool", None)
if pool is None:
return 0
# httpx uses httpcore connection pools; connections live in
# _connections (list) or _pool (list) depending on version.
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
for conn in list(connections):
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
for sock in _iter_pool_sockets(client):
try:
sock.shutdown(_socket.SHUT_RDWR)
except OSError:
@ -2154,5 +2168,6 @@ __all__ = [
"cleanup_dead_connections",
"extract_api_error_context",
"apply_pending_steer_to_tool_results",
"_iter_pool_sockets",
"force_close_tcp_sockets",
]

View file

@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict):
"""
result = {"response": None, "error": None}
request_client_holder = {"client": None}
request_client_lock = threading.Lock()
def _set_request_client(client):
with request_client_lock:
request_client_holder["client"] = client
return client
def _take_request_client():
with request_client_lock:
client = request_client_holder.get("client")
request_client_holder["client"] = None
return client
def _close_request_client_once(reason: str) -> None:
request_client = _take_request_client()
if request_client is not None:
agent._close_request_openai_client(request_client, reason=reason)
def _call():
try:
if agent.api_mode == "codex_responses":
request_client_holder["client"] = agent._create_request_openai_client(
reason="codex_stream_request",
api_kwargs=api_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="codex_stream_request",
api_kwargs=api_kwargs,
)
)
result["response"] = agent._run_codex_stream(
api_kwargs,
client=request_client_holder["client"],
client=request_client,
on_first_delta=getattr(agent, "_codex_on_first_delta", None),
)
elif agent.api_mode == "anthropic_messages":
@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict):
raise
result["response"] = normalize_converse_response(raw_response)
else:
request_client_holder["client"] = agent._create_request_openai_client(
reason="chat_completion_request",
api_kwargs=api_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="chat_completion_request",
api_kwargs=api_kwargs,
)
)
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
result["response"] = request_client.chat.completions.create(**api_kwargs)
except Exception as e:
result["error"] = e
finally:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="request_complete")
_close_request_client_once("request_complete")
# ── Stale-call timeout (mirrors streaming stale detector) ────────
# Non-streaming calls return nothing until the full response is
@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
rc = request_client_holder.get("client")
if rc is not None:
agent._close_request_openai_client(rc, reason="stale_call_kill")
_close_request_client_once("stale_call_kill")
except Exception:
pass
agent._touch_activity(
@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="interrupt_abort")
_close_request_client_once("interrupt_abort")
except Exception:
pass
raise InterruptedError("Agent interrupted during API call")
@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
result = {"response": None, "error": None, "partial_tool_names": []}
request_client_holder = {"client": None, "diag": None}
request_client_lock = threading.Lock()
def _set_request_client(client):
with request_client_lock:
request_client_holder["client"] = client
return client
def _take_request_client():
with request_client_lock:
client = request_client_holder.get("client")
request_client_holder["client"] = None
return client
def _close_request_client_once(reason: str) -> None:
request_client = _take_request_client()
if request_client is not None:
agent._close_request_openai_client(request_client, reason=reason)
first_delta_fired = {"done": False}
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
# Wall-clock timestamp of the last real streaming chunk. The outer
@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
pool=_conn_cap,
),
}
request_client_holder["client"] = agent._create_request_openai_client(
reason="chat_completion_stream_request",
api_kwargs=stream_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="chat_completion_stream_request",
api_kwargs=stream_kwargs,
)
)
# Reset stale-stream timer so the detector measures from this
# attempt's start, not a previous attempt's last chunk.
@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
# ``request_client_holder["diag"]`` for closure access.
_diag = agent._stream_diag_init()
request_client_holder["diag"] = _diag
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
stream = request_client.chat.completions.create(**stream_kwargs)
# Capture rate limit headers from the initial HTTP response.
# The OpenAI SDK Stream object exposes the underlying httpx
@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
mid_tool_call=True,
diag=request_client_holder.get("diag"),
)
stale = request_client_holder.get("client")
if stale is not None:
agent._close_request_openai_client(
stale, reason="stream_mid_tool_retry_cleanup"
)
request_client_holder["client"] = None
_close_request_client_once("stream_mid_tool_retry_cleanup")
try:
agent._replace_primary_openai_client(
reason="stream_mid_tool_retry_pool_cleanup"
@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
diag=request_client_holder.get("diag"),
)
# Close the stale request client before retry
stale = request_client_holder.get("client")
if stale is not None:
agent._close_request_openai_client(
stale, reason="stream_retry_cleanup"
)
request_client_holder["client"] = None
_close_request_client_once("stream_retry_cleanup")
# Also rebuild the primary client to purge
# any dead connections from the pool.
try:
@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
result["error"] = e
return
finally:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="stream_request_complete")
_close_request_client_once("stream_request_complete")
# Provider-configured stale timeout takes priority over env default.
_cfg_stale = get_provider_stale_timeout(agent.provider, agent.model)
@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
f"Reconnecting..."
)
try:
rc = request_client_holder.get("client")
if rc is not None:
agent._close_request_openai_client(rc, reason="stale_stream_kill")
_close_request_client_once("stale_stream_kill")
except Exception:
pass
# Rebuild the primary client too — its connection pool
@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="stream_interrupt_abort")
_close_request_client_once("stream_interrupt_abort")
except Exception:
pass
raise InterruptedError("Agent interrupted during streaming API call")