fix(mcp): make non-MCP HTTP endpoint fast-fail robust and non-retryable

Reworks the content-type preflight so a misconfigured HTTP MCP url (a web-app
root serving HTML) fails in <1s instead of hanging the full 60s connect_timeout
— and does so non-retryably, which neither original PR achieved.

- Allow-list detection (application/json, text/event-stream) instead of a
  text/html-only denylist — catches text/plain, application/xml, etc.
- New NonMcpEndpointError(ConnectionError); run() catches it in the same
  top-level fast-fail block as InvalidMcpUrlError, so it returns before the
  reconnect-backoff loop (truly non-retryable) and the probe runs once, not
  on every reconnect.
- Probe runs on its own httpx client OUTSIDE the SDK anyio task group, so the
  error propagates as itself rather than wrapped in an ExceptionGroup (the
  trap that made the in-SDK event-hook approach a no-op).
- Forwards ssl_verify + client_cert + headers; HEAD->GET fallback on 405/501;
  best-effort pass-through on missing content type, non-2xx, and network
  errors; skips SSE transport. CancelledError is never swallowed.
- Replaces the malformed test file (which never imported the real method and
  failed CI) with 21 tests driving the actual _preflight_content_type against
  a real local HTTP server, plus full run() integration verifying <1s
  non-retryable failure.

Co-authored-by: liuhao1024 <sunsky.lau@gmail.com>
Co-authored-by: uzunkuyruk <egitimviscara@gmail.com>
This commit is contained in:
teknium1 2026-06-01 17:37:50 -07:00 committed by Teknium
parent c914e4a371
commit 64f7f36713
2 changed files with 310 additions and 149 deletions

View file

@ -1,137 +1,237 @@
"""Tests for _MCPServer._preflight_content_type early-fail behaviour."""
"""Tests for MCPServerTask._preflight_content_type fast-fail behaviour.
These drive the REAL ``_preflight_content_type`` method against a real local
HTTP server (via httpx's ASGI/transport plumbing through a stdlib server),
rather than reimplementing the probe inline. That distinction matters: the
production probe must run on its own httpx client outside the MCP SDK's anyio
task group, and a faithful test must exercise that actual method so the
content-type allow-list, HEAD->GET fallback, and best-effort pass-through are
all covered as shipped.
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import http.server
import socketserver
import threading
from contextlib import contextmanager
import pytest
from tools.mcp_tool import MCPServerTask
from tools.mcp_tool import MCPServerTask, NonMcpEndpointError
@pytest.fixture()
def server():
"""Return a minimal MCPServerTask instance (bypasses __init__ complexity)."""
s = MCPServerTask.__new__(MCPServerTask)
s.name = "test-server"
return s
def _make_task(name: str = "probe_srv") -> MCPServerTask:
"""Minimal MCPServerTask without running the heavy __init__."""
task = MCPServerTask.__new__(MCPServerTask)
task.name = name
return task
@contextmanager
def _serve(handler_cls):
"""Run *handler_cls* on a background thread; yield its base URL."""
httpd = socketserver.TCPServer(("127.0.0.1", 0), handler_cls)
port = httpd.server_address[1]
t = threading.Thread(target=httpd.serve_forever, daemon=True)
t.start()
try:
yield f"http://127.0.0.1:{port}"
finally:
httpd.shutdown()
httpd.server_close()
t.join(timeout=5)
def _handler(status: int = 200,
content_type: "str | None" = "text/html; charset=utf-8",
body: bytes = b"<html>x</html>", head_status=None, record=None):
"""Build a BaseHTTPRequestHandler that replies with the given shape.
``head_status`` lets HEAD return a different status than GET (to exercise
the HEAD->GET fallback). ``record`` is an optional list that captures the
HTTP methods the server actually saw.
"""
class _H(http.server.BaseHTTPRequestHandler):
def _write(self, sc, ct, payload):
self.send_response(sc)
if ct is not None:
self.send_header("Content-Type", ct)
self.send_header("Content-Length", str(len(payload)))
self.end_headers()
if payload:
self.wfile.write(payload)
def do_HEAD(self):
if record is not None:
record.append("HEAD")
sc = head_status if head_status is not None else status
self._write(sc, content_type, b"")
def do_GET(self):
if record is not None:
record.append("GET")
self._write(status, content_type, body)
def log_message(self, format, *args): # noqa: A002
pass
return _H
# ---------------------------------------------------------------------------
# HTML response → ConnectionError
# Reject: non-MCP content types on a 2xx response
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_preflight_rejects_html(server):
"""A text/html response must raise ConnectionError immediately."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "text/html; charset=utf-8"}
mock_client = AsyncMock()
mock_client.head = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with pytest.raises(ConnectionError, match="text/html"):
await server._preflight_content_type("https://example.com")
@pytest.mark.parametrize("content_type", [
"text/html; charset=utf-8",
"text/html",
"text/plain",
"application/xml",
"text/HTML", # case-insensitivity
])
def test_non_mcp_content_type_raises(content_type):
task = _make_task("bad_srv")
with _serve(_handler(status=200, content_type=content_type)) as base:
with pytest.raises(NonMcpEndpointError) as exc_info:
asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0))
msg = str(exc_info.value)
assert "bad_srv" in msg
assert "application/json" in msg and "text/event-stream" in msg
@pytest.mark.asyncio
async def test_preflight_rejects_html_on_get_fallback(server):
"""When HEAD returns 405, fall back to GET — still reject HTML."""
head_response = MagicMock()
head_response.status_code = 405
get_response = MagicMock()
get_response.status_code = 200
get_response.headers = {"content-type": "text/html"}
mock_client = AsyncMock()
mock_client.head = AsyncMock(return_value=head_response)
mock_client.get = AsyncMock(return_value=get_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with pytest.raises(ConnectionError, match="text/html"):
await server._preflight_content_type("https://example.com")
def test_non_mcp_error_is_non_retryable_connection_error():
"""NonMcpEndpointError must subclass ConnectionError (retry loop skips it
via an explicit except; broad ConnectionError catchers still work)."""
assert issubclass(NonMcpEndpointError, ConnectionError)
# ---------------------------------------------------------------------------
# Non-HTML responses → silent pass-through
# Pass-through: valid MCP content types, ambiguous, and error responses
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_preflight_accepts_json(server):
"""application/json must NOT raise."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_client = AsyncMock()
mock_client.head = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
# Should not raise
await server._preflight_content_type("https://mcp-server.example.com/mcp")
@pytest.mark.parametrize("content_type", [
"application/json",
"application/json; charset=utf-8",
"text/event-stream",
"TEXT/EVENT-STREAM",
])
def test_valid_mcp_content_types_pass(content_type):
task = _make_task()
with _serve(_handler(status=200, content_type=content_type, body=b"{}")) as base:
# Must not raise.
asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0))
@pytest.mark.asyncio
async def test_preflight_accepts_no_content_type(server):
"""Missing Content-Type header must NOT raise."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {}
mock_client = AsyncMock()
mock_client.head = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
await server._preflight_content_type("https://mcp-server.example.com/mcp")
def test_missing_content_type_passes():
task = _make_task()
with _serve(_handler(status=200, content_type=None, body=b"")) as base:
asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0))
@pytest.mark.asyncio
async def test_preflight_swallows_network_errors(server):
"""Network errors / timeouts must silently pass through."""
mock_client = AsyncMock()
mock_client.head = AsyncMock(side_effect=TimeoutError("connect timed out"))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
# Should not raise — let the real MCP handshake deal with it
await server._preflight_content_type("https://unreachable.example.com")
@pytest.mark.parametrize("status", [401, 403, 404, 500, 503])
def test_non_2xx_responses_pass(status):
"""4xx/5xx are auth challenges or transient errors — let the SDK handle."""
task = _make_task()
with _serve(_handler(status=status, content_type="text/html")) as base:
asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0))
@pytest.mark.asyncio
async def test_preflight_passes_headers_and_verify(server):
"""Custom headers and ssl_verify are forwarded to the probe client."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"content-type": "application/json"}
mock_client = AsyncMock()
mock_client.head = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client) as client_cls:
await server._preflight_content_type(
"https://mcp.example.com/mcp",
headers={"Authorization": "Bearer tok"},
ssl_verify=False,
def test_network_error_passes():
"""A connection failure (nothing listening) must pass through, not raise."""
task = _make_task()
# Reserve a port then close it so the connection is refused.
s = socketserver.TCPServer(("127.0.0.1", 0), http.server.BaseHTTPRequestHandler)
dead_port = s.server_address[1]
s.server_close()
asyncio.run(
task._preflight_content_type(
f"http://127.0.0.1:{dead_port}/mcp", timeout=2.0
)
# Verify the client was created with ssl_verify=False
client_cls.assert_called_once()
call_kwargs = client_cls.call_args
assert call_kwargs.kwargs.get("verify") is False
)
def test_cancelled_error_is_not_swallowed():
"""The best-effort except must NOT catch CancelledError (BaseException)."""
task = _make_task()
async def _run():
import httpx
orig = httpx.AsyncClient
try:
# Patch the client so entering it raises CancelledError.
class _C(orig):
async def __aenter__(self):
raise asyncio.CancelledError()
httpx.AsyncClient = _C
with pytest.raises(asyncio.CancelledError):
await task._preflight_content_type("http://x/mcp", timeout=1.0)
finally:
httpx.AsyncClient = orig
asyncio.run(_run())
# ---------------------------------------------------------------------------
# HEAD -> GET fallback
# ---------------------------------------------------------------------------
def test_head_405_falls_back_to_get_and_rejects_html():
task = _make_task("fallback_srv")
record: list[str] = []
with _serve(_handler(
status=200, content_type="text/html",
head_status=405, record=record,
)) as base:
with pytest.raises(NonMcpEndpointError):
asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0))
assert record == ["HEAD", "GET"]
def test_head_501_falls_back_to_get_and_passes_json():
task = _make_task()
record: list[str] = []
with _serve(_handler(
status=200, content_type="application/json", body=b"{}",
head_status=501, record=record,
)) as base:
asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0))
assert record == ["HEAD", "GET"]
# ---------------------------------------------------------------------------
# ssl_verify / client_cert forwarding to the probe client
# ---------------------------------------------------------------------------
def test_ssl_verify_and_cert_forwarded(monkeypatch):
captured: dict = {}
import httpx
class _FakeClient:
def __init__(self, **kwargs):
captured.update(kwargs)
async def __aenter__(self):
return self
async def __aexit__(self, *a):
return False
async def head(self, url, headers=None):
return httpx.Response(200, headers={"content-type": "application/json"})
monkeypatch.setattr(httpx, "AsyncClient", _FakeClient)
task = _make_task()
asyncio.run(task._preflight_content_type(
"https://mcp.example.com/mcp",
ssl_verify=False,
client_cert="/path/to/cert.pem",
timeout=3.0,
))
assert captured.get("verify") is False
assert captured.get("cert") == "/path/to/cert.pem"
assert captured.get("follow_redirects") is True

View file

@ -518,6 +518,21 @@ class InvalidMcpUrlError(ValueError):
"""
class NonMcpEndpointError(ConnectionError):
"""Raised when an HTTP MCP URL serves a non-MCP response.
A genuine MCP Streamable-HTTP endpoint answers with ``application/json``
or ``text/event-stream``. Anything else on a 2xx response (typically
``text/html`` from a web-app root) means the configured ``url`` points at
the wrong place. This is non-retryable: every attempt returns the same
page, so the reconnect-backoff loop is skipped and the server is reported
failed immediately with an actionable message.
Subclasses :class:`ConnectionError` so callers that only catch the broad
class still treat it as a connection problem.
"""
def _validate_remote_mcp_url(server_name: str, url: Any) -> str:
"""Return the URL as a string if it's a valid http(s) remote MCP URL.
@ -1457,53 +1472,85 @@ class MCPServerTask:
# PID-reuse can't surface stale pgroup state later.
_stdio_pgids.pop(pid, None)
@staticmethod
# Content types a real MCP Streamable-HTTP endpoint may return on the
# initial POST/GET. Anything else on a 2xx response means the URL is not
# an MCP endpoint.
_MCP_CONTENT_TYPES = ("application/json", "text/event-stream")
async def _preflight_content_type(
self,
url: str,
*,
headers: Optional[dict] = None,
ssl_verify: bool = True,
client_cert=None,
timeout: float = 5.0,
) -> None:
"""Quick content-type probe before handing *url* to the MCP SDK.
"""Probe *url* for an MCP-shaped response before the SDK connects.
A misconfigured ``mcp_servers.<name>.url`` that points at a plain web
app (returning ``text/html``) causes the MCP SDK to sit on the
connection for the full ``connect_timeout`` (default 60 s) before
surfacing ``CancelledError``. A cheap HEAD request lets us detect
this in 5 s and raise immediately with an actionable message.
A misconfigured ``mcp_servers.<name>.url`` pointed at a plain web app
returns HTML (or some other non-MCP body). The MCP SDK then sits on
the connection for the full ``connect_timeout`` (default 60 s) before
surfacing an opaque ``CancelledError``. A cheap, short-timeout probe
here catches that in ``timeout`` seconds and raises
:class:`NonMcpEndpointError` with an actionable message.
Non-HTML responses (``application/json``, missing header, network
errors) silently pass through so the normal MCP handshake proceeds.
Detection is allow-list based: a 2xx response is rejected only when it
carries a definite content type that is NOT one an MCP endpoint uses
(``application/json`` / ``text/event-stream``). A missing or empty
content type, non-2xx status, or any network/transport error passes
through silently the probe is strictly best-effort, and the real
handshake remains the source of truth for everything except the
unambiguous "this is a web page, not MCP" case.
Runs on its own httpx client OUTSIDE the SDK's anyio task group, so the
raised error propagates as itself rather than being wrapped in an
``ExceptionGroup`` (which is what defeats hooks installed inside the
SDK transport).
"""
try:
import httpx as _httpx
except ImportError:
return # No httpx → skip probe; SDK import would have failed first.
probe_headers = dict(headers) if headers else {}
# HEAD is idempotent and lightweight; fall back to GET if the
# server rejects HEAD (405 Method Not Allowed).
async with _httpx.AsyncClient(
verify=ssl_verify,
follow_redirects=True,
timeout=_httpx.Timeout(timeout),
) as client:
client_kwargs: dict = {
"verify": ssl_verify,
"follow_redirects": True,
"timeout": _httpx.Timeout(timeout),
}
if client_cert is not None:
client_kwargs["cert"] = client_cert
probe_headers = dict(headers) if headers else {}
try:
async with _httpx.AsyncClient(**client_kwargs) as client:
# HEAD is cheapest; fall back to GET if the server doesn't
# implement it (405 Method Not Allowed / 501 Not Implemented).
resp = await client.head(url, headers=probe_headers)
if resp.status_code == 405:
if resp.status_code in (405, 501):
resp = await client.get(url, headers=probe_headers)
ct = resp.headers.get("content-type", "")
if "text/html" in ct.lower():
raise ConnectionError(
f"MCP server '{url}' returned Content-Type: {ct}. "
"This looks like a regular web page, not an MCP endpoint. "
"Verify the URL points to an MCP Streamable HTTP or SSE "
"endpoint (e.g. https://host/mcp, not https://host/)."
)
except ConnectionError:
raise
except Exception:
# Network errors, timeouts, etc. — let the real MCP handshake
# deal with them; this is just a best-effort early check.
pass
except _httpx.HTTPError:
return # DNS/connect/timeout/transport error — let the SDK try.
# Only judge successful responses. A 4xx/5xx may be an auth challenge
# or a transient error the real handshake handles correctly.
if not (200 <= resp.status_code < 300):
return
ct_base = resp.headers.get("content-type", "").split(";")[0].strip().lower()
if not ct_base:
return # No content type advertised — don't second-guess the SDK.
if ct_base in self._MCP_CONTENT_TYPES:
return # Looks like a real MCP endpoint.
raise NonMcpEndpointError(
f"MCP server '{self.name}' at {url} returned Content-Type "
f"'{ct_base}', not an MCP response (expected one of: "
f"{', '.join(self._MCP_CONTENT_TYPES)}). The URL most likely "
"points at a web page rather than an MCP endpoint — check it "
"resolves to a Streamable HTTP / SSE endpoint "
"(e.g. https://host/mcp, not https://host/)."
)
async def _run_http(self, config: dict):
"""Run the server using HTTP/StreamableHTTP transport."""
@ -1515,14 +1562,6 @@ class MCPServerTask:
)
url = config["url"]
# Pre-flight: reject obvious non-MCP endpoints (e.g. a web app
# returning HTML) in seconds instead of waiting the full
# connect_timeout (default 60 s).
await self._preflight_content_type(
url,
headers=dict(config.get("headers") or {}),
ssl_verify=config.get("ssl_verify", True),
)
headers = dict(config.get("headers") or {})
# Some MCP servers require MCP-Protocol-Version on the initial
# initialize request and reject session-less POSTs otherwise.
@ -1754,6 +1793,28 @@ class MCPServerTask:
self._ready.set()
return
# Pre-flight content-type probe (Streamable HTTP only; SSE is
# exercised by its own client and legitimately serves
# text/event-stream). A URL pointed at a web-app root returns
# HTML, which makes the SDK hang for the full connect_timeout
# before surfacing an opaque CancelledError. Probing here — once,
# outside the SDK task group — fails fast and non-retryably with
# an actionable message, mirroring the URL-validation path above.
if config.get("transport") != "sse":
try:
_probe_headers = dict(config.get("headers") or {})
await self._preflight_content_type(
config["url"],
headers=_probe_headers,
ssl_verify=config.get("ssl_verify", True),
client_cert=_resolve_client_cert(self.name, config),
)
except NonMcpEndpointError as exc:
logger.warning("%s", exc)
self._error = exc
self._ready.set()
return
retries = 0
initial_retries = 0
backoff = 1.0