diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index f0dfc6466f..a7776c0c3c 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -307,6 +307,7 @@ class APIServerAdapter(BasePlatformAdapter): if "*" in self._cors_origins: headers = dict(_CORS_HEADERS) headers["Access-Control-Allow-Origin"] = "*" + headers["Access-Control-Max-Age"] = "600" return headers if origin not in self._cors_origins: @@ -315,6 +316,7 @@ class APIServerAdapter(BasePlatformAdapter): headers = dict(_CORS_HEADERS) headers["Access-Control-Allow-Origin"] = origin headers["Vary"] = "Origin" + headers["Access-Control-Max-Age"] = "600" return headers def _origin_allowed(self, origin: str) -> bool: diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 9f5eb0baf1..e40902a587 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -1356,6 +1356,21 @@ class TestCORS: assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "") + @pytest.mark.asyncio + async def test_cors_preflight_sets_max_age(self): + adapter = _make_adapter(cors_origins=["http://localhost:3000"]) + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.options( + "/v1/chat/completions", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization, Content-Type", + }, + ) + assert resp.status == 200 + assert resp.headers.get("Access-Control-Max-Age") == "600" # --------------------------------------------------------------------------- # Conversation parameter # ---------------------------------------------------------------------------