diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 7efb756c9c..8bbf16e17e 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -469,6 +469,7 @@ class _IdempotencyCache: def __init__(self, max_items: int = 1000, ttl_seconds: int = 300): from collections import OrderedDict self._store = OrderedDict() + self._inflight: Dict[tuple[str, str], "asyncio.Task[Any]"] = {} self._ttl = ttl_seconds self._max = max_items @@ -486,11 +487,27 @@ class _IdempotencyCache: item = self._store.get(key) if item and item["fp"] == fingerprint: return item["resp"] - resp = await compute_coro() - import time as _t - self._store[key] = {"resp": resp, "fp": fingerprint, "ts": _t.time()} - self._purge() - return resp + + inflight_key = (key, fingerprint) + task = self._inflight.get(inflight_key) + if task is None: + async def _compute_and_store(): + resp = await compute_coro() + import time as _t + self._store[key] = {"resp": resp, "fp": fingerprint, "ts": _t.time()} + self._purge() + return resp + + task = asyncio.create_task(_compute_and_store()) + self._inflight[inflight_key] = task + + def _clear_inflight(done_task: "asyncio.Task[Any]") -> None: + if self._inflight.get(inflight_key) is done_task: + self._inflight.pop(inflight_key, None) + + task.add_done_callback(_clear_inflight) + + return await asyncio.shield(task) _idem_cache = _IdempotencyCache() diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index d0cebacb88..ca229f26f7 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -12,6 +12,7 @@ Tests cover: - Error handling (invalid JSON, missing fields) """ +import asyncio import json import time import uuid @@ -25,6 +26,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.api_server import ( APIServerAdapter, ResponseStore, + _IdempotencyCache, _CORS_HEADERS, _derive_chat_session_id, check_api_server_requirements, @@ -104,6 +106,95 @@ class TestResponseStore: assert store.delete("resp_missing") is False +# --------------------------------------------------------------------------- +# _IdempotencyCache +# --------------------------------------------------------------------------- + + +class TestIdempotencyCache: + @pytest.mark.asyncio + async def test_concurrent_same_key_and_fingerprint_runs_once(self): + cache = _IdempotencyCache() + gate = asyncio.Event() + started = asyncio.Event() + calls = 0 + + async def compute(): + nonlocal calls + calls += 1 + started.set() + await gate.wait() + return ("response", {"total_tokens": 1}) + + first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute)) + second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute)) + + await started.wait() + assert calls == 1 + + gate.set() + first_result, second_result = await asyncio.gather(first, second) + + assert first_result == second_result == ("response", {"total_tokens": 1}) + + @pytest.mark.asyncio + async def test_different_fingerprint_does_not_reuse_inflight_task(self): + cache = _IdempotencyCache() + gate = asyncio.Event() + started = asyncio.Event() + calls = 0 + + async def compute(): + nonlocal calls + calls += 1 + result = calls + if calls == 2: + started.set() + await gate.wait() + return result + + first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute)) + second = asyncio.create_task(cache.get_or_set("idem-key", "fp-2", compute)) + + await started.wait() + assert calls == 2 + + gate.set() + results = await asyncio.gather(first, second) + + assert sorted(results) == [1, 2] + + @pytest.mark.asyncio + async def test_cancelled_waiter_does_not_drop_shared_inflight_task(self): + cache = _IdempotencyCache() + gate = asyncio.Event() + started = asyncio.Event() + calls = 0 + + async def compute(): + nonlocal calls + calls += 1 + started.set() + await gate.wait() + return "response" + + first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute)) + + await started.wait() + assert calls == 1 + + first.cancel() + with pytest.raises(asyncio.CancelledError): + await first + + second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute)) + await asyncio.sleep(0) + assert calls == 1 + + gate.set() + assert await second == "response" + + # --------------------------------------------------------------------------- # Adapter initialization # ---------------------------------------------------------------------------