mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-14 04:02:26 +00:00
fix(msgraph_webhook): harden auth surface + IP allowlisting + response hygiene
Defense-in-depth polish on top of the webhook listener before it becomes a real attack surface once the pipeline starts creating subscriptions and Graph starts POSTing to the configured public URL. - Timing-safe clientState comparison. Previously used `==` on strings; switches to hmac.compare_digest so a mismatch does not leak how many leading characters matched. client_state is documented as a strong shared secret (openssl rand -hex 32 in the setup docs), so a timing-safe primitive is the right call. - Split GET and POST handlers. Graph validates a subscription by sending GET with validationToken in the query; anything else on GET is now a 400 so the endpoint cannot be probed or mistakenly used for data exfil. Previously a bare GET fell through to the POST path and blew up on request.json() with a confusing 400. - Empty response bodies on success. 202 is returned with no body so internal counters (accepted / duplicates / scheduled) do not leak to any caller that can reach the endpoint; counters remain observable via /health for operators. 403 on every-item-bad-clientState batches (so forged POSTs stop retrying), 400 on malformed / unknown-resource batches (sender configuration issue). - Optional source-IP allowlist. New `allowed_source_cidrs` extra field (list or comma-separated string) and `MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS` env var let operators restrict the webhook to Microsoft Graph's published webhook source ranges in production. Empty = allow all, preserving dev-tunnel / localhost workflows. Invalid CIDRs are logged and ignored rather than crashing. Also gates the handshake endpoint so disallowed IPs cannot probe it. - Tests updated for the new response contract (empty-body 202, auth-only 403, config-error 400) and extended to cover: bare GET rejection, POST-with-validationToken handshake tolerance, timing-safe compare actually invoked via hmac.compare_digest spy, malformed body / missing value array, IP allowlist accept/reject paths, handshake IP allowlist, invalid CIDR entries, comma-string CIDR list parsing. 52/52 passed (was 40). Full gateway suite: 5049 passed / 1 pre-existing failure in test_discord_free_response (unrelated, reproduces on clean origin/main).
This commit is contained in:
parent
26a59e4f6c
commit
b8d7e0e6d3
3 changed files with 301 additions and 49 deletions
|
|
@ -1418,12 +1418,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||||
msgraph_webhook_port = os.getenv("MSGRAPH_WEBHOOK_PORT")
|
msgraph_webhook_port = os.getenv("MSGRAPH_WEBHOOK_PORT")
|
||||||
msgraph_webhook_client_state = os.getenv("MSGRAPH_WEBHOOK_CLIENT_STATE", "")
|
msgraph_webhook_client_state = os.getenv("MSGRAPH_WEBHOOK_CLIENT_STATE", "")
|
||||||
msgraph_webhook_resources = os.getenv("MSGRAPH_WEBHOOK_ACCEPTED_RESOURCES", "")
|
msgraph_webhook_resources = os.getenv("MSGRAPH_WEBHOOK_ACCEPTED_RESOURCES", "")
|
||||||
|
msgraph_webhook_allowed_cidrs = os.getenv(
|
||||||
|
"MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS", ""
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
msgraph_webhook_enabled
|
msgraph_webhook_enabled
|
||||||
or Platform.MSGRAPH_WEBHOOK in config.platforms
|
or Platform.MSGRAPH_WEBHOOK in config.platforms
|
||||||
or msgraph_webhook_port
|
or msgraph_webhook_port
|
||||||
or msgraph_webhook_client_state
|
or msgraph_webhook_client_state
|
||||||
or msgraph_webhook_resources
|
or msgraph_webhook_resources
|
||||||
|
or msgraph_webhook_allowed_cidrs
|
||||||
):
|
):
|
||||||
if Platform.MSGRAPH_WEBHOOK not in config.platforms:
|
if Platform.MSGRAPH_WEBHOOK not in config.platforms:
|
||||||
config.platforms[Platform.MSGRAPH_WEBHOOK] = PlatformConfig()
|
config.platforms[Platform.MSGRAPH_WEBHOOK] = PlatformConfig()
|
||||||
|
|
@ -1450,6 +1454,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||||
config.platforms[Platform.MSGRAPH_WEBHOOK].extra[
|
config.platforms[Platform.MSGRAPH_WEBHOOK].extra[
|
||||||
"accepted_resources"
|
"accepted_resources"
|
||||||
] = resources
|
] = resources
|
||||||
|
if msgraph_webhook_allowed_cidrs:
|
||||||
|
cidrs = [
|
||||||
|
cidr.strip()
|
||||||
|
for cidr in msgraph_webhook_allowed_cidrs.split(",")
|
||||||
|
if cidr.strip()
|
||||||
|
]
|
||||||
|
if cidrs:
|
||||||
|
config.platforms[Platform.MSGRAPH_WEBHOOK].extra[
|
||||||
|
"allowed_source_cidrs"
|
||||||
|
] = cidrs
|
||||||
|
|
||||||
# DingTalk
|
# DingTalk
|
||||||
dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID")
|
dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID")
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hmac
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
@ -60,6 +62,9 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
self._max_seen_receipts = max(
|
self._max_seen_receipts = max(
|
||||||
1, int(extra.get("max_seen_receipts", DEFAULT_MAX_SEEN_RECEIPTS))
|
1, int(extra.get("max_seen_receipts", DEFAULT_MAX_SEEN_RECEIPTS))
|
||||||
)
|
)
|
||||||
|
self._allowed_source_networks: list[ipaddress._BaseNetwork] = (
|
||||||
|
self._parse_allowed_source_cidrs(extra.get("allowed_source_cidrs"))
|
||||||
|
)
|
||||||
self._runner = None
|
self._runner = None
|
||||||
self._notification_scheduler: Optional[NotificationScheduler] = None
|
self._notification_scheduler: Optional[NotificationScheduler] = None
|
||||||
self._seen_receipts: set[str] = set()
|
self._seen_receipts: set[str] = set()
|
||||||
|
|
@ -90,13 +95,47 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
def _normalize_resource_value(resource: str) -> str:
|
def _normalize_resource_value(resource: str) -> str:
|
||||||
return str(resource or "").strip().strip("/")
|
return str(resource or "").strip().strip("/")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_allowed_source_cidrs(
|
||||||
|
raw: Any,
|
||||||
|
) -> list[ipaddress._BaseNetwork]:
|
||||||
|
"""Parse an optional list of CIDR ranges allowed to POST to the webhook.
|
||||||
|
|
||||||
|
An empty or missing value means "allow everything" (same behavior as
|
||||||
|
before this field existed). When populated, requests from source IPs
|
||||||
|
outside every listed CIDR are rejected with 403 before the body is
|
||||||
|
parsed. Use this to restrict the endpoint to Microsoft Graph's
|
||||||
|
published webhook source ranges in production deployments.
|
||||||
|
"""
|
||||||
|
if raw is None:
|
||||||
|
return []
|
||||||
|
if isinstance(raw, str):
|
||||||
|
candidates = [chunk.strip() for chunk in raw.split(",")]
|
||||||
|
elif isinstance(raw, (list, tuple, set)):
|
||||||
|
candidates = [str(chunk).strip() for chunk in raw]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
networks: list[ipaddress._BaseNetwork] = []
|
||||||
|
for chunk in candidates:
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
networks.append(ipaddress.ip_network(chunk, strict=False))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
"[msgraph_webhook] Ignoring invalid allowed_source_cidrs entry: %r",
|
||||||
|
chunk,
|
||||||
|
)
|
||||||
|
return networks
|
||||||
|
|
||||||
def set_notification_scheduler(self, scheduler: Optional[NotificationScheduler]) -> None:
|
def set_notification_scheduler(self, scheduler: Optional[NotificationScheduler]) -> None:
|
||||||
self._notification_scheduler = scheduler
|
self._notification_scheduler = scheduler
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_get(self._health_path, self._handle_health)
|
app.router.add_get(self._health_path, self._handle_health)
|
||||||
app.router.add_get(self._webhook_path, self._handle_notification)
|
app.router.add_get(self._webhook_path, self._handle_validation)
|
||||||
app.router.add_post(self._webhook_path, self._handle_notification)
|
app.router.add_post(self._webhook_path, self._handle_notification)
|
||||||
|
|
||||||
self._runner = web.AppRunner(app)
|
self._runner = web.AppRunner(app)
|
||||||
|
|
@ -142,7 +181,28 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _handle_validation(self, request: "web.Request") -> "web.Response":
|
||||||
|
"""Handle Microsoft Graph subscription validation handshake.
|
||||||
|
|
||||||
|
Graph validates a subscription endpoint by sending a GET with
|
||||||
|
``validationToken`` in the query string; the service must echo the
|
||||||
|
token verbatim as ``text/plain`` within 10 seconds. Anything else
|
||||||
|
(bare GET, GET without the token) is rejected so the endpoint can't
|
||||||
|
be enumerated or mistakenly used for data exfiltration.
|
||||||
|
"""
|
||||||
|
if not self._source_ip_allowed(request):
|
||||||
|
return web.Response(status=403)
|
||||||
|
validation_token = request.query.get("validationToken", "")
|
||||||
|
if not validation_token:
|
||||||
|
return web.Response(status=400)
|
||||||
|
return web.Response(text=validation_token, content_type="text/plain")
|
||||||
|
|
||||||
async def _handle_notification(self, request: "web.Request") -> "web.Response":
|
async def _handle_notification(self, request: "web.Request") -> "web.Response":
|
||||||
|
if not self._source_ip_allowed(request):
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
# Graph never sends validationToken on POST, but tolerate it for
|
||||||
|
# defensive clients that replay the handshake in-band.
|
||||||
validation_token = request.query.get("validationToken", "")
|
validation_token = request.query.get("validationToken", "")
|
||||||
if validation_token:
|
if validation_token:
|
||||||
return web.Response(text=validation_token, content_type="text/plain")
|
return web.Response(text=validation_token, content_type="text/plain")
|
||||||
|
|
@ -150,27 +210,31 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
notifications = body.get("value")
|
notifications = body.get("value")
|
||||||
if not isinstance(notifications, list):
|
if not isinstance(notifications, list):
|
||||||
return web.json_response({"error": "Missing notification batch"}, status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
accepted = 0
|
accepted = 0
|
||||||
duplicates = 0
|
duplicates = 0
|
||||||
rejected = 0
|
auth_rejected = 0
|
||||||
scheduled = 0
|
other_rejected = 0
|
||||||
|
|
||||||
for raw_notification in notifications:
|
for raw_notification in notifications:
|
||||||
if not isinstance(raw_notification, dict):
|
if not isinstance(raw_notification, dict):
|
||||||
rejected += 1
|
other_rejected += 1
|
||||||
continue
|
continue
|
||||||
notification = dict(raw_notification)
|
notification = dict(raw_notification)
|
||||||
if not self._resource_accepted(str(notification.get("resource") or "")):
|
if not self._resource_accepted(str(notification.get("resource") or "")):
|
||||||
rejected += 1
|
other_rejected += 1
|
||||||
continue
|
continue
|
||||||
if not self._verify_client_state(notification):
|
if not self._verify_client_state(notification):
|
||||||
rejected += 1
|
# Treat bad clientState as an auth failure: if the whole
|
||||||
|
# batch is forged, we want to signal 403 so the sender
|
||||||
|
# stops retrying. Legitimate Graph retries have valid
|
||||||
|
# clientState and hit the accepted/duplicate paths.
|
||||||
|
auth_rejected += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
receipt_key = self._build_receipt_key(notification)
|
receipt_key = self._build_receipt_key(notification)
|
||||||
|
|
@ -181,23 +245,39 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
self._remember_receipt(receipt_key)
|
self._remember_receipt(receipt_key)
|
||||||
|
|
||||||
accepted += 1
|
accepted += 1
|
||||||
scheduled += 1
|
|
||||||
self._accepted_count += 1
|
self._accepted_count += 1
|
||||||
event = self._build_message_event(notification, receipt_key)
|
event = self._build_message_event(notification, receipt_key)
|
||||||
self._schedule_notification(notification, event)
|
self._schedule_notification(notification, event)
|
||||||
|
|
||||||
self._duplicate_count += duplicates
|
self._duplicate_count += duplicates
|
||||||
status = 202 if accepted or duplicates else 403
|
# If anything ingested OR deduped, return 202 with empty body so
|
||||||
return web.json_response(
|
# Graph acks successfully and we don't leak internal counters. If
|
||||||
{
|
# every item failed auth, return 403 so an attacker POSTing fake
|
||||||
"status": "accepted" if accepted or duplicates else "rejected",
|
# notifications gets a clear reject. Other failures (malformed,
|
||||||
"accepted": accepted,
|
# resource-not-accepted) are the sender's configuration problem,
|
||||||
"duplicates": duplicates,
|
# so 400.
|
||||||
"rejected": rejected,
|
if accepted or duplicates:
|
||||||
"scheduled": scheduled,
|
return web.Response(status=202)
|
||||||
},
|
if auth_rejected and not other_rejected:
|
||||||
status=status,
|
return web.Response(status=403)
|
||||||
)
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
def _source_ip_allowed(self, request: "web.Request") -> bool:
|
||||||
|
"""Return True if the request's source IP is in the configured allowlist.
|
||||||
|
|
||||||
|
When ``allowed_source_cidrs`` is empty (the default), everything is
|
||||||
|
allowed — preserves behavior for dev tunnels / localhost setups.
|
||||||
|
"""
|
||||||
|
if not self._allowed_source_networks:
|
||||||
|
return True
|
||||||
|
peer = request.remote or ""
|
||||||
|
if not peer:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
peer_addr = ipaddress.ip_address(peer)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return any(peer_addr in network for network in self._allowed_source_networks)
|
||||||
|
|
||||||
def _resource_accepted(self, resource: str) -> bool:
|
def _resource_accepted(self, resource: str) -> bool:
|
||||||
if not self._accepted_resources:
|
if not self._accepted_resources:
|
||||||
|
|
@ -220,11 +300,21 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _verify_client_state(self, notification: Dict[str, Any]) -> bool:
|
def _verify_client_state(self, notification: Dict[str, Any]) -> bool:
|
||||||
|
"""Verify the Graph-supplied clientState matches the configured secret.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` instead of ``==`` so that a mismatch
|
||||||
|
doesn't leak how many leading characters matched via string-compare
|
||||||
|
timing. The configured client_state is a shared secret (documented in
|
||||||
|
the setup guide as "generate with ``openssl rand -hex 32``"), so a
|
||||||
|
timing-safe compare is the right primitive.
|
||||||
|
"""
|
||||||
expected = self._client_state
|
expected = self._client_state
|
||||||
if expected is None:
|
if expected is None:
|
||||||
return True
|
return True
|
||||||
provided = self._string_or_none(notification.get("clientState"))
|
provided = self._string_or_none(notification.get("clientState"))
|
||||||
return provided == expected
|
if provided is None:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(provided, expected)
|
||||||
|
|
||||||
def _has_seen_receipt(self, receipt_key: str) -> bool:
|
def _has_seen_receipt(self, receipt_key: str) -> bool:
|
||||||
return receipt_key in self._seen_receipts
|
return receipt_key in self._seen_receipts
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@ def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter:
|
||||||
|
|
||||||
|
|
||||||
class _FakeRequest:
|
class _FakeRequest:
|
||||||
def __init__(self, *, query=None, json_payload=None):
|
def __init__(self, *, query=None, json_payload=None, remote="127.0.0.1"):
|
||||||
self.query = query or {}
|
self.query = query or {}
|
||||||
self._json_payload = json_payload
|
self._json_payload = json_payload
|
||||||
|
self.remote = remote
|
||||||
|
|
||||||
async def json(self):
|
async def json(self):
|
||||||
if isinstance(self._json_payload, Exception):
|
if isinstance(self._json_payload, Exception):
|
||||||
|
|
@ -70,14 +71,31 @@ class TestMSGraphWebhookConfig:
|
||||||
|
|
||||||
class TestMSGraphValidationHandshake:
|
class TestMSGraphValidationHandshake:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_validation_token_echo(self):
|
async def test_validation_token_echo_on_get(self):
|
||||||
|
adapter = _make_adapter()
|
||||||
|
resp = await adapter._handle_validation(
|
||||||
|
_FakeRequest(query={"validationToken": "abc123"})
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.text == "abc123"
|
||||||
|
assert resp.content_type == "text/plain"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_bare_get_without_validation_token_rejected(self):
|
||||||
|
"""GET without validationToken is 400 so the endpoint can't be enumerated."""
|
||||||
|
adapter = _make_adapter()
|
||||||
|
resp = await adapter._handle_validation(_FakeRequest())
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_post_with_validation_token_still_echoes(self):
|
||||||
|
"""Tolerate defensive clients that send validationToken on POST."""
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
resp = await adapter._handle_notification(
|
resp = await adapter._handle_notification(
|
||||||
_FakeRequest(query={"validationToken": "abc123"})
|
_FakeRequest(query={"validationToken": "abc123"})
|
||||||
)
|
)
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert resp.text == "abc123"
|
assert resp.text == "abc123"
|
||||||
assert resp.content_type == "text/plain"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMSGraphNotifications:
|
class TestMSGraphNotifications:
|
||||||
|
|
@ -104,12 +122,10 @@ class TestMSGraphNotifications:
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
|
# Success is 202 with empty body: internal counters must not leak to
|
||||||
|
# the wire. Counters are still observable via /health.
|
||||||
assert resp.status == 202
|
assert resp.status == 202
|
||||||
data = json.loads(resp.text)
|
assert resp.body is None or not resp.body
|
||||||
assert data["accepted"] == 1
|
|
||||||
assert data["duplicates"] == 0
|
|
||||||
assert data["rejected"] == 0
|
|
||||||
assert data["scheduled"] == 1
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
|
@ -121,7 +137,8 @@ class TestMSGraphNotifications:
|
||||||
assert event.message_id == "id:notif-1"
|
assert event.message_id == "id:notif-1"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_bad_client_state_rejected(self):
|
async def test_bad_client_state_rejected_as_auth_failure(self):
|
||||||
|
"""Every-item-bad-clientState batches return 403 so forged POSTs stop retrying."""
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
scheduled: list[tuple[dict, object]] = []
|
scheduled: list[tuple[dict, object]] = []
|
||||||
|
|
||||||
|
|
@ -143,15 +160,46 @@ class TestMSGraphNotifications:
|
||||||
|
|
||||||
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
data = json.loads(resp.text)
|
|
||||||
assert data["accepted"] == 0
|
|
||||||
assert data["duplicates"] == 0
|
|
||||||
assert data["rejected"] == 1
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
assert scheduled == []
|
assert scheduled == []
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_client_state_compare_is_timing_safe(self, monkeypatch):
|
||||||
|
"""Ensure hmac.compare_digest is used for clientState comparison."""
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
calls: list[tuple[str, str]] = []
|
||||||
|
real_compare = hmac.compare_digest
|
||||||
|
|
||||||
|
def _spy(a, b):
|
||||||
|
calls.append((a, b))
|
||||||
|
return real_compare(a, b)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"gateway.platforms.msgraph_webhook.hmac.compare_digest", _spy
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = _make_adapter()
|
||||||
|
payload = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": "notif-timing",
|
||||||
|
"subscriptionId": "sub-1",
|
||||||
|
"changeType": "updated",
|
||||||
|
"resource": "communications/onlineMeetings/meeting-x",
|
||||||
|
"clientState": "expected-client-state",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
|
|
||||||
|
assert calls, "hmac.compare_digest was never called; clientState check is not timing-safe"
|
||||||
|
provided, expected = calls[0]
|
||||||
|
assert provided == "expected-client-state"
|
||||||
|
assert expected == "expected-client-state"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_duplicate_notification_deduped(self):
|
async def test_duplicate_notification_deduped(self):
|
||||||
adapter = _make_adapter()
|
adapter = _make_adapter()
|
||||||
|
|
@ -176,11 +224,9 @@ class TestMSGraphNotifications:
|
||||||
first = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
first = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
assert first.status == 202
|
assert first.status == 202
|
||||||
second = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
second = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
|
# Duplicate-only batch still returns 202 so Graph stops retrying.
|
||||||
assert second.status == 202
|
assert second.status == 202
|
||||||
second_data = json.loads(second.text)
|
assert adapter._duplicate_count == 1
|
||||||
assert second_data["accepted"] == 0
|
|
||||||
assert second_data["duplicates"] == 1
|
|
||||||
assert second_data["scheduled"] == 0
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
|
@ -212,10 +258,6 @@ class TestMSGraphNotifications:
|
||||||
|
|
||||||
assert first.status == 202
|
assert first.status == 202
|
||||||
assert second.status == 202
|
assert second.status == 202
|
||||||
second_data = json.loads(second.text)
|
|
||||||
assert second_data["accepted"] == 1
|
|
||||||
assert second_data["duplicates"] == 0
|
|
||||||
assert second_data["scheduled"] == 1
|
|
||||||
|
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
|
@ -237,11 +279,39 @@ class TestMSGraphNotifications:
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
data = json.loads(resp.text)
|
|
||||||
|
|
||||||
assert resp.status == 202
|
assert resp.status == 202
|
||||||
assert data["accepted"] == 1
|
|
||||||
assert data["rejected"] == 0
|
@pytest.mark.anyio
|
||||||
|
async def test_resource_not_in_allowlist_returns_400(self):
|
||||||
|
"""Every-item-rejected-for-non-auth returns 400 (configuration issue)."""
|
||||||
|
adapter = _make_adapter(accepted_resources=["communications/onlineMeetings"])
|
||||||
|
payload = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": "notif-bad-resource",
|
||||||
|
"resource": "users/u1/messages",
|
||||||
|
"clientState": "expected-client-state",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_malformed_body_returns_400(self):
|
||||||
|
adapter = _make_adapter()
|
||||||
|
resp = await adapter._handle_notification(
|
||||||
|
_FakeRequest(json_payload=ValueError("bad json"))
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_missing_value_array_returns_400(self):
|
||||||
|
adapter = _make_adapter()
|
||||||
|
resp = await adapter._handle_notification(
|
||||||
|
_FakeRequest(json_payload={"not_value": []})
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_seen_receipts_are_bounded(self):
|
async def test_seen_receipts_are_bounded(self):
|
||||||
|
|
@ -277,6 +347,84 @@ class TestMSGraphNotifications:
|
||||||
assert list(adapter._seen_receipt_order) == ["id:notif-b", "id:notif-c"]
|
assert list(adapter._seen_receipt_order) == ["id:notif-b", "id:notif-c"]
|
||||||
|
|
||||||
replay = await _post("notif-a")
|
replay = await _post("notif-a")
|
||||||
replay_data = json.loads(replay.text)
|
# notif-a evicted from the bounded cache, so it's accepted again (202)
|
||||||
assert replay_data["accepted"] == 1
|
# rather than treated as a duplicate.
|
||||||
assert replay_data["duplicates"] == 0
|
assert replay.status == 202
|
||||||
|
assert adapter._accepted_count == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphSourceIPAllowlist:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_disabled_by_default_allows_all(self):
|
||||||
|
"""Empty allowlist preserves pre-existing behavior (dev tunnels, localhost)."""
|
||||||
|
adapter = _make_adapter() # no allowed_source_cidrs set
|
||||||
|
payload = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": "notif-ip",
|
||||||
|
"resource": "communications/onlineMeetings/m",
|
||||||
|
"clientState": "expected-client-state",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = await adapter._handle_notification(
|
||||||
|
_FakeRequest(json_payload=payload, remote="203.0.113.99")
|
||||||
|
)
|
||||||
|
assert resp.status == 202
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_post_from_disallowed_ip_rejected(self):
|
||||||
|
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
|
||||||
|
payload = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": "notif-ip-bad",
|
||||||
|
"resource": "communications/onlineMeetings/m",
|
||||||
|
"clientState": "expected-client-state",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = await adapter._handle_notification(
|
||||||
|
_FakeRequest(json_payload=payload, remote="203.0.113.99")
|
||||||
|
)
|
||||||
|
assert resp.status == 403
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_post_from_allowed_ip_accepted(self):
|
||||||
|
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8", "203.0.113.0/24"])
|
||||||
|
payload = {
|
||||||
|
"value": [
|
||||||
|
{
|
||||||
|
"id": "notif-ip-ok",
|
||||||
|
"resource": "communications/onlineMeetings/m",
|
||||||
|
"clientState": "expected-client-state",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
resp = await adapter._handle_notification(
|
||||||
|
_FakeRequest(json_payload=payload, remote="203.0.113.5")
|
||||||
|
)
|
||||||
|
assert resp.status == 202
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_validation_handshake_also_respects_allowlist(self):
|
||||||
|
"""A disallowed IP shouldn't be able to probe the handshake endpoint."""
|
||||||
|
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
|
||||||
|
resp = await adapter._handle_validation(
|
||||||
|
_FakeRequest(query={"validationToken": "probe"}, remote="203.0.113.99")
|
||||||
|
)
|
||||||
|
assert resp.status == 403
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_invalid_cidr_entries_are_ignored_at_init(self):
|
||||||
|
"""Malformed CIDR strings should log a warning and be ignored, not crash."""
|
||||||
|
adapter = _make_adapter(
|
||||||
|
allowed_source_cidrs=["10.0.0.0/8", "not-a-cidr", "", "203.0.113.0/24"]
|
||||||
|
)
|
||||||
|
assert len(adapter._allowed_source_networks) == 2
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_cidr_list_accepts_comma_string(self):
|
||||||
|
"""Env-var-style 'cidr1, cidr2' strings parse as a list."""
|
||||||
|
adapter = _make_adapter(allowed_source_cidrs="10.0.0.0/8, 203.0.113.0/24")
|
||||||
|
assert len(adapter._allowed_source_networks) == 2
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue