mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
fix(gateway): preserve WhatsApp pairing approvals across JID/LID alias flips
This commit is contained in:
parent
3127a41cb1
commit
52a368fa72
2 changed files with 136 additions and 10 deletions
|
|
@ -28,6 +28,10 @@ import time
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from gateway.whatsapp_identity import (
|
||||
expand_whatsapp_aliases,
|
||||
normalize_whatsapp_identifier,
|
||||
)
|
||||
from hermes_constants import get_hermes_dir
|
||||
from utils import atomic_replace
|
||||
|
||||
|
|
@ -110,12 +114,40 @@ class PairingStore:
|
|||
def _save_json(self, path: Path, data: dict) -> None:
|
||||
_secure_write(path, json.dumps(data, indent=2, ensure_ascii=False))
|
||||
|
||||
def _normalize_user_id(self, platform: str, user_id: str) -> str:
|
||||
"""Normalize platform-specific user IDs before persisting them."""
|
||||
raw_user_id = str(user_id or "").strip()
|
||||
if platform == "whatsapp":
|
||||
return normalize_whatsapp_identifier(raw_user_id) or raw_user_id
|
||||
return raw_user_id
|
||||
|
||||
def _user_id_aliases(self, platform: str, user_id: str) -> set[str]:
|
||||
"""Return all known equivalent user IDs for auth/rate-limit checks."""
|
||||
raw_user_id = str(user_id or "").strip()
|
||||
if not raw_user_id:
|
||||
return set()
|
||||
|
||||
aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)}
|
||||
if platform == "whatsapp":
|
||||
aliases.update(expand_whatsapp_aliases(raw_user_id))
|
||||
aliases.discard("")
|
||||
return aliases
|
||||
|
||||
def _user_ids_match(self, platform: str, left: str, right: str) -> bool:
|
||||
"""Return True when two user IDs represent the same principal."""
|
||||
left_aliases = self._user_id_aliases(platform, left)
|
||||
right_aliases = self._user_id_aliases(platform, right)
|
||||
return bool(left_aliases and right_aliases and (left_aliases & right_aliases))
|
||||
|
||||
# ----- Approved users -----
|
||||
|
||||
def is_approved(self, platform: str, user_id: str) -> bool:
|
||||
"""Check if a user is approved (paired) on a platform."""
|
||||
approved = self._load_json(self._approved_path(platform))
|
||||
return user_id in approved
|
||||
for approved_user_id in approved:
|
||||
if self._user_ids_match(platform, approved_user_id, user_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_approved(self, platform: str = None) -> list:
|
||||
"""List approved users, optionally filtered by platform."""
|
||||
|
|
@ -130,7 +162,16 @@ class PairingStore:
|
|||
def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
|
||||
"""Add a user to the approved list. Must be called under self._lock."""
|
||||
approved = self._load_json(self._approved_path(platform))
|
||||
approved[user_id] = {
|
||||
normalized_user_id = self._normalize_user_id(platform, user_id)
|
||||
duplicate_ids = [
|
||||
approved_user_id
|
||||
for approved_user_id in approved
|
||||
if self._user_ids_match(platform, approved_user_id, normalized_user_id)
|
||||
]
|
||||
for approved_user_id in duplicate_ids:
|
||||
del approved[approved_user_id]
|
||||
|
||||
approved[normalized_user_id] = {
|
||||
"user_name": user_name,
|
||||
"approved_at": time.time(),
|
||||
}
|
||||
|
|
@ -141,8 +182,14 @@ class PairingStore:
|
|||
path = self._approved_path(platform)
|
||||
with self._lock:
|
||||
approved = self._load_json(path)
|
||||
if user_id in approved:
|
||||
del approved[user_id]
|
||||
matching_ids = [
|
||||
approved_user_id
|
||||
for approved_user_id in approved
|
||||
if self._user_ids_match(platform, approved_user_id, user_id)
|
||||
]
|
||||
if matching_ids:
|
||||
for approved_user_id in matching_ids:
|
||||
del approved[approved_user_id]
|
||||
self._save_json(path, approved)
|
||||
return True
|
||||
return False
|
||||
|
|
@ -170,6 +217,7 @@ class PairingStore:
|
|||
"""
|
||||
with self._lock:
|
||||
self._cleanup_expired(platform)
|
||||
normalized_user_id = self._normalize_user_id(platform, user_id)
|
||||
|
||||
# Check lockout
|
||||
if self._is_locked_out(platform):
|
||||
|
|
@ -198,7 +246,7 @@ class PairingStore:
|
|||
pending[entry_id] = {
|
||||
"hash": code_hash,
|
||||
"salt": salt.hex(),
|
||||
"user_id": user_id,
|
||||
"user_id": normalized_user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
|
@ -326,15 +374,20 @@ class PairingStore:
|
|||
def _is_rate_limited(self, platform: str, user_id: str) -> bool:
|
||||
"""Check if a user has requested a code too recently."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
key = f"{platform}:{user_id}"
|
||||
last_request = limits.get(key, 0)
|
||||
return (time.time() - last_request) < RATE_LIMIT_SECONDS
|
||||
for alias in self._user_id_aliases(platform, user_id):
|
||||
key = f"{platform}:{alias}"
|
||||
last_request = limits.get(key, 0)
|
||||
if (time.time() - last_request) < RATE_LIMIT_SECONDS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _record_rate_limit(self, platform: str, user_id: str) -> None:
|
||||
"""Record the time of a pairing request for rate limiting."""
|
||||
limits = self._load_json(self._rate_limit_path())
|
||||
key = f"{platform}:{user_id}"
|
||||
limits[key] = time.time()
|
||||
now = time.time()
|
||||
for alias in self._user_id_aliases(platform, user_id):
|
||||
key = f"{platform}:{alias}"
|
||||
limits[key] = now
|
||||
self._save_json(self._rate_limit_path(), limits)
|
||||
|
||||
def _is_locked_out(self, platform: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.pairing import (
|
||||
PairingStore,
|
||||
ALPHABET,
|
||||
|
|
@ -37,6 +40,10 @@ class TestSecureWrite:
|
|||
assert target.exists()
|
||||
assert json.loads(target.read_text()) == {"hello": "world"}
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="POSIX file modes are not enforced on Windows",
|
||||
)
|
||||
def test_sets_file_permissions(self, tmp_path):
|
||||
target = tmp_path / "secret.json"
|
||||
_secure_write(target, "data")
|
||||
|
|
@ -305,6 +312,23 @@ class TestRateLimiting:
|
|||
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
|
||||
assert code2 != code1
|
||||
|
||||
def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net")
|
||||
code2 = store.generate_code("whatsapp", "999999999999999@lid")
|
||||
|
||||
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
|
||||
assert code2 is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max pending limit
|
||||
|
|
@ -397,6 +421,55 @@ class TestApprovalFlow:
|
|||
result = store.approve_code("telegram", "INVALIDCODE")
|
||||
assert result is None
|
||||
|
||||
def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice")
|
||||
store.approve_code("whatsapp", code)
|
||||
|
||||
assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True
|
||||
assert store.is_approved("whatsapp", "999999999999999@lid") is True
|
||||
|
||||
approved = store.list_approved("whatsapp")
|
||||
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["user_id"] == "15551234567"
|
||||
|
||||
def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
approved_path = tmp_path / "whatsapp-approved.json"
|
||||
approved_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"15551234567@s.whatsapp.net": {
|
||||
"user_name": "Legacy Alice",
|
||||
"approved_at": time.time(),
|
||||
}
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
assert store.is_approved("whatsapp", "999999999999999@lid") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lockout after failed attempts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue