From e9b5864b3f059d111b031f3506db739d81122315 Mon Sep 17 00:00:00 2001 From: charliekerfoot Date: Mon, 6 Apr 2026 18:13:54 -0500 Subject: [PATCH] fix: multiple platform adaptors concurrency --- gateway/pairing.py | 133 +++++++++++++++++++++++++++------------------ 1 file changed, 79 insertions(+), 54 deletions(-) diff --git a/gateway/pairing.py b/gateway/pairing.py index 34b3d9023..09b61fef2 100644 --- a/gateway/pairing.py +++ b/gateway/pairing.py @@ -21,6 +21,8 @@ Storage: ~/.hermes/pairing/ import json import os import secrets +import tempfile +import threading import time from pathlib import Path from typing import Optional @@ -45,13 +47,29 @@ PAIRING_DIR = get_hermes_dir("platforms/pairing", "pairing") def _secure_write(path: Path, data: str) -> None: - """Write data to file with restrictive permissions (owner read/write only).""" + """Write data to file with restrictive permissions (owner read/write only). + + Uses a temp-file + atomic rename so readers always see either the old + complete file or the new one — never a partial write. + """ path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(data, encoding="utf-8") + fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") try: - os.chmod(path, 0o600) - except OSError: - pass # Windows doesn't support chmod the same way + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, str(path)) + try: + os.chmod(path, 0o600) + except OSError: + pass # Windows doesn't support chmod the same way + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise class PairingStore: @@ -66,6 +84,9 @@ class PairingStore: def __init__(self): PAIRING_DIR.mkdir(parents=True, exist_ok=True) + # Protects all read-modify-write cycles. The gateway runs multiple + # platform adapters concurrently in threads sharing one PairingStore. + self._lock = threading.RLock() def _pending_path(self, platform: str) -> Path: return PAIRING_DIR / f"{platform}-pending.json" @@ -105,7 +126,7 @@ class PairingStore: return results def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None: - """Add a user to the approved list.""" + """Add a user to the approved list. Must be called under self._lock.""" approved = self._load_json(self._approved_path(platform)) approved[user_id] = { "user_name": user_name, @@ -116,11 +137,12 @@ class PairingStore: def revoke(self, platform: str, user_id: str) -> bool: """Remove a user from the approved list. Returns True if found.""" path = self._approved_path(platform) - approved = self._load_json(path) - if user_id in approved: - del approved[user_id] - self._save_json(path, approved) - return True + with self._lock: + approved = self._load_json(path) + if user_id in approved: + del approved[user_id] + self._save_json(path, approved) + return True return False # ----- Pending codes ----- @@ -136,36 +158,37 @@ class PairingStore: - Max pending codes reached for this platform - User/platform is in lockout due to failed attempts """ - self._cleanup_expired(platform) + with self._lock: + self._cleanup_expired(platform) - # Check lockout - if self._is_locked_out(platform): - return None + # Check lockout + if self._is_locked_out(platform): + return None - # Check rate limit for this specific user - if self._is_rate_limited(platform, user_id): - return None + # Check rate limit for this specific user + if self._is_rate_limited(platform, user_id): + return None - # Check max pending - pending = self._load_json(self._pending_path(platform)) - if len(pending) >= MAX_PENDING_PER_PLATFORM: - return None + # Check max pending + pending = self._load_json(self._pending_path(platform)) + if len(pending) >= MAX_PENDING_PER_PLATFORM: + return None - # Generate cryptographically random code - code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH)) + # Generate cryptographically random code + code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH)) - # Store pending request - pending[code] = { - "user_id": user_id, - "user_name": user_name, - "created_at": time.time(), - } - self._save_json(self._pending_path(platform), pending) + # Store pending request + pending[code] = { + "user_id": user_id, + "user_name": user_name, + "created_at": time.time(), + } + self._save_json(self._pending_path(platform), pending) - # Record rate limit - self._record_rate_limit(platform, user_id) + # Record rate limit + self._record_rate_limit(platform, user_id) - return code + return code def approve_code(self, platform: str, code: str) -> Optional[dict]: """ @@ -173,24 +196,25 @@ class PairingStore: Returns {user_id, user_name} on success, None if code is invalid/expired. """ - self._cleanup_expired(platform) - code = code.upper().strip() + with self._lock: + self._cleanup_expired(platform) + code = code.upper().strip() - pending = self._load_json(self._pending_path(platform)) - if code not in pending: - self._record_failed_attempt(platform) - return None + pending = self._load_json(self._pending_path(platform)) + if code not in pending: + self._record_failed_attempt(platform) + return None - entry = pending.pop(code) - self._save_json(self._pending_path(platform), pending) + entry = pending.pop(code) + self._save_json(self._pending_path(platform), pending) - # Add to approved list - self._approve_user(platform, entry["user_id"], entry.get("user_name", "")) + # Add to approved list + self._approve_user(platform, entry["user_id"], entry.get("user_name", "")) - return { - "user_id": entry["user_id"], - "user_name": entry.get("user_name", ""), - } + return { + "user_id": entry["user_id"], + "user_name": entry.get("user_name", ""), + } def list_pending(self, platform: str = None) -> list: """List pending pairing requests, optionally filtered by platform.""" @@ -212,12 +236,13 @@ class PairingStore: def clear_pending(self, platform: str = None) -> int: """Clear all pending requests. Returns count removed.""" - count = 0 - platforms = [platform] if platform else self._all_platforms("pending") - for p in platforms: - pending = self._load_json(self._pending_path(p)) - count += len(pending) - self._save_json(self._pending_path(p), {}) + with self._lock: + count = 0 + platforms = [platform] if platform else self._all_platforms("pending") + for p in platforms: + pending = self._load_json(self._pending_path(p)) + count += len(pending) + self._save_json(self._pending_path(p), {}) return count # ----- Rate limiting and lockout -----