mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-24 10:52:21 +00:00
Discord enforces a hard 100-command limit per app and rejects an upsert that would push the live total over 100 (error 30032), which silently breaks ALL slash commands. The sync deleted obsolete commands AFTER creating new ones, so an app already at the cap momentarily exceeded it and the whole sync failed. Reorder: delete no-longer-desired commands up front, then create/update. Removes the now-redundant trailing delete loop. Adapts @infinitycrew39 PR #50890 to current main (the original adapter diff no longer applied after the platform refactor); test commit cherry-picked with authorship preserved.
7150 lines
316 KiB
Python
7150 lines
316 KiB
Python
from __future__ import annotations
|
||
|
||
"""
|
||
Discord platform adapter.
|
||
|
||
Uses discord.py library for:
|
||
- Receiving messages from servers and DMs
|
||
- Sending responses back
|
||
- Handling threads and channels
|
||
"""
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import struct
|
||
import subprocess
|
||
import tempfile
|
||
import threading
|
||
import time
|
||
from collections import defaultdict
|
||
from contextlib import suppress
|
||
from typing import Callable, Dict, List, Optional, Any, Tuple
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class _Snowflake:
|
||
"""Minimal object exposing ``.id`` — satisfies discord.py's Snowflake
|
||
protocol for ``channel.history(before=...)`` without constructing a
|
||
``discord.Object`` (which test doubles that stub the discord module
|
||
cannot build). Used to anchor reply-context scans inclusively.
|
||
"""
|
||
|
||
__slots__ = ("id",)
|
||
|
||
def __init__(self, id: int) -> None: # noqa: A002 - matches discord API
|
||
self.id = id
|
||
|
||
VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080}
|
||
_DISCORD_COMMAND_SYNC_POLICIES = {"safe", "bulk", "off"}
|
||
_DISCORD_COMMAND_SYNC_STATE_SUBDIR = "gateway"
|
||
_DISCORD_COMMAND_SYNC_STATE_FILENAME = "discord_command_sync_state.json"
|
||
_DISCORD_NONCONVERSATIONAL_STATE_FILENAME = "discord_nonconversational_messages.json"
|
||
_DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS = 4.5
|
||
_DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS = 30.0
|
||
# Discord enforces a hard cap of 100 global application (slash) commands per
|
||
# app. Registering more makes the ENTIRE sync fail with error 30032
|
||
# ("Maximum number of application commands reached"), which silently breaks
|
||
# every slash command — not just the overflow ones. We keep the desired set
|
||
# at or below this limit at registration time.
|
||
_DISCORD_MAX_APP_COMMANDS = 100
|
||
_DISCORD_NONCONVERSATIONAL_METADATA_KEYS = frozenset({
|
||
"non_conversational",
|
||
"non_conversational_history",
|
||
})
|
||
# Upgrade-bridge fallback only. The primary mechanism is the persisted
|
||
# non-conversational message-ID set populated from explicitly marked sends
|
||
# (metadata["non_conversational"]). These regexes exist solely to recognize
|
||
# status bumps emitted by an older gateway version that pre-dates the marking,
|
||
# so they don't partition history after an upgrade. New emitters should set the
|
||
# metadata flag, not rely on a regex here.
|
||
_DISCORD_NONCONVERSATIONAL_HISTORY_MESSAGE_PATTERNS = (
|
||
re.compile(r"^\s*💾\s*Self-improvement review:\s+\S[\s\S]*$", re.IGNORECASE),
|
||
# Legacy/background-review test doubles used this shorter form before the
|
||
# self-improvement prefix became the stable emitter contract.
|
||
re.compile(
|
||
r"^\s*💾\s+Skill\s+['\"].+?['\"]\s+(?:created|updated|improved|patched)\.?\s*$",
|
||
re.IGNORECASE,
|
||
),
|
||
re.compile(r"^\s*⏳\s+Working\s+—\s+\d+\s+min(?:\s|$)", re.IGNORECASE),
|
||
re.compile(
|
||
r"^\s*\[Background process\s+\S+\s+"
|
||
r"(?:finished with exit code|is still running~)[\s\S]*\]\s*$",
|
||
re.IGNORECASE,
|
||
),
|
||
re.compile(
|
||
r"^\s*(?:✅|❌)\s+Hermes update\s+"
|
||
r"(?:finished|failed|timed out)[\s\S]*$",
|
||
re.IGNORECASE,
|
||
),
|
||
re.compile(r"^\s*♻️?\s+Gateway\s+(?:restarted successfully|online\b)[\s\S]*$", re.IGNORECASE),
|
||
)
|
||
|
||
try:
|
||
import discord
|
||
from discord import Message as DiscordMessage, Intents
|
||
from discord.ext import commands
|
||
DISCORD_AVAILABLE = True
|
||
except ImportError:
|
||
DISCORD_AVAILABLE = False
|
||
discord = None
|
||
DiscordMessage = Any
|
||
Intents = Any
|
||
commands = None
|
||
|
||
import sys
|
||
from pathlib import Path as _Path
|
||
sys.path.insert(0, str(_Path(__file__).resolve().parents[3]))
|
||
|
||
from gateway.config import Platform, PlatformConfig
|
||
|
||
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
|
||
from utils import atomic_json_write, env_float
|
||
from gateway.platforms.base import (
|
||
BasePlatformAdapter,
|
||
MessageEvent,
|
||
MessageType,
|
||
ProcessingOutcome,
|
||
SendResult,
|
||
cache_image_from_url,
|
||
cache_image_from_bytes,
|
||
cache_audio_from_url,
|
||
cache_audio_from_bytes,
|
||
cache_document_from_bytes,
|
||
SUPPORTED_DOCUMENT_TYPES,
|
||
_TEXT_INJECT_EXTENSIONS,
|
||
validate_inbound_media_size,
|
||
)
|
||
from tools.url_safety import is_safe_url
|
||
|
||
|
||
async def _wait_for_ready_or_bot_exit(
|
||
ready_event: asyncio.Event,
|
||
bot_task: asyncio.Task,
|
||
timeout: float,
|
||
) -> None:
|
||
"""Wait until Discord is ready, or surface early bot startup failure.
|
||
|
||
``discord.py`` startup errors (including SOCKS/proxy failures from
|
||
aiohttp-socks/python-socks) happen inside ``Bot.start()``. If ``connect()``
|
||
only waits on ``ready_event``, a dead background task still burns the full
|
||
ready timeout before the gateway supervisor can reconnect. Racing the ready
|
||
event against the bot task keeps failures fast and preserves the original
|
||
exception for logging/classification.
|
||
"""
|
||
ready_task = asyncio.create_task(ready_event.wait())
|
||
try:
|
||
done, _pending = await asyncio.wait(
|
||
{ready_task, bot_task},
|
||
timeout=timeout,
|
||
return_when=asyncio.FIRST_COMPLETED,
|
||
)
|
||
if not done:
|
||
raise asyncio.TimeoutError
|
||
if bot_task in done:
|
||
exc = bot_task.exception()
|
||
if exc is not None:
|
||
raise exc
|
||
if not ready_task.done():
|
||
raise RuntimeError("Discord bot task exited before ready")
|
||
await ready_task
|
||
finally:
|
||
if not ready_task.done():
|
||
ready_task.cancel()
|
||
with suppress(asyncio.CancelledError):
|
||
await ready_task
|
||
|
||
|
||
def _find_discord_windows_bundled_opus(discord_module: Any = None) -> Optional[str]:
|
||
"""Return discord.py's bundled Windows opus DLL path when present."""
|
||
if sys.platform != "win32":
|
||
return None
|
||
discord_module = discord if discord_module is None else discord_module
|
||
if discord_module is None:
|
||
return None
|
||
|
||
opus_module = getattr(discord_module, "opus", None)
|
||
opus_file = getattr(opus_module, "__file__", None)
|
||
if not opus_file:
|
||
return None
|
||
|
||
target = "x64" if struct.calcsize("P") * 8 > 32 else "x86"
|
||
bundled = _Path(opus_file).resolve().parent / "bin" / f"libopus-0.{target}.dll"
|
||
if bundled.is_file():
|
||
return str(bundled)
|
||
return None
|
||
|
||
|
||
class _DiscordNonConversationalMessageTracker:
|
||
"""Persistent bounded set of Discord message IDs that are status noise."""
|
||
|
||
_MAX_TRACKED = 2000
|
||
|
||
def __init__(self, max_tracked: int = _MAX_TRACKED):
|
||
self._max_tracked = max_tracked
|
||
self._ids: dict[str, None] = dict.fromkeys(self._load())
|
||
|
||
def _state_path(self) -> _Path:
|
||
from hermes_constants import get_hermes_home
|
||
|
||
return (
|
||
get_hermes_home()
|
||
/ _DISCORD_COMMAND_SYNC_STATE_SUBDIR
|
||
/ _DISCORD_NONCONVERSATIONAL_STATE_FILENAME
|
||
)
|
||
|
||
def _load(self) -> list[str]:
|
||
path = self._state_path()
|
||
if not path.exists():
|
||
return []
|
||
try:
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
if isinstance(data, list):
|
||
return [str(message_id) for message_id in data if str(message_id).strip()]
|
||
except Exception:
|
||
logger.debug("[%s] Failed to load non-conversational Discord IDs", "Discord")
|
||
return []
|
||
|
||
def _save(self) -> None:
|
||
ids = list(self._ids)
|
||
if len(ids) > self._max_tracked:
|
||
ids = ids[-self._max_tracked:]
|
||
self._ids = dict.fromkeys(ids)
|
||
try:
|
||
atomic_json_write(self._state_path(), ids, indent=None)
|
||
except Exception:
|
||
logger.debug("[%s] Failed to save non-conversational Discord IDs", "Discord", exc_info=True)
|
||
|
||
def mark_many(self, message_ids: List[str]) -> None:
|
||
changed = False
|
||
for message_id in message_ids:
|
||
key = str(message_id or "").strip()
|
||
if key and key not in self._ids:
|
||
self._ids[key] = None
|
||
changed = True
|
||
if changed:
|
||
self._save()
|
||
|
||
def __contains__(self, message_id: str) -> bool:
|
||
return str(message_id or "") in self._ids
|
||
|
||
|
||
def _metadata_marks_nonconversational(metadata: Optional[Dict[str, Any]]) -> bool:
|
||
"""Return True when an outbound send was explicitly marked as status-only."""
|
||
if not isinstance(metadata, dict):
|
||
return False
|
||
return any(bool(metadata.get(key)) for key in _DISCORD_NONCONVERSATIONAL_METADATA_KEYS)
|
||
|
||
|
||
def _looks_like_nonconversational_history_message(content: str) -> bool:
|
||
"""Fallback recognizer for legacy status bumps missing persisted IDs."""
|
||
text = content or ""
|
||
return any(pattern.match(text) for pattern in _DISCORD_NONCONVERSATIONAL_HISTORY_MESSAGE_PATTERNS)
|
||
|
||
|
||
def _clean_discord_id(entry: str) -> str:
|
||
"""Strip common prefixes from a Discord user ID or username entry.
|
||
|
||
Users sometimes paste IDs with prefixes like ``user:123``, ``<@123>``,
|
||
or ``<@!123>`` from Discord's UI or other tools. This normalises the
|
||
entry to just the bare ID or username.
|
||
"""
|
||
entry = entry.strip()
|
||
# Strip Discord mention syntax: <@123> or <@!123>
|
||
if entry.startswith("<@") and entry.endswith(">"):
|
||
entry = entry.lstrip("<@!").rstrip(">")
|
||
# Strip "user:" prefix (seen in some Discord tools / onboarding pastes)
|
||
if entry.lower().startswith("user:"):
|
||
entry = entry[5:]
|
||
return entry.strip()
|
||
|
||
|
||
def check_discord_requirements() -> bool:
|
||
"""Check if Discord dependencies are available.
|
||
|
||
Lazy-installs discord.py via ``tools.lazy_deps.ensure("platform.discord")``
|
||
on first call if not present. After successful install, re-binds module
|
||
globals so ``DISCORD_AVAILABLE`` becomes True.
|
||
"""
|
||
global DISCORD_AVAILABLE, discord, DiscordMessage, Intents, commands
|
||
if DISCORD_AVAILABLE:
|
||
return True
|
||
try:
|
||
from tools.lazy_deps import ensure as _lazy_ensure
|
||
_lazy_ensure("platform.discord", prompt=False)
|
||
except Exception:
|
||
return False
|
||
try:
|
||
import discord as _discord
|
||
from discord import Message as _DM, Intents as _Intents
|
||
from discord.ext import commands as _commands
|
||
except ImportError:
|
||
return False
|
||
discord = _discord
|
||
DiscordMessage = _DM
|
||
Intents = _Intents
|
||
commands = _commands
|
||
DISCORD_AVAILABLE = True
|
||
_define_discord_view_classes()
|
||
return True
|
||
|
||
|
||
def _build_allowed_mentions():
|
||
"""Build Discord ``AllowedMentions`` with safe defaults, overridable via env.
|
||
|
||
Discord bots default to parsing ``@everyone``, ``@here``, role pings, and
|
||
user pings when ``allowed_mentions`` is unset on the client — any LLM
|
||
output or echoed user content that contains ``@everyone`` would therefore
|
||
ping the whole server. We explicitly deny ``@everyone`` and role pings
|
||
by default and keep user / replied-user pings enabled so normal
|
||
conversation still works.
|
||
|
||
Override via environment variables (or ``discord.allow_mentions.*`` in
|
||
config.yaml):
|
||
|
||
DISCORD_ALLOW_MENTION_EVERYONE default false — @everyone + @here
|
||
DISCORD_ALLOW_MENTION_ROLES default false — @role pings
|
||
DISCORD_ALLOW_MENTION_USERS default true — @user pings
|
||
DISCORD_ALLOW_MENTION_REPLIED_USER default true — reply-ping author
|
||
"""
|
||
if not DISCORD_AVAILABLE:
|
||
return None
|
||
|
||
def _b(name: str, default: bool) -> bool:
|
||
raw = os.getenv(name, "").strip().lower()
|
||
if not raw:
|
||
return default
|
||
return raw in {"true", "1", "yes", "on"}
|
||
|
||
return discord.AllowedMentions(
|
||
everyone=_b("DISCORD_ALLOW_MENTION_EVERYONE", False),
|
||
roles=_b("DISCORD_ALLOW_MENTION_ROLES", False),
|
||
users=_b("DISCORD_ALLOW_MENTION_USERS", True),
|
||
replied_user=_b("DISCORD_ALLOW_MENTION_REPLIED_USER", True),
|
||
)
|
||
|
||
|
||
class VoiceReceiver:
|
||
"""Captures and decodes voice audio from a Discord voice channel.
|
||
|
||
Attaches to a VoiceClient's socket listener, decrypts RTP packets
|
||
(NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers
|
||
per-user audio. A polling loop detects silence and delivers
|
||
completed utterances via a callback.
|
||
"""
|
||
|
||
SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance
|
||
MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise)
|
||
SAMPLE_RATE = 48000 # Discord native rate
|
||
CHANNELS = 2 # Discord sends stereo
|
||
|
||
def __init__(self, voice_client, allowed_user_ids: set = None):
|
||
self._vc = voice_client
|
||
self._allowed_user_ids = allowed_user_ids or set()
|
||
self._running = False
|
||
|
||
# Decryption
|
||
self._secret_key: Optional[bytes] = None
|
||
self._dave_session = None
|
||
self._bot_ssrc: int = 0
|
||
|
||
# SSRC -> user_id mapping (populated from SPEAKING events)
|
||
self._ssrc_to_user: Dict[int, int] = {}
|
||
self._lock = threading.Lock()
|
||
|
||
# Per-user audio buffers
|
||
self._buffers: Dict[int, bytearray] = defaultdict(bytearray)
|
||
self._last_packet_time: Dict[int, float] = {}
|
||
|
||
# Opus decoder per SSRC (each user needs own decoder state)
|
||
self._decoders: Dict[int, object] = {}
|
||
|
||
# Pause flag: don't capture while bot is playing TTS
|
||
self._paused = False
|
||
|
||
# Debug logging counter (instance-level to avoid cross-instance races)
|
||
self._packet_debug_count = 0
|
||
|
||
# ------------------------------------------------------------------
|
||
# Lifecycle
|
||
# ------------------------------------------------------------------
|
||
|
||
def start(self):
|
||
"""Start listening for voice packets."""
|
||
conn = self._vc._connection
|
||
self._secret_key = bytes(conn.secret_key)
|
||
self._dave_session = conn.dave_session
|
||
self._bot_ssrc = conn.ssrc
|
||
|
||
self._install_speaking_hook(conn)
|
||
conn.add_socket_listener(self._on_packet)
|
||
self._running = True
|
||
logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc)
|
||
|
||
def stop(self):
|
||
"""Stop listening and clean up."""
|
||
self._running = False
|
||
try:
|
||
self._vc._connection.remove_socket_listener(self._on_packet)
|
||
except Exception:
|
||
pass
|
||
with self._lock:
|
||
self._buffers.clear()
|
||
self._last_packet_time.clear()
|
||
self._decoders.clear()
|
||
self._ssrc_to_user.clear()
|
||
logger.info("VoiceReceiver stopped")
|
||
|
||
def pause(self):
|
||
self._paused = True
|
||
|
||
def resume(self):
|
||
self._paused = False
|
||
|
||
# ------------------------------------------------------------------
|
||
# SSRC -> user_id mapping via SPEAKING opcode hook
|
||
# ------------------------------------------------------------------
|
||
|
||
def map_ssrc(self, ssrc: int, user_id: int):
|
||
with self._lock:
|
||
self._ssrc_to_user[ssrc] = user_id
|
||
|
||
def _install_speaking_hook(self, conn):
|
||
"""Wrap the voice websocket hook to capture SPEAKING events (op 5).
|
||
|
||
VoiceConnectionState stores the hook as ``conn.hook`` (public attr).
|
||
It is passed to DiscordVoiceWebSocket on each (re)connect, so we
|
||
must wrap it on the VoiceConnectionState level AND on the current
|
||
live websocket instance.
|
||
"""
|
||
original_hook = conn.hook
|
||
receiver_self = self
|
||
|
||
async def wrapped_hook(ws, msg):
|
||
if isinstance(msg, dict) and msg.get("op") == 5:
|
||
data = msg.get("d", {})
|
||
ssrc = data.get("ssrc")
|
||
user_id = data.get("user_id")
|
||
if ssrc and user_id:
|
||
logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id)
|
||
receiver_self.map_ssrc(int(ssrc), int(user_id))
|
||
if original_hook:
|
||
await original_hook(ws, msg)
|
||
|
||
# Set on connection state (for future reconnects)
|
||
conn.hook = wrapped_hook
|
||
# Set on the current live websocket (for immediate effect)
|
||
try:
|
||
from discord.utils import MISSING
|
||
if hasattr(conn, 'ws') and conn.ws is not MISSING:
|
||
conn.ws._hook = wrapped_hook
|
||
logger.info("Speaking hook installed on live websocket")
|
||
except Exception as e:
|
||
logger.warning("Could not install hook on live ws: %s", e)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Packet handler (called from SocketReader thread)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _on_packet(self, data: bytes):
|
||
if not self._running or self._paused:
|
||
return
|
||
|
||
# Log first few raw packets for debugging
|
||
self._packet_debug_count += 1
|
||
if self._packet_debug_count <= 5:
|
||
logger.debug(
|
||
"Raw UDP packet: len=%d, first_bytes=%s",
|
||
len(data), data[:4].hex() if len(data) >= 4 else "short",
|
||
)
|
||
|
||
if len(data) < 16:
|
||
return
|
||
|
||
# RTP version check: top 2 bits must be 10 (version 2).
|
||
# Lower bits may vary (padding, extension, CSRC count).
|
||
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
|
||
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
|
||
if self._packet_debug_count <= 5:
|
||
logger.debug("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
|
||
return
|
||
|
||
first_byte = data[0]
|
||
_, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0)
|
||
|
||
# Skip bot's own audio
|
||
if ssrc == self._bot_ssrc:
|
||
return
|
||
|
||
# Calculate dynamic RTP header size (RFC 9335 / rtpsize mode)
|
||
cc = first_byte & 0x0F # CSRC count
|
||
has_extension = bool(first_byte & 0x10) # extension bit
|
||
has_padding = bool(first_byte & 0x20) # padding bit (RFC 3550 §5.1)
|
||
header_size = 12 + (4 * cc) + (4 if has_extension else 0)
|
||
|
||
if len(data) < header_size + 4: # need at least header + nonce
|
||
return
|
||
|
||
# Read extension length from preamble (for skipping after decrypt)
|
||
ext_data_len = 0
|
||
if has_extension:
|
||
ext_preamble_offset = 12 + (4 * cc)
|
||
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
|
||
ext_data_len = ext_words * 4
|
||
|
||
if self._packet_debug_count <= 10:
|
||
with self._lock:
|
||
known_user = self._ssrc_to_user.get(ssrc, "unknown")
|
||
logger.debug(
|
||
"RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d",
|
||
ssrc, seq, known_user, header_size, ext_data_len,
|
||
)
|
||
|
||
header = bytes(data[:header_size])
|
||
payload_with_nonce = data[header_size:]
|
||
|
||
# --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) ---
|
||
if len(payload_with_nonce) < 4:
|
||
return
|
||
nonce = bytearray(24)
|
||
nonce[:4] = payload_with_nonce[-4:]
|
||
encrypted = bytes(payload_with_nonce[:-4])
|
||
|
||
try:
|
||
import nacl.secret # noqa: E402 — delayed import, only in voice path
|
||
box = nacl.secret.Aead(self._secret_key)
|
||
decrypted = box.decrypt(encrypted, header, bytes(nonce))
|
||
except Exception as e:
|
||
if self._packet_debug_count <= 10:
|
||
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
|
||
return
|
||
|
||
# Skip encrypted extension data to get the actual opus payload
|
||
if ext_data_len and len(decrypted) > ext_data_len:
|
||
decrypted = decrypted[ext_data_len:]
|
||
|
||
# --- Strip RTP padding (RFC 3550 §5.1) ---
|
||
# When the P bit is set, the last payload byte holds the count of
|
||
# trailing padding bytes (including itself) that must be removed
|
||
# before further processing. Skipping this passes padding-contaminated
|
||
# bytes into DAVE/Opus and corrupts inbound audio.
|
||
if has_padding:
|
||
if not decrypted:
|
||
if self._packet_debug_count <= 10:
|
||
logger.warning(
|
||
"RTP padding bit set but no payload (ssrc=%d)", ssrc,
|
||
)
|
||
return
|
||
pad_len = decrypted[-1]
|
||
if pad_len == 0 or pad_len > len(decrypted):
|
||
if self._packet_debug_count <= 10:
|
||
logger.warning(
|
||
"Invalid RTP padding length %d for payload size %d (ssrc=%d)",
|
||
pad_len, len(decrypted), ssrc,
|
||
)
|
||
return
|
||
decrypted = decrypted[:-pad_len]
|
||
if not decrypted:
|
||
# Padding consumed entire payload — nothing to decode
|
||
return
|
||
|
||
# --- DAVE E2EE decrypt ---
|
||
if self._dave_session:
|
||
with self._lock:
|
||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||
if user_id:
|
||
try:
|
||
import davey
|
||
decrypted = self._dave_session.decrypt(
|
||
user_id, davey.MediaType.audio, decrypted
|
||
)
|
||
except Exception as e:
|
||
# Unencrypted passthrough — use NaCl-decrypted data as-is
|
||
if "Unencrypted" not in str(e):
|
||
if self._packet_debug_count <= 10:
|
||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||
return
|
||
# If SSRC unknown (no SPEAKING event yet), skip DAVE and try
|
||
# Opus decode directly — audio may be in passthrough mode.
|
||
# Buffer will get a user_id when SPEAKING event arrives later.
|
||
|
||
# --- Opus decode -> PCM ---
|
||
try:
|
||
if ssrc not in self._decoders:
|
||
self._decoders[ssrc] = discord.opus.Decoder()
|
||
pcm = self._decoders[ssrc].decode(decrypted)
|
||
with self._lock:
|
||
self._buffers[ssrc].extend(pcm)
|
||
self._last_packet_time[ssrc] = time.monotonic()
|
||
except Exception as e:
|
||
with self._lock:
|
||
self._decoders.pop(ssrc, None)
|
||
logger.debug(
|
||
"Opus decode error for SSRC %s; reset decoder: %s",
|
||
ssrc,
|
||
e,
|
||
)
|
||
return
|
||
|
||
# ------------------------------------------------------------------
|
||
# Silence detection
|
||
# ------------------------------------------------------------------
|
||
|
||
def _infer_user_for_ssrc(self, ssrc: int) -> int:
|
||
"""Try to infer user_id for an unmapped SSRC.
|
||
|
||
When the bot rejoins a voice channel, Discord may not resend
|
||
SPEAKING events for users already speaking. If exactly one
|
||
allowed user is in the channel, map the SSRC to them.
|
||
"""
|
||
try:
|
||
channel = self._vc.channel
|
||
if not channel:
|
||
return 0
|
||
bot_id = self._vc.user.id if self._vc.user else 0
|
||
allowed = self._allowed_user_ids
|
||
candidates = [
|
||
m.id for m in channel.members
|
||
if m.id != bot_id and (not allowed or str(m.id) in allowed)
|
||
]
|
||
if len(candidates) == 1:
|
||
uid = candidates[0]
|
||
self._ssrc_to_user[ssrc] = uid
|
||
logger.info("Auto-mapped ssrc=%d -> user=%d (sole allowed member)", ssrc, uid)
|
||
return uid
|
||
except Exception:
|
||
pass
|
||
return 0
|
||
|
||
def check_silence(self) -> list:
|
||
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||
now = time.monotonic()
|
||
completed = []
|
||
|
||
with self._lock:
|
||
ssrc_user_map = dict(self._ssrc_to_user)
|
||
ssrc_list = list(self._buffers.keys())
|
||
|
||
for ssrc in ssrc_list:
|
||
last_time = self._last_packet_time.get(ssrc, now)
|
||
silence_duration = now - last_time
|
||
buf = self._buffers[ssrc]
|
||
# 48kHz, 16-bit, stereo = 192000 bytes/sec
|
||
buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2)
|
||
|
||
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||
user_id = ssrc_user_map.get(ssrc, 0)
|
||
if not user_id:
|
||
# SSRC not mapped (SPEAKING event missing after bot rejoin).
|
||
# Infer from allowed users in the voice channel.
|
||
user_id = self._infer_user_for_ssrc(ssrc)
|
||
if user_id:
|
||
completed.append((user_id, bytes(buf)))
|
||
self._buffers[ssrc] = bytearray()
|
||
self._last_packet_time.pop(ssrc, None)
|
||
elif silence_duration >= self.SILENCE_THRESHOLD * 2:
|
||
# Stale buffer with no valid user — discard
|
||
self._buffers.pop(ssrc, None)
|
||
self._last_packet_time.pop(ssrc, None)
|
||
|
||
return completed
|
||
|
||
# ------------------------------------------------------------------
|
||
# PCM -> WAV conversion (for Whisper STT)
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def pcm_to_wav(pcm_data: bytes, output_path: str,
|
||
src_rate: int = 48000, src_channels: int = 2):
|
||
"""Convert raw PCM to 16kHz mono WAV via ffmpeg."""
|
||
with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f:
|
||
f.write(pcm_data)
|
||
pcm_path = f.name
|
||
try:
|
||
subprocess.run(
|
||
[
|
||
"ffmpeg", "-y", "-loglevel", "error",
|
||
"-f", "s16le",
|
||
"-ar", str(src_rate),
|
||
"-ac", str(src_channels),
|
||
"-i", pcm_path,
|
||
"-ar", "16000",
|
||
"-ac", "1",
|
||
output_path,
|
||
],
|
||
check=True,
|
||
timeout=10,
|
||
stdin=subprocess.DEVNULL,
|
||
)
|
||
finally:
|
||
try:
|
||
os.unlink(pcm_path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
def _read_dm_role_auth_guild() -> Optional[int]:
|
||
"""Return the guild ID opted-in for DM role-based auth, or None.
|
||
|
||
Reads ``discord.dm_role_auth_guild`` from config.yaml. This is
|
||
deliberately a config.yaml-only setting (not an env var): per repo
|
||
policy, ``~/.hermes/.env`` is for secrets only, and this is a
|
||
behavioral setting. Guild IDs aren't secrets.
|
||
|
||
Accepts ints or numeric strings in the config. Anything else
|
||
(empty, malformed, None) returns None, which keeps the secure
|
||
default (DM role-auth disabled).
|
||
"""
|
||
try:
|
||
from hermes_cli.config import read_raw_config
|
||
cfg = read_raw_config() or {}
|
||
discord_cfg = cfg.get("discord", {}) or {}
|
||
raw = discord_cfg.get("dm_role_auth_guild")
|
||
except Exception:
|
||
return None
|
||
if raw is None or raw == "":
|
||
return None
|
||
try:
|
||
guild_id = int(raw)
|
||
except (TypeError, ValueError):
|
||
return None
|
||
return guild_id if guild_id > 0 else None
|
||
|
||
|
||
class DiscordAdapter(BasePlatformAdapter):
|
||
"""
|
||
Discord bot adapter.
|
||
|
||
Handles:
|
||
- Receiving messages from servers and DMs
|
||
- Sending responses with Discord markdown
|
||
- Thread support
|
||
- Native slash commands (/ask, /reset, /status, /stop)
|
||
- Button-based exec approvals
|
||
- Auto-threading for long conversations
|
||
- Reaction-based feedback
|
||
"""
|
||
|
||
# Discord message limits
|
||
MAX_MESSAGE_LENGTH = 2000
|
||
_SPLIT_THRESHOLD = 1900 # near the 2000-char split point
|
||
supports_code_blocks = True # Discord markdown renders fenced code blocks natively
|
||
splits_long_messages = True # send() chunks via truncate_message(MAX_MESSAGE_LENGTH)
|
||
|
||
# Auto-disconnect from voice channel after this many seconds of inactivity
|
||
VOICE_TIMEOUT = 300
|
||
|
||
def __init__(self, config: PlatformConfig):
|
||
super().__init__(config, Platform.DISCORD)
|
||
self._client: Optional[commands.Bot] = None
|
||
self._ready_event = asyncio.Event()
|
||
self._allowed_user_ids: set = set() # For button approval authorization
|
||
self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering
|
||
self.gateway_runner = None # Set by gateway/run.py for cross-platform delivery
|
||
# Voice channel state (per-guild)
|
||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||
self._voice_locks: Dict[int, asyncio.Lock] = {} # guild_id -> serialize join/leave
|
||
# Text batching: merge rapid successive messages (Telegram-style)
|
||
self._text_batch_delay_seconds = env_float("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", 0.6)
|
||
self._text_batch_split_delay_seconds = env_float("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", 2.0)
|
||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||
self._voice_sources: Dict[int, Dict[str, Any]] = {} # guild_id -> linked text channel source metadata
|
||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||
# Phase 2: voice listening
|
||
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||
# Resolves the current voice-reply mode ("off"|"voice_only"|"all") for a
|
||
# linked text-channel id; set by run.py. Lets the inactivity timer leave
|
||
# the bot in the channel when the user deliberately picked text-only
|
||
# (/voice off) instead of leaving (/voice leave).
|
||
self._voice_mode_getter: Optional[Callable] = None # set by run.py
|
||
# Phase 3: continuous voice mixer (ambient idle bed + ducked speech).
|
||
# Installed once per guild on join; lets acks / TTS / the "thinking"
|
||
# loop overlap in one outgoing stream instead of stop-and-swap.
|
||
self._voice_mixers: Dict[int, Any] = {} # guild_id -> VoiceMixer
|
||
self._ambient_pcm_cache: Optional[bytes] = None # decoded ambient bed
|
||
self._voice_fx_cfg: Dict[str, Any] = self._load_voice_fx_config()
|
||
# Track threads where the bot has participated so follow-up messages
|
||
# in those threads don't require @mention. Persisted to disk so the
|
||
# set survives gateway restarts.
|
||
self._threads = ThreadParticipationTracker("discord")
|
||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||
# show the standard typing gateway event for bots)
|
||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||
self._bot_task: Optional[asyncio.Task] = None
|
||
self._post_connect_task: Optional[asyncio.Task] = None
|
||
# True while disconnect() is intentionally closing discord.py. The
|
||
# bot task's done callback uses this to distinguish an operator/service
|
||
# shutdown from a runtime websocket crash.
|
||
self._disconnecting = False
|
||
# Dedup cache: prevents duplicate bot responses when Discord
|
||
# RESUME replays events after reconnects.
|
||
self._dedup = MessageDeduplicator()
|
||
# Reply threading mode: "off" (no replies), "first" (reply on first
|
||
# chunk only, default), "all" (reply-reference on every chunk).
|
||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||
self._slash_commands: bool = self.config.extra.get("slash_commands", True)
|
||
# In-memory cache of the bot's last message ID per channel, used by
|
||
# history backfill to skip the full scan on hot paths. Falls back to
|
||
# scanning channel.history() on cache miss (cold start / restart).
|
||
self._last_self_message_id: Dict[str, str] = {}
|
||
# Persistent set of bot-authored lifecycle/status message IDs that
|
||
# should not act as conversational history boundaries after restart.
|
||
self._nonconversational_messages = _DiscordNonConversationalMessageTracker()
|
||
|
||
def _handle_bot_task_done(self, task: asyncio.Task) -> None:
|
||
"""Surface post-startup discord.py task exits to the gateway supervisor.
|
||
|
||
discord.py reconnects normal gateway interruptions internally. When its
|
||
top-level ``Bot.start()`` task actually exits after the adapter has been
|
||
marked running, the Discord websocket is dead while the Hermes gateway
|
||
process can remain alive. Treat that split-brain state as a retryable
|
||
fatal adapter error so ``GatewayRunner._handle_adapter_fatal_error`` can
|
||
remove this adapter and queue Discord for the existing reconnect watcher.
|
||
"""
|
||
if getattr(self, "_disconnecting", False):
|
||
# Intentional service/operator shutdown. Drain the task result so
|
||
# asyncio doesn't emit "exception was never retrieved" warnings.
|
||
with suppress(asyncio.CancelledError, Exception):
|
||
task.exception()
|
||
return
|
||
|
||
# Ignore stale callbacks from an older client if a reconnect already
|
||
# installed a newer Bot.start() task on this adapter instance.
|
||
if self._bot_task is not None and task is not self._bot_task:
|
||
with suppress(asyncio.CancelledError, Exception):
|
||
task.exception()
|
||
return
|
||
|
||
if not self._running:
|
||
# Startup failures are handled by _wait_for_ready_or_bot_exit() in
|
||
# connect(); this callback is only for post-startup split-brain.
|
||
with suppress(asyncio.CancelledError, Exception):
|
||
task.exception()
|
||
return
|
||
|
||
try:
|
||
exc = task.exception()
|
||
except asyncio.CancelledError:
|
||
return
|
||
except Exception as err: # pragma: no cover - defensive
|
||
exc = err
|
||
|
||
if exc is None:
|
||
message = "Discord gateway task exited without an exception"
|
||
else:
|
||
message = f"Discord gateway task exited: {exc}"
|
||
|
||
logger.error("[%s] %s", self.name, message, exc_info=exc if exc else False)
|
||
self._set_fatal_error("discord_gateway_task_exited", message, retryable=True)
|
||
|
||
async def _notify() -> None:
|
||
try:
|
||
await self._notify_fatal_error()
|
||
except Exception as notify_exc: # pragma: no cover - defensive logging
|
||
logger.warning(
|
||
"[%s] Failed to notify gateway supervisor about Discord task exit: %s",
|
||
self.name,
|
||
notify_exc,
|
||
exc_info=True,
|
||
)
|
||
|
||
asyncio.create_task(_notify())
|
||
|
||
async def connect(self) -> bool:
|
||
"""Connect to Discord and start receiving events."""
|
||
if not DISCORD_AVAILABLE:
|
||
logger.error("[%s] discord.py not installed. Run: pip install discord.py", self.name)
|
||
return False
|
||
|
||
# Load opus codec for voice channel support
|
||
if not discord.opus.is_loaded():
|
||
import ctypes.util
|
||
opus_candidates = []
|
||
bundled_opus = _find_discord_windows_bundled_opus(discord)
|
||
if bundled_opus:
|
||
opus_candidates.append(bundled_opus)
|
||
opus_path = ctypes.util.find_library("opus")
|
||
if opus_path:
|
||
opus_candidates.append(opus_path)
|
||
# ctypes.util.find_library fails on macOS with Homebrew-installed libs,
|
||
# so fall back to known Homebrew paths if needed.
|
||
if not opus_path:
|
||
_homebrew_paths = (
|
||
"/opt/homebrew/lib/libopus.dylib", # Apple Silicon
|
||
"/usr/local/lib/libopus.dylib", # Intel Mac
|
||
)
|
||
if sys.platform == "darwin":
|
||
for _hp in _homebrew_paths:
|
||
if os.path.isfile(_hp):
|
||
opus_candidates.append(_hp)
|
||
break
|
||
for opus_path in opus_candidates:
|
||
try:
|
||
discord.opus.load_opus(opus_path)
|
||
if discord.opus.is_loaded():
|
||
break
|
||
except Exception:
|
||
logger.warning("Opus codec found at %s but failed to load", opus_path)
|
||
if not discord.opus.is_loaded():
|
||
logger.warning("Opus codec not found — voice channel playback disabled")
|
||
|
||
if not self.config.token:
|
||
logger.error("[%s] No bot token configured", self.name)
|
||
return False
|
||
|
||
try:
|
||
if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
|
||
return False
|
||
|
||
# Parse allowed user entries (may contain usernames or IDs)
|
||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||
if allowed_env:
|
||
self._allowed_user_ids = {
|
||
_clean_discord_id(uid) for uid in allowed_env.split(",")
|
||
if uid.strip()
|
||
}
|
||
|
||
# Parse DISCORD_ALLOWED_ROLES — comma-separated role IDs.
|
||
# Users with ANY of these roles can interact with the bot.
|
||
roles_env = os.getenv("DISCORD_ALLOWED_ROLES", "")
|
||
if roles_env:
|
||
self._allowed_role_ids = {
|
||
int(rid.strip()) for rid in roles_env.split(",")
|
||
if rid.strip().isdigit()
|
||
}
|
||
|
||
# Set up intents.
|
||
# Message Content is required for normal text replies.
|
||
# Server Members is only needed when the allowlist contains usernames
|
||
# that must be resolved to numeric IDs. Requesting privileged intents
|
||
# that aren't enabled in the Discord Developer Portal can prevent the
|
||
# bot from coming online at all, so avoid requesting members intent
|
||
# unless it is actually necessary.
|
||
intents = Intents.default()
|
||
intents.message_content = True
|
||
intents.dm_messages = True
|
||
intents.guild_messages = True
|
||
intents.members = (
|
||
any(not entry.isdigit() for entry in self._allowed_user_ids)
|
||
or bool(self._allowed_role_ids) # Need members intent for role lookup
|
||
)
|
||
intents.voice_states = True
|
||
|
||
# Resolve proxy (DISCORD_PROXY > generic env vars > macOS system proxy)
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_bot
|
||
proxy_url = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
if proxy_url:
|
||
logger.info("[%s] Using proxy for Discord: %s", self.name, proxy_url)
|
||
|
||
# Create bot — proxy= for HTTP, connector= for SOCKS.
|
||
# allowed_mentions is set with safe defaults (no @everyone/roles)
|
||
# so LLM output or echoed user content can't ping the whole
|
||
# server; override per DISCORD_ALLOW_MENTION_* env vars or the
|
||
# discord.allow_mentions.* block in config.yaml.
|
||
|
||
# Close any existing client to prevent zombie websocket connections
|
||
# on reconnect (see #18187). Without this, the old client remains
|
||
# connected to Discord gateway and both fire on_message, causing
|
||
# double responses.
|
||
if self._client is not None:
|
||
try:
|
||
if not self._client.is_closed():
|
||
await self._client.close()
|
||
except Exception:
|
||
logger.debug("[%s] Failed to close previous Discord client", self.name)
|
||
finally:
|
||
self._client = None
|
||
self._ready_event.clear()
|
||
|
||
self._client = commands.Bot(
|
||
command_prefix="!", # Not really used, we handle raw messages
|
||
intents=intents,
|
||
allowed_mentions=_build_allowed_mentions(),
|
||
**proxy_kwargs_for_bot(proxy_url),
|
||
)
|
||
adapter_self = self # capture for closure
|
||
|
||
# Register event handlers
|
||
@self._client.event
|
||
async def on_ready():
|
||
logger.info("[%s] Connected as %s", adapter_self.name, adapter_self._client.user)
|
||
|
||
# Resolve any usernames in the allowed list to numeric IDs
|
||
await adapter_self._resolve_allowed_usernames()
|
||
adapter_self._ready_event.set()
|
||
|
||
if adapter_self._post_connect_task and not adapter_self._post_connect_task.done():
|
||
adapter_self._post_connect_task.cancel()
|
||
adapter_self._post_connect_task = asyncio.create_task(
|
||
adapter_self._run_post_connect_initialization()
|
||
)
|
||
|
||
@self._client.event
|
||
async def on_message(message: DiscordMessage):
|
||
# Block until _resolve_allowed_usernames has swapped
|
||
# any raw usernames in DISCORD_ALLOWED_USERS for numeric
|
||
# IDs (otherwise on_message's author.id lookup can miss).
|
||
if not adapter_self._ready_event.is_set():
|
||
try:
|
||
await asyncio.wait_for(adapter_self._ready_event.wait(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
|
||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||
return
|
||
|
||
# Always ignore our own messages
|
||
if message.author == self._client.user:
|
||
return
|
||
|
||
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
|
||
# Allow both default and reply types — replies have a distinct MessageType.
|
||
if message.type not in {discord.MessageType.default, discord.MessageType.reply}:
|
||
return
|
||
|
||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||
# "none" — ignore all other bots (default)
|
||
# "mentions" — accept bot messages only when they @mention us
|
||
# "all" — accept all bot messages
|
||
# Must run BEFORE the user allowlist check so that bots
|
||
# permitted by DISCORD_ALLOW_BOTS are not rejected for
|
||
# not being in DISCORD_ALLOWED_USERS (fixes #4466).
|
||
_role_authorized = False
|
||
if getattr(message.author, "bot", False):
|
||
allow_bots = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||
if allow_bots == "none":
|
||
return
|
||
elif allow_bots == "mentions":
|
||
if not self._client.user or self._client.user not in message.mentions:
|
||
return
|
||
# "all" falls through; bot is permitted — skip the
|
||
# human-user allowlist below (bots aren't in it).
|
||
else:
|
||
# Non-bot: enforce the configured user/role allowlists.
|
||
# Pass guild + is_dm so role checks are scoped to the
|
||
# originating guild (prevents cross-guild DM bypass, see
|
||
# _is_allowed_user docstring).
|
||
_msg_guild = getattr(message, "guild", None)
|
||
_is_dm = isinstance(message.channel, discord.DMChannel) or _msg_guild is None
|
||
if not self._is_allowed_user(
|
||
str(message.author.id),
|
||
message.author,
|
||
guild=_msg_guild,
|
||
is_dm=_is_dm,
|
||
):
|
||
return
|
||
_role_authorized = bool(getattr(self, "_allowed_role_ids", set()))
|
||
|
||
# Multi-agent filtering: if the message mentions specific bots
|
||
# but NOT this bot, the sender is talking to another agent —
|
||
# stay silent. Messages with no bot mentions (general chat)
|
||
# still fall through to _handle_message for the existing
|
||
# DISCORD_REQUIRE_MENTION check.
|
||
#
|
||
# This replaces the older DISCORD_IGNORE_NO_MENTION logic
|
||
# with bot-aware filtering that works correctly when multiple
|
||
# agents share a channel.
|
||
if not isinstance(message.channel, discord.DMChannel) and message.mentions:
|
||
_self_mentioned = (
|
||
self._client.user is not None
|
||
and self._client.user in message.mentions
|
||
)
|
||
_other_bots_mentioned = any(
|
||
m.bot and m != self._client.user
|
||
for m in message.mentions
|
||
)
|
||
# If other bots are mentioned but we're not → not for us
|
||
if _other_bots_mentioned and not _self_mentioned:
|
||
return
|
||
# If humans are mentioned but we're not → not for us
|
||
# (preserves old DISCORD_IGNORE_NO_MENTION=true behavior)
|
||
# EXCEPT in free-response channels where the bot should
|
||
# answer regardless of who is mentioned.
|
||
_ignore_no_mention = os.getenv(
|
||
"DISCORD_IGNORE_NO_MENTION", "true"
|
||
).lower() in {"true", "1", "yes"}
|
||
if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned:
|
||
_channel_id = str(message.channel.id)
|
||
_parent_id = None
|
||
if hasattr(message.channel, "parent_id") and message.channel.parent_id:
|
||
_parent_id = str(message.channel.parent_id)
|
||
_free_channels = adapter_self._discord_free_response_channels()
|
||
_channel_ids = {_channel_id}
|
||
if _parent_id:
|
||
_channel_ids.add(_parent_id)
|
||
if "*" not in _free_channels and not (_channel_ids & _free_channels):
|
||
return
|
||
|
||
await self._handle_message(message, role_authorized=_role_authorized)
|
||
|
||
@self._client.event
|
||
async def on_voice_state_update(member, before, after):
|
||
"""Track voice channel join/leave events."""
|
||
# Only track channels where the bot is connected
|
||
bot_guild_ids = set(adapter_self._voice_clients.keys())
|
||
if not bot_guild_ids:
|
||
return
|
||
guild_id = member.guild.id
|
||
if guild_id not in bot_guild_ids:
|
||
return
|
||
# Ignore the bot itself
|
||
if member == adapter_self._client.user:
|
||
return
|
||
|
||
joined = before.channel is None and after.channel is not None
|
||
left = before.channel is not None and after.channel is None
|
||
switched = (
|
||
before.channel is not None
|
||
and after.channel is not None
|
||
and before.channel != after.channel
|
||
)
|
||
|
||
if joined or left or switched:
|
||
logger.info(
|
||
"Voice state: %s (%d) %s (guild %d)",
|
||
member.display_name,
|
||
member.id,
|
||
"joined " + after.channel.name if joined
|
||
else "left " + before.channel.name if left
|
||
else f"moved {before.channel.name} -> {after.channel.name}",
|
||
guild_id,
|
||
)
|
||
|
||
# Register slash commands
|
||
if self._slash_commands:
|
||
self._register_slash_commands()
|
||
|
||
# Start the bot in background
|
||
self._disconnecting = False
|
||
self._bot_task = asyncio.create_task(self._client.start(self.config.token))
|
||
self._bot_task.add_done_callback(self._handle_bot_task_done)
|
||
|
||
# Wait for ready, but fail fast if discord.py's background startup
|
||
# task dies first (for example on SOCKS/proxy connect errors).
|
||
await _wait_for_ready_or_bot_exit(self._ready_event, self._bot_task, timeout=30)
|
||
|
||
self._running = True
|
||
return True
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
|
||
# Cancel the background bot task so it cannot fire on_message after
|
||
# this adapter is discarded. Without this, the task keeps running and
|
||
# a later successful reconnect leaves two active Discord clients that
|
||
# each process every message, producing duplicate threads/responses.
|
||
await self._cancel_bot_task()
|
||
self._release_platform_lock()
|
||
return False
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
|
||
# Same zombie-client hazard as the timeout branch: the background
|
||
# client.start() task may already be running when a later setup
|
||
# step raises. Cancel it so the discarded adapter cannot connect.
|
||
await self._cancel_bot_task()
|
||
self._release_platform_lock()
|
||
return False
|
||
|
||
async def _cancel_bot_task(self) -> None:
|
||
"""Cancel and await the background client.start() task, if running."""
|
||
if self._bot_task and not self._bot_task.done():
|
||
self._bot_task.cancel()
|
||
try:
|
||
await self._bot_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
self._bot_task = None
|
||
|
||
async def disconnect(self) -> None:
|
||
"""Disconnect from Discord."""
|
||
self._disconnecting = True
|
||
# Cancel the bot task before closing the client. If connect() timed out
|
||
# and returned False, the background client.start() task may still be
|
||
# running; calling client.close() alone is not enough to stop it because
|
||
# discord.py's reconnect loop can ignore the closed flag while a
|
||
# WebSocket handshake is in flight. Explicitly cancelling the task here
|
||
# ensures the zombie client cannot receive or dispatch any further events.
|
||
await self._cancel_bot_task()
|
||
# Clean up all active voice connections before closing the client
|
||
for guild_id in list(self._voice_clients.keys()):
|
||
try:
|
||
await self.leave_voice_channel(guild_id)
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.debug("[%s] Error leaving voice channel %s: %s", self.name, guild_id, e)
|
||
|
||
if self._client:
|
||
try:
|
||
await self._client.close()
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True)
|
||
|
||
if self._post_connect_task and not self._post_connect_task.done():
|
||
self._post_connect_task.cancel()
|
||
try:
|
||
await self._post_connect_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
self._running = False
|
||
self._client = None
|
||
self._ready_event.clear()
|
||
self._post_connect_task = None
|
||
|
||
self._release_platform_lock()
|
||
|
||
logger.info("[%s] Disconnected", self.name)
|
||
|
||
def _command_sync_state_path(self) -> _Path:
|
||
from hermes_constants import get_hermes_home
|
||
|
||
directory = get_hermes_home() / _DISCORD_COMMAND_SYNC_STATE_SUBDIR
|
||
try:
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
except Exception:
|
||
pass
|
||
return directory / _DISCORD_COMMAND_SYNC_STATE_FILENAME
|
||
|
||
def _read_command_sync_state(self) -> dict:
|
||
try:
|
||
path = self._command_sync_state_path()
|
||
if not path.exists():
|
||
return {}
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
except Exception:
|
||
return {}
|
||
return data if isinstance(data, dict) else {}
|
||
|
||
def _write_command_sync_state(self, state: dict) -> None:
|
||
atomic_json_write(
|
||
self._command_sync_state_path(),
|
||
state,
|
||
indent=None,
|
||
separators=(",", ":"),
|
||
)
|
||
|
||
def _command_sync_state_key(self, app_id: Any) -> str:
|
||
return str(app_id or "unknown")
|
||
|
||
def _desired_command_sync_fingerprint(self) -> str:
|
||
tree = self._client.tree if self._client else None
|
||
desired = []
|
||
if tree is not None:
|
||
desired = [
|
||
self._canonicalize_app_command_payload(command.to_dict(tree))
|
||
for command in tree.get_commands()
|
||
]
|
||
desired.sort(key=lambda item: (item.get("type", 1), item.get("name", "")))
|
||
payload = json.dumps(desired, sort_keys=True, separators=(",", ":"))
|
||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||
|
||
def _command_sync_skip_reason(self, app_id: Any, fingerprint: str) -> Optional[str]:
|
||
entry = self._read_command_sync_state().get(self._command_sync_state_key(app_id))
|
||
if not isinstance(entry, dict):
|
||
return None
|
||
now = time.time()
|
||
retry_after_until = float(entry.get("retry_after_until") or 0)
|
||
if retry_after_until > now:
|
||
remaining = max(1, int(retry_after_until - now))
|
||
return f"Discord asked us to wait before syncing slash commands; retry in {remaining}s"
|
||
if entry.get("fingerprint") == fingerprint and entry.get("last_success_at"):
|
||
return "same slash-command fingerprint already synced"
|
||
return None
|
||
|
||
def _record_command_sync_attempt(self, app_id: Any, fingerprint: str) -> None:
|
||
state = self._read_command_sync_state()
|
||
state[self._command_sync_state_key(app_id)] = {
|
||
**(
|
||
state.get(self._command_sync_state_key(app_id))
|
||
if isinstance(state.get(self._command_sync_state_key(app_id)), dict)
|
||
else {}
|
||
),
|
||
"fingerprint": fingerprint,
|
||
"last_attempt_at": time.time(),
|
||
}
|
||
self._write_command_sync_state(state)
|
||
|
||
def _record_command_sync_rate_limit(self, app_id: Any, fingerprint: str, retry_after: float) -> None:
|
||
retry_after = max(1.0, float(retry_after))
|
||
state = self._read_command_sync_state()
|
||
state[self._command_sync_state_key(app_id)] = {
|
||
**(
|
||
state.get(self._command_sync_state_key(app_id))
|
||
if isinstance(state.get(self._command_sync_state_key(app_id)), dict)
|
||
else {}
|
||
),
|
||
"fingerprint": fingerprint,
|
||
"last_attempt_at": time.time(),
|
||
"retry_after_until": time.time() + retry_after,
|
||
"retry_after": retry_after,
|
||
}
|
||
self._write_command_sync_state(state)
|
||
|
||
def _record_command_sync_success(self, app_id: Any, fingerprint: str, summary: dict) -> None:
|
||
state = self._read_command_sync_state()
|
||
state[self._command_sync_state_key(app_id)] = {
|
||
"fingerprint": fingerprint,
|
||
"last_attempt_at": time.time(),
|
||
"last_success_at": time.time(),
|
||
"summary": summary,
|
||
}
|
||
self._write_command_sync_state(state)
|
||
|
||
@staticmethod
|
||
def _extract_discord_retry_after(exc: BaseException) -> Optional[float]:
|
||
value = getattr(exc, "retry_after", None)
|
||
if value is not None:
|
||
try:
|
||
return max(1.0, float(value))
|
||
except (TypeError, ValueError):
|
||
return None
|
||
response = getattr(exc, "response", None)
|
||
headers = getattr(response, "headers", None)
|
||
if headers:
|
||
for key in ("Retry-After", "X-RateLimit-Reset-After"):
|
||
try:
|
||
raw = headers.get(key)
|
||
except Exception:
|
||
raw = None
|
||
if raw is None:
|
||
continue
|
||
try:
|
||
return max(1.0, float(raw))
|
||
except (TypeError, ValueError):
|
||
continue
|
||
return None
|
||
|
||
@staticmethod
|
||
def _is_discord_rate_limit(exc: BaseException) -> bool:
|
||
"""True only for exceptions that look like Discord 429 rate limits.
|
||
|
||
Narrower than ``hasattr(exc, 'retry_after')``: discord.py's own
|
||
``RateLimited`` exception and any HTTPException with status 429
|
||
qualify. This prevents suppressing unrelated failures that happen
|
||
to expose a ``retry_after`` attribute."""
|
||
# discord.py emits RateLimited / HTTPException subclasses for 429s.
|
||
# Guard with isinstance-of-class so a mocked ``discord`` module
|
||
# (where attrs are MagicMocks, not types) doesn't trip isinstance.
|
||
if DISCORD_AVAILABLE and discord is not None:
|
||
for attr_name in ("RateLimited", "HTTPException"):
|
||
cls = getattr(discord, attr_name, None)
|
||
if not isinstance(cls, type):
|
||
continue
|
||
if isinstance(exc, cls):
|
||
if attr_name == "RateLimited":
|
||
return True
|
||
status = getattr(exc, "status", None)
|
||
if status == 429:
|
||
return True
|
||
# Fallback duck-type: something named like a rate-limit with a
|
||
# numeric retry_after. Covers mocked clients in tests and exotic
|
||
# transports, without swallowing arbitrary exceptions.
|
||
name = type(exc).__name__.lower()
|
||
if ("ratelimit" in name or "rate_limit" in name) and getattr(exc, "retry_after", None) is not None:
|
||
return True
|
||
response = getattr(exc, "response", None)
|
||
status = getattr(response, "status", None) or getattr(response, "status_code", None)
|
||
if status == 429:
|
||
return True
|
||
return False
|
||
|
||
def _command_sync_mutation_interval_seconds(self) -> float:
|
||
return _DISCORD_COMMAND_SYNC_MUTATION_INTERVAL_SECONDS
|
||
|
||
async def _sleep_between_command_sync_mutations(self) -> None:
|
||
interval = self._command_sync_mutation_interval_seconds()
|
||
if interval > 0:
|
||
await asyncio.sleep(interval)
|
||
|
||
async def _run_post_connect_initialization(self) -> None:
|
||
"""Finish non-critical startup work after Discord is connected."""
|
||
if not self._client:
|
||
return
|
||
try:
|
||
sync_policy = self._get_discord_command_sync_policy()
|
||
if sync_policy == "off":
|
||
logger.info("[%s] Skipping Discord slash command sync (policy=off)", self.name)
|
||
return
|
||
|
||
if sync_policy == "bulk":
|
||
synced = await asyncio.wait_for(self._client.tree.sync(), timeout=30)
|
||
logger.info("[%s] Synced %d slash command(s) via bulk tree sync", self.name, len(synced))
|
||
return
|
||
|
||
app_id = getattr(self._client, "application_id", None) or getattr(getattr(self._client, "user", None), "id", None)
|
||
fingerprint = self._desired_command_sync_fingerprint()
|
||
skip_reason = self._command_sync_skip_reason(app_id, fingerprint)
|
||
if skip_reason:
|
||
logger.info("[%s] Skipping Discord slash command sync: %s", self.name, skip_reason)
|
||
return
|
||
self._record_command_sync_attempt(app_id, fingerprint)
|
||
|
||
http = getattr(self._client, "http", None)
|
||
has_ratelimit_timeout = http is not None and hasattr(http, "max_ratelimit_timeout")
|
||
previous_ratelimit_timeout = getattr(http, "max_ratelimit_timeout", None) if has_ratelimit_timeout else None
|
||
if has_ratelimit_timeout:
|
||
http.max_ratelimit_timeout = _DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS
|
||
|
||
try:
|
||
# Discord's per-app command-management bucket is small, and
|
||
# discord.py can otherwise sit inside one long retry sleep
|
||
# before surfacing the 429. Keep the whole sync bounded and
|
||
# persist Discord's retry-after when it refuses the batch.
|
||
summary = await asyncio.wait_for(self._safe_sync_slash_commands(), timeout=600)
|
||
except Exception as e:
|
||
if not self._is_discord_rate_limit(e):
|
||
raise
|
||
retry_after = self._extract_discord_retry_after(e)
|
||
if retry_after is None:
|
||
# Rate-limited but no retry-after signal — back off for a
|
||
# conservative default so we don't slam the bucket again.
|
||
retry_after = _DISCORD_COMMAND_SYNC_MAX_RATE_LIMIT_SLEEP_SECONDS
|
||
self._record_command_sync_rate_limit(app_id, fingerprint, retry_after)
|
||
logger.warning(
|
||
"[%s] Discord rate-limited slash command sync; retrying after %.0fs",
|
||
self.name,
|
||
retry_after,
|
||
)
|
||
return
|
||
finally:
|
||
if has_ratelimit_timeout:
|
||
http.max_ratelimit_timeout = previous_ratelimit_timeout
|
||
|
||
self._record_command_sync_success(app_id, fingerprint, summary)
|
||
logger.info(
|
||
"[%s] Safely reconciled %d slash command(s): unchanged=%d updated=%d recreated=%d created=%d deleted=%d",
|
||
self.name,
|
||
summary["total"],
|
||
summary["unchanged"],
|
||
summary["updated"],
|
||
summary["recreated"],
|
||
summary["created"],
|
||
summary["deleted"],
|
||
)
|
||
except asyncio.TimeoutError:
|
||
logger.warning(
|
||
"[%s] Slash command sync timed out — Discord rate-limit bucket "
|
||
"may be saturated; will retry on next reconnect",
|
||
self.name,
|
||
)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.warning("[%s] Slash command sync failed: %s", self.name, e, exc_info=True)
|
||
|
||
def _get_discord_command_sync_policy(self) -> str:
|
||
raw = str(os.getenv("DISCORD_COMMAND_SYNC_POLICY", "safe") or "").strip().lower()
|
||
if raw in _DISCORD_COMMAND_SYNC_POLICIES:
|
||
return raw
|
||
if raw:
|
||
logger.warning(
|
||
"[%s] Invalid DISCORD_COMMAND_SYNC_POLICY=%r; falling back to 'safe'",
|
||
self.name,
|
||
raw,
|
||
)
|
||
return "safe"
|
||
|
||
def _canonicalize_app_command_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Reduce command payloads to the semantic fields Hermes manages."""
|
||
contexts = payload.get("contexts")
|
||
integration_types = payload.get("integration_types")
|
||
return {
|
||
"type": int(payload.get("type", 1) or 1),
|
||
"name": str(payload.get("name", "") or ""),
|
||
"description": str(payload.get("description", "") or ""),
|
||
"default_member_permissions": self._normalize_permissions(
|
||
payload.get("default_member_permissions")
|
||
),
|
||
"dm_permission": bool(payload.get("dm_permission", True)),
|
||
"nsfw": bool(payload.get("nsfw", False)),
|
||
"contexts": sorted(int(c) for c in contexts) if contexts else None,
|
||
"integration_types": (
|
||
sorted(int(i) for i in integration_types) if integration_types else None
|
||
),
|
||
"options": [
|
||
self._canonicalize_app_command_option(item)
|
||
for item in payload.get("options", []) or []
|
||
if isinstance(item, dict)
|
||
],
|
||
}
|
||
|
||
@staticmethod
|
||
def _normalize_permissions(value: Any) -> Optional[str]:
|
||
"""Discord emits default_member_permissions as str server-side but discord.py
|
||
sets it as int locally. Normalize to str-or-None so the comparison is stable."""
|
||
if value is None:
|
||
return None
|
||
return str(value)
|
||
|
||
def _existing_command_to_payload(self, command: Any) -> Dict[str, Any]:
|
||
"""Build a canonical-ready dict from an AppCommand.
|
||
|
||
discord.py's AppCommand.to_dict() does NOT include nsfw,
|
||
dm_permission, or default_member_permissions (they live only on the
|
||
attributes). Pull them from the attributes so the canonicalizer sees
|
||
the real server-side values instead of defaults — otherwise any
|
||
command using non-default permissions would diff on every startup.
|
||
"""
|
||
payload = dict(command.to_dict())
|
||
nsfw = getattr(command, "nsfw", None)
|
||
if nsfw is not None:
|
||
payload["nsfw"] = bool(nsfw)
|
||
guild_only = getattr(command, "guild_only", None)
|
||
if guild_only is not None:
|
||
payload["dm_permission"] = not bool(guild_only)
|
||
default_permissions = getattr(command, "default_member_permissions", None)
|
||
if default_permissions is not None:
|
||
payload["default_member_permissions"] = getattr(
|
||
default_permissions, "value", default_permissions
|
||
)
|
||
return payload
|
||
|
||
def _canonicalize_app_command_option(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
return {
|
||
"type": int(payload.get("type", 0) or 0),
|
||
"name": str(payload.get("name", "") or ""),
|
||
"description": str(payload.get("description", "") or ""),
|
||
"required": bool(payload.get("required", False)),
|
||
"autocomplete": bool(payload.get("autocomplete", False)),
|
||
"choices": [
|
||
{
|
||
"name": str(choice.get("name", "") or ""),
|
||
"value": choice.get("value"),
|
||
}
|
||
for choice in payload.get("choices", []) or []
|
||
if isinstance(choice, dict)
|
||
],
|
||
"channel_types": list(payload.get("channel_types", []) or []),
|
||
"min_value": payload.get("min_value"),
|
||
"max_value": payload.get("max_value"),
|
||
"min_length": payload.get("min_length"),
|
||
"max_length": payload.get("max_length"),
|
||
"options": [
|
||
self._canonicalize_app_command_option(item)
|
||
for item in payload.get("options", []) or []
|
||
if isinstance(item, dict)
|
||
],
|
||
}
|
||
|
||
def _patchable_app_command_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Fields supported by discord.py's edit_global_command route."""
|
||
canonical = self._canonicalize_app_command_payload(payload)
|
||
return {
|
||
"name": canonical["name"],
|
||
"description": canonical["description"],
|
||
"options": canonical["options"],
|
||
}
|
||
|
||
async def _safe_sync_slash_commands(self) -> Dict[str, int]:
|
||
"""Diff existing global commands and only mutate the commands that changed."""
|
||
if not self._client:
|
||
return {
|
||
"total": 0,
|
||
"unchanged": 0,
|
||
"updated": 0,
|
||
"recreated": 0,
|
||
"created": 0,
|
||
"deleted": 0,
|
||
}
|
||
|
||
tree = self._client.tree
|
||
app_id = getattr(self._client, "application_id", None) or getattr(getattr(self._client, "user", None), "id", None)
|
||
if not app_id:
|
||
raise RuntimeError("Discord application ID is unavailable for slash command sync")
|
||
|
||
desired_payloads = [command.to_dict(tree) for command in tree.get_commands()]
|
||
desired_by_key = {
|
||
(int(payload.get("type", 1) or 1), str(payload.get("name", "") or "").lower()): payload
|
||
for payload in desired_payloads
|
||
}
|
||
existing_commands = await tree.fetch_commands()
|
||
existing_by_key = {
|
||
(
|
||
int(getattr(getattr(command, "type", None), "value", getattr(command, "type", 1)) or 1),
|
||
str(command.name or "").lower(),
|
||
): command
|
||
for command in existing_commands
|
||
}
|
||
|
||
unchanged = 0
|
||
updated = 0
|
||
recreated = 0
|
||
created = 0
|
||
deleted = 0
|
||
http = self._client.http
|
||
mutation_count = 0
|
||
|
||
async def mutate(call, *args):
|
||
nonlocal mutation_count
|
||
if mutation_count:
|
||
await self._sleep_between_command_sync_mutations()
|
||
result = await call(*args)
|
||
mutation_count += 1
|
||
return result
|
||
|
||
# Delete obsolete commands FIRST to stay under Discord's 100-command
|
||
# limit. Discord rejects an upsert that would push the live total over
|
||
# 100 (error 30032), which silently breaks ALL slash commands. If a new
|
||
# command is created before the obsolete ones are removed, an app that
|
||
# is already at the cap momentarily exceeds it and the whole sync fails.
|
||
# Removing the no-longer-desired commands up front guarantees the live
|
||
# total never rises above the cap mid-sync.
|
||
obsolete_keys = set(existing_by_key.keys()) - set(desired_by_key.keys())
|
||
for key in obsolete_keys:
|
||
current = existing_by_key.pop(key)
|
||
await mutate(http.delete_global_command, app_id, current.id)
|
||
deleted += 1
|
||
|
||
for key, desired in desired_by_key.items():
|
||
current = existing_by_key.pop(key, None)
|
||
if current is None:
|
||
await mutate(http.upsert_global_command, app_id, desired)
|
||
created += 1
|
||
continue
|
||
|
||
current_existing_payload = self._existing_command_to_payload(current)
|
||
current_payload = self._canonicalize_app_command_payload(current_existing_payload)
|
||
desired_payload = self._canonicalize_app_command_payload(desired)
|
||
if current_payload == desired_payload:
|
||
unchanged += 1
|
||
continue
|
||
|
||
if self._patchable_app_command_payload(current_existing_payload) == self._patchable_app_command_payload(desired):
|
||
await mutate(http.delete_global_command, app_id, current.id)
|
||
await mutate(http.upsert_global_command, app_id, desired)
|
||
recreated += 1
|
||
continue
|
||
|
||
await mutate(http.edit_global_command, app_id, current.id, desired)
|
||
updated += 1
|
||
|
||
return {
|
||
"total": len(desired_payloads),
|
||
"unchanged": unchanged,
|
||
"updated": updated,
|
||
"recreated": recreated,
|
||
"created": created,
|
||
"deleted": deleted,
|
||
}
|
||
|
||
async def _add_reaction(self, message: Any, emoji: str) -> bool:
|
||
"""Add an emoji reaction to a Discord message."""
|
||
if not message or not hasattr(message, "add_reaction"):
|
||
return False
|
||
try:
|
||
await message.add_reaction(emoji)
|
||
return True
|
||
except Exception as e:
|
||
logger.debug("[%s] add_reaction failed (%s): %s", self.name, emoji, e)
|
||
return False
|
||
|
||
async def _remove_reaction(self, message: Any, emoji: str) -> bool:
|
||
"""Remove the bot's own emoji reaction from a Discord message."""
|
||
if not message or not hasattr(message, "remove_reaction") or not self._client or not self._client.user:
|
||
return False
|
||
try:
|
||
await message.remove_reaction(emoji, self._client.user)
|
||
return True
|
||
except Exception as e:
|
||
logger.debug("[%s] remove_reaction failed (%s): %s", self.name, emoji, e)
|
||
return False
|
||
|
||
def _reactions_enabled(self) -> bool:
|
||
"""Check if message reactions are enabled via config/env."""
|
||
return os.getenv("DISCORD_REACTIONS", "true").lower() not in {"false", "0", "no"}
|
||
|
||
async def on_processing_start(self, event: MessageEvent) -> None:
|
||
"""Add an in-progress reaction for normal Discord message events."""
|
||
if not self._reactions_enabled():
|
||
return
|
||
message = event.raw_message
|
||
if hasattr(message, "add_reaction"):
|
||
await self._add_reaction(message, "👀")
|
||
|
||
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
||
"""Swap the in-progress reaction for a final success/failure reaction."""
|
||
if not self._reactions_enabled():
|
||
return
|
||
message = event.raw_message
|
||
if hasattr(message, "add_reaction"):
|
||
await self._remove_reaction(message, "👀")
|
||
if outcome == ProcessingOutcome.SUCCESS:
|
||
await self._add_reaction(message, "✅")
|
||
elif outcome == ProcessingOutcome.FAILURE:
|
||
await self._add_reaction(message, "❌")
|
||
|
||
async def send(
|
||
self,
|
||
chat_id: str,
|
||
content: str,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
) -> SendResult:
|
||
"""Send a message to a Discord channel or thread.
|
||
|
||
When metadata contains a thread_id, the message is sent to that
|
||
thread instead of the parent channel identified by chat_id.
|
||
|
||
Forum channels (type 15) reject direct messages — a thread post is
|
||
created automatically.
|
||
"""
|
||
if not self._client:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
try:
|
||
# Determine target channel: thread_id in metadata takes precedence.
|
||
thread_id = None
|
||
if metadata and metadata.get("thread_id"):
|
||
thread_id = metadata["thread_id"]
|
||
nonconversational = _metadata_marks_nonconversational(metadata)
|
||
|
||
if thread_id:
|
||
# Fetch the thread directly — threads are addressed by their own ID.
|
||
channel = self._client.get_channel(int(thread_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(thread_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Thread {thread_id} not found")
|
||
else:
|
||
# Get the parent channel
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||
|
||
# Forum channels reject channel.send() — create a thread post instead.
|
||
if self._is_forum_parent(channel):
|
||
return await self._send_to_forum(channel, content)
|
||
|
||
# Format and split message if needed
|
||
formatted = self.format_message(content)
|
||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||
|
||
message_ids = []
|
||
reference = None
|
||
|
||
if reply_to and self._reply_to_mode != "off":
|
||
try:
|
||
ref_msg = await channel.fetch_message(int(reply_to))
|
||
if hasattr(ref_msg, "to_reference"):
|
||
reference = ref_msg.to_reference(fail_if_not_exists=False)
|
||
else:
|
||
reference = ref_msg
|
||
except Exception as e:
|
||
logger.debug("Could not fetch reply-to message: %s", e)
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
if self._reply_to_mode == "all":
|
||
chunk_reference = reference
|
||
else: # "first" (default) or "off"
|
||
chunk_reference = reference if i == 0 else None
|
||
try:
|
||
msg = await channel.send(
|
||
content=chunk,
|
||
reference=chunk_reference,
|
||
)
|
||
except Exception as e:
|
||
err_text = str(e)
|
||
if (
|
||
chunk_reference is not None
|
||
and (
|
||
(
|
||
"error code: 50035" in err_text
|
||
and "Cannot reply to a system message" in err_text
|
||
)
|
||
or "error code: 10008" in err_text
|
||
)
|
||
):
|
||
logger.warning(
|
||
"[%s] Reply target %s rejected the reply reference; retrying send without reply reference",
|
||
self.name,
|
||
reply_to,
|
||
)
|
||
reference = None
|
||
msg = await channel.send(
|
||
content=chunk,
|
||
reference=None,
|
||
)
|
||
else:
|
||
raise
|
||
message_ids.append(str(msg.id))
|
||
|
||
# Track the last message we sent in this channel for history
|
||
# backfill — avoids a full channel.history() scan on hot paths.
|
||
if message_ids:
|
||
_target_id = thread_id or chat_id
|
||
if nonconversational:
|
||
self._nonconversational_messages.mark_many(message_ids)
|
||
elif not _looks_like_nonconversational_history_message(content):
|
||
self._last_self_message_id[_target_id] = message_ids[-1]
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=message_ids[0] if message_ids else None,
|
||
raw_response={"message_ids": message_ids}
|
||
)
|
||
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult:
|
||
"""Create a thread post in a forum channel with the message as starter content.
|
||
|
||
Forum channels (type 15) don't support direct messages. Instead we
|
||
POST to /channels/{forum_id}/threads with a thread name derived from
|
||
the first line of the message. Any follow-up chunk failures are
|
||
reported in ``raw_response['warnings']`` so the caller can surface
|
||
partial-send issues.
|
||
"""
|
||
# _derive_forum_thread_name is defined further down in this same
|
||
# module — no cross-module import needed.
|
||
|
||
formatted = self.format_message(content)
|
||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||
|
||
thread_name = _derive_forum_thread_name(content)
|
||
|
||
starter_content = chunks[0] if chunks else thread_name
|
||
|
||
try:
|
||
thread = await forum_channel.create_thread(
|
||
name=thread_name,
|
||
content=starter_content,
|
||
)
|
||
except Exception as e:
|
||
logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e)
|
||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||
|
||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||
starter_msg = getattr(thread, "message", None)
|
||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||
|
||
# Send remaining chunks into the newly created thread. Track any
|
||
# per-chunk failures so the caller sees partial-send outcomes.
|
||
message_ids = [message_id]
|
||
warnings: list[str] = []
|
||
for chunk in chunks[1:]:
|
||
try:
|
||
msg = await thread_channel.send(content=chunk)
|
||
message_ids.append(str(msg.id))
|
||
except Exception as e:
|
||
warning = f"Failed to send follow-up chunk to forum thread {thread_id}: {e}"
|
||
logger.warning("[%s] %s", self.name, warning)
|
||
warnings.append(warning)
|
||
|
||
raw_response: Dict[str, Any] = {"message_ids": message_ids, "thread_id": thread_id}
|
||
if warnings:
|
||
raw_response["warnings"] = warnings
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=message_ids[0],
|
||
raw_response=raw_response,
|
||
)
|
||
|
||
async def _forum_post_file(
|
||
self,
|
||
forum_channel: Any,
|
||
*,
|
||
thread_name: Optional[str] = None,
|
||
content: str = "",
|
||
file: Any = None,
|
||
files: Optional[list] = None,
|
||
) -> SendResult:
|
||
"""Create a forum thread whose starter message carries file attachments.
|
||
|
||
Used by the send_voice / send_image_file / send_document paths when
|
||
the target channel is a forum (type 15). ``create_thread`` on a
|
||
ForumChannel accepts the same file/files/content kwargs as
|
||
``channel.send``, creating the thread and starter message atomically.
|
||
"""
|
||
# _derive_forum_thread_name is defined further down in this same
|
||
# module — no cross-module import needed.
|
||
|
||
if not thread_name:
|
||
# Prefer the text content, fall back to the first attached
|
||
# filename, fall back to the generic default.
|
||
hint = content or ""
|
||
if not hint.strip():
|
||
if file is not None:
|
||
hint = getattr(file, "filename", "") or ""
|
||
elif files:
|
||
hint = getattr(files[0], "filename", "") or ""
|
||
thread_name = _derive_forum_thread_name(hint) if hint.strip() else "New Post"
|
||
|
||
kwargs: Dict[str, Any] = {"name": thread_name}
|
||
if content:
|
||
kwargs["content"] = content
|
||
if file is not None:
|
||
kwargs["file"] = file
|
||
if files:
|
||
kwargs["files"] = files
|
||
|
||
try:
|
||
thread = await forum_channel.create_thread(**kwargs)
|
||
except Exception as e:
|
||
logger.error(
|
||
"[%s] Failed to create forum thread with file in %s: %s",
|
||
self.name,
|
||
getattr(forum_channel, "id", "?"),
|
||
e,
|
||
)
|
||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||
|
||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||
starter_msg = getattr(thread, "message", None)
|
||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||
|
||
return SendResult(
|
||
success=True,
|
||
message_id=message_id,
|
||
raw_response={"thread_id": thread_id},
|
||
)
|
||
|
||
async def edit_message(
|
||
self,
|
||
chat_id: str,
|
||
message_id: str,
|
||
content: str,
|
||
*,
|
||
finalize: bool = False,
|
||
) -> SendResult:
|
||
"""Edit a previously sent Discord message."""
|
||
if not self._client:
|
||
return SendResult(success=False, error="Not connected")
|
||
try:
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
msg = await channel.fetch_message(int(message_id))
|
||
formatted = self.format_message(content)
|
||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||
formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||
await msg.edit(content=formatted)
|
||
return SendResult(success=True, message_id=message_id)
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to edit Discord message %s: %s", self.name, message_id, e, exc_info=True)
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def _send_file_attachment(
|
||
self,
|
||
chat_id: str,
|
||
file_path: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
) -> SendResult:
|
||
"""Send a local file as a Discord attachment.
|
||
|
||
Forum channels (type 15) get a new thread whose starter message
|
||
carries the file — they reject direct POST /messages.
|
||
"""
|
||
if not self._client:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||
|
||
filename = file_name or os.path.basename(file_path)
|
||
with open(file_path, "rb") as fh:
|
||
file = discord.File(fh, filename=filename)
|
||
if self._is_forum_parent(channel):
|
||
return await self._forum_post_file(
|
||
channel,
|
||
content=(caption or "").strip(),
|
||
file=file,
|
||
)
|
||
msg = await channel.send(content=caption if caption else None, file=file)
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
|
||
async def send_multiple_images(
|
||
self,
|
||
chat_id: str,
|
||
images: List[Tuple[str, str]],
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
human_delay: float = 0.0,
|
||
) -> None:
|
||
"""Send a batch of images as a single Discord message with multiple attachments.
|
||
|
||
Discord permits up to 10 file attachments per message. Batches are
|
||
chunked accordingly. URL images are downloaded into memory and
|
||
uploaded as inline attachments (same pattern as ``send_image`` so
|
||
they render inline, not as bare links). Local files are opened
|
||
directly. On per-chunk failure the remaining images in that chunk
|
||
fall back to the base per-image loop.
|
||
"""
|
||
if not self._client:
|
||
return
|
||
if not images:
|
||
return
|
||
|
||
try:
|
||
import discord as _discord_mod
|
||
import io as _io
|
||
from urllib.parse import unquote as _unquote
|
||
except Exception: # pragma: no cover
|
||
await super().send_multiple_images(chat_id, images, metadata, human_delay)
|
||
return
|
||
|
||
try:
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
logger.warning("[%s] Channel %s not found for multi-image send", self.name, chat_id)
|
||
return
|
||
except Exception as e:
|
||
logger.warning("[%s] Failed to resolve channel for multi-image send: %s", self.name, e)
|
||
await super().send_multiple_images(chat_id, images, metadata, human_delay)
|
||
return
|
||
|
||
CHUNK = 10
|
||
chunks = [images[i:i + CHUNK] for i in range(0, len(images), CHUNK)]
|
||
|
||
for chunk_idx, chunk in enumerate(chunks):
|
||
if human_delay > 0 and chunk_idx > 0:
|
||
await asyncio.sleep(human_delay)
|
||
|
||
files: List[Any] = []
|
||
captions: List[str] = []
|
||
aiohttp_session = None
|
||
try:
|
||
for image_url, alt_text in chunk:
|
||
if alt_text:
|
||
captions.append(alt_text)
|
||
if image_url.startswith("file://"):
|
||
local_path = _unquote(image_url[7:])
|
||
if not os.path.exists(local_path):
|
||
logger.warning("[%s] Skipping missing image: %s", self.name, local_path)
|
||
continue
|
||
files.append(_discord_mod.File(local_path, filename=os.path.basename(local_path)))
|
||
else:
|
||
if not is_safe_url(image_url):
|
||
logger.warning("[%s] Blocked unsafe image URL in batch", self.name)
|
||
continue
|
||
# Download to BytesIO so it renders inline
|
||
try:
|
||
import aiohttp as _aiohttp
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||
if aiohttp_session is None:
|
||
aiohttp_session = _aiohttp.ClientSession(**_sess_kw)
|
||
async with aiohttp_session.get(
|
||
image_url, timeout=_aiohttp.ClientTimeout(total=30), **_req_kw,
|
||
) as resp:
|
||
if resp.status != 200:
|
||
logger.warning(
|
||
"[%s] Failed to download image (HTTP %d) in batch: %s",
|
||
self.name, resp.status, image_url[:80],
|
||
)
|
||
continue
|
||
data = await resp.read()
|
||
ct = resp.headers.get("content-type", "image/png")
|
||
ext = "png"
|
||
if "jpeg" in ct or "jpg" in ct:
|
||
ext = "jpg"
|
||
elif "gif" in ct:
|
||
ext = "gif"
|
||
elif "webp" in ct:
|
||
ext = "webp"
|
||
files.append(_discord_mod.File(_io.BytesIO(data), filename=f"image_{len(files)}.{ext}"))
|
||
except Exception as dl_err:
|
||
logger.warning("[%s] Download failed for %s: %s", self.name, image_url[:80], dl_err)
|
||
continue
|
||
|
||
if not files:
|
||
continue
|
||
|
||
# Use the first caption if any (Discord only has one message body for the group)
|
||
content = captions[0] if captions else None
|
||
logger.info(
|
||
"[%s] Sending %d image(s) as single Discord message (chunk %d/%d)",
|
||
self.name, len(files), chunk_idx + 1, len(chunks),
|
||
)
|
||
|
||
if self._is_forum_parent(channel):
|
||
await self._forum_post_file(
|
||
channel,
|
||
content=(content or "").strip(),
|
||
files=files,
|
||
)
|
||
else:
|
||
await channel.send(content=content, files=files)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"[%s] Multi-image Discord send failed (chunk %d/%d), falling back to per-image: %s",
|
||
self.name, chunk_idx + 1, len(chunks), e,
|
||
exc_info=True,
|
||
)
|
||
await super().send_multiple_images(chat_id, chunk, metadata, human_delay=human_delay)
|
||
finally:
|
||
if aiohttp_session is not None:
|
||
try:
|
||
await aiohttp_session.close()
|
||
except Exception:
|
||
pass
|
||
|
||
async def play_tts(
|
||
self,
|
||
chat_id: str,
|
||
audio_path: str,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
"""Play auto-TTS audio.
|
||
|
||
When the bot is in a voice channel for this chat's guild, play
|
||
directly in the VC instead of sending as a file attachment.
|
||
"""
|
||
for gid, text_ch_id in self._voice_text_channels.items():
|
||
if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid):
|
||
logger.info("[%s] Playing TTS in voice channel (guild=%d)", self.name, gid)
|
||
success = await self.play_in_voice_channel(gid, audio_path)
|
||
return SendResult(success=success)
|
||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||
|
||
async def send_voice(
|
||
self,
|
||
chat_id: str,
|
||
audio_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
**kwargs,
|
||
) -> SendResult:
|
||
"""Send audio as a Discord file attachment."""
|
||
try:
|
||
import io
|
||
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||
|
||
if not os.path.exists(audio_path):
|
||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||
|
||
filename = os.path.basename(audio_path)
|
||
|
||
with open(audio_path, "rb") as f:
|
||
file_data = f.read()
|
||
|
||
# Forum channels (type 15) reject direct POST /messages — the
|
||
# native voice flag path also targets /messages so it would fail
|
||
# too. Create a thread post with the audio as the starter
|
||
# attachment instead.
|
||
if self._is_forum_parent(channel):
|
||
forum_file = discord.File(io.BytesIO(file_data), filename=filename)
|
||
return await self._forum_post_file(
|
||
channel,
|
||
content=(caption or "").strip(),
|
||
file=forum_file,
|
||
)
|
||
|
||
# Try sending as a native voice message via raw API (flags=8192).
|
||
try:
|
||
import base64
|
||
|
||
duration_secs = 5.0
|
||
try:
|
||
from mutagen.oggopus import OggOpus
|
||
info = OggOpus(audio_path)
|
||
duration_secs = info.info.length
|
||
except Exception:
|
||
duration_secs = max(1.0, len(file_data) / 2000.0)
|
||
|
||
waveform_bytes = bytes([128] * 256)
|
||
waveform_b64 = base64.b64encode(waveform_bytes).decode()
|
||
|
||
import json as _json
|
||
payload = _json.dumps({
|
||
"flags": 8192,
|
||
"attachments": [{
|
||
"id": "0",
|
||
"filename": "voice-message.ogg",
|
||
"duration_secs": round(duration_secs, 2),
|
||
"waveform": waveform_b64,
|
||
}],
|
||
})
|
||
form = [
|
||
{"name": "payload_json", "value": payload},
|
||
{
|
||
"name": "files[0]",
|
||
"value": file_data,
|
||
"filename": "voice-message.ogg",
|
||
"content_type": "audio/ogg",
|
||
},
|
||
]
|
||
msg_data = await self._client.http.request(
|
||
discord.http.Route("POST", "/channels/{channel_id}/messages", channel_id=channel.id),
|
||
form=form,
|
||
)
|
||
return SendResult(success=True, message_id=str(msg_data["id"]))
|
||
except Exception as voice_err:
|
||
logger.debug("Voice message flag failed, falling back to file: %s", voice_err)
|
||
file = discord.File(io.BytesIO(file_data), filename=filename)
|
||
msg = await channel.send(file=file)
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to send audio, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||
return await super().send_voice(chat_id, audio_path, caption, reply_to, metadata=metadata)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Voice channel methods (join / leave / play)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _load_voice_fx_config(self) -> Dict[str, Any]:
|
||
"""Read voice mixer / ambient / ack settings from config.yaml.
|
||
|
||
All settings live under ``discord.voice_fx`` in config.yaml (NOT the
|
||
.env file — these are behavioral, not secrets). The feature is OFF by
|
||
default; users opt in with ``discord.voice_fx.enabled: true``.
|
||
|
||
Returns a dict with safe defaults so callers never KeyError.
|
||
"""
|
||
defaults: Dict[str, Any] = {
|
||
"enabled": False, # master switch for the mixer subsystem
|
||
"ambient_enabled": True, # idle "thinking" bed while tools run
|
||
"ambient_path": "", # optional custom loop file; "" = synthesised
|
||
"ambient_gain": 0.18, # idle bed loudness (0..1)
|
||
"duck_gain": 0.06, # ambient loudness while speech plays
|
||
"speech_gain": 1.0, # TTS / ack loudness
|
||
"ack_enabled": True, # speak a short phrase before tool calls
|
||
"ack_phrases": [
|
||
"Let me look into that.",
|
||
"One moment.",
|
||
"Checking on that now.",
|
||
"Give me a sec.",
|
||
"On it.",
|
||
],
|
||
}
|
||
try:
|
||
from hermes_cli.config import read_raw_config
|
||
cfg = read_raw_config() or {}
|
||
fx = ((cfg.get("discord") or {}).get("voice_fx") or {})
|
||
if isinstance(fx, dict):
|
||
for k, v in fx.items():
|
||
if k in defaults and v is not None:
|
||
defaults[k] = v
|
||
except Exception as e:
|
||
logger.debug("Could not load discord.voice_fx config: %s", e)
|
||
return defaults
|
||
|
||
def _get_ambient_pcm(self) -> Optional[bytes]:
|
||
"""Return decoded 48k/stereo/s16le PCM for the ambient idle bed.
|
||
|
||
Uses a custom file when ``ambient_path`` is set and decodable, else a
|
||
synthesised pad. Cached after first build.
|
||
"""
|
||
if self._ambient_pcm_cache is not None:
|
||
return self._ambient_pcm_cache
|
||
if not self._voice_fx_cfg.get("ambient_enabled"):
|
||
return None
|
||
try:
|
||
from voice_mixer import decode_to_pcm, synth_ambient_pcm
|
||
except ImportError:
|
||
from .voice_mixer import decode_to_pcm, synth_ambient_pcm
|
||
|
||
pcm: Optional[bytes] = None
|
||
path = (self._voice_fx_cfg.get("ambient_path") or "").strip()
|
||
if path and os.path.isfile(path):
|
||
pcm = decode_to_pcm(path)
|
||
if not pcm:
|
||
logger.warning("Ambient file %s failed to decode; using synth bed", path)
|
||
if not pcm:
|
||
pcm = synth_ambient_pcm()
|
||
self._ambient_pcm_cache = pcm
|
||
return pcm
|
||
|
||
async def _install_voice_mixer(self, guild_id: int, vc) -> None:
|
||
"""Create a VoiceMixer, start the ambient bed, and play it on the VC.
|
||
|
||
The mixer runs continuously for the life of the connection: one
|
||
``vc.play(mixer)`` call, never stopped until leave.
|
||
"""
|
||
try:
|
||
from voice_mixer import VoiceMixer
|
||
except ImportError:
|
||
from .voice_mixer import VoiceMixer
|
||
|
||
mixer = VoiceMixer(
|
||
ambient_gain=float(self._voice_fx_cfg.get("ambient_gain", 0.18)),
|
||
duck_gain=float(self._voice_fx_cfg.get("duck_gain", 0.06)),
|
||
speech_gain=float(self._voice_fx_cfg.get("speech_gain", 1.0)),
|
||
)
|
||
ambient = await asyncio.to_thread(self._get_ambient_pcm)
|
||
if ambient:
|
||
mixer.set_ambient(ambient)
|
||
|
||
def _after(error):
|
||
if error:
|
||
logger.error("Voice mixer stream error (guild=%d): %s", guild_id, error)
|
||
|
||
if vc.is_playing():
|
||
vc.stop()
|
||
vc.play(mixer, after=_after)
|
||
self._voice_mixers[guild_id] = mixer
|
||
logger.info("Voice mixer installed (guild=%d, ambient=%s)", guild_id, bool(ambient))
|
||
|
||
async def play_ack_in_voice(self, guild_id: int, phrase: Optional[str] = None) -> bool:
|
||
"""Speak a short acknowledgement over the ambient bed.
|
||
|
||
Called from the gateway's tool-progress hook on the first tool call of
|
||
a turn, so the user hears "let me look into that" before the bot goes
|
||
quiet to work. No-op unless the mixer is installed and acks enabled.
|
||
"""
|
||
if not self._voice_fx_cfg.get("ack_enabled"):
|
||
return False
|
||
mixer = self._voice_mixers.get(guild_id)
|
||
if mixer is None:
|
||
return False
|
||
if phrase is None:
|
||
import random
|
||
phrases = self._voice_fx_cfg.get("ack_phrases") or ["One moment."]
|
||
phrase = random.choice(phrases)
|
||
|
||
# Synthesise the ack via the configured TTS provider, then layer it.
|
||
import uuid as _uuid
|
||
audio_path = os.path.join(
|
||
tempfile.gettempdir(), "hermes_voice",
|
||
f"ack_{_uuid.uuid4().hex[:12]}.mp3",
|
||
)
|
||
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
|
||
try:
|
||
from tools.tts_tool import text_to_speech_tool
|
||
result_json = await asyncio.to_thread(
|
||
text_to_speech_tool, text=phrase, output_path=audio_path
|
||
)
|
||
result = json.loads(result_json)
|
||
actual = result.get("file_path", audio_path)
|
||
if not result.get("success") or not os.path.isfile(actual):
|
||
return False
|
||
try:
|
||
from voice_mixer import decode_to_pcm
|
||
except ImportError:
|
||
from .voice_mixer import decode_to_pcm
|
||
pcm = await asyncio.to_thread(decode_to_pcm, actual)
|
||
if not pcm:
|
||
return False
|
||
mixer.play_speech(
|
||
pcm, gain=float(self._voice_fx_cfg.get("speech_gain", 1.0))
|
||
)
|
||
self._reset_voice_timeout(guild_id)
|
||
return True
|
||
except Exception as e:
|
||
logger.debug("play_ack_in_voice failed: %s", e)
|
||
return False
|
||
finally:
|
||
for p in {audio_path, locals().get("actual")}:
|
||
if p and os.path.isfile(p):
|
||
try:
|
||
os.unlink(p)
|
||
except OSError:
|
||
pass
|
||
|
||
def voice_mixer_active(self, guild_id: int) -> bool:
|
||
"""True when a continuous mixer is installed for this guild."""
|
||
mixers = getattr(self, "_voice_mixers", None)
|
||
return bool(mixers) and mixers.get(guild_id) is not None
|
||
|
||
async def join_voice_channel(self, channel) -> bool:
|
||
"""Join a Discord voice channel. Returns True on success."""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return False
|
||
guild_id = channel.guild.id
|
||
|
||
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
|
||
# Already connected in this guild?
|
||
existing = self._voice_clients.get(guild_id)
|
||
if existing and existing.is_connected():
|
||
if existing.channel.id == channel.id:
|
||
self._reset_voice_timeout(guild_id)
|
||
return True
|
||
await existing.move_to(channel)
|
||
self._reset_voice_timeout(guild_id)
|
||
return True
|
||
|
||
vc = await channel.connect()
|
||
self._voice_clients[guild_id] = vc
|
||
self._reset_voice_timeout(guild_id)
|
||
|
||
# Start voice receiver (Phase 2: listen to users)
|
||
try:
|
||
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
|
||
receiver.start()
|
||
self._voice_receivers[guild_id] = receiver
|
||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||
self._voice_listen_loop(guild_id)
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Voice receiver failed to start: %s", e)
|
||
|
||
# Phase 3: install the continuous mixer (ambient bed + ducked
|
||
# speech). Best-effort — if it fails we fall back to the legacy
|
||
# one-shot FFmpegPCMAudio playback path in play_in_voice_channel.
|
||
if getattr(self, "_voice_fx_cfg", {}).get("enabled"):
|
||
try:
|
||
await self._install_voice_mixer(guild_id, vc)
|
||
except Exception as e:
|
||
logger.warning("Voice mixer failed to start: %s", e)
|
||
|
||
return True
|
||
|
||
async def leave_voice_channel(self, guild_id: int) -> None:
|
||
"""Disconnect from the voice channel in a guild."""
|
||
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
|
||
# Stop voice receiver first
|
||
receiver = self._voice_receivers.pop(guild_id, None)
|
||
if receiver:
|
||
receiver.stop()
|
||
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||
if listen_task:
|
||
listen_task.cancel()
|
||
|
||
# Tear down the mixer (stops the continuous outgoing stream).
|
||
if getattr(self, "_voice_mixers", None) is not None:
|
||
self._voice_mixers.pop(guild_id, None)
|
||
|
||
vc = self._voice_clients.pop(guild_id, None)
|
||
if vc and vc.is_connected():
|
||
try:
|
||
if vc.is_playing():
|
||
vc.stop()
|
||
except Exception:
|
||
pass
|
||
await vc.disconnect()
|
||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||
if task:
|
||
task.cancel()
|
||
self._voice_text_channels.pop(guild_id, None)
|
||
self._voice_sources.pop(guild_id, None)
|
||
|
||
# Maximum seconds to wait for voice playback before giving up
|
||
PLAYBACK_TIMEOUT = 120
|
||
|
||
async def play_in_voice_channel(self, guild_id: int, audio_path: str) -> bool:
|
||
"""Play an audio file in the connected voice channel.
|
||
|
||
When the continuous mixer is installed for this guild, the clip is
|
||
decoded to PCM and layered over the ambient bed (ducking it) so the
|
||
reply can overlap the idle "thinking" loop seamlessly. Otherwise we
|
||
fall back to the legacy one-shot FFmpegPCMAudio path.
|
||
"""
|
||
vc = self._voice_clients.get(guild_id)
|
||
if not vc or not vc.is_connected():
|
||
return False
|
||
|
||
# ── Mixer path (overlap + ducking) ──────────────────────────────
|
||
mixer = getattr(self, "_voice_mixers", {}).get(guild_id) if getattr(self, "_voice_mixers", None) else None
|
||
if mixer is not None:
|
||
try:
|
||
from voice_mixer import decode_to_pcm
|
||
except ImportError:
|
||
from .voice_mixer import decode_to_pcm
|
||
pcm = await asyncio.to_thread(decode_to_pcm, audio_path)
|
||
if pcm:
|
||
speech_gain = float(self._voice_fx_cfg.get("speech_gain", 1.0))
|
||
mixer.play_speech(pcm, gain=speech_gain)
|
||
# Block until the speech child drains so callers serialise
|
||
# replies (mirrors legacy semantics) but the ambient keeps
|
||
# playing underneath the whole time.
|
||
wait_start = time.monotonic()
|
||
while mixer.speech_active:
|
||
if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT:
|
||
logger.warning("Mixer speech playback timed out after %ds", self.PLAYBACK_TIMEOUT)
|
||
mixer.stop_speech()
|
||
break
|
||
await asyncio.sleep(0.05)
|
||
self._reset_voice_timeout(guild_id)
|
||
return True
|
||
logger.warning("Mixer decode failed for %s; falling back to legacy playback", audio_path)
|
||
|
||
# ── Legacy one-shot path (no mixer) ─────────────────────────────
|
||
# Pause voice receiver while playing (echo prevention)
|
||
receiver = self._voice_receivers.get(guild_id)
|
||
if receiver:
|
||
receiver.pause()
|
||
|
||
try:
|
||
# Wait for current playback to finish (with timeout)
|
||
wait_start = time.monotonic()
|
||
while vc.is_playing():
|
||
if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT:
|
||
logger.warning("Timed out waiting for previous playback to finish")
|
||
vc.stop()
|
||
break
|
||
await asyncio.sleep(0.1)
|
||
|
||
done = asyncio.Event()
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def _after(error):
|
||
if error:
|
||
logger.error("Voice playback error: %s", error)
|
||
loop.call_soon_threadsafe(done.set)
|
||
|
||
source = discord.FFmpegPCMAudio(audio_path)
|
||
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
||
vc.play(source, after=_after)
|
||
try:
|
||
await asyncio.wait_for(done.wait(), timeout=self.PLAYBACK_TIMEOUT)
|
||
except asyncio.TimeoutError:
|
||
logger.warning("Voice playback timed out after %ds", self.PLAYBACK_TIMEOUT)
|
||
vc.stop()
|
||
self._reset_voice_timeout(guild_id)
|
||
return True
|
||
finally:
|
||
if receiver:
|
||
receiver.resume()
|
||
|
||
async def get_user_voice_channel(self, guild_id: int, user_id: str):
|
||
"""Return the voice channel the user is currently in, or None."""
|
||
if not self._client:
|
||
return None
|
||
guild = self._client.get_guild(guild_id)
|
||
if not guild:
|
||
return None
|
||
member = guild.get_member(int(user_id))
|
||
if not member or not member.voice:
|
||
return None
|
||
return member.voice.channel
|
||
|
||
def _reset_voice_timeout(self, guild_id: int) -> None:
|
||
"""Reset the auto-disconnect inactivity timer."""
|
||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||
if task:
|
||
task.cancel()
|
||
self._voice_timeout_tasks[guild_id] = asyncio.ensure_future(
|
||
self._voice_timeout_handler(guild_id)
|
||
)
|
||
|
||
async def _voice_timeout_handler(self, guild_id: int) -> None:
|
||
"""Auto-disconnect after VOICE_TIMEOUT seconds of inactivity."""
|
||
try:
|
||
await asyncio.sleep(self.VOICE_TIMEOUT)
|
||
except asyncio.CancelledError:
|
||
return
|
||
text_ch_id = self._voice_text_channels.get(guild_id)
|
||
# ``/voice off`` mutes spoken replies but deliberately keeps the bot in
|
||
# the channel (leaving is ``/voice leave``). The inactivity timer only
|
||
# counts the bot's OWN audio as activity, so under voice-off mode it
|
||
# fires every VOICE_TIMEOUT seconds, yanks the bot out, and spams the
|
||
# text channel with "Left voice channel (inactivity timeout)." Honor the
|
||
# user's choice: skip the auto-disconnect while voice replies are off.
|
||
# (The timer re-arms when the bot next speaks or hears a user.)
|
||
_mode_getter = getattr(self, "_voice_mode_getter", None)
|
||
if text_ch_id is not None and _mode_getter is not None:
|
||
try:
|
||
if _mode_getter(str(text_ch_id)) == "off":
|
||
return
|
||
except Exception:
|
||
pass
|
||
await self.leave_voice_channel(guild_id)
|
||
# Notify the runner so it can clean up voice_mode state
|
||
if self._on_voice_disconnect and text_ch_id:
|
||
try:
|
||
self._on_voice_disconnect(str(text_ch_id))
|
||
except Exception:
|
||
pass
|
||
if text_ch_id and self._client:
|
||
ch = self._client.get_channel(text_ch_id)
|
||
if ch:
|
||
try:
|
||
await ch.send("Left voice channel (inactivity timeout).")
|
||
except Exception:
|
||
pass
|
||
|
||
def is_in_voice_channel(self, guild_id: int) -> bool:
|
||
"""Check if the bot is connected to a voice channel in this guild."""
|
||
vc = self._voice_clients.get(guild_id)
|
||
return vc is not None and vc.is_connected()
|
||
|
||
def get_voice_channel_info(self, guild_id: int) -> Optional[Dict[str, Any]]:
|
||
"""Return voice channel awareness info for the given guild.
|
||
|
||
Returns None if the bot is not in a voice channel. Otherwise
|
||
returns a dict with channel name, member list, count, and
|
||
currently-speaking user IDs (from SSRC mapping).
|
||
"""
|
||
vc = self._voice_clients.get(guild_id)
|
||
if not vc or not vc.is_connected():
|
||
return None
|
||
|
||
channel = vc.channel
|
||
if not channel:
|
||
return None
|
||
|
||
# Members currently in the voice channel (includes bot)
|
||
members_info = []
|
||
bot_user = self._client.user if self._client else None
|
||
for m in channel.members:
|
||
if bot_user and m.id == bot_user.id:
|
||
continue # skip the bot itself
|
||
members_info.append({
|
||
"user_id": m.id,
|
||
"display_name": m.display_name,
|
||
"is_bot": m.bot,
|
||
})
|
||
|
||
# Currently speaking users (from SSRC mapping + active buffers)
|
||
speaking_user_ids: set = set()
|
||
receiver = self._voice_receivers.get(guild_id)
|
||
if receiver:
|
||
now = time.monotonic()
|
||
with receiver._lock:
|
||
for ssrc, last_t in receiver._last_packet_time.items():
|
||
# Consider "speaking" if audio received within last 2 seconds
|
||
if now - last_t < 2.0:
|
||
uid = receiver._ssrc_to_user.get(ssrc)
|
||
if uid:
|
||
speaking_user_ids.add(uid)
|
||
|
||
# Tag speaking status on members
|
||
for info in members_info:
|
||
info["is_speaking"] = info["user_id"] in speaking_user_ids
|
||
|
||
return {
|
||
"channel_name": channel.name,
|
||
"member_count": len(members_info),
|
||
"members": members_info,
|
||
"speaking_count": len(speaking_user_ids),
|
||
}
|
||
|
||
def get_voice_channel_context(self, guild_id: int) -> str:
|
||
"""Return a human-readable voice channel context string.
|
||
|
||
Suitable for injection into the system/ephemeral prompt so the
|
||
agent is always aware of voice channel state.
|
||
"""
|
||
info = self.get_voice_channel_info(guild_id)
|
||
if not info:
|
||
return ""
|
||
|
||
parts = [f"[Voice channel: #{info['channel_name']} — {info['member_count']} participant(s)]"]
|
||
for m in info["members"]:
|
||
status = " (speaking)" if m["is_speaking"] else ""
|
||
parts.append(f" - {m['display_name']}{status}")
|
||
|
||
return "\n".join(parts)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Voice listening (Phase 2)
|
||
# ------------------------------------------------------------------
|
||
|
||
# UDP keepalive interval in seconds — prevents Discord from dropping
|
||
# the UDP route after ~60s of silence.
|
||
_KEEPALIVE_INTERVAL = 15
|
||
|
||
async def _voice_listen_loop(self, guild_id: int):
|
||
"""Periodically check for completed utterances and process them."""
|
||
receiver = self._voice_receivers.get(guild_id)
|
||
if not receiver:
|
||
return
|
||
last_keepalive = time.monotonic()
|
||
try:
|
||
while receiver._running:
|
||
await asyncio.sleep(0.2)
|
||
|
||
# Send periodic UDP keepalive to prevent Discord from
|
||
# dropping the UDP session after ~60s of silence.
|
||
now = time.monotonic()
|
||
if now - last_keepalive >= self._KEEPALIVE_INTERVAL:
|
||
last_keepalive = now
|
||
try:
|
||
vc = self._voice_clients.get(guild_id)
|
||
if vc and vc.is_connected():
|
||
vc._connection.send_packet(b'\xf8\xff\xfe')
|
||
except Exception:
|
||
pass
|
||
|
||
completed = receiver.check_silence()
|
||
# Voice inputs always originate from a specific guild
|
||
# (guild_id is in scope). Pass it so role checks are
|
||
# guild-scoped and not cross-guild.
|
||
_vc_guild = self._client.get_guild(guild_id) if self._client is not None else None
|
||
for user_id, pcm_data in completed:
|
||
if not self._is_allowed_user(
|
||
str(user_id),
|
||
guild=_vc_guild,
|
||
is_dm=False,
|
||
):
|
||
continue
|
||
# A user speaking to the bot is activity too — not just the
|
||
# bot's own playback. Reset the inactivity timer so an active
|
||
# listener isn't disconnected mid-conversation (this also
|
||
# covers voice-on text-only sessions that never play audio).
|
||
self._reset_voice_timeout(guild_id)
|
||
await self._process_voice_input(guild_id, user_id, pcm_data)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error("Voice listen loop error: %s", e, exc_info=True)
|
||
|
||
async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes):
|
||
"""Convert PCM -> WAV -> STT -> callback."""
|
||
from tools.voice_mode import is_whisper_hallucination
|
||
|
||
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
|
||
wav_path = tmp_f.name
|
||
tmp_f.close()
|
||
try:
|
||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||
|
||
from tools.transcription_tools import transcribe_audio
|
||
result = await asyncio.to_thread(transcribe_audio, wav_path)
|
||
|
||
if not result.get("success"):
|
||
return
|
||
transcript = result.get("transcript", "").strip()
|
||
if not transcript or is_whisper_hallucination(transcript):
|
||
return
|
||
|
||
logger.info("Voice input from user %d: %s", user_id, transcript[:100])
|
||
|
||
if self._voice_input_callback:
|
||
await self._voice_input_callback(
|
||
guild_id=guild_id,
|
||
user_id=user_id,
|
||
transcript=transcript,
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Voice input processing failed: %s", e, exc_info=True)
|
||
finally:
|
||
try:
|
||
os.unlink(wav_path)
|
||
except OSError:
|
||
pass
|
||
|
||
def _is_allowed_user(
|
||
self,
|
||
user_id: str,
|
||
author=None,
|
||
*,
|
||
guild=None,
|
||
is_dm: bool = False,
|
||
) -> bool:
|
||
"""Check if user is allowed via DISCORD_ALLOWED_USERS or DISCORD_ALLOWED_ROLES.
|
||
|
||
Uses OR semantics: if the user matches EITHER allowlist, they're allowed.
|
||
If both allowlists are empty, everyone is allowed (backwards compatible).
|
||
|
||
Role checks are **scoped to the guild the message originated from**.
|
||
For DMs (no guild context), role-based auth is disabled by default and
|
||
only user-ID allowlist applies. Set ``discord.dm_role_auth_guild``
|
||
in config.yaml to a specific guild ID to opt-in: role membership in
|
||
that one guild will authorize DMs. This prevents cross-guild
|
||
privilege escalation where a user with the configured role in any
|
||
shared public server could DM the bot and pass the allowlist.
|
||
|
||
Args:
|
||
user_id: Author ID as a string.
|
||
author: Optional Member/User object for in-guild role lookup.
|
||
guild: The guild the message arrived in (None for DMs).
|
||
is_dm: True if the message came from a DM channel.
|
||
"""
|
||
# ``getattr`` fallbacks here guard against test fixtures that build
|
||
# an adapter via ``object.__new__(DiscordAdapter)`` and skip __init__
|
||
# (see AGENTS.md pitfall #17 — same pattern as gateway.run).
|
||
allowed_users = getattr(self, "_allowed_user_ids", set())
|
||
allowed_roles = getattr(self, "_allowed_role_ids", set())
|
||
has_users = bool(allowed_users)
|
||
has_roles = bool(allowed_roles)
|
||
if not has_users and not has_roles:
|
||
return True
|
||
# Check user ID allowlist (works for both DMs and guild messages)
|
||
if has_users and user_id in allowed_users:
|
||
return True
|
||
# Role allowlist is only consulted when configured.
|
||
if not has_roles:
|
||
return False
|
||
|
||
# DM path: roles require explicit opt-in via
|
||
# ``discord.dm_role_auth_guild`` in config.yaml. Without this, a
|
||
# user with the configured role in ANY mutual guild could DM the
|
||
# bot and bypass the allowlist (cross-guild leakage).
|
||
if is_dm or guild is None:
|
||
dm_guild_id = _read_dm_role_auth_guild()
|
||
if dm_guild_id is None:
|
||
return False
|
||
if self._client is None:
|
||
return False
|
||
dm_guild = self._client.get_guild(dm_guild_id)
|
||
if dm_guild is None:
|
||
return False
|
||
try:
|
||
uid_int = int(user_id)
|
||
except (TypeError, ValueError):
|
||
return False
|
||
m = dm_guild.get_member(uid_int)
|
||
if m is None:
|
||
return False
|
||
m_roles = getattr(m, "roles", None) or []
|
||
return any(getattr(r, "id", None) in allowed_roles for r in m_roles)
|
||
|
||
# Guild path: role check is scoped to THIS guild only.
|
||
# 1) Prefer the direct Member object passed in (correct guild by construction).
|
||
direct_roles = getattr(author, "roles", None) if author is not None else None
|
||
author_guild = getattr(author, "guild", None)
|
||
if direct_roles and (author_guild is None or author_guild.id == guild.id):
|
||
if any(getattr(r, "id", None) in allowed_roles for r in direct_roles):
|
||
return True
|
||
# 2) Fallback: resolve the Member in the message's guild only — NEVER
|
||
# scan other mutual guilds (that is the cross-guild bypass bug).
|
||
try:
|
||
uid_int = int(user_id)
|
||
except (TypeError, ValueError):
|
||
return False
|
||
m = guild.get_member(uid_int)
|
||
if m is None:
|
||
return False
|
||
m_roles = getattr(m, "roles", None) or []
|
||
return any(getattr(r, "id", None) in allowed_roles for r in m_roles)
|
||
|
||
# ── Slash command authorization ─────────────────────────────────────
|
||
# Slash commands (``_run_simple_slash`` and ``_handle_thread_create_slash``)
|
||
# are a separate Discord interaction surface from regular messages and
|
||
# historically ran with NO authorization check — bypassing every gate
|
||
# ``on_message`` enforces (DISCORD_ALLOWED_USERS, DISCORD_ALLOWED_ROLES,
|
||
# DISCORD_ALLOWED_CHANNELS, DISCORD_IGNORED_CHANNELS). Any guild member
|
||
# could invoke ``/background``, ``/restart``, ``/sethome``, etc. as the
|
||
# operator. ``_check_slash_authorization`` mirrors the on_message gates
|
||
# one-for-one so the slash surface honors the same trust boundary.
|
||
#
|
||
# By design, this is a no-op for deployments with no allowlist env vars
|
||
# set — ``_is_allowed_user`` returns True and the channel checks early-out
|
||
# — preserving the existing "single-tenant, all guild members trusted"
|
||
# default. Deployments that DO set any DISCORD_ALLOWED_* var get slash
|
||
# parity with on_message.
|
||
|
||
def _evaluate_slash_authorization(
|
||
self, interaction: "discord.Interaction",
|
||
) -> Tuple[bool, Optional[str]]:
|
||
"""Evaluate slash authorization without producing any response.
|
||
|
||
Returns ``(allowed, reason)``. ``reason`` is populated only when
|
||
``allowed`` is False. This is the shared core used by both the
|
||
responding wrapper (``_check_slash_authorization``) and side-effect-
|
||
free callers like the ``/skill`` autocomplete callback, which must
|
||
return an empty list for unauthorized users instead of leaking an
|
||
ephemeral rejection per-keystroke.
|
||
|
||
Fail-closed semantics for malformed payloads: when an allowlist is
|
||
configured but the interaction is missing the data needed to
|
||
evaluate it (no channel id with channel policy active, no user
|
||
with user/role policy active), the gate REJECTS rather than
|
||
falling through. Without these guards a guild interaction that
|
||
happens to deserialize without a channel id would silently bypass
|
||
``DISCORD_ALLOWED_CHANNELS`` and a payload missing ``user`` would
|
||
raise ``AttributeError`` in the user check below, surfacing as
|
||
an opaque interaction failure rather than a clean rejection.
|
||
"""
|
||
chan_obj = getattr(interaction, "channel", None)
|
||
in_dm = isinstance(chan_obj, discord.DMChannel) if chan_obj is not None else False
|
||
|
||
# ── Channel scope (mirrors on_message lines 3374-3388) ──
|
||
# DMs aren't channel-gated — DMs follow on_message's DM lockdown
|
||
# path which has its own user-allowlist enforcement.
|
||
if not in_dm:
|
||
chan_id_raw = getattr(interaction, "channel_id", None) or getattr(
|
||
chan_obj, "id", None,
|
||
)
|
||
channel_ids: set = set()
|
||
if chan_id_raw is not None:
|
||
channel_ids.add(str(chan_id_raw))
|
||
# Mirror on_message: also test the parent channel for threads
|
||
# so per-channel allow/deny lists work consistently.
|
||
if isinstance(chan_obj, discord.Thread):
|
||
parent_id = self._get_parent_channel_id(chan_obj)
|
||
if parent_id:
|
||
channel_ids.add(str(parent_id))
|
||
|
||
allowed_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
|
||
if allowed_raw:
|
||
allowed = {c.strip() for c in allowed_raw.split(",") if c.strip()}
|
||
if "*" not in allowed:
|
||
if not channel_ids:
|
||
# Channel policy is configured but the interaction
|
||
# has no resolvable channel id. Fail closed.
|
||
return (
|
||
False,
|
||
"channel id missing with DISCORD_ALLOWED_CHANNELS configured",
|
||
)
|
||
if not (channel_ids & allowed):
|
||
return (False, "channel not in DISCORD_ALLOWED_CHANNELS")
|
||
|
||
# Ignored beats allowed: even when a thread's parent channel
|
||
# is on the allowlist, an explicit DISCORD_IGNORED_CHANNELS
|
||
# entry on the thread or its parent rejects the interaction.
|
||
ignored_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||
if ignored_raw and channel_ids:
|
||
ignored = {c.strip() for c in ignored_raw.split(",") if c.strip()}
|
||
if "*" in ignored or (channel_ids & ignored):
|
||
return (False, "channel in DISCORD_IGNORED_CHANNELS")
|
||
|
||
# ── User / role allowlist (mirrors on_message line 681) ──
|
||
user = getattr(interaction, "user", None)
|
||
allowed_users = getattr(self, "_allowed_user_ids", set()) or set()
|
||
allowed_roles = getattr(self, "_allowed_role_ids", set()) or set()
|
||
if user is None or getattr(user, "id", None) is None:
|
||
# No identifiable user. With any user/role allowlist
|
||
# configured, fail closed rather than raise AttributeError
|
||
# on ``interaction.user.id`` below. With no allowlist this
|
||
# is the existing "no allowlist = everyone" backwards-compat.
|
||
if allowed_users or allowed_roles:
|
||
return (False, "missing interaction.user with allowlist configured")
|
||
return (True, None)
|
||
|
||
user_id = str(user.id)
|
||
# Pass guild + is_dm so role check is scoped to the originating
|
||
# guild and cross-guild DM bypass (#12136) can't land via the
|
||
# slash surface either.
|
||
interaction_guild = getattr(interaction, "guild", None)
|
||
if not self._is_allowed_user(
|
||
user_id,
|
||
author=user,
|
||
guild=interaction_guild,
|
||
is_dm=in_dm,
|
||
):
|
||
return (
|
||
False,
|
||
"user not in DISCORD_ALLOWED_USERS / DISCORD_ALLOWED_ROLES",
|
||
)
|
||
|
||
return (True, None)
|
||
|
||
async def _check_slash_authorization(
|
||
self, interaction: "discord.Interaction", command_text: str,
|
||
) -> bool:
|
||
"""Mirror on_message's user/role/channel gates onto a slash invocation.
|
||
|
||
Returns True to proceed. Returns False *after* sending an ephemeral
|
||
rejection, logging a warning, and scheduling a cross-platform admin
|
||
alert — the caller must stop on False (the interaction has already
|
||
been responded to).
|
||
"""
|
||
allowed, reason = self._evaluate_slash_authorization(interaction)
|
||
if allowed:
|
||
return True
|
||
return await self._reject_slash(
|
||
interaction, command_text, reason=reason or "unauthorized",
|
||
)
|
||
|
||
async def _reject_slash(
|
||
self, interaction: "discord.Interaction", command_text: str, *, reason: str,
|
||
) -> bool:
|
||
"""Send ephemeral reject + log warning + schedule admin alert. Returns False.
|
||
|
||
Tolerates a missing ``interaction.user`` -- the fail-closed branch
|
||
in ``_evaluate_slash_authorization`` deliberately routes here for
|
||
malformed payloads (no user) when an allowlist is configured, and
|
||
``str(interaction.user.id)`` would raise AttributeError before the
|
||
ephemeral rejection could be sent.
|
||
"""
|
||
user = getattr(interaction, "user", None)
|
||
if user is not None:
|
||
user_id = str(getattr(user, "id", "?"))
|
||
user_name = getattr(user, "name", "?")
|
||
else:
|
||
user_id = "?"
|
||
user_name = "?"
|
||
chan_id = getattr(interaction, "channel_id", None) or getattr(
|
||
getattr(interaction, "channel", None), "id", None,
|
||
)
|
||
guild_id = getattr(interaction, "guild_id", None)
|
||
|
||
logger.warning(
|
||
"[Discord] Unauthorized slash attempt: user=%s id=%s channel=%s "
|
||
"guild=%s cmd=%r reason=%r",
|
||
user_name, user_id, chan_id, guild_id, command_text, reason,
|
||
)
|
||
|
||
try:
|
||
await interaction.response.send_message(
|
||
"You're not authorized to use this command.",
|
||
ephemeral=True,
|
||
)
|
||
except Exception as e:
|
||
# Interaction may already be responded to (e.g. caller deferred
|
||
# before the auth check, or Discord retried). Best-effort only.
|
||
logger.debug("[Discord] Could not send unauthorized ephemeral: %s", e)
|
||
|
||
# Fire-and-forget: don't block the interaction handler on Telegram I/O.
|
||
try:
|
||
asyncio.create_task(self._notify_unauthorized_slash(
|
||
user_name, user_id, chan_id, guild_id, command_text, reason,
|
||
))
|
||
except Exception as e:
|
||
logger.debug("[Discord] Could not schedule admin notify task: %s", e)
|
||
|
||
return False
|
||
|
||
async def _notify_unauthorized_slash(
|
||
self, user_name: str, user_id: str, chan_id, guild_id,
|
||
command_text: str, reason: str,
|
||
) -> None:
|
||
"""Best-effort cross-platform alert to the gateway operator.
|
||
|
||
Tries TELEGRAM first (most operators set TELEGRAM_HOME_CHANNEL),
|
||
then SLACK. Silently no-ops if no other platform is configured
|
||
with a home channel.
|
||
|
||
A soft send failure -- adapter.send() returning a result with
|
||
``success=False`` rather than raising -- continues the fallback
|
||
chain. Treating a SendResult(success=False) as delivered would
|
||
mean a Telegram outage that the adapter politely surfaces (e.g.
|
||
rate-limit, auth failure) silently swallows the alert without
|
||
attempting Slack. Hard exceptions still take the same path via
|
||
the except branch below.
|
||
"""
|
||
runner = getattr(self, "gateway_runner", None)
|
||
if not runner:
|
||
return
|
||
for target in (Platform.TELEGRAM, Platform.SLACK):
|
||
try:
|
||
adapter = runner.adapters.get(target)
|
||
if not adapter:
|
||
continue
|
||
home = runner.config.get_home_channel(target)
|
||
if not home or not getattr(home, "chat_id", None):
|
||
continue
|
||
msg = (
|
||
"⚠️ Unauthorized Discord slash attempt\n"
|
||
f"User: {user_name} ({user_id})\n"
|
||
f"Channel: {chan_id} (guild {guild_id})\n"
|
||
f"Command: {command_text}\n"
|
||
f"Reason: {reason}"
|
||
)
|
||
result = await adapter.send(str(home.chat_id), msg)
|
||
# Only return on confirmed delivery. SendResult(success=False)
|
||
# -> continue to the next platform.
|
||
if getattr(result, "success", None) is False:
|
||
logger.debug(
|
||
"[Discord] Admin notify via %s returned success=False"
|
||
" (error=%r); falling through",
|
||
target, getattr(result, "error", None),
|
||
)
|
||
continue
|
||
return
|
||
except Exception as e:
|
||
logger.debug("[Discord] Admin notify via %s failed: %s", target, e)
|
||
|
||
async def send_image_file(
|
||
self,
|
||
chat_id: str,
|
||
image_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send a local image file natively as a Discord file attachment."""
|
||
try:
|
||
return await self._send_file_attachment(chat_id, image_path, caption)
|
||
except FileNotFoundError:
|
||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to send local image, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||
return await super().send_image_file(chat_id, image_path, caption, reply_to, metadata=metadata)
|
||
|
||
async def send_image(
|
||
self,
|
||
chat_id: str,
|
||
image_url: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send an image natively as a Discord file attachment."""
|
||
if not self._client:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
if not is_safe_url(image_url):
|
||
logger.warning("[%s] Blocked unsafe image URL during Discord send_image", self.name)
|
||
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
|
||
|
||
try:
|
||
import aiohttp
|
||
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||
|
||
# Download the image and send as a Discord file attachment
|
||
# (Discord renders attachments inline, unlike plain URLs)
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
|
||
if resp.status != 200:
|
||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||
|
||
image_data = await resp.read()
|
||
|
||
# Determine filename from URL or content type
|
||
content_type = resp.headers.get("content-type", "image/png")
|
||
ext = "png"
|
||
if "jpeg" in content_type or "jpg" in content_type:
|
||
ext = "jpg"
|
||
elif "gif" in content_type:
|
||
ext = "gif"
|
||
elif "webp" in content_type:
|
||
ext = "webp"
|
||
|
||
import io
|
||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||
|
||
if self._is_forum_parent(channel):
|
||
return await self._forum_post_file(
|
||
channel,
|
||
content=(caption or "").strip(),
|
||
file=file,
|
||
)
|
||
|
||
msg = await channel.send(
|
||
content=caption if caption else None,
|
||
file=file,
|
||
)
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
|
||
except ImportError:
|
||
logger.warning(
|
||
"[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp",
|
||
self.name,
|
||
exc_info=True,
|
||
)
|
||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error(
|
||
"[%s] Failed to send image attachment, falling back to URL: %s",
|
||
self.name,
|
||
e,
|
||
exc_info=True,
|
||
)
|
||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||
|
||
async def send_animation(
|
||
self,
|
||
chat_id: str,
|
||
animation_url: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send an animated GIF natively as a Discord file attachment."""
|
||
if not self._client:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
if not is_safe_url(animation_url):
|
||
logger.warning("[%s] Blocked unsafe animation URL during Discord send_animation", self.name)
|
||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||
|
||
try:
|
||
import aiohttp
|
||
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
if not channel:
|
||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||
|
||
# Download the GIF and send as a Discord file attachment
|
||
# (Discord renders .gif attachments as auto-playing animations inline)
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||
async with session.get(animation_url, timeout=aiohttp.ClientTimeout(total=30), **_req_kw) as resp:
|
||
if resp.status != 200:
|
||
raise Exception(f"Failed to download animation: HTTP {resp.status}")
|
||
|
||
animation_data = await resp.read()
|
||
|
||
import io
|
||
file = discord.File(io.BytesIO(animation_data), filename="animation.gif")
|
||
|
||
if self._is_forum_parent(channel):
|
||
return await self._forum_post_file(
|
||
channel,
|
||
content=(caption or "").strip(),
|
||
file=file,
|
||
)
|
||
|
||
msg = await channel.send(
|
||
content=caption if caption else None,
|
||
file=file,
|
||
)
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
|
||
except ImportError:
|
||
logger.warning(
|
||
"[%s] aiohttp not installed, falling back to URL. Run: pip install aiohttp",
|
||
self.name,
|
||
exc_info=True,
|
||
)
|
||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error(
|
||
"[%s] Failed to send animation attachment, falling back to URL: %s",
|
||
self.name,
|
||
e,
|
||
exc_info=True,
|
||
)
|
||
return await super().send_animation(chat_id, animation_url, caption, reply_to, metadata=metadata)
|
||
|
||
async def send_video(
|
||
self,
|
||
chat_id: str,
|
||
video_path: str,
|
||
caption: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send a local video file natively as a Discord attachment."""
|
||
try:
|
||
return await self._send_file_attachment(chat_id, video_path, caption)
|
||
except FileNotFoundError:
|
||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to send local video, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||
return await super().send_video(chat_id, video_path, caption, reply_to, metadata=metadata)
|
||
|
||
async def send_document(
|
||
self,
|
||
chat_id: str,
|
||
file_path: str,
|
||
caption: Optional[str] = None,
|
||
file_name: Optional[str] = None,
|
||
reply_to: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send an arbitrary file natively as a Discord attachment."""
|
||
try:
|
||
return await self._send_file_attachment(chat_id, file_path, caption, file_name=file_name)
|
||
except FileNotFoundError:
|
||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to send document, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||
|
||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||
"""Start a persistent typing indicator for a channel.
|
||
|
||
Discord's TYPING_START gateway event is unreliable in DMs for bots.
|
||
Instead, start a background loop that hits the typing endpoint every
|
||
12 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||
stop_typing() is called (after the response is sent).
|
||
|
||
Rate-limit handling: if a 429 is encountered, the loop logs a
|
||
warning, sleeps for the ``retry_after`` duration (or a sensible
|
||
default), and continues — it does NOT die on a single rate-limit
|
||
hit. Only CancelledError (from stop_typing) stops the loop.
|
||
"""
|
||
if not self._client:
|
||
return
|
||
# Don't start a duplicate loop
|
||
if chat_id in self._typing_tasks:
|
||
return
|
||
|
||
async def _typing_loop() -> None:
|
||
try:
|
||
while True:
|
||
try:
|
||
route = discord.http.Route(
|
||
"POST", "/channels/{channel_id}/typing",
|
||
channel_id=chat_id,
|
||
)
|
||
await self._client.http.request(route)
|
||
except asyncio.CancelledError:
|
||
return
|
||
except Exception as e:
|
||
# Don't die on 429 — backoff and continue
|
||
retry_after = self._extract_discord_retry_after(e)
|
||
if retry_after is not None:
|
||
logger.warning(
|
||
"Typing indicator rate-limited for %s; retrying in %.1fs",
|
||
chat_id, retry_after,
|
||
)
|
||
else:
|
||
logger.debug(
|
||
"Discord typing indicator failed for %s: %s",
|
||
chat_id, e,
|
||
)
|
||
return
|
||
await asyncio.sleep(retry_after)
|
||
continue
|
||
await asyncio.sleep(12)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
finally:
|
||
self._typing_tasks.pop(chat_id, None)
|
||
|
||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||
|
||
async def stop_typing(self, chat_id: str) -> None:
|
||
"""Stop the persistent typing indicator for a channel."""
|
||
task = self._typing_tasks.pop(chat_id, None)
|
||
if task:
|
||
task.cancel()
|
||
try:
|
||
await task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
|
||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||
"""Get information about a Discord channel."""
|
||
if not self._client:
|
||
return {"name": "Unknown", "type": "dm"}
|
||
|
||
try:
|
||
channel = self._client.get_channel(int(chat_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(chat_id))
|
||
|
||
if not channel:
|
||
return {"name": str(chat_id), "type": "dm"}
|
||
|
||
# Determine channel type
|
||
if isinstance(channel, discord.DMChannel):
|
||
chat_type = "dm"
|
||
name = channel.recipient.name if channel.recipient else str(chat_id)
|
||
elif isinstance(channel, discord.Thread):
|
||
chat_type = "thread"
|
||
name = channel.name
|
||
elif isinstance(channel, discord.TextChannel):
|
||
chat_type = "channel"
|
||
name = f"#{channel.name}"
|
||
if channel.guild:
|
||
name = f"{channel.guild.name} / {name}"
|
||
else:
|
||
chat_type = "channel"
|
||
name = getattr(channel, "name", str(chat_id))
|
||
|
||
return {
|
||
"name": name,
|
||
"type": chat_type,
|
||
"guild_id": str(channel.guild.id) if hasattr(channel, "guild") and channel.guild else None,
|
||
"guild_name": channel.guild.name if hasattr(channel, "guild") and channel.guild else None,
|
||
}
|
||
except Exception as e: # pragma: no cover - defensive logging
|
||
logger.error("[%s] Failed to get chat info for %s: %s", self.name, chat_id, e, exc_info=True)
|
||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||
|
||
async def _resolve_allowed_usernames(self) -> None:
|
||
"""
|
||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||
|
||
Users can specify usernames (e.g. "teknium") or display names instead of
|
||
raw numeric IDs. After resolution, the env var and internal set are updated
|
||
so authorization checks work with IDs only.
|
||
"""
|
||
if not self._allowed_user_ids or not self._client:
|
||
return
|
||
|
||
numeric_ids = set()
|
||
to_resolve = set()
|
||
|
||
for entry in self._allowed_user_ids:
|
||
if entry.isdigit():
|
||
numeric_ids.add(entry)
|
||
else:
|
||
to_resolve.add(entry.lower())
|
||
|
||
if not to_resolve:
|
||
return
|
||
|
||
print(f"[{self.name}] Resolving {len(to_resolve)} username(s): {', '.join(to_resolve)}")
|
||
resolved_count = 0
|
||
|
||
for guild in self._client.guilds:
|
||
# Fetch full member list (requires members intent)
|
||
try:
|
||
members = guild.members
|
||
if len(members) < guild.member_count:
|
||
members = [m async for m in guild.fetch_members(limit=None)]
|
||
except Exception as e:
|
||
logger.warning("Failed to fetch members for guild %s: %s", guild.name, e)
|
||
continue
|
||
|
||
for member in members:
|
||
name_lower = member.name.lower()
|
||
display_lower = member.display_name.lower()
|
||
global_lower = (member.global_name or "").lower()
|
||
|
||
matched = name_lower in to_resolve or display_lower in to_resolve or global_lower in to_resolve
|
||
if matched:
|
||
uid = str(member.id)
|
||
numeric_ids.add(uid)
|
||
resolved_count += 1
|
||
matched_name = name_lower if name_lower in to_resolve else (
|
||
display_lower if display_lower in to_resolve else global_lower
|
||
)
|
||
to_resolve.discard(matched_name)
|
||
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
|
||
|
||
if not to_resolve:
|
||
break
|
||
|
||
if to_resolve:
|
||
print(f"[{self.name}] Could not resolve usernames: {', '.join(to_resolve)}")
|
||
|
||
# Update internal set and env var so gateway auth checks use IDs
|
||
self._allowed_user_ids = numeric_ids
|
||
os.environ["DISCORD_ALLOWED_USERS"] = ",".join(sorted(numeric_ids))
|
||
if resolved_count:
|
||
print(f"[{self.name}] Updated DISCORD_ALLOWED_USERS with {resolved_count} resolved ID(s)")
|
||
|
||
def format_message(self, content: str) -> str:
|
||
"""
|
||
Format message for Discord.
|
||
|
||
Discord uses its own markdown variant.
|
||
"""
|
||
# Discord markdown is fairly standard, no special escaping needed
|
||
return content
|
||
|
||
async def _run_simple_slash(
|
||
self,
|
||
interaction: discord.Interaction,
|
||
command_text: str,
|
||
followup_msg: str | None = None,
|
||
) -> None:
|
||
"""Common handler for simple slash commands that dispatch a command string.
|
||
|
||
Defers the interaction (shows "thinking..."), dispatches the command,
|
||
then cleans up the deferred response. If *followup_msg* is provided
|
||
the "thinking..." indicator is replaced with that text; otherwise it
|
||
is deleted so the channel isn't cluttered.
|
||
"""
|
||
# Log the invoker so ghost-command reports can be triaged. Discord
|
||
# native slash invocations are always user-initiated (no bot can fire
|
||
# them), but mobile autocomplete / keyboard shortcuts / other users
|
||
# in the same channel are easy to miss in post-mortems.
|
||
try:
|
||
_user = interaction.user
|
||
_chan_id = getattr(interaction.channel, "id", None) or getattr(interaction, "channel_id", None)
|
||
logger.info(
|
||
"[Discord] slash '%s' invoked by user=%s id=%s channel=%s guild=%s",
|
||
command_text,
|
||
getattr(_user, "name", "?"),
|
||
getattr(_user, "id", "?"),
|
||
_chan_id,
|
||
getattr(interaction, "guild_id", None),
|
||
)
|
||
except Exception:
|
||
pass # logging must never block command dispatch
|
||
|
||
# Auth gate — must run before defer() so an ephemeral rejection can
|
||
# be delivered on the still-unresponded interaction.
|
||
if not await self._check_slash_authorization(interaction, command_text):
|
||
return
|
||
|
||
await interaction.response.defer(ephemeral=True)
|
||
event = self._build_slash_event(interaction, command_text)
|
||
await self.handle_message(event)
|
||
try:
|
||
if followup_msg:
|
||
await interaction.edit_original_response(content=followup_msg)
|
||
else:
|
||
await interaction.delete_original_response()
|
||
except Exception as e:
|
||
logger.debug("Discord interaction cleanup failed: %s", e)
|
||
|
||
def _register_slash_commands(self) -> None:
|
||
"""Register Discord slash commands on the command tree."""
|
||
if not self._client:
|
||
return
|
||
|
||
tree = self._client.tree
|
||
|
||
@tree.command(name="new", description="Start a new conversation")
|
||
async def slash_new(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/reset", "New conversation started~")
|
||
|
||
@tree.command(name="reset", description="Reset your Hermes session")
|
||
async def slash_reset(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/reset", "Session reset~")
|
||
|
||
@tree.command(name="model", description="Show or change the model")
|
||
@discord.app_commands.describe(name="Model name (e.g. anthropic/claude-sonnet-4). Leave empty to see current.")
|
||
async def slash_model(interaction: discord.Interaction, name: str = ""):
|
||
await self._run_simple_slash(interaction, f"/model {name}".strip())
|
||
|
||
@tree.command(name="reasoning", description="Show or change reasoning effort")
|
||
@discord.app_commands.describe(effort="Reasoning effort: none, minimal, low, medium, high, or xhigh.")
|
||
async def slash_reasoning(interaction: discord.Interaction, effort: str = ""):
|
||
await self._run_simple_slash(interaction, f"/reasoning {effort}".strip())
|
||
|
||
@tree.command(name="personality", description="Set a personality")
|
||
@discord.app_commands.describe(name="Personality name. Leave empty to list available.")
|
||
async def slash_personality(interaction: discord.Interaction, name: str = ""):
|
||
await self._run_simple_slash(interaction, f"/personality {name}".strip())
|
||
|
||
@tree.command(name="retry", description="Retry your last message")
|
||
async def slash_retry(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/retry", "Retrying~")
|
||
|
||
@tree.command(name="undo", description="Remove the last exchange")
|
||
async def slash_undo(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/undo")
|
||
|
||
@tree.command(name="status", description="Show Hermes session status")
|
||
async def slash_status(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/status", "Status sent~")
|
||
|
||
@tree.command(name="sethome", description="Set this chat as the home channel")
|
||
async def slash_sethome(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/sethome")
|
||
|
||
@tree.command(name="stop", description="Stop the running Hermes agent")
|
||
async def slash_stop(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/stop", "Stop requested~")
|
||
|
||
@tree.command(name="steer", description="Inject a message after the next tool call (no interrupt)")
|
||
@discord.app_commands.describe(prompt="Text to inject into the agent's next tool result")
|
||
async def slash_steer(interaction: discord.Interaction, prompt: str):
|
||
await self._run_simple_slash(interaction, f"/steer {prompt}".strip())
|
||
|
||
@tree.command(name="compress", description="Compress conversation context")
|
||
async def slash_compress(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/compress")
|
||
|
||
@tree.command(name="title", description="Set or show the session title")
|
||
@discord.app_commands.describe(name="Session title. Leave empty to show current.")
|
||
async def slash_title(interaction: discord.Interaction, name: str = ""):
|
||
await self._run_simple_slash(interaction, f"/title {name}".strip())
|
||
|
||
@tree.command(name="resume", description="Resume a previously-named session")
|
||
@discord.app_commands.describe(name="Session name to resume. Leave empty to list sessions.")
|
||
async def slash_resume(interaction: discord.Interaction, name: str = ""):
|
||
await self._run_simple_slash(interaction, f"/resume {name}".strip())
|
||
|
||
@tree.command(name="usage", description="Show token usage for this session")
|
||
async def slash_usage(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/usage")
|
||
|
||
@tree.command(name="help", description="Show available commands")
|
||
async def slash_help(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/help")
|
||
|
||
@tree.command(name="insights", description="Show usage insights and analytics")
|
||
@discord.app_commands.describe(days="Number of days to analyze (default: 7)")
|
||
async def slash_insights(interaction: discord.Interaction, days: int = 7):
|
||
await self._run_simple_slash(interaction, f"/insights {days}")
|
||
|
||
@tree.command(name="reload-mcp", description="Reload MCP servers from config")
|
||
async def slash_reload_mcp(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/reload-mcp")
|
||
|
||
@tree.command(name="reload-skills", description="Re-scan ~/.hermes/skills/ for new or removed skills")
|
||
async def slash_reload_skills(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/reload-skills")
|
||
|
||
@tree.command(name="voice", description="Toggle voice reply mode")
|
||
@discord.app_commands.describe(mode="Voice mode: join, channel, leave, on, tts, off, or status")
|
||
@discord.app_commands.choices(mode=[
|
||
# `join` and `channel` both route to _handle_voice_channel_join in
|
||
# gateway/run.py — expose both in the slash UI so autocomplete
|
||
# matches what the docs advertise and what the runner accepts when
|
||
# the command is typed as plain text.
|
||
discord.app_commands.Choice(name="join — join your voice channel", value="join"),
|
||
discord.app_commands.Choice(name="channel — join your voice channel (alias)", value="channel"),
|
||
discord.app_commands.Choice(name="leave — leave voice channel", value="leave"),
|
||
discord.app_commands.Choice(name="on — voice reply to voice messages", value="on"),
|
||
discord.app_commands.Choice(name="tts — voice reply to all messages", value="tts"),
|
||
discord.app_commands.Choice(name="off — text only", value="off"),
|
||
discord.app_commands.Choice(name="status — show current mode", value="status"),
|
||
])
|
||
async def slash_voice(interaction: discord.Interaction, mode: str = ""):
|
||
await self._run_simple_slash(interaction, f"/voice {mode}".strip())
|
||
|
||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||
async def slash_update(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/update", "Update initiated~")
|
||
|
||
@tree.command(name="restart", description="Gracefully restart the Hermes gateway")
|
||
async def slash_restart(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, "/restart", "Restart requested~")
|
||
|
||
@tree.command(name="approve", description="Approve a pending dangerous command")
|
||
@discord.app_commands.describe(scope="Optional: 'all', 'session', 'always', 'all session', 'all always'")
|
||
async def slash_approve(interaction: discord.Interaction, scope: str = ""):
|
||
await self._run_simple_slash(interaction, f"/approve {scope}".strip())
|
||
|
||
@tree.command(name="deny", description="Deny a pending dangerous command")
|
||
@discord.app_commands.describe(scope="Optional: 'all' to deny all pending commands")
|
||
async def slash_deny(interaction: discord.Interaction, scope: str = ""):
|
||
await self._run_simple_slash(interaction, f"/deny {scope}".strip())
|
||
|
||
@tree.command(name="thread", description="Create a new thread and start a Hermes session in it")
|
||
@discord.app_commands.describe(
|
||
name="Thread name",
|
||
message="Optional first message to send to Hermes in the thread",
|
||
auto_archive_duration="Auto-archive in minutes (60, 1440, 4320, 10080)",
|
||
)
|
||
async def slash_thread(
|
||
interaction: discord.Interaction,
|
||
name: str,
|
||
message: str = "",
|
||
auto_archive_duration: int = 1440,
|
||
):
|
||
# defer() is performed inside the handler *after* the auth gate
|
||
# so a rejected invoker can receive an ephemeral rejection.
|
||
await self._handle_thread_create_slash(interaction, name, message, auto_archive_duration)
|
||
|
||
@tree.command(name="queue", description="Queue a prompt for the next turn (doesn't interrupt)")
|
||
@discord.app_commands.describe(prompt="The prompt to queue")
|
||
async def slash_queue(interaction: discord.Interaction, prompt: str):
|
||
await self._run_simple_slash(interaction, f"/queue {prompt}", "Queued for the next turn.")
|
||
|
||
@tree.command(name="background", description="Run a prompt in the background")
|
||
@discord.app_commands.describe(prompt="The prompt to run in the background")
|
||
async def slash_background(interaction: discord.Interaction, prompt: str):
|
||
await self._run_simple_slash(interaction, f"/background {prompt}", "Background task started~")
|
||
|
||
# ── Auto-register any gateway-available commands not yet on the tree ──
|
||
# This ensures new commands added to COMMAND_REGISTRY in
|
||
# hermes_cli/commands.py automatically appear as Discord slash
|
||
# commands without needing a manual entry here.
|
||
def _build_auto_slash_command(_name: str, _description: str, _args_hint: str = ""):
|
||
"""Build a discord.app_commands.Command that proxies to _run_simple_slash."""
|
||
discord_name = _name.lower()[:32]
|
||
desc = (_description or f"Run /{_name}")[:100]
|
||
has_args = bool(_args_hint)
|
||
|
||
if has_args:
|
||
def _make_args_handler(__name: str, __hint: str):
|
||
@discord.app_commands.describe(args=f"Arguments: {__hint}"[:100])
|
||
async def _handler(interaction: discord.Interaction, args: str = ""):
|
||
await self._run_simple_slash(
|
||
interaction, f"/{__name} {args}".strip()
|
||
)
|
||
_handler.__name__ = f"auto_slash_{__name.replace('-', '_')}"
|
||
return _handler
|
||
|
||
handler = _make_args_handler(_name, _args_hint)
|
||
else:
|
||
def _make_simple_handler(__name: str):
|
||
async def _handler(interaction: discord.Interaction):
|
||
await self._run_simple_slash(interaction, f"/{__name}")
|
||
_handler.__name__ = f"auto_slash_{__name.replace('-', '_')}"
|
||
return _handler
|
||
|
||
handler = _make_simple_handler(_name)
|
||
|
||
return discord.app_commands.Command(
|
||
name=discord_name,
|
||
description=desc,
|
||
callback=handler,
|
||
)
|
||
|
||
already_registered: set[str] = set()
|
||
# Native commands above are registered first and are the highest
|
||
# priority, so they always survive the 100-command cap. Reserve one
|
||
# slot for the consolidated ``/skill`` group registered further below.
|
||
slot_cap = _DISCORD_MAX_APP_COMMANDS - 1
|
||
dropped_over_cap = 0
|
||
try:
|
||
from hermes_cli.commands import COMMAND_REGISTRY, _is_gateway_available, _resolve_config_gates
|
||
|
||
try:
|
||
already_registered = {cmd.name for cmd in tree.get_commands()}
|
||
except Exception:
|
||
pass
|
||
|
||
config_overrides = _resolve_config_gates()
|
||
|
||
for cmd_def in COMMAND_REGISTRY:
|
||
if not _is_gateway_available(cmd_def, config_overrides):
|
||
continue
|
||
# Discord command names: lowercase, hyphens OK, max 32 chars.
|
||
discord_name = cmd_def.name.lower()[:32]
|
||
if discord_name in already_registered:
|
||
continue
|
||
if len(already_registered) >= slot_cap:
|
||
dropped_over_cap += 1
|
||
continue
|
||
auto_cmd = _build_auto_slash_command(
|
||
cmd_def.name,
|
||
cmd_def.description,
|
||
cmd_def.args_hint,
|
||
)
|
||
try:
|
||
tree.add_command(auto_cmd)
|
||
already_registered.add(discord_name)
|
||
except Exception:
|
||
# Silently skip commands that fail registration (e.g.
|
||
# name conflict with a subcommand group).
|
||
pass
|
||
|
||
logger.debug(
|
||
"Discord auto-registered %d commands from COMMAND_REGISTRY",
|
||
len(already_registered),
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Discord auto-register from COMMAND_REGISTRY failed: %s", e)
|
||
|
||
# ── Plugin-registered slash commands ──
|
||
# Plugins register via PluginContext.register_command(); we mirror
|
||
# those into Discord's native slash picker so users get the same
|
||
# autocomplete UX as for built-in commands. No per-platform plugin
|
||
# API needed — plugin commands are platform-agnostic.
|
||
try:
|
||
from hermes_cli.commands import _iter_plugin_command_entries
|
||
|
||
for plugin_name, plugin_desc, plugin_args_hint in _iter_plugin_command_entries():
|
||
discord_name = plugin_name.lower()[:32]
|
||
if discord_name in already_registered:
|
||
continue
|
||
if len(already_registered) >= slot_cap:
|
||
dropped_over_cap += 1
|
||
continue
|
||
auto_cmd = _build_auto_slash_command(
|
||
plugin_name,
|
||
plugin_desc,
|
||
plugin_args_hint,
|
||
)
|
||
try:
|
||
tree.add_command(auto_cmd)
|
||
already_registered.add(discord_name)
|
||
except Exception:
|
||
# Silently skip commands that fail registration (e.g.
|
||
# name conflict with a subcommand group).
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(
|
||
"Discord auto-register from plugin commands failed: %s", e
|
||
)
|
||
|
||
# Register skills under a single /skill command group with category
|
||
# subcommand groups. This uses 1 top-level slot instead of N,
|
||
# supporting up to 25 categories × 25 skills = 625 skills.
|
||
self._register_skill_group(tree)
|
||
|
||
if dropped_over_cap:
|
||
# Staying under the cap keeps the whole sync succeeding; without
|
||
# this guard a single over-limit command makes Discord reject the
|
||
# entire batch (error 30032), breaking every slash command.
|
||
logger.warning(
|
||
"[%s] Reached Discord's limit of %d slash commands; skipped %d "
|
||
"lower-priority command(s) to keep the command sync working. "
|
||
"Disable slash commands you don't need or trim installed plugins "
|
||
"to surface them all.",
|
||
self.name,
|
||
_DISCORD_MAX_APP_COMMANDS,
|
||
dropped_over_cap,
|
||
)
|
||
|
||
# Optional defense-in-depth: hide every slash command from non-admin
|
||
# guild members in Discord's slash picker. Server-side authorization
|
||
# (``_check_slash_authorization``) is the actual gate; this is purely
|
||
# UX so users don't see commands they can't invoke. Off by default
|
||
# to preserve the slash UX for deployments that intentionally allow
|
||
# everyone in the guild.
|
||
if os.getenv("DISCORD_HIDE_SLASH_COMMANDS", "false").strip().lower() in {
|
||
"true", "1", "yes", "on",
|
||
}:
|
||
self._apply_owner_only_visibility(tree)
|
||
|
||
def _apply_owner_only_visibility(self, tree) -> None:
|
||
"""Set default_member_permissions=0 on every registered slash command.
|
||
|
||
Discord interprets ``Permissions(0)`` as "requires no permissions",
|
||
which paradoxically means the command is hidden from every guild
|
||
member except those with the Administrator permission. Server admins
|
||
can re-grant per user/role via Server Settings → Integrations →
|
||
<bot> → Permissions.
|
||
|
||
Authoritative gate is ``_check_slash_authorization`` on every
|
||
invocation, which catches stale clients, role grants made by
|
||
mistake, and direct API calls bypassing Discord's UI hide.
|
||
"""
|
||
try:
|
||
no_perms = discord.Permissions(0)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"[Discord] _apply_owner_only_visibility: cannot build Permissions(0): %s",
|
||
e,
|
||
)
|
||
return
|
||
applied = 0
|
||
for cmd in tree.get_commands():
|
||
try:
|
||
cmd.default_permissions = no_perms
|
||
applied += 1
|
||
except Exception as e:
|
||
logger.debug(
|
||
"[Discord] Could not set default_permissions on %r: %s",
|
||
getattr(cmd, "name", "?"), e,
|
||
)
|
||
logger.info(
|
||
"[Discord] Hid %d slash command(s) from non-admin guild members "
|
||
"(opt-in defense in depth via DISCORD_HIDE_SLASH_COMMANDS).",
|
||
applied,
|
||
)
|
||
|
||
def _register_skill_group(self, tree) -> None:
|
||
"""Register a single ``/skill`` command with autocomplete on the name.
|
||
|
||
Discord enforces an ~8000-byte per-command payload limit. The older
|
||
nested layout (``/skill <category> <name>``) registered one giant
|
||
command whose serialized payload grew linearly with the skill
|
||
catalog — with the default ~75 skills the payload was ~14 KB and
|
||
``tree.sync()`` rejected the entire slash-command batch (issues
|
||
#11321, #10259, #11385, #10261, #10214).
|
||
|
||
Autocomplete options are fetched dynamically by Discord when the
|
||
user types — they do NOT count against the per-command registration
|
||
budget. So we register ONE flat ``/skill`` command with
|
||
``name: str`` (autocompleted) and ``args: str = ""``. This scales
|
||
to thousands of skills with no size math, no splitting, and no
|
||
hidden skills. The slash picker also becomes more discoverable —
|
||
Discord live-filters by the user's typed prefix against both the
|
||
skill name and its description.
|
||
|
||
The entries list and lookup dict are stored on ``self`` rather
|
||
than captured in closure variables so :meth:`refresh_skill_group`
|
||
can repopulate them when the user runs ``/reload-skills`` without
|
||
needing to touch the Discord slash-command tree or trigger a
|
||
``tree.sync()`` call.
|
||
"""
|
||
try:
|
||
existing_names = set()
|
||
try:
|
||
existing_names = {cmd.name for cmd in tree.get_commands()}
|
||
except Exception:
|
||
pass
|
||
|
||
# Populate the instance-level entries/lookup so the
|
||
# autocomplete + handler callbacks below always read the
|
||
# freshest state. refresh_skill_group() re-runs the same
|
||
# collector and mutates these two attributes in place.
|
||
self._skill_entries: list[tuple[str, str, str]] = []
|
||
self._skill_lookup: dict[str, tuple[str, str]] = {}
|
||
self._skill_group_reserved_names: set[str] = set(existing_names)
|
||
self._refresh_skill_catalog_state()
|
||
|
||
if not self._skill_entries:
|
||
return
|
||
|
||
async def _autocomplete_name(
|
||
interaction: "discord.Interaction", current: str,
|
||
) -> list:
|
||
"""Filter skills by the user's typed prefix.
|
||
|
||
Matches both the skill name and its description so
|
||
"/skill pdf" surfaces skills whose description mentions
|
||
PDFs even if the name doesn't. Discord caps this list at
|
||
25 entries per query.
|
||
|
||
Authorization: a quiet pre-check evaluates the slash
|
||
allowlists and returns ``[]`` for unauthorized users so
|
||
the installed skill catalog is not leaked to anyone who
|
||
can see the command in the picker. Returning a generic
|
||
empty list here is intentional — sending a per-keystroke
|
||
ephemeral rejection would produce a barrage of error
|
||
popups during typing.
|
||
|
||
Reads ``self._skill_entries`` so a ``/reload-skills`` run
|
||
since process start shows up on the very next keystroke.
|
||
"""
|
||
try:
|
||
allowed, _reason = self._evaluate_slash_authorization(interaction)
|
||
except Exception:
|
||
# Defensive: never raise from autocomplete. Fail
|
||
# closed by returning an empty suggestion list.
|
||
return []
|
||
if not allowed:
|
||
return []
|
||
q = (current or "").strip().lower()
|
||
choices: list = []
|
||
for name, desc, _key in self._skill_entries:
|
||
if not q or q in name.lower() or (desc and q in desc.lower()):
|
||
if desc:
|
||
label = f"{name} — {desc}"
|
||
else:
|
||
label = name
|
||
# Discord's Choice.name is capped at 100 chars.
|
||
if len(label) > 100:
|
||
label = label[:97] + "..."
|
||
choices.append(
|
||
discord.app_commands.Choice(name=label, value=name)
|
||
)
|
||
if len(choices) >= 25:
|
||
break
|
||
return choices
|
||
|
||
@discord.app_commands.describe(
|
||
name="Which skill to run",
|
||
args="Optional arguments for the skill",
|
||
)
|
||
@discord.app_commands.autocomplete(name=_autocomplete_name)
|
||
async def _skill_handler(
|
||
interaction: "discord.Interaction", name: str, args: str = "",
|
||
):
|
||
# Authorize BEFORE any skill lookup so that known and
|
||
# unknown skill names produce identical rejections for
|
||
# unauthorized users (no probing the installed catalog
|
||
# via "Unknown skill: <name>" responses).
|
||
if not await self._check_slash_authorization(interaction, "/skill"):
|
||
return
|
||
entry = self._skill_lookup.get(name)
|
||
if not entry:
|
||
await interaction.response.send_message(
|
||
f"Unknown skill: `{name}`. Start typing for "
|
||
f"autocomplete suggestions.",
|
||
ephemeral=True,
|
||
)
|
||
return
|
||
_desc, cmd_key = entry
|
||
await self._run_simple_slash(
|
||
interaction, f"{cmd_key} {args}".strip()
|
||
)
|
||
|
||
cmd = discord.app_commands.Command(
|
||
name="skill",
|
||
description="Run a Hermes skill",
|
||
callback=_skill_handler,
|
||
)
|
||
tree.add_command(cmd)
|
||
|
||
logger.info(
|
||
"[%s] Registered /skill command with %d skill(s) via autocomplete",
|
||
self.name, len(self._skill_entries),
|
||
)
|
||
if self._skill_group_hidden_count:
|
||
logger.info(
|
||
"[%s] %d skill(s) filtered out of /skill (name clamp / reserved)",
|
||
self.name, self._skill_group_hidden_count,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("[%s] Failed to register /skill command: %s", self.name, exc)
|
||
|
||
def _refresh_skill_catalog_state(self) -> None:
|
||
"""Re-scan disk for skills and repopulate ``self._skill_entries``.
|
||
|
||
Called once from :meth:`_register_skill_group` at startup and
|
||
again from :meth:`refresh_skill_group` whenever the user runs
|
||
``/reload-skills``. No Discord API calls are made — autocomplete
|
||
and the handler both read from these instance attributes
|
||
directly, so an in-place mutation is sufficient.
|
||
"""
|
||
from hermes_cli.commands import discord_skill_commands_by_category
|
||
|
||
reserved = getattr(self, "_skill_group_reserved_names", set())
|
||
categories, uncategorized, hidden = discord_skill_commands_by_category(
|
||
reserved_names=set(reserved),
|
||
)
|
||
entries: list[tuple[str, str, str]] = list(uncategorized)
|
||
for cat_skills in categories.values():
|
||
entries.extend(cat_skills)
|
||
# Stable alphabetical order so the autocomplete suggestion
|
||
# list is predictable across restarts.
|
||
entries.sort(key=lambda t: t[0])
|
||
|
||
self._skill_entries = entries
|
||
self._skill_lookup = {n: (d, k) for n, d, k in entries}
|
||
self._skill_group_hidden_count = hidden
|
||
|
||
def refresh_skill_group(self) -> tuple[int, int]:
|
||
"""Rescan skills and update the live ``/skill`` autocomplete state.
|
||
|
||
Invoked by :meth:`gateway.run.GatewayOrchestrator._handle_reload_skills_command`
|
||
after :func:`agent.skill_commands.reload_skills` has refreshed
|
||
the in-process skill-command registry. Without this call, the
|
||
``/skill`` autocomplete dropdown keeps showing the list captured
|
||
at process start — new skills stay invisible and deleted skills
|
||
return an "Unknown skill" error when clicked.
|
||
|
||
Because autocomplete options are fetched dynamically by Discord,
|
||
we only need to mutate the entries/lookup attributes read by the
|
||
callbacks — no ``tree.sync()`` is required.
|
||
|
||
Returns ``(new_count, hidden_count)``.
|
||
"""
|
||
try:
|
||
self._refresh_skill_catalog_state()
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[%s] Failed to refresh /skill autocomplete after reload: %s",
|
||
self.name, exc,
|
||
)
|
||
return (len(getattr(self, "_skill_entries", [])), 0)
|
||
logger.info(
|
||
"[%s] Refreshed /skill autocomplete: %d skill(s) available (%d filtered)",
|
||
self.name,
|
||
len(self._skill_entries),
|
||
self._skill_group_hidden_count,
|
||
)
|
||
return (len(self._skill_entries), self._skill_group_hidden_count)
|
||
|
||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||
is_thread = isinstance(interaction.channel, discord.Thread)
|
||
thread_id = None
|
||
|
||
if is_dm:
|
||
chat_type = "dm"
|
||
elif is_thread:
|
||
chat_type = "thread"
|
||
thread_id = str(interaction.channel_id)
|
||
else:
|
||
chat_type = "group"
|
||
|
||
chat_name = ""
|
||
if not is_dm and hasattr(interaction.channel, "name"):
|
||
chat_name = interaction.channel.name
|
||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||
|
||
# Get channel topic (if available).
|
||
# For forum threads, inherit the parent forum's topic.
|
||
chat_topic = self._get_effective_topic(interaction.channel, is_thread=is_thread)
|
||
|
||
source = self.build_source(
|
||
chat_id=str(interaction.channel_id),
|
||
chat_name=chat_name,
|
||
chat_type=chat_type,
|
||
user_id=str(interaction.user.id),
|
||
user_name=interaction.user.display_name,
|
||
thread_id=thread_id,
|
||
chat_topic=chat_topic,
|
||
)
|
||
|
||
msg_type = MessageType.COMMAND if text.startswith("/") else MessageType.TEXT
|
||
channel_id = str(interaction.channel_id)
|
||
parent_id = str(getattr(getattr(interaction, "channel", None), "parent_id", "") or "")
|
||
return MessageEvent(
|
||
text=text,
|
||
message_type=msg_type,
|
||
source=source,
|
||
raw_message=interaction,
|
||
channel_prompt=self._resolve_channel_prompt(channel_id, parent_id or None),
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Thread creation helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _handle_thread_create_slash(
|
||
self,
|
||
interaction: discord.Interaction,
|
||
name: str,
|
||
message: str = "",
|
||
auto_archive_duration: int = 1440,
|
||
) -> None:
|
||
"""Create a Discord thread from a slash command and start a session in it."""
|
||
if not await self._check_slash_authorization(interaction, "/thread"):
|
||
return
|
||
await interaction.response.defer(ephemeral=True)
|
||
result = await self._create_thread(
|
||
interaction,
|
||
name=name,
|
||
message=message,
|
||
auto_archive_duration=auto_archive_duration,
|
||
)
|
||
|
||
if not result.get("success"):
|
||
error = result.get("error", "unknown error")
|
||
await interaction.followup.send(f"Failed to create thread: {error}", ephemeral=True)
|
||
return
|
||
|
||
thread_id = result.get("thread_id")
|
||
thread_name = result.get("thread_name") or name
|
||
|
||
# Tell the user where the thread is
|
||
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
|
||
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
|
||
|
||
# Track thread participation so follow-ups don't require @mention
|
||
if thread_id:
|
||
self._threads.mark(thread_id)
|
||
|
||
# If a message was provided, kick off a new Hermes session in the thread
|
||
starter = (message or "").strip()
|
||
if starter and thread_id:
|
||
await self._dispatch_thread_session(interaction, thread_id, thread_name, starter)
|
||
|
||
async def _dispatch_thread_session(
|
||
self,
|
||
interaction: discord.Interaction,
|
||
thread_id: str,
|
||
thread_name: str,
|
||
text: str,
|
||
) -> None:
|
||
"""Build a MessageEvent pointing at a thread and send it through handle_message."""
|
||
guild_name = ""
|
||
if hasattr(interaction, "guild") and interaction.guild:
|
||
guild_name = interaction.guild.name
|
||
|
||
chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name
|
||
|
||
# Inherit forum topic when the thread was created inside a forum channel.
|
||
_chan = getattr(interaction, "channel", None)
|
||
chat_topic = self._get_effective_topic(_chan, is_thread=True) if _chan else None
|
||
|
||
source = self.build_source(
|
||
chat_id=thread_id,
|
||
chat_name=chat_name,
|
||
chat_type="thread",
|
||
user_id=str(interaction.user.id),
|
||
user_name=interaction.user.display_name,
|
||
thread_id=thread_id,
|
||
chat_topic=chat_topic,
|
||
)
|
||
|
||
_parent_channel = self._thread_parent_channel(getattr(interaction, "channel", None))
|
||
_parent_id = str(getattr(_parent_channel, "id", "") or "")
|
||
_skills = self._resolve_channel_skills(thread_id, _parent_id or None)
|
||
_channel_prompt = self._resolve_channel_prompt(thread_id, _parent_id or None)
|
||
event = MessageEvent(
|
||
text=text,
|
||
message_type=MessageType.TEXT,
|
||
source=source,
|
||
raw_message=interaction,
|
||
auto_skill=_skills,
|
||
channel_prompt=_channel_prompt,
|
||
)
|
||
await self.handle_message(event)
|
||
|
||
def _resolve_channel_skills(self, channel_id: str, parent_id: str | None = None) -> list[str] | None:
|
||
"""Look up auto-skill bindings for a Discord channel/forum thread.
|
||
|
||
Config format (in platform extra):
|
||
channel_skill_bindings:
|
||
- id: "123456"
|
||
skills: ["skill-a", "skill-b"]
|
||
Also checks parent_id so forum threads inherit the forum's bindings.
|
||
"""
|
||
from gateway.platforms.base import resolve_channel_skills
|
||
return resolve_channel_skills(self.config.extra, channel_id, parent_id)
|
||
|
||
def _resolve_channel_prompt(self, channel_id: str, parent_id: str | None = None) -> str | None:
|
||
"""Resolve a Discord per-channel prompt, preferring the exact channel over its parent."""
|
||
from gateway.platforms.base import resolve_channel_prompt
|
||
return resolve_channel_prompt(self.config.extra, channel_id, parent_id)
|
||
|
||
def _discord_require_mention(self) -> bool:
|
||
"""Return whether Discord channel messages require a bot mention."""
|
||
configured = self.config.extra.get("require_mention")
|
||
if configured is not None:
|
||
if isinstance(configured, str):
|
||
return configured.lower() not in {"false", "0", "no", "off"}
|
||
return bool(configured)
|
||
return os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in {"false", "0", "no", "off"}
|
||
|
||
def _discord_allow_any_attachment(self) -> bool:
|
||
"""Return whether Discord attachments bypass the SUPPORTED_DOCUMENT_TYPES allowlist.
|
||
|
||
When True, any uploaded file is cached to disk and surfaced to the
|
||
agent as a local path so it can be inspected via terminal / read_file
|
||
/ ffprobe / etc. Default False preserves the historical behaviour of
|
||
dropping unsupported types with a warning log.
|
||
"""
|
||
configured = self.config.extra.get("allow_any_attachment")
|
||
if configured is not None:
|
||
if isinstance(configured, str):
|
||
return configured.lower() not in {"false", "0", "no", "off", ""}
|
||
return bool(configured)
|
||
return os.getenv("DISCORD_ALLOW_ANY_ATTACHMENT", "false").lower() in {"true", "1", "yes", "on"}
|
||
|
||
def _discord_max_attachment_bytes(self) -> int:
|
||
"""Return the per-attachment byte cap. 0 means unlimited.
|
||
|
||
The whole attachment is held in memory while being written to the
|
||
cache, so unlimited carries a real memory cost. Default 32 MiB
|
||
matches the historical hardcoded value.
|
||
"""
|
||
configured = self.config.extra.get("max_attachment_bytes")
|
||
if configured is None:
|
||
configured = os.getenv("DISCORD_MAX_ATTACHMENT_BYTES")
|
||
if configured is None or configured == "":
|
||
return 32 * 1024 * 1024
|
||
try:
|
||
value = int(configured)
|
||
except (TypeError, ValueError):
|
||
logger.warning(
|
||
"[Discord] Invalid max_attachment_bytes value %r, falling back to 32 MiB",
|
||
configured,
|
||
)
|
||
return 32 * 1024 * 1024
|
||
return max(0, value)
|
||
|
||
@staticmethod
|
||
def _is_discord_voice_message_attachment(att: Any) -> bool:
|
||
"""Return True when a Discord audio attachment is a native voice note."""
|
||
marker = getattr(att, "is_voice_message", None)
|
||
if marker is not None:
|
||
if callable(marker):
|
||
try:
|
||
return bool(marker())
|
||
except Exception as exc:
|
||
logger.debug("[Discord] is_voice_message() failed for attachment: %s", exc)
|
||
return False
|
||
return bool(marker)
|
||
|
||
return (
|
||
getattr(att, "duration", None) is not None
|
||
and getattr(att, "waveform", None) is not None
|
||
)
|
||
|
||
def _discord_free_response_channels(self) -> set:
|
||
"""Return Discord channel IDs where no bot mention is required.
|
||
|
||
A single ``"*"`` entry (either from a list or a comma-separated
|
||
string) is preserved in the returned set so callers can short-circuit
|
||
on wildcard membership, consistent with ``allowed_channels``.
|
||
"""
|
||
raw = self.config.extra.get("free_response_channels")
|
||
if raw is None:
|
||
raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||
if isinstance(raw, list):
|
||
return {str(part).strip() for part in raw if str(part).strip()}
|
||
# Coerce non-list scalars (str/int/float) to str before splitting.
|
||
# YAML parses a bare numeric value such as
|
||
# `free_response_channels: 1491973769726791812` as int, which was
|
||
# previously falling through the isinstance(str) branch and silently
|
||
# returning an empty set. str() here accepts whatever scalar the YAML
|
||
# loader hands us without changing existing string/CSV semantics.
|
||
s = str(raw).strip() if raw is not None else ""
|
||
if s:
|
||
return {part.strip() for part in s.split(",") if part.strip()}
|
||
return set()
|
||
|
||
def _discord_thread_require_mention(self) -> bool:
|
||
"""Return whether thread participation requires @mention to follow up.
|
||
|
||
When ``False`` (default), once the bot has participated in a thread it
|
||
keeps responding to every message in that thread without needing to be
|
||
mentioned again — useful for one-on-one conversations.
|
||
|
||
When ``True``, the @mention requirement is enforced inside threads as
|
||
well. Set this when multiple bots share a thread and you want each
|
||
one to only fire on explicit @mention, avoiding bot-to-bot loops or
|
||
unwanted cross-replies.
|
||
"""
|
||
configured = self.config.extra.get("thread_require_mention")
|
||
if configured is not None:
|
||
if isinstance(configured, str):
|
||
return configured.lower() not in {"false", "0", "no", "off"}
|
||
return bool(configured)
|
||
return os.getenv("DISCORD_THREAD_REQUIRE_MENTION", "false").lower() in {"true", "1", "yes", "on"}
|
||
|
||
def _discord_history_backfill(self) -> bool:
|
||
"""Return whether history backfill is enabled for shared sessions."""
|
||
configured = self.config.extra.get("history_backfill")
|
||
if configured is not None:
|
||
if isinstance(configured, str):
|
||
return configured.lower() not in {"false", "0", "no", "off"}
|
||
return bool(configured)
|
||
return os.getenv("DISCORD_HISTORY_BACKFILL", "true").lower() in {"true", "1", "yes"}
|
||
|
||
def _discord_history_backfill_limit(self) -> int:
|
||
"""Return the max number of messages to scan backwards for context.
|
||
|
||
In practice the scan usually stops much earlier — at the bot's own
|
||
last message in the channel (the natural partition point). This
|
||
limit is a safety cap for cold starts and long gaps where no prior
|
||
bot message exists in recent history.
|
||
"""
|
||
configured = self.config.extra.get("history_backfill_limit")
|
||
if configured is not None:
|
||
try:
|
||
return int(configured)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
raw = os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT", "50")
|
||
try:
|
||
return int(raw)
|
||
except (ValueError, TypeError):
|
||
return 50
|
||
|
||
async def _fetch_channel_context(
|
||
self,
|
||
channel: Any,
|
||
before: "DiscordMessage",
|
||
reply_target: Optional[Any] = None,
|
||
) -> str:
|
||
"""Fetch recent channel messages for conversational context.
|
||
|
||
Scans backwards from *before* and collects messages until it hits
|
||
a message sent by this bot (the natural partition point between
|
||
bot turns) or reaches ``history_backfill_limit``.
|
||
|
||
When ``reply_target`` is provided (the user replied to a specific
|
||
message), a second backward scan is run ending at that target so the
|
||
agent sees the conversation surrounding what the user pointed at —
|
||
even when the reply target sits *before* the most recent bot turn and
|
||
would otherwise be cut off by the self-message partition. The two
|
||
windows are merged chronologically and de-duplicated by message ID.
|
||
|
||
Returns a formatted block like::
|
||
|
||
[Recent channel messages]
|
||
[Alice] some message
|
||
[Bob [bot]] another message
|
||
|
||
Returns an empty string if no context is available.
|
||
"""
|
||
limit = self._discord_history_backfill_limit()
|
||
if limit <= 0:
|
||
return ""
|
||
|
||
# Determine which bot messages to include in context
|
||
allow_bots_raw = os.getenv("DISCORD_ALLOW_BOTS", "none").lower().strip()
|
||
include_other_bots = allow_bots_raw != "none"
|
||
|
||
# Use the in-memory cache to narrow the fetch window on hot paths.
|
||
# If we know our last message ID in this channel, pass it as `after`
|
||
# to avoid scanning the full limit. Falls back to scanning on cache
|
||
# miss (cold start / restart).
|
||
# Guard: only use the cache when it's chronologically before the
|
||
# trigger — Discord snowflake IDs are monotonically increasing, so
|
||
# a simple int comparison suffices.
|
||
channel_id = str(getattr(channel, "id", ""))
|
||
_cached_id = self._last_self_message_id.get(channel_id)
|
||
_after_obj = None
|
||
try:
|
||
if _cached_id and int(_cached_id) < int(before.id):
|
||
_after_obj = discord.Object(id=int(_cached_id))
|
||
except (ValueError, TypeError):
|
||
pass # Malformed cache entry — fall back to cold-start scan
|
||
|
||
try:
|
||
def _keep(msg) -> Optional[str]:
|
||
"""Return a formatted ``[name] content`` line, or None to skip.
|
||
|
||
Encapsulates the system-message / non-conversational / other-bot
|
||
filtering so both the primary and reply-anchored scans apply
|
||
identical rules. Does NOT enforce the self-message partition —
|
||
callers decide where to stop.
|
||
"""
|
||
if msg.type not in {discord.MessageType.default, discord.MessageType.reply}:
|
||
return None
|
||
content = getattr(msg, "clean_content", msg.content) or ""
|
||
if (
|
||
str(getattr(msg, "id", "")) in self._nonconversational_messages
|
||
or _looks_like_nonconversational_history_message(content)
|
||
):
|
||
return None
|
||
# Respect DISCORD_ALLOW_BOTS for other bots. For history
|
||
# context, "mentions" is treated as "all" — we are deciding
|
||
# what context to show, not whether to respond.
|
||
if (
|
||
getattr(msg.author, "bot", False)
|
||
and msg.author != self._client.user
|
||
and not include_other_bots
|
||
):
|
||
return None
|
||
if not content and msg.attachments:
|
||
content = "(attachment)"
|
||
if not content:
|
||
return None
|
||
name = (
|
||
getattr(msg.author, "display_name", None)
|
||
or getattr(msg.author, "name", None)
|
||
or "unknown"
|
||
)
|
||
if getattr(msg.author, "bot", False):
|
||
name = f"{name} [bot]"
|
||
return f"[{name}] {content}"
|
||
|
||
# ── Primary window: recent channel activity since the last bot turn ──
|
||
collected: List[Tuple[str, str]] = [] # (message_id, line)
|
||
seen_ids: set = set()
|
||
# IMPORTANT: pass oldest_first=False explicitly. discord.py 2.x
|
||
# silently flips the default to True when `after=` is supplied,
|
||
# which would select the *earliest* N messages after our last
|
||
# response instead of the *latest* N before the trigger. In
|
||
# high-traffic windows that returns stale tool traces and drops
|
||
# the actual final answer. See the regression test
|
||
# `test_fetch_channel_context_cache_uses_latest_window_when_after_set`.
|
||
async for msg in channel.history(
|
||
limit=limit,
|
||
before=before,
|
||
after=_after_obj,
|
||
oldest_first=False,
|
||
):
|
||
# Non-conversational lifecycle/status bumps (self-improvement
|
||
# reviews, background-process notices, restart banners) must be
|
||
# skipped BEFORE the partition check — otherwise a delayed
|
||
# status bump authored by us would be mistaken for the real
|
||
# last bot turn and hide messages that came after it.
|
||
_content = getattr(msg, "clean_content", msg.content) or ""
|
||
if (
|
||
str(getattr(msg, "id", "")) in self._nonconversational_messages
|
||
or _looks_like_nonconversational_history_message(_content)
|
||
):
|
||
continue
|
||
# Stop at our own (conversational) message — this is the
|
||
# partition point. Everything before this is already in the
|
||
# session transcript. (Redundant when _after_obj is set, but
|
||
# needed for cold start.)
|
||
if msg.author == self._client.user:
|
||
break
|
||
line = _keep(msg)
|
||
if line is None:
|
||
continue
|
||
mid = str(getattr(msg, "id", ""))
|
||
collected.append((mid, line))
|
||
if mid:
|
||
seen_ids.add(mid)
|
||
|
||
# ── Reply window: context around the message the user pointed at ──
|
||
# When the user replied to a specific message that sits BEFORE the
|
||
# primary window's partition point, the surrounding exchange isn't
|
||
# captured above. Fetch a small window ending just after the reply
|
||
# target so the agent sees what it was referencing. This window is
|
||
# NOT partitioned on the self-message boundary — the whole point is
|
||
# to surface older context the transcript lacks.
|
||
reply_collected: List[Tuple[str, str]] = []
|
||
reply_target_id = str(getattr(reply_target, "id", "")) if reply_target else ""
|
||
if reply_target is not None and reply_target_id and reply_target_id not in seen_ids:
|
||
# Reuse the same cap as the primary scan but keep the reply
|
||
# window modest — it's anchored context, not a full backfill.
|
||
reply_limit = max(1, min(limit, 10))
|
||
# `before` is exclusive in discord.py, so to *include* the
|
||
# target we anchor at target_id + 1. Use a minimal snowflake
|
||
# shim (any object exposing ``.id`` satisfies discord.py's
|
||
# Snowflake protocol) rather than discord.Object, so this path
|
||
# works under test doubles that stub the discord module too.
|
||
try:
|
||
_before_obj = _Snowflake(int(reply_target_id) + 1)
|
||
except (ValueError, TypeError):
|
||
_before_obj = before
|
||
async for msg in channel.history(
|
||
limit=reply_limit,
|
||
before=_before_obj,
|
||
oldest_first=False,
|
||
):
|
||
line = _keep(msg)
|
||
if line is None:
|
||
continue
|
||
mid = str(getattr(msg, "id", ""))
|
||
if mid and mid in seen_ids:
|
||
continue
|
||
reply_collected.append((mid, line))
|
||
if mid:
|
||
seen_ids.add(mid)
|
||
|
||
if not collected and not reply_collected:
|
||
return ""
|
||
|
||
# channel.history returns newest-first; reverse each window for
|
||
# chronological order, then present reply context first (it is
|
||
# older) followed by the recent activity.
|
||
collected.reverse()
|
||
reply_collected.reverse()
|
||
|
||
blocks: List[str] = []
|
||
if reply_collected:
|
||
blocks.append(
|
||
"[Context around the replied-to message]\n"
|
||
+ "\n".join(line for _id, line in reply_collected)
|
||
)
|
||
if collected:
|
||
blocks.append(
|
||
"[Recent channel messages]\n"
|
||
+ "\n".join(line for _id, line in collected)
|
||
)
|
||
return "\n\n".join(blocks)
|
||
|
||
except discord.Forbidden:
|
||
logger.debug("[%s] Missing permissions to fetch channel history", self.name)
|
||
return ""
|
||
except Exception as e:
|
||
logger.warning("[%s] Failed to fetch channel history: %s", self.name, e)
|
||
return ""
|
||
|
||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||
"""Return the parent text channel when invoked from a thread."""
|
||
return getattr(channel, "parent", None) or channel
|
||
|
||
async def _resolve_interaction_channel(self, interaction: discord.Interaction) -> Optional[Any]:
|
||
"""Return the interaction channel, fetching it if the payload is partial."""
|
||
channel = getattr(interaction, "channel", None)
|
||
if channel is not None:
|
||
return channel
|
||
if not self._client:
|
||
return None
|
||
channel_id = getattr(interaction, "channel_id", None)
|
||
if channel_id is None:
|
||
return None
|
||
channel = self._client.get_channel(int(channel_id))
|
||
if channel is not None:
|
||
return channel
|
||
try:
|
||
return await self._client.fetch_channel(int(channel_id))
|
||
except Exception:
|
||
return None
|
||
|
||
async def _create_thread(
|
||
self,
|
||
interaction: discord.Interaction,
|
||
*,
|
||
name: str,
|
||
message: str = "",
|
||
auto_archive_duration: int = 1440,
|
||
) -> Dict[str, Any]:
|
||
"""Create a thread in the current Discord channel.
|
||
|
||
Tries ``parent_channel.create_thread()`` first. If Discord rejects
|
||
that (e.g. permission issues), falls back to sending a seed message
|
||
and creating the thread from it.
|
||
"""
|
||
name = (name or "").strip()
|
||
if not name:
|
||
return {"error": "Thread name is required."}
|
||
|
||
if auto_archive_duration not in VALID_THREAD_AUTO_ARCHIVE_MINUTES:
|
||
allowed = ", ".join(str(v) for v in sorted(VALID_THREAD_AUTO_ARCHIVE_MINUTES))
|
||
return {"error": f"auto_archive_duration must be one of: {allowed}."}
|
||
|
||
channel = await self._resolve_interaction_channel(interaction)
|
||
if channel is None:
|
||
return {"error": "Could not resolve the current Discord channel."}
|
||
if isinstance(channel, discord.DMChannel):
|
||
return {"error": "Discord threads can only be created inside server text channels, not DMs."}
|
||
|
||
parent_channel = self._thread_parent_channel(channel)
|
||
if parent_channel is None:
|
||
return {"error": "Could not determine a parent text channel for the new thread."}
|
||
|
||
display_name = getattr(getattr(interaction, "user", None), "display_name", None) or "unknown user"
|
||
reason = f"Requested by {display_name} via /thread"
|
||
starter_message = (message or "").strip()
|
||
|
||
try:
|
||
thread = await parent_channel.create_thread(
|
||
name=name,
|
||
auto_archive_duration=auto_archive_duration,
|
||
reason=reason,
|
||
)
|
||
if starter_message:
|
||
await thread.send(starter_message)
|
||
return {
|
||
"success": True,
|
||
"thread_id": str(thread.id),
|
||
"thread_name": getattr(thread, "name", None) or name,
|
||
}
|
||
except Exception as direct_error:
|
||
try:
|
||
seed_content = starter_message or f"\U0001f9f5 Thread created by Hermes: **{name}**"
|
||
seed_msg = await parent_channel.send(seed_content)
|
||
thread = await seed_msg.create_thread(
|
||
name=name,
|
||
auto_archive_duration=auto_archive_duration,
|
||
reason=reason,
|
||
)
|
||
return {
|
||
"success": True,
|
||
"thread_id": str(thread.id),
|
||
"thread_name": getattr(thread, "name", None) or name,
|
||
}
|
||
except Exception as fallback_error:
|
||
return {
|
||
"error": (
|
||
"Discord rejected direct thread creation and the fallback also failed. "
|
||
f"Direct error: {direct_error}. Fallback error: {fallback_error}"
|
||
)
|
||
}
|
||
|
||
# ------------------------------------------------------------------
|
||
# Auto-thread helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _auto_create_thread(self, message: 'DiscordMessage') -> Optional[Any]:
|
||
"""Create a thread from a user message for auto-threading.
|
||
|
||
Returns the created thread object, or ``None`` on failure.
|
||
"""
|
||
# Build a short thread name from the message. Strip Discord mention
|
||
# syntax (users / roles / channels) so thread titles don't end up
|
||
# showing raw <@id>, <@&id>, or <#id> markers — the ID isn't
|
||
# meaningful to humans glancing at the thread list (#6336).
|
||
content = (message.content or "").strip()
|
||
# <@123>, <@!123>, <@&123>, <#123> — collapse to empty; normalize spaces.
|
||
content = re.sub(r"<@[!&]?\d+>", "", content)
|
||
content = re.sub(r"<#\d+>", "", content)
|
||
content = re.sub(r"\s+", " ", content).strip()
|
||
thread_name = content[:80] if content else "Hermes"
|
||
if len(content) > 80:
|
||
thread_name = thread_name[:77] + "..."
|
||
|
||
try:
|
||
thread = await message.create_thread(name=thread_name, auto_archive_duration=1440)
|
||
return thread
|
||
except Exception as direct_error:
|
||
display_name = getattr(getattr(message, "author", None), "display_name", None) or "unknown user"
|
||
reason = f"Auto-threaded from mention by {display_name}"
|
||
try:
|
||
seed_msg = await message.channel.send(f"\U0001f9f5 Thread created by Hermes: **{thread_name}**")
|
||
thread = await seed_msg.create_thread(
|
||
name=thread_name,
|
||
auto_archive_duration=1440,
|
||
reason=reason,
|
||
)
|
||
return thread
|
||
except Exception as fallback_error:
|
||
logger.warning(
|
||
"[%s] Auto-thread creation failed. Direct error: %s. Fallback error: %s",
|
||
self.name,
|
||
direct_error,
|
||
fallback_error,
|
||
)
|
||
return None
|
||
|
||
async def create_handoff_thread(
|
||
self,
|
||
parent_chat_id: str,
|
||
name: str,
|
||
) -> Optional[str]:
|
||
"""Create a Discord thread under a text channel for a handoff.
|
||
|
||
Falls back to a seed-message + ``message.create_thread`` path if
|
||
``parent.create_thread`` is rejected (some channel types or
|
||
permission setups). Returns the new thread id as a string, or
|
||
``None`` on failure or when the parent isn't a text channel
|
||
(DMs, voice channels, threads themselves can't host threads).
|
||
"""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return None
|
||
|
||
try:
|
||
parent_id = int(parent_chat_id)
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
try:
|
||
parent = self._client.get_channel(parent_id)
|
||
if parent is None:
|
||
parent = await self._client.fetch_channel(parent_id)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[%s] Handoff thread: cannot resolve parent %s: %s",
|
||
self.name, parent_chat_id, exc,
|
||
)
|
||
return None
|
||
|
||
# DMs, voice channels, and existing threads can't host child threads.
|
||
if isinstance(parent, getattr(discord, "DMChannel", ())):
|
||
logger.info(
|
||
"[%s] Handoff thread: parent %s is a DM; threads not supported here",
|
||
self.name, parent_chat_id,
|
||
)
|
||
return None
|
||
|
||
thread_name = (name or "handoff").strip()[:80] or "handoff"
|
||
reason = "Hermes session handoff"
|
||
|
||
# First try: create a thread directly on the channel.
|
||
try:
|
||
create = getattr(parent, "create_thread", None)
|
||
if create is not None:
|
||
thread = await create(
|
||
name=thread_name,
|
||
auto_archive_duration=1440,
|
||
reason=reason,
|
||
)
|
||
return str(thread.id)
|
||
except Exception as direct_error:
|
||
logger.debug(
|
||
"[%s] Handoff thread: direct create failed (%s); trying seed-message fallback",
|
||
self.name, direct_error,
|
||
)
|
||
|
||
# Fallback: post a seed message and create the thread from it.
|
||
try:
|
||
send = getattr(parent, "send", None)
|
||
if send is None:
|
||
return None
|
||
seed_msg = await send(f"\U0001f9f5 Hermes handoff: **{thread_name}**")
|
||
thread = await seed_msg.create_thread(
|
||
name=thread_name,
|
||
auto_archive_duration=1440,
|
||
reason=reason,
|
||
)
|
||
return str(thread.id)
|
||
except Exception as fallback_error:
|
||
logger.warning(
|
||
"[%s] Handoff thread: both create paths failed for parent %s: %s",
|
||
self.name, parent_chat_id, fallback_error,
|
||
)
|
||
return None
|
||
|
||
async def send_exec_approval(
|
||
self, chat_id: str, command: str, session_key: str,
|
||
description: str = "dangerous command",
|
||
metadata: Optional[dict] = None,
|
||
) -> SendResult:
|
||
"""
|
||
Send a button-based exec approval prompt for a dangerous command.
|
||
|
||
The buttons call ``resolve_gateway_approval()`` to unblock the waiting
|
||
agent thread — this replaces the text-based ``/approve`` flow on Discord.
|
||
"""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
try:
|
||
# Resolve channel — use thread_id from metadata if present
|
||
target_id = chat_id
|
||
if metadata and metadata.get("thread_id"):
|
||
target_id = metadata["thread_id"]
|
||
|
||
channel = self._client.get_channel(int(target_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(target_id))
|
||
|
||
# Discord embed description limit is 4096; show full command up to that
|
||
max_desc = 4088
|
||
cmd_display = command if len(command) <= max_desc else command[: max_desc - 3] + "..."
|
||
embed = discord.Embed(
|
||
title="⚠️ Command Approval Required",
|
||
description=f"```\n{cmd_display}\n```",
|
||
color=discord.Color.orange(),
|
||
)
|
||
embed.add_field(name="Reason", value=description, inline=False)
|
||
|
||
view = ExecApprovalView(
|
||
session_key=session_key,
|
||
allowed_user_ids=self._allowed_user_ids,
|
||
allowed_role_ids=self._allowed_role_ids,
|
||
)
|
||
|
||
msg = await channel.send(embed=embed, view=view)
|
||
view._message = msg # store for on_timeout expiration editing
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
|
||
except Exception as e:
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def send_slash_confirm(
|
||
self, chat_id: str, title: str, message: str, session_key: str,
|
||
confirm_id: str, metadata: Optional[dict] = None,
|
||
) -> SendResult:
|
||
"""Send a three-button slash-command confirmation prompt."""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
try:
|
||
target_id = chat_id
|
||
if metadata and metadata.get("thread_id"):
|
||
target_id = metadata["thread_id"]
|
||
|
||
channel = self._client.get_channel(int(target_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(target_id))
|
||
|
||
# Embed description limit is 4096; message usually fits easily.
|
||
max_desc = 4088
|
||
body = message if len(message) <= max_desc else message[: max_desc - 3] + "..."
|
||
embed = discord.Embed(
|
||
title=title or "Confirm",
|
||
description=body,
|
||
color=discord.Color.orange(),
|
||
)
|
||
|
||
view = SlashConfirmView(
|
||
session_key=session_key,
|
||
confirm_id=confirm_id,
|
||
allowed_user_ids=self._allowed_user_ids,
|
||
allowed_role_ids=self._allowed_role_ids,
|
||
)
|
||
|
||
msg = await channel.send(embed=embed, view=view)
|
||
view._message = msg # store for on_timeout expiration editing
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
except Exception as e:
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def send_clarify(
|
||
self,
|
||
chat_id: str,
|
||
question: str,
|
||
choices: Optional[list],
|
||
clarify_id: str,
|
||
session_key: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Render a clarify prompt with one Discord button per choice.
|
||
|
||
Multi-choice mode (``choices`` non-empty): renders a button per option
|
||
plus a final "✏️ Other (type answer)" button. Picking "Other" flips
|
||
the clarify entry into text-capture mode so the next user message in
|
||
the session becomes the response. Numeric clicks resolve immediately
|
||
via ``resolve_gateway_clarify(clarify_id, choice_text)``.
|
||
|
||
Open-ended mode (``choices`` empty/None): renders the question as
|
||
plain embed text — no buttons. The gateway's text-intercept captures
|
||
the next message in this session and resolves the clarify.
|
||
|
||
Choice normalisation: ``choices`` may contain bare strings OR dicts
|
||
(LLMs sometimes emit ``[{"description": "..."}]`` instead of bare
|
||
strings, which would otherwise render as raw Python repr on the
|
||
button label). Dict choices are unwrapped against the canonical
|
||
LLM tool-call keys ``label``, ``description``, ``text``, ``title``
|
||
in that order. Dicts with none of those keys are dropped.
|
||
"""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
try:
|
||
target_id = chat_id
|
||
if metadata and metadata.get("thread_id"):
|
||
target_id = metadata["thread_id"]
|
||
|
||
channel = self._client.get_channel(int(target_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(target_id))
|
||
|
||
# Discord embed description limit is 4096; trim conservatively.
|
||
max_desc = 4088
|
||
body = str(question or "").strip()
|
||
if len(body) > max_desc:
|
||
body = body[: max_desc - 3] + "..."
|
||
|
||
embed = discord.Embed(
|
||
title="❓ Hermes needs your input",
|
||
description=body,
|
||
color=discord.Color.orange(),
|
||
)
|
||
|
||
# Normalise choices: LLMs sometimes emit `[{"description": "..."}]`
|
||
# instead of bare strings, which would render as raw Python repr on
|
||
# the button label. Unwrap the common shapes, then stringify.
|
||
def _flatten_choice(c):
|
||
if c is None:
|
||
return ""
|
||
if isinstance(c, str):
|
||
return c.strip()
|
||
if isinstance(c, dict):
|
||
# Prefer the canonical LLM tool-call user-facing keys
|
||
# in the order the LLM is most likely to emit them.
|
||
# 'name' and 'value' are deliberately NOT here: they're
|
||
# Discord-component-shaped fields that could appear in
|
||
# dicts that aren't meant to be choices (e.g., a
|
||
# developer-error wiring that passes a Button-shaped
|
||
# object). Picking them would leak raw enum values
|
||
# or 4-char model identifiers onto user-facing buttons.
|
||
# If a dict has none of the canonical keys, drop it
|
||
# rather than picking some random field — a garbage
|
||
# button label is worse than no button at all.
|
||
for key in ("label", "description", "text", "title"):
|
||
v = c.get(key)
|
||
if isinstance(v, str) and v.strip():
|
||
return v.strip()
|
||
return ""
|
||
if isinstance(c, (list, tuple)):
|
||
return " ".join(_flatten_choice(x) for x in c).strip()
|
||
return str(c).strip()
|
||
|
||
clean_choices = [
|
||
s for s in (_flatten_choice(c) for c in (choices or [])) if s
|
||
]
|
||
# Discord allows up to 5 buttons per row, 5 rows per view = 25.
|
||
# We reserve one slot for the "Other" button, so cap at 24 choices.
|
||
clean_choices = clean_choices[:24]
|
||
|
||
if clean_choices:
|
||
embed.add_field(
|
||
name="Choices",
|
||
value="Pick one below, or click ✏️ Other to type a custom answer.",
|
||
inline=False,
|
||
)
|
||
view = ClarifyChoiceView(
|
||
choices=clean_choices,
|
||
clarify_id=clarify_id,
|
||
allowed_user_ids=self._allowed_user_ids,
|
||
allowed_role_ids=self._allowed_role_ids,
|
||
)
|
||
else:
|
||
embed.add_field(
|
||
name="Reply",
|
||
value="Reply in this channel with your answer.",
|
||
inline=False,
|
||
)
|
||
view = None
|
||
|
||
msg = await channel.send(embed=embed, view=view) if view else await channel.send(embed=embed)
|
||
if view:
|
||
view._message = msg # store for on_timeout expiration editing
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
except Exception as e:
|
||
logger.warning("[%s] send_clarify failed: %s", self.name, e)
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def send_update_prompt(
|
||
self, chat_id: str, prompt: str, default: str = "",
|
||
session_key: str = "",
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send an interactive button-based update prompt (Yes / No).
|
||
|
||
Used by the gateway ``/update`` watcher when ``hermes update --gateway``
|
||
needs user input (stash restore, config migration).
|
||
"""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return SendResult(success=False, error="Not connected")
|
||
try:
|
||
target_id = metadata.get("thread_id") if metadata and metadata.get("thread_id") else chat_id
|
||
channel = self._client.get_channel(int(target_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(target_id))
|
||
|
||
default_hint = f" (default: {default})" if default else ""
|
||
embed = discord.Embed(
|
||
title="⚕ Update Needs Your Input",
|
||
description=f"{prompt}{default_hint}",
|
||
color=discord.Color.gold(),
|
||
)
|
||
view = UpdatePromptView(
|
||
session_key=session_key,
|
||
allowed_user_ids=self._allowed_user_ids,
|
||
allowed_role_ids=self._allowed_role_ids,
|
||
)
|
||
msg = await channel.send(embed=embed, view=view)
|
||
view._message = msg # store for on_timeout expiration editing
|
||
if _metadata_marks_nonconversational(metadata):
|
||
self._nonconversational_messages.mark_many([str(msg.id)])
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
except Exception as e:
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
async def send_model_picker(
|
||
self,
|
||
chat_id: str,
|
||
providers: list,
|
||
current_model: str,
|
||
current_provider: str,
|
||
session_key: str,
|
||
on_model_selected,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> SendResult:
|
||
"""Send an interactive select-menu model picker.
|
||
|
||
Two-step drill-down: provider dropdown → model dropdown.
|
||
Uses Discord embeds + Select menus via ``ModelPickerView``.
|
||
"""
|
||
if not self._client or not DISCORD_AVAILABLE:
|
||
return SendResult(success=False, error="Not connected")
|
||
|
||
try:
|
||
# Resolve target channel (use thread_id if present)
|
||
target_id = chat_id
|
||
if metadata and metadata.get("thread_id"):
|
||
target_id = metadata["thread_id"]
|
||
|
||
channel = self._client.get_channel(int(target_id))
|
||
if not channel:
|
||
channel = await self._client.fetch_channel(int(target_id))
|
||
|
||
try:
|
||
from hermes_cli.providers import get_label
|
||
provider_label = get_label(current_provider)
|
||
except Exception:
|
||
provider_label = current_provider
|
||
|
||
embed = discord.Embed(
|
||
title="⚙ Model Configuration",
|
||
description=(
|
||
f"Current model: `{current_model or 'unknown'}`\n"
|
||
f"Provider: {provider_label}\n\n"
|
||
f"Select a provider:"
|
||
),
|
||
color=discord.Color.blue(),
|
||
)
|
||
|
||
view = ModelPickerView(
|
||
providers=providers,
|
||
current_model=current_model,
|
||
current_provider=current_provider,
|
||
session_key=session_key,
|
||
on_model_selected=on_model_selected,
|
||
allowed_user_ids=self._allowed_user_ids,
|
||
allowed_role_ids=self._allowed_role_ids,
|
||
)
|
||
|
||
msg = await channel.send(embed=embed, view=view)
|
||
view._message = msg # store for on_timeout expiration editing
|
||
return SendResult(success=True, message_id=str(msg.id))
|
||
|
||
except Exception as e:
|
||
logger.warning("[%s] send_model_picker failed: %s", self.name, e)
|
||
return SendResult(success=False, error=str(e))
|
||
|
||
def _get_parent_channel_id(self, channel: Any) -> Optional[str]:
|
||
"""Return the parent channel ID for a Discord thread-like channel, if present."""
|
||
parent = getattr(channel, "parent", None)
|
||
if parent is not None and getattr(parent, "id", None) is not None:
|
||
return str(parent.id)
|
||
parent_id = getattr(channel, "parent_id", None)
|
||
if parent_id is not None:
|
||
return str(parent_id)
|
||
return None
|
||
|
||
def _is_forum_parent(self, channel: Any) -> bool:
|
||
"""Best-effort check for whether a Discord channel is a forum channel."""
|
||
if channel is None:
|
||
return False
|
||
forum_cls = getattr(discord, "ForumChannel", None)
|
||
if forum_cls and isinstance(channel, forum_cls):
|
||
return True
|
||
channel_type = getattr(channel, "type", None)
|
||
if channel_type is not None:
|
||
type_value = getattr(channel_type, "value", channel_type)
|
||
if type_value == 15:
|
||
return True
|
||
return False
|
||
|
||
def _get_effective_topic(self, channel: Any, is_thread: bool = False) -> Optional[str]:
|
||
"""Return the channel topic, falling back to the parent forum's topic for forum threads."""
|
||
topic = getattr(channel, "topic", None)
|
||
if not topic and is_thread:
|
||
parent = getattr(channel, "parent", None)
|
||
if parent and self._is_forum_parent(parent):
|
||
topic = getattr(parent, "topic", None)
|
||
return topic
|
||
|
||
def _format_thread_chat_name(self, thread: Any) -> str:
|
||
"""Build a readable chat name for thread-like Discord channels, including forum context when available."""
|
||
thread_name = getattr(thread, "name", None) or str(getattr(thread, "id", "thread"))
|
||
parent = getattr(thread, "parent", None)
|
||
guild = getattr(thread, "guild", None) or getattr(parent, "guild", None)
|
||
guild_name = getattr(guild, "name", None)
|
||
parent_name = getattr(parent, "name", None)
|
||
|
||
if self._is_forum_parent(parent) and guild_name and parent_name:
|
||
return f"{guild_name} / {parent_name} / {thread_name}"
|
||
if parent_name and guild_name:
|
||
return f"{guild_name} / #{parent_name} / {thread_name}"
|
||
if parent_name:
|
||
return f"{parent_name} / {thread_name}"
|
||
return thread_name
|
||
|
||
# ------------------------------------------------------------------
|
||
# Attachment download helpers
|
||
#
|
||
# Discord attachments (images / audio / documents) are fetched via the
|
||
# authenticated bot session whenever the Attachment object exposes
|
||
# ``read()``. That sidesteps two classes of bug that hit the older
|
||
# plain-HTTP path:
|
||
#
|
||
# 1. ``cdn.discordapp.com`` URLs increasingly require bot auth on
|
||
# download — unauthenticated httpx sees 403 Forbidden.
|
||
# (issue #8242)
|
||
# 2. Some user environments (VPNs, corporate DNS, tunnels) resolve
|
||
# ``cdn.discordapp.com`` to private-looking IPs that our
|
||
# ``is_safe_url`` guard classifies as SSRF risks. Routing the
|
||
# fetch through discord.py's own HTTP client handles DNS
|
||
# internally so our guard isn't consulted for the attachment
|
||
# path. (issue #6587)
|
||
#
|
||
# If ``att.read()`` is unavailable (unexpected object shape / test
|
||
# stub) or the bot session fetch fails, we fall back to the existing
|
||
# SSRF-gated URL downloaders. The fallback keeps defense-in-depth
|
||
# against any future Discord payload-schema drift that could slip a
|
||
# non-CDN URL into the ``att.url`` field. (issue #11345)
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _read_attachment_bytes(
|
||
self,
|
||
att,
|
||
*,
|
||
media_type: str = "media",
|
||
) -> Optional[bytes]:
|
||
"""Read an attachment via discord.py's authenticated bot session.
|
||
|
||
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
|
||
expose a callable ``read()`` or the read itself fails. Callers
|
||
should treat ``None`` as a signal to fall back to the URL-based
|
||
downloaders.
|
||
|
||
Oversized attachments (per ``gateway.max_inbound_media_bytes``) raise
|
||
``ValueError`` BEFORE the bytes are pulled into memory when Discord
|
||
reports the size up front, so a hostile upload can't OOM the gateway.
|
||
"""
|
||
attachment_size = getattr(att, "size", None)
|
||
if attachment_size:
|
||
validate_inbound_media_size(int(attachment_size), media_type=media_type)
|
||
|
||
reader = getattr(att, "read", None)
|
||
if reader is None or not callable(reader):
|
||
return None
|
||
try:
|
||
raw_bytes = await reader()
|
||
except Exception as e:
|
||
logger.warning(
|
||
"[Discord] Authenticated attachment read failed for %s: %s",
|
||
getattr(att, "filename", None) or getattr(att, "url", "<unknown>"),
|
||
e,
|
||
)
|
||
return None
|
||
validate_inbound_media_size(len(raw_bytes), media_type=media_type)
|
||
return raw_bytes
|
||
|
||
async def _cache_discord_image(self, att, ext: str) -> str:
|
||
"""Cache a Discord image attachment to local disk.
|
||
|
||
Primary path: ``att.read()`` + ``cache_image_from_bytes``
|
||
(authenticated, no SSRF gate).
|
||
|
||
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
|
||
"""
|
||
raw_bytes = await self._read_attachment_bytes(att, media_type="image")
|
||
if raw_bytes is not None:
|
||
try:
|
||
return cache_image_from_bytes(raw_bytes, ext=ext)
|
||
except Exception as e:
|
||
logger.debug(
|
||
"[Discord] cache_image_from_bytes rejected att.read() data; falling back to URL: %s",
|
||
e,
|
||
)
|
||
return await cache_image_from_url(att.url, ext=ext)
|
||
|
||
async def _cache_discord_audio(self, att, ext: str) -> str:
|
||
"""Cache a Discord audio attachment to local disk.
|
||
|
||
Primary path: ``att.read()`` + ``cache_audio_from_bytes``
|
||
(authenticated, no SSRF gate).
|
||
|
||
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
|
||
"""
|
||
raw_bytes = await self._read_attachment_bytes(att, media_type="audio")
|
||
if raw_bytes is not None:
|
||
try:
|
||
return cache_audio_from_bytes(raw_bytes, ext=ext)
|
||
except Exception as e:
|
||
logger.debug(
|
||
"[Discord] cache_audio_from_bytes failed; falling back to URL: %s",
|
||
e,
|
||
)
|
||
return await cache_audio_from_url(att.url, ext=ext)
|
||
|
||
async def _cache_discord_document(self, att, ext: str) -> bytes:
|
||
"""Download a Discord document attachment and return the raw bytes.
|
||
|
||
Primary path: ``att.read()`` (authenticated, no SSRF gate).
|
||
|
||
Fallback: SSRF-gated ``aiohttp`` download. This closes the gap
|
||
where the old document path made raw ``aiohttp.ClientSession``
|
||
requests with no safety check (#11345). The caller is responsible
|
||
for passing the returned bytes to ``cache_document_from_bytes``
|
||
(and, where applicable, for injecting text content).
|
||
"""
|
||
raw_bytes = await self._read_attachment_bytes(att, media_type="document")
|
||
if raw_bytes is not None:
|
||
return raw_bytes
|
||
|
||
# Fallback: SSRF-gated URL download.
|
||
if not is_safe_url(att.url):
|
||
raise ValueError(
|
||
f"Blocked unsafe attachment URL (SSRF protection): {att.url}"
|
||
)
|
||
import aiohttp
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||
async with session.get(
|
||
att.url,
|
||
timeout=aiohttp.ClientTimeout(total=30),
|
||
**_req_kw,
|
||
) as resp:
|
||
if resp.status != 200:
|
||
raise Exception(f"HTTP {resp.status}")
|
||
return await resp.read()
|
||
|
||
async def _handle_message(self, message: DiscordMessage, role_authorized: bool = False) -> None:
|
||
"""Handle incoming Discord messages."""
|
||
# In server channels (not DMs), require the bot to be @mentioned
|
||
# UNLESS the channel is in the free-response list or the message is
|
||
# in a thread where the bot has already participated.
|
||
#
|
||
# Config (all settable via discord.* in config.yaml or DISCORD_* env vars):
|
||
# discord.require_mention: Require @mention in server channels (default: true)
|
||
# discord.free_response_channels: Channel IDs where bot responds without mention
|
||
# discord.ignored_channels: Channel IDs where bot NEVER responds (even when mentioned)
|
||
# discord.allowed_channels: If set, bot ONLY responds in these channels (whitelist)
|
||
# discord.no_thread_channels: Channel IDs where bot responds directly without creating thread
|
||
# discord.auto_thread: Auto-create thread on @mention in channels (default: true)
|
||
|
||
thread_id = None
|
||
parent_channel_id = None
|
||
is_thread = isinstance(message.channel, discord.Thread)
|
||
if is_thread:
|
||
thread_id = str(message.channel.id)
|
||
parent_channel_id = self._get_parent_channel_id(message.channel)
|
||
|
||
is_voice_linked_channel = False
|
||
|
||
# Save mention-stripped text before auto-threading since create_thread()
|
||
# can clobber message.content, breaking /command detection in channels.
|
||
raw_content = message.content.strip()
|
||
normalized_content = raw_content
|
||
mention_prefix = False
|
||
|
||
snapshot_attachments = []
|
||
if hasattr(message, "message_snapshots") and message.message_snapshots:
|
||
snapshot_text_parts = []
|
||
for snap in message.message_snapshots:
|
||
if getattr(snap, "content", None):
|
||
snapshot_text_parts.append(snap.content.strip())
|
||
snapshot_attachments.extend(getattr(snap, "attachments", []) or [])
|
||
if snapshot_text_parts and not raw_content:
|
||
raw_content = "\n".join(snapshot_text_parts)
|
||
normalized_content = raw_content
|
||
if self._client.user and self._client.user in message.mentions:
|
||
mention_prefix = True
|
||
normalized_content = normalized_content.replace(f"<@{self._client.user.id}>", "").strip()
|
||
normalized_content = normalized_content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||
message.content = normalized_content
|
||
if not isinstance(message.channel, discord.DMChannel):
|
||
channel_ids = {str(message.channel.id)}
|
||
if parent_channel_id:
|
||
channel_ids.add(parent_channel_id)
|
||
|
||
# Check allowed channels - if set, only respond in these channels
|
||
allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
|
||
if allowed_channels_raw:
|
||
allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()}
|
||
if "*" not in allowed_channels and not (channel_ids & allowed_channels):
|
||
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids)
|
||
return
|
||
|
||
# Check ignored channels - never respond even when mentioned
|
||
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
|
||
if "*" in ignored_channels or (channel_ids & ignored_channels):
|
||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||
return
|
||
|
||
free_channels = self._discord_free_response_channels()
|
||
if parent_channel_id:
|
||
channel_ids.add(parent_channel_id)
|
||
|
||
require_mention = self._discord_require_mention()
|
||
# Voice-linked text channels act as free-response while voice is active.
|
||
# Only the exact bound channel gets the exemption, not sibling threads.
|
||
voice_linked_ids = {str(ch_id) for ch_id in self._voice_text_channels.values()}
|
||
current_channel_id = str(message.channel.id)
|
||
is_voice_linked_channel = current_channel_id in voice_linked_ids
|
||
is_free_channel = (
|
||
"*" in free_channels
|
||
or bool(channel_ids & free_channels)
|
||
or is_voice_linked_channel
|
||
)
|
||
|
||
# Skip the mention check if the message is in a thread where
|
||
# the bot has previously participated (auto-created or replied in)
|
||
# — UNLESS thread_require_mention is enabled, in which case threads
|
||
# are gated the same as channels. Useful when multiple bots share
|
||
# a thread.
|
||
in_bot_thread = (
|
||
is_thread
|
||
and thread_id in self._threads
|
||
and not self._discord_thread_require_mention()
|
||
)
|
||
|
||
if require_mention and not is_free_channel and not in_bot_thread:
|
||
if self._client.user not in message.mentions and not mention_prefix:
|
||
return
|
||
# Auto-thread: when enabled, automatically create a thread for every
|
||
# @mention in a text channel so each conversation is isolated (like Slack).
|
||
# Messages already inside threads or DMs are unaffected.
|
||
# no_thread_channels: channels where bot responds directly without thread.
|
||
auto_threaded_channel = None
|
||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||
skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel
|
||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in {"true", "1", "yes"}
|
||
is_reply_message = getattr(message, "type", None) == discord.MessageType.reply
|
||
if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message:
|
||
thread = await self._auto_create_thread(message)
|
||
if thread:
|
||
parent_channel_id = str(message.channel.id)
|
||
is_thread = True
|
||
thread_id = str(thread.id)
|
||
auto_threaded_channel = thread
|
||
self._threads.mark(thread_id)
|
||
|
||
referenced_attachments = []
|
||
reference = getattr(message, "reference", None)
|
||
resolved_reference = getattr(reference, "resolved", None) if reference else None
|
||
if resolved_reference is not None:
|
||
referenced_attachments = list(getattr(resolved_reference, "attachments", []) or [])
|
||
|
||
all_attachments = list(message.attachments) + snapshot_attachments + referenced_attachments
|
||
|
||
# Determine message type
|
||
msg_type = MessageType.TEXT
|
||
if normalized_content.startswith("/"):
|
||
msg_type = MessageType.COMMAND
|
||
elif all_attachments:
|
||
# Check attachment types. Any non-media attachment is treated as a
|
||
# DOCUMENT regardless of extension — authorization to message the
|
||
# agent is the gate, not the file type.
|
||
for att in all_attachments:
|
||
if att.content_type:
|
||
if att.content_type.startswith("image/"):
|
||
msg_type = MessageType.PHOTO
|
||
elif att.content_type.startswith("video/"):
|
||
msg_type = MessageType.VIDEO
|
||
elif att.content_type.startswith("audio/"):
|
||
if self._is_discord_voice_message_attachment(att):
|
||
msg_type = MessageType.VOICE
|
||
else:
|
||
msg_type = MessageType.AUDIO
|
||
else:
|
||
msg_type = MessageType.DOCUMENT
|
||
break
|
||
else:
|
||
# No content_type at all (rare — discord usually fills it
|
||
# in). Treat as a document so downstream pipelines surface
|
||
# the path to the agent.
|
||
msg_type = MessageType.DOCUMENT
|
||
break
|
||
|
||
# When auto-threading kicked in, route responses to the new thread
|
||
effective_channel = auto_threaded_channel or message.channel
|
||
|
||
# Determine chat type
|
||
if isinstance(message.channel, discord.DMChannel):
|
||
chat_type = "dm"
|
||
chat_name = message.author.name
|
||
elif is_thread:
|
||
chat_type = "thread"
|
||
chat_name = self._format_thread_chat_name(effective_channel)
|
||
else:
|
||
chat_type = "group"
|
||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||
|
||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't).
|
||
# For threads whose parent is a forum channel, inherit the parent's topic
|
||
# so forum descriptions (e.g. project instructions) appear in the session context.
|
||
chat_topic = self._get_effective_topic(message.channel, is_thread=is_thread)
|
||
|
||
# Build source
|
||
guild = getattr(message, "guild", None)
|
||
source = self.build_source(
|
||
chat_id=str(effective_channel.id),
|
||
chat_name=chat_name,
|
||
chat_type=chat_type,
|
||
user_id=str(message.author.id),
|
||
user_name=message.author.display_name,
|
||
thread_id=thread_id,
|
||
chat_topic=chat_topic,
|
||
is_bot=getattr(message.author, "bot", False),
|
||
guild_id=str(guild.id) if guild else None,
|
||
parent_chat_id=parent_channel_id,
|
||
message_id=str(message.id),
|
||
role_authorized=role_authorized,
|
||
)
|
||
|
||
# Build media URLs -- download image attachments to local cache so the
|
||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||
media_urls = []
|
||
media_types = []
|
||
pending_text_injection: Optional[str] = None
|
||
for att in all_attachments:
|
||
content_type = att.content_type or "unknown"
|
||
if content_type.startswith("image/"):
|
||
try:
|
||
# Determine extension from content type (image/png -> .png)
|
||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||
if ext not in {".jpg", ".jpeg", ".png", ".gif", ".webp"}:
|
||
ext = ".jpg"
|
||
cached_path = await self._cache_discord_image(att, ext)
|
||
media_urls.append(cached_path)
|
||
media_types.append(content_type)
|
||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||
except Exception as e:
|
||
print(f"[Discord] Failed to cache image attachment: {e}", flush=True)
|
||
# Fall back to the CDN URL if caching fails
|
||
media_urls.append(att.url)
|
||
media_types.append(content_type)
|
||
elif content_type.startswith("audio/"):
|
||
try:
|
||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||
if ext not in {".ogg", ".mp3", ".wav", ".webm", ".m4a"}:
|
||
ext = ".ogg"
|
||
cached_path = await self._cache_discord_audio(att, ext)
|
||
media_urls.append(cached_path)
|
||
media_types.append(content_type)
|
||
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
|
||
except Exception as e:
|
||
print(f"[Discord] Failed to cache audio attachment: {e}", flush=True)
|
||
media_urls.append(att.url)
|
||
media_types.append(content_type)
|
||
else:
|
||
# Document attachments: download, cache, and optionally inject text
|
||
ext = ""
|
||
if att.filename:
|
||
_, ext = os.path.splitext(att.filename)
|
||
ext = ext.lower()
|
||
if not ext and content_type:
|
||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||
ext = mime_to_ext.get(content_type, "")
|
||
in_allowlist = ext in SUPPORTED_DOCUMENT_TYPES
|
||
# Any file type is accepted — authorization to message the agent
|
||
# is the gate, not the file extension. Known types keep their
|
||
# precise MIME; unknown types fall back to the source content_type
|
||
# or octet-stream so the agent reaches for terminal tools.
|
||
max_doc_bytes = self._discord_max_attachment_bytes()
|
||
if max_doc_bytes and att.size and att.size > max_doc_bytes:
|
||
logger.warning(
|
||
"[Discord] Document too large (%s bytes > cap %s), skipping: %s",
|
||
att.size, max_doc_bytes, att.filename,
|
||
)
|
||
else:
|
||
try:
|
||
raw_bytes = await self._cache_discord_document(att, ext)
|
||
cached_path = cache_document_from_bytes(
|
||
raw_bytes, att.filename or f"document{ext or '.bin'}"
|
||
)
|
||
if in_allowlist:
|
||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||
else:
|
||
# Untyped file. Use the source content_type if
|
||
# discord gave us one, otherwise fall back to
|
||
# octet-stream so the agent knows it's binary and
|
||
# reaches for terminal tools.
|
||
doc_mime = (
|
||
content_type
|
||
if content_type and content_type != "unknown"
|
||
else "application/octet-stream"
|
||
)
|
||
media_urls.append(cached_path)
|
||
media_types.append(doc_mime)
|
||
logger.info(
|
||
"[Discord] Cached user %s: %s",
|
||
"document" if in_allowlist else "attachment",
|
||
cached_path,
|
||
)
|
||
# Inject text content for any text-readable document
|
||
# Inject text content for text-readable documents
|
||
# (capped at 100 KB). Gate on a text-like extension/MIME
|
||
# — NOT a blind UTF-8 decode, since binary formats like
|
||
# PDF/zip/docx can have decodable ASCII headers. Unknown
|
||
# but clearly-textual types (text/* MIME or a known text
|
||
# extension) are inlined too; everything else relies on
|
||
# ``gateway/run.py`` to emit a path-pointing context note.
|
||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||
_is_text = (
|
||
ext in _TEXT_INJECT_EXTENSIONS
|
||
or (content_type or "").startswith("text/")
|
||
)
|
||
if _is_text and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||
try:
|
||
text_content = raw_bytes.decode("utf-8")
|
||
display_name = att.filename or f"document{ext or '.txt'}"
|
||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||
if pending_text_injection:
|
||
pending_text_injection = f"{pending_text_injection}\n\n{injection}"
|
||
else:
|
||
pending_text_injection = injection
|
||
except UnicodeDecodeError:
|
||
pass
|
||
# NOTE: for the untyped-attachment path we deliberately
|
||
# do NOT inject a path string here. ``gateway/run.py``
|
||
# already detects DOCUMENT-typed events with
|
||
# ``application/octet-stream`` MIME and emits a context
|
||
# note with the sandbox-translated cache path via
|
||
# ``to_agent_visible_cache_path()`` (important for
|
||
# Docker/Modal terminal backends).
|
||
except Exception as e:
|
||
logger.warning(
|
||
"[Discord] Failed to cache document %s: %s",
|
||
att.filename, e, exc_info=True,
|
||
)
|
||
|
||
# Use normalized_content (saved before auto-threading) instead of message.content,
|
||
# to detect /slash commands in channel messages.
|
||
event_text = normalized_content
|
||
if pending_text_injection:
|
||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||
|
||
# ── History backfill ─────────────────────────────────────────
|
||
# When require_mention is active, the bot only processes messages
|
||
# that @mention it. Messages in the channel between bot turns are
|
||
# invisible to the session transcript. To recover that context,
|
||
# fetch recent channel history and prepend it to the user message.
|
||
#
|
||
# The fetch window is: everything after the bot's last message in
|
||
# the channel up to (but not including) the current trigger. On
|
||
# cold start (no prior bot message found), fetch the last N messages
|
||
# and stop at the first self-message encountered.
|
||
#
|
||
# Threads naturally scope to thread-only history (channel.history()
|
||
# on a thread returns only that thread's messages). DMs are skipped
|
||
# because every DM message triggers the bot — there's no mention gap
|
||
# to fill; the session transcript already has everything.
|
||
#
|
||
# Per-user sessions also benefit: Alice's session is missing the
|
||
# other-channel-participants' context, and her own messages from
|
||
# before she mentioned the bot. Backfill fills that gap.
|
||
#
|
||
# Messages that arrive while the bot is processing (between trigger
|
||
# and response) are not captured — this is an accepted simplification
|
||
# to keep the partition rule clean.
|
||
_channel_context = None
|
||
_is_dm = isinstance(message.channel, discord.DMChannel)
|
||
if not _is_dm and self._discord_history_backfill():
|
||
# Run backfill when there's a real gap to fill:
|
||
# - mention-gated channels with no free-response override
|
||
# (messages between bot turns aren't in the transcript)
|
||
# - any thread (in_bot_thread bypasses the mention check, but
|
||
# processing-window gaps and post-restart context still need
|
||
# recovery)
|
||
# - any reply (the user pointed at a specific message; hydrate
|
||
# the context around it even in a free-response channel where
|
||
# no mention gap exists — otherwise replies get only the short
|
||
# "[Replying to: ...]" snippet with no surrounding context)
|
||
# DMs skip entirely because every DM message triggers the bot,
|
||
# so the session transcript already has everything.
|
||
# Auto-threaded messages also skip — we just created the thread,
|
||
# there's nothing prior to backfill.
|
||
_has_mention_gap = require_mention and not is_free_channel and not in_bot_thread
|
||
_is_reply = message.reference is not None
|
||
|
||
# Resolve the replied-to message into an object exposing ``.id``.
|
||
# discord.py may give us a full Message (resolved), a
|
||
# DeletedReferencedMessage, or nothing. Duck-type on ``.id``
|
||
# rather than isinstance(discord.Message) — under test doubles the
|
||
# discord module (and thus discord.Message) can be a mock, which is
|
||
# not a valid isinstance() second argument. Any object with an int
|
||
# id works as a scan anchor; otherwise fall back to a bare snowflake
|
||
# built from the reference's message_id.
|
||
_reply_target = None
|
||
if _is_reply:
|
||
_resolved = getattr(message.reference, "resolved", None)
|
||
_resolved_id = getattr(_resolved, "id", None) if _resolved is not None else None
|
||
if _resolved_id is not None:
|
||
_reply_target = _resolved
|
||
else:
|
||
_ref_mid = getattr(message.reference, "message_id", None)
|
||
if _ref_mid is not None:
|
||
with suppress(ValueError, TypeError):
|
||
_reply_target = _Snowflake(int(_ref_mid))
|
||
|
||
if (_has_mention_gap or is_thread or _is_reply) and auto_threaded_channel is None:
|
||
_backfill_text = await self._fetch_channel_context(
|
||
message.channel, before=message, reply_target=_reply_target,
|
||
)
|
||
if _backfill_text:
|
||
_channel_context = _backfill_text
|
||
|
||
# Defense-in-depth: prevent empty user messages from entering session
|
||
# (can happen when user sends @mention-only with no other text).
|
||
# When channel_context is present, a bare mention means "catch me up"
|
||
# — the context IS the message, so skip the placeholder.
|
||
if (not event_text or not event_text.strip()) and not _channel_context:
|
||
event_text = "(The user sent a message with no text content)"
|
||
|
||
_chan = message.channel
|
||
_parent_id = str(getattr(_chan, "parent_id", "") or "")
|
||
_chan_id = str(getattr(_chan, "id", ""))
|
||
_skills = self._resolve_channel_skills(_chan_id, _parent_id or None)
|
||
_channel_prompt = self._resolve_channel_prompt(_chan_id, _parent_id or None)
|
||
|
||
reply_to_id = None
|
||
reply_to_text = None
|
||
if message.reference:
|
||
reply_to_id = str(message.reference.message_id)
|
||
if message.reference.resolved:
|
||
reply_to_text = getattr(message.reference.resolved, "content", None) or None
|
||
|
||
event = MessageEvent(
|
||
text=event_text,
|
||
message_type=msg_type,
|
||
source=source,
|
||
raw_message=message,
|
||
message_id=str(message.id),
|
||
media_urls=media_urls,
|
||
media_types=media_types,
|
||
reply_to_message_id=reply_to_id,
|
||
reply_to_text=reply_to_text,
|
||
timestamp=message.created_at,
|
||
auto_skill=_skills,
|
||
channel_prompt=_channel_prompt,
|
||
channel_context=_channel_context,
|
||
)
|
||
|
||
# Track thread participation so the bot won't require @mention for
|
||
# follow-up messages in threads it has already engaged in.
|
||
if thread_id:
|
||
self._threads.mark(thread_id)
|
||
|
||
# Only batch plain text messages — commands, media, etc. dispatch
|
||
# immediately since they won't be split by the Discord client.
|
||
if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
|
||
self._enqueue_text_event(event)
|
||
else:
|
||
await self.handle_message(event)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Text message aggregation (handles Discord client-side splits)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||
"""Session-scoped key for text message batching."""
|
||
from gateway.session import build_session_key
|
||
return build_session_key(
|
||
event.source,
|
||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||
)
|
||
|
||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||
"""Buffer a text event and reset the flush timer.
|
||
|
||
When Discord splits a long user message at 2000 chars, the chunks
|
||
arrive within a few hundred milliseconds. This merges them into
|
||
a single event before dispatching.
|
||
"""
|
||
key = self._text_batch_key(event)
|
||
existing = self._pending_text_batches.get(key)
|
||
chunk_len = len(event.text or "")
|
||
if existing is None:
|
||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||
self._pending_text_batches[key] = event
|
||
else:
|
||
if event.text:
|
||
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||
if event.media_urls:
|
||
existing.media_urls.extend(event.media_urls)
|
||
existing.media_types.extend(event.media_types)
|
||
|
||
prior_task = self._pending_text_batch_tasks.get(key)
|
||
if prior_task and not prior_task.done():
|
||
prior_task.cancel()
|
||
self._pending_text_batch_tasks[key] = asyncio.create_task(
|
||
self._flush_text_batch(key)
|
||
)
|
||
|
||
async def _flush_text_batch(self, key: str) -> None:
|
||
"""Wait for the quiet period then dispatch the aggregated text.
|
||
|
||
Uses a longer delay when the latest chunk is near Discord's 2000-char
|
||
split point, since a continuation chunk is almost certain.
|
||
"""
|
||
current_task = asyncio.current_task()
|
||
try:
|
||
pending = self._pending_text_batches.get(key)
|
||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||
if last_len >= self._SPLIT_THRESHOLD:
|
||
delay = self._text_batch_split_delay_seconds
|
||
else:
|
||
delay = self._text_batch_delay_seconds
|
||
await asyncio.sleep(delay)
|
||
event = self._pending_text_batches.pop(key, None)
|
||
if not event:
|
||
return
|
||
logger.info(
|
||
"[Discord] Flushing text batch %s (%d chars)",
|
||
key, len(event.text or ""),
|
||
)
|
||
# Shield the downstream dispatch so that a subsequent chunk
|
||
# arriving while handle_message is mid-flight cannot cancel
|
||
# the running agent turn. _enqueue_text_event always cancels
|
||
# the prior flush task when a new chunk lands; without this
|
||
# shield, CancelledError would propagate from our task down
|
||
# into handle_message → the agent's streaming request,
|
||
# aborting the response the user was waiting on. The new
|
||
# chunk is handled by the fresh flush task regardless.
|
||
await asyncio.shield(self.handle_message(event))
|
||
except asyncio.CancelledError:
|
||
# Only reached if cancel landed before the pop — the shielded
|
||
# handle_message is unaffected either way. Let the task exit
|
||
# cleanly so the finally block cleans up.
|
||
pass
|
||
finally:
|
||
if self._pending_text_batch_tasks.get(key) is current_task:
|
||
self._pending_text_batch_tasks.pop(key, None)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Discord UI Components (outside the adapter class)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _component_check_auth(
|
||
interaction,
|
||
allowed_user_ids: Optional[set],
|
||
allowed_role_ids: Optional[set],
|
||
) -> bool:
|
||
"""Shared user-or-role OR semantics for component view button clicks.
|
||
|
||
Mirrors the gateway's external-surface authorization model: component
|
||
button clicks must be explicitly authorized by a Discord user/role
|
||
allowlist, a global user allowlist, or an explicit allow-all flag.
|
||
|
||
Behavior:
|
||
|
||
- DISCORD_ALLOW_ALL_USERS or GATEWAY_ALLOW_ALL_USERS -> allow
|
||
- user is in DISCORD_ALLOWED_USERS or GATEWAY_ALLOWED_USERS -> allow
|
||
- role allowlist set + user has a role in it -> allow
|
||
- role allowlist set + interaction.user has no resolvable
|
||
``roles`` attribute (e.g. DM context with a role policy active)
|
||
-> reject (fail closed)
|
||
- otherwise -> reject
|
||
"""
|
||
if os.getenv("DISCORD_ALLOW_ALL_USERS", "").strip().lower() in {"true", "1", "yes"}:
|
||
return True
|
||
if os.getenv("GATEWAY_ALLOW_ALL_USERS", "").strip().lower() in {"true", "1", "yes"}:
|
||
return True
|
||
|
||
user_set = {str(uid).strip() for uid in (allowed_user_ids or set()) if str(uid).strip()}
|
||
global_allowed = {
|
||
uid.strip()
|
||
for uid in os.getenv("GATEWAY_ALLOWED_USERS", "").split(",")
|
||
if uid.strip()
|
||
}
|
||
user_set.update(global_allowed)
|
||
role_set = set(allowed_role_ids or set())
|
||
has_users = bool(user_set)
|
||
has_roles = bool(role_set)
|
||
user = getattr(interaction, "user", None)
|
||
if user is None:
|
||
return False
|
||
|
||
if has_users:
|
||
try:
|
||
uid = str(user.id)
|
||
except AttributeError:
|
||
uid = ""
|
||
if "*" in user_set or (uid and uid in user_set):
|
||
return True
|
||
|
||
if has_roles:
|
||
roles_attr = getattr(user, "roles", None)
|
||
if roles_attr is None:
|
||
# Role policy is configured but the interaction doesn't
|
||
# carry role data (DM-context Member, raw User payload).
|
||
# Fail closed: a user without a resolvable role list cannot
|
||
# satisfy a role allowlist.
|
||
return False
|
||
try:
|
||
user_role_ids = {getattr(r, "id", None) for r in roles_attr}
|
||
except TypeError:
|
||
return False
|
||
if user_role_ids & role_set:
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def _define_discord_view_classes() -> None:
|
||
"""Register Discord UI view classes as module globals.
|
||
|
||
Called at module load (when discord.py is pre-installed) and also from
|
||
check_discord_requirements() after a lazy install, so view classes are
|
||
always defined whenever DISCORD_AVAILABLE is True. Without this,
|
||
ExecApprovalView and siblings are only defined at import time; a later
|
||
lazy install sets DISCORD_AVAILABLE=True but leaves the classes
|
||
undefined, causing NameError on the first button interaction.
|
||
"""
|
||
global ExecApprovalView, SlashConfirmView, UpdatePromptView, ModelPickerView, ClarifyChoiceView
|
||
|
||
class ExecApprovalView(discord.ui.View):
|
||
"""
|
||
Interactive button view for exec approval of dangerous commands.
|
||
|
||
Shows four buttons: Allow Once, Allow Session, Always Allow, Deny.
|
||
Clicking a button calls ``resolve_gateway_approval()`` to unblock the
|
||
waiting agent thread — the same mechanism as the text ``/approve`` flow.
|
||
Only users in the allowed list can click. Times out after 5 minutes.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_key: str,
|
||
allowed_user_ids: set,
|
||
allowed_role_ids: Optional[set] = None,
|
||
):
|
||
super().__init__(timeout=300) # 5-minute timeout
|
||
self.session_key = session_key
|
||
self.allowed_user_ids = allowed_user_ids
|
||
self.allowed_role_ids = allowed_role_ids or set()
|
||
self.resolved = False
|
||
|
||
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||
"""Verify the user clicking is authorized."""
|
||
return _component_check_auth(
|
||
interaction, self.allowed_user_ids, self.allowed_role_ids,
|
||
)
|
||
|
||
async def _resolve(
|
||
self, interaction: discord.Interaction, choice: str,
|
||
color: discord.Color, label: str,
|
||
):
|
||
"""Resolve the approval via the gateway approval queue and update the embed."""
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"This approval has already been resolved~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized to approve commands~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
self.resolved = True
|
||
|
||
# Update the embed with the decision
|
||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||
if embed:
|
||
embed.color = color
|
||
embed.set_footer(text=f"{label} by {interaction.user.display_name}")
|
||
|
||
# Disable all buttons
|
||
for child in self.children:
|
||
child.disabled = True
|
||
|
||
await interaction.response.edit_message(embed=embed, view=self)
|
||
|
||
# Unblock the waiting agent thread via the gateway approval queue
|
||
try:
|
||
from tools.approval import resolve_gateway_approval
|
||
count = resolve_gateway_approval(self.session_key, choice)
|
||
logger.info(
|
||
"Discord button resolved %d approval(s) for session %s (choice=%s, user=%s)",
|
||
count, self.session_key, choice, interaction.user.display_name,
|
||
)
|
||
except Exception as exc:
|
||
logger.error("Failed to resolve gateway approval from button: %s", exc)
|
||
|
||
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
|
||
async def allow_once(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._resolve(interaction, "once", discord.Color.green(), "Approved once")
|
||
|
||
@discord.ui.button(label="Allow Session", style=discord.ButtonStyle.grey)
|
||
async def allow_session(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._resolve(interaction, "session", discord.Color.blue(), "Approved for session")
|
||
|
||
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
|
||
async def allow_always(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._resolve(interaction, "always", discord.Color.purple(), "Approved permanently")
|
||
|
||
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
|
||
async def deny(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._resolve(interaction, "deny", discord.Color.red(), "Denied")
|
||
|
||
async def on_timeout(self):
|
||
"""Handle view timeout -- disable buttons and mark as expired."""
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
# Visually update the Discord message so buttons appear disabled.
|
||
msg = getattr(self, '_message', None)
|
||
if msg:
|
||
try:
|
||
embed = msg.embeds[0] if msg.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.greyple()
|
||
embed.set_footer(text="⏱ Prompt expired — no action taken")
|
||
await msg.edit(embed=embed, view=self)
|
||
except Exception:
|
||
pass # message deleted or too old to edit
|
||
|
||
class SlashConfirmView(discord.ui.View):
|
||
"""Three-button view for generic slash-command confirmations.
|
||
|
||
Used by ``/reload-mcp`` and any future slash command routed through
|
||
``GatewayRunner._request_slash_confirm``. Buttons map to the
|
||
gateway's three choices:
|
||
|
||
* "Approve Once" → ``choice="once"``
|
||
* "Always Approve" → ``choice="always"``
|
||
* "Cancel" → ``choice="cancel"``
|
||
|
||
Clicking calls the module-level
|
||
``tools.slash_confirm.resolve(session_key, confirm_id, choice)``
|
||
which runs the handler the runner stored for this ``session_key``.
|
||
Only users in the adapter's allowlist can click. Times out after
|
||
5 minutes (matches the gateway primitive's timeout).
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_key: str,
|
||
confirm_id: str,
|
||
allowed_user_ids: set,
|
||
allowed_role_ids: Optional[set] = None,
|
||
):
|
||
super().__init__(timeout=300)
|
||
self.session_key = session_key
|
||
self.confirm_id = confirm_id
|
||
self.allowed_user_ids = allowed_user_ids
|
||
self.allowed_role_ids = allowed_role_ids or set()
|
||
self.resolved = False
|
||
|
||
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||
return _component_check_auth(
|
||
interaction, self.allowed_user_ids, self.allowed_role_ids,
|
||
)
|
||
|
||
async def _resolve(
|
||
self, interaction: discord.Interaction, choice: str,
|
||
color: discord.Color, label: str,
|
||
):
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"This prompt has already been resolved~", ephemeral=True,
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized to answer this prompt~", ephemeral=True,
|
||
)
|
||
return
|
||
|
||
self.resolved = True
|
||
|
||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||
if embed:
|
||
embed.color = color
|
||
embed.set_footer(text=f"{label} by {interaction.user.display_name}")
|
||
|
||
for child in self.children:
|
||
child.disabled = True
|
||
|
||
await interaction.response.edit_message(embed=embed, view=self)
|
||
|
||
# Resolve via the module-level primitive. If the handler
|
||
# returns a follow-up message, post it in the same channel.
|
||
try:
|
||
from tools import slash_confirm as _slash_confirm_mod
|
||
result_text = await _slash_confirm_mod.resolve(
|
||
self.session_key, self.confirm_id, choice,
|
||
)
|
||
if result_text:
|
||
await interaction.followup.send(result_text)
|
||
logger.info(
|
||
"Discord button resolved slash-confirm for session %s "
|
||
"(choice=%s, user=%s)",
|
||
self.session_key, choice, interaction.user.display_name,
|
||
)
|
||
except Exception as exc:
|
||
logger.error("Discord slash-confirm resolve failed: %s", exc, exc_info=True)
|
||
|
||
@discord.ui.button(label="Approve Once", style=discord.ButtonStyle.green)
|
||
async def approve_once(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button,
|
||
):
|
||
await self._resolve(interaction, "once", discord.Color.green(), "Approved once")
|
||
|
||
@discord.ui.button(label="Always Approve", style=discord.ButtonStyle.blurple)
|
||
async def approve_always(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button,
|
||
):
|
||
await self._resolve(interaction, "always", discord.Color.purple(), "Always approved")
|
||
|
||
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.red)
|
||
async def cancel(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button,
|
||
):
|
||
await self._resolve(interaction, "cancel", discord.Color.greyple(), "Cancelled")
|
||
|
||
async def on_timeout(self):
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
# Visually update the Discord message so buttons appear disabled.
|
||
msg = getattr(self, '_message', None)
|
||
if msg:
|
||
try:
|
||
embed = msg.embeds[0] if msg.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.greyple()
|
||
embed.set_footer(text="⏱ Prompt expired — no action taken")
|
||
await msg.edit(embed=embed, view=self)
|
||
except Exception:
|
||
pass
|
||
|
||
class UpdatePromptView(discord.ui.View):
|
||
"""Interactive Yes/No buttons for ``hermes update`` prompts.
|
||
|
||
Clicking a button writes the answer to ``.update_response`` so the
|
||
detached update process can pick it up. Only authorized users can
|
||
click. Times out after 5 minutes (the update process also has a
|
||
5-minute timeout on its side).
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_key: str,
|
||
allowed_user_ids: set,
|
||
allowed_role_ids: Optional[set] = None,
|
||
):
|
||
super().__init__(timeout=300)
|
||
self.session_key = session_key
|
||
self.allowed_user_ids = allowed_user_ids
|
||
self.allowed_role_ids = allowed_role_ids or set()
|
||
self.resolved = False
|
||
|
||
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||
return _component_check_auth(
|
||
interaction, self.allowed_user_ids, self.allowed_role_ids,
|
||
)
|
||
|
||
async def _respond(
|
||
self, interaction: discord.Interaction, answer: str,
|
||
color: discord.Color, label: str,
|
||
):
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"Already answered~", ephemeral=True
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
self.resolved = True
|
||
|
||
# Update embed
|
||
embed = interaction.message.embeds[0] if interaction.message.embeds else None
|
||
if embed:
|
||
embed.color = color
|
||
embed.set_footer(text=f"{label} by {interaction.user.display_name}")
|
||
|
||
for child in self.children:
|
||
child.disabled = True
|
||
await interaction.response.edit_message(embed=embed, view=self)
|
||
|
||
# Write response file
|
||
try:
|
||
from hermes_constants import get_hermes_home
|
||
home = get_hermes_home()
|
||
response_path = home / ".update_response"
|
||
tmp = response_path.with_suffix(".tmp")
|
||
tmp.write_text(answer)
|
||
tmp.replace(response_path)
|
||
logger.info(
|
||
"Discord update prompt answered '%s' by %s",
|
||
answer, interaction.user.display_name,
|
||
)
|
||
except Exception as exc:
|
||
logger.error("Failed to write update response: %s", exc)
|
||
|
||
@discord.ui.button(label="Yes", style=discord.ButtonStyle.green, emoji="✓")
|
||
async def yes_btn(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._respond(interaction, "y", discord.Color.green(), "Yes")
|
||
|
||
@discord.ui.button(label="No", style=discord.ButtonStyle.red, emoji="✗")
|
||
async def no_btn(
|
||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||
):
|
||
await self._respond(interaction, "n", discord.Color.red(), "No")
|
||
|
||
async def on_timeout(self):
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
# Visually update the Discord message so buttons appear disabled.
|
||
msg = getattr(self, '_message', None)
|
||
if msg:
|
||
try:
|
||
embed = msg.embeds[0] if msg.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.greyple()
|
||
embed.set_footer(text="⏱ Prompt expired — no action taken")
|
||
await msg.edit(embed=embed, view=self)
|
||
except Exception:
|
||
pass
|
||
|
||
class ModelPickerView(discord.ui.View):
|
||
"""Interactive select-menu view for model switching.
|
||
|
||
Two-step drill-down: provider dropdown → model dropdown.
|
||
Edits the original message in-place as the user navigates.
|
||
Times out after 2 minutes.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
providers: list,
|
||
current_model: str,
|
||
current_provider: str,
|
||
session_key: str,
|
||
on_model_selected,
|
||
allowed_user_ids: set,
|
||
allowed_role_ids: Optional[set] = None,
|
||
):
|
||
super().__init__(timeout=120)
|
||
self.providers = providers
|
||
self.current_model = current_model
|
||
self.current_provider = current_provider
|
||
self.session_key = session_key
|
||
self.on_model_selected = on_model_selected
|
||
self.allowed_user_ids = allowed_user_ids
|
||
self.allowed_role_ids = allowed_role_ids or set()
|
||
self.resolved = False
|
||
self._selected_provider: str = ""
|
||
self._pending_expensive_model: str = ""
|
||
|
||
self._build_provider_select()
|
||
|
||
def _check_auth(self, interaction: discord.Interaction) -> bool:
|
||
return _component_check_auth(
|
||
interaction, self.allowed_user_ids, self.allowed_role_ids,
|
||
)
|
||
|
||
def _build_provider_select(self):
|
||
"""Build the provider dropdown menu."""
|
||
self.clear_items()
|
||
options = []
|
||
for p in self.providers:
|
||
count = p.get("total_models", len(p.get("models", [])))
|
||
label = f"{p['name']} ({count} models)"
|
||
desc = "current" if p.get("is_current") else None
|
||
options.append(
|
||
discord.SelectOption(
|
||
label=label[:100],
|
||
value=p["slug"],
|
||
description=desc,
|
||
)
|
||
)
|
||
if not options:
|
||
return
|
||
|
||
select = discord.ui.Select(
|
||
placeholder="Choose a provider...",
|
||
options=options[:25],
|
||
custom_id="model_provider_select",
|
||
)
|
||
select.callback = self._on_provider_selected
|
||
self.add_item(select)
|
||
|
||
cancel_btn = discord.ui.Button(
|
||
label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel"
|
||
)
|
||
cancel_btn.callback = self._on_cancel
|
||
self.add_item(cancel_btn)
|
||
|
||
def _build_model_select(self, provider_slug: str):
|
||
"""Build the model dropdown for a specific provider."""
|
||
self.clear_items()
|
||
provider = next(
|
||
(p for p in self.providers if p["slug"] == provider_slug), None
|
||
)
|
||
if not provider:
|
||
return
|
||
|
||
models = provider.get("models", [])
|
||
options = []
|
||
for model_id in models[:25]:
|
||
short = model_id.split("/")[-1] if "/" in model_id else model_id
|
||
options.append(
|
||
discord.SelectOption(
|
||
label=short[:100],
|
||
value=model_id[:100],
|
||
)
|
||
)
|
||
if not options:
|
||
return
|
||
|
||
select = discord.ui.Select(
|
||
placeholder=f"Choose a model from {provider.get('name', provider_slug)}...",
|
||
options=options,
|
||
custom_id="model_model_select",
|
||
)
|
||
select.callback = self._on_model_selected
|
||
self.add_item(select)
|
||
|
||
back_btn = discord.ui.Button(
|
||
label="◀ Back", style=discord.ButtonStyle.grey, custom_id="model_back"
|
||
)
|
||
back_btn.callback = self._on_back
|
||
self.add_item(back_btn)
|
||
|
||
cancel_btn = discord.ui.Button(
|
||
label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel2"
|
||
)
|
||
cancel_btn.callback = self._on_cancel
|
||
self.add_item(cancel_btn)
|
||
|
||
def _build_expensive_confirm(self, model_id: str):
|
||
"""Build confirmation buttons for unusually expensive models."""
|
||
self.clear_items()
|
||
self._pending_expensive_model = model_id
|
||
|
||
confirm_btn = discord.ui.Button(
|
||
label="Switch anyway",
|
||
style=discord.ButtonStyle.red,
|
||
custom_id="model_expensive_confirm",
|
||
)
|
||
confirm_btn.callback = self._on_expensive_confirm
|
||
self.add_item(confirm_btn)
|
||
|
||
cancel_btn = discord.ui.Button(
|
||
label="Cancel",
|
||
style=discord.ButtonStyle.grey,
|
||
custom_id="model_expensive_cancel",
|
||
)
|
||
cancel_btn.callback = self._on_cancel
|
||
self.add_item(cancel_btn)
|
||
|
||
async def _expensive_warning_for(self, model_id: str):
|
||
try:
|
||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||
|
||
# Pricing lookup can hit models.dev / a /models endpoint on a
|
||
# cache miss — keep it off the event loop.
|
||
return await asyncio.to_thread(
|
||
expensive_model_warning,
|
||
model_id,
|
||
provider=self._selected_provider,
|
||
)
|
||
except Exception:
|
||
return None
|
||
|
||
async def _on_provider_selected(self, interaction: discord.Interaction):
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
provider_slug = interaction.data["values"][0]
|
||
self._selected_provider = provider_slug
|
||
provider = next(
|
||
(p for p in self.providers if p["slug"] == provider_slug), None
|
||
)
|
||
pname = provider.get("name", provider_slug) if provider else provider_slug
|
||
|
||
self._build_model_select(provider_slug)
|
||
|
||
total = provider.get("total_models", 0) if provider else 0
|
||
shown = min(len(provider.get("models", [])), 25) if provider else 0
|
||
extra = f"\n*{total - shown} more available — type `/model <name>` directly*" if total > shown else ""
|
||
|
||
await interaction.response.edit_message(
|
||
embed=discord.Embed(
|
||
title="⚙ Model Configuration",
|
||
description=f"Provider: **{pname}**\nSelect a model:{extra}",
|
||
color=discord.Color.blue(),
|
||
),
|
||
view=self,
|
||
)
|
||
|
||
async def _switch_selected_model(
|
||
self,
|
||
interaction: discord.Interaction,
|
||
model_id: str,
|
||
):
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"Already resolved~", ephemeral=True
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
self.resolved = True
|
||
self.clear_items()
|
||
await interaction.response.edit_message(
|
||
embed=discord.Embed(
|
||
title="⚙ Switching Model",
|
||
description=f"Switching to `{model_id}`...",
|
||
color=discord.Color.blue(),
|
||
),
|
||
view=None,
|
||
)
|
||
|
||
try:
|
||
result_text = await self.on_model_selected(
|
||
str(interaction.channel_id),
|
||
model_id,
|
||
self._selected_provider,
|
||
)
|
||
except Exception as exc:
|
||
result_text = f"Error switching model: {exc}"
|
||
|
||
await interaction.edit_original_response(
|
||
embed=discord.Embed(
|
||
title="⚙ Model Switched",
|
||
description=result_text,
|
||
color=discord.Color.green(),
|
||
),
|
||
view=None,
|
||
)
|
||
|
||
async def _on_model_selected(self, interaction: discord.Interaction):
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"Already resolved~", ephemeral=True
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
model_id = interaction.data["values"][0]
|
||
warning = await self._expensive_warning_for(model_id)
|
||
if warning is not None:
|
||
self._build_expensive_confirm(model_id)
|
||
await interaction.response.edit_message(
|
||
embed=discord.Embed(
|
||
title="⚠ Expensive Model Warning",
|
||
description=warning.message,
|
||
color=discord.Color.red(),
|
||
),
|
||
view=self,
|
||
)
|
||
return
|
||
|
||
await self._switch_selected_model(interaction, model_id)
|
||
|
||
async def _on_expensive_confirm(self, interaction: discord.Interaction):
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
if not self._pending_expensive_model:
|
||
await interaction.response.send_message(
|
||
"Model selection expired.", ephemeral=True
|
||
)
|
||
return
|
||
await self._switch_selected_model(
|
||
interaction,
|
||
self._pending_expensive_model,
|
||
)
|
||
|
||
async def _on_back(self, interaction: discord.Interaction):
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized~", ephemeral=True
|
||
)
|
||
return
|
||
|
||
self._build_provider_select()
|
||
|
||
try:
|
||
from hermes_cli.providers import get_label
|
||
provider_label = get_label(self.current_provider)
|
||
except Exception:
|
||
provider_label = self.current_provider
|
||
|
||
await interaction.response.edit_message(
|
||
embed=discord.Embed(
|
||
title="⚙ Model Configuration",
|
||
description=(
|
||
f"Current model: `{self.current_model or 'unknown'}`\n"
|
||
f"Provider: {provider_label}\n\n"
|
||
f"Select a provider:"
|
||
),
|
||
color=discord.Color.blue(),
|
||
),
|
||
view=self,
|
||
)
|
||
|
||
async def _on_cancel(self, interaction: discord.Interaction):
|
||
self.resolved = True
|
||
self.clear_items()
|
||
await interaction.response.edit_message(
|
||
embed=discord.Embed(
|
||
title="⚙ Model Configuration",
|
||
description="Model selection cancelled.",
|
||
color=discord.Color.greyple(),
|
||
),
|
||
view=self,
|
||
)
|
||
|
||
async def on_timeout(self):
|
||
self.resolved = True
|
||
self.clear_items()
|
||
# Visually update the Discord message so it appears expired.
|
||
msg = getattr(self, '_message', None)
|
||
if msg:
|
||
try:
|
||
embed = discord.Embed(
|
||
title="⚙ Model Configuration",
|
||
description="⏱ Selection expired — no model change.",
|
||
color=discord.Color.greyple(),
|
||
)
|
||
await msg.edit(embed=embed, view=self)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
class ClarifyChoiceView(discord.ui.View):
|
||
"""Interactive button view for the clarify tool's multiple-choice prompts.
|
||
|
||
Renders one button per choice (max 24) plus a final ``✏️ Other`` button.
|
||
Picking a numeric choice resolves the gateway clarify entry immediately;
|
||
picking ``Other`` flips the entry into text-capture mode so the next
|
||
user message in the session becomes the response (the gateway's
|
||
text-intercept handles the resolution).
|
||
|
||
Auth gating mirrors ``ExecApprovalView`` — only users/roles in the
|
||
Discord adapter's allowlist may answer. Single-use: after the first
|
||
valid click all buttons disable and the embed updates to show who
|
||
answered and what they chose.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
choices: List[str],
|
||
clarify_id: str,
|
||
allowed_user_ids: set,
|
||
allowed_role_ids: Optional[set] = None,
|
||
):
|
||
super().__init__(timeout=300) # 5-minute timeout
|
||
self.choices = list(choices)[:24]
|
||
self.clarify_id = clarify_id
|
||
self.allowed_user_ids = allowed_user_ids
|
||
self.allowed_role_ids = allowed_role_ids or set()
|
||
self.resolved = False
|
||
|
||
for index, choice in enumerate(self.choices):
|
||
# Discord button labels are capped at 80 chars. On mobile the
|
||
# visible width is much narrower (often <40 chars before it
|
||
# wraps to 2 lines and the second line gets cut off), so we
|
||
# cap aggressively and cut at a word boundary when possible
|
||
# to keep the trailing text readable.
|
||
#
|
||
# Cut strategy (most-preferred to least-preferred):
|
||
# 1. Last space in the trailing half of the budget
|
||
# (cleanest word boundary)
|
||
# 2. Last soft boundary in the trailing half of the
|
||
# budget (hyphen, comma, period, paren)
|
||
# 3. Hard cut at the budget limit (last resort)
|
||
prefix = f"{index + 1}. "
|
||
budget = 80 - len(prefix)
|
||
if len(choice) <= budget:
|
||
label_body = choice
|
||
else:
|
||
truncated = choice[: budget - 1].rstrip()
|
||
cut_at = -1
|
||
# 1. Last space in the trailing half of the budget.
|
||
space = truncated.rfind(" ")
|
||
if space >= budget // 2:
|
||
cut_at = space
|
||
# 2. Soft boundary — only if no word boundary found.
|
||
# Find the latest soft boundary in the trailing half
|
||
# of the budget; that maximizes preserved text length.
|
||
# Cut AT the soft boundary (inclusive) so the label
|
||
# ends on the soft char (e.g. "-" or ",") rather than
|
||
# on the alpha char that followed it.
|
||
if cut_at < 0:
|
||
latest_soft = max(
|
||
(truncated.rfind(s) for s in ("-", ",", ".", ")")),
|
||
default=-1,
|
||
)
|
||
if latest_soft >= budget // 2:
|
||
cut_at = latest_soft + 1
|
||
if cut_at > 0:
|
||
truncated = truncated[:cut_at]
|
||
label_body = truncated.rstrip() + "…"
|
||
button = discord.ui.Button(
|
||
label=f"{prefix}{label_body}",
|
||
style=discord.ButtonStyle.primary,
|
||
custom_id=f"clarify:{clarify_id}:{index}",
|
||
)
|
||
button.callback = self._make_choice_callback(index, choice)
|
||
self.add_item(button)
|
||
|
||
other_btn = discord.ui.Button(
|
||
label="✏️ Other (type answer)",
|
||
style=discord.ButtonStyle.secondary,
|
||
custom_id=f"clarify:{clarify_id}:other",
|
||
)
|
||
other_btn.callback = self._on_other
|
||
self.add_item(other_btn)
|
||
|
||
def _check_auth(self, interaction: "discord.Interaction") -> bool:
|
||
return _component_check_auth(
|
||
interaction, self.allowed_user_ids, self.allowed_role_ids,
|
||
)
|
||
|
||
def _make_choice_callback(self, index: int, choice: str):
|
||
async def _callback(interaction: "discord.Interaction"):
|
||
await self._resolve_choice(interaction, index, choice)
|
||
return _callback
|
||
|
||
async def _resolve_choice(
|
||
self,
|
||
interaction: "discord.Interaction",
|
||
index: int,
|
||
choice: str,
|
||
) -> None:
|
||
"""Resolve the clarify with a chosen option."""
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"This prompt has already been answered~", ephemeral=True,
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized to answer this prompt~", ephemeral=True,
|
||
)
|
||
return
|
||
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
|
||
embed = interaction.message.embeds[0] if (
|
||
interaction.message and interaction.message.embeds
|
||
) else None
|
||
if embed:
|
||
user = getattr(interaction, "user", None)
|
||
display_name = getattr(user, "display_name", "user")
|
||
embed.color = discord.Color.green()
|
||
embed.set_footer(text=f"Answered by {display_name}: {choice}")
|
||
|
||
try:
|
||
await interaction.response.edit_message(embed=embed, view=self)
|
||
except Exception:
|
||
logger.debug(
|
||
"Discord clarify edit_message failed for %s",
|
||
self.clarify_id,
|
||
exc_info=True,
|
||
)
|
||
try:
|
||
await interaction.response.defer()
|
||
except Exception:
|
||
pass
|
||
|
||
# Resolve via the gateway clarify primitive — same mechanism as
|
||
# Telegram. Look up the canonical choice text from the entry so
|
||
# we round-trip the original value, not a button-label variant.
|
||
resolved_text: Optional[str] = None
|
||
try:
|
||
from tools.clarify_gateway import _entries as _clarify_entries # type: ignore
|
||
entry = _clarify_entries.get(self.clarify_id)
|
||
if entry and entry.choices and 0 <= index < len(entry.choices):
|
||
resolved_text = entry.choices[index]
|
||
except Exception:
|
||
resolved_text = None
|
||
if resolved_text is None:
|
||
resolved_text = choice
|
||
|
||
try:
|
||
from tools.clarify_gateway import resolve_gateway_clarify
|
||
resolved = resolve_gateway_clarify(self.clarify_id, resolved_text)
|
||
logger.info(
|
||
"Discord clarify button resolved (id=%s, choice=%r, user=%s, ok=%s)",
|
||
self.clarify_id, resolved_text,
|
||
getattr(getattr(interaction, "user", None), "display_name", "?"),
|
||
resolved,
|
||
)
|
||
except Exception as exc:
|
||
logger.error(
|
||
"Discord clarify resolve_gateway_clarify failed (id=%s): %s",
|
||
self.clarify_id, exc,
|
||
)
|
||
|
||
async def _on_other(self, interaction: "discord.Interaction") -> None:
|
||
"""Flip the clarify entry into text-capture mode."""
|
||
if self.resolved:
|
||
await interaction.response.send_message(
|
||
"This prompt has already been answered~", ephemeral=True,
|
||
)
|
||
return
|
||
if not self._check_auth(interaction):
|
||
await interaction.response.send_message(
|
||
"You're not authorized to answer this prompt~", ephemeral=True,
|
||
)
|
||
return
|
||
|
||
# Don't pop the entry — the gateway's text-intercept needs it
|
||
# until the user actually types. Just mark it as awaiting text
|
||
# and disable the buttons so the user can't double-click.
|
||
try:
|
||
from tools.clarify_gateway import mark_awaiting_text
|
||
mark_awaiting_text(self.clarify_id)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"Discord clarify mark_awaiting_text failed (id=%s): %s",
|
||
self.clarify_id, exc,
|
||
)
|
||
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
|
||
embed = interaction.message.embeds[0] if (
|
||
interaction.message and interaction.message.embeds
|
||
) else None
|
||
if embed:
|
||
user = getattr(interaction, "user", None)
|
||
display_name = getattr(user, "display_name", "user")
|
||
embed.color = discord.Color.blue()
|
||
embed.set_footer(
|
||
text=f"Awaiting typed response from {display_name}…",
|
||
)
|
||
|
||
try:
|
||
await interaction.response.edit_message(embed=embed, view=self)
|
||
except Exception:
|
||
try:
|
||
await interaction.response.defer()
|
||
except Exception:
|
||
pass
|
||
|
||
async def on_timeout(self):
|
||
self.resolved = True
|
||
for child in self.children:
|
||
child.disabled = True
|
||
# Visually update the Discord message so buttons appear disabled.
|
||
msg = getattr(self, '_message', None)
|
||
if msg:
|
||
try:
|
||
embed = msg.embeds[0] if msg.embeds else None
|
||
if embed:
|
||
embed.color = discord.Color.greyple()
|
||
embed.set_footer(text="⏱ Prompt expired — no action taken")
|
||
await msg.edit(embed=embed, view=self)
|
||
except Exception:
|
||
pass
|
||
if DISCORD_AVAILABLE:
|
||
_define_discord_view_classes()
|
||
|
||
|
||
# ── Standalone (out-of-process) sender ────────────────────────────────────────
|
||
# Used by ``tools/send_message_tool._send_via_adapter`` when the gateway runner
|
||
# is not in this process (e.g. ``hermes cron`` running standalone) and no live
|
||
# DiscordAdapter instance is available. Implements the same forum/thread/
|
||
# multipart logic the live adapter would use, via Discord's REST API directly.
|
||
#
|
||
# This block was previously hosted in ``tools/send_message_tool.py`` as
|
||
# ``_send_discord``. It moved into the plugin so all Discord-specific HTTP
|
||
# logic lives next to the adapter — same shape as Teams' ``_standalone_send``.
|
||
|
||
# Process-local cache for Discord channel-type probes. Avoids re-probing the
|
||
# same channel on every send when the directory cache has no entry (e.g. fresh
|
||
# install, or channel created after the last directory build).
|
||
_DISCORD_CHANNEL_TYPE_PROBE_CACHE: Dict[str, bool] = {}
|
||
|
||
|
||
def _remember_channel_is_forum(chat_id: str, is_forum: bool) -> None:
|
||
_DISCORD_CHANNEL_TYPE_PROBE_CACHE[str(chat_id)] = bool(is_forum)
|
||
|
||
|
||
def _probe_is_forum_cached(chat_id: str) -> Optional[bool]:
|
||
return _DISCORD_CHANNEL_TYPE_PROBE_CACHE.get(str(chat_id))
|
||
|
||
|
||
def _derive_forum_thread_name(message: str) -> str:
|
||
"""Derive a thread name from the first line of the message, capped at 100 chars."""
|
||
first_line = message.strip().split("\n", 1)[0].strip()
|
||
# Strip common markdown heading prefixes
|
||
first_line = first_line.lstrip("#").strip()
|
||
if not first_line:
|
||
first_line = "New Post"
|
||
return first_line[:100]
|
||
|
||
|
||
def _standalone_sanitize_error(text) -> str:
|
||
"""Local copy of tools.send_message_tool._sanitize_error_text — strips bot
|
||
tokens from any error payload before bubbling it up. Inlined so the
|
||
plugin doesn't introduce a hard dependency on send_message_tool internals.
|
||
"""
|
||
s = str(text)
|
||
# Mask anything that looks like a Bot token in an Authorization header.
|
||
import re as _re_san
|
||
return _re_san.sub(
|
||
r"(Authorization:\s*Bot\s+)\S+",
|
||
r"\1***",
|
||
s,
|
||
flags=_re_san.IGNORECASE,
|
||
)
|
||
|
||
|
||
async def _standalone_send(
|
||
pconfig,
|
||
chat_id: str,
|
||
message: str,
|
||
*,
|
||
thread_id: Optional[str] = None,
|
||
media_files: Optional[list] = None,
|
||
force_document: bool = False,
|
||
) -> Dict[str, Any]:
|
||
"""Send via Discord REST API without a live gateway adapter.
|
||
|
||
Used by ``tools/send_message_tool._send_via_adapter`` when the gateway
|
||
runner is not in this process. Reads ``DISCORD_BOT_TOKEN`` from
|
||
``pconfig.token`` (set by the gateway config loader from env) and falls
|
||
back to the ``DISCORD_BOT_TOKEN`` env var.
|
||
|
||
Forum channels (type 15) reject ``POST /messages`` — a thread post is
|
||
created automatically via ``POST /channels/{id}/threads``. Media files
|
||
are uploaded as multipart attachments on the starter message of the new
|
||
thread. Channel type is resolved from the channel directory first, then
|
||
a process-local probe cache, and only as a last resort with a live
|
||
``GET /channels/{id}`` probe (whose result is memoized).
|
||
|
||
``force_document`` is accepted for signature parity but unused — Discord
|
||
treats every uploaded file as a generic attachment.
|
||
"""
|
||
try:
|
||
import aiohttp
|
||
except ImportError:
|
||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||
|
||
token = (getattr(pconfig, "token", None) or os.getenv("DISCORD_BOT_TOKEN", "")).strip()
|
||
if not token:
|
||
return {"error": "Discord standalone send: DISCORD_BOT_TOKEN is not set"}
|
||
|
||
try:
|
||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||
auth_headers = {"Authorization": f"Bot {token}"}
|
||
json_headers = {**auth_headers, "Content-Type": "application/json"}
|
||
media_files = media_files or []
|
||
last_data = None
|
||
warnings = []
|
||
|
||
# Thread endpoint: Discord threads are channels; send directly to the thread ID.
|
||
if thread_id:
|
||
url = f"https://discord.com/api/v10/channels/{thread_id}/messages"
|
||
else:
|
||
# Check if the target channel is a forum channel (type 15).
|
||
# Forum channels reject POST /messages — create a thread post instead.
|
||
# Three-layer detection: directory cache → process-local probe
|
||
# cache → GET /channels/{id} probe (with result memoized).
|
||
_channel_type = None
|
||
try:
|
||
from gateway.channel_directory import lookup_channel_type
|
||
_channel_type = lookup_channel_type("discord", chat_id)
|
||
except Exception:
|
||
pass
|
||
|
||
if _channel_type == "forum":
|
||
is_forum = True
|
||
elif _channel_type is not None:
|
||
is_forum = False
|
||
else:
|
||
cached = _probe_is_forum_cached(chat_id)
|
||
if cached is not None:
|
||
is_forum = cached
|
||
else:
|
||
is_forum = False
|
||
try:
|
||
info_url = f"https://discord.com/api/v10/channels/{chat_id}"
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15), **_sess_kw) as info_sess:
|
||
async with info_sess.get(info_url, headers=json_headers, **_req_kw) as info_resp:
|
||
if info_resp.status == 200:
|
||
info = await info_resp.json()
|
||
is_forum = info.get("type") == 15
|
||
_remember_channel_is_forum(chat_id, is_forum)
|
||
except Exception:
|
||
logger.debug("Failed to probe channel type for %s", chat_id, exc_info=True)
|
||
|
||
if is_forum:
|
||
thread_name = _derive_forum_thread_name(message)
|
||
thread_url = f"https://discord.com/api/v10/channels/{chat_id}/threads"
|
||
|
||
# Filter to readable media files up front so we can pick the
|
||
# right code path (JSON vs multipart) before opening a session.
|
||
valid_media = []
|
||
for media_path, _is_voice in media_files:
|
||
if not os.path.exists(media_path):
|
||
warning = f"Media file not found, skipping: {media_path}"
|
||
logger.warning(warning)
|
||
warnings.append(warning)
|
||
continue
|
||
valid_media.append(media_path)
|
||
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60), **_sess_kw) as session:
|
||
if valid_media:
|
||
# Multipart: payload_json + files[N] creates a forum
|
||
# thread with the starter message plus attachments in
|
||
# a single API call.
|
||
attachments_meta = [
|
||
{"id": str(idx), "filename": os.path.basename(path)}
|
||
for idx, path in enumerate(valid_media)
|
||
]
|
||
starter_message = {"content": message, "attachments": attachments_meta}
|
||
payload_json = json.dumps({"name": thread_name, "message": starter_message})
|
||
|
||
form = aiohttp.FormData()
|
||
form.add_field("payload_json", payload_json, content_type="application/json")
|
||
|
||
try:
|
||
for idx, media_path in enumerate(valid_media):
|
||
with open(media_path, "rb") as fh:
|
||
form.add_field(
|
||
f"files[{idx}]",
|
||
fh.read(),
|
||
filename=os.path.basename(media_path),
|
||
)
|
||
async with session.post(thread_url, headers=auth_headers, data=form, **_req_kw) as resp:
|
||
if resp.status not in {200, 201}:
|
||
body = await resp.text()
|
||
return {"error": f"Discord forum thread creation error ({resp.status}): {body}"}
|
||
data = await resp.json()
|
||
except Exception as e:
|
||
return {"error": _standalone_sanitize_error(f"Discord forum thread upload failed: {e}")}
|
||
else:
|
||
# No media — simple JSON POST creates the thread with
|
||
# just the text starter.
|
||
async with session.post(
|
||
thread_url,
|
||
headers=json_headers,
|
||
json={
|
||
"name": thread_name,
|
||
"message": {"content": message},
|
||
},
|
||
**_req_kw,
|
||
) as resp:
|
||
if resp.status not in {200, 201}:
|
||
body = await resp.text()
|
||
return {"error": f"Discord forum thread creation error ({resp.status}): {body}"}
|
||
data = await resp.json()
|
||
|
||
thread_id_created = data.get("id")
|
||
starter_msg_id = (data.get("message") or {}).get("id", thread_id_created)
|
||
result = {
|
||
"success": True,
|
||
"platform": "discord",
|
||
"chat_id": chat_id,
|
||
"thread_id": thread_id_created,
|
||
"message_id": starter_msg_id,
|
||
}
|
||
if warnings:
|
||
result["warnings"] = warnings
|
||
return result
|
||
|
||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session:
|
||
# Send text message (skip if empty and media is present)
|
||
if message.strip() or not media_files:
|
||
async with session.post(url, headers=json_headers, json={"content": message}, **_req_kw) as resp:
|
||
if resp.status not in {200, 201}:
|
||
body = await resp.text()
|
||
return {"error": f"Discord API error ({resp.status}): {body}"}
|
||
last_data = await resp.json()
|
||
|
||
# Send each media file as a separate multipart upload
|
||
for media_path, _is_voice in media_files:
|
||
if not os.path.exists(media_path):
|
||
warning = f"Media file not found, skipping: {media_path}"
|
||
logger.warning(warning)
|
||
warnings.append(warning)
|
||
continue
|
||
try:
|
||
form = aiohttp.FormData()
|
||
filename = os.path.basename(media_path)
|
||
with open(media_path, "rb") as f:
|
||
form.add_field("files[0]", f, filename=filename)
|
||
async with session.post(url, headers=auth_headers, data=form, **_req_kw) as resp:
|
||
if resp.status not in {200, 201}:
|
||
body = await resp.text()
|
||
warning = _standalone_sanitize_error(f"Failed to send media {media_path}: Discord API error ({resp.status}): {body}")
|
||
logger.error(warning)
|
||
warnings.append(warning)
|
||
continue
|
||
last_data = await resp.json()
|
||
except Exception as e:
|
||
warning = _standalone_sanitize_error(f"Failed to send media {media_path}: {e}")
|
||
logger.error(warning)
|
||
warnings.append(warning)
|
||
|
||
if last_data is None:
|
||
error = "No deliverable text or media remained after processing"
|
||
if warnings:
|
||
return {"error": error, "warnings": warnings}
|
||
return {"error": error}
|
||
|
||
result = {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": last_data.get("id")}
|
||
if warnings:
|
||
result["warnings"] = warnings
|
||
return result
|
||
except Exception as e:
|
||
return {"error": _standalone_sanitize_error(f"Discord send failed: {e}")}
|
||
|
||
|
||
# ── Plugin entry point ────────────────────────────────────────────────────────
|
||
|
||
|
||
def _clean_discord_user_ids(raw: str) -> list:
|
||
"""Strip common Discord mention prefixes from a comma-separated ID string."""
|
||
cleaned = []
|
||
for uid in raw.replace(" ", "").split(","):
|
||
uid = uid.strip()
|
||
if uid.startswith("<@") and uid.endswith(">"):
|
||
uid = uid.lstrip("<@!").rstrip(">")
|
||
if uid.lower().startswith("user:"):
|
||
uid = uid[5:]
|
||
if uid:
|
||
cleaned.append(uid)
|
||
return cleaned
|
||
|
||
|
||
def interactive_setup() -> None:
|
||
"""Guide the user through Discord bot setup.
|
||
|
||
Mirrors Teams' ``interactive_setup`` shape: lazy-imports CLI helpers so
|
||
the plugin's import surface stays small, prompts for the bot token,
|
||
captures an allowlist, and offers to set a home channel.
|
||
"""
|
||
from hermes_cli.config import get_env_value, save_env_value
|
||
from hermes_cli.cli_output import (
|
||
prompt,
|
||
prompt_yes_no,
|
||
print_header,
|
||
print_info,
|
||
print_success,
|
||
)
|
||
|
||
print_header("Discord")
|
||
existing = get_env_value("DISCORD_BOT_TOKEN")
|
||
if existing:
|
||
print_info("Discord: already configured")
|
||
if not prompt_yes_no("Reconfigure Discord?", False):
|
||
if not get_env_value("DISCORD_ALLOWED_USERS"):
|
||
print_info("⚠️ Discord has no user allowlist - anyone can use your bot!")
|
||
if prompt_yes_no("Add allowed users now?", True):
|
||
print_info(" To find Discord ID: Enable Developer Mode, right-click name → Copy ID")
|
||
allowed_users = prompt("Allowed user IDs (comma-separated)")
|
||
if allowed_users:
|
||
cleaned_ids = _clean_discord_user_ids(allowed_users)
|
||
save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids))
|
||
print_success("Discord allowlist configured")
|
||
return
|
||
|
||
print_info("Create a bot at https://discord.com/developers/applications")
|
||
token = prompt("Discord bot token", password=True)
|
||
if not token:
|
||
return
|
||
save_env_value("DISCORD_BOT_TOKEN", token)
|
||
print_success("Discord token saved")
|
||
|
||
print()
|
||
print_info("🔒 Security: Restrict who can use your bot")
|
||
print_info(" To find your Discord user ID:")
|
||
print_info(" 1. Enable Developer Mode in Discord settings")
|
||
print_info(" 2. Right-click your name → Copy ID")
|
||
print()
|
||
print_info(" You can also use Discord usernames (resolved on gateway start).")
|
||
print()
|
||
allowed_users = prompt(
|
||
"Allowed user IDs or usernames (comma-separated, leave empty for open access)"
|
||
)
|
||
if allowed_users:
|
||
cleaned_ids = _clean_discord_user_ids(allowed_users)
|
||
save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids))
|
||
print_success("Discord allowlist configured")
|
||
else:
|
||
print_info("⚠️ No allowlist set - anyone in servers with your bot can use it!")
|
||
|
||
print()
|
||
print_info("📬 Home Channel: where Hermes delivers cron job results,")
|
||
print_info(" cross-platform messages, and notifications.")
|
||
print_info(" To get a channel ID: right-click a channel → Copy Channel ID")
|
||
print_info(" (requires Developer Mode in Discord settings)")
|
||
print_info(" You can also set this later by typing /set-home in a Discord channel.")
|
||
home_channel = prompt("Home channel ID (leave empty to set later with /set-home)")
|
||
if home_channel:
|
||
save_env_value("DISCORD_HOME_CHANNEL", home_channel)
|
||
|
||
|
||
def _apply_yaml_config(yaml_cfg: dict, discord_cfg: dict) -> dict | None:
|
||
"""Translate ``config.yaml`` ``discord:`` keys into env vars.
|
||
|
||
Implements the ``apply_yaml_config_fn`` contract (#24836). Mirrors the
|
||
legacy ``discord_cfg`` block that used to live in
|
||
``gateway/config.py::load_gateway_config()`` before this migration.
|
||
|
||
The DiscordAdapter reads its runtime configuration via ``os.getenv()``
|
||
throughout the connect / handle code paths (``DISCORD_ALLOWED_USERS``,
|
||
``DISCORD_REQUIRE_MENTION``, ``DISCORD_FREE_RESPONSE_CHANNELS``,
|
||
``DISCORD_AUTO_THREAD``, ``DISCORD_REACTIONS``,
|
||
``DISCORD_IGNORED_CHANNELS``, ``DISCORD_ALLOWED_CHANNELS``,
|
||
``DISCORD_NO_THREAD_CHANNELS``, ``DISCORD_HISTORY_BACKFILL``,
|
||
``DISCORD_HISTORY_BACKFILL_LIMIT``, ``DISCORD_ALLOW_MENTION_*``,
|
||
``DISCORD_REPLY_TO_MODE``, ``DISCORD_THREAD_REQUIRE_MENTION``).
|
||
Rather than rewrite ~50 call sites inside the adapter to read from
|
||
``PlatformConfig.extra`` instead, this hook keeps the existing
|
||
env-driven model and merely owns the YAML→env translation here, next to
|
||
the adapter that consumes it.
|
||
|
||
Env vars take precedence over YAML — every assignment is guarded by
|
||
``not os.getenv(...)`` so explicit env vars survive a config.yaml
|
||
update. Returns ``None`` because no extras are seeded into
|
||
``PlatformConfig.extra`` directly (everything flows through env).
|
||
"""
|
||
if "require_mention" in discord_cfg and not os.getenv("DISCORD_REQUIRE_MENTION"):
|
||
os.environ["DISCORD_REQUIRE_MENTION"] = str(discord_cfg["require_mention"]).lower()
|
||
if "thread_require_mention" in discord_cfg and not os.getenv("DISCORD_THREAD_REQUIRE_MENTION"):
|
||
os.environ["DISCORD_THREAD_REQUIRE_MENTION"] = str(discord_cfg["thread_require_mention"]).lower()
|
||
platforms_cfg = yaml_cfg.get("platforms")
|
||
platform_extra_cfg = {}
|
||
if isinstance(platforms_cfg, dict):
|
||
discord_platform_cfg = platforms_cfg.get("discord")
|
||
if isinstance(discord_platform_cfg, dict):
|
||
candidate_extra = discord_platform_cfg.get("extra")
|
||
if isinstance(candidate_extra, dict):
|
||
platform_extra_cfg = candidate_extra
|
||
allowed_users_cfg = (
|
||
discord_cfg["allow_from"] if "allow_from" in discord_cfg
|
||
else platform_extra_cfg.get("allow_from")
|
||
)
|
||
if allowed_users_cfg is not None and not os.getenv("DISCORD_ALLOWED_USERS"):
|
||
if isinstance(allowed_users_cfg, list):
|
||
allowed_users_cfg = ",".join(str(v) for v in allowed_users_cfg)
|
||
os.environ["DISCORD_ALLOWED_USERS"] = str(allowed_users_cfg)
|
||
frc = discord_cfg.get("free_response_channels")
|
||
if frc is not None and not os.getenv("DISCORD_FREE_RESPONSE_CHANNELS"):
|
||
if isinstance(frc, list):
|
||
frc = ",".join(str(v) for v in frc)
|
||
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
|
||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||
if "reactions" in discord_cfg and not os.getenv("DISCORD_REACTIONS"):
|
||
os.environ["DISCORD_REACTIONS"] = str(discord_cfg["reactions"]).lower()
|
||
# ignored_channels: channels where bot never responds (even when mentioned)
|
||
ic = discord_cfg.get("ignored_channels")
|
||
if ic is not None and not os.getenv("DISCORD_IGNORED_CHANNELS"):
|
||
if isinstance(ic, list):
|
||
ic = ",".join(str(v) for v in ic)
|
||
os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic)
|
||
# allowed_channels: if set, bot ONLY responds in these channels (whitelist)
|
||
ac = discord_cfg.get("allowed_channels")
|
||
if ac is not None and not os.getenv("DISCORD_ALLOWED_CHANNELS"):
|
||
if isinstance(ac, list):
|
||
ac = ",".join(str(v) for v in ac)
|
||
os.environ["DISCORD_ALLOWED_CHANNELS"] = str(ac)
|
||
# no_thread_channels: channels where bot responds directly without creating thread
|
||
ntc = discord_cfg.get("no_thread_channels")
|
||
if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"):
|
||
if isinstance(ntc, list):
|
||
ntc = ",".join(str(v) for v in ntc)
|
||
os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc)
|
||
# history_backfill: recover missed channel messages for shared sessions
|
||
# when require_mention is active. Fetches messages between bot turns
|
||
# and prepends them to the user message for context.
|
||
if "history_backfill" in discord_cfg and not os.getenv("DISCORD_HISTORY_BACKFILL"):
|
||
os.environ["DISCORD_HISTORY_BACKFILL"] = str(discord_cfg["history_backfill"]).lower()
|
||
hbl = discord_cfg.get("history_backfill_limit")
|
||
if hbl is not None and not os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT"):
|
||
os.environ["DISCORD_HISTORY_BACKFILL_LIMIT"] = str(hbl)
|
||
# allow_mentions: granular control over what the bot can ping.
|
||
# Safe defaults (no @everyone/roles) are applied in the adapter;
|
||
# these YAML keys only override when set and let users opt back
|
||
# into unsafe modes (e.g. roles=true) if they actually want it.
|
||
allow_mentions_cfg = discord_cfg.get("allow_mentions")
|
||
if isinstance(allow_mentions_cfg, dict):
|
||
for yaml_key, env_key in (
|
||
("everyone", "DISCORD_ALLOW_MENTION_EVERYONE"),
|
||
("roles", "DISCORD_ALLOW_MENTION_ROLES"),
|
||
("users", "DISCORD_ALLOW_MENTION_USERS"),
|
||
("replied_user", "DISCORD_ALLOW_MENTION_REPLIED_USER"),
|
||
):
|
||
if yaml_key in allow_mentions_cfg and not os.getenv(env_key):
|
||
os.environ[env_key] = str(allow_mentions_cfg[yaml_key]).lower()
|
||
# reply_to_mode: top-level preferred, falls back to extra.reply_to_mode.
|
||
# YAML 1.1 parses bare 'off' as boolean False — coerce to string "off".
|
||
_discord_extra = discord_cfg.get("extra") if isinstance(discord_cfg.get("extra"), dict) else {}
|
||
_discord_rtm = (
|
||
discord_cfg["reply_to_mode"] if "reply_to_mode" in discord_cfg
|
||
else _discord_extra.get("reply_to_mode")
|
||
)
|
||
if _discord_rtm is not None and not os.getenv("DISCORD_REPLY_TO_MODE"):
|
||
_rtm_str = "off" if _discord_rtm is False else str(_discord_rtm).lower()
|
||
os.environ["DISCORD_REPLY_TO_MODE"] = _rtm_str
|
||
return None # all settings flow through env; nothing to merge into extras
|
||
|
||
|
||
def _is_connected(config) -> bool:
|
||
"""Discord is considered connected when DISCORD_BOT_TOKEN is set.
|
||
|
||
Looks up via ``hermes_cli.gateway.get_env_value`` at call time (not via
|
||
the plugin's own bound import) so tests that patch ``gateway_mod.get_env_value``
|
||
— including ``test_setup_openclaw_migration`` — can suppress ambient
|
||
``DISCORD_BOT_TOKEN`` env vars. Matches what the legacy
|
||
``_PLATFORMS["discord"]`` dispatch did before this migration.
|
||
"""
|
||
import hermes_cli.gateway as gateway_mod
|
||
return bool((gateway_mod.get_env_value("DISCORD_BOT_TOKEN") or "").strip())
|
||
|
||
|
||
def _build_adapter(config):
|
||
"""Factory wrapper that constructs DiscordAdapter from a PlatformConfig."""
|
||
return DiscordAdapter(config)
|
||
|
||
|
||
def register(ctx) -> None:
|
||
"""Plugin entry point — called by the Hermes plugin system."""
|
||
ctx.register_platform(
|
||
name="discord",
|
||
label="Discord",
|
||
adapter_factory=_build_adapter,
|
||
check_fn=check_discord_requirements,
|
||
is_connected=_is_connected,
|
||
required_env=["DISCORD_BOT_TOKEN"],
|
||
install_hint="pip install 'hermes-agent[messaging]'",
|
||
# Interactive setup wizard — replaces the central
|
||
# hermes_cli/setup.py::_setup_discord function. Same shape as Teams.
|
||
setup_fn=interactive_setup,
|
||
# YAML→env config bridge — owns the translation of ``config.yaml``
|
||
# ``discord:`` keys (require_mention, free_response_channels,
|
||
# auto_thread, reactions, ignored_channels, allowed_channels,
|
||
# no_thread_channels, allow_mentions.*, reply_to_mode,
|
||
# thread_require_mention) into ``DISCORD_*`` env vars that the
|
||
# adapter reads via ``os.getenv()``. Replaces the hardcoded block
|
||
# that used to live in ``gateway/config.py``. Hook contract: #24836.
|
||
apply_yaml_config_fn=_apply_yaml_config,
|
||
# Auth env vars for _is_user_authorized() integration
|
||
allowed_users_env="DISCORD_ALLOWED_USERS",
|
||
allow_all_env="DISCORD_ALLOW_ALL_USERS",
|
||
# Cron home-channel delivery
|
||
cron_deliver_env_var="DISCORD_HOME_CHANNEL",
|
||
# Out-of-process cron delivery via Discord REST API. Without this
|
||
# hook, ``deliver=discord`` cron jobs fail with "No live adapter"
|
||
# when cron runs separately from the gateway. Mirrors Teams pattern.
|
||
standalone_sender_fn=_standalone_send,
|
||
# Discord hard limit per message
|
||
max_message_length=2000,
|
||
# Display
|
||
emoji="🎮",
|
||
allow_update_command=True,
|
||
)
|