fix: multiple platform adaptors concurrency

This commit is contained in:
charliekerfoot 2026-04-06 18:13:54 -05:00 committed by Teknium
parent c1818b7e9e
commit e9b5864b3f

View file

@ -21,6 +21,8 @@ Storage: ~/.hermes/pairing/
import json import json
import os import os
import secrets import secrets
import tempfile
import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -45,13 +47,29 @@ PAIRING_DIR = get_hermes_dir("platforms/pairing", "pairing")
def _secure_write(path: Path, data: str) -> None: def _secure_write(path: Path, data: str) -> None:
"""Write data to file with restrictive permissions (owner read/write only).""" """Write data to file with restrictive permissions (owner read/write only).
Uses a temp-file + atomic rename so readers always see either the old
complete file or the new one never a partial write.
"""
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(data, encoding="utf-8") fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp")
try: try:
os.chmod(path, 0o600) with os.fdopen(fd, "w", encoding="utf-8") as f:
except OSError: f.write(data)
pass # Windows doesn't support chmod the same way f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, str(path))
try:
os.chmod(path, 0o600)
except OSError:
pass # Windows doesn't support chmod the same way
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
class PairingStore: class PairingStore:
@ -66,6 +84,9 @@ class PairingStore:
def __init__(self): def __init__(self):
PAIRING_DIR.mkdir(parents=True, exist_ok=True) PAIRING_DIR.mkdir(parents=True, exist_ok=True)
# Protects all read-modify-write cycles. The gateway runs multiple
# platform adapters concurrently in threads sharing one PairingStore.
self._lock = threading.RLock()
def _pending_path(self, platform: str) -> Path: def _pending_path(self, platform: str) -> Path:
return PAIRING_DIR / f"{platform}-pending.json" return PAIRING_DIR / f"{platform}-pending.json"
@ -105,7 +126,7 @@ class PairingStore:
return results return results
def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None: def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
"""Add a user to the approved list.""" """Add a user to the approved list. Must be called under self._lock."""
approved = self._load_json(self._approved_path(platform)) approved = self._load_json(self._approved_path(platform))
approved[user_id] = { approved[user_id] = {
"user_name": user_name, "user_name": user_name,
@ -116,11 +137,12 @@ class PairingStore:
def revoke(self, platform: str, user_id: str) -> bool: def revoke(self, platform: str, user_id: str) -> bool:
"""Remove a user from the approved list. Returns True if found.""" """Remove a user from the approved list. Returns True if found."""
path = self._approved_path(platform) path = self._approved_path(platform)
approved = self._load_json(path) with self._lock:
if user_id in approved: approved = self._load_json(path)
del approved[user_id] if user_id in approved:
self._save_json(path, approved) del approved[user_id]
return True self._save_json(path, approved)
return True
return False return False
# ----- Pending codes ----- # ----- Pending codes -----
@ -136,36 +158,37 @@ class PairingStore:
- Max pending codes reached for this platform - Max pending codes reached for this platform
- User/platform is in lockout due to failed attempts - User/platform is in lockout due to failed attempts
""" """
self._cleanup_expired(platform) with self._lock:
self._cleanup_expired(platform)
# Check lockout # Check lockout
if self._is_locked_out(platform): if self._is_locked_out(platform):
return None return None
# Check rate limit for this specific user # Check rate limit for this specific user
if self._is_rate_limited(platform, user_id): if self._is_rate_limited(platform, user_id):
return None return None
# Check max pending # Check max pending
pending = self._load_json(self._pending_path(platform)) pending = self._load_json(self._pending_path(platform))
if len(pending) >= MAX_PENDING_PER_PLATFORM: if len(pending) >= MAX_PENDING_PER_PLATFORM:
return None return None
# Generate cryptographically random code # Generate cryptographically random code
code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH)) code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH))
# Store pending request # Store pending request
pending[code] = { pending[code] = {
"user_id": user_id, "user_id": user_id,
"user_name": user_name, "user_name": user_name,
"created_at": time.time(), "created_at": time.time(),
} }
self._save_json(self._pending_path(platform), pending) self._save_json(self._pending_path(platform), pending)
# Record rate limit # Record rate limit
self._record_rate_limit(platform, user_id) self._record_rate_limit(platform, user_id)
return code return code
def approve_code(self, platform: str, code: str) -> Optional[dict]: def approve_code(self, platform: str, code: str) -> Optional[dict]:
""" """
@ -173,24 +196,25 @@ class PairingStore:
Returns {user_id, user_name} on success, None if code is invalid/expired. Returns {user_id, user_name} on success, None if code is invalid/expired.
""" """
self._cleanup_expired(platform) with self._lock:
code = code.upper().strip() self._cleanup_expired(platform)
code = code.upper().strip()
pending = self._load_json(self._pending_path(platform)) pending = self._load_json(self._pending_path(platform))
if code not in pending: if code not in pending:
self._record_failed_attempt(platform) self._record_failed_attempt(platform)
return None return None
entry = pending.pop(code) entry = pending.pop(code)
self._save_json(self._pending_path(platform), pending) self._save_json(self._pending_path(platform), pending)
# Add to approved list # Add to approved list
self._approve_user(platform, entry["user_id"], entry.get("user_name", "")) self._approve_user(platform, entry["user_id"], entry.get("user_name", ""))
return { return {
"user_id": entry["user_id"], "user_id": entry["user_id"],
"user_name": entry.get("user_name", ""), "user_name": entry.get("user_name", ""),
} }
def list_pending(self, platform: str = None) -> list: def list_pending(self, platform: str = None) -> list:
"""List pending pairing requests, optionally filtered by platform.""" """List pending pairing requests, optionally filtered by platform."""
@ -212,12 +236,13 @@ class PairingStore:
def clear_pending(self, platform: str = None) -> int: def clear_pending(self, platform: str = None) -> int:
"""Clear all pending requests. Returns count removed.""" """Clear all pending requests. Returns count removed."""
count = 0 with self._lock:
platforms = [platform] if platform else self._all_platforms("pending") count = 0
for p in platforms: platforms = [platform] if platform else self._all_platforms("pending")
pending = self._load_json(self._pending_path(p)) for p in platforms:
count += len(pending) pending = self._load_json(self._pending_path(p))
self._save_json(self._pending_path(p), {}) count += len(pending)
self._save_json(self._pending_path(p), {})
return count return count
# ----- Rate limiting and lockout ----- # ----- Rate limiting and lockout -----