From b8d7e0e6d386eceb081cab8123db26474b1a6b9d Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Fri, 8 May 2026 09:38:51 -0700 Subject: [PATCH] fix(msgraph_webhook): harden auth surface + IP allowlisting + response hygiene Defense-in-depth polish on top of the webhook listener before it becomes a real attack surface once the pipeline starts creating subscriptions and Graph starts POSTing to the configured public URL. - Timing-safe clientState comparison. Previously used `==` on strings; switches to hmac.compare_digest so a mismatch does not leak how many leading characters matched. client_state is documented as a strong shared secret (openssl rand -hex 32 in the setup docs), so a timing-safe primitive is the right call. - Split GET and POST handlers. Graph validates a subscription by sending GET with validationToken in the query; anything else on GET is now a 400 so the endpoint cannot be probed or mistakenly used for data exfil. Previously a bare GET fell through to the POST path and blew up on request.json() with a confusing 400. - Empty response bodies on success. 202 is returned with no body so internal counters (accepted / duplicates / scheduled) do not leak to any caller that can reach the endpoint; counters remain observable via /health for operators. 403 on every-item-bad-clientState batches (so forged POSTs stop retrying), 400 on malformed / unknown-resource batches (sender configuration issue). - Optional source-IP allowlist. New `allowed_source_cidrs` extra field (list or comma-separated string) and `MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS` env var let operators restrict the webhook to Microsoft Graph's published webhook source ranges in production. Empty = allow all, preserving dev-tunnel / localhost workflows. Invalid CIDRs are logged and ignored rather than crashing. Also gates the handshake endpoint so disallowed IPs cannot probe it. - Tests updated for the new response contract (empty-body 202, auth-only 403, config-error 400) and extended to cover: bare GET rejection, POST-with-validationToken handshake tolerance, timing-safe compare actually invoked via hmac.compare_digest spy, malformed body / missing value array, IP allowlist accept/reject paths, handshake IP allowlist, invalid CIDR entries, comma-string CIDR list parsing. 52/52 passed (was 40). Full gateway suite: 5049 passed / 1 pre-existing failure in test_discord_free_response (unrelated, reproduces on clean origin/main). --- gateway/config.py | 14 ++ gateway/platforms/msgraph_webhook.py | 132 ++++++++++++++--- tests/gateway/test_msgraph_webhook.py | 204 ++++++++++++++++++++++---- 3 files changed, 301 insertions(+), 49 deletions(-) diff --git a/gateway/config.py b/gateway/config.py index 7813b16b4a..6b09b34d18 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -1418,12 +1418,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None: msgraph_webhook_port = os.getenv("MSGRAPH_WEBHOOK_PORT") msgraph_webhook_client_state = os.getenv("MSGRAPH_WEBHOOK_CLIENT_STATE", "") msgraph_webhook_resources = os.getenv("MSGRAPH_WEBHOOK_ACCEPTED_RESOURCES", "") + msgraph_webhook_allowed_cidrs = os.getenv( + "MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS", "" + ) if ( msgraph_webhook_enabled or Platform.MSGRAPH_WEBHOOK in config.platforms or msgraph_webhook_port or msgraph_webhook_client_state or msgraph_webhook_resources + or msgraph_webhook_allowed_cidrs ): if Platform.MSGRAPH_WEBHOOK not in config.platforms: config.platforms[Platform.MSGRAPH_WEBHOOK] = PlatformConfig() @@ -1450,6 +1454,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.MSGRAPH_WEBHOOK].extra[ "accepted_resources" ] = resources + if msgraph_webhook_allowed_cidrs: + cidrs = [ + cidr.strip() + for cidr in msgraph_webhook_allowed_cidrs.split(",") + if cidr.strip() + ] + if cidrs: + config.platforms[Platform.MSGRAPH_WEBHOOK].extra[ + "allowed_source_cidrs" + ] = cidrs # DingTalk dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID") diff --git a/gateway/platforms/msgraph_webhook.py b/gateway/platforms/msgraph_webhook.py index e157cd22a3..46430a25bc 100644 --- a/gateway/platforms/msgraph_webhook.py +++ b/gateway/platforms/msgraph_webhook.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio +import hmac +import ipaddress import json import logging from collections import deque @@ -60,6 +62,9 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): self._max_seen_receipts = max( 1, int(extra.get("max_seen_receipts", DEFAULT_MAX_SEEN_RECEIPTS)) ) + self._allowed_source_networks: list[ipaddress._BaseNetwork] = ( + self._parse_allowed_source_cidrs(extra.get("allowed_source_cidrs")) + ) self._runner = None self._notification_scheduler: Optional[NotificationScheduler] = None self._seen_receipts: set[str] = set() @@ -90,13 +95,47 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): def _normalize_resource_value(resource: str) -> str: return str(resource or "").strip().strip("/") + @staticmethod + def _parse_allowed_source_cidrs( + raw: Any, + ) -> list[ipaddress._BaseNetwork]: + """Parse an optional list of CIDR ranges allowed to POST to the webhook. + + An empty or missing value means "allow everything" (same behavior as + before this field existed). When populated, requests from source IPs + outside every listed CIDR are rejected with 403 before the body is + parsed. Use this to restrict the endpoint to Microsoft Graph's + published webhook source ranges in production deployments. + """ + if raw is None: + return [] + if isinstance(raw, str): + candidates = [chunk.strip() for chunk in raw.split(",")] + elif isinstance(raw, (list, tuple, set)): + candidates = [str(chunk).strip() for chunk in raw] + else: + return [] + + networks: list[ipaddress._BaseNetwork] = [] + for chunk in candidates: + if not chunk: + continue + try: + networks.append(ipaddress.ip_network(chunk, strict=False)) + except ValueError: + logger.warning( + "[msgraph_webhook] Ignoring invalid allowed_source_cidrs entry: %r", + chunk, + ) + return networks + def set_notification_scheduler(self, scheduler: Optional[NotificationScheduler]) -> None: self._notification_scheduler = scheduler async def connect(self) -> bool: app = web.Application() app.router.add_get(self._health_path, self._handle_health) - app.router.add_get(self._webhook_path, self._handle_notification) + app.router.add_get(self._webhook_path, self._handle_validation) app.router.add_post(self._webhook_path, self._handle_notification) self._runner = web.AppRunner(app) @@ -142,7 +181,28 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): } ) + async def _handle_validation(self, request: "web.Request") -> "web.Response": + """Handle Microsoft Graph subscription validation handshake. + + Graph validates a subscription endpoint by sending a GET with + ``validationToken`` in the query string; the service must echo the + token verbatim as ``text/plain`` within 10 seconds. Anything else + (bare GET, GET without the token) is rejected so the endpoint can't + be enumerated or mistakenly used for data exfiltration. + """ + if not self._source_ip_allowed(request): + return web.Response(status=403) + validation_token = request.query.get("validationToken", "") + if not validation_token: + return web.Response(status=400) + return web.Response(text=validation_token, content_type="text/plain") + async def _handle_notification(self, request: "web.Request") -> "web.Response": + if not self._source_ip_allowed(request): + return web.Response(status=403) + + # Graph never sends validationToken on POST, but tolerate it for + # defensive clients that replay the handshake in-band. validation_token = request.query.get("validationToken", "") if validation_token: return web.Response(text=validation_token, content_type="text/plain") @@ -150,27 +210,31 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): try: body = await request.json() except Exception: - return web.json_response({"error": "Invalid JSON body"}, status=400) + return web.Response(status=400) notifications = body.get("value") if not isinstance(notifications, list): - return web.json_response({"error": "Missing notification batch"}, status=400) + return web.Response(status=400) accepted = 0 duplicates = 0 - rejected = 0 - scheduled = 0 + auth_rejected = 0 + other_rejected = 0 for raw_notification in notifications: if not isinstance(raw_notification, dict): - rejected += 1 + other_rejected += 1 continue notification = dict(raw_notification) if not self._resource_accepted(str(notification.get("resource") or "")): - rejected += 1 + other_rejected += 1 continue if not self._verify_client_state(notification): - rejected += 1 + # Treat bad clientState as an auth failure: if the whole + # batch is forged, we want to signal 403 so the sender + # stops retrying. Legitimate Graph retries have valid + # clientState and hit the accepted/duplicate paths. + auth_rejected += 1 continue receipt_key = self._build_receipt_key(notification) @@ -181,23 +245,39 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): self._remember_receipt(receipt_key) accepted += 1 - scheduled += 1 self._accepted_count += 1 event = self._build_message_event(notification, receipt_key) self._schedule_notification(notification, event) self._duplicate_count += duplicates - status = 202 if accepted or duplicates else 403 - return web.json_response( - { - "status": "accepted" if accepted or duplicates else "rejected", - "accepted": accepted, - "duplicates": duplicates, - "rejected": rejected, - "scheduled": scheduled, - }, - status=status, - ) + # If anything ingested OR deduped, return 202 with empty body so + # Graph acks successfully and we don't leak internal counters. If + # every item failed auth, return 403 so an attacker POSTing fake + # notifications gets a clear reject. Other failures (malformed, + # resource-not-accepted) are the sender's configuration problem, + # so 400. + if accepted or duplicates: + return web.Response(status=202) + if auth_rejected and not other_rejected: + return web.Response(status=403) + return web.Response(status=400) + + def _source_ip_allowed(self, request: "web.Request") -> bool: + """Return True if the request's source IP is in the configured allowlist. + + When ``allowed_source_cidrs`` is empty (the default), everything is + allowed — preserves behavior for dev tunnels / localhost setups. + """ + if not self._allowed_source_networks: + return True + peer = request.remote or "" + if not peer: + return False + try: + peer_addr = ipaddress.ip_address(peer) + except ValueError: + return False + return any(peer_addr in network for network in self._allowed_source_networks) def _resource_accepted(self, resource: str) -> bool: if not self._accepted_resources: @@ -220,11 +300,21 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): return False def _verify_client_state(self, notification: Dict[str, Any]) -> bool: + """Verify the Graph-supplied clientState matches the configured secret. + + Uses ``hmac.compare_digest`` instead of ``==`` so that a mismatch + doesn't leak how many leading characters matched via string-compare + timing. The configured client_state is a shared secret (documented in + the setup guide as "generate with ``openssl rand -hex 32``"), so a + timing-safe compare is the right primitive. + """ expected = self._client_state if expected is None: return True provided = self._string_or_none(notification.get("clientState")) - return provided == expected + if provided is None: + return False + return hmac.compare_digest(provided, expected) def _has_seen_receipt(self, receipt_key: str) -> bool: return receipt_key in self._seen_receipts diff --git a/tests/gateway/test_msgraph_webhook.py b/tests/gateway/test_msgraph_webhook.py index 3c6a4daceb..d97c98492a 100644 --- a/tests/gateway/test_msgraph_webhook.py +++ b/tests/gateway/test_msgraph_webhook.py @@ -19,9 +19,10 @@ def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter: class _FakeRequest: - def __init__(self, *, query=None, json_payload=None): + def __init__(self, *, query=None, json_payload=None, remote="127.0.0.1"): self.query = query or {} self._json_payload = json_payload + self.remote = remote async def json(self): if isinstance(self._json_payload, Exception): @@ -70,14 +71,31 @@ class TestMSGraphWebhookConfig: class TestMSGraphValidationHandshake: @pytest.mark.anyio - async def test_validation_token_echo(self): + async def test_validation_token_echo_on_get(self): + adapter = _make_adapter() + resp = await adapter._handle_validation( + _FakeRequest(query={"validationToken": "abc123"}) + ) + assert resp.status == 200 + assert resp.text == "abc123" + assert resp.content_type == "text/plain" + + @pytest.mark.anyio + async def test_bare_get_without_validation_token_rejected(self): + """GET without validationToken is 400 so the endpoint can't be enumerated.""" + adapter = _make_adapter() + resp = await adapter._handle_validation(_FakeRequest()) + assert resp.status == 400 + + @pytest.mark.anyio + async def test_post_with_validation_token_still_echoes(self): + """Tolerate defensive clients that send validationToken on POST.""" adapter = _make_adapter() resp = await adapter._handle_notification( _FakeRequest(query={"validationToken": "abc123"}) ) assert resp.status == 200 assert resp.text == "abc123" - assert resp.content_type == "text/plain" class TestMSGraphNotifications: @@ -104,12 +122,10 @@ class TestMSGraphNotifications: } resp = await adapter._handle_notification(_FakeRequest(json_payload=payload)) + # Success is 202 with empty body: internal counters must not leak to + # the wire. Counters are still observable via /health. assert resp.status == 202 - data = json.loads(resp.text) - assert data["accepted"] == 1 - assert data["duplicates"] == 0 - assert data["rejected"] == 0 - assert data["scheduled"] == 1 + assert resp.body is None or not resp.body await asyncio.sleep(0.05) @@ -121,7 +137,8 @@ class TestMSGraphNotifications: assert event.message_id == "id:notif-1" @pytest.mark.anyio - async def test_bad_client_state_rejected(self): + async def test_bad_client_state_rejected_as_auth_failure(self): + """Every-item-bad-clientState batches return 403 so forged POSTs stop retrying.""" adapter = _make_adapter() scheduled: list[tuple[dict, object]] = [] @@ -143,15 +160,46 @@ class TestMSGraphNotifications: resp = await adapter._handle_notification(_FakeRequest(json_payload=payload)) assert resp.status == 403 - data = json.loads(resp.text) - assert data["accepted"] == 0 - assert data["duplicates"] == 0 - assert data["rejected"] == 1 await asyncio.sleep(0.05) assert scheduled == [] + @pytest.mark.anyio + async def test_client_state_compare_is_timing_safe(self, monkeypatch): + """Ensure hmac.compare_digest is used for clientState comparison.""" + import hmac + + calls: list[tuple[str, str]] = [] + real_compare = hmac.compare_digest + + def _spy(a, b): + calls.append((a, b)) + return real_compare(a, b) + + monkeypatch.setattr( + "gateway.platforms.msgraph_webhook.hmac.compare_digest", _spy + ) + + adapter = _make_adapter() + payload = { + "value": [ + { + "id": "notif-timing", + "subscriptionId": "sub-1", + "changeType": "updated", + "resource": "communications/onlineMeetings/meeting-x", + "clientState": "expected-client-state", + } + ] + } + await adapter._handle_notification(_FakeRequest(json_payload=payload)) + + assert calls, "hmac.compare_digest was never called; clientState check is not timing-safe" + provided, expected = calls[0] + assert provided == "expected-client-state" + assert expected == "expected-client-state" + @pytest.mark.anyio async def test_duplicate_notification_deduped(self): adapter = _make_adapter() @@ -176,11 +224,9 @@ class TestMSGraphNotifications: first = await adapter._handle_notification(_FakeRequest(json_payload=payload)) assert first.status == 202 second = await adapter._handle_notification(_FakeRequest(json_payload=payload)) + # Duplicate-only batch still returns 202 so Graph stops retrying. assert second.status == 202 - second_data = json.loads(second.text) - assert second_data["accepted"] == 0 - assert second_data["duplicates"] == 1 - assert second_data["scheduled"] == 0 + assert adapter._duplicate_count == 1 await asyncio.sleep(0.05) @@ -212,10 +258,6 @@ class TestMSGraphNotifications: assert first.status == 202 assert second.status == 202 - second_data = json.loads(second.text) - assert second_data["accepted"] == 1 - assert second_data["duplicates"] == 0 - assert second_data["scheduled"] == 1 await asyncio.sleep(0.05) @@ -237,11 +279,39 @@ class TestMSGraphNotifications: } resp = await adapter._handle_notification(_FakeRequest(json_payload=payload)) - data = json.loads(resp.text) - assert resp.status == 202 - assert data["accepted"] == 1 - assert data["rejected"] == 0 + + @pytest.mark.anyio + async def test_resource_not_in_allowlist_returns_400(self): + """Every-item-rejected-for-non-auth returns 400 (configuration issue).""" + adapter = _make_adapter(accepted_resources=["communications/onlineMeetings"]) + payload = { + "value": [ + { + "id": "notif-bad-resource", + "resource": "users/u1/messages", + "clientState": "expected-client-state", + } + ] + } + resp = await adapter._handle_notification(_FakeRequest(json_payload=payload)) + assert resp.status == 400 + + @pytest.mark.anyio + async def test_malformed_body_returns_400(self): + adapter = _make_adapter() + resp = await adapter._handle_notification( + _FakeRequest(json_payload=ValueError("bad json")) + ) + assert resp.status == 400 + + @pytest.mark.anyio + async def test_missing_value_array_returns_400(self): + adapter = _make_adapter() + resp = await adapter._handle_notification( + _FakeRequest(json_payload={"not_value": []}) + ) + assert resp.status == 400 @pytest.mark.anyio async def test_seen_receipts_are_bounded(self): @@ -277,6 +347,84 @@ class TestMSGraphNotifications: assert list(adapter._seen_receipt_order) == ["id:notif-b", "id:notif-c"] replay = await _post("notif-a") - replay_data = json.loads(replay.text) - assert replay_data["accepted"] == 1 - assert replay_data["duplicates"] == 0 + # notif-a evicted from the bounded cache, so it's accepted again (202) + # rather than treated as a duplicate. + assert replay.status == 202 + assert adapter._accepted_count == 4 + + +class TestMSGraphSourceIPAllowlist: + @pytest.mark.anyio + async def test_disabled_by_default_allows_all(self): + """Empty allowlist preserves pre-existing behavior (dev tunnels, localhost).""" + adapter = _make_adapter() # no allowed_source_cidrs set + payload = { + "value": [ + { + "id": "notif-ip", + "resource": "communications/onlineMeetings/m", + "clientState": "expected-client-state", + } + ] + } + resp = await adapter._handle_notification( + _FakeRequest(json_payload=payload, remote="203.0.113.99") + ) + assert resp.status == 202 + + @pytest.mark.anyio + async def test_post_from_disallowed_ip_rejected(self): + adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"]) + payload = { + "value": [ + { + "id": "notif-ip-bad", + "resource": "communications/onlineMeetings/m", + "clientState": "expected-client-state", + } + ] + } + resp = await adapter._handle_notification( + _FakeRequest(json_payload=payload, remote="203.0.113.99") + ) + assert resp.status == 403 + + @pytest.mark.anyio + async def test_post_from_allowed_ip_accepted(self): + adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8", "203.0.113.0/24"]) + payload = { + "value": [ + { + "id": "notif-ip-ok", + "resource": "communications/onlineMeetings/m", + "clientState": "expected-client-state", + } + ] + } + resp = await adapter._handle_notification( + _FakeRequest(json_payload=payload, remote="203.0.113.5") + ) + assert resp.status == 202 + + @pytest.mark.anyio + async def test_validation_handshake_also_respects_allowlist(self): + """A disallowed IP shouldn't be able to probe the handshake endpoint.""" + adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"]) + resp = await adapter._handle_validation( + _FakeRequest(query={"validationToken": "probe"}, remote="203.0.113.99") + ) + assert resp.status == 403 + + @pytest.mark.anyio + async def test_invalid_cidr_entries_are_ignored_at_init(self): + """Malformed CIDR strings should log a warning and be ignored, not crash.""" + adapter = _make_adapter( + allowed_source_cidrs=["10.0.0.0/8", "not-a-cidr", "", "203.0.113.0/24"] + ) + assert len(adapter._allowed_source_networks) == 2 + + @pytest.mark.anyio + async def test_cidr_list_accepts_comma_string(self): + """Env-var-style 'cidr1, cidr2' strings parse as a list.""" + adapter = _make_adapter(allowed_source_cidrs="10.0.0.0/8, 203.0.113.0/24") + assert len(adapter._allowed_source_networks) == 2