refactor(matrix): simplify adapter after code review

- Extract _resolve_message_context() to deduplicate ~40 lines of
  mention/thread/DM gating logic between text and media handlers
- Move mautrix.types imports to module level (16 scattered local
  imports consolidated)
- Parse mention/thread env vars once in __init__ instead of per-message
- Cache _is_bot_mentioned() result instead of calling 3x per event
- Consolidate send_emote/send_notice into shared _send_simple_message()
- Use _is_dm_room() in get_chat_info() instead of inline duplication
- Add _CRYPTO_PICKLE_PATH constant (was duplicated in 2 locations)
- Fix fragile event_ts extraction (double getattr, None safety)
- Clean up leaked aiohttp session on auth failure paths
- Remove redundant trailing _track_thread() calls
This commit is contained in:
alt-glitch 2026-04-11 07:38:50 +05:30 committed by Teknium
parent 1f3f120042
commit bc8b93812c

View file

@ -35,6 +35,19 @@ from typing import Any, Dict, Optional, Set
from html import escape as _html_escape
from mautrix.types import (
ContentURI,
EventID,
EventType,
PaginationDirection,
PresenceState,
RoomCreatePreset,
RoomID,
SyncToken,
TrustState,
UserID,
)
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
@ -54,6 +67,7 @@ MAX_MESSAGE_LENGTH = 4000
# Uses get_hermes_home() so each profile gets its own Matrix store.
from hermes_constants import get_hermes_dir as _get_hermes_dir
_STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store")
_CRYPTO_PICKLE_PATH = _STORE_DIR / "crypto_store.pickle"
# Grace period: ignore messages older than this many seconds before startup.
_STARTUP_GRACE_SECONDS = 5
@ -169,12 +183,17 @@ class MatrixAdapter(BasePlatformAdapter):
self._bot_participated_threads: set = self._load_participated_threads()
self._MAX_TRACKED_THREADS = 500
# Mention/thread gating — parsed once from env vars.
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "")
self._free_rooms: Set[str] = {r.strip() for r in free_rooms_raw.split(",") if r.strip()}
self._auto_thread: bool = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
self._dm_mention_threads: bool = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
# Reactions: configurable via MATRIX_REACTIONS (default: true).
self._reactions_enabled: bool = os.getenv(
"MATRIX_REACTIONS", "true"
).lower() not in ("false", "0", "no")
# Tracks the reaction event_id for in-progress (eyes) reactions.
# Key: (room_id, message_event_id) → reaction_event_id (for the eyes reaction).
self._pending_reactions: dict[tuple[str, str], str] = {}
# Text batching: merge rapid successive messages (Telegram-style).
@ -206,7 +225,6 @@ class MatrixAdapter(BasePlatformAdapter):
from mautrix.api import HTTPAPI
from mautrix.client import Client
from mautrix.client.state_store import MemoryStateStore, MemorySyncStore
from mautrix.types import EventType, UserID
if not self._homeserver:
logger.error("Matrix: homeserver URL not configured")
@ -262,6 +280,7 @@ class MatrixAdapter(BasePlatformAdapter):
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER: %s",
exc,
)
await api.session.close()
return False
elif self._password and self._user_id:
try:
@ -271,15 +290,16 @@ class MatrixAdapter(BasePlatformAdapter):
device_name="Hermes Agent",
device_id=self._device_id or None,
)
# login() stores the token automatically.
if resp and hasattr(resp, "device_id"):
client.device_id = resp.device_id
logger.info("Matrix: logged in as %s", self._user_id)
except Exception as exc:
logger.error("Matrix: login failed — %s", exc)
await api.session.close()
return False
else:
logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD")
await api.session.close()
return False
# Set up E2EE if requested.
@ -298,7 +318,7 @@ class MatrixAdapter(BasePlatformAdapter):
crypto_store = MemoryCryptoStore()
# Restore persisted crypto state from a previous run.
pickle_path = _STORE_DIR / "crypto_store.pickle"
pickle_path = _CRYPTO_PICKLE_PATH
if pickle_path.exists():
try:
import pickle
@ -314,7 +334,6 @@ class MatrixAdapter(BasePlatformAdapter):
# Set trust policy: accept unverified devices so senders
# share Megolm session keys with us automatically.
from mautrix.types import TrustState
olm.share_keys_min_trust = TrustState.UNVERIFIED
olm.send_keys_min_trust = TrustState.UNVERIFIED
@ -392,7 +411,7 @@ class MatrixAdapter(BasePlatformAdapter):
import pickle
crypto_store = self._client.crypto.crypto_store
_STORE_DIR.mkdir(parents=True, exist_ok=True)
pickle_path = _STORE_DIR / "crypto_store.pickle"
pickle_path = _CRYPTO_PICKLE_PATH
with open(pickle_path, "wb") as f:
pickle.dump(crypto_store, f)
logger.info("Matrix: persisted E2EE crypto store to %s", pickle_path)
@ -416,7 +435,6 @@ class MatrixAdapter(BasePlatformAdapter):
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a message to a Matrix room."""
from mautrix.types import EventType, RoomID
if not content:
return SendResult(success=True)
@ -492,30 +510,12 @@ class MatrixAdapter(BasePlatformAdapter):
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Return room name and type (dm/group)."""
name = chat_id
chat_type = "group"
chat_type = "dm" if await self._is_dm_room(chat_id) else "group"
if self._client:
# Try state store for member count.
state_store = getattr(self._client, "state_store", None)
if state_store:
try:
members = await state_store.get_members(
chat_id,
)
if members and len(members) == 2:
chat_type = "dm"
except Exception:
pass
# Use DM cache.
if self._dm_rooms.get(chat_id, False):
chat_type = "dm"
# Try to get room name from state.
try:
from mautrix.types import EventType as ET, RoomID
name_evt = await self._client.get_state_event(
RoomID(chat_id), ET.ROOM_NAME,
RoomID(chat_id), EventType.ROOM_NAME,
)
if name_evt and hasattr(name_evt, "name") and name_evt.name:
name = name_evt.name
@ -534,7 +534,6 @@ class MatrixAdapter(BasePlatformAdapter):
"""Send a typing indicator."""
if self._client:
try:
from mautrix.types import RoomID
await self._client.set_typing(RoomID(chat_id), timeout=30000)
except Exception:
pass
@ -543,7 +542,6 @@ class MatrixAdapter(BasePlatformAdapter):
self, chat_id: str, message_id: str, content: str
) -> SendResult:
"""Edit an existing message (via m.replace)."""
from mautrix.types import EventType, RoomID
formatted = self.format_message(content)
msg_content: Dict[str, Any] = {
@ -683,7 +681,6 @@ class MatrixAdapter(BasePlatformAdapter):
is_voice: bool = False,
) -> SendResult:
"""Upload bytes to Matrix and send as a media message."""
from mautrix.types import EventType, RoomID
# Upload to homeserver.
try:
@ -866,10 +863,8 @@ class MatrixAdapter(BasePlatformAdapter):
return
# Startup grace: ignore old messages from initial sync.
event_ts = getattr(event, "timestamp", 0) / 1000.0 if getattr(event, "timestamp", 0) else 0
# Also check server_timestamp for compatibility.
if not event_ts:
event_ts = getattr(event, "server_timestamp", 0) / 1000.0 if getattr(event, "server_timestamp", 0) else 0
raw_ts = getattr(event, "timestamp", None) or getattr(event, "server_timestamp", None) or 0
event_ts = raw_ts / 1000.0 if raw_ts else 0.0
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
return
@ -907,6 +902,68 @@ class MatrixAdapter(BasePlatformAdapter):
elif msgtype in ("m.text", "m.notice"):
await self._handle_text_message(room_id, sender, event_id, event_ts, source_content, relates_to)
async def _resolve_message_context(
self,
room_id: str,
sender: str,
event_id: str,
body: str,
source_content: dict,
relates_to: dict,
) -> Optional[tuple]:
"""Shared mention/thread/DM gating for text and media handlers.
Returns (body, is_dm, chat_type, thread_id, display_name, source)
or None if the message should be dropped (mention gating).
"""
is_dm = await self._is_dm_room(room_id)
chat_type = "dm" if is_dm else "group"
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
formatted_body = source_content.get("formatted_body")
is_mentioned = self._is_bot_mentioned(body, formatted_body)
# Require-mention gating.
if not is_dm:
is_free_room = room_id in self._free_rooms
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
if self._require_mention and not is_free_room and not in_bot_thread:
if not is_mentioned:
return None
# DM mention-thread.
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
thread_id = event_id
self._track_thread(thread_id)
# Strip mention from body.
if is_mentioned:
body = self._strip_mention(body)
# Auto-thread.
if not is_dm and not thread_id and self._auto_thread:
thread_id = event_id
self._track_thread(thread_id)
display_name = await self._get_display_name(room_id, sender)
source = self.build_source(
chat_id=room_id,
chat_type=chat_type,
user_id=sender,
user_name=display_name,
thread_id=thread_id,
)
if thread_id:
self._track_thread(thread_id)
self._background_read_receipt(room_id, event_id)
return body, is_dm, chat_type, thread_id, display_name, source
async def _handle_text_message(
self,
room_id: str,
@ -921,45 +978,12 @@ class MatrixAdapter(BasePlatformAdapter):
if not body:
return
# Determine chat type.
is_dm = await self._is_dm_room(room_id)
chat_type = "dm" if is_dm else "group"
# Thread support.
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
# Require-mention gating.
if not is_dm:
free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "")
free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()}
require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
is_free_room = room_id in free_rooms
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
formatted_body = source_content.get("formatted_body")
if require_mention and not is_free_room and not in_bot_thread:
if not self._is_bot_mentioned(body, formatted_body):
return
# DM mention-thread.
if is_dm and not thread_id:
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
thread_id = event_id
self._track_thread(thread_id)
# Strip mention from body.
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
body = self._strip_mention(body)
# Auto-thread.
if not is_dm and not thread_id:
auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
if auto_thread:
thread_id = event_id
self._track_thread(thread_id)
ctx = await self._resolve_message_context(
room_id, sender, event_id, body, source_content, relates_to,
)
if ctx is None:
return
body, is_dm, chat_type, thread_id, display_name, source = ctx
# Reply-to detection.
reply_to = None
@ -983,20 +1007,10 @@ class MatrixAdapter(BasePlatformAdapter):
stripped.append(line)
body = "\n".join(stripped) if stripped else body
# Message type.
msg_type = MessageType.TEXT
if body.startswith(("!", "/")):
msg_type = MessageType.COMMAND
display_name = await self._get_display_name(room_id, sender)
source = self.build_source(
chat_id=room_id,
chat_type=chat_type,
user_id=sender,
user_name=display_name,
thread_id=thread_id,
)
msg_event = MessageEvent(
text=body,
message_type=msg_type,
@ -1006,13 +1020,6 @@ class MatrixAdapter(BasePlatformAdapter):
reply_to_message_id=reply_to,
)
if thread_id:
self._track_thread(thread_id)
# Acknowledge receipt (fire-and-forget).
self._background_read_receipt(room_id, event_id)
# Batch plain text messages — commands dispatch immediately.
if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0:
self._enqueue_text_event(msg_event)
else:
@ -1079,7 +1086,6 @@ class MatrixAdapter(BasePlatformAdapter):
)
if should_cache_locally and url:
try:
from mautrix.types import ContentURI
file_bytes = await self._client.download_media(ContentURI(url))
if file_bytes is not None:
if is_encrypted_media:
@ -1131,53 +1137,12 @@ class MatrixAdapter(BasePlatformAdapter):
except Exception as e:
logger.warning("[Matrix] Failed to cache media: %s", e)
is_dm = await self._is_dm_room(room_id)
chat_type = "dm" if is_dm else "group"
# Thread/reply detection.
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
# Require-mention gating (media messages).
if not is_dm:
free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "")
free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()}
require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
is_free_room = room_id in free_rooms
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
if require_mention and not is_free_room and not in_bot_thread:
formatted_body = source_content.get("formatted_body")
if not self._is_bot_mentioned(body, formatted_body):
return
# DM mention-thread.
if is_dm and not thread_id:
dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes")
if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")):
thread_id = event_id
self._track_thread(thread_id)
# Strip mention from body.
if self._is_bot_mentioned(body, source_content.get("formatted_body")):
body = self._strip_mention(body)
# Auto-thread.
if not is_dm and not thread_id:
auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
if auto_thread:
thread_id = event_id
self._track_thread(thread_id)
display_name = await self._get_display_name(room_id, sender)
source = self.build_source(
chat_id=room_id,
chat_type=chat_type,
user_id=sender,
user_name=display_name,
thread_id=thread_id,
ctx = await self._resolve_message_context(
room_id, sender, event_id, body, source_content, relates_to,
)
if ctx is None:
return
body, is_dm, chat_type, thread_id, display_name, source = ctx
allow_http_fallback = bool(http_url) and not is_encrypted_media
media_urls = [cached_path] if cached_path else ([http_url] if allow_http_fallback else None)
@ -1193,11 +1158,6 @@ class MatrixAdapter(BasePlatformAdapter):
media_types=media_types,
)
if thread_id:
self._track_thread(thread_id)
self._background_read_receipt(room_id, event_id)
await self.handle_message(msg_event)
async def _on_encrypted_event(self, event: Any) -> None:
@ -1219,7 +1179,6 @@ class MatrixAdapter(BasePlatformAdapter):
async def _on_invite(self, event: Any) -> None:
"""Auto-join rooms when invited."""
from mautrix.types import RoomID
room_id = str(getattr(event, "room_id", ""))
@ -1245,7 +1204,6 @@ class MatrixAdapter(BasePlatformAdapter):
"""Send an emoji reaction to a message in a room.
Returns the reaction event_id on success, None on failure.
"""
from mautrix.types import EventType, RoomID
if not self._client:
return None
@ -1409,7 +1367,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._client:
return False
try:
from mautrix.types import EventID, RoomID
await self._client.set_read_markers(
RoomID(room_id),
fully_read_event=EventID(event_id),
@ -1432,7 +1389,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._client:
return False
try:
from mautrix.types import EventID, RoomID
await self._client.redact(
RoomID(room_id), EventID(event_id), reason=reason or None,
)
@ -1456,7 +1412,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._client:
return []
try:
from mautrix.types import PaginationDirection, RoomID, SyncToken
resp = await self._client.get_messages(
RoomID(room_id),
direction=PaginationDirection.BACKWARD,
@ -1505,7 +1460,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._client:
return None
try:
from mautrix.types import RoomCreatePreset, UserID
preset_enum = {
"private_chat": RoomCreatePreset.PRIVATE,
"public_chat": RoomCreatePreset.PUBLIC,
@ -1532,7 +1486,6 @@ class MatrixAdapter(BasePlatformAdapter):
if not self._client:
return False
try:
from mautrix.types import RoomID, UserID
await self._client.invite_user(RoomID(room_id), UserID(user_id))
logger.info("Matrix: invited %s to %s", user_id, room_id)
return True
@ -1554,7 +1507,6 @@ class MatrixAdapter(BasePlatformAdapter):
logger.warning("Matrix: invalid presence state %r", state)
return False
try:
from mautrix.types import PresenceState
presence_map = {
"online": PresenceState.ONLINE,
"offline": PresenceState.OFFLINE,
@ -1574,19 +1526,14 @@ class MatrixAdapter(BasePlatformAdapter):
# Emote & notice message types
# ------------------------------------------------------------------
async def send_emote(
self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None,
async def _send_simple_message(
self, chat_id: str, text: str, msgtype: str,
) -> SendResult:
"""Send an emote message (/me style action)."""
from mautrix.types import EventType, RoomID
"""Send a simple message (emote, notice) with optional HTML formatting."""
if not self._client or not text:
return SendResult(success=False, error="No client or empty text")
msg_content: Dict[str, Any] = {
"msgtype": "m.emote",
"body": text,
}
msg_content: Dict[str, Any] = {"msgtype": msgtype, "body": text}
html = self._markdown_to_html(text)
if html and html != text:
msg_content["format"] = "org.matrix.custom.html"
@ -1600,31 +1547,17 @@ class MatrixAdapter(BasePlatformAdapter):
except Exception as exc:
return SendResult(success=False, error=str(exc))
async def send_emote(
self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send an emote message (/me style action)."""
return await self._send_simple_message(chat_id, text, "m.emote")
async def send_notice(
self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a notice message (bot-appropriate, non-alerting)."""
from mautrix.types import EventType, RoomID
if not self._client or not text:
return SendResult(success=False, error="No client or empty text")
msg_content: Dict[str, Any] = {
"msgtype": "m.notice",
"body": text,
}
html = self._markdown_to_html(text)
if html and html != text:
msg_content["format"] = "org.matrix.custom.html"
msg_content["formatted_body"] = html
try:
event_id = await self._client.send_message_event(
RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content,
)
return SendResult(success=True, message_id=str(event_id))
except Exception as exc:
return SendResult(success=False, error=str(exc))
return await self._send_simple_message(chat_id, text, "m.notice")
# ------------------------------------------------------------------
# Helpers