fix(msgraph_webhook): harden auth surface + IP allowlisting + response hygiene

Defense-in-depth polish on top of the webhook listener before it becomes
a real attack surface once the pipeline starts creating subscriptions
and Graph starts POSTing to the configured public URL.

- Timing-safe clientState comparison. Previously used `==` on strings;
  switches to hmac.compare_digest so a mismatch does not leak how many
  leading characters matched. client_state is documented as a strong
  shared secret (openssl rand -hex 32 in the setup docs), so a
  timing-safe primitive is the right call.

- Split GET and POST handlers. Graph validates a subscription by sending
  GET with validationToken in the query; anything else on GET is now a
  400 so the endpoint cannot be probed or mistakenly used for data
  exfil. Previously a bare GET fell through to the POST path and blew
  up on request.json() with a confusing 400.

- Empty response bodies on success. 202 is returned with no body so
  internal counters (accepted / duplicates / scheduled) do not leak to
  any caller that can reach the endpoint; counters remain observable
  via /health for operators. 403 on every-item-bad-clientState batches
  (so forged POSTs stop retrying), 400 on malformed / unknown-resource
  batches (sender configuration issue).

- Optional source-IP allowlist. New `allowed_source_cidrs` extra field
  (list or comma-separated string) and `MSGRAPH_WEBHOOK_ALLOWED_SOURCE_CIDRS`
  env var let operators restrict the webhook to Microsoft Graph's
  published webhook source ranges in production. Empty = allow all,
  preserving dev-tunnel / localhost workflows. Invalid CIDRs are
  logged and ignored rather than crashing. Also gates the handshake
  endpoint so disallowed IPs cannot probe it.

- Tests updated for the new response contract (empty-body 202,
  auth-only 403, config-error 400) and extended to cover: bare GET
  rejection, POST-with-validationToken handshake tolerance,
  timing-safe compare actually invoked via hmac.compare_digest spy,
  malformed body / missing value array, IP allowlist accept/reject
  paths, handshake IP allowlist, invalid CIDR entries, comma-string
  CIDR list parsing. 52/52 passed (was 40).

Full gateway suite: 5049 passed / 1 pre-existing failure in
test_discord_free_response (unrelated, reproduces on clean origin/main).
This commit is contained in:
Teknium 2026-05-08 09:38:51 -07:00
parent 26a59e4f6c
commit b8d7e0e6d3
3 changed files with 301 additions and 49 deletions

View file

@ -19,9 +19,10 @@ def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter:
class _FakeRequest:
def __init__(self, *, query=None, json_payload=None):
def __init__(self, *, query=None, json_payload=None, remote="127.0.0.1"):
self.query = query or {}
self._json_payload = json_payload
self.remote = remote
async def json(self):
if isinstance(self._json_payload, Exception):
@ -70,14 +71,31 @@ class TestMSGraphWebhookConfig:
class TestMSGraphValidationHandshake:
@pytest.mark.anyio
async def test_validation_token_echo(self):
async def test_validation_token_echo_on_get(self):
adapter = _make_adapter()
resp = await adapter._handle_validation(
_FakeRequest(query={"validationToken": "abc123"})
)
assert resp.status == 200
assert resp.text == "abc123"
assert resp.content_type == "text/plain"
@pytest.mark.anyio
async def test_bare_get_without_validation_token_rejected(self):
"""GET without validationToken is 400 so the endpoint can't be enumerated."""
adapter = _make_adapter()
resp = await adapter._handle_validation(_FakeRequest())
assert resp.status == 400
@pytest.mark.anyio
async def test_post_with_validation_token_still_echoes(self):
"""Tolerate defensive clients that send validationToken on POST."""
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(query={"validationToken": "abc123"})
)
assert resp.status == 200
assert resp.text == "abc123"
assert resp.content_type == "text/plain"
class TestMSGraphNotifications:
@ -104,12 +122,10 @@ class TestMSGraphNotifications:
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
# Success is 202 with empty body: internal counters must not leak to
# the wire. Counters are still observable via /health.
assert resp.status == 202
data = json.loads(resp.text)
assert data["accepted"] == 1
assert data["duplicates"] == 0
assert data["rejected"] == 0
assert data["scheduled"] == 1
assert resp.body is None or not resp.body
await asyncio.sleep(0.05)
@ -121,7 +137,8 @@ class TestMSGraphNotifications:
assert event.message_id == "id:notif-1"
@pytest.mark.anyio
async def test_bad_client_state_rejected(self):
async def test_bad_client_state_rejected_as_auth_failure(self):
"""Every-item-bad-clientState batches return 403 so forged POSTs stop retrying."""
adapter = _make_adapter()
scheduled: list[tuple[dict, object]] = []
@ -143,15 +160,46 @@ class TestMSGraphNotifications:
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert resp.status == 403
data = json.loads(resp.text)
assert data["accepted"] == 0
assert data["duplicates"] == 0
assert data["rejected"] == 1
await asyncio.sleep(0.05)
assert scheduled == []
@pytest.mark.anyio
async def test_client_state_compare_is_timing_safe(self, monkeypatch):
"""Ensure hmac.compare_digest is used for clientState comparison."""
import hmac
calls: list[tuple[str, str]] = []
real_compare = hmac.compare_digest
def _spy(a, b):
calls.append((a, b))
return real_compare(a, b)
monkeypatch.setattr(
"gateway.platforms.msgraph_webhook.hmac.compare_digest", _spy
)
adapter = _make_adapter()
payload = {
"value": [
{
"id": "notif-timing",
"subscriptionId": "sub-1",
"changeType": "updated",
"resource": "communications/onlineMeetings/meeting-x",
"clientState": "expected-client-state",
}
]
}
await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert calls, "hmac.compare_digest was never called; clientState check is not timing-safe"
provided, expected = calls[0]
assert provided == "expected-client-state"
assert expected == "expected-client-state"
@pytest.mark.anyio
async def test_duplicate_notification_deduped(self):
adapter = _make_adapter()
@ -176,11 +224,9 @@ class TestMSGraphNotifications:
first = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert first.status == 202
second = await adapter._handle_notification(_FakeRequest(json_payload=payload))
# Duplicate-only batch still returns 202 so Graph stops retrying.
assert second.status == 202
second_data = json.loads(second.text)
assert second_data["accepted"] == 0
assert second_data["duplicates"] == 1
assert second_data["scheduled"] == 0
assert adapter._duplicate_count == 1
await asyncio.sleep(0.05)
@ -212,10 +258,6 @@ class TestMSGraphNotifications:
assert first.status == 202
assert second.status == 202
second_data = json.loads(second.text)
assert second_data["accepted"] == 1
assert second_data["duplicates"] == 0
assert second_data["scheduled"] == 1
await asyncio.sleep(0.05)
@ -237,11 +279,39 @@ class TestMSGraphNotifications:
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
data = json.loads(resp.text)
assert resp.status == 202
assert data["accepted"] == 1
assert data["rejected"] == 0
@pytest.mark.anyio
async def test_resource_not_in_allowlist_returns_400(self):
"""Every-item-rejected-for-non-auth returns 400 (configuration issue)."""
adapter = _make_adapter(accepted_resources=["communications/onlineMeetings"])
payload = {
"value": [
{
"id": "notif-bad-resource",
"resource": "users/u1/messages",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
assert resp.status == 400
@pytest.mark.anyio
async def test_malformed_body_returns_400(self):
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(json_payload=ValueError("bad json"))
)
assert resp.status == 400
@pytest.mark.anyio
async def test_missing_value_array_returns_400(self):
adapter = _make_adapter()
resp = await adapter._handle_notification(
_FakeRequest(json_payload={"not_value": []})
)
assert resp.status == 400
@pytest.mark.anyio
async def test_seen_receipts_are_bounded(self):
@ -277,6 +347,84 @@ class TestMSGraphNotifications:
assert list(adapter._seen_receipt_order) == ["id:notif-b", "id:notif-c"]
replay = await _post("notif-a")
replay_data = json.loads(replay.text)
assert replay_data["accepted"] == 1
assert replay_data["duplicates"] == 0
# notif-a evicted from the bounded cache, so it's accepted again (202)
# rather than treated as a duplicate.
assert replay.status == 202
assert adapter._accepted_count == 4
class TestMSGraphSourceIPAllowlist:
@pytest.mark.anyio
async def test_disabled_by_default_allows_all(self):
"""Empty allowlist preserves pre-existing behavior (dev tunnels, localhost)."""
adapter = _make_adapter() # no allowed_source_cidrs set
payload = {
"value": [
{
"id": "notif-ip",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.99")
)
assert resp.status == 202
@pytest.mark.anyio
async def test_post_from_disallowed_ip_rejected(self):
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
payload = {
"value": [
{
"id": "notif-ip-bad",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.99")
)
assert resp.status == 403
@pytest.mark.anyio
async def test_post_from_allowed_ip_accepted(self):
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8", "203.0.113.0/24"])
payload = {
"value": [
{
"id": "notif-ip-ok",
"resource": "communications/onlineMeetings/m",
"clientState": "expected-client-state",
}
]
}
resp = await adapter._handle_notification(
_FakeRequest(json_payload=payload, remote="203.0.113.5")
)
assert resp.status == 202
@pytest.mark.anyio
async def test_validation_handshake_also_respects_allowlist(self):
"""A disallowed IP shouldn't be able to probe the handshake endpoint."""
adapter = _make_adapter(allowed_source_cidrs=["10.0.0.0/8"])
resp = await adapter._handle_validation(
_FakeRequest(query={"validationToken": "probe"}, remote="203.0.113.99")
)
assert resp.status == 403
@pytest.mark.anyio
async def test_invalid_cidr_entries_are_ignored_at_init(self):
"""Malformed CIDR strings should log a warning and be ignored, not crash."""
adapter = _make_adapter(
allowed_source_cidrs=["10.0.0.0/8", "not-a-cidr", "", "203.0.113.0/24"]
)
assert len(adapter._allowed_source_networks) == 2
@pytest.mark.anyio
async def test_cidr_list_accepts_comma_string(self):
"""Env-var-style 'cidr1, cidr2' strings parse as a list."""
adapter = _make_adapter(allowed_source_cidrs="10.0.0.0/8, 203.0.113.0/24")
assert len(adapter._allowed_source_networks) == 2