hermes-agent/tests/test_yuanbao_pipeline.py
loongfay e20e0bd744
feat(Yuanbao): support wechat forward msg (#43508)
* feat(yuanbao): support wechat forward msg

* feat(yuanbao): support wechat forward msg

---------

Co-authored-by: loongfay <izhaolongfei@gmail.com>
2026-06-12 02:06:47 -07:00

1384 lines
49 KiB
Python

"""
test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline.
Tests cover:
1. InboundPipeline engine (use, use_before, use_after, remove, execute)
2. InboundContext dataclass
3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.)
4. InboundPipelineBuilder
5. End-to-end pipeline integration
6. OOP middleware ABC and class tests
"""
import sys
import os
import json
# Ensure project root is on the path
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from gateway.platforms.yuanbao import (
InboundContext,
InboundMiddleware,
InboundPipeline,
DecodeMiddleware,
ExtractFieldsMiddleware,
DedupMiddleware,
SkipSelfMiddleware,
ChatRoutingMiddleware,
AccessPolicy,
AccessGuardMiddleware,
ExtractContentMiddleware,
PlaceholderFilterMiddleware,
OwnerCommandMiddleware,
BuildSourceMiddleware,
GroupAtGuardMiddleware,
QuoteContextMiddleware,
MediaResolveMiddleware,
PatchAnchorsMiddleware,
DispatchMiddleware,
InboundPipelineBuilder,
YuanbaoAdapter,
)
from gateway.config import PlatformConfig
# ============================================================
# Helpers
# ============================================================
def make_config(**kwargs):
extra = kwargs.pop("extra", {})
extra.setdefault("app_id", "test_key")
extra.setdefault("app_secret", "test_secret")
extra.setdefault("ws_url", "wss://test.example.com/ws")
extra.setdefault("api_domain", "https://test.example.com")
return PlatformConfig(
extra=extra,
**kwargs,
)
def make_adapter(**kwargs) -> YuanbaoAdapter:
"""Create a YuanbaoAdapter with test config."""
config = make_config(**kwargs)
adapter = YuanbaoAdapter(config)
adapter._bot_id = "bot_123"
return adapter
def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext:
"""Create an InboundContext with sensible defaults for testing."""
if adapter is None:
adapter = make_adapter()
raw_frames = [conn_data] if conn_data else []
ctx = InboundContext(adapter=adapter, raw_frames=raw_frames)
for k, v in overrides.items():
setattr(ctx, k, v)
return ctx
def make_json_push(
from_account="alice",
to_account="bot_123",
group_code="",
text="Hello!",
msg_id="msg-001",
) -> bytes:
"""Build a JSON callback_command push payload.
Note: MsgContent inner fields use lowercase ("text" not "Text")
because _extract_text() looks for lowercase keys.
"""
msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}]
push = {
"CallbackCommand": "C2C.CallbackAfterSendMsg",
"From_Account": from_account,
"To_Account": to_account,
"MsgBody": msg_body,
"MsgKey": msg_id,
}
if group_code:
push["CallbackCommand"] = "Group.CallbackAfterSendMsg"
push["GroupId"] = group_code
return json.dumps(push).encode("utf-8")
# ============================================================
# 1. InboundPipeline Engine Tests
# ============================================================
class TestInboundPipeline:
"""Test the pipeline engine itself."""
@pytest.mark.asyncio
async def test_empty_pipeline(self):
"""Empty pipeline executes without error."""
pipeline = InboundPipeline()
ctx = make_ctx()
await pipeline.execute(ctx) # Should not raise
@pytest.mark.asyncio
async def test_single_middleware(self):
"""Single middleware is called with ctx and next_fn."""
called = []
async def mw(ctx, next_fn):
called.append("mw")
await next_fn()
pipeline = InboundPipeline().use("test", mw)
ctx = make_ctx()
await pipeline.execute(ctx)
assert called == ["mw"]
@pytest.mark.asyncio
async def test_middleware_order(self):
"""Middlewares execute in registration order."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_b(ctx, next_fn):
order.append("b")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c)
await pipeline.execute(make_ctx())
assert order == ["a", "b", "c"]
@pytest.mark.asyncio
async def test_middleware_can_stop_pipeline(self):
"""A middleware that doesn't call next_fn stops the pipeline."""
order = []
async def mw_stop(ctx, next_fn):
order.append("stop")
# Don't call next_fn — pipeline stops here
async def mw_after(ctx, next_fn):
order.append("after")
await next_fn()
pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after)
await pipeline.execute(make_ctx())
assert order == ["stop"] # "after" should NOT be called
@pytest.mark.asyncio
async def test_conditional_guard_skip(self):
"""Middleware with when=False is skipped."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_skipped(ctx, next_fn):
order.append("skipped")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = (
InboundPipeline()
.use("a", mw_a)
.use("skipped", mw_skipped, when=lambda ctx: False)
.use("c", mw_c)
)
await pipeline.execute(make_ctx())
assert order == ["a", "c"]
@pytest.mark.asyncio
async def test_conditional_guard_pass(self):
"""Middleware with when=True is executed."""
order = []
async def mw(ctx, next_fn):
order.append("mw")
await next_fn()
pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True)
await pipeline.execute(make_ctx())
assert order == ["mw"]
def test_use_before(self):
"""use_before inserts middleware before the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_before("c", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_before_nonexistent_appends(self):
"""use_before with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_before("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_use_after(self):
"""use_after inserts middleware after the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_after("a", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_after_nonexistent_appends(self):
"""use_after with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_after("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_remove(self):
"""remove deletes middleware by name."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop)
pipeline.remove("b")
assert pipeline.middleware_names == ["a", "c"]
def test_remove_nonexistent_is_noop(self):
"""remove with nonexistent name is a no-op."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.remove("nonexistent")
assert pipeline.middleware_names == ["a"]
@pytest.mark.asyncio
async def test_error_propagation(self):
"""Errors in middlewares propagate to the caller."""
async def mw_error(ctx, next_fn):
raise ValueError("test error")
pipeline = InboundPipeline().use("error", mw_error)
with pytest.raises(ValueError, match="test error"):
await pipeline.execute(make_ctx())
def test_middleware_names_property(self):
"""middleware_names returns ordered list of names."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = (
InboundPipeline()
.use("decode", noop)
.use("dedup", noop)
.use("dispatch", noop)
)
assert pipeline.middleware_names == ["decode", "dedup", "dispatch"]
@pytest.mark.asyncio
async def test_onion_model(self):
"""Middlewares support before/after processing (onion model)."""
order = []
async def mw_outer(ctx, next_fn):
order.append("outer-before")
await next_fn()
order.append("outer-after")
async def mw_inner(ctx, next_fn):
order.append("inner")
await next_fn()
pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner)
await pipeline.execute(make_ctx())
assert order == ["outer-before", "inner", "outer-after"]
# ============================================================
# 2. Individual Middleware Tests
# ============================================================
class TestDecodeMiddleware:
@pytest.mark.asyncio
async def test_json_decode(self):
"""DecodeMiddleware parses JSON push correctly."""
push_data = make_json_push(from_account="alice", text="hi")
ctx = make_ctx(conn_data=push_data)
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is not None
assert ctx.decoded_via == "json"
assert ctx.push.get("from_account") == "alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_empty_data_stops_pipeline(self):
"""DecodeMiddleware stops pipeline on empty conn_data."""
ctx = make_ctx(conn_data=b"")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is None
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_invalid_data_may_produce_garbage(self):
"""DecodeMiddleware: binary data may be parsed by protobuf as garbage fields.
This is expected behavior — the protobuf parser is lenient and may
produce "seemingly valid" fields from arbitrary bytes. The downstream
middlewares (dedup, skip-self, etc.) will filter out such garbage.
"""
ctx = make_ctx(conn_data=b"\x00\x01\x02\x03")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
# Protobuf parser may or may not produce a result — either is acceptable.
# The key invariant: no exception is raised.
assert True # Reached here without error
class TestExtractFieldsMiddleware:
@pytest.mark.asyncio
async def test_extracts_fields(self):
"""ExtractFieldsMiddleware populates ctx from push dict."""
ctx = make_ctx(push={
"from_account": "alice",
"group_code": "grp-1",
"group_name": "Test Group",
"sender_nickname": "Alice",
"msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
"msg_id": "msg-001",
"cloud_custom_data": '{"key": "val"}',
})
next_fn = AsyncMock()
await ExtractFieldsMiddleware()(ctx, next_fn)
assert ctx.from_account == "alice"
assert ctx.group_code == "grp-1"
assert ctx.group_name == "Test Group"
assert ctx.sender_nickname == "Alice"
assert len(ctx.msg_body) == 1
assert ctx.msg_id == "msg-001"
assert ctx.cloud_custom_data == '{"key": "val"}'
next_fn.assert_awaited_once()
class TestDedupMiddleware:
@pytest.mark.asyncio
async def test_new_message_passes(self):
"""DedupMiddleware passes new messages through."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_duplicate_stops_pipeline(self):
"""DedupMiddleware stops pipeline for duplicate messages."""
adapter = make_adapter()
# Mark message as seen
adapter._dedup.is_duplicate("dup-msg-001")
ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_msg_id_passes(self):
"""DedupMiddleware passes messages with empty msg_id."""
ctx = make_ctx(msg_id="")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestSkipSelfMiddleware:
@pytest.mark.asyncio
async def test_self_message_stops(self):
"""SkipSelfMiddleware stops pipeline for bot's own messages."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="bot_123")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_other_message_passes(self):
"""SkipSelfMiddleware passes messages from other users."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="alice")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestChatRoutingMiddleware:
@pytest.mark.asyncio
async def test_group_routing(self):
"""ChatRoutingMiddleware sets group chat fields."""
ctx = make_ctx(group_code="grp-1", group_name="Test Group")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "group:grp-1"
assert ctx.chat_type == "group"
assert ctx.chat_name == "Test Group"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing(self):
"""ChatRoutingMiddleware sets DM chat fields."""
ctx = make_ctx(from_account="alice", sender_nickname="Alice")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "direct:alice"
assert ctx.chat_type == "dm"
assert ctx.chat_name == "Alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing_no_nickname(self):
"""ChatRoutingMiddleware falls back to from_account when no nickname."""
ctx = make_ctx(from_account="alice", sender_nickname="")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_name == "alice"
class TestAccessGuardMiddleware:
@pytest.mark.asyncio
async def test_open_policy_passes(self):
"""AccessGuardMiddleware passes with open policy."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_disabled_dm_stops(self):
"""AccessGuardMiddleware stops DM when dm_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_dm_allowed(self):
"""AccessGuardMiddleware passes DM when sender is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_allowlist_dm_blocked(self):
"""AccessGuardMiddleware blocks DM when sender is not in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_disabled_group_stops(self):
"""AccessGuardMiddleware stops group when group_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_group_allowed(self):
"""AccessGuardMiddleware passes group when group_code is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestExtractContentMiddleware:
@pytest.mark.asyncio
async def test_extracts_text_and_media(self):
"""ExtractContentMiddleware extracts text and media refs."""
adapter = make_adapter()
msg_body = [
{"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}},
{"msg_type": "TIMImageElem", "msg_content": {
"image_info_array": [{"url": "https://img.example.com/1.jpg"}]
}},
]
ctx = make_ctx(adapter=adapter, msg_body=msg_body)
next_fn = AsyncMock()
await ExtractContentMiddleware()(ctx, next_fn)
assert "Hello!" in ctx.raw_text
assert len(ctx.media_refs) == 1
assert ctx.media_refs[0]["kind"] == "image"
next_fn.assert_awaited_once()
class TestPlaceholderFilterMiddleware:
@pytest.mark.asyncio
async def test_placeholder_stops(self):
"""PlaceholderFilterMiddleware stops on pure placeholder."""
ctx = make_ctx(raw_text="[image]", media_refs=[])
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_placeholder_with_media_passes(self):
"""PlaceholderFilterMiddleware passes placeholder when media exists."""
ctx = make_ctx(
raw_text="[image]",
media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}],
)
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_normal_text_passes(self):
"""PlaceholderFilterMiddleware passes normal text."""
ctx = make_ctx(raw_text="Hello world!")
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestGroupAtGuardMiddleware:
@pytest.mark.asyncio
async def test_dm_passes(self):
"""GroupAtGuardMiddleware passes DM messages."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, chat_type="dm")
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_with_at_bot_passes(self):
"""GroupAtGuardMiddleware passes group messages that @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
msg_body = [
{"msg_type": "TIMCustomElem", "msg_content": {
"data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"})
}},
]
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=msg_body,
from_account="alice",
sender_nickname="Alice",
raw_text="Hello",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_without_at_bot_observes(self):
"""GroupAtGuardMiddleware observes group messages without @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._session_store = None # No session store -> observe is a no-op
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
from_account="alice",
sender_nickname="Alice",
raw_text="hi",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_owner_command_skips_at_check(self):
"""GroupAtGuardMiddleware passes when owner_command is set."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(
adapter=adapter,
chat_type="group",
msg_body=[],
owner_command="/new",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
# ============================================================
# 4. Factory Tests
# ============================================================
class TestCreateInboundPipeline:
def test_default_pipeline_has_all_middlewares(self):
"""InboundPipelineBuilder.build() creates pipeline with all expected middlewares."""
pipeline = InboundPipelineBuilder.build()
expected = [
"decode",
"extract-fields",
"recall_guard",
"dedup",
"skip-self",
"chat-routing",
"access-guard",
"auto-sethome",
"extract-content",
"placeholder-filter",
"owner-command",
"build-source",
"group-at-guard",
"group-attribution",
"classify-msg-type",
"quote-context",
"forwarded-records-parse",
"media-resolve",
"patch-anchors",
"dispatch",
]
assert pipeline.middleware_names == expected
# ============================================================
# 5. End-to-End Pipeline Integration Tests
# ============================================================
class TestPipelineIntegration:
@pytest.mark.asyncio
async def test_full_dm_message_flow(self):
"""Full pipeline processes a DM message end-to-end."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
adapter.handle_message = AsyncMock()
adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], []))
push_data = make_json_push(
from_account="alice",
to_account="bot_123",
text="Hello bot!",
msg_id="msg-e2e-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Verify context was populated correctly
assert ctx.decoded_via == "json"
assert ctx.from_account == "alice"
assert ctx.chat_type == "dm"
assert ctx.chat_id == "direct:alice"
assert "Hello bot!" in ctx.raw_text
assert ctx.source is not None
@pytest.mark.asyncio
async def test_self_message_filtered(self):
"""Pipeline stops when message is from bot itself."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
push_data = make_json_push(
from_account="bot_123",
to_account="bot_123",
text="echo",
msg_id="msg-self-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline should have stopped at skip-self — no source built
assert ctx.source is None
@pytest.mark.asyncio
async def test_duplicate_message_filtered(self):
"""Pipeline stops on duplicate message."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
# First message goes through
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-dup-001",
)
ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx1)
assert ctx1.from_account == "alice"
# Second message with same msg_id is filtered
ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data])
await pipeline.execute(ctx2)
# Dedup should stop pipeline before chat routing
assert ctx2.chat_type == ""
@pytest.mark.asyncio
async def test_blocked_dm_filtered(self):
"""Pipeline stops when DM is blocked by policy."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-blocked-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline stopped at access-guard — no content extracted
assert ctx.raw_text == ""
@pytest.mark.asyncio
async def test_adapter_has_pipeline(self):
"""YuanbaoAdapter.__init__ creates an inbound pipeline."""
adapter = make_adapter()
assert hasattr(adapter, "_inbound_pipeline")
assert isinstance(adapter._inbound_pipeline, InboundPipeline)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
# ============================================================
# 6. OOP Middleware Tests
# ============================================================
class TestInboundMiddlewareABC:
"""Test the InboundMiddleware OOP protocol (callable + named)."""
def test_subclass_with_handle_works(self):
"""Subclass with handle() can be instantiated."""
class GoodMiddleware(InboundMiddleware):
name = "good"
async def handle(self, ctx, next_fn):
await next_fn()
mw = GoodMiddleware()
assert mw.name == "good"
@pytest.mark.asyncio
async def test_callable_protocol(self):
"""Middleware instances are callable via __call__."""
class TestMW(InboundMiddleware):
name = "test"
async def handle(self, ctx, next_fn):
ctx.raw_text = "called"
await next_fn()
mw = TestMW()
ctx = make_ctx()
next_fn = AsyncMock()
await mw(ctx, next_fn) # Call via __call__
assert ctx.raw_text == "called"
next_fn.assert_awaited_once()
class TestMiddlewareClasses:
"""Pin the canonical ``name`` of each concrete middleware class.
These names are referenced by ``InboundPipelineBuilder.build()`` ordering,
by ``use_before`` / ``use_after`` insertion in extensions, and by log
messages — so they're a real downstream contract worth pinning.
"""
MIDDLEWARE_CLASSES = [
(DecodeMiddleware, "decode"),
(ExtractFieldsMiddleware, "extract-fields"),
(DedupMiddleware, "dedup"),
(SkipSelfMiddleware, "skip-self"),
(ChatRoutingMiddleware, "chat-routing"),
(AccessGuardMiddleware, "access-guard"),
(ExtractContentMiddleware, "extract-content"),
(PlaceholderFilterMiddleware, "placeholder-filter"),
(OwnerCommandMiddleware, "owner-command"),
(BuildSourceMiddleware, "build-source"),
(GroupAtGuardMiddleware, "group-at-guard"),
(DispatchMiddleware, "dispatch"),
]
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
def test_has_correct_name(self, cls, expected_name):
"""Each middleware class has the expected name."""
mw = cls()
assert mw.name == expected_name
class TestPipelineOOPRegistration:
"""Test that InboundPipeline works with OOP middleware instances."""
@pytest.mark.asyncio
async def test_use_with_middleware_instance(self):
"""pipeline.use(SomeMiddleware()) auto-extracts name."""
class TestMW(InboundMiddleware):
name = "test-mw"
async def handle(self, ctx, next_fn):
ctx.raw_text = "oop-works"
await next_fn()
pipeline = InboundPipeline().use(TestMW())
assert pipeline.middleware_names == ["test-mw"]
ctx = make_ctx()
await pipeline.execute(ctx)
assert ctx.raw_text == "oop-works"
@pytest.mark.asyncio
async def test_mixed_oop_and_functional(self):
"""Pipeline supports mixing OOP and functional middlewares."""
order = []
class OopMW(InboundMiddleware):
name = "oop"
async def handle(self, ctx, next_fn):
order.append("oop")
await next_fn()
async def func_mw(ctx, next_fn):
order.append("func")
await next_fn()
pipeline = (
InboundPipeline()
.use(OopMW())
.use("func", func_mw)
)
assert pipeline.middleware_names == ["oop", "func"]
await pipeline.execute(make_ctx())
assert order == ["oop", "func"]
# ============================================================
# QuoteContextMiddleware Tests
# ============================================================
#
# Quote-media resolution used to depend on a process-local
# msg_id→resids cache populated by ExtractContentMiddleware. After #27866
# made gateway/run.py write @bot user transcript entries with
# message_id (symmetric with the observed-group writer at yuanbao.py:2091),
# QuoteContextMiddleware's transcript-lookup path covers every quote case
# we used to rely on the cache for, so the cache (and those tests) were
# removed. ``_extract_quote_context()`` is now a pure (quote_id, quote_text)
# extractor; quote media references are populated separately by
# ``_extract_media_refs_from_transcript()`` against the transcript store.
class TestQuoteContextMiddleware:
"""Tests for QuoteContextMiddleware._extract_quote_context."""
def test_extract_quote_context_no_cloud_data(self):
"""Returns (None, None) when cloud_custom_data is empty."""
result = QuoteContextMiddleware()._extract_quote_context("")
assert result == (None, None)
def test_extract_quote_context_no_quote_key(self):
"""Returns (None, None) when JSON has no 'quote' key."""
cloud_data = json.dumps({"foo": "bar"})
result = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert result == (None, None)
def test_extract_quote_context_with_desc(self):
"""Extracts quote_id and quote_text from desc."""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-001",
"desc": "Hello world",
"sender_nickname": "Alice",
}
})
quote_id, quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id == "quoted-msg-001"
assert quote_text == "Alice: Hello world"
def test_extract_quote_context_empty_desc(self):
"""When desc is empty, quote_text is None but quote_id is preserved."""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-003",
"desc": "",
"sender_nickname": "Carol",
}
})
quote_id, quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id == "quoted-msg-003"
assert quote_text is None
def test_extract_quote_context_no_quote_id(self):
"""When quote.id is empty, quote_id is None."""
cloud_data = json.dumps({
"quote": {
"id": "",
"desc": "some text",
}
})
quote_id, _quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id is None
@pytest.mark.asyncio
async def test_handle_sets_ctx_fields(self):
"""QuoteContextMiddleware.handle() sets ctx.reply_to_message_id, reply_to_text, quote_media_refs.
With no transcript store wired up, quote_media_refs falls back to []
— media resolution from transcript is covered by separate tests.
"""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-004",
"desc": "Check this image",
"sender_nickname": "Dave",
}
})
adapter = make_adapter()
adapter._session_store = None # no transcript lookup path
ctx = make_ctx(adapter=adapter, cloud_custom_data=cloud_data)
next_fn = AsyncMock()
await QuoteContextMiddleware()(ctx, next_fn)
assert ctx.reply_to_message_id == "quoted-msg-004"
assert ctx.reply_to_text == "Dave: Check this image"
assert ctx.quote_media_refs == []
next_fn.assert_awaited_once()
# ============================================================
# MediaResolveMiddleware Tests
# ============================================================
#
# After the dispatch refactor, MediaResolveMiddleware is the single entry
# point for all inbound media downloads. It merges up to three sources
# into ``ctx.media_urls`` / ``ctx.media_types`` (deduped, in this order):
#
# 1) media carried by the current message itself (always),
# 2) quote_media_refs (when reply_to_message_id is set),
# 3) recent group-observed media (only when chat_type == "group" and
# no quote is present).
#
# Direct messages skip the observed backfill entirely.
class TestResolveYbresRefs:
"""Direct tests for ``MediaResolveMiddleware._resolve_ybres_refs``.
This classmethod is the shared engine for both ``_resolve_quote_media``
and ``_collect_observed_media``. Patching the two upstream callers from
routing tests doesn't exercise its filtering / error-swallowing
behavior, so we pin those contracts directly here.
"""
@pytest.mark.asyncio
async def test_resolves_each_ref_in_order(self):
"""Successful resolution returns ``(paths, mimes)`` aligned with input order."""
adapter = make_adapter()
refs = [
("rid-1", "image", "a.jpg"),
("rid-2", "file", "doc.pdf"),
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(side_effect=["https://fresh/1", "https://fresh/2"]),
) as p_fetch, patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(side_effect=[
("/cache/a.jpg", "image/jpeg"),
("/cache/doc.pdf", "application/pdf"),
]),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == ["/cache/a.jpg", "/cache/doc.pdf"]
assert mimes == ["image/jpeg", "application/pdf"]
assert p_fetch.await_count == 2
# filename from the ref tuple is forwarded to download_and_cache
cache_kwargs = [c.kwargs for c in p_cache.await_args_list]
assert cache_kwargs[0]["file_name"] == "a.jpg"
assert cache_kwargs[0]["kind"] == "image"
assert cache_kwargs[0]["resource_id"] == "rid-1"
assert cache_kwargs[1]["file_name"] == "doc.pdf"
@pytest.mark.asyncio
async def test_skips_unresolvable_kinds(self):
"""Refs whose kind is outside ``_RESOLVABLE_MEDIA_KINDS`` are dropped silently."""
adapter = make_adapter()
refs = [
("rid-a", "voice", ""), # not resolvable
("rid-i", "image", "ok.jpg"), # resolvable
("rid-?", "unknown", ""), # not resolvable
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(return_value="https://fresh/i"),
) as p_fetch, patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=("/cache/ok.jpg", "image/jpeg")),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == ["/cache/ok.jpg"]
assert mimes == ["image/jpeg"]
# Only the resolvable ref hit the network.
p_fetch.assert_awaited_once()
p_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_fetch_failure_is_swallowed_per_ref(self):
"""If ``_fetch_resource_url`` raises, that ref is skipped — not the whole batch."""
adapter = make_adapter()
refs = [
("rid-bad", "image", ""),
("rid-ok", "image", ""),
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(side_effect=[RuntimeError("boom"), "https://fresh/ok"]),
), patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=("/cache/ok.jpg", "image/jpeg")),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
# bad ref dropped; good ref preserved
assert paths == ["/cache/ok.jpg"]
assert mimes == ["image/jpeg"]
# download_and_cache was only invoked for the surviving ref
p_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_cache_miss_drops_ref(self):
"""If ``_download_and_cache`` returns None, the ref is dropped."""
adapter = make_adapter()
refs = [("rid-1", "image", "")]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(return_value="https://fresh/1"),
), patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=None),
):
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == []
assert mimes == []
class TestMediaResolveMiddlewareRouting:
"""Branch-routing tests for MediaResolveMiddleware.handle()."""
def _make_resolved_ctx(self, *, chat_type: str, reply_to: str = None,
quote_media_refs=None, raw_text: str = "hello"):
adapter = make_adapter()
ctx = make_ctx(
adapter=adapter,
chat_type=chat_type,
reply_to_message_id=reply_to,
quote_media_refs=list(quote_media_refs or []),
raw_text=raw_text,
media_refs=[], # no own attachments by default
)
return adapter, ctx
@pytest.mark.asyncio
async def test_dm_no_quote_skips_observed_backfill(self):
"""In dm chats, observed-media backfill is never invoked."""
_adapter, ctx = self._make_resolved_ctx(chat_type="dm")
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
) as p_own, patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=([], [])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=([], [])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_own.assert_awaited_once()
p_quote.assert_not_awaited()
p_observed.assert_not_awaited()
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_with_quote_resolves_quote_media_only(self):
"""In dm chats with a quote, only quote media (plus own) is resolved."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-001",
quote_media_refs=[("rid-q1", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
) as p_own, patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q1.jpg"], ["image/jpeg"])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=([], [])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_own.assert_awaited_once()
p_quote.assert_awaited_once()
p_observed.assert_not_awaited()
assert ctx.media_urls == ["/cache/q1.jpg"]
assert ctx.media_types == ["image/jpeg"]
@pytest.mark.asyncio
async def test_group_no_quote_runs_observed_backfill(self):
"""In group chats without quote, observed-media backfill is invoked."""
_adapter, ctx = self._make_resolved_ctx(chat_type="group")
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=([], [])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=(["/cache/o1.jpg"], ["image/jpeg"])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_quote.assert_not_awaited()
p_observed.assert_awaited_once()
assert ctx.media_urls == ["/cache/o1.jpg"]
@pytest.mark.asyncio
async def test_group_with_quote_skips_observed_backfill(self):
"""In group chats with a quote, only quote media is resolved (no backfill)."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="group",
reply_to="quoted-002",
quote_media_refs=[("rid-q2", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q2.jpg"], ["image/jpeg"])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=(["/cache/o2.jpg"], ["image/jpeg"])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_quote.assert_awaited_once()
p_observed.assert_not_awaited()
assert ctx.media_urls == ["/cache/q2.jpg"]
@pytest.mark.asyncio
async def test_merges_and_dedupes_three_sources(self):
"""Own + quote sources are merged with dedup applied."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-003",
quote_media_refs=[("rid", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=(["/cache/a.jpg"], ["image/jpeg"])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
# Same path as own → must be deduped.
new=AsyncMock(return_value=(["/cache/a.jpg", "/cache/b.jpg"],
["image/jpeg", "image/png"])),
):
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
assert ctx.media_urls == ["/cache/a.jpg", "/cache/b.jpg"]
assert ctx.media_types == ["image/jpeg", "image/png"]
@pytest.mark.asyncio
async def test_placeholder_recheck_uses_own_count_only(self):
"""Placeholder retry-skip uses ``own_count``, not the merged total.
A bare placeholder text (e.g. ``[image]``) accompanied only by a
quote-resolved image is still skippable — quote media must not
flip a placeholder into a non-placeholder.
"""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-004",
quote_media_refs=[("rid", "image", "")],
raw_text="[image]",
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])), # no own media
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q.jpg"], ["image/jpeg"])),
), patch.object(
PlaceholderFilterMiddleware, "is_skippable_placeholder",
return_value=True,
) as p_check:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
# Pipeline short-circuited despite quote media being present.
next_fn.assert_not_awaited()
# And the second-pass check was called with own_count == 0.
p_check.assert_called_once()
_text_arg, count_arg = p_check.call_args.args
assert count_arg == 0
# ============================================================
# PatchAnchorsMiddleware Tests
# ============================================================
class TestPatchAnchorsMiddleware:
"""Tests for PatchAnchorsMiddleware._patch()."""
def test_no_op_when_text_or_urls_empty(self):
assert PatchAnchorsMiddleware._patch("", [], []) == ""
assert PatchAnchorsMiddleware._patch("hello", [], []) == "hello"
def test_replaces_image_anchor_with_local_path(self):
text = "look [image|ybres:abc] please"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/x.jpg"], ["image/jpeg"],
)
assert out == "look [image: /cache/x.jpg] please"
def test_replaces_file_anchor_with_filename_label(self):
text = "see [file:doc.pdf|ybres:rid-1]"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/doc.pdf"], ["application/pdf"],
)
assert "[file: doc.pdf → /cache/doc.pdf]" in out
def test_skips_non_local_paths(self):
"""URLs not starting with '/' are left untouched."""
text = "[image|ybres:abc]"
out = PatchAnchorsMiddleware._patch(
text, ["https://example.com/x.jpg"], ["image/jpeg"],
)
# Anchor preserved verbatim because the resolved url is remote.
assert out == text
def test_anchor_kind_image_requires_image_mime(self):
"""An [image|...] anchor with a non-image mime is left alone."""
text = "[image|ybres:rid]"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/odd.bin"], ["application/octet-stream"],
)
assert out == text
@pytest.mark.asyncio
async def test_handle_writes_back_to_ctx(self):
adapter = make_adapter()
ctx = make_ctx(
adapter=adapter,
raw_text="hi [image|ybres:rid]",
media_urls=["/cache/y.png"],
media_types=["image/png"],
)
next_fn = AsyncMock()
await PatchAnchorsMiddleware()(ctx, next_fn)
assert ctx.raw_text == "hi [image: /cache/y.png]"
next_fn.assert_awaited_once()