fix(gateway): harden kanban and provider cleanup races

This commit is contained in:
helix4u 2026-05-20 14:58:01 -06:00 committed by Teknium
parent 31a0100104
commit 1a7bb988fc
6 changed files with 259 additions and 101 deletions

View file

@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No
def _iter_pool_sockets(client: Any):
"""Yield raw sockets reachable from an OpenAI/httpx client pool.
httpcore 1.x stores the concrete HTTP11/HTTP2 connection under
``conn._connection``; older versions exposed stream attributes directly
on the pool entry. Keep the traversal defensive because these are private
transport internals and vary across httpx/httpcore releases.
"""
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return
transport = getattr(http_client, "_transport", None)
if transport is None:
return
pool = getattr(transport, "_pool", None)
if pool is None:
return
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
except Exception:
return
seen: set[int] = set()
for conn in list(connections):
candidates = [conn]
inner = getattr(conn, "_connection", None)
if inner is not None:
candidates.append(inner)
for candidate in candidates:
stream = (
getattr(candidate, "_network_stream", None)
or getattr(candidate, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
get_extra_info = getattr(stream, "get_extra_info", None)
if callable(get_extra_info):
try:
sock = get_extra_info("socket")
except Exception:
sock = None
if sock is None:
wrapped = getattr(stream, "stream", None)
if wrapped is not None:
sock = getattr(wrapped, "_sock", None)
if sock is None:
# anyio-backed streams expose the raw socket through
# SocketAttribute.raw_socket when available.
wrapped = getattr(stream, "_stream", None)
extra = getattr(wrapped, "extra", None)
if callable(extra):
try:
from anyio.abc import SocketAttribute
sock = extra(SocketAttribute.raw_socket)
except Exception:
sock = None
if sock is None:
continue
marker = id(sock)
if marker in seen:
continue
seen.add(marker)
yield sock
def cleanup_dead_connections(agent) -> bool:
"""Detect and clean up dead TCP connections on the primary client.
@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool:
if client is None:
return False
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return False
transport = getattr(http_client, "_transport", None)
if transport is None:
return False
pool = getattr(transport, "_pool", None)
if pool is None:
return False
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
dead_count = 0
for conn in list(connections):
# Check for connections that are idle but have closed sockets
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
for sock in _iter_pool_sockets(client):
# Probe socket health with a non-blocking recv peek
import socket as _socket
try:
@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int:
closed = 0
try:
http_client = getattr(client, "_client", None)
if http_client is None:
return 0
transport = getattr(http_client, "_transport", None)
if transport is None:
return 0
pool = getattr(transport, "_pool", None)
if pool is None:
return 0
# httpx uses httpcore connection pools; connections live in
# _connections (list) or _pool (list) depending on version.
connections = (
getattr(pool, "_connections", None)
or getattr(pool, "_pool", None)
or []
)
for conn in list(connections):
stream = (
getattr(conn, "_network_stream", None)
or getattr(conn, "_stream", None)
)
if stream is None:
continue
sock = getattr(stream, "_sock", None)
if sock is None:
sock = getattr(stream, "stream", None)
if sock is not None:
sock = getattr(sock, "_sock", None)
if sock is None:
continue
for sock in _iter_pool_sockets(client):
try:
sock.shutdown(_socket.SHUT_RDWR)
except OSError:
@ -2154,5 +2168,6 @@ __all__ = [
"cleanup_dead_connections",
"extract_api_error_context",
"apply_pending_steer_to_tool_results",
"_iter_pool_sockets",
"force_close_tcp_sockets",
]

View file

@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict):
"""
result = {"response": None, "error": None}
request_client_holder = {"client": None}
request_client_lock = threading.Lock()
def _set_request_client(client):
with request_client_lock:
request_client_holder["client"] = client
return client
def _take_request_client():
with request_client_lock:
client = request_client_holder.get("client")
request_client_holder["client"] = None
return client
def _close_request_client_once(reason: str) -> None:
request_client = _take_request_client()
if request_client is not None:
agent._close_request_openai_client(request_client, reason=reason)
def _call():
try:
if agent.api_mode == "codex_responses":
request_client_holder["client"] = agent._create_request_openai_client(
reason="codex_stream_request",
api_kwargs=api_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="codex_stream_request",
api_kwargs=api_kwargs,
)
)
result["response"] = agent._run_codex_stream(
api_kwargs,
client=request_client_holder["client"],
client=request_client,
on_first_delta=getattr(agent, "_codex_on_first_delta", None),
)
elif agent.api_mode == "anthropic_messages":
@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict):
raise
result["response"] = normalize_converse_response(raw_response)
else:
request_client_holder["client"] = agent._create_request_openai_client(
reason="chat_completion_request",
api_kwargs=api_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="chat_completion_request",
api_kwargs=api_kwargs,
)
)
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
result["response"] = request_client.chat.completions.create(**api_kwargs)
except Exception as e:
result["error"] = e
finally:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="request_complete")
_close_request_client_once("request_complete")
# ── Stale-call timeout (mirrors streaming stale detector) ────────
# Non-streaming calls return nothing until the full response is
@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
rc = request_client_holder.get("client")
if rc is not None:
agent._close_request_openai_client(rc, reason="stale_call_kill")
_close_request_client_once("stale_call_kill")
except Exception:
pass
agent._touch_activity(
@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="interrupt_abort")
_close_request_client_once("interrupt_abort")
except Exception:
pass
raise InterruptedError("Agent interrupted during API call")
@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
result = {"response": None, "error": None, "partial_tool_names": []}
request_client_holder = {"client": None, "diag": None}
request_client_lock = threading.Lock()
def _set_request_client(client):
with request_client_lock:
request_client_holder["client"] = client
return client
def _take_request_client():
with request_client_lock:
client = request_client_holder.get("client")
request_client_holder["client"] = None
return client
def _close_request_client_once(reason: str) -> None:
request_client = _take_request_client()
if request_client is not None:
agent._close_request_openai_client(request_client, reason=reason)
first_delta_fired = {"done": False}
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
# Wall-clock timestamp of the last real streaming chunk. The outer
@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
pool=_conn_cap,
),
}
request_client_holder["client"] = agent._create_request_openai_client(
reason="chat_completion_stream_request",
api_kwargs=stream_kwargs,
request_client = _set_request_client(
agent._create_request_openai_client(
reason="chat_completion_stream_request",
api_kwargs=stream_kwargs,
)
)
# Reset stale-stream timer so the detector measures from this
# attempt's start, not a previous attempt's last chunk.
@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
# ``request_client_holder["diag"]`` for closure access.
_diag = agent._stream_diag_init()
request_client_holder["diag"] = _diag
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
stream = request_client.chat.completions.create(**stream_kwargs)
# Capture rate limit headers from the initial HTTP response.
# The OpenAI SDK Stream object exposes the underlying httpx
@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
mid_tool_call=True,
diag=request_client_holder.get("diag"),
)
stale = request_client_holder.get("client")
if stale is not None:
agent._close_request_openai_client(
stale, reason="stream_mid_tool_retry_cleanup"
)
request_client_holder["client"] = None
_close_request_client_once("stream_mid_tool_retry_cleanup")
try:
agent._replace_primary_openai_client(
reason="stream_mid_tool_retry_pool_cleanup"
@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
diag=request_client_holder.get("diag"),
)
# Close the stale request client before retry
stale = request_client_holder.get("client")
if stale is not None:
agent._close_request_openai_client(
stale, reason="stream_retry_cleanup"
)
request_client_holder["client"] = None
_close_request_client_once("stream_retry_cleanup")
# Also rebuild the primary client to purge
# any dead connections from the pool.
try:
@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
result["error"] = e
return
finally:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="stream_request_complete")
_close_request_client_once("stream_request_complete")
# Provider-configured stale timeout takes priority over env default.
_cfg_stale = get_provider_stale_timeout(agent.provider, agent.model)
@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
f"Reconnecting..."
)
try:
rc = request_client_holder.get("client")
if rc is not None:
agent._close_request_openai_client(rc, reason="stale_stream_kill")
_close_request_client_once("stale_stream_kill")
except Exception:
pass
# Rebuild the primary client too — its connection pool
@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
agent._anthropic_client.close()
agent._rebuild_anthropic_client()
else:
request_client = request_client_holder.get("client")
if request_client is not None:
agent._close_request_openai_client(request_client, reason="stream_interrupt_abort")
_close_request_client_once("stream_interrupt_abort")
except Exception:
pass
raise InterruptedError("Agent interrupted during streaming API call")

View file

@ -951,6 +951,58 @@ CREATE INDEX IF NOT EXISTS idx_notify_task ON kanban_notify_subs(task_
_INITIALIZED_PATHS: set[str] = set()
_INIT_LOCK = threading.RLock()
_SQLITE_HEADER = b"SQLite format 3\x00"
def _looks_like_tls_record_at(data: bytes, offset: int) -> bool:
"""Return True for a TLS record header at ``data[offset:]``."""
if len(data) < offset + 5:
return False
content_type = data[offset]
major = data[offset + 1]
minor = data[offset + 2]
length = int.from_bytes(data[offset + 3:offset + 5], "big")
return (
content_type in {0x14, 0x15, 0x16, 0x17}
and major == 0x03
and minor in {0x00, 0x01, 0x02, 0x03, 0x04}
and 0 < length <= 18432
)
def _validate_sqlite_header(path: Path) -> None:
"""Fail early with an actionable error for non-SQLite Kanban DB files.
``sqlite3.connect()`` creates missing and zero-byte files, so those are
allowed. Existing non-empty files must have the SQLite header before we
hand them to SQLite/WAL setup. This keeps corrupted page-0 failures from
being collapsed into a generic PRAGMA error and lets the gateway's corrupt
board handling identify the board by fingerprint.
"""
try:
stat = path.stat()
except FileNotFoundError:
return
except OSError:
return
if stat.st_size == 0:
return
try:
with path.open("rb") as handle:
head = handle.read(64)
except OSError:
return
if head.startswith(_SQLITE_HEADER):
return
signature = ""
if head.startswith(b"SQLit") and _looks_like_tls_record_at(head, 5):
signature = " (TLS record header detected at byte offset 5)"
elif _looks_like_tls_record_at(head, 0):
signature = " (TLS record header detected at byte offset 0)"
raise sqlite3.DatabaseError(
"file is not a database: invalid SQLite header for "
f"{path}{signature}; first_32={head[:32].hex(' ')}"
)
def connect(
@ -981,6 +1033,7 @@ def connect(
else:
path = kanban_db_path(board=board)
path.parent.mkdir(parents=True, exist_ok=True)
_validate_sqlite_header(path)
resolved = str(path.resolve())
conn = sqlite3.connect(str(path), isolation_level=None, timeout=30)
try:

View file

@ -48,6 +48,27 @@ def test_init_creates_expected_tables(kanban_home):
assert {"tasks", "task_links", "task_comments", "task_events"} <= names
def test_connect_rejects_tls_record_in_sqlite_header(tmp_path, monkeypatch):
"""Kanban should classify TLS-looking page-0 clobbers before WAL setup."""
home = tmp_path / ".hermes"
home.mkdir()
monkeypatch.setenv("HERMES_HOME", str(home))
monkeypatch.delenv("HERMES_KANBAN_DB", raising=False)
monkeypatch.delenv("HERMES_KANBAN_HOME", raising=False)
monkeypatch.setattr(Path, "home", lambda: tmp_path)
corrupt = home / "kanban.db"
corrupt.write_bytes(b"SQLit" + bytes.fromhex("17 03 03 00 13") + b"x" * 32)
with pytest.raises(sqlite3.DatabaseError) as exc_info:
kb.connect(board="default")
msg = str(exc_info.value)
assert "file is not a database" in msg
assert "TLS record header detected at byte offset 5" in msg
assert "53 51 4c 69 74 17 03 03 00 13" in msg
def test_connect_migrates_legacy_db_before_optional_column_indexes(tmp_path):
"""Legacy DBs missing additive indexed columns must migrate cleanly.

View file

@ -16,6 +16,7 @@ with ``APIConnectionError('Connection error.')`` whose cause was
That is the exact scenario this test reproduces at object level without a
network, so it runs in CI on every PR.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from run_agent import AIAgent
@ -186,3 +187,32 @@ def test_replace_primary_openai_client_survives_repeated_rebuilds():
"Some _create_openai_client calls returned the same object across "
"a teardown — rebuild is not producing fresh clients"
)
def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper():
"""httpcore 1.x stores the real stream below conn._connection."""
from agent.agent_runtime_helpers import force_close_tcp_sockets
class FakeSocket:
def __init__(self):
self.shutdown_calls = 0
self.close_calls = 0
def shutdown(self, _how):
self.shutdown_calls += 1
def close(self):
self.close_calls += 1
sock = FakeSocket()
stream = SimpleNamespace(_sock=sock)
http11 = SimpleNamespace(_network_stream=stream)
pool_entry = SimpleNamespace(_connection=http11)
pool = SimpleNamespace(_connections=[pool_entry])
transport = SimpleNamespace(_pool=pool)
http_client = SimpleNamespace(_transport=transport)
openai_client = SimpleNamespace(_client=http_client)
assert force_close_tcp_sockets(openai_client) == 1
assert sock.shutdown_calls == 1
assert sock.close_calls == 1

View file

@ -1,5 +1,6 @@
import sys
import threading
import time
import types
from types import SimpleNamespace
@ -64,6 +65,7 @@ def _build_agent(shared_client=None):
agent.stream_delta_callback = None
agent._stream_callback = None
agent.reasoning_callback = None
agent.status_callback = None
return agent
@ -93,6 +95,24 @@ def test_retry_after_api_connection_error_recreates_request_client(monkeypatch):
assert second_request.close_calls >= 1
def test_stale_non_stream_close_is_single_owner(monkeypatch):
def slow_responder(**kwargs):
time.sleep(0.1)
raise _connection_error()
request_client = FakeRequestClient(slow_responder)
factory = OpenAIFactory([request_client])
monkeypatch.setattr(run_agent, "OpenAI", factory)
agent = _build_agent()
agent._compute_non_stream_stale_timeout = lambda _messages: 0.01
with pytest.raises(APIConnectionError):
agent._interruptible_api_call({"model": agent.model, "messages": []})
assert request_client.close_calls == 1
def test_closed_shared_client_is_recreated_before_request(monkeypatch):
stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used")))
stale_shared._client.is_closed = True