diff --git a/tests/tools/test_mcp_preflight_content_type.py b/tests/tools/test_mcp_preflight_content_type.py index 2880ce4f4b5..312aa48dfc9 100644 --- a/tests/tools/test_mcp_preflight_content_type.py +++ b/tests/tools/test_mcp_preflight_content_type.py @@ -1,137 +1,237 @@ -"""Tests for _MCPServer._preflight_content_type early-fail behaviour.""" +"""Tests for MCPServerTask._preflight_content_type fast-fail behaviour. + +These drive the REAL ``_preflight_content_type`` method against a real local +HTTP server (via httpx's ASGI/transport plumbing through a stdlib server), +rather than reimplementing the probe inline. That distinction matters: the +production probe must run on its own httpx client outside the MCP SDK's anyio +task group, and a faithful test must exercise that actual method so the +content-type allow-list, HEAD->GET fallback, and best-effort pass-through are +all covered as shipped. +""" from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +import http.server +import socketserver +import threading +from contextlib import contextmanager import pytest -from tools.mcp_tool import MCPServerTask +from tools.mcp_tool import MCPServerTask, NonMcpEndpointError -@pytest.fixture() -def server(): - """Return a minimal MCPServerTask instance (bypasses __init__ complexity).""" - s = MCPServerTask.__new__(MCPServerTask) - s.name = "test-server" - return s +def _make_task(name: str = "probe_srv") -> MCPServerTask: + """Minimal MCPServerTask without running the heavy __init__.""" + task = MCPServerTask.__new__(MCPServerTask) + task.name = name + return task + + +@contextmanager +def _serve(handler_cls): + """Run *handler_cls* on a background thread; yield its base URL.""" + httpd = socketserver.TCPServer(("127.0.0.1", 0), handler_cls) + port = httpd.server_address[1] + t = threading.Thread(target=httpd.serve_forever, daemon=True) + t.start() + try: + yield f"http://127.0.0.1:{port}" + finally: + httpd.shutdown() + httpd.server_close() + t.join(timeout=5) + + +def _handler(status: int = 200, + content_type: "str | None" = "text/html; charset=utf-8", + body: bytes = b"x", head_status=None, record=None): + """Build a BaseHTTPRequestHandler that replies with the given shape. + + ``head_status`` lets HEAD return a different status than GET (to exercise + the HEAD->GET fallback). ``record`` is an optional list that captures the + HTTP methods the server actually saw. + """ + + class _H(http.server.BaseHTTPRequestHandler): + def _write(self, sc, ct, payload): + self.send_response(sc) + if ct is not None: + self.send_header("Content-Type", ct) + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + if payload: + self.wfile.write(payload) + + def do_HEAD(self): + if record is not None: + record.append("HEAD") + sc = head_status if head_status is not None else status + self._write(sc, content_type, b"") + + def do_GET(self): + if record is not None: + record.append("GET") + self._write(status, content_type, body) + + def log_message(self, format, *args): # noqa: A002 + pass + + return _H # --------------------------------------------------------------------------- -# HTML response → ConnectionError +# Reject: non-MCP content types on a 2xx response # --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_preflight_rejects_html(server): - """A text/html response must raise ConnectionError immediately.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "text/html; charset=utf-8"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - with pytest.raises(ConnectionError, match="text/html"): - await server._preflight_content_type("https://example.com") +@pytest.mark.parametrize("content_type", [ + "text/html; charset=utf-8", + "text/html", + "text/plain", + "application/xml", + "text/HTML", # case-insensitivity +]) +def test_non_mcp_content_type_raises(content_type): + task = _make_task("bad_srv") + with _serve(_handler(status=200, content_type=content_type)) as base: + with pytest.raises(NonMcpEndpointError) as exc_info: + asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0)) + msg = str(exc_info.value) + assert "bad_srv" in msg + assert "application/json" in msg and "text/event-stream" in msg -@pytest.mark.asyncio -async def test_preflight_rejects_html_on_get_fallback(server): - """When HEAD returns 405, fall back to GET — still reject HTML.""" - head_response = MagicMock() - head_response.status_code = 405 - - get_response = MagicMock() - get_response.status_code = 200 - get_response.headers = {"content-type": "text/html"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=head_response) - mock_client.get = AsyncMock(return_value=get_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - with pytest.raises(ConnectionError, match="text/html"): - await server._preflight_content_type("https://example.com") +def test_non_mcp_error_is_non_retryable_connection_error(): + """NonMcpEndpointError must subclass ConnectionError (retry loop skips it + via an explicit except; broad ConnectionError catchers still work).""" + assert issubclass(NonMcpEndpointError, ConnectionError) # --------------------------------------------------------------------------- -# Non-HTML responses → silent pass-through +# Pass-through: valid MCP content types, ambiguous, and error responses # --------------------------------------------------------------------------- - -@pytest.mark.asyncio -async def test_preflight_accepts_json(server): - """application/json must NOT raise.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "application/json"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - # Should not raise - await server._preflight_content_type("https://mcp-server.example.com/mcp") +@pytest.mark.parametrize("content_type", [ + "application/json", + "application/json; charset=utf-8", + "text/event-stream", + "TEXT/EVENT-STREAM", +]) +def test_valid_mcp_content_types_pass(content_type): + task = _make_task() + with _serve(_handler(status=200, content_type=content_type, body=b"{}")) as base: + # Must not raise. + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_accepts_no_content_type(server): - """Missing Content-Type header must NOT raise.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - await server._preflight_content_type("https://mcp-server.example.com/mcp") +def test_missing_content_type_passes(): + task = _make_task() + with _serve(_handler(status=200, content_type=None, body=b"")) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_swallows_network_errors(server): - """Network errors / timeouts must silently pass through.""" - - mock_client = AsyncMock() - mock_client.head = AsyncMock(side_effect=TimeoutError("connect timed out")) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client): - # Should not raise — let the real MCP handshake deal with it - await server._preflight_content_type("https://unreachable.example.com") +@pytest.mark.parametrize("status", [401, 403, 404, 500, 503]) +def test_non_2xx_responses_pass(status): + """4xx/5xx are auth challenges or transient errors — let the SDK handle.""" + task = _make_task() + with _serve(_handler(status=status, content_type="text/html")) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) -@pytest.mark.asyncio -async def test_preflight_passes_headers_and_verify(server): - """Custom headers and ssl_verify are forwarded to the probe client.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-type": "application/json"} - - mock_client = AsyncMock() - mock_client.head = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - - with patch("httpx.AsyncClient", return_value=mock_client) as client_cls: - await server._preflight_content_type( - "https://mcp.example.com/mcp", - headers={"Authorization": "Bearer tok"}, - ssl_verify=False, +def test_network_error_passes(): + """A connection failure (nothing listening) must pass through, not raise.""" + task = _make_task() + # Reserve a port then close it so the connection is refused. + s = socketserver.TCPServer(("127.0.0.1", 0), http.server.BaseHTTPRequestHandler) + dead_port = s.server_address[1] + s.server_close() + asyncio.run( + task._preflight_content_type( + f"http://127.0.0.1:{dead_port}/mcp", timeout=2.0 ) - # Verify the client was created with ssl_verify=False - client_cls.assert_called_once() - call_kwargs = client_cls.call_args - assert call_kwargs.kwargs.get("verify") is False + ) + + +def test_cancelled_error_is_not_swallowed(): + """The best-effort except must NOT catch CancelledError (BaseException).""" + task = _make_task() + + async def _run(): + import httpx + orig = httpx.AsyncClient + try: + # Patch the client so entering it raises CancelledError. + class _C(orig): + async def __aenter__(self): + raise asyncio.CancelledError() + + httpx.AsyncClient = _C + with pytest.raises(asyncio.CancelledError): + await task._preflight_content_type("http://x/mcp", timeout=1.0) + finally: + httpx.AsyncClient = orig + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# HEAD -> GET fallback +# --------------------------------------------------------------------------- + +def test_head_405_falls_back_to_get_and_rejects_html(): + task = _make_task("fallback_srv") + record: list[str] = [] + with _serve(_handler( + status=200, content_type="text/html", + head_status=405, record=record, + )) as base: + with pytest.raises(NonMcpEndpointError): + asyncio.run(task._preflight_content_type(f"{base}/", timeout=5.0)) + assert record == ["HEAD", "GET"] + + +def test_head_501_falls_back_to_get_and_passes_json(): + task = _make_task() + record: list[str] = [] + with _serve(_handler( + status=200, content_type="application/json", body=b"{}", + head_status=501, record=record, + )) as base: + asyncio.run(task._preflight_content_type(f"{base}/mcp", timeout=5.0)) + assert record == ["HEAD", "GET"] + + +# --------------------------------------------------------------------------- +# ssl_verify / client_cert forwarding to the probe client +# --------------------------------------------------------------------------- + +def test_ssl_verify_and_cert_forwarded(monkeypatch): + captured: dict = {} + + import httpx + + class _FakeClient: + def __init__(self, **kwargs): + captured.update(kwargs) + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def head(self, url, headers=None): + return httpx.Response(200, headers={"content-type": "application/json"}) + + monkeypatch.setattr(httpx, "AsyncClient", _FakeClient) + task = _make_task() + asyncio.run(task._preflight_content_type( + "https://mcp.example.com/mcp", + ssl_verify=False, + client_cert="/path/to/cert.pem", + timeout=3.0, + )) + assert captured.get("verify") is False + assert captured.get("cert") == "/path/to/cert.pem" + assert captured.get("follow_redirects") is True diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e9d1fc0b7fc..b15fcdaecb1 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -518,6 +518,21 @@ class InvalidMcpUrlError(ValueError): """ +class NonMcpEndpointError(ConnectionError): + """Raised when an HTTP MCP URL serves a non-MCP response. + + A genuine MCP Streamable-HTTP endpoint answers with ``application/json`` + or ``text/event-stream``. Anything else on a 2xx response (typically + ``text/html`` from a web-app root) means the configured ``url`` points at + the wrong place. This is non-retryable: every attempt returns the same + page, so the reconnect-backoff loop is skipped and the server is reported + failed immediately with an actionable message. + + Subclasses :class:`ConnectionError` so callers that only catch the broad + class still treat it as a connection problem. + """ + + def _validate_remote_mcp_url(server_name: str, url: Any) -> str: """Return the URL as a string if it's a valid http(s) remote MCP URL. @@ -1457,53 +1472,85 @@ class MCPServerTask: # PID-reuse can't surface stale pgroup state later. _stdio_pgids.pop(pid, None) - @staticmethod + # Content types a real MCP Streamable-HTTP endpoint may return on the + # initial POST/GET. Anything else on a 2xx response means the URL is not + # an MCP endpoint. + _MCP_CONTENT_TYPES = ("application/json", "text/event-stream") + async def _preflight_content_type( + self, url: str, *, headers: Optional[dict] = None, ssl_verify: bool = True, + client_cert=None, timeout: float = 5.0, ) -> None: - """Quick content-type probe before handing *url* to the MCP SDK. + """Probe *url* for an MCP-shaped response before the SDK connects. - A misconfigured ``mcp_servers..url`` that points at a plain web - app (returning ``text/html``) causes the MCP SDK to sit on the - connection for the full ``connect_timeout`` (default 60 s) before - surfacing ``CancelledError``. A cheap HEAD request lets us detect - this in ≤ 5 s and raise immediately with an actionable message. + A misconfigured ``mcp_servers..url`` pointed at a plain web app + returns HTML (or some other non-MCP body). The MCP SDK then sits on + the connection for the full ``connect_timeout`` (default 60 s) before + surfacing an opaque ``CancelledError``. A cheap, short-timeout probe + here catches that in ≤ ``timeout`` seconds and raises + :class:`NonMcpEndpointError` with an actionable message. - Non-HTML responses (``application/json``, missing header, network - errors) silently pass through so the normal MCP handshake proceeds. + Detection is allow-list based: a 2xx response is rejected only when it + carries a definite content type that is NOT one an MCP endpoint uses + (``application/json`` / ``text/event-stream``). A missing or empty + content type, non-2xx status, or any network/transport error passes + through silently — the probe is strictly best-effort, and the real + handshake remains the source of truth for everything except the + unambiguous "this is a web page, not MCP" case. + + Runs on its own httpx client OUTSIDE the SDK's anyio task group, so the + raised error propagates as itself rather than being wrapped in an + ``ExceptionGroup`` (which is what defeats hooks installed inside the + SDK transport). """ try: import httpx as _httpx + except ImportError: + return # No httpx → skip probe; SDK import would have failed first. - probe_headers = dict(headers) if headers else {} - # HEAD is idempotent and lightweight; fall back to GET if the - # server rejects HEAD (405 Method Not Allowed). - async with _httpx.AsyncClient( - verify=ssl_verify, - follow_redirects=True, - timeout=_httpx.Timeout(timeout), - ) as client: + client_kwargs: dict = { + "verify": ssl_verify, + "follow_redirects": True, + "timeout": _httpx.Timeout(timeout), + } + if client_cert is not None: + client_kwargs["cert"] = client_cert + + probe_headers = dict(headers) if headers else {} + try: + async with _httpx.AsyncClient(**client_kwargs) as client: + # HEAD is cheapest; fall back to GET if the server doesn't + # implement it (405 Method Not Allowed / 501 Not Implemented). resp = await client.head(url, headers=probe_headers) - if resp.status_code == 405: + if resp.status_code in (405, 501): resp = await client.get(url, headers=probe_headers) - ct = resp.headers.get("content-type", "") - if "text/html" in ct.lower(): - raise ConnectionError( - f"MCP server '{url}' returned Content-Type: {ct}. " - "This looks like a regular web page, not an MCP endpoint. " - "Verify the URL points to an MCP Streamable HTTP or SSE " - "endpoint (e.g. https://host/mcp, not https://host/)." - ) - except ConnectionError: - raise - except Exception: - # Network errors, timeouts, etc. — let the real MCP handshake - # deal with them; this is just a best-effort early check. - pass + except _httpx.HTTPError: + return # DNS/connect/timeout/transport error — let the SDK try. + + # Only judge successful responses. A 4xx/5xx may be an auth challenge + # or a transient error the real handshake handles correctly. + if not (200 <= resp.status_code < 300): + return + + ct_base = resp.headers.get("content-type", "").split(";")[0].strip().lower() + if not ct_base: + return # No content type advertised — don't second-guess the SDK. + if ct_base in self._MCP_CONTENT_TYPES: + return # Looks like a real MCP endpoint. + + raise NonMcpEndpointError( + f"MCP server '{self.name}' at {url} returned Content-Type " + f"'{ct_base}', not an MCP response (expected one of: " + f"{', '.join(self._MCP_CONTENT_TYPES)}). The URL most likely " + "points at a web page rather than an MCP endpoint — check it " + "resolves to a Streamable HTTP / SSE endpoint " + "(e.g. https://host/mcp, not https://host/)." + ) async def _run_http(self, config: dict): """Run the server using HTTP/StreamableHTTP transport.""" @@ -1515,14 +1562,6 @@ class MCPServerTask: ) url = config["url"] - # Pre-flight: reject obvious non-MCP endpoints (e.g. a web app - # returning HTML) in seconds instead of waiting the full - # connect_timeout (default 60 s). - await self._preflight_content_type( - url, - headers=dict(config.get("headers") or {}), - ssl_verify=config.get("ssl_verify", True), - ) headers = dict(config.get("headers") or {}) # Some MCP servers require MCP-Protocol-Version on the initial # initialize request and reject session-less POSTs otherwise. @@ -1754,6 +1793,28 @@ class MCPServerTask: self._ready.set() return + # Pre-flight content-type probe (Streamable HTTP only; SSE is + # exercised by its own client and legitimately serves + # text/event-stream). A URL pointed at a web-app root returns + # HTML, which makes the SDK hang for the full connect_timeout + # before surfacing an opaque CancelledError. Probing here — once, + # outside the SDK task group — fails fast and non-retryably with + # an actionable message, mirroring the URL-validation path above. + if config.get("transport") != "sse": + try: + _probe_headers = dict(config.get("headers") or {}) + await self._preflight_content_type( + config["url"], + headers=_probe_headers, + ssl_verify=config.get("ssl_verify", True), + client_cert=_resolve_client_cert(self.name, config), + ) + except NonMcpEndpointError as exc: + logger.warning("%s", exc) + self._error = exc + self._ready.set() + return + retries = 0 initial_retries = 0 backoff = 1.0