diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index e25bb350f..5e0208c77 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -63,6 +63,7 @@ from gateway.platforms.base import ( cache_image_from_bytes, ) from hermes_constants import get_hermes_home +from utils import atomic_json_write ILINK_BASE_URL = "https://ilinkai.weixin.qq.com" WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c" @@ -206,7 +207,7 @@ def save_weixin_account( "saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), } path = _account_file(hermes_home, account_id) - path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + atomic_json_write(path, payload) try: path.chmod(0o600) except OSError: @@ -269,7 +270,7 @@ class ContextTokenStore: if key.startswith(prefix) } try: - self._path(account_id).write_text(json.dumps(payload), encoding="utf-8") + atomic_json_write(self._path(account_id), payload) except Exception as exc: logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc) @@ -868,7 +869,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str: def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None: path = _sync_buf_path(hermes_home, account_id) - path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8") + atomic_json_write(path, {"get_updates_buf": sync_buf}) async def qr_login( diff --git a/tests/gateway/test_weixin.py b/tests/gateway/test_weixin.py index caf4a7eba..815ea75ef 100644 --- a/tests/gateway/test_weixin.py +++ b/tests/gateway/test_weixin.py @@ -1,12 +1,14 @@ """Tests for the Weixin platform adapter.""" import asyncio +import json import os from unittest.mock import AsyncMock, patch from gateway.config import PlatformConfig from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides -from gateway.platforms.weixin import WeixinAdapter +from gateway.platforms import weixin +from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter from tools.send_message_tool import _parse_target_ref, _send_to_platform @@ -187,6 +189,70 @@ class TestWeixinConfig: assert config.get_connected_platforms() == [] +class TestWeixinStatePersistence: + def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + account_path = tmp_path / "weixin" / "accounts" / "acct.json" + account_path.parent.mkdir(parents=True, exist_ok=True) + original = {"token": "old-token", "base_url": "https://old.example.com"} + account_path.write_text(json.dumps(original), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin.save_weixin_account( + str(tmp_path), + account_id="acct", + token="new-token", + base_url="https://new.example.com", + user_id="wxid_new", + ) + except OSError: + pass + else: + raise AssertionError("expected save_weixin_account to propagate replace failure") + + assert json.loads(account_path.read_text(encoding="utf-8")) == original + + def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json" + token_path.parent.mkdir(parents=True, exist_ok=True) + token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + store = ContextTokenStore(str(tmp_path)) + with patch.object(weixin.logger, "warning") as warning_mock: + store.set("acct", "user-b", "new-token") + + assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"} + warning_mock.assert_called_once() + + def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch): + sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json" + sync_path.parent.mkdir(parents=True, exist_ok=True) + sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8") + + def _boom(_src, _dst): + raise OSError("disk full") + + monkeypatch.setattr("utils.os.replace", _boom) + + try: + weixin._save_sync_buf(str(tmp_path), "acct", "new-sync") + except OSError: + pass + else: + raise AssertionError("expected _save_sync_buf to propagate replace failure") + + assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"} + + class TestWeixinSendMessageIntegration: def test_parse_target_ref_accepts_weixin_ids(self): assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)