mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-07 08:02:23 +00:00
fix(gateway): harden kanban and provider cleanup races
This commit is contained in:
parent
31a0100104
commit
1a7bb988fc
6 changed files with 259 additions and 101 deletions
|
|
@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No
|
|||
|
||||
|
||||
|
||||
def _iter_pool_sockets(client: Any):
|
||||
"""Yield raw sockets reachable from an OpenAI/httpx client pool.
|
||||
|
||||
httpcore 1.x stores the concrete HTTP11/HTTP2 connection under
|
||||
``conn._connection``; older versions exposed stream attributes directly
|
||||
on the pool entry. Keep the traversal defensive because these are private
|
||||
transport internals and vary across httpx/httpcore releases.
|
||||
"""
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
seen: set[int] = set()
|
||||
for conn in list(connections):
|
||||
candidates = [conn]
|
||||
inner = getattr(conn, "_connection", None)
|
||||
if inner is not None:
|
||||
candidates.append(inner)
|
||||
for candidate in candidates:
|
||||
stream = (
|
||||
getattr(candidate, "_network_stream", None)
|
||||
or getattr(candidate, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
get_extra_info = getattr(stream, "get_extra_info", None)
|
||||
if callable(get_extra_info):
|
||||
try:
|
||||
sock = get_extra_info("socket")
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
wrapped = getattr(stream, "stream", None)
|
||||
if wrapped is not None:
|
||||
sock = getattr(wrapped, "_sock", None)
|
||||
if sock is None:
|
||||
# anyio-backed streams expose the raw socket through
|
||||
# SocketAttribute.raw_socket when available.
|
||||
wrapped = getattr(stream, "_stream", None)
|
||||
extra = getattr(wrapped, "extra", None)
|
||||
if callable(extra):
|
||||
try:
|
||||
from anyio.abc import SocketAttribute
|
||||
sock = extra(SocketAttribute.raw_socket)
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
continue
|
||||
marker = id(sock)
|
||||
if marker in seen:
|
||||
continue
|
||||
seen.add(marker)
|
||||
yield sock
|
||||
|
||||
|
||||
def cleanup_dead_connections(agent) -> bool:
|
||||
"""Detect and clean up dead TCP connections on the primary client.
|
||||
|
||||
|
|
@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool:
|
|||
if client is None:
|
||||
return False
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return False
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return False
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return False
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
dead_count = 0
|
||||
for conn in list(connections):
|
||||
# Check for connections that are idle but have closed sockets
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
# Probe socket health with a non-blocking recv peek
|
||||
import socket as _socket
|
||||
try:
|
||||
|
|
@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int:
|
|||
|
||||
closed = 0
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return 0
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return 0
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return 0
|
||||
# httpx uses httpcore connection pools; connections live in
|
||||
# _connections (list) or _pool (list) depending on version.
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
for conn in list(connections):
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
try:
|
||||
sock.shutdown(_socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
|
|
@ -2154,5 +2168,6 @@ __all__ = [
|
|||
"cleanup_dead_connections",
|
||||
"extract_api_error_context",
|
||||
"apply_pending_steer_to_tool_results",
|
||||
"_iter_pool_sockets",
|
||||
"force_close_tcp_sockets",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
"""
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if agent.api_mode == "codex_responses":
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = agent._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
client=request_client,
|
||||
on_first_delta=getattr(agent, "_codex_on_first_delta", None),
|
||||
)
|
||||
elif agent.api_mode == "anthropic_messages":
|
||||
|
|
@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
raise
|
||||
result["response"] = normalize_converse_response(raw_response)
|
||||
else:
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
|
||||
result["response"] = request_client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="request_complete")
|
||||
_close_request_client_once("request_complete")
|
||||
|
||||
# ── Stale-call timeout (mirrors streaming stale detector) ────────
|
||||
# Non-streaming calls return nothing until the full response is
|
||||
|
|
@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_call_kill")
|
||||
_close_request_client_once("stale_call_kill")
|
||||
except Exception:
|
||||
pass
|
||||
agent._touch_activity(
|
||||
|
|
@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="interrupt_abort")
|
||||
_close_request_client_once("interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
|
|
@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
|
||||
result = {"response": None, "error": None, "partial_tool_names": []}
|
||||
request_client_holder = {"client": None, "diag": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
# Wall-clock timestamp of the last real streaming chunk. The outer
|
||||
|
|
@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
pool=_conn_cap,
|
||||
),
|
||||
}
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
)
|
||||
)
|
||||
# Reset stale-stream timer so the detector measures from this
|
||||
# attempt's start, not a previous attempt's last chunk.
|
||||
|
|
@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
# ``request_client_holder["diag"]`` for closure access.
|
||||
_diag = agent._stream_diag_init()
|
||||
request_client_holder["diag"] = _diag
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
stream = request_client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
# Capture rate limit headers from the initial HTTP response.
|
||||
# The OpenAI SDK Stream object exposes the underlying httpx
|
||||
|
|
@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
mid_tool_call=True,
|
||||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_mid_tool_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_mid_tool_retry_cleanup")
|
||||
try:
|
||||
agent._replace_primary_openai_client(
|
||||
reason="stream_mid_tool_retry_pool_cleanup"
|
||||
|
|
@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
# Close the stale request client before retry
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_retry_cleanup")
|
||||
# Also rebuild the primary client to purge
|
||||
# any dead connections from the pool.
|
||||
try:
|
||||
|
|
@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
result["error"] = e
|
||||
return
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
_close_request_client_once("stream_request_complete")
|
||||
|
||||
# Provider-configured stale timeout takes priority over env default.
|
||||
_cfg_stale = get_provider_stale_timeout(agent.provider, agent.model)
|
||||
|
|
@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
f"Reconnecting..."
|
||||
)
|
||||
try:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_stream_kill")
|
||||
_close_request_client_once("stale_stream_kill")
|
||||
except Exception:
|
||||
pass
|
||||
# Rebuild the primary client too — its connection pool
|
||||
|
|
@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
_close_request_client_once("stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue