mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway/api_server): deduplicate concurrent idempotent requests
This commit is contained in:
parent
f81c0394d0
commit
3f10c27cc0
2 changed files with 113 additions and 5 deletions
|
|
@ -469,6 +469,7 @@ class _IdempotencyCache:
|
|||
def __init__(self, max_items: int = 1000, ttl_seconds: int = 300):
|
||||
from collections import OrderedDict
|
||||
self._store = OrderedDict()
|
||||
self._inflight: Dict[tuple[str, str], "asyncio.Task[Any]"] = {}
|
||||
self._ttl = ttl_seconds
|
||||
self._max = max_items
|
||||
|
||||
|
|
@ -486,11 +487,27 @@ class _IdempotencyCache:
|
|||
item = self._store.get(key)
|
||||
if item and item["fp"] == fingerprint:
|
||||
return item["resp"]
|
||||
resp = await compute_coro()
|
||||
import time as _t
|
||||
self._store[key] = {"resp": resp, "fp": fingerprint, "ts": _t.time()}
|
||||
self._purge()
|
||||
return resp
|
||||
|
||||
inflight_key = (key, fingerprint)
|
||||
task = self._inflight.get(inflight_key)
|
||||
if task is None:
|
||||
async def _compute_and_store():
|
||||
resp = await compute_coro()
|
||||
import time as _t
|
||||
self._store[key] = {"resp": resp, "fp": fingerprint, "ts": _t.time()}
|
||||
self._purge()
|
||||
return resp
|
||||
|
||||
task = asyncio.create_task(_compute_and_store())
|
||||
self._inflight[inflight_key] = task
|
||||
|
||||
def _clear_inflight(done_task: "asyncio.Task[Any]") -> None:
|
||||
if self._inflight.get(inflight_key) is done_task:
|
||||
self._inflight.pop(inflight_key, None)
|
||||
|
||||
task.add_done_callback(_clear_inflight)
|
||||
|
||||
return await asyncio.shield(task)
|
||||
|
||||
|
||||
_idem_cache = _IdempotencyCache()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ Tests cover:
|
|||
- Error handling (invalid JSON, missing fields)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
|
|
@ -25,6 +26,7 @@ from gateway.config import GatewayConfig, Platform, PlatformConfig
|
|||
from gateway.platforms.api_server import (
|
||||
APIServerAdapter,
|
||||
ResponseStore,
|
||||
_IdempotencyCache,
|
||||
_CORS_HEADERS,
|
||||
_derive_chat_session_id,
|
||||
check_api_server_requirements,
|
||||
|
|
@ -104,6 +106,95 @@ class TestResponseStore:
|
|||
assert store.delete("resp_missing") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IdempotencyCache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIdempotencyCache:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_same_key_and_fingerprint_runs_once(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return ("response", {"total_tokens": 1})
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 1
|
||||
|
||||
gate.set()
|
||||
first_result, second_result = await asyncio.gather(first, second)
|
||||
|
||||
assert first_result == second_result == ("response", {"total_tokens": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_fingerprint_does_not_reuse_inflight_task(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
result = calls
|
||||
if calls == 2:
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return result
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-2", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 2
|
||||
|
||||
gate.set()
|
||||
results = await asyncio.gather(first, second)
|
||||
|
||||
assert sorted(results) == [1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_waiter_does_not_drop_shared_inflight_task(self):
|
||||
cache = _IdempotencyCache()
|
||||
gate = asyncio.Event()
|
||||
started = asyncio.Event()
|
||||
calls = 0
|
||||
|
||||
async def compute():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
started.set()
|
||||
await gate.wait()
|
||||
return "response"
|
||||
|
||||
first = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
|
||||
await started.wait()
|
||||
assert calls == 1
|
||||
|
||||
first.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await first
|
||||
|
||||
second = asyncio.create_task(cache.get_or_set("idem-key", "fp-1", compute))
|
||||
await asyncio.sleep(0)
|
||||
assert calls == 1
|
||||
|
||||
gate.set()
|
||||
assert await second == "response"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue