mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-17 04:31:55 +00:00
feat: devex help, add Makefile, ruff, pre-commit, and modernize CI
This commit is contained in:
parent
172a38c344
commit
f4d7e6a29e
111 changed files with 11655 additions and 10200 deletions
|
|
@ -303,8 +303,8 @@ Optional but valuable:
|
|||
After implementing everything, verify with:
|
||||
|
||||
```bash
|
||||
# All tests pass
|
||||
python -m pytest tests/ -q
|
||||
# All checks pass (lint + test)
|
||||
make check
|
||||
|
||||
# Grep for your platform name to find any missed integration points
|
||||
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \
|
||||
|
|
|
|||
|
|
@ -13,20 +13,20 @@ import uuid
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pathlib import Path as _Path
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
|
|
@ -251,6 +251,7 @@ def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
|||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
|
||||
TEXT = "text"
|
||||
LOCATION = "location"
|
||||
PHOTO = "photo"
|
||||
|
|
@ -266,42 +267,43 @@ class MessageType(Enum):
|
|||
class MessageEvent:
|
||||
"""
|
||||
Incoming message from a platform.
|
||||
|
||||
|
||||
Normalized representation that all adapters produce.
|
||||
"""
|
||||
|
||||
# Message content
|
||||
text: str
|
||||
message_type: MessageType = MessageType.TEXT
|
||||
|
||||
|
||||
# Source information
|
||||
source: SessionSource = None
|
||||
|
||||
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: Optional[str] = None
|
||||
|
||||
message_id: str | None = None
|
||||
|
||||
# Media attachments
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
media_urls: list[str] = field(default_factory=list)
|
||||
media_types: list[str] = field(default_factory=list)
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: Optional[str] = None
|
||||
|
||||
reply_to_message_id: str | None = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
def is_command(self) -> bool:
|
||||
"""Check if this is a command message (e.g., /new, /reset)."""
|
||||
return self.text.startswith("/")
|
||||
|
||||
def get_command(self) -> Optional[str]:
|
||||
|
||||
def get_command(self) -> str | None:
|
||||
"""Extract command name if this is a command message."""
|
||||
if not self.is_command():
|
||||
return None
|
||||
# Split on space and get first word, strip the /
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[0][1:].lower() if parts else None
|
||||
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
"""Get the arguments after a command."""
|
||||
if not self.is_command():
|
||||
|
|
@ -310,91 +312,88 @@ class MessageEvent:
|
|||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class SendResult:
|
||||
"""Result of sending a message."""
|
||||
|
||||
success: bool
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
message_id: str | None = None
|
||||
error: str | None = None
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[str | None]]
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
|
||||
|
||||
Subclasses implement platform-specific logic for:
|
||||
- Connecting and authenticating
|
||||
- Receiving messages
|
||||
- Sending messages/responses
|
||||
- Handling media
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig, platform: Platform):
|
||||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._message_handler: MessageHandler | None = None
|
||||
self._running = False
|
||||
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
|
||||
self._active_sessions: dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
return self.platform.value.title()
|
||||
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if adapter is currently connected."""
|
||||
return self._running
|
||||
|
||||
|
||||
def set_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
Set the handler for incoming messages.
|
||||
|
||||
|
||||
The handler receives a MessageEvent and should return
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Connect to the platform and start receiving messages.
|
||||
|
||||
|
||||
Returns True if connection was successful.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the platform."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
content: Message content (may be markdown)
|
||||
reply_to: Optional message ID to reply to
|
||||
metadata: Additional platform-specific options
|
||||
|
||||
|
||||
Returns:
|
||||
SendResult with success status and message ID
|
||||
"""
|
||||
|
|
@ -416,21 +415,21 @@ class BasePlatformAdapter(ABC):
|
|||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
|
||||
Override in subclasses if the platform supports it.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
|
|
@ -438,87 +437,91 @@ class BasePlatformAdapter(ABC):
|
|||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an animated GIF natively via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send GIFs as proper animations
|
||||
(e.g., Telegram send_animation) so they auto-play inline.
|
||||
Default falls back to send_image.
|
||||
"""
|
||||
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _is_animation_url(url: str) -> bool:
|
||||
"""Check if a URL points to an animated GIF (vs a static image)."""
|
||||
lower = url.lower().split('?')[0] # Strip query params
|
||||
return lower.endswith('.gif')
|
||||
lower = url.lower().split("?")[0] # Strip query params
|
||||
return lower.endswith(".gif")
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
||||
def extract_images(content: str) -> tuple[list[tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
||||
md_pattern = r"!\[([^\]]*)\]\((https?://[^\s\)]+)\)"
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
||||
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
||||
if any(
|
||||
url.lower().endswith(ext) or ext in url.lower()
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp", "fal.media", "fal-cdn", "replicate.delivery"]
|
||||
):
|
||||
images.append((url, alt_text))
|
||||
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
|
||||
# Remove only the matched image tags from content (not all markdown images)
|
||||
if images:
|
||||
extracted_urls = {url for url, _ in images}
|
||||
|
||||
def _remove_if_extracted(match):
|
||||
url = match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
return '' if url in extracted_urls else match.group(0)
|
||||
return "" if url in extracted_urls else match.group(0)
|
||||
|
||||
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
|
||||
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
|
|
@ -532,8 +535,8 @@ class BasePlatformAdapter(ABC):
|
|||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a video natively via the platform API.
|
||||
|
|
@ -550,9 +553,9 @@ class BasePlatformAdapter(ABC):
|
|||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a document/file natively via the platform API.
|
||||
|
|
@ -569,8 +572,8 @@ class BasePlatformAdapter(ABC):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a local image file natively via the platform API.
|
||||
|
|
@ -585,45 +588,45 @@ class BasePlatformAdapter(ABC):
|
|||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
||||
def extract_media(content: str) -> tuple[list[tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r'MEDIA:(\S+)'
|
||||
media_pattern = r"MEDIA:(\S+)"
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, '', cleaned)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
cleaned = re.sub(media_pattern, "", cleaned)
|
||||
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
|
||||
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
||||
to recover quickly after progress messages interrupt it.
|
||||
"""
|
||||
|
|
@ -633,20 +636,20 @@ class BasePlatformAdapter(ABC):
|
|||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
|
|
@ -655,10 +658,10 @@ class BasePlatformAdapter(ABC):
|
|||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_human_delay() -> float:
|
||||
"""
|
||||
|
|
@ -685,35 +688,40 @@ class BasePlatformAdapter(ABC):
|
|||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
|
||||
# Send response if any
|
||||
if not response:
|
||||
logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
logger.info(
|
||||
"[%s] extract_images found %d image(s) in response (%d chars)",
|
||||
self.name,
|
||||
len(images),
|
||||
len(response),
|
||||
)
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id
|
||||
logger.info(
|
||||
"[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id
|
||||
)
|
||||
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
|
|
@ -721,14 +729,14 @@ class BasePlatformAdapter(ABC):
|
|||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id
|
||||
reply_to=event.message_id,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
|
|
@ -736,7 +744,12 @@ class BasePlatformAdapter(ABC):
|
|||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
logger.info(
|
||||
"[%s] Sending image: %s (alt=%s)",
|
||||
self.name,
|
||||
image_url[:80],
|
||||
alt_text[:30] if alt_text else "",
|
||||
)
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
|
|
@ -754,11 +767,11 @@ class BasePlatformAdapter(ABC):
|
|||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if human_delay > 0:
|
||||
|
|
@ -790,7 +803,7 @@ class BasePlatformAdapter(ABC):
|
|||
print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}")
|
||||
except Exception as media_err:
|
||||
print(f"[{self.name}] Error sending media: {media_err}")
|
||||
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
|
|
@ -806,10 +819,11 @@ class BasePlatformAdapter(ABC):
|
|||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
|
|
@ -821,26 +835,26 @@ class BasePlatformAdapter(ABC):
|
|||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
|
||||
def get_pending_message(self, session_key: str) -> MessageEvent | None:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_name: Optional[str] = None,
|
||||
chat_name: str | None = None,
|
||||
chat_type: str = "dm",
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
user_id: str | None = None,
|
||||
user_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
chat_topic: str | None = None,
|
||||
user_id_alt: str | None = None,
|
||||
chat_id_alt: str | None = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
|
|
@ -858,30 +872,30 @@ class BasePlatformAdapter(ABC):
|
|||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about a chat/channel.
|
||||
|
||||
|
||||
Returns dict with at least:
|
||||
- name: Chat name
|
||||
- type: "dm", "group", "channel"
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format a message for this platform.
|
||||
|
||||
|
||||
Override in subclasses to handle platform-specific formatting
|
||||
(e.g., Telegram MarkdownV2, Discord markdown).
|
||||
|
||||
|
||||
Default implementation returns content as-is.
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> list[str]:
|
||||
"""
|
||||
Split a long message into chunks, preserving code block boundaries.
|
||||
|
||||
|
|
@ -900,14 +914,14 @@ class BasePlatformAdapter(ABC):
|
|||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
||||
FENCE_CLOSE = "\n```"
|
||||
|
||||
chunks: List[str] = []
|
||||
chunks: list[str] = []
|
||||
remaining = content
|
||||
# When the previous chunk ended mid-code-block, this holds the
|
||||
# language tag (possibly "") so we can reopen the fence.
|
||||
carry_lang: Optional[str] = None
|
||||
carry_lang: str | None = None
|
||||
|
||||
while remaining:
|
||||
# If we're continuing a code block from the previous chunk,
|
||||
|
|
@ -965,8 +979,6 @@ class BasePlatformAdapter(ABC):
|
|||
# Append chunk indicators when the response spans multiple messages
|
||||
if len(chunks) > 1:
|
||||
total = len(chunks)
|
||||
chunks = [
|
||||
f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)
|
||||
]
|
||||
chunks = [f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)]
|
||||
|
||||
return chunks
|
||||
|
|
|
|||
|
|
@ -10,14 +10,16 @@ Uses discord.py library for:
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Message as DiscordMessage, Intents
|
||||
from discord import Intents
|
||||
from discord import Message as DiscordMessage
|
||||
from discord.ext import commands
|
||||
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
|
|
@ -28,6 +30,7 @@ except ImportError:
|
|||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
|
@ -36,8 +39,8 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -49,7 +52,7 @@ def check_discord_requirements() -> bool:
|
|||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
|
|
@ -59,26 +62,26 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
- Auto-threading for long conversations
|
||||
- Reaction-based feedback
|
||||
"""
|
||||
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._client: commands.Bot | None = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Set up intents -- members intent needed for username-to-ID resolution
|
||||
intents = Intents.default()
|
||||
|
|
@ -86,30 +89,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
self._allowed_user_ids = {
|
||||
uid.strip() for uid in allowed_env.split(",") if uid.strip()
|
||||
}
|
||||
|
||||
self._allowed_user_ids = {uid.strip() for uid in allowed_env.split(",") if uid.strip()}
|
||||
|
||||
adapter_self = self # capture for closure
|
||||
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
print(f"[{adapter_self.name}] Connected as {adapter_self._client.user}")
|
||||
|
||||
|
||||
# Resolve any usernames in the allowed list to numeric IDs
|
||||
await adapter_self._resolve_allowed_usernames()
|
||||
|
||||
|
||||
# Sync slash commands with Discord
|
||||
try:
|
||||
synced = await adapter_self._client.tree.sync()
|
||||
|
|
@ -117,33 +118,33 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
except Exception as e:
|
||||
print(f"[{adapter_self.name}] Slash command sync failed: {e}")
|
||||
adapter_self._ready_event.set()
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
except TimeoutError:
|
||||
print(f"[{self.name}] Timeout waiting for connection")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
if self._client:
|
||||
|
|
@ -151,59 +152,55 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
await self._client.close()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception as e:
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
|
|
@ -223,7 +220,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
msg = await channel.fetch_message(int(message_id))
|
||||
formatted = self.format_message(content)
|
||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||
formatted = formatted[:self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
formatted = formatted[: self.MAX_MESSAGE_LENGTH - 3] + "..."
|
||||
await msg.edit(content=formatted)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
except Exception as e:
|
||||
|
|
@ -233,28 +230,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
|
|
@ -262,36 +259,36 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
|
|
@ -299,7 +296,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
|
@ -308,31 +305,31 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
|
|
@ -342,23 +339,24 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
|
||||
import io
|
||||
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
|
|
@ -368,20 +366,20 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
|
|
@ -397,7 +395,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
|
|
@ -406,7 +404,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
async def _resolve_allowed_usernames(self) -> None:
|
||||
"""
|
||||
Resolve non-numeric entries in DISCORD_ALLOWED_USERS to Discord user IDs.
|
||||
|
|
@ -453,8 +451,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
uid = str(member.id)
|
||||
numeric_ids.add(uid)
|
||||
resolved_count += 1
|
||||
matched_name = name_lower if name_lower in to_resolve else (
|
||||
display_lower if display_lower in to_resolve else global_lower
|
||||
matched_name = (
|
||||
name_lower
|
||||
if name_lower in to_resolve
|
||||
else (display_lower if display_lower in to_resolve else global_lower)
|
||||
)
|
||||
to_resolve.discard(matched_name)
|
||||
print(f"[{self.name}] Resolved '{matched_name}' -> {uid} ({member.name}#{member.discriminator})")
|
||||
|
|
@ -474,12 +474,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
return content
|
||||
|
||||
|
||||
def _register_slash_commands(self) -> None:
|
||||
"""Register Discord slash commands on the command tree."""
|
||||
if not self._client:
|
||||
|
|
@ -694,7 +694,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
|
|
@ -715,9 +715,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
raw_message=interaction,
|
||||
)
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, approval_id: str
|
||||
) -> SendResult:
|
||||
async def send_exec_approval(self, chat_id: str, command: str, approval_id: str) -> SendResult:
|
||||
"""
|
||||
Send a button-based exec approval prompt for a dangerous command.
|
||||
|
||||
|
|
@ -759,28 +757,28 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
|
|
@ -798,7 +796,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
|
|
@ -811,15 +809,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
|
||||
# Get thread ID if in a thread
|
||||
thread_id = None
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
|
|
@ -830,7 +828,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
|
|
@ -869,7 +867,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
message_type=msg_type,
|
||||
|
|
@ -881,7 +879,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
|
|
@ -911,20 +909,14 @@ if DISCORD_AVAILABLE:
|
|||
return True # No allowlist = anyone can approve
|
||||
return str(interaction.user.id) in self.allowed_user_ids
|
||||
|
||||
async def _resolve(
|
||||
self, interaction: discord.Interaction, action: str, color: discord.Color
|
||||
):
|
||||
async def _resolve(self, interaction: discord.Interaction, action: str, color: discord.Color):
|
||||
"""Resolve the approval and update the message."""
|
||||
if self.resolved:
|
||||
await interaction.response.send_message(
|
||||
"This approval has already been resolved~", ephemeral=True
|
||||
)
|
||||
await interaction.response.send_message("This approval has already been resolved~", ephemeral=True)
|
||||
return
|
||||
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
"You're not authorized to approve commands~", ephemeral=True
|
||||
)
|
||||
await interaction.response.send_message("You're not authorized to approve commands~", ephemeral=True)
|
||||
return
|
||||
|
||||
self.resolved = True
|
||||
|
|
@ -944,6 +936,7 @@ if DISCORD_AVAILABLE:
|
|||
# Store the approval decision
|
||||
try:
|
||||
from tools.approval import approve_permanent
|
||||
|
||||
if action == "allow_once":
|
||||
pass # One-time approval handled by gateway
|
||||
elif action == "allow_always":
|
||||
|
|
@ -952,21 +945,15 @@ if DISCORD_AVAILABLE:
|
|||
pass
|
||||
|
||||
@discord.ui.button(label="Allow Once", style=discord.ButtonStyle.green)
|
||||
async def allow_once(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def allow_once(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_once", discord.Color.green())
|
||||
|
||||
@discord.ui.button(label="Always Allow", style=discord.ButtonStyle.blurple)
|
||||
async def allow_always(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def allow_always(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "allow_always", discord.Color.blue())
|
||||
|
||||
@discord.ui.button(label="Deny", style=discord.ButtonStyle.red)
|
||||
async def deny(
|
||||
self, interaction: discord.Interaction, button: discord.ui.Button
|
||||
):
|
||||
async def deny(self, interaction: discord.Interaction, button: discord.ui.Button):
|
||||
await self._resolve(interaction, "deny", discord.Color.red())
|
||||
|
||||
async def on_timeout(self):
|
||||
|
|
|
|||
|
|
@ -19,10 +19,11 @@ import os
|
|||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
|
@ -66,10 +67,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
super().__init__(config, Platform.HOMEASSISTANT)
|
||||
|
||||
# Connection state
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||||
self._rest_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._rest_session: aiohttp.ClientSession | None = None
|
||||
self._listen_task: asyncio.Task | None = None
|
||||
self._msg_id: int = 0
|
||||
|
||||
# Configuration from extra
|
||||
|
|
@ -80,13 +81,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
self._hass_token: str = token
|
||||
|
||||
# Event filtering
|
||||
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
|
||||
self._watch_domains: set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: set[str] = set(extra.get("ignore_entities", []))
|
||||
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
|
||||
|
||||
# Cooldown tracking: entity_id -> last_event_timestamp
|
||||
self._last_event_time: Dict[str, float] = {}
|
||||
self._last_event_time: dict[str, float] = {}
|
||||
|
||||
def _next_id(self) -> int:
|
||||
"""Return the next WebSocket message ID."""
|
||||
|
|
@ -141,10 +142,12 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
# Step 2: Send auth
|
||||
await self._ws.send_json({
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
})
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
}
|
||||
)
|
||||
|
||||
# Step 3: Wait for auth_ok
|
||||
msg = await self._ws.receive_json()
|
||||
|
|
@ -155,11 +158,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
|
||||
# Step 4: Subscribe to state_changed events
|
||||
sub_id = self._next_id()
|
||||
await self._ws.send_json({
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
})
|
||||
await self._ws.send_json(
|
||||
{
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify subscription acknowledgement
|
||||
msg = await self._ws.receive_json()
|
||||
|
|
@ -245,7 +250,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||||
break
|
||||
|
||||
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
|
||||
async def _handle_ha_event(self, event: dict[str, Any]) -> None:
|
||||
"""Process a state_changed event from Home Assistant."""
|
||||
event_data = event.get("data", {})
|
||||
entity_id: str = event_data.get("entity_id", "")
|
||||
|
|
@ -302,9 +307,9 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
@staticmethod
|
||||
def _format_state_change(
|
||||
entity_id: str,
|
||||
old_state: Dict[str, Any],
|
||||
new_state: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
old_state: dict[str, Any],
|
||||
new_state: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""Convert a state_changed event into a human-readable description."""
|
||||
if not new_state:
|
||||
return None
|
||||
|
|
@ -331,10 +336,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
|
||||
if domain == "sensor":
|
||||
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: changed from "
|
||||
f"{old_val}{unit} to {new_val}{unit}"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: changed from {old_val}{unit} to {new_val}{unit}"
|
||||
|
||||
if domain == "binary_sensor":
|
||||
return (
|
||||
|
|
@ -344,22 +346,13 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
if domain in ("light", "switch", "fan"):
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: turned "
|
||||
f"{'on' if new_val == 'on' else 'off'}"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: turned {'on' if new_val == 'on' else 'off'}"
|
||||
|
||||
if domain == "alarm_control_panel":
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: alarm state changed from "
|
||||
f"'{old_val}' to '{new_val}'"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name}: alarm state changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# Generic fallback
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name} ({entity_id}): "
|
||||
f"changed from '{old_val}' to '{new_val}'"
|
||||
)
|
||||
return f"[Home Assistant] {friendly_name} ({entity_id}): changed from '{old_val}' to '{new_val}'"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound messaging
|
||||
|
|
@ -369,8 +362,8 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a notification via HA REST API (persistent_notification.create).
|
||||
|
||||
|
|
@ -384,7 +377,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
}
|
||||
payload = {
|
||||
"title": "Hermes Agent",
|
||||
"message": content[:self.MAX_MESSAGE_LENGTH],
|
||||
"message": content[: self.MAX_MESSAGE_LENGTH],
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -401,20 +394,22 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
else:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
) as resp,
|
||||
):
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except TimeoutError:
|
||||
return SendResult(success=False, error="Timeout sending notification to HA")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
|
@ -423,7 +418,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
|
|||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
return {
|
||||
"name": "Home Assistant Events",
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ import os
|
|||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
|
|
@ -32,9 +32,9 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
|
@ -59,6 +59,7 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
|||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
|
|
@ -68,7 +69,7 @@ def _redact_phone(phone: str) -> str:
|
|||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
def _parse_comma_list(value: str) -> list[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
|||
Signal encodes @mentions as the Unicode object replacement character
|
||||
with out-of-band metadata containing the mentioned user's UUID/number.
|
||||
"""
|
||||
if not mentions or "\uFFFC" not in text:
|
||||
if not mentions or "\ufffc" not in text:
|
||||
return text
|
||||
# Sort mentions by start position (reverse) to replace from end to start
|
||||
# so indices don't shift as we replace
|
||||
|
|
@ -121,7 +122,7 @@ def _render_mentions(text: str, mentions: list) -> str:
|
|||
# Use the mention's number or UUID as the replacement
|
||||
identifier = mention.get("number") or mention.get("uuid") or "user"
|
||||
replacement = f"@{identifier}"
|
||||
text = text[:start] + replacement + text[start + length:]
|
||||
text = text[:start] + replacement + text[start + length :]
|
||||
return text
|
||||
|
||||
|
||||
|
|
@ -134,6 +135,7 @@ def check_signal_requirements() -> bool:
|
|||
# Signal Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SignalAdapter(BasePlatformAdapter):
|
||||
"""Signal messenger adapter using signal-cli HTTP daemon."""
|
||||
|
||||
|
|
@ -152,22 +154,25 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
|
||||
|
||||
# HTTP client
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
self.client: httpx.AsyncClient | None = None
|
||||
|
||||
# Background tasks
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._sse_task: asyncio.Task | None = None
|
||||
self._health_monitor_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
self._sse_response: httpx.Response | None = None
|
||||
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
logger.info(
|
||||
"Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url,
|
||||
_redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
|
|
@ -241,7 +246,8 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
logger.debug("Signal SSE: connecting to %s", url)
|
||||
async with self.client.stream(
|
||||
"GET", url,
|
||||
"GET",
|
||||
url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=None,
|
||||
) as response:
|
||||
|
|
@ -306,9 +312,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
|
||||
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
|
||||
try:
|
||||
resp = await self.client.get(
|
||||
f"{self.http_url}/api/v1/check", timeout=10.0
|
||||
)
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
# Daemon is alive but SSE is idle — update activity to
|
||||
# avoid repeated warnings (connection may just be quiet)
|
||||
|
|
@ -345,11 +349,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = (
|
||||
envelope_data.get("sourceNumber")
|
||||
or envelope_data.get("sourceUuid")
|
||||
or envelope_data.get("source")
|
||||
)
|
||||
sender = envelope_data.get("sourceNumber") or envelope_data.get("sourceUuid") or envelope_data.get("source")
|
||||
sender_name = envelope_data.get("sourceName", "")
|
||||
sender_uuid = envelope_data.get("sourceUuid", "")
|
||||
|
||||
|
|
@ -367,10 +367,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
# Get data message — also check editMessage (edited messages contain
|
||||
# their updated dataMessage inside editMessage.dataMessage)
|
||||
data_message = (
|
||||
envelope_data.get("dataMessage")
|
||||
or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
)
|
||||
data_message = envelope_data.get("dataMessage") or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
if not data_message:
|
||||
return
|
||||
|
||||
|
|
@ -451,11 +448,11 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
if ts_ms:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=UTC)
|
||||
except (ValueError, OSError):
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
else:
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
timestamp = datetime.now(tz=UTC)
|
||||
|
||||
# Build and dispatch event
|
||||
event = MessageEvent(
|
||||
|
|
@ -468,8 +465,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
logger.debug("Signal: message from %s in %s: %s", _redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
|
@ -479,10 +475,13 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
async def _fetch_attachment(self, attachment_id: str) -> tuple:
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc("getAttachment", {
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
})
|
||||
result = await self._rpc(
|
||||
"getAttachment",
|
||||
{
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None, ""
|
||||
|
|
@ -547,13 +546,13 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to_message_id: Optional[str] = None,
|
||||
reply_to_message_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a text message."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": text,
|
||||
}
|
||||
|
|
@ -571,7 +570,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
|
||||
|
|
@ -586,7 +585,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an image. Supports http(s):// and file:// URLs."""
|
||||
|
|
@ -611,7 +610,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
|
|
@ -631,8 +630,8 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
filename: str | None = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment."""
|
||||
|
|
@ -641,7 +640,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
if not Path(file_path).exists():
|
||||
return SendResult(success=False, error="File not found")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
params: dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
|
|
@ -690,7 +689,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a chat/contact."""
|
||||
if chat_id.startswith("group:"):
|
||||
return {
|
||||
|
|
@ -700,10 +699,13 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
}
|
||||
|
||||
# Try to resolve contact name
|
||||
result = await self._rpc("getContact", {
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
})
|
||||
result = await self._rpc(
|
||||
"getContact",
|
||||
{
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
},
|
||||
)
|
||||
|
||||
name = chat_id
|
||||
if result and isinstance(result, dict):
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ Uses slack-bolt (Python) with Socket Mode for:
|
|||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
||||
from slack_bolt.async_app import AsyncApp
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
SLACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
SLACK_AVAILABLE = False
|
||||
|
|
@ -26,18 +27,17 @@ except ImportError:
|
|||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -66,9 +66,9 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SLACK)
|
||||
self._app: Optional[AsyncApp] = None
|
||||
self._handler: Optional[AsyncSocketModeHandler] = None
|
||||
self._bot_user_id: Optional[str] = None
|
||||
self._app: AsyncApp | None = None
|
||||
self._handler: AsyncSocketModeHandler | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Slack via Socket Mode."""
|
||||
|
|
@ -135,8 +135,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a message to a Slack channel or DM."""
|
||||
if not self._app:
|
||||
|
|
@ -193,8 +193,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
|
|
@ -202,6 +202,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
|
@ -222,8 +223,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image to Slack by uploading the URL as a file."""
|
||||
if not self._app:
|
||||
|
|
@ -247,7 +248,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Fall back to sending the URL as text
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
|
@ -256,8 +257,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
if not self._app:
|
||||
|
|
@ -280,8 +281,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video file to Slack."""
|
||||
if not self._app:
|
||||
|
|
@ -308,9 +309,9 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment to Slack."""
|
||||
if not self._app:
|
||||
|
|
@ -335,7 +336,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
print(f"[{self.name}] Failed to send document: {e}")
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to)
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Slack channel."""
|
||||
if not self._app:
|
||||
return {"name": chat_id, "type": "unknown"}
|
||||
|
|
@ -442,9 +443,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
# Download and cache
|
||||
raw_bytes = await self._download_slack_file_bytes(url)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, original_filename or f"document{ext}"
|
||||
)
|
||||
cached_path = cache_document_from_bytes(raw_bytes, original_filename or f"document{ext}")
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
|
|
@ -457,7 +456,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if text:
|
||||
text = f"{injection}\n\n{text}"
|
||||
|
|
@ -499,16 +498,20 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
# Map subcommands to gateway commands
|
||||
subcommand_map = {
|
||||
"new": "/reset", "reset": "/reset",
|
||||
"status": "/status", "stop": "/stop",
|
||||
"new": "/reset",
|
||||
"reset": "/reset",
|
||||
"status": "/status",
|
||||
"stop": "/stop",
|
||||
"help": "/help",
|
||||
"model": "/model", "personality": "/personality",
|
||||
"retry": "/retry", "undo": "/undo",
|
||||
"model": "/model",
|
||||
"personality": "/personality",
|
||||
"retry": "/retry",
|
||||
"undo": "/undo",
|
||||
}
|
||||
first_word = text.split()[0] if text else ""
|
||||
if first_word in subcommand_map:
|
||||
# Preserve arguments after the subcommand
|
||||
rest = text[len(first_word):].strip()
|
||||
rest = text[len(first_word) :].strip()
|
||||
text = f"{subcommand_map[first_word]} {rest}".strip() if rest else subcommand_map[first_word]
|
||||
elif text:
|
||||
pass # Treat as a regular question
|
||||
|
|
@ -544,9 +547,11 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
|
|
|
|||
|
|
@ -7,24 +7,26 @@ Uses python-telegram-bot library for:
|
|||
- Handling media and commands
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram import Bot, Message, Update
|
||||
from telegram.constants import ChatType, ParseMode
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
from telegram.ext import (
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
)
|
||||
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
|
|
@ -42,22 +44,24 @@ except ImportError:
|
|||
# don't crash during class definition when the library isn't installed.
|
||||
class _MockContextTypes:
|
||||
DEFAULT_TYPE = Any
|
||||
|
||||
ContextTypes = _MockContextTypes
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -68,12 +72,12 @@ def check_telegram_requirements() -> bool:
|
|||
|
||||
# Matches every character that MarkdownV2 requires to be backslash-escaped
|
||||
# when it appears outside a code span or fenced code block.
|
||||
_MDV2_ESCAPE_RE = re.compile(r'([_*\[\]()~`>#\+\-=|{}.!\\])')
|
||||
_MDV2_ESCAPE_RE = re.compile(r"([_*\[\]()~`>#\+\-=|{}.!\\])")
|
||||
|
||||
|
||||
def _escape_mdv2(text: str) -> str:
|
||||
"""Escape Telegram MarkdownV2 special characters with a preceding backslash."""
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
return _MDV2_ESCAPE_RE.sub(r"\\\1", text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
|
|
@ -83,103 +87,108 @@ def _strip_mdv2(text: str) -> str:
|
|||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
cleaned = re.sub(r"\\([_*\[\]()~`>#\+\-=|{}.!\\])", r"\1", text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
cleaned = re.sub(r"\*([^*]+)\*", r"\1", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
||||
|
||||
Handles:
|
||||
- Receiving messages from users and groups
|
||||
- Sending responses with Telegram markdown
|
||||
- Forum topics (thread_id support)
|
||||
- Media messages
|
||||
"""
|
||||
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
|
||||
self._app: Application | None = None
|
||||
self._bot: Bot | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
|
||||
return False
|
||||
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
|
||||
# Register handlers
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
self._handle_text_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.COMMAND,
|
||||
self._handle_command
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION),
|
||||
self._handle_location_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL | filters.Sticker.ALL,
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
self._app.add_handler(TelegramMessageHandler(filters.TEXT & ~filters.COMMAND, self._handle_text_message))
|
||||
self._app.add_handler(TelegramMessageHandler(filters.COMMAND, self._handle_command))
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.LOCATION | getattr(filters, "VENUE", filters.LOCATION), self._handle_location_message
|
||||
)
|
||||
)
|
||||
self._app.add_handler(
|
||||
TelegramMessageHandler(
|
||||
filters.PHOTO
|
||||
| filters.VIDEO
|
||||
| filters.AUDIO
|
||||
| filters.VOICE
|
||||
| filters.Document.ALL
|
||||
| filters.Sticker.ALL,
|
||||
self._handle_media_message,
|
||||
)
|
||||
)
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
await self._bot.set_my_commands([
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
|
||||
await self._bot.set_my_commands(
|
||||
[
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("model", "Show or change the model"),
|
||||
BotCommand("personality", "Set a personality"),
|
||||
BotCommand("retry", "Retry your last message"),
|
||||
BotCommand("undo", "Remove the last exchange"),
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Could not register command menu: {e}")
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected and polling for updates")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
if self._app:
|
||||
|
|
@ -189,31 +198,27 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
|
||||
self._running = False
|
||||
self._app = None
|
||||
self._bot = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Telegram chat."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
|
|
@ -227,7 +232,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
logger.warning(
|
||||
"[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error
|
||||
)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
|
|
@ -241,13 +248,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
raw_response={"message_ids": message_ids},
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
|
|
@ -284,18 +291,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
|
|
@ -317,23 +325,24 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
|
|
@ -350,17 +359,17 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
|
|
@ -375,11 +384,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
|
|
@ -391,18 +401,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
chat_id: str,
|
||||
animation_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send an animated GIF natively as a Telegram animation (auto-plays inline)."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
msg = await self._bot.send_animation(
|
||||
chat_id=int(chat_id),
|
||||
|
|
@ -420,21 +430,18 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
try:
|
||||
await self._bot.send_chat_action(
|
||||
chat_id=int(chat_id),
|
||||
action="typing"
|
||||
)
|
||||
await self._bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a Telegram chat."""
|
||||
if not self._bot:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
chat = await self._bot.get_chat(int(chat_id))
|
||||
|
||||
|
||||
chat_type = "dm"
|
||||
if chat.type == ChatType.GROUP:
|
||||
chat_type = "group"
|
||||
|
|
@ -444,7 +451,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
chat_type = "forum"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
return {
|
||||
"name": chat.title or chat.full_name or str(chat_id),
|
||||
"type": chat_type,
|
||||
|
|
@ -453,7 +460,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Convert standard markdown to Telegram MarkdownV2 format.
|
||||
|
|
@ -480,38 +487,36 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
text = re.sub(
|
||||
r'(```(?:[^\n]*\n)?[\s\S]*?```)',
|
||||
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
||||
lambda m: _ph(m.group(0)),
|
||||
text,
|
||||
)
|
||||
|
||||
# 2) Protect inline code (`...`)
|
||||
text = re.sub(r'(`[^`]+`)', lambda m: _ph(m.group(0)), text)
|
||||
text = re.sub(r"(`[^`]+`)", lambda m: _ph(m.group(0)), text)
|
||||
|
||||
# 3) Convert markdown links – escape the display text; inside the URL
|
||||
# only ')' and '\' need escaping per the MarkdownV2 spec.
|
||||
def _convert_link(m):
|
||||
display = _escape_mdv2(m.group(1))
|
||||
url = m.group(2).replace('\\', '\\\\').replace(')', '\\)')
|
||||
return _ph(f'[{display}]({url})')
|
||||
url = m.group(2).replace("\\", "\\\\").replace(")", "\\)")
|
||||
return _ph(f"[{display}]({url})")
|
||||
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', _convert_link, text)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", _convert_link, text)
|
||||
|
||||
# 4) Convert markdown headers (## Title) → bold *Title*
|
||||
def _convert_header(m):
|
||||
inner = m.group(1).strip()
|
||||
# Strip redundant bold markers that may appear inside a header
|
||||
inner = re.sub(r'\*\*(.+?)\*\*', r'\1', inner)
|
||||
return _ph(f'*{_escape_mdv2(inner)}*')
|
||||
inner = re.sub(r"\*\*(.+?)\*\*", r"\1", inner)
|
||||
return _ph(f"*{_escape_mdv2(inner)}*")
|
||||
|
||||
text = re.sub(
|
||||
r'^#{1,6}\s+(.+)$', _convert_header, text, flags=re.MULTILINE
|
||||
)
|
||||
text = re.sub(r"^#{1,6}\s+(.+)$", _convert_header, text, flags=re.MULTILINE)
|
||||
|
||||
# 5) Convert bold: **text** → *text* (MarkdownV2 bold)
|
||||
text = re.sub(
|
||||
r'\*\*(.+?)\*\*',
|
||||
lambda m: _ph(f'*{_escape_mdv2(m.group(1))}*'),
|
||||
r"\*\*(.+?)\*\*",
|
||||
lambda m: _ph(f"*{_escape_mdv2(m.group(1))}*"),
|
||||
text,
|
||||
)
|
||||
|
||||
|
|
@ -519,8 +524,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# [^*\n]+ prevents matching across newlines (which would corrupt
|
||||
# bullet lists using * markers and multi-line content).
|
||||
text = re.sub(
|
||||
r'\*([^*\n]+)\*',
|
||||
lambda m: _ph(f'_{_escape_mdv2(m.group(1))}_'),
|
||||
r"\*([^*\n]+)\*",
|
||||
lambda m: _ph(f"_{_escape_mdv2(m.group(1))}_"),
|
||||
text,
|
||||
)
|
||||
|
||||
|
|
@ -533,23 +538,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
text = text.replace(key, placeholders[key])
|
||||
|
||||
return text
|
||||
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming location/venue pin messages."""
|
||||
if not update.message:
|
||||
|
|
@ -589,9 +594,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
|
||||
msg = update.message
|
||||
|
||||
|
||||
# Determine media type
|
||||
if msg.sticker:
|
||||
msg_type = MessageType.STICKER
|
||||
|
|
@ -607,19 +612,19 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
msg_type = MessageType.DOCUMENT
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
|
||||
# Handle stickers: describe via vision tool with caching
|
||||
if msg.sticker:
|
||||
await self._handle_sticker(msg, event)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
|
|
@ -643,7 +648,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
|
||||
|
||||
|
||||
# Download voice/audio messages to cache for STT transcription
|
||||
if msg.voice:
|
||||
try:
|
||||
|
|
@ -685,10 +690,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Check if supported
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
supported_list = ", ".join(sorted(SUPPORTED_DOCUMENT_TYPES.keys()))
|
||||
event.text = (
|
||||
f"Unsupported document type '{ext or 'unknown'}'. "
|
||||
f"Supported types: {supported_list}"
|
||||
)
|
||||
event.text = f"Unsupported document type '{ext or 'unknown'}'. Supported types: {supported_list}"
|
||||
print(f"[Telegram] Unsupported document type: {ext or 'unknown'}", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
|
@ -696,10 +698,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Check file size (Telegram Bot API limit: 20 MB)
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if not doc.file_size or doc.file_size > MAX_DOC_BYTES:
|
||||
event.text = (
|
||||
"The document is too large or its size could not be verified. "
|
||||
"Maximum: 20 MB."
|
||||
)
|
||||
event.text = "The document is too large or its size could not be verified. Maximum: 20 MB."
|
||||
print(f"[Telegram] Document too large: {doc.file_size} bytes", flush=True)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
|
@ -720,20 +719,20 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = original_filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
display_name = re.sub(r"[^\w.\- ]", "_", display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if event.text:
|
||||
event.text = f"{injection}\n\n{event.text}"
|
||||
else:
|
||||
event.text = injection
|
||||
except UnicodeDecodeError:
|
||||
print(f"[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
print("[Telegram] Could not decode text file as UTF-8, skipping content injection", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache document: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
||||
async def _handle_sticker(self, msg: Message, event: "MessageEvent") -> None:
|
||||
"""
|
||||
Describe a Telegram sticker via vision analysis, with caching.
|
||||
|
|
@ -743,11 +742,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
a placeholder noting the emoji.
|
||||
"""
|
||||
from gateway.sticker_cache import (
|
||||
get_cached_description,
|
||||
cache_sticker_description,
|
||||
build_sticker_injection,
|
||||
build_animated_sticker_injection,
|
||||
STICKER_VISION_PROMPT,
|
||||
build_animated_sticker_injection,
|
||||
build_sticker_injection,
|
||||
cache_sticker_description,
|
||||
get_cached_description,
|
||||
)
|
||||
|
||||
sticker = msg.sticker
|
||||
|
|
@ -775,9 +774,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=".webp")
|
||||
print(f"[Telegram] Analyzing sticker: {cached_path}", flush=True)
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
import json as _json
|
||||
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
|
||||
result_json = await vision_analyze_tool(
|
||||
image_url=cached_path,
|
||||
user_prompt=STICKER_VISION_PROMPT,
|
||||
|
|
@ -792,27 +792,29 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Vision failed -- use emoji as fallback
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji, set_name,
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Sticker analysis error: {e}", flush=True)
|
||||
event.text = build_sticker_injection(
|
||||
f"a sticker with emoji {emoji}" if emoji else "a sticker",
|
||||
emoji, set_name,
|
||||
emoji,
|
||||
set_name,
|
||||
)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
|
||||
# Determine chat type
|
||||
chat_type = "dm"
|
||||
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
|
|
@ -822,7 +824,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ with different backends via a bridge pattern.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
|
|
@ -24,7 +23,7 @@ import subprocess
|
|||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -36,7 +35,9 @@ def _kill_port_process(port: int) -> None:
|
|||
# Use netstat to find the PID bound to this port, then taskkill
|
||||
result = subprocess.run(
|
||||
["netstat", "-ano", "-p", "TCP"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
for line in result.stdout.splitlines():
|
||||
parts = line.split()
|
||||
|
|
@ -46,24 +47,29 @@ def _kill_port_process(port: int) -> None:
|
|||
try:
|
||||
subprocess.run(
|
||||
["taskkill", "/PID", parts[4], "/F"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except subprocess.SubprocessError:
|
||||
pass
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["fuser", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
subprocess.run(
|
||||
["fuser", "-k", f"{port}/tcp"],
|
||||
capture_output=True, timeout=5,
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
|
@ -72,25 +78,20 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
def check_whatsapp_requirements() -> bool:
|
||||
"""
|
||||
Check if WhatsApp dependencies are available.
|
||||
|
||||
|
||||
WhatsApp requires a Node.js bridge for most implementations.
|
||||
"""
|
||||
# Check for Node.js
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
result = subprocess.run(["node", "--version"], capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -99,62 +100,61 @@ def check_whatsapp_requirements() -> bool:
|
|||
class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
WhatsApp adapter.
|
||||
|
||||
|
||||
This implementation uses a simple HTTP bridge pattern where:
|
||||
1. A Node.js process runs the WhatsApp Web client
|
||||
2. Messages are forwarded via HTTP/IPC to this Python adapter
|
||||
3. Responses are sent back through the bridge
|
||||
|
||||
|
||||
The actual Node.js bridge implementation can vary:
|
||||
- whatsapp-web.js based
|
||||
- Baileys based
|
||||
- Business API based
|
||||
|
||||
|
||||
Configuration:
|
||||
- bridge_script: Path to the Node.js bridge script
|
||||
- bridge_port: Port for HTTP communication (default: 3000)
|
||||
- session_path: Path to store WhatsApp session data
|
||||
"""
|
||||
|
||||
|
||||
# WhatsApp message limits
|
||||
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
|
||||
|
||||
|
||||
# Default bridge location relative to the hermes-agent install
|
||||
_DEFAULT_BRIDGE_DIR = Path(__file__).resolve().parents[2] / "scripts" / "whatsapp-bridge"
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WHATSAPP)
|
||||
self._bridge_process: Optional[subprocess.Popen] = None
|
||||
self._bridge_process: subprocess.Popen | None = None
|
||||
self._bridge_port: int = config.extra.get("bridge_port", 3000)
|
||||
self._bridge_script: Optional[str] = config.extra.get(
|
||||
self._bridge_script: str | None = config.extra.get(
|
||||
"bridge_script",
|
||||
str(self._DEFAULT_BRIDGE_DIR / "bridge.js"),
|
||||
)
|
||||
self._session_path: Path = Path(config.extra.get(
|
||||
"session_path",
|
||||
Path.home() / ".hermes" / "whatsapp" / "session"
|
||||
))
|
||||
self._session_path: Path = Path(
|
||||
config.extra.get("session_path", Path.home() / ".hermes" / "whatsapp" / "session")
|
||||
)
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._bridge_log_fh = None
|
||||
self._bridge_log: Optional[Path] = None
|
||||
|
||||
self._bridge_log: Path | None = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Start the WhatsApp bridge.
|
||||
|
||||
|
||||
This launches the Node.js bridge process and waits for it to be ready.
|
||||
"""
|
||||
if not check_whatsapp_requirements():
|
||||
logger.warning("[%s] Node.js not found. WhatsApp requires Node.js.", self.name)
|
||||
return False
|
||||
|
||||
|
||||
bridge_path = Path(self._bridge_script)
|
||||
if not bridge_path.exists():
|
||||
logger.warning("[%s] Bridge script not found: %s", self.name, bridge_path)
|
||||
return False
|
||||
|
||||
|
||||
logger.info("[%s] Bridge found at %s", self.name, bridge_path)
|
||||
|
||||
|
||||
# Auto-install npm dependencies if node_modules doesn't exist
|
||||
bridge_dir = bridge_path.parent
|
||||
if not (bridge_dir / "node_modules").exists():
|
||||
|
|
@ -174,16 +174,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to install dependencies: {e}")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Kill any orphaned bridge from a previous gateway run
|
||||
_kill_port_process(self._bridge_port)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
# Start the bridge process in its own process group.
|
||||
# Route output to a log file so QR codes, errors, and reconnection
|
||||
# messages are preserved for troubleshooting.
|
||||
|
|
@ -195,19 +196,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port", str(self._bridge_port),
|
||||
"--session", str(self._session_path),
|
||||
"--mode", whatsapp_mode,
|
||||
"--port",
|
||||
str(self._bridge_port),
|
||||
"--session",
|
||||
str(self._session_path),
|
||||
"--mode",
|
||||
whatsapp_mode,
|
||||
],
|
||||
stdout=bridge_log_fh,
|
||||
stderr=bridge_log_fh,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
|
||||
# Wait for the bridge to connect to WhatsApp.
|
||||
# Phase 1: wait for the HTTP server to come up (up to 15s).
|
||||
# Phase 2: wait for WhatsApp status: connected (up to 15s more).
|
||||
import aiohttp
|
||||
|
||||
http_ready = False
|
||||
data = {}
|
||||
for attempt in range(15):
|
||||
|
|
@ -218,17 +223,18 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
http_ready = True
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
|
@ -237,7 +243,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
print(f"[{self.name}] Check log: {self._bridge_log}")
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
# Phase 2: HTTP is up but WhatsApp may still be connecting.
|
||||
# Give it more time to authenticate with saved credentials.
|
||||
if data.get("status") != "connected":
|
||||
|
|
@ -250,16 +256,17 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self._close_bridge_log()
|
||||
return False
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/health",
|
||||
timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/health", timeout=aiohttp.ClientTimeout(total=2)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
if data.get("status") == "connected":
|
||||
print(f"[{self.name}] Bridge ready (status: connected)")
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
|
|
@ -268,19 +275,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
print(f"[{self.name}] ⚠ WhatsApp not connected after 30s")
|
||||
print(f"[{self.name}] Bridge log: {self._bridge_log}")
|
||||
print(f"[{self.name}] If session expired, re-pair: hermes whatsapp")
|
||||
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
||||
|
||||
def _close_bridge_log(self) -> None:
|
||||
"""Close the bridge log file handle if open."""
|
||||
if self._bridge_log_fh:
|
||||
|
|
@ -296,6 +303,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
# Kill the entire process group so child node processes die too
|
||||
import signal
|
||||
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.terminate()
|
||||
|
|
@ -314,29 +322,25 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
|
||||
# Also kill any orphaned bridge processes on our port
|
||||
_kill_port_process(self._bridge_port)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
self, chat_id: str, content: str, reply_to: str | None = None, metadata: dict[str, Any] | None = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
|
|
@ -344,28 +348,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
f"http://localhost:{self._bridge_port}/send", json=payload, timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
return SendResult(success=True, message_id=data.get("messageId"), raw_response=data)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
|
||||
except ImportError:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="aiohttp not installed. Run: pip install aiohttp"
|
||||
)
|
||||
return SendResult(success=False, error="aiohttp not installed. Run: pip install aiohttp")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
|
|
@ -380,21 +375,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/edit",
|
||||
json={
|
||||
"chatId": chat_id,
|
||||
"messageId": message_id,
|
||||
"message": content,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
|
|
@ -403,8 +401,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
chat_id: str,
|
||||
file_path: str,
|
||||
media_type: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
|
|
@ -415,7 +413,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
if not os.path.exists(file_path):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
payload: dict[str, Any] = {
|
||||
"chatId": chat_id,
|
||||
"filePath": file_path,
|
||||
"mediaType": media_type,
|
||||
|
|
@ -425,22 +423,24 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
if file_name:
|
||||
payload["fileName"] = file_name
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
f"http://localhost:{self._bridge_port}/send-media",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=120),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data,
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
|
@ -449,8 +449,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Download image URL to cache, send natively via bridge."""
|
||||
try:
|
||||
|
|
@ -463,8 +463,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively via bridge."""
|
||||
return await self._send_media_to_bridge(chat_id, image_path, "image", caption)
|
||||
|
|
@ -473,8 +473,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a video natively via bridge — plays inline in WhatsApp."""
|
||||
return await self._send_media_to_bridge(chat_id, video_path, "video", caption)
|
||||
|
|
@ -483,13 +483,16 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
caption: str | None = None,
|
||||
file_name: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> SendResult:
|
||||
"""Send a document/file as a downloadable attachment via bridge."""
|
||||
return await self._send_media_to_bridge(
|
||||
chat_id, file_path, "document", caption,
|
||||
chat_id,
|
||||
file_path,
|
||||
"document",
|
||||
caption,
|
||||
file_name or os.path.basename(file_path),
|
||||
)
|
||||
|
||||
|
|
@ -497,44 +500,45 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}", timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e)
|
||||
|
||||
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
|
|
@ -542,29 +546,30 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages", timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp,
|
||||
):
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
|
||||
async def _build_message_event(self, data: dict[str, Any]) -> MessageEvent | None:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
# Determine message type
|
||||
|
|
@ -579,11 +584,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
|
||||
# Determine chat type
|
||||
is_group = data.get("isGroup", False)
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=data.get("chatId", ""),
|
||||
|
|
@ -592,7 +597,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
user_id=data.get("senderId"),
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
|
|
@ -622,7 +627,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
|
|
@ -635,4 +640,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue