fix(gateway): use monotonic deadlines in QR onboarding flows

This commit is contained in:
Zyproth 2026-05-05 19:51:51 +03:00 committed by Teknium
parent 73d6371762
commit 6e8f1e09a9
6 changed files with 105 additions and 13 deletions

View file

@ -4591,12 +4591,12 @@ def _poll_registration(
Returns dict with app_id, app_secret, domain, open_id on success. Returns dict with app_id, app_secret, domain, open_id on success.
Returns None on failure. Returns None on failure.
""" """
deadline = time.time() + expire_in deadline = time.monotonic() + expire_in
current_domain = domain current_domain = domain
domain_switched = False domain_switched = False
poll_count = 0 poll_count = 0
while time.time() < deadline: while time.monotonic() < deadline:
base_url = _accounts_base_url(current_domain) base_url = _accounts_base_url(current_domain)
try: try:
res = _post_registration(base_url, { res = _post_registration(base_url, {

View file

@ -37,6 +37,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
import time
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@ -1562,12 +1563,11 @@ def qr_scan_for_bot_info(
print(" Fetching configuration results...", end="", flush=True) print(" Fetching configuration results...", end="", flush=True)
# ── Step 3: Poll for result ── # ── Step 3: Poll for result ──
import time deadline = time.monotonic() + timeout_seconds
deadline = time.time() + timeout_seconds
query_url = f"{_QR_QUERY_URL}?scode={urllib.parse.quote(scode)}" query_url = f"{_QR_QUERY_URL}?scode={urllib.parse.quote(scode)}"
poll_count = 0 poll_count = 0
while time.time() < deadline: while time.monotonic() < deadline:
try: try:
req = urllib.request.Request(query_url, headers={"User-Agent": "HermesAgent/1.0"}) req = urllib.request.Request(query_url, headers={"User-Agent": "HermesAgent/1.0"})
with urllib.request.urlopen(req, timeout=10) as resp: with urllib.request.urlopen(req, timeout=10) as resp:

View file

@ -1037,11 +1037,11 @@ async def qr_login(
except Exception as _qr_exc: except Exception as _qr_exc:
print(f"(终端二维码渲染失败: {_qr_exc},请直接打开上面的二维码链接)") print(f"(终端二维码渲染失败: {_qr_exc},请直接打开上面的二维码链接)")
deadline = time.time() + timeout_seconds deadline = time.monotonic() + timeout_seconds
current_base_url = ILINK_BASE_URL current_base_url = ILINK_BASE_URL
refresh_count = 0 refresh_count = 0
while time.time() < deadline: while time.monotonic() < deadline:
try: try:
status_resp = await _api_get( status_resp = await _api_get(
session, session,

View file

@ -127,7 +127,7 @@ class TestPollRegistration:
def test_poll_returns_credentials_on_success(self, mock_urlopen_fn, mock_time): def test_poll_returns_credentials_on_success(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1] mock_time.monotonic.side_effect = [0, 1]
mock_time.sleep = MagicMock() mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({ mock_urlopen_fn.return_value = _mock_urlopen({
@ -149,7 +149,7 @@ class TestPollRegistration:
def test_poll_switches_domain_on_lark_tenant_brand(self, mock_urlopen_fn, mock_time): def test_poll_switches_domain_on_lark_tenant_brand(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1, 2] mock_time.monotonic.side_effect = [0, 1, 2]
mock_time.sleep = MagicMock() mock_time.sleep = MagicMock()
pending_resp = _mock_urlopen({ pending_resp = _mock_urlopen({
@ -175,7 +175,7 @@ class TestPollRegistration:
"""Credentials and lark tenant_brand in one response must not be discarded.""" """Credentials and lark tenant_brand in one response must not be discarded."""
from gateway.platforms.feishu import _poll_registration from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1] mock_time.monotonic.side_effect = [0, 1]
mock_time.sleep = MagicMock() mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({ mock_urlopen_fn.return_value = _mock_urlopen({
@ -196,7 +196,7 @@ class TestPollRegistration:
def test_poll_returns_none_on_access_denied(self, mock_urlopen_fn, mock_time): def test_poll_returns_none_on_access_denied(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 1] mock_time.monotonic.side_effect = [0, 1]
mock_time.sleep = MagicMock() mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({ mock_urlopen_fn.return_value = _mock_urlopen({
@ -212,7 +212,7 @@ class TestPollRegistration:
def test_poll_returns_none_on_timeout(self, mock_urlopen_fn, mock_time): def test_poll_returns_none_on_timeout(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration from gateway.platforms.feishu import _poll_registration
mock_time.time.side_effect = [0, 999] mock_time.monotonic.side_effect = [0, 999]
mock_time.sleep = MagicMock() mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({ mock_urlopen_fn.return_value = _mock_urlopen({
@ -223,6 +223,25 @@ class TestPollRegistration:
) )
assert result is None assert result is None
@patch("gateway.platforms.feishu.time")
@patch("gateway.platforms.feishu.urlopen")
def test_poll_timeout_uses_monotonic_clock(self, mock_urlopen_fn, mock_time):
from gateway.platforms.feishu import _poll_registration
mock_time.monotonic.side_effect = [1000, 1000.2, 1001.1]
mock_time.time.side_effect = [1000, 900, 901, 902]
mock_time.sleep = MagicMock()
mock_urlopen_fn.return_value = _mock_urlopen({
"error": "authorization_pending",
})
result = _poll_registration(
device_code="dc_123", interval=1, expire_in=1, domain="feishu"
)
assert result is None
mock_urlopen_fn.assert_called_once()
class TestRenderQr: class TestRenderQr:
"""Tests for QR code terminal rendering.""" """Tests for QR code terminal rendering."""

View file

@ -4,7 +4,7 @@ import base64
import os import os
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -122,6 +122,48 @@ class TestWeComConnect:
assert "invalid secret" in (adapter.fatal_error_message or "") assert "invalid secret" in (adapter.fatal_error_message or "")
class TestWeComQrScan:
@patch("gateway.platforms.wecom.time")
@patch("gateway.platforms.wecom.json.loads")
@patch("gateway.platforms.wecom.logger")
@patch("urllib.request.urlopen")
@patch("urllib.request.Request")
def test_qr_scan_timeout_uses_monotonic_clock(
self,
mock_request,
mock_urlopen,
_mock_logger,
mock_json_loads,
mock_time,
):
from gateway.platforms.wecom import qr_scan_for_bot_info
generate_resp = MagicMock()
generate_resp.read.return_value = b'{"data":{"scode":"abc","auth_url":"https://example.com/qr"}}'
generate_resp.__enter__.return_value = generate_resp
generate_resp.__exit__.return_value = False
poll_resp = MagicMock()
poll_resp.read.return_value = b'{"data":{"status":"pending"}}'
poll_resp.__enter__.return_value = poll_resp
poll_resp.__exit__.return_value = False
mock_urlopen.side_effect = [generate_resp, poll_resp]
mock_json_loads.side_effect = [
{"data": {"scode": "abc", "auth_url": "https://example.com/qr"}},
{"data": {"status": "pending"}},
]
mock_time.monotonic.side_effect = [1000, 1000.2, 1001.1]
mock_time.time.side_effect = [1000, 900, 901, 902]
mock_time.sleep = MagicMock()
with patch("builtins.print"), patch.dict("sys.modules", {"qrcode": None}):
result = qr_scan_for_bot_info(timeout_seconds=1)
assert result is None
assert mock_urlopen.call_count == 2
class TestWeComReplyMode: class TestWeComReplyMode:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_uses_passive_reply_markdown_when_reply_context_exists(self): async def test_send_uses_passive_reply_markdown_when_reply_context_exists(self):

View file

@ -7,6 +7,8 @@ import os
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest
from gateway.config import PlatformConfig from gateway.config import PlatformConfig
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
from gateway.platforms.base import SendResult from gateway.platforms.base import SendResult
@ -279,6 +281,35 @@ class TestWeixinStatePersistence:
assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"} assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"}
class TestWeixinQrLogin:
@pytest.mark.asyncio
async def test_qr_login_timeout_uses_monotonic_clock(self, tmp_path):
first_qr = {
"qrcode": "qr-1",
"qrcode_img_content": "https://example.com/qr-1",
}
pending = {"status": "wait"}
with patch("gateway.platforms.weixin._api_get", new_callable=AsyncMock) as api_get_mock, \
patch("gateway.platforms.weixin.time") as mock_time, \
patch("gateway.platforms.weixin.AIOHTTP_AVAILABLE", True), \
patch("gateway.platforms.weixin.aiohttp.ClientSession", create=True) as session_cls, \
patch("builtins.print"):
api_get_mock.side_effect = [first_qr, pending]
mock_time.monotonic.side_effect = [1000, 1000.2, 1001.1]
mock_time.time.side_effect = [1000, 900, 901, 902]
session = AsyncMock()
session.__aenter__.return_value = session
session.__aexit__.return_value = False
session_cls.return_value = session
result = await weixin.qr_login(str(tmp_path), timeout_seconds=1)
assert result is None
assert api_get_mock.await_count == 2
class TestWeixinSendMessageIntegration: class TestWeixinSendMessageIntegration:
def test_parse_target_ref_accepts_weixin_ids(self): def test_parse_target_ref_accepts_weixin_ids(self):
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True) assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)