fix(gateway/weixin): ensure atomic persistence for critical session state

This commit is contained in:
WAXLYY 2026-04-11 19:38:32 +03:00 committed by Teknium
parent 59e630a64d
commit f4f4078ad9
2 changed files with 71 additions and 4 deletions

View file

@ -63,6 +63,7 @@ from gateway.platforms.base import (
cache_image_from_bytes, cache_image_from_bytes,
) )
from hermes_constants import get_hermes_home from hermes_constants import get_hermes_home
from utils import atomic_json_write
ILINK_BASE_URL = "https://ilinkai.weixin.qq.com" ILINK_BASE_URL = "https://ilinkai.weixin.qq.com"
WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c" WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
@ -206,7 +207,7 @@ def save_weixin_account(
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
} }
path = _account_file(hermes_home, account_id) path = _account_file(hermes_home, account_id)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8") atomic_json_write(path, payload)
try: try:
path.chmod(0o600) path.chmod(0o600)
except OSError: except OSError:
@ -269,7 +270,7 @@ class ContextTokenStore:
if key.startswith(prefix) if key.startswith(prefix)
} }
try: try:
self._path(account_id).write_text(json.dumps(payload), encoding="utf-8") atomic_json_write(self._path(account_id), payload)
except Exception as exc: except Exception as exc:
logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc) logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc)
@ -868,7 +869,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str:
def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None: def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None:
path = _sync_buf_path(hermes_home, account_id) path = _sync_buf_path(hermes_home, account_id)
path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8") atomic_json_write(path, {"get_updates_buf": sync_buf})
async def qr_login( async def qr_login(

View file

@ -1,12 +1,14 @@
"""Tests for the Weixin platform adapter.""" """Tests for the Weixin platform adapter."""
import asyncio import asyncio
import json
import os import os
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from gateway.config import PlatformConfig from gateway.config import PlatformConfig
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
from gateway.platforms.weixin import WeixinAdapter from gateway.platforms import weixin
from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter
from tools.send_message_tool import _parse_target_ref, _send_to_platform from tools.send_message_tool import _parse_target_ref, _send_to_platform
@ -187,6 +189,70 @@ class TestWeixinConfig:
assert config.get_connected_platforms() == [] assert config.get_connected_platforms() == []
class TestWeixinStatePersistence:
def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
account_path = tmp_path / "weixin" / "accounts" / "acct.json"
account_path.parent.mkdir(parents=True, exist_ok=True)
original = {"token": "old-token", "base_url": "https://old.example.com"}
account_path.write_text(json.dumps(original), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
try:
weixin.save_weixin_account(
str(tmp_path),
account_id="acct",
token="new-token",
base_url="https://new.example.com",
user_id="wxid_new",
)
except OSError:
pass
else:
raise AssertionError("expected save_weixin_account to propagate replace failure")
assert json.loads(account_path.read_text(encoding="utf-8")) == original
def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json"
token_path.parent.mkdir(parents=True, exist_ok=True)
token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
store = ContextTokenStore(str(tmp_path))
with patch.object(weixin.logger, "warning") as warning_mock:
store.set("acct", "user-b", "new-token")
assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"}
warning_mock.assert_called_once()
def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json"
sync_path.parent.mkdir(parents=True, exist_ok=True)
sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
try:
weixin._save_sync_buf(str(tmp_path), "acct", "new-sync")
except OSError:
pass
else:
raise AssertionError("expected _save_sync_buf to propagate replace failure")
assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"}
class TestWeixinSendMessageIntegration: class TestWeixinSendMessageIntegration:
def test_parse_target_ref_accepts_weixin_ids(self): def test_parse_target_ref_accepts_weixin_ids(self):
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True) assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)