mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
fix(gateway): harden kanban and provider cleanup races
This commit is contained in:
parent
31a0100104
commit
1a7bb988fc
6 changed files with 259 additions and 101 deletions
|
|
@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No
|
|||
|
||||
|
||||
|
||||
def _iter_pool_sockets(client: Any):
|
||||
"""Yield raw sockets reachable from an OpenAI/httpx client pool.
|
||||
|
||||
httpcore 1.x stores the concrete HTTP11/HTTP2 connection under
|
||||
``conn._connection``; older versions exposed stream attributes directly
|
||||
on the pool entry. Keep the traversal defensive because these are private
|
||||
transport internals and vary across httpx/httpcore releases.
|
||||
"""
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
seen: set[int] = set()
|
||||
for conn in list(connections):
|
||||
candidates = [conn]
|
||||
inner = getattr(conn, "_connection", None)
|
||||
if inner is not None:
|
||||
candidates.append(inner)
|
||||
for candidate in candidates:
|
||||
stream = (
|
||||
getattr(candidate, "_network_stream", None)
|
||||
or getattr(candidate, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
get_extra_info = getattr(stream, "get_extra_info", None)
|
||||
if callable(get_extra_info):
|
||||
try:
|
||||
sock = get_extra_info("socket")
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
wrapped = getattr(stream, "stream", None)
|
||||
if wrapped is not None:
|
||||
sock = getattr(wrapped, "_sock", None)
|
||||
if sock is None:
|
||||
# anyio-backed streams expose the raw socket through
|
||||
# SocketAttribute.raw_socket when available.
|
||||
wrapped = getattr(stream, "_stream", None)
|
||||
extra = getattr(wrapped, "extra", None)
|
||||
if callable(extra):
|
||||
try:
|
||||
from anyio.abc import SocketAttribute
|
||||
sock = extra(SocketAttribute.raw_socket)
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
continue
|
||||
marker = id(sock)
|
||||
if marker in seen:
|
||||
continue
|
||||
seen.add(marker)
|
||||
yield sock
|
||||
|
||||
|
||||
def cleanup_dead_connections(agent) -> bool:
|
||||
"""Detect and clean up dead TCP connections on the primary client.
|
||||
|
||||
|
|
@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool:
|
|||
if client is None:
|
||||
return False
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return False
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return False
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return False
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
dead_count = 0
|
||||
for conn in list(connections):
|
||||
# Check for connections that are idle but have closed sockets
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
# Probe socket health with a non-blocking recv peek
|
||||
import socket as _socket
|
||||
try:
|
||||
|
|
@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int:
|
|||
|
||||
closed = 0
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return 0
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return 0
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return 0
|
||||
# httpx uses httpcore connection pools; connections live in
|
||||
# _connections (list) or _pool (list) depending on version.
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
for conn in list(connections):
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
try:
|
||||
sock.shutdown(_socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
|
|
@ -2154,5 +2168,6 @@ __all__ = [
|
|||
"cleanup_dead_connections",
|
||||
"extract_api_error_context",
|
||||
"apply_pending_steer_to_tool_results",
|
||||
"_iter_pool_sockets",
|
||||
"force_close_tcp_sockets",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
"""
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if agent.api_mode == "codex_responses":
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = agent._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
client=request_client,
|
||||
on_first_delta=getattr(agent, "_codex_on_first_delta", None),
|
||||
)
|
||||
elif agent.api_mode == "anthropic_messages":
|
||||
|
|
@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
raise
|
||||
result["response"] = normalize_converse_response(raw_response)
|
||||
else:
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
|
||||
result["response"] = request_client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="request_complete")
|
||||
_close_request_client_once("request_complete")
|
||||
|
||||
# ── Stale-call timeout (mirrors streaming stale detector) ────────
|
||||
# Non-streaming calls return nothing until the full response is
|
||||
|
|
@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_call_kill")
|
||||
_close_request_client_once("stale_call_kill")
|
||||
except Exception:
|
||||
pass
|
||||
agent._touch_activity(
|
||||
|
|
@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="interrupt_abort")
|
||||
_close_request_client_once("interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
|
|
@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
|
||||
result = {"response": None, "error": None, "partial_tool_names": []}
|
||||
request_client_holder = {"client": None, "diag": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
# Wall-clock timestamp of the last real streaming chunk. The outer
|
||||
|
|
@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
pool=_conn_cap,
|
||||
),
|
||||
}
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
)
|
||||
)
|
||||
# Reset stale-stream timer so the detector measures from this
|
||||
# attempt's start, not a previous attempt's last chunk.
|
||||
|
|
@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
# ``request_client_holder["diag"]`` for closure access.
|
||||
_diag = agent._stream_diag_init()
|
||||
request_client_holder["diag"] = _diag
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
stream = request_client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
# Capture rate limit headers from the initial HTTP response.
|
||||
# The OpenAI SDK Stream object exposes the underlying httpx
|
||||
|
|
@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
mid_tool_call=True,
|
||||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_mid_tool_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_mid_tool_retry_cleanup")
|
||||
try:
|
||||
agent._replace_primary_openai_client(
|
||||
reason="stream_mid_tool_retry_pool_cleanup"
|
||||
|
|
@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
# Close the stale request client before retry
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_retry_cleanup")
|
||||
# Also rebuild the primary client to purge
|
||||
# any dead connections from the pool.
|
||||
try:
|
||||
|
|
@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
result["error"] = e
|
||||
return
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
_close_request_client_once("stream_request_complete")
|
||||
|
||||
# Provider-configured stale timeout takes priority over env default.
|
||||
_cfg_stale = get_provider_stale_timeout(agent.provider, agent.model)
|
||||
|
|
@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
f"Reconnecting..."
|
||||
)
|
||||
try:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_stream_kill")
|
||||
_close_request_client_once("stale_stream_kill")
|
||||
except Exception:
|
||||
pass
|
||||
# Rebuild the primary client too — its connection pool
|
||||
|
|
@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
_close_request_client_once("stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
|
|
|
|||
|
|
@ -951,6 +951,58 @@ CREATE INDEX IF NOT EXISTS idx_notify_task ON kanban_notify_subs(task_
|
|||
|
||||
_INITIALIZED_PATHS: set[str] = set()
|
||||
_INIT_LOCK = threading.RLock()
|
||||
_SQLITE_HEADER = b"SQLite format 3\x00"
|
||||
|
||||
|
||||
def _looks_like_tls_record_at(data: bytes, offset: int) -> bool:
|
||||
"""Return True for a TLS record header at ``data[offset:]``."""
|
||||
if len(data) < offset + 5:
|
||||
return False
|
||||
content_type = data[offset]
|
||||
major = data[offset + 1]
|
||||
minor = data[offset + 2]
|
||||
length = int.from_bytes(data[offset + 3:offset + 5], "big")
|
||||
return (
|
||||
content_type in {0x14, 0x15, 0x16, 0x17}
|
||||
and major == 0x03
|
||||
and minor in {0x00, 0x01, 0x02, 0x03, 0x04}
|
||||
and 0 < length <= 18432
|
||||
)
|
||||
|
||||
|
||||
def _validate_sqlite_header(path: Path) -> None:
|
||||
"""Fail early with an actionable error for non-SQLite Kanban DB files.
|
||||
|
||||
``sqlite3.connect()`` creates missing and zero-byte files, so those are
|
||||
allowed. Existing non-empty files must have the SQLite header before we
|
||||
hand them to SQLite/WAL setup. This keeps corrupted page-0 failures from
|
||||
being collapsed into a generic PRAGMA error and lets the gateway's corrupt
|
||||
board handling identify the board by fingerprint.
|
||||
"""
|
||||
try:
|
||||
stat = path.stat()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except OSError:
|
||||
return
|
||||
if stat.st_size == 0:
|
||||
return
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
head = handle.read(64)
|
||||
except OSError:
|
||||
return
|
||||
if head.startswith(_SQLITE_HEADER):
|
||||
return
|
||||
signature = ""
|
||||
if head.startswith(b"SQLit") and _looks_like_tls_record_at(head, 5):
|
||||
signature = " (TLS record header detected at byte offset 5)"
|
||||
elif _looks_like_tls_record_at(head, 0):
|
||||
signature = " (TLS record header detected at byte offset 0)"
|
||||
raise sqlite3.DatabaseError(
|
||||
"file is not a database: invalid SQLite header for "
|
||||
f"{path}{signature}; first_32={head[:32].hex(' ')}"
|
||||
)
|
||||
|
||||
|
||||
def connect(
|
||||
|
|
@ -981,6 +1033,7 @@ def connect(
|
|||
else:
|
||||
path = kanban_db_path(board=board)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_validate_sqlite_header(path)
|
||||
resolved = str(path.resolve())
|
||||
conn = sqlite3.connect(str(path), isolation_level=None, timeout=30)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -48,6 +48,27 @@ def test_init_creates_expected_tables(kanban_home):
|
|||
assert {"tasks", "task_links", "task_comments", "task_events"} <= names
|
||||
|
||||
|
||||
def test_connect_rejects_tls_record_in_sqlite_header(tmp_path, monkeypatch):
|
||||
"""Kanban should classify TLS-looking page-0 clobbers before WAL setup."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.delenv("HERMES_KANBAN_DB", raising=False)
|
||||
monkeypatch.delenv("HERMES_KANBAN_HOME", raising=False)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
corrupt = home / "kanban.db"
|
||||
corrupt.write_bytes(b"SQLit" + bytes.fromhex("17 03 03 00 13") + b"x" * 32)
|
||||
|
||||
with pytest.raises(sqlite3.DatabaseError) as exc_info:
|
||||
kb.connect(board="default")
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "file is not a database" in msg
|
||||
assert "TLS record header detected at byte offset 5" in msg
|
||||
assert "53 51 4c 69 74 17 03 03 00 13" in msg
|
||||
|
||||
|
||||
def test_connect_migrates_legacy_db_before_optional_column_indexes(tmp_path):
|
||||
"""Legacy DBs missing additive indexed columns must migrate cleanly.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ with ``APIConnectionError('Connection error.')`` whose cause was
|
|||
That is the exact scenario this test reproduces at object level without a
|
||||
network, so it runs in CI on every PR.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
|
@ -186,3 +187,32 @@ def test_replace_primary_openai_client_survives_repeated_rebuilds():
|
|||
"Some _create_openai_client calls returned the same object across "
|
||||
"a teardown — rebuild is not producing fresh clients"
|
||||
)
|
||||
|
||||
|
||||
def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper():
|
||||
"""httpcore 1.x stores the real stream below conn._connection."""
|
||||
from agent.agent_runtime_helpers import force_close_tcp_sockets
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self):
|
||||
self.shutdown_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
def shutdown(self, _how):
|
||||
self.shutdown_calls += 1
|
||||
|
||||
def close(self):
|
||||
self.close_calls += 1
|
||||
|
||||
sock = FakeSocket()
|
||||
stream = SimpleNamespace(_sock=sock)
|
||||
http11 = SimpleNamespace(_network_stream=stream)
|
||||
pool_entry = SimpleNamespace(_connection=http11)
|
||||
pool = SimpleNamespace(_connections=[pool_entry])
|
||||
transport = SimpleNamespace(_pool=pool)
|
||||
http_client = SimpleNamespace(_transport=transport)
|
||||
openai_client = SimpleNamespace(_client=http_client)
|
||||
|
||||
assert force_close_tcp_sockets(openai_client) == 1
|
||||
assert sock.shutdown_calls == 1
|
||||
assert sock.close_calls == 1
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
|
@ -64,6 +65,7 @@ def _build_agent(shared_client=None):
|
|||
agent.stream_delta_callback = None
|
||||
agent._stream_callback = None
|
||||
agent.reasoning_callback = None
|
||||
agent.status_callback = None
|
||||
return agent
|
||||
|
||||
|
||||
|
|
@ -93,6 +95,24 @@ def test_retry_after_api_connection_error_recreates_request_client(monkeypatch):
|
|||
assert second_request.close_calls >= 1
|
||||
|
||||
|
||||
def test_stale_non_stream_close_is_single_owner(monkeypatch):
|
||||
def slow_responder(**kwargs):
|
||||
time.sleep(0.1)
|
||||
raise _connection_error()
|
||||
|
||||
request_client = FakeRequestClient(slow_responder)
|
||||
factory = OpenAIFactory([request_client])
|
||||
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
||||
|
||||
agent = _build_agent()
|
||||
agent._compute_non_stream_stale_timeout = lambda _messages: 0.01
|
||||
|
||||
with pytest.raises(APIConnectionError):
|
||||
agent._interruptible_api_call({"model": agent.model, "messages": []})
|
||||
|
||||
assert request_client.close_calls == 1
|
||||
|
||||
|
||||
def test_closed_shared_client_is_recreated_before_request(monkeypatch):
|
||||
stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used")))
|
||||
stale_shared._client.is_closed = True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue