From 52a368fa722f91337786739181eb653f779d4b86 Mon Sep 17 00:00:00 2001 From: QuenVix Date: Sat, 23 May 2026 10:25:35 +0300 Subject: [PATCH] fix(gateway): preserve WhatsApp pairing approvals across JID/LID alias flips --- gateway/pairing.py | 73 ++++++++++++++++++++++++++++++----- tests/gateway/test_pairing.py | 73 +++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 10 deletions(-) diff --git a/gateway/pairing.py b/gateway/pairing.py index 96949eba14c..b8bfe46a9a8 100644 --- a/gateway/pairing.py +++ b/gateway/pairing.py @@ -28,6 +28,10 @@ import time from pathlib import Path from typing import Optional +from gateway.whatsapp_identity import ( + expand_whatsapp_aliases, + normalize_whatsapp_identifier, +) from hermes_constants import get_hermes_dir from utils import atomic_replace @@ -110,12 +114,40 @@ class PairingStore: def _save_json(self, path: Path, data: dict) -> None: _secure_write(path, json.dumps(data, indent=2, ensure_ascii=False)) + def _normalize_user_id(self, platform: str, user_id: str) -> str: + """Normalize platform-specific user IDs before persisting them.""" + raw_user_id = str(user_id or "").strip() + if platform == "whatsapp": + return normalize_whatsapp_identifier(raw_user_id) or raw_user_id + return raw_user_id + + def _user_id_aliases(self, platform: str, user_id: str) -> set[str]: + """Return all known equivalent user IDs for auth/rate-limit checks.""" + raw_user_id = str(user_id or "").strip() + if not raw_user_id: + return set() + + aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)} + if platform == "whatsapp": + aliases.update(expand_whatsapp_aliases(raw_user_id)) + aliases.discard("") + return aliases + + def _user_ids_match(self, platform: str, left: str, right: str) -> bool: + """Return True when two user IDs represent the same principal.""" + left_aliases = self._user_id_aliases(platform, left) + right_aliases = self._user_id_aliases(platform, right) + return bool(left_aliases and right_aliases and (left_aliases & right_aliases)) + # ----- Approved users ----- def is_approved(self, platform: str, user_id: str) -> bool: """Check if a user is approved (paired) on a platform.""" approved = self._load_json(self._approved_path(platform)) - return user_id in approved + for approved_user_id in approved: + if self._user_ids_match(platform, approved_user_id, user_id): + return True + return False def list_approved(self, platform: str = None) -> list: """List approved users, optionally filtered by platform.""" @@ -130,7 +162,16 @@ class PairingStore: def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None: """Add a user to the approved list. Must be called under self._lock.""" approved = self._load_json(self._approved_path(platform)) - approved[user_id] = { + normalized_user_id = self._normalize_user_id(platform, user_id) + duplicate_ids = [ + approved_user_id + for approved_user_id in approved + if self._user_ids_match(platform, approved_user_id, normalized_user_id) + ] + for approved_user_id in duplicate_ids: + del approved[approved_user_id] + + approved[normalized_user_id] = { "user_name": user_name, "approved_at": time.time(), } @@ -141,8 +182,14 @@ class PairingStore: path = self._approved_path(platform) with self._lock: approved = self._load_json(path) - if user_id in approved: - del approved[user_id] + matching_ids = [ + approved_user_id + for approved_user_id in approved + if self._user_ids_match(platform, approved_user_id, user_id) + ] + if matching_ids: + for approved_user_id in matching_ids: + del approved[approved_user_id] self._save_json(path, approved) return True return False @@ -170,6 +217,7 @@ class PairingStore: """ with self._lock: self._cleanup_expired(platform) + normalized_user_id = self._normalize_user_id(platform, user_id) # Check lockout if self._is_locked_out(platform): @@ -198,7 +246,7 @@ class PairingStore: pending[entry_id] = { "hash": code_hash, "salt": salt.hex(), - "user_id": user_id, + "user_id": normalized_user_id, "user_name": user_name, "created_at": time.time(), } @@ -326,15 +374,20 @@ class PairingStore: def _is_rate_limited(self, platform: str, user_id: str) -> bool: """Check if a user has requested a code too recently.""" limits = self._load_json(self._rate_limit_path()) - key = f"{platform}:{user_id}" - last_request = limits.get(key, 0) - return (time.time() - last_request) < RATE_LIMIT_SECONDS + for alias in self._user_id_aliases(platform, user_id): + key = f"{platform}:{alias}" + last_request = limits.get(key, 0) + if (time.time() - last_request) < RATE_LIMIT_SECONDS: + return True + return False def _record_rate_limit(self, platform: str, user_id: str) -> None: """Record the time of a pairing request for rate limiting.""" limits = self._load_json(self._rate_limit_path()) - key = f"{platform}:{user_id}" - limits[key] = time.time() + now = time.time() + for alias in self._user_id_aliases(platform, user_id): + key = f"{platform}:{alias}" + limits[key] = now self._save_json(self._rate_limit_path(), limits) def _is_locked_out(self, platform: str) -> bool: diff --git a/tests/gateway/test_pairing.py b/tests/gateway/test_pairing.py index ca58e2d8269..0bff131ed1a 100644 --- a/tests/gateway/test_pairing.py +++ b/tests/gateway/test_pairing.py @@ -2,10 +2,13 @@ import json import os +import sys import time from pathlib import Path from unittest.mock import patch +import pytest + from gateway.pairing import ( PairingStore, ALPHABET, @@ -37,6 +40,10 @@ class TestSecureWrite: assert target.exists() assert json.loads(target.read_text()) == {"hello": "world"} + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="POSIX file modes are not enforced on Windows", + ) def test_sets_file_permissions(self, tmp_path): target = tmp_path / "secret.json" _secure_write(target, "data") @@ -305,6 +312,23 @@ class TestRateLimiting: assert isinstance(code2, str) and len(code2) == CODE_LENGTH assert code2 != code1 + def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net") + code2 = store.generate_code("whatsapp", "999999999999999@lid") + + assert isinstance(code1, str) and len(code1) == CODE_LENGTH + assert code2 is None + # --------------------------------------------------------------------------- # Max pending limit @@ -397,6 +421,55 @@ class TestApprovalFlow: result = store.approve_code("telegram", "INVALIDCODE") assert result is None + def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice") + store.approve_code("whatsapp", code) + + assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True + assert store.is_approved("whatsapp", "999999999999999@lid") is True + + approved = store.list_approved("whatsapp") + + assert len(approved) == 1 + assert approved[0]["user_id"] == "15551234567" + + def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + approved_path = tmp_path / "whatsapp-approved.json" + approved_path.write_text( + json.dumps( + { + "15551234567@s.whatsapp.net": { + "user_name": "Legacy Alice", + "approved_at": time.time(), + } + }, + indent=2, + ), + encoding="utf-8", + ) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + assert store.is_approved("whatsapp", "999999999999999@lid") is True + # --------------------------------------------------------------------------- # Lockout after failed attempts