diff --git a/tests/gateway/conftest.py b/tests/gateway/conftest.py index d2f55ff9f6..3e734e0d40 100644 --- a/tests/gateway/conftest.py +++ b/tests/gateway/conftest.py @@ -88,11 +88,63 @@ def _ensure_discord_mock() -> None: discord_mod.Thread = type("Thread", (), {}) discord_mod.ForumChannel = type("ForumChannel", (), {}) discord_mod.Interaction = object - discord_mod.Embed = MagicMock + discord_mod.Message = type("Message", (), {}) + + # Embed: accept the kwargs production code / tests use + # (title, description, color). MagicMock auto-attributes work too, + # but some tests construct and inspect .title/.description directly. + class _FakeEmbed: + def __init__(self, *, title=None, description=None, color=None, **_): + self.title = title + self.description = description + self.color = color + discord_mod.Embed = _FakeEmbed + + # ui.View / ui.Select / ui.Button: real classes (not MagicMock) so + # tests that subclass ModelPickerView / iterate .children / clear + # items work. + class _FakeView: + def __init__(self, timeout=None): + self.timeout = timeout + self.children = [] + def add_item(self, item): + self.children.append(item) + def clear_items(self): + self.children.clear() + + class _FakeSelect: + def __init__(self, *, placeholder=None, options=None, custom_id=None, **_): + self.placeholder = placeholder + self.options = options or [] + self.custom_id = custom_id + self.callback = None + self.disabled = False + + class _FakeButton: + def __init__(self, *, label=None, style=None, custom_id=None, emoji=None, + url=None, disabled=False, row=None, sku_id=None, **_): + self.label = label + self.style = style + self.custom_id = custom_id + self.emoji = emoji + self.url = url + self.disabled = disabled + self.row = row + self.sku_id = sku_id + self.callback = None + + class _FakeSelectOption: + def __init__(self, *, label=None, value=None, description=None, **_): + self.label = label + self.value = value + self.description = description + discord_mod.SelectOption = _FakeSelectOption + discord_mod.ui = SimpleNamespace( - View=object, + View=_FakeView, + Select=_FakeSelect, + Button=_FakeButton, button=lambda *a, **k: (lambda fn: fn), - Button=object, ) discord_mod.ButtonStyle = SimpleNamespace( success=1, primary=2, secondary=2, danger=3, @@ -100,7 +152,7 @@ def _ensure_discord_mock() -> None: ) discord_mod.Color = SimpleNamespace( orange=lambda: 1, green=lambda: 2, blue=lambda: 3, - red=lambda: 4, purple=lambda: 5, + red=lambda: 4, purple=lambda: 5, greyple=lambda: 6, ) # app_commands — needed by _register_slash_commands auto-registration diff --git a/tests/gateway/test_discord_model_picker.py b/tests/gateway/test_discord_model_picker.py index 1fd8ac4de9..a1ff434bd3 100644 --- a/tests/gateway/test_discord_model_picker.py +++ b/tests/gateway/test_discord_model_picker.py @@ -1,118 +1,16 @@ -"""Regression tests for the Discord /model picker.""" +"""Regression tests for the Discord /model picker. -from types import ModuleType, SimpleNamespace -from unittest.mock import AsyncMock, MagicMock -import sys +Uses the shared discord mock from tests/gateway/conftest.py (installed +at collection time via _ensure_discord_mock()). Previously this file +installed its own mock at module-import time and clobbered sys.modules, +breaking other gateway tests under pytest-xdist. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock import pytest - -def _ensure_discord_mock(): - existing = sys.modules.get("discord") - if isinstance(existing, ModuleType) and getattr(existing, "__file__", None): - return - - class _FakeView: - def __init__(self, timeout=None): - self.timeout = timeout - self.children = [] - - def add_item(self, item): - self.children.append(item) - - def clear_items(self): - self.children.clear() - - class _FakeSelect: - def __init__(self, *, placeholder, options, custom_id): - self.placeholder = placeholder - self.options = options - self.custom_id = custom_id - self.callback = None - self.disabled = False - - class _FakeButton: - def __init__(self, *, label, style, custom_id=None, emoji=None, url=None, disabled=False, row=None, sku_id=None): - self.label = label - self.style = style - self.custom_id = custom_id - self.emoji = emoji - self.url = url - self.disabled = disabled - self.row = row - self.sku_id = sku_id - self.callback = None - - class _FakeSelectOption: - def __init__(self, *, label, value, description=None): - self.label = label - self.value = value - self.description = description - - class _FakeEmbed: - def __init__(self, *, title, description, color): - self.title = title - self.description = description - self.color = color - - class _FakeColor: - @staticmethod - def green(): - return "green" - - @staticmethod - def blue(): - return "blue" - - @staticmethod - def red(): - return "red" - - @staticmethod - def greyple(): - return "greyple" - - class _FakeButtonStyle: - green = "green" - grey = "grey" - red = "red" - blurple = "blurple" - - discord_mod = sys.modules.get("discord") or MagicMock() - discord_mod.Intents.default.return_value = MagicMock() - discord_mod.DMChannel = type("DMChannel", (), {}) - discord_mod.Thread = type("Thread", (), {}) - discord_mod.ForumChannel = type("ForumChannel", (), {}) - discord_mod.Interaction = object - discord_mod.Message = type("Message", (), {}) - discord_mod.SelectOption = _FakeSelectOption - discord_mod.Embed = _FakeEmbed - discord_mod.Color = _FakeColor - discord_mod.ButtonStyle = _FakeButtonStyle - discord_mod.app_commands = getattr( - discord_mod, - "app_commands", - SimpleNamespace(describe=lambda **kwargs: (lambda fn: fn)), - ) - discord_mod.ui = SimpleNamespace( - View=_FakeView, - Select=_FakeSelect, - Button=_FakeButton, - button=lambda **kwargs: (lambda fn: fn), - ) - - ext_mod = MagicMock() - commands_mod = MagicMock() - commands_mod.Bot = MagicMock - ext_mod.commands = commands_mod - - sys.modules["discord"] = discord_mod - sys.modules.setdefault("discord.ext", ext_mod) - sys.modules.setdefault("discord.ext.commands", commands_mod) - - -_ensure_discord_mock() - from gateway.platforms.discord import ModelPickerView @@ -135,7 +33,12 @@ async def test_model_picker_clears_controls_before_running_switch_callback(): ) async def edit_original_response(**kwargs): - events.append(("final-edit", kwargs["embed"].title, kwargs["embed"].description, kwargs["view"])) + events.append(( + "final-edit", + kwargs["embed"].title, + kwargs["embed"].description, + kwargs["view"], + )) view = ModelPickerView( providers=[