hermes-agent/tests/test_yuanbao_pipeline.py
teknium1 24a934295f test(yuanbao): add missing patch import to pipeline tests
The salvaged refactor's new tests use unittest.mock.patch (25 call sites)
but the import line only brought in AsyncMock and MagicMock, so 10 of the
new tests failed with NameError. Add patch to the import.
2026-06-09 03:17:00 -07:00

1383 lines
49 KiB
Python

"""
test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline.
Tests cover:
1. InboundPipeline engine (use, use_before, use_after, remove, execute)
2. InboundContext dataclass
3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.)
4. InboundPipelineBuilder
5. End-to-end pipeline integration
6. OOP middleware ABC and class tests
"""
import sys
import os
import json
# Ensure project root is on the path
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from gateway.platforms.yuanbao import (
InboundContext,
InboundMiddleware,
InboundPipeline,
DecodeMiddleware,
ExtractFieldsMiddleware,
DedupMiddleware,
SkipSelfMiddleware,
ChatRoutingMiddleware,
AccessPolicy,
AccessGuardMiddleware,
ExtractContentMiddleware,
PlaceholderFilterMiddleware,
OwnerCommandMiddleware,
BuildSourceMiddleware,
GroupAtGuardMiddleware,
QuoteContextMiddleware,
MediaResolveMiddleware,
PatchAnchorsMiddleware,
DispatchMiddleware,
InboundPipelineBuilder,
YuanbaoAdapter,
)
from gateway.config import PlatformConfig
# ============================================================
# Helpers
# ============================================================
def make_config(**kwargs):
extra = kwargs.pop("extra", {})
extra.setdefault("app_id", "test_key")
extra.setdefault("app_secret", "test_secret")
extra.setdefault("ws_url", "wss://test.example.com/ws")
extra.setdefault("api_domain", "https://test.example.com")
return PlatformConfig(
extra=extra,
**kwargs,
)
def make_adapter(**kwargs) -> YuanbaoAdapter:
"""Create a YuanbaoAdapter with test config."""
config = make_config(**kwargs)
adapter = YuanbaoAdapter(config)
adapter._bot_id = "bot_123"
return adapter
def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext:
"""Create an InboundContext with sensible defaults for testing."""
if adapter is None:
adapter = make_adapter()
raw_frames = [conn_data] if conn_data else []
ctx = InboundContext(adapter=adapter, raw_frames=raw_frames)
for k, v in overrides.items():
setattr(ctx, k, v)
return ctx
def make_json_push(
from_account="alice",
to_account="bot_123",
group_code="",
text="Hello!",
msg_id="msg-001",
) -> bytes:
"""Build a JSON callback_command push payload.
Note: MsgContent inner fields use lowercase ("text" not "Text")
because _extract_text() looks for lowercase keys.
"""
msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}]
push = {
"CallbackCommand": "C2C.CallbackAfterSendMsg",
"From_Account": from_account,
"To_Account": to_account,
"MsgBody": msg_body,
"MsgKey": msg_id,
}
if group_code:
push["CallbackCommand"] = "Group.CallbackAfterSendMsg"
push["GroupId"] = group_code
return json.dumps(push).encode("utf-8")
# ============================================================
# 1. InboundPipeline Engine Tests
# ============================================================
class TestInboundPipeline:
"""Test the pipeline engine itself."""
@pytest.mark.asyncio
async def test_empty_pipeline(self):
"""Empty pipeline executes without error."""
pipeline = InboundPipeline()
ctx = make_ctx()
await pipeline.execute(ctx) # Should not raise
@pytest.mark.asyncio
async def test_single_middleware(self):
"""Single middleware is called with ctx and next_fn."""
called = []
async def mw(ctx, next_fn):
called.append("mw")
await next_fn()
pipeline = InboundPipeline().use("test", mw)
ctx = make_ctx()
await pipeline.execute(ctx)
assert called == ["mw"]
@pytest.mark.asyncio
async def test_middleware_order(self):
"""Middlewares execute in registration order."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_b(ctx, next_fn):
order.append("b")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c)
await pipeline.execute(make_ctx())
assert order == ["a", "b", "c"]
@pytest.mark.asyncio
async def test_middleware_can_stop_pipeline(self):
"""A middleware that doesn't call next_fn stops the pipeline."""
order = []
async def mw_stop(ctx, next_fn):
order.append("stop")
# Don't call next_fn — pipeline stops here
async def mw_after(ctx, next_fn):
order.append("after")
await next_fn()
pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after)
await pipeline.execute(make_ctx())
assert order == ["stop"] # "after" should NOT be called
@pytest.mark.asyncio
async def test_conditional_guard_skip(self):
"""Middleware with when=False is skipped."""
order = []
async def mw_a(ctx, next_fn):
order.append("a")
await next_fn()
async def mw_skipped(ctx, next_fn):
order.append("skipped")
await next_fn()
async def mw_c(ctx, next_fn):
order.append("c")
await next_fn()
pipeline = (
InboundPipeline()
.use("a", mw_a)
.use("skipped", mw_skipped, when=lambda ctx: False)
.use("c", mw_c)
)
await pipeline.execute(make_ctx())
assert order == ["a", "c"]
@pytest.mark.asyncio
async def test_conditional_guard_pass(self):
"""Middleware with when=True is executed."""
order = []
async def mw(ctx, next_fn):
order.append("mw")
await next_fn()
pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True)
await pipeline.execute(make_ctx())
assert order == ["mw"]
def test_use_before(self):
"""use_before inserts middleware before the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_before("c", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_before_nonexistent_appends(self):
"""use_before with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_before("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_use_after(self):
"""use_after inserts middleware after the target."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("c", noop)
pipeline.use_after("a", "b", noop)
assert pipeline.middleware_names == ["a", "b", "c"]
def test_use_after_nonexistent_appends(self):
"""use_after with nonexistent target appends to end."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.use_after("nonexistent", "b", noop)
assert pipeline.middleware_names == ["a", "b"]
def test_remove(self):
"""remove deletes middleware by name."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop)
pipeline.remove("b")
assert pipeline.middleware_names == ["a", "c"]
def test_remove_nonexistent_is_noop(self):
"""remove with nonexistent name is a no-op."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = InboundPipeline().use("a", noop)
pipeline.remove("nonexistent")
assert pipeline.middleware_names == ["a"]
@pytest.mark.asyncio
async def test_error_propagation(self):
"""Errors in middlewares propagate to the caller."""
async def mw_error(ctx, next_fn):
raise ValueError("test error")
pipeline = InboundPipeline().use("error", mw_error)
with pytest.raises(ValueError, match="test error"):
await pipeline.execute(make_ctx())
def test_middleware_names_property(self):
"""middleware_names returns ordered list of names."""
async def noop(ctx, next_fn):
await next_fn()
pipeline = (
InboundPipeline()
.use("decode", noop)
.use("dedup", noop)
.use("dispatch", noop)
)
assert pipeline.middleware_names == ["decode", "dedup", "dispatch"]
@pytest.mark.asyncio
async def test_onion_model(self):
"""Middlewares support before/after processing (onion model)."""
order = []
async def mw_outer(ctx, next_fn):
order.append("outer-before")
await next_fn()
order.append("outer-after")
async def mw_inner(ctx, next_fn):
order.append("inner")
await next_fn()
pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner)
await pipeline.execute(make_ctx())
assert order == ["outer-before", "inner", "outer-after"]
# ============================================================
# 2. Individual Middleware Tests
# ============================================================
class TestDecodeMiddleware:
@pytest.mark.asyncio
async def test_json_decode(self):
"""DecodeMiddleware parses JSON push correctly."""
push_data = make_json_push(from_account="alice", text="hi")
ctx = make_ctx(conn_data=push_data)
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is not None
assert ctx.decoded_via == "json"
assert ctx.push.get("from_account") == "alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_empty_data_stops_pipeline(self):
"""DecodeMiddleware stops pipeline on empty conn_data."""
ctx = make_ctx(conn_data=b"")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
assert ctx.push is None
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_invalid_data_may_produce_garbage(self):
"""DecodeMiddleware: binary data may be parsed by protobuf as garbage fields.
This is expected behavior — the protobuf parser is lenient and may
produce "seemingly valid" fields from arbitrary bytes. The downstream
middlewares (dedup, skip-self, etc.) will filter out such garbage.
"""
ctx = make_ctx(conn_data=b"\x00\x01\x02\x03")
next_fn = AsyncMock()
await DecodeMiddleware()(ctx, next_fn)
# Protobuf parser may or may not produce a result — either is acceptable.
# The key invariant: no exception is raised.
assert True # Reached here without error
class TestExtractFieldsMiddleware:
@pytest.mark.asyncio
async def test_extracts_fields(self):
"""ExtractFieldsMiddleware populates ctx from push dict."""
ctx = make_ctx(push={
"from_account": "alice",
"group_code": "grp-1",
"group_name": "Test Group",
"sender_nickname": "Alice",
"msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
"msg_id": "msg-001",
"cloud_custom_data": '{"key": "val"}',
})
next_fn = AsyncMock()
await ExtractFieldsMiddleware()(ctx, next_fn)
assert ctx.from_account == "alice"
assert ctx.group_code == "grp-1"
assert ctx.group_name == "Test Group"
assert ctx.sender_nickname == "Alice"
assert len(ctx.msg_body) == 1
assert ctx.msg_id == "msg-001"
assert ctx.cloud_custom_data == '{"key": "val"}'
next_fn.assert_awaited_once()
class TestDedupMiddleware:
@pytest.mark.asyncio
async def test_new_message_passes(self):
"""DedupMiddleware passes new messages through."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_duplicate_stops_pipeline(self):
"""DedupMiddleware stops pipeline for duplicate messages."""
adapter = make_adapter()
# Mark message as seen
adapter._dedup.is_duplicate("dup-msg-001")
ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_msg_id_passes(self):
"""DedupMiddleware passes messages with empty msg_id."""
ctx = make_ctx(msg_id="")
next_fn = AsyncMock()
await DedupMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestSkipSelfMiddleware:
@pytest.mark.asyncio
async def test_self_message_stops(self):
"""SkipSelfMiddleware stops pipeline for bot's own messages."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="bot_123")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_other_message_passes(self):
"""SkipSelfMiddleware passes messages from other users."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(adapter=adapter, from_account="alice")
next_fn = AsyncMock()
await SkipSelfMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestChatRoutingMiddleware:
@pytest.mark.asyncio
async def test_group_routing(self):
"""ChatRoutingMiddleware sets group chat fields."""
ctx = make_ctx(group_code="grp-1", group_name="Test Group")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "group:grp-1"
assert ctx.chat_type == "group"
assert ctx.chat_name == "Test Group"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing(self):
"""ChatRoutingMiddleware sets DM chat fields."""
ctx = make_ctx(from_account="alice", sender_nickname="Alice")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_id == "direct:alice"
assert ctx.chat_type == "dm"
assert ctx.chat_name == "Alice"
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_routing_no_nickname(self):
"""ChatRoutingMiddleware falls back to from_account when no nickname."""
ctx = make_ctx(from_account="alice", sender_nickname="")
next_fn = AsyncMock()
await ChatRoutingMiddleware()(ctx, next_fn)
assert ctx.chat_name == "alice"
class TestAccessGuardMiddleware:
@pytest.mark.asyncio
async def test_open_policy_passes(self):
"""AccessGuardMiddleware passes with open policy."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_disabled_dm_stops(self):
"""AccessGuardMiddleware stops DM when dm_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_dm_allowed(self):
"""AccessGuardMiddleware passes DM when sender is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_allowlist_dm_blocked(self):
"""AccessGuardMiddleware blocks DM when sender is not in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_disabled_group_stops(self):
"""AccessGuardMiddleware stops group when group_policy=disabled."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_allowlist_group_allowed(self):
"""AccessGuardMiddleware passes group when group_code is in allowlist."""
adapter = make_adapter()
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"])
ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1")
next_fn = AsyncMock()
await AccessGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestExtractContentMiddleware:
@pytest.mark.asyncio
async def test_extracts_text_and_media(self):
"""ExtractContentMiddleware extracts text and media refs."""
adapter = make_adapter()
msg_body = [
{"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}},
{"msg_type": "TIMImageElem", "msg_content": {
"image_info_array": [{"url": "https://img.example.com/1.jpg"}]
}},
]
ctx = make_ctx(adapter=adapter, msg_body=msg_body)
next_fn = AsyncMock()
await ExtractContentMiddleware()(ctx, next_fn)
assert "Hello!" in ctx.raw_text
assert len(ctx.media_refs) == 1
assert ctx.media_refs[0]["kind"] == "image"
next_fn.assert_awaited_once()
class TestPlaceholderFilterMiddleware:
@pytest.mark.asyncio
async def test_placeholder_stops(self):
"""PlaceholderFilterMiddleware stops on pure placeholder."""
ctx = make_ctx(raw_text="[image]", media_refs=[])
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_placeholder_with_media_passes(self):
"""PlaceholderFilterMiddleware passes placeholder when media exists."""
ctx = make_ctx(
raw_text="[image]",
media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}],
)
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_normal_text_passes(self):
"""PlaceholderFilterMiddleware passes normal text."""
ctx = make_ctx(raw_text="Hello world!")
next_fn = AsyncMock()
await PlaceholderFilterMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
class TestGroupAtGuardMiddleware:
@pytest.mark.asyncio
async def test_dm_passes(self):
"""GroupAtGuardMiddleware passes DM messages."""
adapter = make_adapter()
ctx = make_ctx(adapter=adapter, chat_type="dm")
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_with_at_bot_passes(self):
"""GroupAtGuardMiddleware passes group messages that @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
msg_body = [
{"msg_type": "TIMCustomElem", "msg_content": {
"data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"})
}},
]
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=msg_body,
from_account="alice",
sender_nickname="Alice",
raw_text="Hello",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_group_without_at_bot_observes(self):
"""GroupAtGuardMiddleware observes group messages without @bot."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._session_store = None # No session store -> observe is a no-op
ctx = make_ctx(
adapter=adapter,
chat_type="group",
chat_id="group:grp-1",
msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}],
from_account="alice",
sender_nickname="Alice",
raw_text="hi",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_not_awaited()
@pytest.mark.asyncio
async def test_owner_command_skips_at_check(self):
"""GroupAtGuardMiddleware passes when owner_command is set."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
ctx = make_ctx(
adapter=adapter,
chat_type="group",
msg_body=[],
owner_command="/new",
source=MagicMock(),
)
next_fn = AsyncMock()
await GroupAtGuardMiddleware()(ctx, next_fn)
next_fn.assert_awaited_once()
# ============================================================
# 4. Factory Tests
# ============================================================
class TestCreateInboundPipeline:
def test_default_pipeline_has_all_middlewares(self):
"""InboundPipelineBuilder.build() creates pipeline with all expected middlewares."""
pipeline = InboundPipelineBuilder.build()
expected = [
"decode",
"extract-fields",
"recall_guard",
"dedup",
"skip-self",
"chat-routing",
"access-guard",
"auto-sethome",
"extract-content",
"placeholder-filter",
"owner-command",
"build-source",
"group-at-guard",
"group-attribution",
"classify-msg-type",
"quote-context",
"media-resolve",
"patch-anchors",
"dispatch",
]
assert pipeline.middleware_names == expected
# ============================================================
# 5. End-to-End Pipeline Integration Tests
# ============================================================
class TestPipelineIntegration:
@pytest.mark.asyncio
async def test_full_dm_message_flow(self):
"""Full pipeline processes a DM message end-to-end."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[])
adapter.handle_message = AsyncMock()
adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], []))
push_data = make_json_push(
from_account="alice",
to_account="bot_123",
text="Hello bot!",
msg_id="msg-e2e-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Verify context was populated correctly
assert ctx.decoded_via == "json"
assert ctx.from_account == "alice"
assert ctx.chat_type == "dm"
assert ctx.chat_id == "direct:alice"
assert "Hello bot!" in ctx.raw_text
assert ctx.source is not None
@pytest.mark.asyncio
async def test_self_message_filtered(self):
"""Pipeline stops when message is from bot itself."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
push_data = make_json_push(
from_account="bot_123",
to_account="bot_123",
text="echo",
msg_id="msg-self-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline should have stopped at skip-self — no source built
assert ctx.source is None
@pytest.mark.asyncio
async def test_duplicate_message_filtered(self):
"""Pipeline stops on duplicate message."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
# First message goes through
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-dup-001",
)
ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx1)
assert ctx1.from_account == "alice"
# Second message with same msg_id is filtered
ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data])
await pipeline.execute(ctx2)
# Dedup should stop pipeline before chat routing
assert ctx2.chat_type == ""
@pytest.mark.asyncio
async def test_blocked_dm_filtered(self):
"""Pipeline stops when DM is blocked by policy."""
adapter = make_adapter()
adapter._bot_id = "bot_123"
adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[])
push_data = make_json_push(
from_account="alice",
text="Hello!",
msg_id="msg-blocked-001",
)
ctx = InboundContext(adapter=adapter, raw_frames=[push_data])
pipeline = InboundPipelineBuilder.build()
await pipeline.execute(ctx)
# Pipeline stopped at access-guard — no content extracted
assert ctx.raw_text == ""
@pytest.mark.asyncio
async def test_adapter_has_pipeline(self):
"""YuanbaoAdapter.__init__ creates an inbound pipeline."""
adapter = make_adapter()
assert hasattr(adapter, "_inbound_pipeline")
assert isinstance(adapter._inbound_pipeline, InboundPipeline)
if __name__ == "__main__":
pytest.main([__file__, "-v"])
# ============================================================
# 6. OOP Middleware Tests
# ============================================================
class TestInboundMiddlewareABC:
"""Test the InboundMiddleware OOP protocol (callable + named)."""
def test_subclass_with_handle_works(self):
"""Subclass with handle() can be instantiated."""
class GoodMiddleware(InboundMiddleware):
name = "good"
async def handle(self, ctx, next_fn):
await next_fn()
mw = GoodMiddleware()
assert mw.name == "good"
@pytest.mark.asyncio
async def test_callable_protocol(self):
"""Middleware instances are callable via __call__."""
class TestMW(InboundMiddleware):
name = "test"
async def handle(self, ctx, next_fn):
ctx.raw_text = "called"
await next_fn()
mw = TestMW()
ctx = make_ctx()
next_fn = AsyncMock()
await mw(ctx, next_fn) # Call via __call__
assert ctx.raw_text == "called"
next_fn.assert_awaited_once()
class TestMiddlewareClasses:
"""Pin the canonical ``name`` of each concrete middleware class.
These names are referenced by ``InboundPipelineBuilder.build()`` ordering,
by ``use_before`` / ``use_after`` insertion in extensions, and by log
messages — so they're a real downstream contract worth pinning.
"""
MIDDLEWARE_CLASSES = [
(DecodeMiddleware, "decode"),
(ExtractFieldsMiddleware, "extract-fields"),
(DedupMiddleware, "dedup"),
(SkipSelfMiddleware, "skip-self"),
(ChatRoutingMiddleware, "chat-routing"),
(AccessGuardMiddleware, "access-guard"),
(ExtractContentMiddleware, "extract-content"),
(PlaceholderFilterMiddleware, "placeholder-filter"),
(OwnerCommandMiddleware, "owner-command"),
(BuildSourceMiddleware, "build-source"),
(GroupAtGuardMiddleware, "group-at-guard"),
(DispatchMiddleware, "dispatch"),
]
@pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES)
def test_has_correct_name(self, cls, expected_name):
"""Each middleware class has the expected name."""
mw = cls()
assert mw.name == expected_name
class TestPipelineOOPRegistration:
"""Test that InboundPipeline works with OOP middleware instances."""
@pytest.mark.asyncio
async def test_use_with_middleware_instance(self):
"""pipeline.use(SomeMiddleware()) auto-extracts name."""
class TestMW(InboundMiddleware):
name = "test-mw"
async def handle(self, ctx, next_fn):
ctx.raw_text = "oop-works"
await next_fn()
pipeline = InboundPipeline().use(TestMW())
assert pipeline.middleware_names == ["test-mw"]
ctx = make_ctx()
await pipeline.execute(ctx)
assert ctx.raw_text == "oop-works"
@pytest.mark.asyncio
async def test_mixed_oop_and_functional(self):
"""Pipeline supports mixing OOP and functional middlewares."""
order = []
class OopMW(InboundMiddleware):
name = "oop"
async def handle(self, ctx, next_fn):
order.append("oop")
await next_fn()
async def func_mw(ctx, next_fn):
order.append("func")
await next_fn()
pipeline = (
InboundPipeline()
.use(OopMW())
.use("func", func_mw)
)
assert pipeline.middleware_names == ["oop", "func"]
await pipeline.execute(make_ctx())
assert order == ["oop", "func"]
# ============================================================
# QuoteContextMiddleware Tests
# ============================================================
#
# Quote-media resolution used to depend on a process-local
# msg_id→resids cache populated by ExtractContentMiddleware. After #27866
# made gateway/run.py write @bot user transcript entries with
# message_id (symmetric with the observed-group writer at yuanbao.py:2091),
# QuoteContextMiddleware's transcript-lookup path covers every quote case
# we used to rely on the cache for, so the cache (and those tests) were
# removed. ``_extract_quote_context()`` is now a pure (quote_id, quote_text)
# extractor; quote media references are populated separately by
# ``_extract_media_refs_from_transcript()`` against the transcript store.
class TestQuoteContextMiddleware:
"""Tests for QuoteContextMiddleware._extract_quote_context."""
def test_extract_quote_context_no_cloud_data(self):
"""Returns (None, None) when cloud_custom_data is empty."""
result = QuoteContextMiddleware()._extract_quote_context("")
assert result == (None, None)
def test_extract_quote_context_no_quote_key(self):
"""Returns (None, None) when JSON has no 'quote' key."""
cloud_data = json.dumps({"foo": "bar"})
result = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert result == (None, None)
def test_extract_quote_context_with_desc(self):
"""Extracts quote_id and quote_text from desc."""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-001",
"desc": "Hello world",
"sender_nickname": "Alice",
}
})
quote_id, quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id == "quoted-msg-001"
assert quote_text == "Alice: Hello world"
def test_extract_quote_context_empty_desc(self):
"""When desc is empty, quote_text is None but quote_id is preserved."""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-003",
"desc": "",
"sender_nickname": "Carol",
}
})
quote_id, quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id == "quoted-msg-003"
assert quote_text is None
def test_extract_quote_context_no_quote_id(self):
"""When quote.id is empty, quote_id is None."""
cloud_data = json.dumps({
"quote": {
"id": "",
"desc": "some text",
}
})
quote_id, _quote_text = QuoteContextMiddleware()._extract_quote_context(cloud_data)
assert quote_id is None
@pytest.mark.asyncio
async def test_handle_sets_ctx_fields(self):
"""QuoteContextMiddleware.handle() sets ctx.reply_to_message_id, reply_to_text, quote_media_refs.
With no transcript store wired up, quote_media_refs falls back to []
— media resolution from transcript is covered by separate tests.
"""
cloud_data = json.dumps({
"quote": {
"id": "quoted-msg-004",
"desc": "Check this image",
"sender_nickname": "Dave",
}
})
adapter = make_adapter()
adapter._session_store = None # no transcript lookup path
ctx = make_ctx(adapter=adapter, cloud_custom_data=cloud_data)
next_fn = AsyncMock()
await QuoteContextMiddleware()(ctx, next_fn)
assert ctx.reply_to_message_id == "quoted-msg-004"
assert ctx.reply_to_text == "Dave: Check this image"
assert ctx.quote_media_refs == []
next_fn.assert_awaited_once()
# ============================================================
# MediaResolveMiddleware Tests
# ============================================================
#
# After the dispatch refactor, MediaResolveMiddleware is the single entry
# point for all inbound media downloads. It merges up to three sources
# into ``ctx.media_urls`` / ``ctx.media_types`` (deduped, in this order):
#
# 1) media carried by the current message itself (always),
# 2) quote_media_refs (when reply_to_message_id is set),
# 3) recent group-observed media (only when chat_type == "group" and
# no quote is present).
#
# Direct messages skip the observed backfill entirely.
class TestResolveYbresRefs:
"""Direct tests for ``MediaResolveMiddleware._resolve_ybres_refs``.
This classmethod is the shared engine for both ``_resolve_quote_media``
and ``_collect_observed_media``. Patching the two upstream callers from
routing tests doesn't exercise its filtering / error-swallowing
behavior, so we pin those contracts directly here.
"""
@pytest.mark.asyncio
async def test_resolves_each_ref_in_order(self):
"""Successful resolution returns ``(paths, mimes)`` aligned with input order."""
adapter = make_adapter()
refs = [
("rid-1", "image", "a.jpg"),
("rid-2", "file", "doc.pdf"),
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(side_effect=["https://fresh/1", "https://fresh/2"]),
) as p_fetch, patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(side_effect=[
("/cache/a.jpg", "image/jpeg"),
("/cache/doc.pdf", "application/pdf"),
]),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == ["/cache/a.jpg", "/cache/doc.pdf"]
assert mimes == ["image/jpeg", "application/pdf"]
assert p_fetch.await_count == 2
# filename from the ref tuple is forwarded to download_and_cache
cache_kwargs = [c.kwargs for c in p_cache.await_args_list]
assert cache_kwargs[0]["file_name"] == "a.jpg"
assert cache_kwargs[0]["kind"] == "image"
assert cache_kwargs[0]["resource_id"] == "rid-1"
assert cache_kwargs[1]["file_name"] == "doc.pdf"
@pytest.mark.asyncio
async def test_skips_unresolvable_kinds(self):
"""Refs whose kind is outside ``_RESOLVABLE_MEDIA_KINDS`` are dropped silently."""
adapter = make_adapter()
refs = [
("rid-v", "video", ""), # not resolvable
("rid-i", "image", "ok.jpg"), # resolvable
("rid-?", "unknown", ""), # not resolvable
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(return_value="https://fresh/i"),
) as p_fetch, patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=("/cache/ok.jpg", "image/jpeg")),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == ["/cache/ok.jpg"]
assert mimes == ["image/jpeg"]
# Only the resolvable ref hit the network.
p_fetch.assert_awaited_once()
p_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_fetch_failure_is_swallowed_per_ref(self):
"""If ``_fetch_resource_url`` raises, that ref is skipped — not the whole batch."""
adapter = make_adapter()
refs = [
("rid-bad", "image", ""),
("rid-ok", "image", ""),
]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(side_effect=[RuntimeError("boom"), "https://fresh/ok"]),
), patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=("/cache/ok.jpg", "image/jpeg")),
) as p_cache:
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
# bad ref dropped; good ref preserved
assert paths == ["/cache/ok.jpg"]
assert mimes == ["image/jpeg"]
# download_and_cache was only invoked for the surviving ref
p_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_cache_miss_drops_ref(self):
"""If ``_download_and_cache`` returns None, the ref is dropped."""
adapter = make_adapter()
refs = [("rid-1", "image", "")]
with patch.object(
MediaResolveMiddleware, "_fetch_resource_url",
new=AsyncMock(return_value="https://fresh/1"),
), patch.object(
MediaResolveMiddleware, "_download_and_cache",
new=AsyncMock(return_value=None),
):
paths, mimes = await MediaResolveMiddleware._resolve_ybres_refs(
adapter, refs, log_prefix="test",
)
assert paths == []
assert mimes == []
class TestMediaResolveMiddlewareRouting:
"""Branch-routing tests for MediaResolveMiddleware.handle()."""
def _make_resolved_ctx(self, *, chat_type: str, reply_to: str = None,
quote_media_refs=None, raw_text: str = "hello"):
adapter = make_adapter()
ctx = make_ctx(
adapter=adapter,
chat_type=chat_type,
reply_to_message_id=reply_to,
quote_media_refs=list(quote_media_refs or []),
raw_text=raw_text,
media_refs=[], # no own attachments by default
)
return adapter, ctx
@pytest.mark.asyncio
async def test_dm_no_quote_skips_observed_backfill(self):
"""In dm chats, observed-media backfill is never invoked."""
_adapter, ctx = self._make_resolved_ctx(chat_type="dm")
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
) as p_own, patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=([], [])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=([], [])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_own.assert_awaited_once()
p_quote.assert_not_awaited()
p_observed.assert_not_awaited()
next_fn.assert_awaited_once()
@pytest.mark.asyncio
async def test_dm_with_quote_resolves_quote_media_only(self):
"""In dm chats with a quote, only quote media (plus own) is resolved."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-001",
quote_media_refs=[("rid-q1", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
) as p_own, patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q1.jpg"], ["image/jpeg"])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=([], [])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_own.assert_awaited_once()
p_quote.assert_awaited_once()
p_observed.assert_not_awaited()
assert ctx.media_urls == ["/cache/q1.jpg"]
assert ctx.media_types == ["image/jpeg"]
@pytest.mark.asyncio
async def test_group_no_quote_runs_observed_backfill(self):
"""In group chats without quote, observed-media backfill is invoked."""
_adapter, ctx = self._make_resolved_ctx(chat_type="group")
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=([], [])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=(["/cache/o1.jpg"], ["image/jpeg"])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_quote.assert_not_awaited()
p_observed.assert_awaited_once()
assert ctx.media_urls == ["/cache/o1.jpg"]
@pytest.mark.asyncio
async def test_group_with_quote_skips_observed_backfill(self):
"""In group chats with a quote, only quote media is resolved (no backfill)."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="group",
reply_to="quoted-002",
quote_media_refs=[("rid-q2", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q2.jpg"], ["image/jpeg"])),
) as p_quote, patch.object(
MediaResolveMiddleware, "_collect_observed_media",
new=AsyncMock(return_value=(["/cache/o2.jpg"], ["image/jpeg"])),
) as p_observed:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
p_quote.assert_awaited_once()
p_observed.assert_not_awaited()
assert ctx.media_urls == ["/cache/q2.jpg"]
@pytest.mark.asyncio
async def test_merges_and_dedupes_three_sources(self):
"""Own + quote sources are merged with dedup applied."""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-003",
quote_media_refs=[("rid", "image", "")],
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=(["/cache/a.jpg"], ["image/jpeg"])),
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
# Same path as own → must be deduped.
new=AsyncMock(return_value=(["/cache/a.jpg", "/cache/b.jpg"],
["image/jpeg", "image/png"])),
):
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
assert ctx.media_urls == ["/cache/a.jpg", "/cache/b.jpg"]
assert ctx.media_types == ["image/jpeg", "image/png"]
@pytest.mark.asyncio
async def test_placeholder_recheck_uses_own_count_only(self):
"""Placeholder retry-skip uses ``own_count``, not the merged total.
A bare placeholder text (e.g. ``[image]``) accompanied only by a
quote-resolved image is still skippable — quote media must not
flip a placeholder into a non-placeholder.
"""
_adapter, ctx = self._make_resolved_ctx(
chat_type="dm",
reply_to="quoted-004",
quote_media_refs=[("rid", "image", "")],
raw_text="[image]",
)
with patch.object(
MediaResolveMiddleware, "_resolve_media_urls",
new=AsyncMock(return_value=([], [])), # no own media
), patch.object(
MediaResolveMiddleware, "_resolve_quote_media",
new=AsyncMock(return_value=(["/cache/q.jpg"], ["image/jpeg"])),
), patch.object(
PlaceholderFilterMiddleware, "is_skippable_placeholder",
return_value=True,
) as p_check:
next_fn = AsyncMock()
await MediaResolveMiddleware()(ctx, next_fn)
# Pipeline short-circuited despite quote media being present.
next_fn.assert_not_awaited()
# And the second-pass check was called with own_count == 0.
p_check.assert_called_once()
_text_arg, count_arg = p_check.call_args.args
assert count_arg == 0
# ============================================================
# PatchAnchorsMiddleware Tests
# ============================================================
class TestPatchAnchorsMiddleware:
"""Tests for PatchAnchorsMiddleware._patch()."""
def test_no_op_when_text_or_urls_empty(self):
assert PatchAnchorsMiddleware._patch("", [], []) == ""
assert PatchAnchorsMiddleware._patch("hello", [], []) == "hello"
def test_replaces_image_anchor_with_local_path(self):
text = "look [image|ybres:abc] please"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/x.jpg"], ["image/jpeg"],
)
assert out == "look [image: /cache/x.jpg] please"
def test_replaces_file_anchor_with_filename_label(self):
text = "see [file:doc.pdf|ybres:rid-1]"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/doc.pdf"], ["application/pdf"],
)
assert "[file: doc.pdf → /cache/doc.pdf]" in out
def test_skips_non_local_paths(self):
"""URLs not starting with '/' are left untouched."""
text = "[image|ybres:abc]"
out = PatchAnchorsMiddleware._patch(
text, ["https://example.com/x.jpg"], ["image/jpeg"],
)
# Anchor preserved verbatim because the resolved url is remote.
assert out == text
def test_anchor_kind_image_requires_image_mime(self):
"""An [image|...] anchor with a non-image mime is left alone."""
text = "[image|ybres:rid]"
out = PatchAnchorsMiddleware._patch(
text, ["/cache/odd.bin"], ["application/octet-stream"],
)
assert out == text
@pytest.mark.asyncio
async def test_handle_writes_back_to_ctx(self):
adapter = make_adapter()
ctx = make_ctx(
adapter=adapter,
raw_text="hi [image|ybres:rid]",
media_urls=["/cache/y.png"],
media_types=["image/png"],
)
next_fn = AsyncMock()
await PatchAnchorsMiddleware()(ctx, next_fn)
assert ctx.raw_text == "hi [image: /cache/y.png]"
next_fn.assert_awaited_once()