fix(msgraph_webhook): harden auth surface + IP allowlisting + response hygiene

Defense-in-depth polish on top of the webhook listener before it becomes
a real attack surface once the pipeline starts creating subscriptions
and Graph starts POSTing to the configured public URL.

- Timing-safe clientState comparison. Previously used `==` on strings;
  switches to hmac.compare_digest so a mismatch does not leak how many
  leading characters matched. client_state is documented as a strong
  shared secret (openssl rand -hex 32 in the setup docs), so a
  timing-safe primitive is the right call.

- Split GET and POST handlers. Graph validates a subscription by sending
  GET with validationToken in the query; anything else on GET is now a
  400 so the endpoint cannot be probed or mistakenly used for data
  exfil. Previously a bare GET fell through to the POST path and blew
  up on request.json() with a confusing 400.

- Empty response bodies on success. 202 is returned with no body so
  internal counters (accepted / duplicates / scheduled) do not leak to
  any caller that can reach the endpoint; counters remain observable
  via /health for operators. 403 on every-item-bad-clientState batches
  (so forged POSTs stop retrying), 400 on malformed / unknown-resource
  batches (sender configuration issue).

- Optional source-IP allowlist. New `allowed_source_cidrs` extra field
  (list or comma-separated string) and `MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS`
  env var let operators restrict the webhook to Microsoft Graph's
  published webhook source ranges in production. Empty = allow all,
  preserving dev-tunnel / localhost workflows. Invalid CIDRs are
  logged and ignored rather than crashing. Also gates the handshake
  endpoint so disallowed IPs cannot probe it.

- Tests updated for the new response contract (empty-body 202,
  auth-only 403, config-error 400) and extended to cover: bare GET
  rejection, POST-with-validationToken handshake tolerance,
  timing-safe compare actually invoked via hmac.compare_digest spy,
  malformed body / missing value array, IP allowlist accept/reject
  paths, handshake IP allowlist, invalid CIDR entries, comma-string
  CIDR list parsing. 52/52 passed (was 40).

Full gateway suite: 5049 passed / 1 pre-existing failure in
test_discord_free_response (unrelated, reproduces on clean origin/main).
This commit is contained in:
Teknium 2026-05-08 09:38:51 -07:00
parent 26a59e4f6c
commit b8d7e0e6d3
3 changed files with 301 additions and 49 deletions

View file

@ -1418,12 +1418,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
msgraph_webhook_port = os.getenv("MSGRAPH_WEBHOOK_PORT")
msgraph_webhook_client_state = os.getenv("MSGRAPH_WEBHOOK_CLIENT_STATE", "")
msgraph_webhook_resources = os.getenv("MSGRAPH_WEBHOOK_ACCEPTED_RESOURCES", "")
msgraph_webhook_allowed_cidrs = os.getenv(
"MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS", ""
)
if (
msgraph_webhook_enabled
or Platform.MSGRAPH_WEBHOOK in config.platforms
or msgraph_webhook_port
or msgraph_webhook_client_state
or msgraph_webhook_resources
or msgraph_webhook_allowed_cidrs
):
if Platform.MSGRAPH_WEBHOOK not in config.platforms:
config.platforms[Platform.MSGRAPH_WEBHOOK] = PlatformConfig()
@ -1450,6 +1454,16 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
config.platforms[Platform.MSGRAPH_WEBHOOK].extra[
"accepted_resources"
] = resources
if msgraph_webhook_allowed_cidrs:
cidrs = [
cidr.strip()
for cidr in msgraph_webhook_allowed_cidrs.split(",")
if cidr.strip()
]
if cidrs:
config.platforms[Platform.MSGRAPH_WEBHOOK].extra[
"allowed_source_cidrs"
] = cidrs
# DingTalk
dingtalk_client_id = os.getenv("DINGTALK_CLIENT_ID")

View file

@ -3,6 +3,8 @@
from __future__ import annotations
import asyncio
import hmac
import ipaddress
import json
import logging
from collections import deque
@ -60,6 +62,9 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
self._max_seen_receipts = max(
1, int(extra.get("max_seen_receipts", DEFAULT_MAX_SEEN_RECEIPTS))
)
self._allowed_source_networks: list[ipaddress._BaseNetwork] = (
self._parse_allowed_source_cidrs(extra.get("allowed_source_cidrs"))
)
self._runner = None
self._notification_scheduler: Optional[NotificationScheduler] = None
self._seen_receipts: set[str] = set()
@ -90,13 +95,47 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
def _normalize_resource_value(resource: str) -> str:
return str(resource or "").strip().strip("/")
@staticmethod
def _parse_allowed_source_cidrs(
raw: Any,
) -> list[ipaddress._BaseNetwork]:
"""Parse an optional list of CIDR ranges allowed to POST to the webhook.
An empty or missing value means "allow everything" (same behavior as
before this field existed). When populated, requests from source IPs
outside every listed CIDR are rejected with 403 before the body is
parsed. Use this to restrict the endpoint to Microsoft Graph's
published webhook source ranges in production deployments.
"""
if raw is None:
return []
if isinstance(raw, str):
candidates = [chunk.strip() for chunk in raw.split(",")]
elif isinstance(raw, (list, tuple, set)):
candidates = [str(chunk).strip() for chunk in raw]
else:
return []
networks: list[ipaddress._BaseNetwork] = []
for chunk in candidates:
if not chunk:
continue
try:
networks.append(ipaddress.ip_network(chunk, strict=False))
except ValueError:
logger.warning(
"[msgraph_webhook] Ignoring invalid allowed_source_cidrs entry: %r",
chunk,
)
return networks
def set_notification_scheduler(self, scheduler: Optional[NotificationScheduler]) -> None:
self._notification_scheduler = scheduler
async def connect(self) -> bool:
app = web.Application()
app.router.add_get(self._health_path, self._handle_health)
app.router.add_get(self._webhook_path, self._handle_notification)
app.router.add_get(self._webhook_path, self._handle_validation)
app.router.add_post(self._webhook_path, self._handle_notification)
self._runner = web.AppRunner(app)
@ -142,7 +181,28 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
}
)
async def _handle_validation(self, request: "web.Request") -> "web.Response":
"""Handle Microsoft Graph subscription validation handshake.
Graph validates a subscription endpoint by sending a GET with
``validationToken`` in the query string; the service must echo the
token verbatim as ``text/plain`` within 10 seconds. Anything else
(bare GET, GET without the token) is rejected so the endpoint can't
be enumerated or mistakenly used for data exfiltration.
"""
if not self._source_ip_allowed(request):
return web.Response(status=403)
validation_token = request.query.get("validationToken", "")
if not validation_token:
return web.Response(status=400)
return web.Response(text=validation_token, content_type="text/plain")
async def _handle_notification(self, request: "web.Request") -> "web.Response":
if not self._source_ip_allowed(request):
return web.Response(status=403)
# Graph never sends validationToken on POST, but tolerate it for
# defensive clients that replay the handshake in-band.
validation_token = request.query.get("validationToken", "")
if validation_token:
return web.Response(text=validation_token, content_type="text/plain")
@ -150,27 +210,31 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
try:
body = await request.json()
except Exception:
return web.json_response({"error": "Invalid JSON body"}, status=400)
return web.Response(status=400)
notifications = body.get("value")
if not isinstance(notifications, list):
return web.json_response({"error": "Missing notification batch"}, status=400)
return web.Response(status=400)
accepted = 0
duplicates = 0
rejected = 0
scheduled = 0
auth_rejected = 0
other_rejected = 0
for raw_notification in notifications:
if not isinstance(raw_notification, dict):
rejected += 1
other_rejected += 1
continue
notification = dict(raw_notification)
if not self._resource_accepted(str(notification.get("resource") or "")):
rejected += 1
other_rejected += 1
continue
if not self._verify_client_state(notification):
rejected += 1
# Treat bad clientState as an auth failure: if the whole
# batch is forged, we want to signal 403 so the sender
# stops retrying. Legitimate Graph retries have valid
# clientState and hit the accepted/duplicate paths.
auth_rejected += 1
continue
receipt_key = self._build_receipt_key(notification)
@ -181,23 +245,39 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
self._remember_receipt(receipt_key)
accepted += 1
scheduled += 1
self._accepted_count += 1
event = self._build_message_event(notification, receipt_key)
self._schedule_notification(notification, event)
self._duplicate_count += duplicates
status = 202 if accepted or duplicates else 403
return web.json_response(
{
"status": "accepted" if accepted or duplicates else "rejected",
"accepted": accepted,
"duplicates": duplicates,
"rejected": rejected,
"scheduled": scheduled,
},
status=status,
)
# If anything ingested OR deduped, return 202 with empty body so
# Graph acks successfully and we don't leak internal counters. If
# every item failed auth, return 403 so an attacker POSTing fake
# notifications gets a clear reject. Other failures (malformed,
# resource-not-accepted) are the sender's configuration problem,
# so 400.
if accepted or duplicates:
return web.Response(status=202)
if auth_rejected and not other_rejected:
return web.Response(status=403)
return web.Response(status=400)
def _source_ip_allowed(self, request: "web.Request") -> bool:
"""Return True if the request's source IP is in the configured allowlist.
When ``allowed_source_cidrs`` is empty (the default), everything is
allowed preserves behavior for dev tunnels / localhost setups.
"""
if not self._allowed_source_networks:
return True
peer = request.remote or ""
if not peer:
return False
try:
peer_addr = ipaddress.ip_address(peer)
except ValueError:
return False
return any(peer_addr in network for network in self._allowed_source_networks)
def _resource_accepted(self, resource: str) -> bool:
if not self._accepted_resources:
@ -220,11 +300,21 @@ class MSGraphWebhookAdapter(BasePlatformAdapter):
return False
def _verify_client_state(self, notification: Dict[str, Any]) -> bool:
"""Verify the Graph-supplied clientState matches the configured secret.
Uses ``hmac.compare_digest`` instead of ``==`` so that a mismatch
doesn't leak how many leading characters matched via string-compare
timing. The configured client_state is a shared secret (documented in
the setup guide as "generate with ``openssl rand -hex 32``"), so a
timing-safe compare is the right primitive.
"""
expected = self._client_state
if expected is None:
return True
provided = self._string_or_none(notification.get("clientState"))
return provided == expected
if provided is None:
return False
return hmac.compare_digest(provided, expected)
def _has_seen_receipt(self, receipt_key: str) -> bool:
return receipt_key in self._seen_receipts

View file

@ -19,9 +19,10 @@ def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter:
class _FakeRequest:
def __init__(self, *, query=None, json_payload=None):
def __init__(self, *, query=None, json_payload=None, remote="127.0.0.1"):
self.query = query or {}
self._json_payload = json_payload
self.remote = remote
async def json(self):
if isinstance(self._json_payload, Exception):
@ -70,14 +71,31 @@ class TestMSGraphWebhookConfig:
class TestMSGraphValidationHandshake:
@pytest.mark.anyio
async def test_validation_token_echo(self):
async def test_validation_token_echo_on_get(self):
adapter = _make_adapter()
resp = await adapter._handle_validation(
_FakeRequest(query={"validationToken": "abc123"})
)
assert resp.status == 200
assert resp.text == "abc123"
assert resp.content_type == "text/plain"
@pytest.mark.anyio
async def test_bare_get_without_validation_token_rejected(self):
"""GET without validationToken is 400 so the endpoint can't be enumerated."""
adapter = _make_adapter()
resp = await adapter._handle_validation(_FakeRequest())
assert resp.status == 400
@pytest.mark.anyio
async def test_post_with_validation_token_still_echoes(self):
"""Tolerate defensive clients that send validationToken on POST."""
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(query={"validationToken": "abc123"})
)
assert resp.status == 200
assert resp.text == "abc123"
assert resp.content_type == "text/plain"
class TestMSGraphNotifications:
@ -104,12 +122,10 @@ class TestMSGraphNotifications:
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
# Success is 202 with empty body: internal counters must not leak to
# the wire. Counters are still observable via /health.
assert resp.status == 202
data = json.loads(resp.text)
assert data["accepted"] == 1
assert data["duplicates"] == 0
assert data["rejected"] == 0
assert data["scheduled"] == 1
assert resp.body is None or not resp.body
await asyncio.sleep(0.05)
@ -121,7 +137,8 @@ class TestMSGraphNotifications:
assert event.message_id == "id:notif-1"
@pytest.mark.anyio
async def test_bad_client_state_rejected(self):
async def test_bad_client_state_rejected_as_auth_failure(self):
"""Every-item-bad-clientState batches return 403 so forged POSTs stop retrying."""
adapter = _make_adapter()
scheduled: list[tuple[dict, object]] = []
@ -143,15 +160,46 @@ class TestMSGraphNotifications:
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert resp.status == 403
data = json.loads(resp.text)
assert data["accepted"] == 0
assert data["duplicates"] == 0
assert data["rejected"] == 1
await asyncio.sleep(0.05)
assert scheduled == []
@pytest.mark.anyio
async def test_client_state_compare_is_timing_safe(self, monkeypatch):
"""Ensure hmac.compare_digest is used for clientState comparison."""
import hmac
calls: list[tuple[str, str]] = []
real_compare = hmac.compare_digest
def _spy(a, b):
calls.append((a, b))
return real_compare(a, b)
monkeypatch.setattr(
"gateway.platforms.msgraph_webhook.hmac.compare_digest", _spy
)
adapter = _make_adapter()
payload = {
"value": [
{
"id": "notif-timing",
"subscriptionId": "sub-1",
"changeType": "updated",
"resource": "communications/onlineMeetings/meeting-x",
"clientState": "expected-client-state",
}
]
}
await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert calls, "hmac.compare_digest was never called; clientState check is not timing-safe"
provided, expected = calls[0]
assert provided == "expected-client-state"
assert expected == "expected-client-state"
@pytest.mark.anyio
async def test_duplicate_notification_deduped(self):
adapter = _make_adapter()
@ -176,11 +224,9 @@ class TestMSGraphNotifications:
first = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert first.status == 202
second = await adapter._handle_notification(_FakeRequest(json_payload=payload))
# Duplicate-only batch still returns 202 so Graph stops retrying.
assert second.status == 202
second_data = json.loads(second.text)
assert second_data["accepted"] == 0
assert second_data["duplicates"] == 1
assert second_data["scheduled"] == 0
assert adapter._duplicate_count == 1
await asyncio.sleep(0.05)
@ -212,10 +258,6 @@ class TestMSGraphNotifications:
assert first.status == 202
assert second.status == 202
second_data = json.loads(second.text)
assert second_data["accepted"] == 1
assert second_data["duplicates"] == 0
assert second_data["scheduled"] == 1
await asyncio.sleep(0.05)
@ -237,11 +279,39 @@ class TestMSGraphNotifications:
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
data = json.loads(resp.text)
assert resp.status == 202
assert data["accepted"] == 1
assert data["rejected"] == 0
@pytest.mark.anyio
async def test_resource_not_in_allowlist_returns_400(self):
"""Every-item-rejected-for-non-auth returns 400 (configuration issue)."""
adapter = _make_adapter(accepted_resources=["communications/onlineMeetings"])
payload = {
"value": [
{
"id": "notif-bad-resource",
"resource": "users/u1/messages",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert resp.status == 400
@pytest.mark.anyio
async def test_malformed_body_returns_400(self):
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(json_payload=ValueError("bad json"))
)
assert resp.status == 400
@pytest.mark.anyio
async def test_missing_value_array_returns_400(self):
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(json_payload={"not_value": []})
)
assert resp.status == 400
@pytest.mark.anyio
async def test_seen_receipts_are_bounded(self):
@ -277,6 +347,84 @@ class TestMSGraphNotifications:
assert list(adapter._seen_receipt_order) == ["id:notif-b", "id:notif-c"]
replay = await _post("notif-a")
replay_data = json.loads(replay.text)
assert replay_data["accepted"] == 1
assert replay_data["duplicates"] == 0
# notif-a evicted from the bounded cache, so it's accepted again (202)
# rather than treated as a duplicate.
assert replay.status == 202
assert adapter._accepted_count == 4
class TestMSGraphSourceIPAllowlist:
@pytest.mark.anyio
async def test_disabled_by_default_allows_all(self):
"""Empty allowlist preserves pre-existing behavior (dev tunnels, localhost)."""
adapter = _make_adapter() # no allowed_source_cidrs set
payload = {
"value": [
{
"id": "notif-ip",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.99")
)
assert resp.status == 202
@pytest.mark.anyio
async def test_post_from_disallowed_ip_rejected(self):
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
payload = {
"value": [
{
"id": "notif-ip-bad",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.99")
)
assert resp.status == 403
@pytest.mark.anyio
async def test_post_from_allowed_ip_accepted(self):
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8", "203.0.113.0/24"])
payload = {
"value": [
{
"id": "notif-ip-ok",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.5")
)
assert resp.status == 202
@pytest.mark.anyio
async def test_validation_handshake_also_respects_allowlist(self):
"""A disallowed IP shouldn't be able to probe the handshake endpoint."""
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
resp = await adapter._handle_validation(
_FakeRequest(query={"validationToken": "probe"}, remote="203.0.113.99")
)
assert resp.status == 403
@pytest.mark.anyio
async def test_invalid_cidr_entries_are_ignored_at_init(self):
"""Malformed CIDR strings should log a warning and be ignored, not crash."""
adapter = _make_adapter(
allowed_source_cidrs=["10.0.0.0/8", "not-a-cidr", "", "203.0.113.0/24"]
)
assert len(adapter._allowed_source_networks) == 2
@pytest.mark.anyio
async def test_cidr_list_accepts_comma_string(self):
"""Env-var-style 'cidr1, cidr2' strings parse as a list."""
adapter = _make_adapter(allowed_source_cidrs="10.0.0.0/8, 203.0.113.0/24")
assert len(adapter._allowed_source_networks) == 2