mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Sweep ~74 redundant local imports across 21 files where the same module was already imported at the top level. Also includes type fixes and lint cleanups on the same branch.
1466 lines
58 KiB
Python
1466 lines
58 KiB
Python
"""
|
||
WeCom (Enterprise WeChat) platform adapter.
|
||
|
||
Uses the WeCom AI Bot WebSocket gateway for inbound and outbound messages.
|
||
The adapter focuses on the core gateway path:
|
||
|
||
- authenticate via ``aibot_subscribe``
|
||
- receive inbound ``aibot_msg_callback`` events
|
||
- send outbound markdown messages via ``aibot_send_msg``
|
||
- upload outbound media via ``aibot_upload_media_*`` and send native attachments
|
||
- best-effort download of inbound image/file attachments for agent context
|
||
|
||
Configuration in config.yaml:
|
||
platforms:
|
||
wecom:
|
||
enabled: true
|
||
extra:
|
||
bot_id: "your-bot-id" # or WECOM_BOT_ID env var
|
||
secret: "your-secret" # or WECOM_SECRET env var
|
||
websocket_url: "wss://openws.work.weixin.qq.com"
|
||
dm_policy: "open" # open | allowlist | disabled | pairing
|
||
allow_from: ["user_id_1"]
|
||
group_policy: "open" # open | allowlist | disabled
|
||
group_allow_from: ["group_id_1"]
|
||
groups:
|
||
group_id_1:
|
||
allow_from: ["user_id_1"]
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import mimetypes
|
||
import os
|
||
import re
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
from urllib.parse import unquote, urlparse
|
||
|
||
try:
|
||
import aiohttp
|
||
AIOHTTP_AVAILABLE = True
|
||
except ImportError:
|
||
AIOHTTP_AVAILABLE = False
|
||
aiohttp = None # type: ignore[assignment]
|
||
|
||
try:
|
||
import httpx
|
||
HTTPX_AVAILABLE = True
|
||
except ImportError:
|
||
HTTPX_AVAILABLE = False
|
||
httpx = None # type: ignore[assignment]
|
||
|
||
from gateway.config import Platform, PlatformConfig
|
||
from gateway.platforms.helpers import MessageDeduplicator
|
||
from gateway.platforms.base import (
|
||
BasePlatformAdapter,
|
||
MessageEvent,
|
||
MessageType,
|
||
SendResult,
|
||
cache_document_from_bytes,
|
||
cache_image_from_bytes,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_WS_URL = "wss://openws.work.weixin.qq.com"
|
||
|
||
APP_CMD_SUBSCRIBE = "aibot_subscribe"
|
||
APP_CMD_CALLBACK = "aibot_msg_callback"
|
||
APP_CMD_LEGACY_CALLBACK = "aibot_callback"
|
||
APP_CMD_EVENT_CALLBACK = "aibot_event_callback"
|
||
APP_CMD_SEND = "aibot_send_msg"
|
||
APP_CMD_RESPONSE = "aibot_respond_msg"
|
||
APP_CMD_PING = "ping"
|
||
APP_CMD_UPLOAD_MEDIA_INIT = "aibot_upload_media_init"
|
||
APP_CMD_UPLOAD_MEDIA_CHUNK = "aibot_upload_media_chunk"
|
||
APP_CMD_UPLOAD_MEDIA_FINISH = "aibot_upload_media_finish"
|
||
|
||
CALLBACK_COMMANDS = {APP_CMD_CALLBACK, APP_CMD_LEGACY_CALLBACK}
|
||
NON_RESPONSE_COMMANDS = CALLBACK_COMMANDS | {APP_CMD_EVENT_CALLBACK}
|
||
|
||
MAX_MESSAGE_LENGTH = 4000
|
||
CONNECT_TIMEOUT_SECONDS = 20.0
|
||
REQUEST_TIMEOUT_SECONDS = 15.0
|
||
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||
|
||
DEDUP_MAX_SIZE = 1000
|
||
|
||
IMAGE_MAX_BYTES = 10 * 1024 * 1024
|
||
VIDEO_MAX_BYTES = 10 * 1024 * 1024
|
||
VOICE_MAX_BYTES = 2 * 1024 * 1024
|
||
FILE_MAX_BYTES = 20 * 1024 * 1024
|
||
ABSOLUTE_MAX_BYTES = FILE_MAX_BYTES
|
||
UPLOAD_CHUNK_SIZE = 512 * 1024
|
||
MAX_UPLOAD_CHUNKS = 100
|
||
VOICE_SUPPORTED_MIMES = {"audio/amr"}
|
||
|
||
|
||
def check_wecom_requirements() -> bool:
|
||
"""Check if WeCom runtime dependencies are available."""
|
||
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
||
|
||
|
||
def _coerce_list(value: Any) -> List[str]:
|
||
"""Coerce config values into a trimmed string list."""
|
||
if value is None:
|
||
return []
|
||
if isinstance(value, str):
|
||
return [item.strip() for item in value.split(",") if item.strip()]
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [str(item).strip() for item in value if str(item).strip()]
|
||
return [str(value).strip()] if str(value).strip() else []
|
||
|
||
|
||
def _normalize_entry(raw: str) -> str:
|
||
"""Normalize allowlist entries such as ``wecom:user:foo``."""
|
||
value = str(raw).strip()
|
||
value = re.sub(r"^wecom:", "", value, flags=re.IGNORECASE)
|
||
value = re.sub(r"^(user|group):", "", value, flags=re.IGNORECASE)
|
||
return value.strip()
|
||
|
||
|
||
def _entry_matches(entries: List[str], target: str) -> bool:
|
||
"""Case-insensitive allowlist match with ``*`` support."""
|
||
normalized_target = str(target).strip().lower()
|
||
for entry in entries:
|
||
normalized = _normalize_entry(entry).lower()
|
||
if normalized == "*" or normalized == normalized_target:
|
||
return True
|
||
return False
|
||
|
||
|
||
class WeComAdapter(BasePlatformAdapter):
|
||
"""WeCom AI Bot adapter backed by a persistent WebSocket connection."""
|
||
|
||
MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH
|
||
# Threshold for detecting WeCom client-side message splits.
|
||
# When a chunk is near the 4000-char limit, a continuation is almost certain.
|
||
_SPLIT_THRESHOLD = 3900
|
||
|
||
def __init__(self, config: PlatformConfig):
|
||
super().__init__(config, Platform.WECOM)
|
||
|
||
extra = config.extra or {}
|
||
self._bot_id = str(extra.get("bot_id") or os.getenv("WECOM_BOT_ID", "")).strip()
|
||
self._secret = str(extra.get("secret") or os.getenv("WECOM_SECRET", "")).strip()
|
||
self._ws_url = str(
|
||
extra.get("websocket_url")
|
||
or extra.get("websocketUrl")
|
||
or os.getenv("WECOM_WEBSOCKET_URL", DEFAULT_WS_URL)
|
||
).strip() or DEFAULT_WS_URL
|
||
|
||
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WECOM_DM_POLICY", "open")).strip().lower()
|
||
self._allow_from = _coerce_list(extra.get("allow_from") or extra.get("allowFrom"))
|
||
|
||
self._group_policy = str(extra.get("group_policy") or os.getenv("WECOM_GROUP_POLICY", "open")).strip().lower()
|
||
self._group_allow_from = _coerce_list(extra.get("group_allow_from") or extra.get("groupAllowFrom"))
|
||
self._groups = extra.get("groups") if isinstance(extra.get("groups"), dict) else {}
|
||
|
||
self._session: Optional["aiohttp.ClientSession"] = None
|
||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||
self._listen_task: Optional[asyncio.Task] = None
|
||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||
self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
|
||
self._reply_req_ids: Dict[str, str] = {}
|
||
|
||
# Text batching: merge rapid successive messages (Telegram-style).
|
||
# WeCom clients split long messages around 4000 chars.
|
||
self._text_batch_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||
self._device_id = uuid.uuid4().hex
|
||
self._last_chat_req_ids: Dict[str, str] = {}
|
||
|
||
# ------------------------------------------------------------------
|
||
# Connection lifecycle
|
||
# ------------------------------------------------------------------
|
||
|
||
async def connect(self) -> bool:
|
||
"""Connect to the WeCom AI Bot gateway."""
|
||
if not AIOHTTP_AVAILABLE:
|
||
message = "WeCom startup failed: aiohttp not installed"
|
||
self._set_fatal_error("wecom_missing_dependency", message, retryable=True)
|
||
logger.warning("[%s] %s. Run: pip install aiohttp", self.name, message)
|
||
return False
|
||
if not HTTPX_AVAILABLE:
|
||
message = "WeCom startup failed: httpx not installed"
|
||
self._set_fatal_error("wecom_missing_dependency", message, retryable=True)
|
||
logger.warning("[%s] %s. Run: pip install httpx", self.name, message)
|
||
return False
|
||
if not self._bot_id or not self._secret:
|
||
message = "WeCom startup failed: WECOM_BOT_ID and WECOM_SECRET are required"
|
||
self._set_fatal_error("wecom_missing_credentials", message, retryable=True)
|
||
logger.warning("[%s] %s", self.name, message)
|
||
return False
|
||
|
||
try:
|
||
self._http_client = httpx.AsyncClient(timeout=30.0, follow_redirects=True)
|
||
await self._open_connection()
|
||
self._mark_connected()
|
||
self._listen_task = asyncio.create_task(self._listen_loop())
|
||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||
logger.info("[%s] Connected to %s", self.name, self._ws_url)
|
||
return True
|
||
except Exception as exc:
|
||
message = f"WeCom startup failed: {exc}"
|
||
self._set_fatal_error("wecom_connect_error", message, retryable=True)
|
||
logger.error("[%s] Failed to connect: %s", self.name, exc, exc_info=True)
|
||
await self._cleanup_ws()
|
||
if self._http_client:
|
||
await self._http_client.aclose()
|
||
self._http_client = None
|
||
return False
|
||
|
||
async def disconnect(self) -> None:
|
||
"""Disconnect from WeCom."""
|
||
self._running = False
|
||
self._mark_disconnected()
|
||
|
||
if self._listen_task:
|
||
self._listen_task.cancel()
|
||
try:
|
||
await self._listen_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._listen_task = None
|
||
|
||
if self._heartbeat_task:
|
||
self._heartbeat_task.cancel()
|
||
try:
|
||
await self._heartbeat_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._heartbeat_task = None
|
||
|
||
self._fail_pending_responses(RuntimeError("WeCom adapter disconnected"))
|
||
await self._cleanup_ws()
|
||
|
||
if self._http_client:
|
||
await self._http_client.aclose()
|
||
self._http_client = None
|
||
|
||
self._dedup.clear()
|
||
logger.info("[%s] Disconnected", self.name)
|
||
|
||
async def _cleanup_ws(self) -> None:
|
||
"""Close the live websocket/session, if any."""
|
||
if self._ws and not self._ws.closed:
|
||
await self._ws.close()
|
||
self._ws = None
|
||
|
||
if self._session and not self._session.closed:
|
||
await self._session.close()
|
||
self._session = None
|
||
|
||
async def _open_connection(self) -> None:
|
||
"""Open and authenticate a websocket connection."""
|
||
await self._cleanup_ws()
|
||
self._session = aiohttp.ClientSession(trust_env=True)
|
||
self._ws = await self._session.ws_connect(
|
||
self._ws_url,
|
||
heartbeat=HEARTBEAT_INTERVAL_SECONDS * 2,
|
||
timeout=CONNECT_TIMEOUT_SECONDS,
|
||
)
|
||
|
||
req_id = self._new_req_id("subscribe")
|
||
await self._send_json(
|
||
{
|
||
"cmd": APP_CMD_SUBSCRIBE,
|
||
"headers": {"req_id": req_id},
|
||
"body": {
|
||
"bot_id": self._bot_id,
|
||
"secret": self._secret,
|
||
"device_id": self._device_id,
|
||
},
|
||
}
|
||
)
|
||
|
||
auth_payload = await self._wait_for_handshake(req_id)
|
||
errcode = auth_payload.get("errcode", 0)
|
||
if errcode not in (0, None):
|
||
errmsg = auth_payload.get("errmsg", "authentication failed")
|
||
raise RuntimeError(f"{errmsg} (errcode={errcode})")
|
||
|
||
async def _wait_for_handshake(self, req_id: str) -> Dict[str, Any]:
|
||
"""Wait for the subscribe acknowledgement."""
|
||
if not self._ws:
|
||
raise RuntimeError("WebSocket not initialized")
|
||
|
||
deadline = asyncio.get_running_loop().time() + CONNECT_TIMEOUT_SECONDS
|
||
while True:
|
||
remaining = deadline - asyncio.get_running_loop().time()
|
||
if remaining <= 0:
|
||
raise TimeoutError("Timed out waiting for WeCom subscribe acknowledgement")
|
||
|
||
msg = await asyncio.wait_for(self._ws.receive(), timeout=remaining)
|
||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
payload = self._parse_json(msg.data)
|
||
if not payload:
|
||
continue
|
||
if payload.get("cmd") == APP_CMD_PING:
|
||
continue
|
||
if self._payload_req_id(payload) == req_id:
|
||
return payload
|
||
logger.debug("[%s] Ignoring pre-auth payload: %s", self.name, payload.get("cmd"))
|
||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR):
|
||
raise RuntimeError("WeCom websocket closed during authentication")
|
||
|
||
async def _listen_loop(self) -> None:
|
||
"""Read websocket events forever, reconnecting on errors."""
|
||
backoff_idx = 0
|
||
while self._running:
|
||
try:
|
||
await self._read_events()
|
||
backoff_idx = 0
|
||
except asyncio.CancelledError:
|
||
return
|
||
except Exception as exc:
|
||
if not self._running:
|
||
return
|
||
logger.warning("[%s] WebSocket error: %s", self.name, exc)
|
||
self._fail_pending_responses(RuntimeError("WeCom connection interrupted"))
|
||
|
||
delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)]
|
||
backoff_idx += 1
|
||
await asyncio.sleep(delay)
|
||
|
||
try:
|
||
await self._open_connection()
|
||
backoff_idx = 0
|
||
logger.info("[%s] Reconnected", self.name)
|
||
except Exception as reconnect_exc:
|
||
logger.warning("[%s] Reconnect failed: %s", self.name, reconnect_exc)
|
||
|
||
async def _read_events(self) -> None:
|
||
"""Read websocket frames until the connection closes."""
|
||
if not self._ws:
|
||
raise RuntimeError("WebSocket not connected")
|
||
|
||
while self._running and self._ws and not self._ws.closed:
|
||
msg = await self._ws.receive()
|
||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||
payload = self._parse_json(msg.data)
|
||
if payload:
|
||
await self._dispatch_payload(payload)
|
||
elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||
raise RuntimeError("WeCom websocket closed")
|
||
|
||
async def _heartbeat_loop(self) -> None:
|
||
"""Send lightweight application-level pings."""
|
||
try:
|
||
while self._running:
|
||
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
||
if not self._ws or self._ws.closed:
|
||
continue
|
||
try:
|
||
await self._send_json(
|
||
{
|
||
"cmd": APP_CMD_PING,
|
||
"headers": {"req_id": self._new_req_id("ping")},
|
||
"body": {},
|
||
}
|
||
)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Heartbeat send failed: %s", self.name, exc)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
async def _dispatch_payload(self, payload: Dict[str, Any]) -> None:
|
||
"""Route inbound websocket payloads."""
|
||
req_id = self._payload_req_id(payload)
|
||
cmd = str(payload.get("cmd") or "")
|
||
|
||
if req_id and req_id in self._pending_responses and cmd not in NON_RESPONSE_COMMANDS:
|
||
future = self._pending_responses.get(req_id)
|
||
if future and not future.done():
|
||
future.set_result(payload)
|
||
return
|
||
|
||
if cmd in CALLBACK_COMMANDS:
|
||
await self._on_message(payload)
|
||
return
|
||
if cmd in {APP_CMD_PING, APP_CMD_EVENT_CALLBACK}:
|
||
return
|
||
|
||
logger.debug("[%s] Ignoring websocket payload: %s", self.name, cmd or payload)
|
||
|
||
def _fail_pending_responses(self, exc: Exception) -> None:
|
||
"""Fail all outstanding request futures."""
|
||
for req_id, future in list(self._pending_responses.items()):
|
||
if not future.done():
|
||
future.set_exception(exc)
|
||
self._pending_responses.pop(req_id, None)
|
||
|
||
async def _send_json(self, payload: Dict[str, Any]) -> None:
|
||
"""Send a raw JSON frame over the active websocket."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
await self._ws.send_json(payload)
|
||
|
||
async def _send_request(self, cmd: str, body: Dict[str, Any], timeout: float = REQUEST_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||
"""Send a JSON request and await the correlated response."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
|
||
req_id = self._new_req_id(cmd)
|
||
future = asyncio.get_running_loop().create_future()
|
||
self._pending_responses[req_id] = future
|
||
try:
|
||
await self._send_json({"cmd": cmd, "headers": {"req_id": req_id}, "body": body})
|
||
response = await asyncio.wait_for(future, timeout=timeout)
|
||
return response
|
||
finally:
|
||
self._pending_responses.pop(req_id, None)
|
||
|
||
async def _send_reply_request(
|
||
self,
|
||
reply_req_id: str,
|
||
body: Dict[str, Any],
|
||
cmd: str = APP_CMD_RESPONSE,
|
||
timeout: float = REQUEST_TIMEOUT_SECONDS,
|
||
) -> Dict[str, Any]:
|
||
"""Send a reply frame correlated to an inbound callback req_id."""
|
||
if not self._ws or self._ws.closed:
|
||
raise RuntimeError("WeCom websocket is not connected")
|
||
|
||
normalized_req_id = str(reply_req_id or "").strip()
|
||
if not normalized_req_id:
|
||
raise ValueError("reply_req_id is required")
|
||
|
||
future = asyncio.get_running_loop().create_future()
|
||
self._pending_responses[normalized_req_id] = future
|
||
try:
|
||
await self._send_json(
|
||
{"cmd": cmd, "headers": {"req_id": normalized_req_id}, "body": body}
|
||
)
|
||
response = await asyncio.wait_for(future, timeout=timeout)
|
||
return response
|
||
finally:
|
||
self._pending_responses.pop(normalized_req_id, None)
|
||
|
||
@staticmethod
|
||
def _new_req_id(prefix: str) -> str:
|
||
return f"{prefix}-{uuid.uuid4().hex}"
|
||
|
||
@staticmethod
|
||
def _payload_req_id(payload: Dict[str, Any]) -> str:
|
||
headers = payload.get("headers")
|
||
if isinstance(headers, dict):
|
||
return str(headers.get("req_id") or "")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _parse_json(raw: Any) -> Optional[Dict[str, Any]]:
|
||
try:
|
||
payload = json.loads(raw)
|
||
except Exception:
|
||
logger.debug("Failed to parse WeCom payload: %r", raw)
|
||
return None
|
||
return payload if isinstance(payload, dict) else None
|
||
|
||
# ------------------------------------------------------------------
|
||
# Inbound message parsing
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _on_message(self, payload: Dict[str, Any]) -> None:
|
||
"""Process an inbound WeCom message callback event."""
|
||
body = payload.get("body")
|
||
if not isinstance(body, dict):
|
||
return
|
||
|
||
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
|
||
if self._dedup.is_duplicate(msg_id):
|
||
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
|
||
return
|
||
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
|
||
|
||
sender = body.get("from") if isinstance(body.get("from"), dict) else {}
|
||
sender_id = str(sender.get("userid") or "").strip()
|
||
chat_id = str(body.get("chatid") or sender_id).strip()
|
||
if not chat_id:
|
||
logger.debug("[%s] Missing chat id, skipping message", self.name)
|
||
return
|
||
|
||
is_group = str(body.get("chattype") or "").lower() == "group"
|
||
if is_group:
|
||
if not self._is_group_allowed(chat_id, sender_id):
|
||
logger.debug("[%s] Group %s / sender %s blocked by policy", self.name, chat_id, sender_id)
|
||
return
|
||
elif not self._is_dm_allowed(sender_id):
|
||
logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id)
|
||
return
|
||
|
||
# Cache the inbound req_id after policy checks so proactive sends to
|
||
# this chat can fall back to APP_CMD_RESPONSE (required for groups —
|
||
# WeCom AI Bots cannot initiate APP_CMD_SEND in group chats).
|
||
self._remember_chat_req_id(chat_id, self._payload_req_id(payload))
|
||
|
||
text, reply_text = self._extract_text(body)
|
||
media_urls, media_types = await self._extract_media(body)
|
||
message_type = self._derive_message_type(body, text, media_types)
|
||
has_reply_context = bool(reply_text and (text or media_urls))
|
||
|
||
if not text and reply_text and not media_urls:
|
||
text = reply_text
|
||
|
||
if not text and not media_urls:
|
||
logger.debug("[%s] Empty WeCom message skipped", self.name)
|
||
return
|
||
|
||
source = self.build_source(
|
||
chat_id=chat_id,
|
||
chat_type="group" if is_group else "dm",
|
||
user_id=sender_id or None,
|
||
user_name=sender_id or None,
|
||
)
|
||
|
||
event = MessageEvent(
|
||
text=text,
|
||
message_type=message_type,
|
||
source=source,
|
||
raw_message=payload,
|
||
message_id=msg_id,
|
||
media_urls=media_urls,
|
||
media_types=media_types,
|
||
reply_to_message_id=f"quote:{msg_id}" if has_reply_context else None,
|
||
reply_to_text=reply_text if has_reply_context else None,
|
||
timestamp=datetime.now(tz=timezone.utc),
|
||
)
|
||
|
||
# Only batch plain text messages — commands, media, etc. dispatch
|
||
# immediately since they won't be split by the WeCom client.
|
||
if message_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
|
||
self._enqueue_text_event(event)
|
||
else:
|
||
await self.handle_message(event)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Text message aggregation (handles WeCom client-side splits)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||
"""Session-scoped key for text message batching."""
|
||
from gateway.session import build_session_key
|
||
return build_session_key(
|
||
event.source,
|
||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||
)
|
||
|
||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||
"""Buffer a text event and reset the flush timer.
|
||
|
||
When WeCom splits a long user message at 4000 chars, the chunks
|
||
arrive within a few hundred milliseconds. This merges them into
|
||
a single event before dispatching.
|
||
"""
|
||
key = self._text_batch_key(event)
|
||
existing = self._pending_text_batches.get(key)
|
||
chunk_len = len(event.text or "")
|
||
if existing is None:
|
||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||
self._pending_text_batches[key] = event
|
||
else:
|
||
if event.text:
|
||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||
# Merge any media that might be attached
|
||
if event.media_urls:
|
||
existing.media_urls.extend(event.media_urls)
|
||
existing.media_types.extend(event.media_types)
|
||
|
||
# Cancel any pending flush and restart the timer
|
||
prior_task = self._pending_text_batch_tasks.get(key)
|
||
if prior_task and not prior_task.done():
|
||
prior_task.cancel()
|
||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||
self._flush_text_batch(key)
|
||
)
|
||
|
||
async def _flush_text_batch(self, key: str) -> None:
|
||
"""Wait for the quiet period then dispatch the aggregated text.
|
||
|
||
Uses a longer delay when the latest chunk is near WeCom's 4000-char
|
||
split point, since a continuation chunk is almost certain.
|
||
"""
|
||
current_task = asyncio.current_task()
|
||
try:
|
||
pending = self._pending_text_batches.get(key)
|
||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||
if last_len >= self._SPLIT_THRESHOLD:
|
||
delay = self._text_batch_split_delay_seconds
|
||
else:
|
||
delay = self._text_batch_delay_seconds
|
||
await asyncio.sleep(delay)
|
||
event = self._pending_text_batches.pop(key, None)
|
||
if not event:
|
||
return
|
||
logger.info(
|
||
"[WeCom] Flushing text batch %s (%d chars)",
|
||
key, len(event.text or ""),
|
||
)
|
||
await self.handle_message(event)
|
||
finally:
|
||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||
self._pending_text_batch_tasks.pop(key, None)
|
||
|
||
@staticmethod
|
||
def _extract_text(body: Dict[str, Any]) -> Tuple[str, Optional[str]]:
|
||
"""Extract plain text and quoted text from a callback payload."""
|
||
text_parts: List[str] = []
|
||
reply_text: Optional[str] = None
|
||
msgtype = str(body.get("msgtype") or "").lower()
|
||
|
||
if msgtype == "mixed":
|
||
_raw_mixed = body.get("mixed")
|
||
mixed = _raw_mixed if isinstance(_raw_mixed, dict) else {}
|
||
_raw_items = mixed.get("msg_item")
|
||
items = _raw_items if isinstance(_raw_items, list) else []
|
||
for item in items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if str(item.get("msgtype") or "").lower() == "text":
|
||
_raw_text = item.get("text")
|
||
text_block = _raw_text if isinstance(_raw_text, dict) else {}
|
||
content = str(text_block.get("content") or "").strip()
|
||
if content:
|
||
text_parts.append(content)
|
||
else:
|
||
text_block = body.get("text") if isinstance(body.get("text"), dict) else {}
|
||
content = str(text_block.get("content") or "").strip()
|
||
if content:
|
||
text_parts.append(content)
|
||
|
||
if msgtype == "voice":
|
||
voice_block = body.get("voice") if isinstance(body.get("voice"), dict) else {}
|
||
voice_text = str(voice_block.get("content") or "").strip()
|
||
if voice_text:
|
||
text_parts.append(voice_text)
|
||
|
||
# Extract appmsg title (filename) for WeCom AI Bot attachments
|
||
if msgtype == "appmsg":
|
||
appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {}
|
||
title = str(appmsg.get("title") or "").strip()
|
||
if title:
|
||
text_parts.append(title)
|
||
|
||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||
quote_type = str(quote.get("msgtype") or "").lower()
|
||
if quote_type == "text":
|
||
quote_text = quote.get("text") if isinstance(quote.get("text"), dict) else {}
|
||
reply_text = str(quote_text.get("content") or "").strip() or None
|
||
elif quote_type == "voice":
|
||
quote_voice = quote.get("voice") if isinstance(quote.get("voice"), dict) else {}
|
||
reply_text = str(quote_voice.get("content") or "").strip() or None
|
||
|
||
return "\n".join(part for part in text_parts if part).strip(), reply_text
|
||
|
||
async def _extract_media(self, body: Dict[str, Any]) -> Tuple[List[str], List[str]]:
|
||
"""Best-effort extraction of inbound media to local cache paths."""
|
||
media_paths: List[str] = []
|
||
media_types: List[str] = []
|
||
refs: List[Tuple[str, Dict[str, Any]]] = []
|
||
msgtype = str(body.get("msgtype") or "").lower()
|
||
|
||
if msgtype == "mixed":
|
||
_raw_mixed = body.get("mixed")
|
||
mixed = _raw_mixed if isinstance(_raw_mixed, dict) else {}
|
||
_raw_items = mixed.get("msg_item")
|
||
items = _raw_items if isinstance(_raw_items, list) else []
|
||
for item in items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
item_type = str(item.get("msgtype") or "").lower()
|
||
if item_type == "image" and isinstance(item.get("image"), dict):
|
||
refs.append(("image", item["image"]))
|
||
else:
|
||
if isinstance(body.get("image"), dict):
|
||
refs.append(("image", body["image"]))
|
||
if msgtype == "file" and isinstance(body.get("file"), dict):
|
||
refs.append(("file", body["file"]))
|
||
# Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel)
|
||
if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict):
|
||
appmsg = body["appmsg"]
|
||
if isinstance(appmsg.get("file"), dict):
|
||
refs.append(("file", appmsg["file"]))
|
||
elif isinstance(appmsg.get("image"), dict):
|
||
refs.append(("image", appmsg["image"]))
|
||
|
||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||
quote_type = str(quote.get("msgtype") or "").lower()
|
||
if quote_type == "image" and isinstance(quote.get("image"), dict):
|
||
refs.append(("image", quote["image"]))
|
||
elif quote_type == "file" and isinstance(quote.get("file"), dict):
|
||
refs.append(("file", quote["file"]))
|
||
|
||
for kind, ref in refs:
|
||
cached = await self._cache_media(kind, ref)
|
||
if cached:
|
||
path, content_type = cached
|
||
media_paths.append(path)
|
||
media_types.append(content_type)
|
||
|
||
return media_paths, media_types
|
||
|
||
async def _cache_media(self, kind: str, media: Dict[str, Any]) -> Optional[Tuple[str, str]]:
|
||
"""Cache an inbound image/file/media reference to local storage."""
|
||
if "base64" in media and media.get("base64"):
|
||
try:
|
||
raw = self._decode_base64(media["base64"])
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to decode %s base64 media: %s", self.name, kind, exc)
|
||
return None
|
||
|
||
if kind == "image":
|
||
ext = self._detect_image_ext(raw)
|
||
try:
|
||
return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg")
|
||
except ValueError as exc:
|
||
logger.warning("[%s] Rejected non-image bytes: %s", self.name, exc)
|
||
return None
|
||
|
||
filename = str(media.get("filename") or media.get("name") or "wecom_file")
|
||
return cache_document_from_bytes(raw, filename), mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||
|
||
url = str(media.get("url") or "").strip()
|
||
if not url:
|
||
return None
|
||
|
||
try:
|
||
raw, headers = await self._download_remote_bytes(url, max_bytes=ABSOLUTE_MAX_BYTES)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to download %s from %s: %s", self.name, kind, url, exc)
|
||
return None
|
||
|
||
aes_key = str(media.get("aeskey") or "").strip()
|
||
if aes_key:
|
||
try:
|
||
raw = self._decrypt_file_bytes(raw, aes_key)
|
||
except Exception as exc:
|
||
logger.debug("[%s] Failed to decrypt %s from %s: %s", self.name, kind, url, exc)
|
||
return None
|
||
|
||
content_type = str(headers.get("content-type") or "").split(";", 1)[0].strip() or "application/octet-stream"
|
||
if kind == "image":
|
||
ext = self._guess_extension(url, content_type, fallback=self._detect_image_ext(raw))
|
||
try:
|
||
return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg")
|
||
except ValueError as exc:
|
||
logger.warning("[%s] Rejected non-image bytes from %s: %s", self.name, url, exc)
|
||
return None
|
||
|
||
filename = self._guess_filename(url, headers.get("content-disposition"), content_type)
|
||
return cache_document_from_bytes(raw, filename), content_type
|
||
|
||
@staticmethod
|
||
def _decode_base64(data: str) -> bytes:
|
||
payload = data.split(",", 1)[-1].strip()
|
||
return base64.b64decode(payload)
|
||
|
||
@staticmethod
|
||
def _detect_image_ext(data: bytes) -> str:
|
||
if data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||
return ".png"
|
||
if data.startswith(b"\xff\xd8\xff"):
|
||
return ".jpg"
|
||
if data.startswith((b"GIF87a", b"GIF89a")):
|
||
return ".gif"
|
||
if data.startswith(b"RIFF") and data[8:12] == b"WEBP":
|
||
return ".webp"
|
||
return ".jpg"
|
||
|
||
@staticmethod
|
||
def _mime_for_ext(ext: str, fallback: str = "application/octet-stream") -> str:
|
||
return mimetypes.types_map.get(ext.lower(), fallback)
|
||
|
||
@staticmethod
|
||
def _guess_extension(url: str, content_type: str, fallback: str) -> str:
|
||
ext = mimetypes.guess_extension(content_type) if content_type else None
|
||
if ext:
|
||
return ext
|
||
path_ext = Path(urlparse(url).path).suffix
|
||
if path_ext:
|
||
return path_ext
|
||
return fallback
|
||
|
||
@staticmethod
|
||
def _guess_filename(url: str, content_disposition: Optional[str], content_type: str) -> str:
|
||
if content_disposition:
|
||
match = re.search(r'filename="?([^";]+)"?', content_disposition)
|
||
if match:
|
||
return match.group(1)
|
||
|
||
name = Path(urlparse(url).path).name or "document"
|
||
if "." not in name:
|
||
ext = mimetypes.guess_extension(content_type) or ".bin"
|
||
name = f"{name}{ext}"
|
||
return name
|
||
|
||
@staticmethod
|
||
def _derive_message_type(body: Dict[str, Any], text: str, media_types: List[str]) -> MessageType:
|
||
"""Choose the normalized inbound message type."""
|
||
if any(mtype.startswith(("application/", "text/")) for mtype in media_types):
|
||
return MessageType.DOCUMENT
|
||
if any(mtype.startswith("image/") for mtype in media_types):
|
||
return MessageType.TEXT if text else MessageType.PHOTO
|
||
if str(body.get("msgtype") or "").lower() == "voice":
|
||
return MessageType.VOICE
|
||
return MessageType.TEXT
|
||
|
||
# ------------------------------------------------------------------
|
||
# Policy helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _is_dm_allowed(self, sender_id: str) -> bool:
|
||
if self._dm_policy == "disabled":
|
||
return False
|
||
if self._dm_policy == "allowlist":
|
||
return _entry_matches(self._allow_from, sender_id)
|
||
return True
|
||
|
||
def _is_group_allowed(self, chat_id: str, sender_id: str) -> bool:
|
||
if self._group_policy == "disabled":
|
||
return False
|
||
if self._group_policy == "allowlist" and not _entry_matches(self._group_allow_from, chat_id):
|
||
return False
|
||
|
||
group_cfg = self._resolve_group_cfg(chat_id)
|
||
sender_allow = _coerce_list(group_cfg.get("allow_from") or group_cfg.get("allowFrom"))
|
||
if sender_allow:
|
||
return _entry_matches(sender_allow, sender_id)
|
||
return True
|
||
|
||
def _resolve_group_cfg(self, chat_id: str) -> Dict[str, Any]:
|
||
if not isinstance(self._groups, dict):
|
||
return {}
|
||
if chat_id in self._groups and isinstance(self._groups[chat_id], dict):
|
||
return self._groups[chat_id]
|
||
lowered = chat_id.lower()
|
||
for key, value in self._groups.items():
|
||
if isinstance(key, str) and key.lower() == lowered and isinstance(value, dict):
|
||
return value
|
||
wildcard = self._groups.get("*")
|
||
return wildcard if isinstance(wildcard, dict) else {}
|
||
|
||
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
|
||
normalized_message_id = str(message_id or "").strip()
|
||
normalized_req_id = str(req_id or "").strip()
|
||
if not normalized_message_id or not normalized_req_id:
|
||
return
|
||
self._reply_req_ids[normalized_message_id] = normalized_req_id
|
||
while len(self._reply_req_ids) > DEDUP_MAX_SIZE:
|
||
self._reply_req_ids.pop(next(iter(self._reply_req_ids)))
|
||
|
||
def _remember_chat_req_id(self, chat_id: str, req_id: str) -> None:
|
||
"""Cache the most recent inbound req_id per chat.
|
||
|
||
Used as a fallback reply target when we need to send into a group
|
||
without an explicit ``reply_to`` — WeCom AI Bots are blocked from
|
||
APP_CMD_SEND in groups and must use APP_CMD_RESPONSE bound to some
|
||
prior req_id. Bounded like _reply_req_ids so long-running gateways
|
||
don't leak memory across many chats.
|
||
"""
|
||
normalized_chat_id = str(chat_id or "").strip()
|
||
normalized_req_id = str(req_id or "").strip()
|
||
if not normalized_chat_id or not normalized_req_id:
|
||
return
|
||
self._last_chat_req_ids[normalized_chat_id] = normalized_req_id
|
||
while len(self._last_chat_req_ids) > DEDUP_MAX_SIZE:
|
||
self._last_chat_req_ids.pop(next(iter(self._last_chat_req_ids)))
|
||
|
||
def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]:
|
||
normalized = str(reply_to or "").strip()
|
||
if not normalized or normalized.startswith("quote:"):
|
||
return None
|
||
return self._reply_req_ids.get(normalized)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Outbound messaging
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _guess_mime_type(filename: str) -> str:
|
||
mime_type = mimetypes.guess_type(filename)[0]
|
||
if mime_type:
|
||
return mime_type
|
||
if Path(filename).suffix.lower() == ".amr":
|
||
return "audio/amr"
|
||
return "application/octet-stream"
|
||
|
||
@staticmethod
|
||
def _normalize_content_type(content_type: str, filename: str) -> str:
|
||
normalized = str(content_type or "").split(";", 1)[0].strip().lower()
|
||
guessed = WeComAdapter._guess_mime_type(filename)
|
||
if not normalized:
|
||
return guessed
|
||
if normalized in {"application/octet-stream", "text/plain"}:
|
||
return guessed
|
||
return normalized
|
||
|
||
@staticmethod
|
||
def _detect_wecom_media_type(content_type: str) -> str:
|
||
mime_type = str(content_type or "").strip().lower()
|
||
if mime_type.startswith("image/"):
|
||
return "image"
|
||
if mime_type.startswith("video/"):
|
||
return "video"
|
||
if mime_type.startswith("audio/") or mime_type == "application/ogg":
|
||
return "voice"
|
||
return "file"
|
||
|
||
@staticmethod
|
||
def _apply_file_size_limits(file_size: int, detected_type: str, content_type: Optional[str] = None) -> Dict[str, Any]:
|
||
file_size_mb = file_size / (1024 * 1024)
|
||
normalized_type = str(detected_type or "file").lower()
|
||
normalized_content_type = str(content_type or "").strip().lower()
|
||
|
||
if file_size > ABSOLUTE_MAX_BYTES:
|
||
return {
|
||
"final_type": normalized_type,
|
||
"rejected": True,
|
||
"reject_reason": (
|
||
f"文件大小 {file_size_mb:.2f}MB 超过了企业微信允许的最大限制 20MB,无法发送。"
|
||
"请尝试压缩文件或减小文件大小。"
|
||
),
|
||
"downgraded": False,
|
||
"downgrade_note": None,
|
||
}
|
||
|
||
if normalized_type == "image" and file_size > IMAGE_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"图片大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
if normalized_type == "video" and file_size > VIDEO_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"视频大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
if normalized_type == "voice":
|
||
if normalized_content_type and normalized_content_type not in VOICE_SUPPORTED_MIMES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": (
|
||
f"语音格式 {normalized_content_type} 不支持,企微仅支持 AMR 格式,已转为文件格式发送"
|
||
),
|
||
}
|
||
if file_size > VOICE_MAX_BYTES:
|
||
return {
|
||
"final_type": "file",
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": True,
|
||
"downgrade_note": f"语音大小 {file_size_mb:.2f}MB 超过 2MB 限制,已转为文件格式发送",
|
||
}
|
||
|
||
return {
|
||
"final_type": normalized_type,
|
||
"rejected": False,
|
||
"reject_reason": None,
|
||
"downgraded": False,
|
||
"downgrade_note": None,
|
||
}
|
||
|
||
@staticmethod
|
||
def _response_error(response: Dict[str, Any]) -> Optional[str]:
|
||
errcode = response.get("errcode", 0)
|
||
if errcode in (0, None):
|
||
return None
|
||
errmsg = str(response.get("errmsg") or "unknown error")
|
||
return f"WeCom errcode {errcode}: {errmsg}"
|
||
|
||
@classmethod
|
||
def _raise_for_wecom_error(cls, response: Dict[str, Any], operation: str) -> None:
|
||
error = cls._response_error(response)
|
||
if error:
|
||
raise RuntimeError(f"{operation} failed: {error}")
|
||
|
||
@staticmethod
|
||
def _decrypt_file_bytes(encrypted_data: bytes, aes_key: str) -> bytes:
|
||
if not encrypted_data:
|
||
raise ValueError("encrypted_data is empty")
|
||
if not aes_key:
|
||
raise ValueError("aes_key is required")
|
||
|
||
key = base64.b64decode(aes_key)
|
||
if len(key) != 32:
|
||
raise ValueError(f"Invalid WeCom AES key length: expected 32 bytes, got {len(key)}")
|
||
|
||
try:
|
||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||
except ImportError as exc: # pragma: no cover - dependency is environment-specific
|
||
raise RuntimeError("cryptography is required for WeCom media decryption") from exc
|
||
|
||
cipher = Cipher(algorithms.AES(key), modes.CBC(key[:16]))
|
||
decryptor = cipher.decryptor()
|
||
decrypted = decryptor.update(encrypted_data) + decryptor.finalize()
|
||
|
||
pad_len = decrypted[-1]
|
||
if pad_len < 1 or pad_len > 32 or pad_len > len(decrypted):
|
||
raise ValueError(f"Invalid PKCS#7 padding value: {pad_len}")
|
||
if any(byte != pad_len for byte in decrypted[-pad_len:]):
|
||
raise ValueError("Invalid PKCS#7 padding: padding bytes mismatch")
|
||
|
||
return decrypted[:-pad_len]
|
||
|
||
async def _download_remote_bytes(
|
||
self,
|
||
url: str,
|
||
max_bytes: int,
|
||
) -> Tuple[bytes, Dict[str, str]]:
|
||
from tools.url_safety import is_safe_url
|
||
if not is_safe_url(url):
|
||
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url[:80]}")
|
||
|
||
if not HTTPX_AVAILABLE:
|
||
raise RuntimeError("httpx is required for WeCom media download")
|
||
|
||
client = self._http_client or httpx.AsyncClient(timeout=30.0, follow_redirects=True)
|
||
created_client = client is not self._http_client
|
||
try:
|
||
async with client.stream(
|
||
"GET",
|
||
url,
|
||
headers={
|
||
"User-Agent": "HermesAgent/1.0",
|
||
"Accept": "*/*",
|
||
},
|
||
) as response:
|
||
response.raise_for_status()
|
||
headers = {key.lower(): value for key, value in response.headers.items()}
|
||
content_length = headers.get("content-length")
|
||
if content_length and content_length.isdigit() and int(content_length) > max_bytes:
|
||
raise ValueError(
|
||
f"Remote media exceeds WeCom limit: {int(content_length)} bytes > {max_bytes} bytes"
|
||
)
|
||
|
||
data = bytearray()
|
||
async for chunk in response.aiter_bytes():
|
||
data.extend(chunk)
|
||
if len(data) > max_bytes:
|
||
raise ValueError(
|
||
f"Remote media exceeds WeCom limit while downloading: {len(data)} bytes > {max_bytes} bytes"
|
||
)
|
||
|
||
return bytes(data), headers
|
||
finally:
|
||
if created_client:
|
||
await client.aclose()
|
||
|
||
@staticmethod
|
||
def _looks_like_url(media_source: str) -> bool:
|
||
parsed = urlparse(str(media_source or ""))
|
||
return parsed.scheme in {"http", "https"}
|
||
|
||
async def _load_outbound_media(
|
||
self,
|
||
media_source: str,
|
||
file_name: Optional[str] = None,
|
||
) -> Tuple[bytes, str, str]:
|
||
source = str(media_source or "").strip()
|
||
if not source:
|
||
raise ValueError("media source is required")
|
||
if re.fullmatch(r"<[^>\n]+>", source):
|
||
raise ValueError(f"Media placeholder was not replaced with a real file path: {source}")
|
||
|
||
parsed = urlparse(source)
|
||
if parsed.scheme in {"http", "https"}:
|
||
data, headers = await self._download_remote_bytes(source, max_bytes=ABSOLUTE_MAX_BYTES)
|
||
content_disposition = headers.get("content-disposition")
|
||
resolved_name = file_name or self._guess_filename(source, content_disposition, headers.get("content-type", ""))
|
||
content_type = self._normalize_content_type(headers.get("content-type", ""), resolved_name)
|
||
return data, content_type, resolved_name
|
||
|
||
if parsed.scheme == "file":
|
||
local_path = Path(unquote(parsed.path)).expanduser()
|
||
else:
|
||
local_path = Path(source).expanduser()
|
||
|
||
if not local_path.is_absolute():
|
||
local_path = (Path.cwd() / local_path).resolve()
|
||
|
||
if not local_path.exists() or not local_path.is_file():
|
||
raise FileNotFoundError(f"Media file not found: {local_path}")
|
||
|
||
data = local_path.read_bytes()
|
||
resolved_name = file_name or local_path.name
|
||
content_type = self._normalize_content_type("", resolved_name)
|
||
return data, content_type, resolved_name
|
||
|
||
async def _prepare_outbound_media(
|
||
self,
|
||
media_source: str,
|
||
file_name: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
data, content_type, resolved_name = await self._load_outbound_media(media_source, file_name=file_name)
|
||
detected_type = self._detect_wecom_media_type(content_type)
|
||
size_check = self._apply_file_size_limits(len(data), detected_type, content_type)
|
||
return {
|
||
"data": data,
|
||
"content_type": content_type,
|
||
"file_name": resolved_name,
|
||
"detected_type": detected_type,
|
||
**size_check,
|
||
}
|
||
|
||
async def _upload_media_bytes(self, data: bytes, media_type: str, filename: str) -> Dict[str, Any]:
|
||
if not data:
|
||
raise ValueError("Cannot upload empty media")
|
||
|
||
total_size = len(data)
|
||
total_chunks = (total_size + UPLOAD_CHUNK_SIZE - 1) // UPLOAD_CHUNK_SIZE
|
||
if total_chunks > MAX_UPLOAD_CHUNKS:
|
||
raise ValueError(
|
||
f"File too large: {total_chunks} chunks exceeds maximum of {MAX_UPLOAD_CHUNKS} chunks"
|
||
)
|
||
|
||
init_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_INIT,
|
||
{
|
||
"type": media_type,
|
||
"filename": filename,
|
||
"total_size": total_size,
|
||
"total_chunks": total_chunks,
|
||
"md5": hashlib.md5(data).hexdigest(),
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(init_response, "media upload init")
|
||
|
||
init_body = init_response.get("body") if isinstance(init_response.get("body"), dict) else {}
|
||
upload_id = str(init_body.get("upload_id") or "").strip()
|
||
if not upload_id:
|
||
raise RuntimeError(f"media upload init failed: missing upload_id in response {init_response}")
|
||
|
||
for chunk_index, start in enumerate(range(0, total_size, UPLOAD_CHUNK_SIZE)):
|
||
chunk = data[start : start + UPLOAD_CHUNK_SIZE]
|
||
chunk_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_CHUNK,
|
||
{
|
||
"upload_id": upload_id,
|
||
# Match the official SDK implementation, which currently uses 0-based chunk indexes.
|
||
"chunk_index": chunk_index,
|
||
"base64_data": base64.b64encode(chunk).decode("ascii"),
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(chunk_response, f"media upload chunk {chunk_index}")
|
||
|
||
finish_response = await self._send_request(
|
||
APP_CMD_UPLOAD_MEDIA_FINISH,
|
||
{"upload_id": upload_id},
|
||
)
|
||
self._raise_for_wecom_error(finish_response, "media upload finish")
|
||
|
||
finish_body = finish_response.get("body") if isinstance(finish_response.get("body"), dict) else {}
|
||
media_id = str(finish_body.get("media_id") or "").strip()
|
||
if not media_id:
|
||
raise RuntimeError(f"media upload finish failed: missing media_id in response {finish_response}")
|
||
|
||
return {
|
||
"type": str(finish_body.get("type") or media_type),
|
||
"media_id": media_id,
|
||
"created_at": finish_body.get("created_at"),
|
||
}
|
||
|
||
async def _send_media_message(self, chat_id: str, media_type: str, media_id: str) -> Dict[str, Any]:
|
||
response = await self._send_request(
|
||
APP_CMD_SEND,
|
||
{
|
||
"chatid": chat_id,
|
||
"msgtype": media_type,
|
||
media_type: {"media_id": media_id},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send media message")
|
||
return response
|
||
|
||
async def _send_reply_markdown(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||
response = await self._send_reply_request(
|
||
reply_req_id,
|
||
{
|
||
"msgtype": "markdown",
|
||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send reply markdown")
|
||
return response
|
||
|
||
async def _send_reply_media_message(
|
||
self,
|
||
reply_req_id: str,
|
||
media_type: str,
|
||
media_id: str,
|
||
) -> Dict[str, Any]:
|
||
response = await self._send_reply_request(
|
||
reply_req_id,
|
||
{
|
||
"msgtype": media_type,
|
||
media_type: {"media_id": media_id},
|
||
},
|
||
)
|
||
self._raise_for_wecom_error(response, "send reply media message")
|
||
return response
|
||
|
||
async def _send_followup_markdown(
|
||
self,
|
||
chat_id: str,
|
||
content: str,
|
||
reply_to: Optional[str] = None,
|
||
) -> Optional[SendResult]:
|
||
if not content:
|
||
return None
|
||
result = await self.send(chat_id=chat_id, content=content, reply_to=reply_to)
|
||
if not result.success:
|
||
logger.warning("[%s] Follow-up markdown send failed: %s", self.name, result.error)
|
||
return result
|
||
|
||
async def _send_media_source(
|
||
self,
|
||
chat_id: str,
|
||
media_source: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
) -> SendResult:
|
||
if not chat_id:
|
||
return SendResult(success=False, error="chat_id is required")
|
||
|
||
try:
|
||
prepared = await self._prepare_outbound_media(media_source, file_name=file_name)
|
||
except FileNotFoundError as exc:
|
||
return SendResult(success=False, error=str(exc))
|
||
except Exception as exc:
|
||
logger.error("[%s] Failed to prepare outbound media %s: %s", self.name, media_source, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
if prepared["rejected"]:
|
||
await self._send_followup_markdown(
|
||
chat_id,
|
||
f"⚠️ {prepared['reject_reason']}",
|
||
reply_to=reply_to,
|
||
)
|
||
return SendResult(success=False, error=prepared["reject_reason"])
|
||
|
||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||
|
||
try:
|
||
upload_result = await self._upload_media_bytes(
|
||
prepared["data"],
|
||
prepared["final_type"],
|
||
prepared["file_name"],
|
||
)
|
||
if reply_req_id:
|
||
media_response = await self._send_reply_media_message(
|
||
reply_req_id,
|
||
prepared["final_type"],
|
||
upload_result["media_id"],
|
||
)
|
||
else:
|
||
media_response = await self._send_media_message(
|
||
chat_id,
|
||
prepared["final_type"],
|
||
upload_result["media_id"],
|
||
)
|
||
except asyncio.TimeoutError:
|
||
return SendResult(success=False, error="Timeout sending media to WeCom")
|
||
except Exception as exc:
|
||
logger.error("[%s] Failed to send media %s: %s", self.name, media_source, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
caption_result = None
|
||
downgrade_result = None
|
||
if caption:
|
||
caption_result = await self._send_followup_markdown(
|
||
chat_id,
|
||
caption,
|
||
reply_to=reply_to,
|
||
)
|
||
if prepared["downgraded"] and prepared["downgrade_note"]:
|
||
downgrade_result = await self._send_followup_markdown(
|
||
chat_id,
|
||
f"ℹ️ {prepared['downgrade_note']}",
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=self._payload_req_id(media_response) or uuid.uuid4().hex[:12],
|
||
raw_response={
|
||
"upload": upload_result,
|
||
"media": media_response,
|
||
"caption": caption_result.raw_response if caption_result else None,
|
||
"caption_error": caption_result.error if caption_result and not caption_result.success else None,
|
||
"downgrade": downgrade_result.raw_response if downgrade_result else None,
|
||
"downgrade_error": downgrade_result.error if downgrade_result and not downgrade_result.success else None,
|
||
},
|
||
)
|
||
|
||
async def send(
|
||
self,
|
||
chat_id: str,
|
||
content: str,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send markdown to a WeCom chat via proactive ``aibot_send_msg``."""
|
||
del metadata
|
||
|
||
if not chat_id:
|
||
return SendResult(success=False, error="chat_id is required")
|
||
|
||
try:
|
||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||
|
||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||
|
||
if reply_req_id:
|
||
response = await self._send_reply_markdown(reply_req_id, content)
|
||
else:
|
||
response = await self._send_request(
|
||
APP_CMD_SEND,
|
||
{
|
||
"chatid": chat_id,
|
||
"msgtype": "markdown",
|
||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||
},
|
||
)
|
||
except asyncio.TimeoutError:
|
||
return SendResult(success=False, error="Timeout sending message to WeCom")
|
||
except Exception as exc:
|
||
logger.error("[%s] Send failed: %s", self.name, exc)
|
||
return SendResult(success=False, error=str(exc))
|
||
|
||
error = self._response_error(response)
|
||
if error:
|
||
return SendResult(success=False, error=error)
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=self._payload_req_id(response) or uuid.uuid4().hex[:12],
|
||
raw_response=response,
|
||
)
|
||
|
||
async def send_image(
|
||
self,
|
||
chat_id: str,
|
||
image_url: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
del metadata
|
||
|
||
result = await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=image_url,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
if result.success or not self._looks_like_url(image_url):
|
||
return result
|
||
|
||
logger.warning("[%s] Falling back to text send for image URL %s: %s", self.name, image_url, result.error)
|
||
fallback_text = f"{caption}\n{image_url}" if caption else image_url
|
||
return await self.send(chat_id=chat_id, content=fallback_text, reply_to=reply_to)
|
||
|
||
async def send_image_file(
|
||
self,
|
||
chat_id: str,
|
||
image_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=image_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_document(
|
||
self,
|
||
chat_id: str,
|
||
file_path: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=file_path,
|
||
caption=caption,
|
||
file_name=file_name,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_voice(
|
||
self,
|
||
chat_id: str,
|
||
audio_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=audio_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_video(
|
||
self,
|
||
chat_id: str,
|
||
video_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
del kwargs
|
||
return await self._send_media_source(
|
||
chat_id=chat_id,
|
||
media_source=video_path,
|
||
caption=caption,
|
||
reply_to=reply_to,
|
||
)
|
||
|
||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||
"""WeCom does not expose typing indicators in this adapter."""
|
||
del chat_id, metadata
|
||
|
||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||
"""Return minimal chat info."""
|
||
return {
|
||
"name": chat_id,
|
||
"type": "group" if chat_id and chat_id.lower().startswith("group") else "dm",
|
||
}
|