diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index 7a9a0961a75..b98fe4b44e7 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No +def _iter_pool_sockets(client: Any): + """Yield raw sockets reachable from an OpenAI/httpx client pool. + + httpcore 1.x stores the concrete HTTP11/HTTP2 connection under + ``conn._connection``; older versions exposed stream attributes directly + on the pool entry. Keep the traversal defensive because these are private + transport internals and vary across httpx/httpcore releases. + """ + try: + http_client = getattr(client, "_client", None) + if http_client is None: + return + transport = getattr(http_client, "_transport", None) + if transport is None: + return + pool = getattr(transport, "_pool", None) + if pool is None: + return + connections = ( + getattr(pool, "_connections", None) + or getattr(pool, "_pool", None) + or [] + ) + except Exception: + return + + seen: set[int] = set() + for conn in list(connections): + candidates = [conn] + inner = getattr(conn, "_connection", None) + if inner is not None: + candidates.append(inner) + for candidate in candidates: + stream = ( + getattr(candidate, "_network_stream", None) + or getattr(candidate, "_stream", None) + ) + if stream is None: + continue + sock = getattr(stream, "_sock", None) + if sock is None: + get_extra_info = getattr(stream, "get_extra_info", None) + if callable(get_extra_info): + try: + sock = get_extra_info("socket") + except Exception: + sock = None + if sock is None: + wrapped = getattr(stream, "stream", None) + if wrapped is not None: + sock = getattr(wrapped, "_sock", None) + if sock is None: + # anyio-backed streams expose the raw socket through + # SocketAttribute.raw_socket when available. + wrapped = getattr(stream, "_stream", None) + extra = getattr(wrapped, "extra", None) + if callable(extra): + try: + from anyio.abc import SocketAttribute + sock = extra(SocketAttribute.raw_socket) + except Exception: + sock = None + if sock is None: + continue + marker = id(sock) + if marker in seen: + continue + seen.add(marker) + yield sock + + def cleanup_dead_connections(agent) -> bool: """Detect and clean up dead TCP connections on the primary client. @@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool: if client is None: return False try: - http_client = getattr(client, "_client", None) - if http_client is None: - return False - transport = getattr(http_client, "_transport", None) - if transport is None: - return False - pool = getattr(transport, "_pool", None) - if pool is None: - return False - connections = ( - getattr(pool, "_connections", None) - or getattr(pool, "_pool", None) - or [] - ) dead_count = 0 - for conn in list(connections): - # Check for connections that are idle but have closed sockets - stream = ( - getattr(conn, "_network_stream", None) - or getattr(conn, "_stream", None) - ) - if stream is None: - continue - sock = getattr(stream, "_sock", None) - if sock is None: - sock = getattr(stream, "stream", None) - if sock is not None: - sock = getattr(sock, "_sock", None) - if sock is None: - continue + for sock in _iter_pool_sockets(client): # Probe socket health with a non-blocking recv peek import socket as _socket try: @@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int: closed = 0 try: - http_client = getattr(client, "_client", None) - if http_client is None: - return 0 - transport = getattr(http_client, "_transport", None) - if transport is None: - return 0 - pool = getattr(transport, "_pool", None) - if pool is None: - return 0 - # httpx uses httpcore connection pools; connections live in - # _connections (list) or _pool (list) depending on version. - connections = ( - getattr(pool, "_connections", None) - or getattr(pool, "_pool", None) - or [] - ) - for conn in list(connections): - stream = ( - getattr(conn, "_network_stream", None) - or getattr(conn, "_stream", None) - ) - if stream is None: - continue - sock = getattr(stream, "_sock", None) - if sock is None: - sock = getattr(stream, "stream", None) - if sock is not None: - sock = getattr(sock, "_sock", None) - if sock is None: - continue + for sock in _iter_pool_sockets(client): try: sock.shutdown(_socket.SHUT_RDWR) except OSError: @@ -2154,5 +2168,6 @@ __all__ = [ "cleanup_dead_connections", "extract_api_error_context", "apply_pending_steer_to_tool_results", + "_iter_pool_sockets", "force_close_tcp_sockets", ] diff --git a/agent/chat_completion_helpers.py b/agent/chat_completion_helpers.py index 2e0caebcbe3..c68f2271f5b 100644 --- a/agent/chat_completion_helpers.py +++ b/agent/chat_completion_helpers.py @@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict): """ result = {"response": None, "error": None} request_client_holder = {"client": None} + request_client_lock = threading.Lock() + + def _set_request_client(client): + with request_client_lock: + request_client_holder["client"] = client + return client + + def _take_request_client(): + with request_client_lock: + client = request_client_holder.get("client") + request_client_holder["client"] = None + return client + + def _close_request_client_once(reason: str) -> None: + request_client = _take_request_client() + if request_client is not None: + agent._close_request_openai_client(request_client, reason=reason) def _call(): try: if agent.api_mode == "codex_responses": - request_client_holder["client"] = agent._create_request_openai_client( - reason="codex_stream_request", - api_kwargs=api_kwargs, + request_client = _set_request_client( + agent._create_request_openai_client( + reason="codex_stream_request", + api_kwargs=api_kwargs, + ) ) result["response"] = agent._run_codex_stream( api_kwargs, - client=request_client_holder["client"], + client=request_client, on_first_delta=getattr(agent, "_codex_on_first_delta", None), ) elif agent.api_mode == "anthropic_messages": @@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict): raise result["response"] = normalize_converse_response(raw_response) else: - request_client_holder["client"] = agent._create_request_openai_client( - reason="chat_completion_request", - api_kwargs=api_kwargs, + request_client = _set_request_client( + agent._create_request_openai_client( + reason="chat_completion_request", + api_kwargs=api_kwargs, + ) ) - result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs) + result["response"] = request_client.chat.completions.create(**api_kwargs) except Exception as e: result["error"] = e finally: - request_client = request_client_holder.get("client") - if request_client is not None: - agent._close_request_openai_client(request_client, reason="request_complete") + _close_request_client_once("request_complete") # ── Stale-call timeout (mirrors streaming stale detector) ──────── # Non-streaming calls return nothing until the full response is @@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict): agent._anthropic_client.close() agent._rebuild_anthropic_client() else: - rc = request_client_holder.get("client") - if rc is not None: - agent._close_request_openai_client(rc, reason="stale_call_kill") + _close_request_client_once("stale_call_kill") except Exception: pass agent._touch_activity( @@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict): agent._anthropic_client.close() agent._rebuild_anthropic_client() else: - request_client = request_client_holder.get("client") - if request_client is not None: - agent._close_request_openai_client(request_client, reason="interrupt_abort") + _close_request_client_once("interrupt_abort") except Exception: pass raise InterruptedError("Agent interrupted during API call") @@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= result = {"response": None, "error": None, "partial_tool_names": []} request_client_holder = {"client": None, "diag": None} + request_client_lock = threading.Lock() + + def _set_request_client(client): + with request_client_lock: + request_client_holder["client"] = client + return client + + def _take_request_client(): + with request_client_lock: + client = request_client_holder.get("client") + request_client_holder["client"] = None + return client + + def _close_request_client_once(reason: str) -> None: + request_client = _take_request_client() + if request_client is not None: + agent._close_request_openai_client(request_client, reason=reason) + first_delta_fired = {"done": False} deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback) # Wall-clock timestamp of the last real streaming chunk. The outer @@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= pool=_conn_cap, ), } - request_client_holder["client"] = agent._create_request_openai_client( - reason="chat_completion_stream_request", - api_kwargs=stream_kwargs, + request_client = _set_request_client( + agent._create_request_openai_client( + reason="chat_completion_stream_request", + api_kwargs=stream_kwargs, + ) ) # Reset stale-stream timer so the detector measures from this # attempt's start, not a previous attempt's last chunk. @@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= # ``request_client_holder["diag"]`` for closure access. _diag = agent._stream_diag_init() request_client_holder["diag"] = _diag - stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) + stream = request_client.chat.completions.create(**stream_kwargs) # Capture rate limit headers from the initial HTTP response. # The OpenAI SDK Stream object exposes the underlying httpx @@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= mid_tool_call=True, diag=request_client_holder.get("diag"), ) - stale = request_client_holder.get("client") - if stale is not None: - agent._close_request_openai_client( - stale, reason="stream_mid_tool_retry_cleanup" - ) - request_client_holder["client"] = None + _close_request_client_once("stream_mid_tool_retry_cleanup") try: agent._replace_primary_openai_client( reason="stream_mid_tool_retry_pool_cleanup" @@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= diag=request_client_holder.get("diag"), ) # Close the stale request client before retry - stale = request_client_holder.get("client") - if stale is not None: - agent._close_request_openai_client( - stale, reason="stream_retry_cleanup" - ) - request_client_holder["client"] = None + _close_request_client_once("stream_retry_cleanup") # Also rebuild the primary client to purge # any dead connections from the pool. try: @@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= result["error"] = e return finally: - request_client = request_client_holder.get("client") - if request_client is not None: - agent._close_request_openai_client(request_client, reason="stream_request_complete") + _close_request_client_once("stream_request_complete") # Provider-configured stale timeout takes priority over env default. _cfg_stale = get_provider_stale_timeout(agent.provider, agent.model) @@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= f"Reconnecting..." ) try: - rc = request_client_holder.get("client") - if rc is not None: - agent._close_request_openai_client(rc, reason="stale_stream_kill") + _close_request_client_once("stale_stream_kill") except Exception: pass # Rebuild the primary client too — its connection pool @@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= agent._anthropic_client.close() agent._rebuild_anthropic_client() else: - request_client = request_client_holder.get("client") - if request_client is not None: - agent._close_request_openai_client(request_client, reason="stream_interrupt_abort") + _close_request_client_once("stream_interrupt_abort") except Exception: pass raise InterruptedError("Agent interrupted during streaming API call") diff --git a/hermes_cli/kanban_db.py b/hermes_cli/kanban_db.py index d557354238c..7a30b70987f 100644 --- a/hermes_cli/kanban_db.py +++ b/hermes_cli/kanban_db.py @@ -951,6 +951,58 @@ CREATE INDEX IF NOT EXISTS idx_notify_task ON kanban_notify_subs(task_ _INITIALIZED_PATHS: set[str] = set() _INIT_LOCK = threading.RLock() +_SQLITE_HEADER = b"SQLite format 3\x00" + + +def _looks_like_tls_record_at(data: bytes, offset: int) -> bool: + """Return True for a TLS record header at ``data[offset:]``.""" + if len(data) < offset + 5: + return False + content_type = data[offset] + major = data[offset + 1] + minor = data[offset + 2] + length = int.from_bytes(data[offset + 3:offset + 5], "big") + return ( + content_type in {0x14, 0x15, 0x16, 0x17} + and major == 0x03 + and minor in {0x00, 0x01, 0x02, 0x03, 0x04} + and 0 < length <= 18432 + ) + + +def _validate_sqlite_header(path: Path) -> None: + """Fail early with an actionable error for non-SQLite Kanban DB files. + + ``sqlite3.connect()`` creates missing and zero-byte files, so those are + allowed. Existing non-empty files must have the SQLite header before we + hand them to SQLite/WAL setup. This keeps corrupted page-0 failures from + being collapsed into a generic PRAGMA error and lets the gateway's corrupt + board handling identify the board by fingerprint. + """ + try: + stat = path.stat() + except FileNotFoundError: + return + except OSError: + return + if stat.st_size == 0: + return + try: + with path.open("rb") as handle: + head = handle.read(64) + except OSError: + return + if head.startswith(_SQLITE_HEADER): + return + signature = "" + if head.startswith(b"SQLit") and _looks_like_tls_record_at(head, 5): + signature = " (TLS record header detected at byte offset 5)" + elif _looks_like_tls_record_at(head, 0): + signature = " (TLS record header detected at byte offset 0)" + raise sqlite3.DatabaseError( + "file is not a database: invalid SQLite header for " + f"{path}{signature}; first_32={head[:32].hex(' ')}" + ) def connect( @@ -981,6 +1033,7 @@ def connect( else: path = kanban_db_path(board=board) path.parent.mkdir(parents=True, exist_ok=True) + _validate_sqlite_header(path) resolved = str(path.resolve()) conn = sqlite3.connect(str(path), isolation_level=None, timeout=30) try: diff --git a/tests/hermes_cli/test_kanban_db.py b/tests/hermes_cli/test_kanban_db.py index 64ed630db1c..435ef41001a 100644 --- a/tests/hermes_cli/test_kanban_db.py +++ b/tests/hermes_cli/test_kanban_db.py @@ -48,6 +48,27 @@ def test_init_creates_expected_tables(kanban_home): assert {"tasks", "task_links", "task_comments", "task_events"} <= names +def test_connect_rejects_tls_record_in_sqlite_header(tmp_path, monkeypatch): + """Kanban should classify TLS-looking page-0 clobbers before WAL setup.""" + home = tmp_path / ".hermes" + home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(home)) + monkeypatch.delenv("HERMES_KANBAN_DB", raising=False) + monkeypatch.delenv("HERMES_KANBAN_HOME", raising=False) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + corrupt = home / "kanban.db" + corrupt.write_bytes(b"SQLit" + bytes.fromhex("17 03 03 00 13") + b"x" * 32) + + with pytest.raises(sqlite3.DatabaseError) as exc_info: + kb.connect(board="default") + + msg = str(exc_info.value) + assert "file is not a database" in msg + assert "TLS record header detected at byte offset 5" in msg + assert "53 51 4c 69 74 17 03 03 00 13" in msg + + def test_connect_migrates_legacy_db_before_optional_column_indexes(tmp_path): """Legacy DBs missing additive indexed columns must migrate cleanly. diff --git a/tests/run_agent/test_create_openai_client_reuse.py b/tests/run_agent/test_create_openai_client_reuse.py index 0eac567ae6c..13d95a46634 100644 --- a/tests/run_agent/test_create_openai_client_reuse.py +++ b/tests/run_agent/test_create_openai_client_reuse.py @@ -16,6 +16,7 @@ with ``APIConnectionError('Connection error.')`` whose cause was That is the exact scenario this test reproduces at object level without a network, so it runs in CI on every PR. """ +from types import SimpleNamespace from unittest.mock import MagicMock, patch from run_agent import AIAgent @@ -186,3 +187,32 @@ def test_replace_primary_openai_client_survives_repeated_rebuilds(): "Some _create_openai_client calls returned the same object across " "a teardown — rebuild is not producing fresh clients" ) + + +def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper(): + """httpcore 1.x stores the real stream below conn._connection.""" + from agent.agent_runtime_helpers import force_close_tcp_sockets + + class FakeSocket: + def __init__(self): + self.shutdown_calls = 0 + self.close_calls = 0 + + def shutdown(self, _how): + self.shutdown_calls += 1 + + def close(self): + self.close_calls += 1 + + sock = FakeSocket() + stream = SimpleNamespace(_sock=sock) + http11 = SimpleNamespace(_network_stream=stream) + pool_entry = SimpleNamespace(_connection=http11) + pool = SimpleNamespace(_connections=[pool_entry]) + transport = SimpleNamespace(_pool=pool) + http_client = SimpleNamespace(_transport=transport) + openai_client = SimpleNamespace(_client=http_client) + + assert force_close_tcp_sockets(openai_client) == 1 + assert sock.shutdown_calls == 1 + assert sock.close_calls == 1 diff --git a/tests/run_agent/test_openai_client_lifecycle.py b/tests/run_agent/test_openai_client_lifecycle.py index 72d92fd15e1..35a8ec7a084 100644 --- a/tests/run_agent/test_openai_client_lifecycle.py +++ b/tests/run_agent/test_openai_client_lifecycle.py @@ -1,5 +1,6 @@ import sys import threading +import time import types from types import SimpleNamespace @@ -64,6 +65,7 @@ def _build_agent(shared_client=None): agent.stream_delta_callback = None agent._stream_callback = None agent.reasoning_callback = None + agent.status_callback = None return agent @@ -93,6 +95,24 @@ def test_retry_after_api_connection_error_recreates_request_client(monkeypatch): assert second_request.close_calls >= 1 +def test_stale_non_stream_close_is_single_owner(monkeypatch): + def slow_responder(**kwargs): + time.sleep(0.1) + raise _connection_error() + + request_client = FakeRequestClient(slow_responder) + factory = OpenAIFactory([request_client]) + monkeypatch.setattr(run_agent, "OpenAI", factory) + + agent = _build_agent() + agent._compute_non_stream_stale_timeout = lambda _messages: 0.01 + + with pytest.raises(APIConnectionError): + agent._interruptible_api_call({"model": agent.model, "messages": []}) + + assert request_client.close_calls == 1 + + def test_closed_shared_client_is_recreated_before_request(monkeypatch): stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used"))) stale_shared._client.is_closed = True