From 64f7f36713b429f139f0da6461dfbd73e58160a0 Mon Sep 17 00:00:00 2001 From: teknium1 <127238744+teknium1@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:37:50 -0700 Subject: [PATCH] fix(mcp): make non-MCP HTTP endpoint fast-fail robust and non-retryable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reworks the content-type preflight so a misconfigured HTTP MCP url (a web-app root serving HTML) fails in <1s instead of hanging the full 60s connect_timeout — and does so non-retryably, which neither original PR achieved. - Allow-list detection (application/json, text/event-stream) instead of a text/html-only denylist — catches text/plain, application/xml, etc. - New NonMcpEndpointError(ConnectionError); run() catches it in the same top-level fast-fail block as InvalidMcpUrlError, so it returns before the reconnect-backoff loop (truly non-retryable) and the probe runs once, not on every reconnect. - Probe runs on its own httpx client OUTSIDE the SDK anyio task group, so the error propagates as itself rather than wrapped in an ExceptionGroup (the trap that made the in-SDK event-hook approach a no-op). - Forwards ssl_verify + client_cert + headers; HEAD->GET fallback on 405/501; best-effort pass-through on missing content type, non-2xx, and network errors; skips SSE transport. CancelledError is never swallowed. - Replaces the malformed test file (which never imported the real method and failed CI) with 21 tests driving the actual _preflight_content_type against a real local HTTP server, plus full run() integration verifying <1s non-retryable failure. Co-authored-by: liuhao1024 Co-authored-by: uzunkuyruk --- .../tools/test_mcp_preflight_content_type.py | 318 ++++++++++++------ tools/mcp_tool.py | 141 +++++--- 2 files changed, 310 insertions(+), 149 deletions(-) diff --git a/tests/tools/test_mcp_preflight_content_type.py b/tests/tools/test_mcp_preflight_content_type.py index 2880ce4f4b5..312aa48dfc9 100644 --- a/tests/tools/test_mcp_preflight_content_type.py +++ b/tests/tools/test_mcp_preflight_content_type.py @@ -1,137 +1,237 @@ -"""Tests for _MCPServer._preflight_content_type early-fail behaviour.""" +"""Tests for MCPServerTask._preflight_content_type fast-fail behaviour. + +These drive the REAL ``_preflight_content_type`` method against a real local +HTTP server (via httpx's ASGI/transport plumbing through a stdlib server), +rather than reimplementing the probe inline. That distinction matters: the +production probe must run on its own httpx client outside the MCP SDK's anyio +task group, and a faithful test must exercise that actual method so the +content-type allow-list, HEAD->GET fallback, and best-effort pass-through are +all covered as shipped. +""" from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +import http.server +import socketserver +import threading +from contextlib import contextmanager import pytest -from tools.mcp_tool import MCPServerTask +from tools.mcp_tool import MCPServerTask, NonMcpEndpointError -@pytest.fixture() -def server(): - """Return a minimal MCPServerTask instance (bypasses __init__ complexity).""" - s = MCPServerTask.__new__(MCPServerTask) - s.name = "test-server" - return s +def _make_task(name: str = "probe_srv") -> MCPServerTask: + """Minimal MCPServerTask without running the heavy __init__.""" + task = MCPServerTask.__new__(MCPServerTask) + task.name = name + return task + + +@contextmanager +def _serve(handler_cls): + """Run *handler_cls* on a background thread; yield its base URL.""" + httpd = socketserver.TCPServer(("127.0.0.1", 0), handler_cls) + port = httpd.server_address[1] + t = threading.Thread(target=httpd.serve_forever, daemon=True) + t.start() + try: + yield f"http://127.0.0.1:{port}" + finally: + httpd.shutdown() + httpd.server_close() + t.join(timeout=5) + + +def _handler(status: int = 200, + content_type: "str | None" = "text/html; charset=utf-8", + body: bytes = b"x", head_status=None, record=None): + """Build a BaseHTTPRequestHandler that replies with the given shape. + + ``head_status`` lets HEAD return a different status than GET (to exercise + the HEAD->GET fallback). ``record`` is an optional list that captures the + HTTP methods the server actually saw. + """ + + class _H(http.server.BaseHTTPRequestHandler): + def _write(self, sc, ct, payload): + self.send_response(sc) + if ct is not None: + self.send_header("Content-Type", ct) + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + if payload: + self.wfile.write(payload) + + def do_HEAD(self): + if record is not None: + record.append("HEAD") + sc = head_status if head_status is not None else status + self._write(sc, content_type, b"") + + def do_GET(self): + if record is not None: + record.append("GET") + self._write(status, content_type, body) + + def log_message(self, format, *args): # noqa: A002 + pass + + return _H # --------------------------------------------------------------------------- -# HTML response → ConnectionError +# Reject: non-MCP content types on a 2xx response # --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_preflight_rejects_html(server): - """A text/html response must raise ConnectionError immediately.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "text/html; charset=utf-8"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - with pytest.raises(ConnectionError, match="text/html"): - await server._preflight_content_type("https://example.com") +@pytest.mark.parametrize("content_type", [ + "text/html; charset=utf-8", + "text/html", + "text/plain", + "application/xml", + "text/HTML", # case-insensitivity +]) +def test_non_mcp_content_type_raises(content_type): + task = _make_task("bad_srv") + with _serve(_handler(status=200, content_type=content_type)) as base: + with pytest.raises(NonMcpEndpointError) as exc_info: + asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0)) + msg = str(exc_info.value) + assert "bad_srv" in msg + assert "application/json" in msg and "text/event-stream" in msg -@pytest.mark.asyncio -async def test_preflight_rejects_html_on_get_fallback(server): - """When HEAD returns 405, fall back to GET — still reject HTML.""" - head_response = MagicMock() - head_response.status_code = 405 - - get_response = MagicMock() - get_response.status_code = 200 - get_response.headers = {"content-type": "text/html"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=head_response) - mock_client.get = AsyncMock(return_value=get_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - with pytest.raises(ConnectionError, match="text/html"): - await server._preflight_content_type("https://example.com") +def test_non_mcp_error_is_non_retryable_connection_error(): + """NonMcpEndpointError must subclass ConnectionError (retry loop skips it + via an explicit except; broad ConnectionError catchers still work).""" + assert issubclass(NonMcpEndpointError, ConnectionError) # --------------------------------------------------------------------------- -# Non-HTML responses → silent pass-through +# Pass-through: valid MCP content types, ambiguous, and error responses # --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_preflight_accepts_json(server): - """application/json must NOT raise.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "application/json"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - # Should not raise - await server._preflight_content_type("https://mcp-server.example.com/mcp") +@pytest.mark.parametrize("content_type", [ + "application/json", + "application/json; charset=utf-8", + "text/event-stream", + "TEXT/EVENT-STREAM", +]) +def test_valid_mcp_content_types_pass(content_type): + task = _make_task() + with _serve(_handler(status=200, content_type=content_type, body=b"{}")) as base: + # Must not raise. + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_accepts_no_content_type(server): - """Missing Content-Type header must NOT raise.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - await server._preflight_content_type("https://mcp-server.example.com/mcp") +def test_missing_content_type_passes(): + task = _make_task() + with _serve(_handler(status=200, content_type=None, body=b"")) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_swallows_network_errors(server): - """Network errors / timeouts must silently pass through.""" - - mock_client = AsyncMock() - mock_client.head = AsyncMock(side_effect=TimeoutError("connect timed out")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - # Should not raise — let the real MCP handshake deal with it - await server._preflight_content_type("https://unreachable.example.com") +@pytest.mark.parametrize("status", [401, 403, 404, 500, 503]) +def test_non_2xx_responses_pass(status): + """4xx/5xx are auth challenges or transient errors — let the SDK handle.""" + task = _make_task() + with _serve(_handler(status=status, content_type="text/html")) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_passes_headers_and_verify(server): - """Custom headers and ssl_verify are forwarded to the probe client.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "application/json"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client) as client_cls: - await server._preflight_content_type( - "https://mcp.example.com/mcp", - headers={"Authorization": "Bearer tok"}, - ssl_verify=False, +def test_network_error_passes(): + """A connection failure (nothing listening) must pass through, not raise.""" + task = _make_task() + # Reserve a port then close it so the connection is refused. + s = socketserver.TCPServer(("127.0.0.1", 0), http.server.BaseHTTPRequestHandler) + dead_port = s.server_address[1] + s.server_close() + asyncio.run( + task._preflight_content_type( + f"http://127.0.0.1:{dead_port}/mcp", timeout=2.0 ) - # Verify the client was created with ssl_verify=False - client_cls.assert_called_once() - call_kwargs = client_cls.call_args - assert call_kwargs.kwargs.get("verify") is False + ) + + +def test_cancelled_error_is_not_swallowed(): + """The best-effort except must NOT catch CancelledError (BaseException).""" + task = _make_task() + + async def _run(): + import httpx + orig = httpx.AsyncClient + try: + # Patch the client so entering it raises CancelledError. + class _C(orig): + async def __aenter__(self): + raise asyncio.CancelledError() + + httpx.AsyncClient = _C + with pytest.raises(asyncio.CancelledError): + await task._preflight_content_type("http://x/mcp", timeout=1.0) + finally: + httpx.AsyncClient = orig + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# HEAD -> GET fallback +# --------------------------------------------------------------------------- + +def test_head_405_falls_back_to_get_and_rejects_html(): + task = _make_task("fallback_srv") + record: list[str] = [] + with _serve(_handler( + status=200, content_type="text/html", + head_status=405, record=record, + )) as base: + with pytest.raises(NonMcpEndpointError): + asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0)) + assert record == ["HEAD", "GET"] + + +def test_head_501_falls_back_to_get_and_passes_json(): + task = _make_task() + record: list[str] = [] + with _serve(_handler( + status=200, content_type="application/json", body=b"{}", + head_status=501, record=record, + )) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) + assert record == ["HEAD", "GET"] + + +# --------------------------------------------------------------------------- +# ssl_verify / client_cert forwarding to the probe client +# --------------------------------------------------------------------------- + +def test_ssl_verify_and_cert_forwarded(monkeypatch): + captured: dict = {} + + import httpx + + class _FakeClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def head(self, url, headers=None): + return httpx.Response(200, headers={"content-type": "application/json"}) + + monkeypatch.setattr(httpx, "AsyncClient", _FakeClient) + task = _make_task() + asyncio.run(task._preflight_content_type( + "https://mcp.example.com/mcp", + ssl_verify=False, + client_cert="/path/to/cert.pem", + timeout=3.0, + )) + assert captured.get("verify") is False + assert captured.get("cert") == "/path/to/cert.pem" + assert captured.get("follow_redirects") is True diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e9d1fc0b7fc..b15fcdaecb1 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -518,6 +518,21 @@ class InvalidMcpUrlError(ValueError): """ +class NonMcpEndpointError(ConnectionError): + """Raised when an HTTP MCP URL serves a non-MCP response. + + A genuine MCP Streamable-HTTP endpoint answers with ``application/json`` + or ``text/event-stream``. Anything else on a 2xx response (typically + ``text/html`` from a web-app root) means the configured ``url`` points at + the wrong place. This is non-retryable: every attempt returns the same + page, so the reconnect-backoff loop is skipped and the server is reported + failed immediately with an actionable message. + + Subclasses :class:`ConnectionError` so callers that only catch the broad + class still treat it as a connection problem. + """ + + def _validate_remote_mcp_url(server_name: str, url: Any) -> str: """Return the URL as a string if it's a valid http(s) remote MCP URL. @@ -1457,53 +1472,85 @@ class MCPServerTask: # PID-reuse can't surface stale pgroup state later. _stdio_pgids.pop(pid, None) - @staticmethod + # Content types a real MCP Streamable-HTTP endpoint may return on the + # initial POST/GET. Anything else on a 2xx response means the URL is not + # an MCP endpoint. + _MCP_CONTENT_TYPES = ("application/json", "text/event-stream") + async def _preflight_content_type( + self, url: str, *, headers: Optional[dict] = None, ssl_verify: bool = True, + client_cert=None, timeout: float = 5.0, ) -> None: - """Quick content-type probe before handing *url* to the MCP SDK. + """Probe *url* for an MCP-shaped response before the SDK connects. - A misconfigured ``mcp_servers..url`` that points at a plain web - app (returning ``text/html``) causes the MCP SDK to sit on the - connection for the full ``connect_timeout`` (default 60 s) before - surfacing ``CancelledError``. A cheap HEAD request lets us detect - this in ≤ 5 s and raise immediately with an actionable message. + A misconfigured ``mcp_servers..url`` pointed at a plain web app + returns HTML (or some other non-MCP body). The MCP SDK then sits on + the connection for the full ``connect_timeout`` (default 60 s) before + surfacing an opaque ``CancelledError``. A cheap, short-timeout probe + here catches that in ≤ ``timeout`` seconds and raises + :class:`NonMcpEndpointError` with an actionable message. - Non-HTML responses (``application/json``, missing header, network - errors) silently pass through so the normal MCP handshake proceeds. + Detection is allow-list based: a 2xx response is rejected only when it + carries a definite content type that is NOT one an MCP endpoint uses + (``application/json`` / ``text/event-stream``). A missing or empty + content type, non-2xx status, or any network/transport error passes + through silently — the probe is strictly best-effort, and the real + handshake remains the source of truth for everything except the + unambiguous "this is a web page, not MCP" case. + + Runs on its own httpx client OUTSIDE the SDK's anyio task group, so the + raised error propagates as itself rather than being wrapped in an + ``ExceptionGroup`` (which is what defeats hooks installed inside the + SDK transport). """ try: import httpx as _httpx + except ImportError: + return # No httpx → skip probe; SDK import would have failed first. - probe_headers = dict(headers) if headers else {} - # HEAD is idempotent and lightweight; fall back to GET if the - # server rejects HEAD (405 Method Not Allowed). - async with _httpx.AsyncClient( - verify=ssl_verify, - follow_redirects=True, - timeout=_httpx.Timeout(timeout), - ) as client: + client_kwargs: dict = { + "verify": ssl_verify, + "follow_redirects": True, + "timeout": _httpx.Timeout(timeout), + } + if client_cert is not None: + client_kwargs["cert"] = client_cert + + probe_headers = dict(headers) if headers else {} + try: + async with _httpx.AsyncClient(**client_kwargs) as client: + # HEAD is cheapest; fall back to GET if the server doesn't + # implement it (405 Method Not Allowed / 501 Not Implemented). resp = await client.head(url, headers=probe_headers) - if resp.status_code == 405: + if resp.status_code in (405, 501): resp = await client.get(url, headers=probe_headers) - ct = resp.headers.get("content-type", "") - if "text/html" in ct.lower(): - raise ConnectionError( - f"MCP server '{url}' returned Content-Type: {ct}. " - "This looks like a regular web page, not an MCP endpoint. " - "Verify the URL points to an MCP Streamable HTTP or SSE " - "endpoint (e.g. https://host/mcp, not https://host/)." - ) - except ConnectionError: - raise - except Exception: - # Network errors, timeouts, etc. — let the real MCP handshake - # deal with them; this is just a best-effort early check. - pass + except _httpx.HTTPError: + return # DNS/connect/timeout/transport error — let the SDK try. + + # Only judge successful responses. A 4xx/5xx may be an auth challenge + # or a transient error the real handshake handles correctly. + if not (200 <= resp.status_code < 300): + return + + ct_base = resp.headers.get("content-type", "").split(";")[0].strip().lower() + if not ct_base: + return # No content type advertised — don't second-guess the SDK. + if ct_base in self._MCP_CONTENT_TYPES: + return # Looks like a real MCP endpoint. + + raise NonMcpEndpointError( + f"MCP server '{self.name}' at {url} returned Content-Type " + f"'{ct_base}', not an MCP response (expected one of: " + f"{', '.join(self._MCP_CONTENT_TYPES)}). The URL most likely " + "points at a web page rather than an MCP endpoint — check it " + "resolves to a Streamable HTTP / SSE endpoint " + "(e.g. https://host/mcp, not https://host/)." + ) async def _run_http(self, config: dict): """Run the server using HTTP/StreamableHTTP transport.""" @@ -1515,14 +1562,6 @@ class MCPServerTask: ) url = config["url"] - # Pre-flight: reject obvious non-MCP endpoints (e.g. a web app - # returning HTML) in seconds instead of waiting the full - # connect_timeout (default 60 s). - await self._preflight_content_type( - url, - headers=dict(config.get("headers") or {}), - ssl_verify=config.get("ssl_verify", True), - ) headers = dict(config.get("headers") or {}) # Some MCP servers require MCP-Protocol-Version on the initial # initialize request and reject session-less POSTs otherwise. @@ -1754,6 +1793,28 @@ class MCPServerTask: self._ready.set() return + # Pre-flight content-type probe (Streamable HTTP only; SSE is + # exercised by its own client and legitimately serves + # text/event-stream). A URL pointed at a web-app root returns + # HTML, which makes the SDK hang for the full connect_timeout + # before surfacing an opaque CancelledError. Probing here — once, + # outside the SDK task group — fails fast and non-retryably with + # an actionable message, mirroring the URL-validation path above. + if config.get("transport") != "sse": + try: + _probe_headers = dict(config.get("headers") or {}) + await self._preflight_content_type( + config["url"], + headers=_probe_headers, + ssl_verify=config.get("ssl_verify", True), + client_cert=_resolve_client_cert(self.name, config), + ) + except NonMcpEndpointError as exc: + logger.warning("%s", exc) + self._error = exc + self._ready.set() + return + retries = 0 initial_retries = 0 backoff = 1.0