diff --git a/gateway/platforms/yuanbao.py b/gateway/platforms/yuanbao.py index 9eec7079f53..26a151304da 100644 --- a/gateway/platforms/yuanbao.py +++ b/gateway/platforms/yuanbao.py @@ -18,6 +18,8 @@ Configuration in config.yaml (or via env vars): from __future__ import annotations import asyncio +import base64 +import binascii import collections import dataclasses import hashlib @@ -31,9 +33,10 @@ import time import urllib.parse import uuid from datetime import datetime, timezone, timedelta +from enum import Enum from pathlib import Path from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple +from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple import sys @@ -55,6 +58,7 @@ from gateway.platforms.base import ( SendResult, cache_document_from_bytes, cache_image_from_bytes, + cache_video_from_bytes, ) from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.yuanbao_media import ( @@ -77,6 +81,7 @@ from gateway.platforms.yuanbao_proto import ( HERMES_INSTANCE_ID, decode_conn_msg, decode_inbound_push, + decode_forward_msg_data, decode_query_group_info_rsp, decode_get_group_member_list_rsp, encode_auth_bind, @@ -164,7 +169,7 @@ _YB_RES_REF_RE = re.compile( _YB_LOCAL_MEDIA_RE = re.compile(r"\[(\w+):[^\]]*?(/[^\]]+?)\s*\]") # Media kinds that can be resolved and injected into the model context -_RESOLVABLE_MEDIA_KINDS = frozenset({"image", "file"}) +_RESOLVABLE_MEDIA_KINDS = frozenset({"image", "file", "video"}) # Strip page indicators like (1/3) appended by BasePlatformAdapter _INDICATOR_RE = re.compile(r'\s*\(\d+/\d+\)$') @@ -932,6 +937,10 @@ class InboundContext: raw_text: str = "" media_refs: list = dc_field(default_factory=list) + # Populated by ExtractContentMiddleware for elem_type 1009 (WeChat forward). + # Contains the parsed ForwardMsgData dict (sub_type / nick_name / msg list). + forwarded_records: Optional[dict] = None + # Owner command detection owner_command: Optional[str] = None @@ -939,7 +948,7 @@ class InboundContext: source: Optional[Any] = None # SessionSource # Populated by ClassifyMessageTypeMiddleware - msg_type: Optional[Any] = None # MessageType + msg_type: Optional[Any] = None # MessageType | YuanbaoMessageType # Populated by QuoteContextMiddleware reply_to_message_id: Optional[str] = None @@ -1761,6 +1770,9 @@ class ExtractContentMiddleware(InboundMiddleware): parts.append(text) else: parts.append("[unsupported message type]") + elif ctype == 1009: + # WeChat forwarded chat record: use the truncated summary text. + parts.append(custom.get("text", "[chat record]")) else: parts.append("[unsupported message type]") except (json.JSONDecodeError, TypeError): @@ -1872,10 +1884,70 @@ class ExtractContentMiddleware(InboundMiddleware): pass return urls + @staticmethod + def _extract_forwarded_records(msg_body: list, user_id: str = "") -> Optional[dict]: + """Extract ForwardMsgData from ext_map for elem_type 1009 (WeChat forward). + + The detailed chat-record payload lives in ``msg_content.ext_map`` + (protobuf field 999, ``map``): + - key format: ``wexin_forward_msg_[forward_msg_id]_[userid]`` + - value: a **base64-encoded protobuf** ``ForwardMsgData`` (NOT JSON). + Decode with base64 then ``decode_forward_msg_data`` to recover the + ``sub_type`` / ``nick_name`` / ``msg`` structure. + + Matching strategy: take the first ``wexin_forward_msg_`` entry whose + decoded payload is a valid ``ForwardMsgData`` (``sub_type == 1``). + + Returns the parsed ``ForwardMsgData`` dict or ``None``. + """ + for elem in msg_body or []: + if not isinstance(elem, dict) or elem.get("msg_type") != "TIMCustomElem": + continue + content = elem.get("msg_content", {}) or {} + if not isinstance(content, dict): + continue + data_str = content.get("data", "") + if not data_str: + continue + try: + custom = json.loads(data_str) + except (json.JSONDecodeError, TypeError): + continue + if not (isinstance(custom, dict) and custom.get("elem_type") == 1009): + continue + + ext_map = content.get("ext_map") or {} + if not isinstance(ext_map, dict) or not ext_map: + return None + + def _parse_value(value): + # ext_map values are base64-encoded ForwardMsgData protobuf. + if not isinstance(value, str) or not value: + return None + try: + pb = base64.b64decode(value) + except (binascii.Error, ValueError): + return None + data = decode_forward_msg_data(pb) + if isinstance(data, dict) and data.get("sub_type") == 1: + return data + return None + + # Take the first valid wexin_forward_msg_ entry. + for key, value in ext_map.items(): + if not key.startswith("wexin_forward_msg_"): + continue + parsed = _parse_value(value) + if parsed is not None: + return parsed + + return None + async def handle(self, ctx: InboundContext, next_fn) -> None: ctx.raw_text = self._rewrite_slash_command(self._extract_text(ctx.msg_body)) ctx.media_refs = self._extract_inbound_media_refs(ctx.msg_body) ctx.link_urls = self._extract_link_urls(ctx.msg_body) + ctx.forwarded_records = self._extract_forwarded_records(ctx.msg_body, ctx.from_account) await next_fn() class PlaceholderFilterMiddleware(InboundMiddleware): @@ -2085,10 +2157,14 @@ class GroupAtGuardMiddleware(InboundMiddleware): "and answer it directly." ) - @staticmethod + @classmethod def _observe_group_message( + cls, adapter, source, sender_display: str, text: str, - *, msg_id: Optional[str] = None, + *, + ctx: InboundContext, + msg_id: Optional[str] = None, + forwarded_records: Optional[dict] = None, ) -> None: """Write a group message into the session transcript without triggering the agent. @@ -2103,7 +2179,14 @@ class GroupAtGuardMiddleware(InboundMiddleware): try: session_entry = store.get_or_create_session(source) user_id = source.user_id or "unknown" - attributed = f"[{sender_display}|{user_id}]\n{text}" + body_text = text + if forwarded_records: + summary = ForwardedRecordsParseMiddleware.build_forward_text( + forwarded_records, ctx=ctx, is_dispatch=False, + ) + if summary: + body_text = f"{text}\n{summary}" if text else summary + attributed = f"[{sender_display}|{user_id}]\n{body_text}" entry: dict = { "role": "user", "content": attributed, @@ -2125,6 +2208,8 @@ class GroupAtGuardMiddleware(InboundMiddleware): self._observe_group_message( adapter, ctx.source, ctx.sender_nickname or ctx.from_account, ctx.raw_text, msg_id=ctx.msg_id or None, + forwarded_records=ctx.forwarded_records, + ctx=ctx, ) logger.info( "[%s] Group message observed (no @bot): chat=%s from=%s", @@ -2165,14 +2250,26 @@ class GroupAttributionMiddleware(InboundMiddleware): await next_fn() +class YuanbaoMessageType(Enum): + """Yuanbao-local message subtypes; coerced back to :class:`MessageType` + before leaving the adapter (see :class:`DispatchMiddleware`).""" + + # WeChat forwarded chat records (TIMCustomElem, elem_type 1009). + CHAT_RECORD = "chat_record" + + class ClassifyMessageTypeMiddleware(InboundMiddleware): """Determine MessageType from text content and msg_body elements.""" name = "classify-msg-type" @staticmethod - def _classify(text: str, msg_body: list) -> MessageType: - """Classify message type based on text and msg_body.""" + def _classify(text: str, msg_body: list): + """Classify message type based on text and msg_body. + + Returns a base :class:`MessageType`, or a yuanbao-local + :class:`YuanbaoMessageType` for platform-specific subtypes. + """ if text.startswith("/"): return MessageType.COMMAND for elem in msg_body: @@ -2185,6 +2282,14 @@ class ClassifyMessageTypeMiddleware(InboundMiddleware): return MessageType.VIDEO if etype == "TIMFileElem": return MessageType.DOCUMENT + if etype == "TIMCustomElem": + data_str = (elem.get("msg_content") or {}).get("data", "") + try: + custom = json.loads(data_str) + except (json.JSONDecodeError, TypeError): + custom = None + if isinstance(custom, dict) and custom.get("elem_type") == 1009: + return YuanbaoMessageType.CHAT_RECORD return MessageType.TEXT async def handle(self, ctx: InboundContext, next_fn) -> None: @@ -2266,6 +2371,180 @@ class QuoteContextMiddleware(InboundMiddleware): await next_fn() +class ForwardedRecordsParseMiddleware(InboundMiddleware): + """Deep-parse WeChat forwarded chat records (elem_type 1009) for dispatch. + + Activates when a full ``ForwardMsgData`` dict is available on the current + turn, carried by the current message (``ctx.forwarded_records``). + Resolves media to ``[kind|ybres:RID]`` + placeholders, appends downloadable refs to ``ctx.media_refs`` (for + :class:`MediaResolveMiddleware`), and rewrites ``ctx.raw_text``. + + Group @bot turns *without* a forward on the current message rely on the + eagerly-rendered summaries that :class:`GroupAtGuardMiddleware` writes to + the transcript at observe time — there is no run-time summary fallback + here. + + On any failure the middleware leaves ``ctx.raw_text`` untouched + (graceful degradation, design §2.8). + """ + + name = "forwarded-records-parse" + + async def handle(self, ctx: InboundContext, next_fn) -> None: + try: + if ctx.forwarded_records: + self._send_loading_heartbeat(ctx) + ctx.raw_text = self.build_forward_text(ctx.forwarded_records, ctx=ctx, is_dispatch=True) + except Exception as exc: + # Degrade gracefully: leave ctx.raw_text as-is. + logger.warning( + "[%s] forwarded-records deep parse failed: %s", + getattr(ctx.adapter, "name", "yuanbao"), exc, + ) + + await next_fn() + + # -- Heartbeat --------------------------------------------------------- + + @staticmethod + async def _send_loading_heartbeat(ctx: InboundContext) -> None: + """Best-effort RUNNING heartbeat so the user sees a loading bubble.""" + try: + await ctx.adapter._outbound.heartbeat.send_heartbeat_once( + ctx.chat_id, WS_HEARTBEAT_RUNNING, + ) + except Exception: + pass + + # -- Record rendering helpers ----------------------------------------- + + @classmethod + def _media_marker( + cls, media: dict, plain_text: str = "", + ) -> Tuple[str, Optional[Dict[str, str]]]: + """Render one ``msgContent.multimedia`` entry as a textual marker. + + Returns ``(marker, ref)``. Downloadable media emits a + ``[kind|ybres:RID]`` marker and a ``ctx.media_refs`` ref dict when a + usable RID/URL is present; otherwise a plain ``[kind] name`` marker + and ``ref=None``. + """ + media_type = (media.get("type", "") or media.get("doc_type", "")).strip().lower() + url = str(media.get("url") or "").strip() + media_id = str(media.get("media_id") or "").strip() + file_name = str(media.get("file_name") or "").strip() + # media_id is directly usable as a ybres RID (design §2.10.9); + # fall back to parsing the resourceId out of the URL. + rid = media_id or ExtractContentMiddleware._parse_resource_id(url) + + if media_type == "image": + if url and rid: + return f"[image|ybres:{rid}] {file_name}".rstrip(), {"kind": "image", "url": url} + return f"[image] {file_name or plain_text}".rstrip(), None + + if media_type in ("file", "document", "code"): + if url and rid: + ref: Dict[str, str] = {"kind": "file", "url": url} + if file_name: + ref["name"] = file_name + return f"[file|ybres:{rid}] {file_name}".rstrip(), ref + return f"[file] {file_name}".rstrip(), None + + if media_type == "url": + # Link share (e.g. WeChat article) — keep URL for the agent. + link_title = file_name or str(media.get("title") or "") + return f"[link] {link_title} {url}".rstrip(), None + + if media_type == "video": + if url and rid: + return f"[video|ybres:{rid}] {file_name}".rstrip(), {"kind": "video", "url": url} + return f"[video] {file_name or url}".rstrip(), None + + return f"[{media_type or 'media'}] {url or file_name}".rstrip(), None + + # Per-record combined-text cap; record count is NOT capped (design §2.10.3). + FORWARD_MSG_TEXT_MAX_CHARS = 1000 + + @classmethod + def _walk_forward_msgs( + cls, + forward_data: dict, + ) -> Iterator[Tuple[str, str, List[Dict[str, str]]]]: + """Walk ``ForwardMsgData['msg']`` and yield ``(sender, body, refs)``. + + Per-record dispatch over ``msgContent`` (text / multimedia / nested + forward / fallback); ``body`` is capped at + :attr:`FORWARD_MSG_TEXT_MAX_CHARS`. Media goes through + :meth:`_media_marker`, always building full ``[kind|ybres:RID]`` + markers; ``refs`` holds that record's downloadable ``ctx.media_refs`` + entries in textual order — the order PatchAnchorsMiddleware relies on + (design §2.10.6). Headers / footers are the caller's job. + """ + for msg in (forward_data.get("msg") if isinstance(forward_data, dict) else None) or []: + if not isinstance(msg, dict): + continue + sender = msg.get("sender", "") + plain_text = msg.get("plainText", "") + msg_contents = msg.get("msgContent", []) or [] + + refs: List[Dict[str, str]] = [] + if not msg_contents: + rendered = plain_text + else: + parts: List[str] = [] + for mc in msg_contents: + if not isinstance(mc, dict): + continue + mc_type = mc.get("type", 0) # EnumMsgContentType + if mc_type == 1: # TEXT + parts.append(mc.get("text", "")) + elif mc_type == 2: # MULTIMEDIA + for media in mc.get("multimedia", []) or []: + if isinstance(media, dict): + marker, ref = cls._media_marker( + media, plain_text, + ) + parts.append(marker) + if ref is not None: + refs.append(ref) + elif mc_type == 3: # nested FORWARD_MSG (design §2.10.10) + parts.append("[嵌套聊天记录]") + else: + if plain_text: + parts.append(plain_text) + rendered = " ".join(p for p in parts if p) or plain_text + + if len(rendered) > cls.FORWARD_MSG_TEXT_MAX_CHARS: + rendered = rendered[: cls.FORWARD_MSG_TEXT_MAX_CHARS] + "…(已截断)" + yield sender, rendered, refs + + # -- Prompt builders --------------------------------------------------- + + @classmethod + def build_forward_text( + cls, forward_data: dict, *, ctx: InboundContext, is_dispatch: bool, + ) -> str: + """Render ``ForwardMsgData`` into forward text. + + Body lines are ``发送人:正文`` with full ``[kind|ybres:RID]`` media + markers preserved. When ``is_dispatch`` is true, refs are appended to + ``ctx.media_refs`` for downstream resolution and a ``用户附言: + {ctx.raw_text}`` footer is added; observed callers skip both since + no later middleware runs. + """ + nickname = ctx.sender_nickname or "用户" + lines = [f"当前用户的昵称为{nickname}", "以下为用户的聊天记录"] + for sender, body, refs in cls._walk_forward_msgs(forward_data): + lines.append(f"{sender}:{body}") + if is_dispatch: + ctx.media_refs.extend(refs) + text = "\n".join(lines) + if is_dispatch and ctx.raw_text.strip(): + text += f"\n\n用户附言:{ctx.raw_text.strip()}" + return text + + class MediaResolveMiddleware(InboundMiddleware): """Resolve inbound media references to downloadable URLs.""" @@ -2273,9 +2552,6 @@ class MediaResolveMiddleware(InboundMiddleware): # --- Resource download cache (keyed by resourceId) --- # Avoids redundant downloads of the same resource within the TTL window. - # The same resourceId can be referenced multiple times in a session (own - # attachment, then quoted again, then observed in a group backfill); each - # reference otherwise triggers a fresh token exchange + download. _resource_cache: ClassVar[Dict[str, Tuple[str, str, float]]] = {} # rid -> (local_path, mime, ts) _RESOURCE_CACHE_TTL_S: ClassVar[int] = 24 * 60 * 60 # 24 hours _RESOURCE_CACHE_MAX_SIZE: ClassVar[int] = 256 @@ -2451,6 +2727,15 @@ class MediaResolveMiddleware(InboundMiddleware): cls._put_cached_resource(resource_id, local_path, mime) return local_path, mime + if kind == "video": + # Yuanbao video resources carry no reliable extension; default to mp4. + local_path = cache_video_from_bytes(file_bytes) + mime = guess_mime_type(local_path) or ( + content_type if content_type.startswith("video/") else "video/mp4" + ) + cls._put_cached_resource(resource_id, local_path, mime) + return local_path, mime + # kind == "file" if not file_name: parsed = urllib.parse.urlparse(fetch_url) @@ -2572,14 +2857,22 @@ class MediaResolveMiddleware(InboundMiddleware): if not history: return [], [] - start = max(0, len(history) - OBSERVED_MEDIA_BACKFILL_LOOKBACK) + # Walk the most recent LOOKBACK messages newest→oldest so that when we + # hit the per-turn resolve cap we keep the *latest* media references, + # not the oldest ones in the window. Within a single message, also + # iterate matches in reverse so the last-added image wins on ties. + # Final ``order`` is reversed back to chronological (old→new) before + # handing off to ``_resolve_ybres_refs`` so downstream prompt insertion + # preserves natural reading order. + window = history[-OBSERVED_MEDIA_BACKFILL_LOOKBACK:] order: List[Tuple[str, str, str]] = [] # (rid, kind, filename) seen: set = set() - for msg in history[start:]: + for msg in reversed(window): content = msg.get("content") if not isinstance(content, str) or "|ybres:" not in content: continue - for m in _YB_RES_REF_RE.finditer(content): + matches = list(_YB_RES_REF_RE.finditer(content)) + for m in reversed(matches): head = m.group(1) # "image" | "file:" | "voice" | "video" rid = m.group(2) kind, _, filename = head.partition(":") @@ -2595,6 +2888,9 @@ class MediaResolveMiddleware(InboundMiddleware): if len(order) >= OBSERVED_MEDIA_BACKFILL_MAX_RESOLVE_PER_TURN: break + # Restore chronological order (oldest→newest) for downstream resolution. + order.reverse() + if not order: return [], [] @@ -2640,9 +2936,7 @@ class MediaResolveMiddleware(InboundMiddleware): if not isinstance(text, str) or not text: return paths, mimes - # Already-local media paths written by PatchAnchorsMiddleware. The - # generic anchor regex covers every kind _patch emits (image/file today, - # video/audio if they later become resolvable) without per-kind upkeep. + # Already-local media paths written by PatchAnchorsMiddleware. seen: set = set() for m in _YB_LOCAL_MEDIA_RE.finditer(text): kind = (m.group(1) or "").strip().lower() @@ -2756,6 +3050,8 @@ class PatchAnchorsMiddleware(InboundMiddleware): elif kind == "file": label = filename.strip() or os.path.basename(u) replacement = f"[file: {label} → {u}]" + elif kind == "video": + replacement = f"[video: {u}]" else: continue patched = ( @@ -2790,7 +3086,11 @@ class DispatchMiddleware(InboundMiddleware): message_type=( MessageType.DOCUMENT if any(mt.startswith(("application/", "text/")) for mt in ctx.media_types) - else ctx.msg_type + # Coerce yuanbao-local subtypes (e.g. CHAT_RECORD) back to a + # base MessageType: chat records are deep-parsed into a text + # prompt, so TEXT is the right kind for downstream routing. + else ctx.msg_type if isinstance(ctx.msg_type, MessageType) + else MessageType.TEXT ), source=ctx.source, message_id=ctx.msg_id or None, @@ -2889,6 +3189,7 @@ class InboundPipelineBuilder: GroupAttributionMiddleware, ClassifyMessageTypeMiddleware, QuoteContextMiddleware, + ForwardedRecordsParseMiddleware, MediaResolveMiddleware, PatchAnchorsMiddleware, DispatchMiddleware, diff --git a/gateway/platforms/yuanbao_proto.py b/gateway/platforms/yuanbao_proto.py index cbe176aee5a..4eeb54c5559 100644 --- a/gateway/platforms/yuanbao_proto.py +++ b/gateway/platforms/yuanbao_proto.py @@ -492,6 +492,29 @@ def decode_biz_msg(data: bytes) -> dict: # field 10: url (string) # field 11: file_size (uint32) # field 12: file_name (string) +# field 999: ext_map (map) ← extension info for WeChat chat-history forwarding +# protobuf map is wire-encoded as a repeated message entry; each entry has: +# field 1: key (string) +# field 2: value (string) +# key format: wexin_forward_msg_[forward_msg_id]_[userid] +# value: base64(ForwardMsgData protobuf) ← NOT JSON; it is base64-encoded +# protobuf bytes that must be parsed with decode_forward_msg_data(). + + +def _encode_map_entry(key: str, value: str) -> bytes: + """Encode a single entry of a protobuf map (field 1 key, field 2 value).""" + buf = b"" + if key: + buf += _encode_field(1, WT_LEN, _encode_string(str(key))) + if value: + buf += _encode_field(2, WT_LEN, _encode_string(str(value))) + return buf + + +def _decode_map_entry(data: bytes) -> tuple[str, str]: + """Decode a single entry of a protobuf map, returning (key, value).""" + fdict = _fields_to_dict(_parse_fields(data)) + return _get_string(fdict, 1), _get_string(fdict, 2) def _encode_msg_content(content: dict) -> bytes: @@ -518,6 +541,12 @@ def _encode_msg_content(content: dict) -> bytes: if url: img_buf += _encode_field(5, WT_LEN, _encode_string(url)) buf += _encode_field(8, WT_LEN, _encode_message(img_buf)) + # ext_map (map, field 999) — repeated message entries + ext_map = content.get("ext_map") + if isinstance(ext_map, dict): + for k, v in ext_map.items(): + entry_bytes = _encode_map_entry(str(k), str(v)) + buf += _encode_field(999, WT_LEN, _encode_message(entry_bytes)) return buf @@ -550,6 +579,14 @@ def _decode_msg_content(data: bytes) -> dict: imgs.append(img) if imgs: content["image_info_array"] = imgs + # ext_map (field 999) — decode repeated map entries into a plain dict + ext_map: dict[str, str] = {} + for entry_bytes in _get_repeated_bytes(fdict, 999): + k, v = _decode_map_entry(entry_bytes) + if k: + ext_map[k] = v + if ext_map: + content["ext_map"] = ext_map return content @@ -710,9 +747,178 @@ def decode_inbound_push(data: bytes) -> Optional[dict]: # ============================================================ -# 出站消息编码 +# WeChat forwarded chat-history parsing (ForwardMsgData) # ============================================================ +# +# The value of ext_map["wexin_forward_msg__"] is a base64-encoded +# ForwardMsgData protobuf (NOT JSON). Structure (verified against live captures): +# +# message ForwardMsgData { +# uint32 sub_type = 1; // 1 = WeChat chat-history forward +# uint32 begin_time = 2; +# uint32 end_time = 3; +# string nick_name = 4; // forwarder's WeChat nickname +# repeated ForwardMsg msg = 5; +# } +# message ForwardMsg { +# string sender = 1; +# uint32 time = 2; +# string plainText = 3; +# repeated MsgContent msgContent = 4; +# } +# message MsgContent { +# uint32 type = 1; // 1=TEXT, 2=MULTIMEDIA, 3=nested forward +# string text = 2; // type==1 +# repeated Multimedia multimedia = 3; // type==2 +# } +# message Multimedia { +# string type = 1; // image / file / document / url / video +# string url = 2; +# string file_name = 4; +# uint32 file_size = 5; +# uint32 width = 6; +# uint32 height = 7; +# string media_id = 15; // can be used directly as a ybres RID +# string res_type = 24; +# } + +def _decode_forward_multimedia(data: bytes) -> dict: + """Decode a single Multimedia sub-message into the dict shape expected by _format_multimedia.""" + fdict = _fields_to_dict(_parse_fields(data)) + media: dict = {} + mtype = _get_string(fdict, 1) + if mtype: + media["type"] = mtype + url = _get_string(fdict, 2) + if url: + media["url"] = url + file_name = _get_string(fdict, 4) + if file_name: + media["file_name"] = file_name + file_size = _get_varint(fdict, 5) + if file_size: + media["file_size"] = file_size + media_id = _get_string(fdict, 15) + if media_id: + media["media_id"] = media_id + return media + + +def _decode_forward_msg_content(data: bytes) -> dict: + """Decode a single MsgContent sub-message into {type, text?, multimedia?}.""" + fdict = _fields_to_dict(_parse_fields(data)) + content: dict = {"type": _get_varint(fdict, 1)} + text = _get_string(fdict, 2) + if text: + content["text"] = text + multimedia = [ + _decode_forward_multimedia(b) for b in _get_repeated_bytes(fdict, 3) + ] + if multimedia: + content["multimedia"] = multimedia + return content + + +def _decode_forward_msg(data: bytes) -> dict: + """Decode a single ForwardMsg sub-message into {sender, plainText, msgContent}.""" + fdict = _fields_to_dict(_parse_fields(data)) + return { + "sender": _get_string(fdict, 1), + "time": _get_varint(fdict, 2), + "plainText": _get_string(fdict, 3), + "msgContent": [ + _decode_forward_msg_content(b) for b in _get_repeated_bytes(fdict, 4) + ], + } + + +def decode_forward_msg_data(data: bytes) -> Optional[dict]: + """Parse ForwardMsgData protobuf bytes (the base64-decoded ext_map value). + + Args: + data: ForwardMsgData protobuf bytes, after base64 decoding. + + Returns: + A dict matching the structure consumed by + ``ForwardedRecordsParseMiddleware.build_forward_text`` + (``sub_type`` / ``nick_name`` / ``msg`` list); ``None`` on parse failure. + """ + try: + fdict = _fields_to_dict(_parse_fields(data)) + return { + "sub_type": _get_varint(fdict, 1), + "begin_time": _get_varint(fdict, 2), + "end_time": _get_varint(fdict, 3), + "nick_name": _get_string(fdict, 4), + "msg": [_decode_forward_msg(b) for b in _get_repeated_bytes(fdict, 5)], + } + except Exception as e: + if DEBUG_MODE: + logger.debug("[yuanbao_proto] decode_forward_msg_data failed: %s", e) + return None + + +def _encode_forward_multimedia(media: dict) -> bytes: + buf = b"" + for fn, key in [(1, "type"), (2, "url"), (4, "file_name"), (15, "media_id")]: + v = media.get(key, "") + if v: + buf += _encode_field(fn, WT_LEN, _encode_string(str(v))) + for fn, key in [(5, "file_size"), (6, "width"), (7, "height")]: + v = media.get(key, 0) + if v: + buf += _encode_field(fn, WT_VARINT, _encode_varint(int(v))) + return buf + + +def _encode_forward_msg_content(content: dict) -> bytes: + buf = _encode_field(1, WT_VARINT, _encode_varint(int(content.get("type", 0)))) + text = content.get("text", "") + if text: + buf += _encode_field(2, WT_LEN, _encode_string(str(text))) + for media in content.get("multimedia") or []: + buf += _encode_field(3, WT_LEN, _encode_message(_encode_forward_multimedia(media))) + return buf + + +def _encode_forward_msg(msg: dict) -> bytes: + buf = b"" + sender = msg.get("sender", "") + if sender: + buf += _encode_field(1, WT_LEN, _encode_string(str(sender))) + time_val = msg.get("time", 0) + if time_val: + buf += _encode_field(2, WT_VARINT, _encode_varint(int(time_val))) + plain = msg.get("plainText", "") + if plain: + buf += _encode_field(3, WT_LEN, _encode_string(str(plain))) + for mc in msg.get("msgContent") or []: + buf += _encode_field(4, WT_LEN, _encode_message(_encode_forward_msg_content(mc))) + return buf + + +def encode_forward_msg_data(data: dict) -> bytes: + """Encode ForwardMsgData protobuf bytes (inverse of ``decode_forward_msg_data``). + + Mainly used to build mock / test data; production code never needs to encode this. + """ + buf = _encode_field(1, WT_VARINT, _encode_varint(int(data.get("sub_type", 0)))) + for fn, key in [(2, "begin_time"), (3, "end_time")]: + v = data.get(key, 0) + if v: + buf += _encode_field(fn, WT_VARINT, _encode_varint(int(v))) + nick = data.get("nick_name", "") + if nick: + buf += _encode_field(4, WT_LEN, _encode_string(str(nick))) + for msg in data.get("msg") or []: + buf += _encode_field(5, WT_LEN, _encode_message(_encode_forward_msg(msg))) + return buf + + +# ============================================================ +# Outbound message encoding +# ============================================================ def _encode_send_c2c_req( to_account: str, from_account: str, @@ -724,7 +930,7 @@ def _encode_send_c2c_req( trace_id: str = "", ) -> bytes: """ - 编码 SendC2CMessageReq biz payload。 + Encode a SendC2CMessageReq biz payload. SendC2CMessageReq fields: 1: msg_id (string) @@ -769,7 +975,7 @@ def _encode_send_group_req( trace_id: str = "", ) -> bytes: """ - 编码 SendGroupMessageReq biz payload。 + Encode a SendGroupMessageReq biz payload. SendGroupMessageReq fields: 1: msg_id (string) @@ -816,18 +1022,20 @@ def encode_send_c2c_message( trace_id: str = "", ) -> bytes: """ - 编码 C2C 发消息请求,返回完整 ConnMsg bytes(可直接发送到 WebSocket)。 + Encode a C2C send-message request and return the full ConnMsg bytes + (ready to be sent over WebSocket). Args: - to_account: 收件人账号 - msg_body: 消息体列表,每个元素: {"msg_type": str, "msg_content": dict} - 例如: [{"msg_type": "TIMTextElem", "msg_content": {"text": "hello"}}] - from_account: 发件人账号(机器人账号) - msg_id: 消息唯一 ID(空时使用 req_id) - msg_random: 随机数(防重) - msg_seq: 消息序列号(可选) - group_code: 来自群聊的私聊场景时填写 - trace_id: 链路追踪 ID + to_account: recipient account + msg_body: list of message-body elements; each item is + {"msg_type": str, "msg_content": dict}. + Example: [{"msg_type": "TIMTextElem", "msg_content": {"text": "hello"}}] + from_account: sender account (the bot account) + msg_id: unique message ID (req_id is used when empty) + msg_random: random number for de-duplication + msg_seq: message sequence number (optional) + group_code: filled in for the "private chat originating from a group" case + trace_id: trace ID for request tracing Returns: ConnMsg bytes @@ -866,18 +1074,19 @@ def encode_send_group_message( trace_id: str = "", ) -> bytes: """ - 编码群消息发送请求,返回完整 ConnMsg bytes(可直接发送到 WebSocket)。 + Encode a group send-message request and return the full ConnMsg bytes + (ready to be sent over WebSocket). Args: - group_code: 群号 - msg_body: 消息体列表 - from_account: 发件人账号(机器人账号) - msg_id: 消息唯一 ID - to_account: 指定接收者(一般为空) - random: 去重随机字符串 - msg_seq: 消息序列号 - ref_msg_id: 引用消息 ID - trace_id: 链路追踪 ID + group_code: group ID + msg_body: list of message-body elements + from_account: sender account (the bot account) + msg_id: unique message ID + to_account: targeted recipient (usually empty) + random: random string for de-duplication + msg_seq: message sequence number + ref_msg_id: ID of the referenced (quoted) message + trace_id: trace ID for request tracing Returns: ConnMsg bytes diff --git a/tests/test_yuanbao_pipeline.py b/tests/test_yuanbao_pipeline.py index 023d949b2cc..ac35f49647d 100644 --- a/tests/test_yuanbao_pipeline.py +++ b/tests/test_yuanbao_pipeline.py @@ -704,6 +704,7 @@ class TestCreateInboundPipeline: "group-attribution", "classify-msg-type", "quote-context", + "forwarded-records-parse", "media-resolve", "patch-anchors", "dispatch", @@ -1082,9 +1083,9 @@ class TestResolveYbresRefs: """Refs whose kind is outside ``_RESOLVABLE_MEDIA_KINDS`` are dropped silently.""" adapter = make_adapter() refs = [ - ("rid-v", "video", ""), # not resolvable + ("rid-a", "voice", ""), # not resolvable ("rid-i", "image", "ok.jpg"), # resolvable - ("rid-?", "unknown", ""), # not resolvable + ("rid-?", "unknown", ""), # not resolvable ] with patch.object(