mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
2284 lines
90 KiB
Python
2284 lines
90 KiB
Python
"""
|
|
Base platform adapter interface.
|
|
|
|
All platform adapters (Telegram, Discord, WhatsApp) inherit from this
|
|
and implement the required methods.
|
|
"""
|
|
|
|
import asyncio
|
|
import ipaddress
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import socket as _socket
|
|
import subprocess
|
|
import sys
|
|
import uuid
|
|
from abc import ABC, abstractmethod
|
|
from urllib.parse import urlsplit
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def utf16_len(s: str) -> int:
|
|
"""Count UTF-16 code units in *s*.
|
|
|
|
Telegram's message-length limit (4 096) is measured in UTF-16 code units,
|
|
**not** Unicode code-points. Characters outside the Basic Multilingual
|
|
Plane (emoji like 😀, CJK Extension B, musical symbols, …) are encoded as
|
|
surrogate pairs and therefore consume **two** UTF-16 code units each, even
|
|
though Python's ``len()`` counts them as one.
|
|
|
|
Ported from nearai/ironclaw#2304 which discovered the same discrepancy in
|
|
Rust's ``chars().count()``.
|
|
"""
|
|
return len(s.encode("utf-16-le")) // 2
|
|
|
|
|
|
def _prefix_within_utf16_limit(s: str, limit: int) -> str:
|
|
"""Return the longest prefix of *s* whose UTF-16 length ≤ *limit*.
|
|
|
|
Unlike a plain ``s[:limit]``, this respects surrogate-pair boundaries so
|
|
we never slice a multi-code-unit character in half.
|
|
"""
|
|
if utf16_len(s) <= limit:
|
|
return s
|
|
# Binary search for the longest safe prefix
|
|
lo, hi = 0, len(s)
|
|
while lo < hi:
|
|
mid = (lo + hi + 1) // 2
|
|
if utf16_len(s[:mid]) <= limit:
|
|
lo = mid
|
|
else:
|
|
hi = mid - 1
|
|
return s[:lo]
|
|
|
|
|
|
def _custom_unit_to_cp(s: str, budget: int, len_fn) -> int:
|
|
"""Return the largest codepoint offset *n* such that ``len_fn(s[:n]) <= budget``.
|
|
|
|
Used by :meth:`BasePlatformAdapter.truncate_message` when *len_fn* measures
|
|
length in units different from Python codepoints (e.g. UTF-16 code units).
|
|
Falls back to binary search which is O(log n) calls to *len_fn*.
|
|
"""
|
|
if len_fn(s) <= budget:
|
|
return len(s)
|
|
lo, hi = 0, len(s)
|
|
while lo < hi:
|
|
mid = (lo + hi + 1) // 2
|
|
if len_fn(s[:mid]) <= budget:
|
|
lo = mid
|
|
else:
|
|
hi = mid - 1
|
|
return lo
|
|
|
|
|
|
def is_network_accessible(host: str) -> bool:
|
|
"""Return True if *host* would expose the server beyond loopback.
|
|
|
|
Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1)
|
|
are local-only. Unspecified addresses (0.0.0.0, ::) bind all
|
|
interfaces. Hostnames are resolved; DNS failure fails closed.
|
|
"""
|
|
try:
|
|
addr = ipaddress.ip_address(host)
|
|
if addr.is_loopback:
|
|
return False
|
|
# ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped
|
|
# addresses, so check the underlying IPv4 explicitly.
|
|
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
|
return False
|
|
return True
|
|
except ValueError:
|
|
# when host variable is a hostname, we should try to resolve below
|
|
pass
|
|
|
|
try:
|
|
resolved = _socket.getaddrinfo(
|
|
host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM,
|
|
)
|
|
# if the hostname resolves into at least one non-loopback address,
|
|
# then we consider it to be network accessible
|
|
for _family, _type, _proto, _canonname, sockaddr in resolved:
|
|
addr = ipaddress.ip_address(sockaddr[0])
|
|
if not addr.is_loopback:
|
|
return True
|
|
return False
|
|
except (_socket.gaierror, OSError):
|
|
return True
|
|
|
|
|
|
def _detect_macos_system_proxy() -> str | None:
|
|
"""Read the macOS system HTTP(S) proxy via ``scutil --proxy``.
|
|
|
|
Returns an ``http://host:port`` URL string if an HTTP or HTTPS proxy is
|
|
enabled, otherwise *None*. Falls back silently on non-macOS or on any
|
|
subprocess error.
|
|
"""
|
|
if sys.platform != "darwin":
|
|
return None
|
|
try:
|
|
out = subprocess.check_output(
|
|
["scutil", "--proxy"], timeout=3, text=True, stderr=subprocess.DEVNULL,
|
|
)
|
|
except Exception:
|
|
return None
|
|
|
|
props: dict[str, str] = {}
|
|
for line in out.splitlines():
|
|
line = line.strip()
|
|
if " : " in line:
|
|
key, _, val = line.partition(" : ")
|
|
props[key.strip()] = val.strip()
|
|
|
|
# Prefer HTTPS, fall back to HTTP
|
|
for enable_key, host_key, port_key in (
|
|
("HTTPSEnable", "HTTPSProxy", "HTTPSPort"),
|
|
("HTTPEnable", "HTTPProxy", "HTTPPort"),
|
|
):
|
|
if props.get(enable_key) == "1":
|
|
host = props.get(host_key)
|
|
port = props.get(port_key)
|
|
if host and port:
|
|
return f"http://{host}:{port}"
|
|
return None
|
|
|
|
|
|
def resolve_proxy_url(platform_env_var: str | None = None) -> str | None:
|
|
"""Return a proxy URL from env vars, or macOS system proxy.
|
|
|
|
Check order:
|
|
0. *platform_env_var* (e.g. ``DISCORD_PROXY``) — highest priority
|
|
1. HTTPS_PROXY / HTTP_PROXY / ALL_PROXY (and lowercase variants)
|
|
2. macOS system proxy via ``scutil --proxy`` (auto-detect)
|
|
|
|
Returns *None* if no proxy is found.
|
|
"""
|
|
if platform_env_var:
|
|
value = (os.environ.get(platform_env_var) or "").strip()
|
|
if value:
|
|
return value
|
|
for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY",
|
|
"https_proxy", "http_proxy", "all_proxy"):
|
|
value = (os.environ.get(key) or "").strip()
|
|
if value:
|
|
return value
|
|
return _detect_macos_system_proxy()
|
|
|
|
|
|
def proxy_kwargs_for_bot(proxy_url: str | None) -> dict:
|
|
"""Build kwargs for ``commands.Bot()`` / ``discord.Client()`` with proxy.
|
|
|
|
Returns:
|
|
- SOCKS URL → ``{"connector": ProxyConnector(..., rdns=True)}``
|
|
- HTTP URL → ``{"proxy": url}``
|
|
- *None* → ``{}``
|
|
|
|
``rdns=True`` forces remote DNS resolution through the proxy — required
|
|
by many SOCKS implementations (Shadowrocket, Clash) and essential for
|
|
bypassing DNS pollution behind the GFW.
|
|
"""
|
|
if not proxy_url:
|
|
return {}
|
|
if proxy_url.lower().startswith("socks"):
|
|
try:
|
|
from aiohttp_socks import ProxyConnector
|
|
|
|
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
|
return {"connector": connector}
|
|
except ImportError:
|
|
logger.warning(
|
|
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
|
"Run: pip install aiohttp-socks",
|
|
proxy_url,
|
|
)
|
|
return {}
|
|
return {"proxy": proxy_url}
|
|
|
|
|
|
def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
|
|
"""Build kwargs for standalone ``aiohttp.ClientSession`` with proxy.
|
|
|
|
Returns ``(session_kwargs, request_kwargs)`` where:
|
|
- SOCKS → ``({"connector": ProxyConnector(...)}, {})``
|
|
- HTTP → ``({}, {"proxy": url})``
|
|
- None → ``({}, {})``
|
|
|
|
Usage::
|
|
|
|
sess_kw, req_kw = proxy_kwargs_for_aiohttp(proxy_url)
|
|
async with aiohttp.ClientSession(**sess_kw) as session:
|
|
async with session.get(url, **req_kw) as resp:
|
|
...
|
|
"""
|
|
if not proxy_url:
|
|
return {}, {}
|
|
if proxy_url.lower().startswith("socks"):
|
|
try:
|
|
from aiohttp_socks import ProxyConnector
|
|
|
|
connector = ProxyConnector.from_url(proxy_url, rdns=True)
|
|
return {"connector": connector}, {}
|
|
except ImportError:
|
|
logger.warning(
|
|
"aiohttp_socks not installed — SOCKS proxy %s ignored. "
|
|
"Run: pip install aiohttp-socks",
|
|
proxy_url,
|
|
)
|
|
return {}, {}
|
|
return {}, {"proxy": proxy_url}
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
|
from enum import Enum
|
|
|
|
from pathlib import Path as _Path
|
|
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
|
|
|
from gateway.config import Platform, PlatformConfig
|
|
from gateway.session import SessionSource, build_session_key
|
|
from hermes_constants import get_hermes_dir
|
|
|
|
|
|
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = (
|
|
"Secure secret entry is not supported over messaging. "
|
|
"Load this skill in the local CLI to be prompted, or add the key to ~/.hermes/.env manually."
|
|
)
|
|
|
|
|
|
def safe_url_for_log(url: str, max_len: int = 80) -> str:
|
|
"""Return a URL string safe for logs (no query/fragment/userinfo)."""
|
|
if max_len <= 0:
|
|
return ""
|
|
|
|
if url is None:
|
|
return ""
|
|
|
|
raw = str(url)
|
|
if not raw:
|
|
return ""
|
|
|
|
try:
|
|
parsed = urlsplit(raw)
|
|
except Exception:
|
|
return raw[:max_len]
|
|
|
|
if parsed.scheme and parsed.netloc:
|
|
# Strip potential embedded credentials (user:pass@host).
|
|
netloc = parsed.netloc.rsplit("@", 1)[-1]
|
|
base = f"{parsed.scheme}://{netloc}"
|
|
path = parsed.path or ""
|
|
if path and path != "/":
|
|
basename = path.rsplit("/", 1)[-1]
|
|
safe = f"{base}/.../{basename}" if basename else f"{base}/..."
|
|
else:
|
|
safe = base
|
|
else:
|
|
safe = raw
|
|
|
|
if len(safe) <= max_len:
|
|
return safe
|
|
if max_len <= 3:
|
|
return "." * max_len
|
|
return f"{safe[:max_len - 3]}..."
|
|
|
|
|
|
async def _ssrf_redirect_guard(response):
|
|
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
|
|
|
Without this, an attacker can host a public URL that 302-redirects to
|
|
http://169.254.169.254/ and bypass the pre-flight is_safe_url() check.
|
|
|
|
Must be async because httpx.AsyncClient awaits response event hooks.
|
|
"""
|
|
if response.is_redirect and response.next_request:
|
|
redirect_url = str(response.next_request.url)
|
|
from tools.url_safety import is_safe_url
|
|
if not is_safe_url(redirect_url):
|
|
raise ValueError(
|
|
f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Image cache utilities
|
|
#
|
|
# When users send images on messaging platforms, we download them to a local
|
|
# cache directory so they can be analyzed by the vision tool (which accepts
|
|
# local file paths). This avoids issues with ephemeral platform URLs
|
|
# (e.g. Telegram file URLs expire after ~1 hour).
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Default location: {HERMES_HOME}/cache/images/ (legacy: image_cache/)
|
|
IMAGE_CACHE_DIR = get_hermes_dir("cache/images", "image_cache")
|
|
|
|
|
|
def get_image_cache_dir() -> Path:
|
|
"""Return the image cache directory, creating it if it doesn't exist."""
|
|
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
return IMAGE_CACHE_DIR
|
|
|
|
|
|
def _looks_like_image(data: bytes) -> bool:
|
|
"""Return True if *data* starts with a known image magic-byte sequence."""
|
|
if len(data) < 4:
|
|
return False
|
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
|
return True
|
|
if data[:3] == b"\xff\xd8\xff":
|
|
return True
|
|
if data[:6] in (b"GIF87a", b"GIF89a"):
|
|
return True
|
|
if data[:2] == b"BM":
|
|
return True
|
|
if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP":
|
|
return True
|
|
return False
|
|
|
|
|
|
def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
|
"""
|
|
Save raw image bytes to the cache and return the absolute file path.
|
|
|
|
Args:
|
|
data: Raw image bytes.
|
|
ext: File extension including the dot (e.g. ".jpg", ".png").
|
|
|
|
Returns:
|
|
Absolute path to the cached image file as a string.
|
|
|
|
Raises:
|
|
ValueError: If *data* does not look like a valid image (e.g. an HTML
|
|
error page returned by the upstream server).
|
|
"""
|
|
if not _looks_like_image(data):
|
|
snippet = data[:80].decode("utf-8", errors="replace")
|
|
raise ValueError(
|
|
f"Refusing to cache non-image data as {ext} "
|
|
f"(starts with: {snippet!r})"
|
|
)
|
|
cache_dir = get_image_cache_dir()
|
|
filename = f"img_{uuid.uuid4().hex[:12]}{ext}"
|
|
filepath = cache_dir / filename
|
|
filepath.write_bytes(data)
|
|
return str(filepath)
|
|
|
|
|
|
async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str:
|
|
"""
|
|
Download an image from a URL and save it to the local cache.
|
|
|
|
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
|
backoff so a single slow CDN response doesn't lose the media.
|
|
|
|
Args:
|
|
url: The HTTP/HTTPS URL to download from.
|
|
ext: File extension including the dot (e.g. ".jpg", ".png").
|
|
retries: Number of retry attempts on transient failures.
|
|
|
|
Returns:
|
|
Absolute path to the cached image file as a string.
|
|
|
|
Raises:
|
|
ValueError: If the URL targets a private/internal network (SSRF protection).
|
|
"""
|
|
from tools.url_safety import is_safe_url
|
|
if not is_safe_url(url):
|
|
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
|
|
|
import asyncio
|
|
import httpx
|
|
import logging as _logging
|
|
_log = _logging.getLogger(__name__)
|
|
|
|
last_exc = None
|
|
async with httpx.AsyncClient(
|
|
timeout=30.0,
|
|
follow_redirects=True,
|
|
event_hooks={"response": [_ssrf_redirect_guard]},
|
|
) as client:
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
response = await client.get(
|
|
url,
|
|
headers={
|
|
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
|
"Accept": "image/*,*/*;q=0.8",
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
return cache_image_from_bytes(response.content, ext)
|
|
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
|
last_exc = exc
|
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
|
raise
|
|
if attempt < retries:
|
|
wait = 1.5 * (attempt + 1)
|
|
_log.debug(
|
|
"Media cache retry %d/%d for %s (%.1fs): %s",
|
|
attempt + 1,
|
|
retries,
|
|
safe_url_for_log(url),
|
|
wait,
|
|
exc,
|
|
)
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
raise
|
|
raise last_exc
|
|
|
|
|
|
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
|
"""
|
|
Delete cached images older than *max_age_hours*.
|
|
|
|
Returns the number of files removed.
|
|
"""
|
|
import time
|
|
|
|
cache_dir = get_image_cache_dir()
|
|
cutoff = time.time() - (max_age_hours * 3600)
|
|
removed = 0
|
|
for f in cache_dir.iterdir():
|
|
if f.is_file() and f.stat().st_mtime < cutoff:
|
|
try:
|
|
f.unlink()
|
|
removed += 1
|
|
except OSError:
|
|
pass
|
|
return removed
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Audio cache utilities
|
|
#
|
|
# Same pattern as image cache -- voice messages from platforms are downloaded
|
|
# here so the STT tool (OpenAI Whisper) can transcribe them from local files.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
AUDIO_CACHE_DIR = get_hermes_dir("cache/audio", "audio_cache")
|
|
|
|
|
|
def get_audio_cache_dir() -> Path:
|
|
"""Return the audio cache directory, creating it if it doesn't exist."""
|
|
AUDIO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
return AUDIO_CACHE_DIR
|
|
|
|
|
|
def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str:
|
|
"""
|
|
Save raw audio bytes to the cache and return the absolute file path.
|
|
|
|
Args:
|
|
data: Raw audio bytes.
|
|
ext: File extension including the dot (e.g. ".ogg", ".mp3").
|
|
|
|
Returns:
|
|
Absolute path to the cached audio file as a string.
|
|
"""
|
|
cache_dir = get_audio_cache_dir()
|
|
filename = f"audio_{uuid.uuid4().hex[:12]}{ext}"
|
|
filepath = cache_dir / filename
|
|
filepath.write_bytes(data)
|
|
return str(filepath)
|
|
|
|
|
|
async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> str:
|
|
"""
|
|
Download an audio file from a URL and save it to the local cache.
|
|
|
|
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
|
backoff so a single slow CDN response doesn't lose the media.
|
|
|
|
Args:
|
|
url: The HTTP/HTTPS URL to download from.
|
|
ext: File extension including the dot (e.g. ".ogg", ".mp3").
|
|
retries: Number of retry attempts on transient failures.
|
|
|
|
Returns:
|
|
Absolute path to the cached audio file as a string.
|
|
|
|
Raises:
|
|
ValueError: If the URL targets a private/internal network (SSRF protection).
|
|
"""
|
|
from tools.url_safety import is_safe_url
|
|
if not is_safe_url(url):
|
|
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
|
|
|
|
import asyncio
|
|
import httpx
|
|
import logging as _logging
|
|
_log = _logging.getLogger(__name__)
|
|
|
|
last_exc = None
|
|
async with httpx.AsyncClient(
|
|
timeout=30.0,
|
|
follow_redirects=True,
|
|
event_hooks={"response": [_ssrf_redirect_guard]},
|
|
) as client:
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
response = await client.get(
|
|
url,
|
|
headers={
|
|
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
|
"Accept": "audio/*,*/*;q=0.8",
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
return cache_audio_from_bytes(response.content, ext)
|
|
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
|
last_exc = exc
|
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
|
raise
|
|
if attempt < retries:
|
|
wait = 1.5 * (attempt + 1)
|
|
_log.debug(
|
|
"Audio cache retry %d/%d for %s (%.1fs): %s",
|
|
attempt + 1,
|
|
retries,
|
|
safe_url_for_log(url),
|
|
wait,
|
|
exc,
|
|
)
|
|
await asyncio.sleep(wait)
|
|
continue
|
|
raise
|
|
raise last_exc
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Document cache utilities
|
|
#
|
|
# Same pattern as image/audio cache -- documents from platforms are downloaded
|
|
# here so the agent can reference them by local file path.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DOCUMENT_CACHE_DIR = get_hermes_dir("cache/documents", "document_cache")
|
|
|
|
SUPPORTED_DOCUMENT_TYPES = {
|
|
".pdf": "application/pdf",
|
|
".md": "text/markdown",
|
|
".txt": "text/plain",
|
|
".log": "text/plain",
|
|
".zip": "application/zip",
|
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
}
|
|
|
|
|
|
def get_document_cache_dir() -> Path:
|
|
"""Return the document cache directory, creating it if it doesn't exist."""
|
|
DOCUMENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
return DOCUMENT_CACHE_DIR
|
|
|
|
|
|
def cache_document_from_bytes(data: bytes, filename: str) -> str:
|
|
"""
|
|
Save raw document bytes to the cache and return the absolute file path.
|
|
|
|
The cached filename preserves the original human-readable name with a
|
|
unique prefix: ``doc_{uuid12}_{original_filename}``.
|
|
|
|
Args:
|
|
data: Raw document bytes.
|
|
filename: Original filename (e.g. "report.pdf").
|
|
|
|
Returns:
|
|
Absolute path to the cached document file as a string.
|
|
|
|
Raises:
|
|
ValueError: If the sanitized path escapes the cache directory.
|
|
"""
|
|
cache_dir = get_document_cache_dir()
|
|
# Sanitize: strip directory components, null bytes, and control characters
|
|
safe_name = Path(filename).name if filename else "document"
|
|
safe_name = safe_name.replace("\x00", "").strip()
|
|
if not safe_name or safe_name in (".", ".."):
|
|
safe_name = "document"
|
|
cached_name = f"doc_{uuid.uuid4().hex[:12]}_{safe_name}"
|
|
filepath = cache_dir / cached_name
|
|
# Final safety check: ensure path stays inside cache dir
|
|
if not filepath.resolve().is_relative_to(cache_dir.resolve()):
|
|
raise ValueError(f"Path traversal rejected: {filename!r}")
|
|
filepath.write_bytes(data)
|
|
return str(filepath)
|
|
|
|
|
|
def cleanup_document_cache(max_age_hours: int = 24) -> int:
|
|
"""
|
|
Delete cached documents older than *max_age_hours*.
|
|
|
|
Returns the number of files removed.
|
|
"""
|
|
import time
|
|
|
|
cache_dir = get_document_cache_dir()
|
|
cutoff = time.time() - (max_age_hours * 3600)
|
|
removed = 0
|
|
for f in cache_dir.iterdir():
|
|
if f.is_file() and f.stat().st_mtime < cutoff:
|
|
try:
|
|
f.unlink()
|
|
removed += 1
|
|
except OSError:
|
|
pass
|
|
return removed
|
|
|
|
|
|
class MessageType(Enum):
|
|
"""Types of incoming messages."""
|
|
TEXT = "text"
|
|
LOCATION = "location"
|
|
PHOTO = "photo"
|
|
VIDEO = "video"
|
|
AUDIO = "audio"
|
|
VOICE = "voice"
|
|
DOCUMENT = "document"
|
|
STICKER = "sticker"
|
|
COMMAND = "command" # /command style
|
|
|
|
|
|
class ProcessingOutcome(Enum):
|
|
"""Result classification for message-processing lifecycle hooks."""
|
|
|
|
SUCCESS = "success"
|
|
FAILURE = "failure"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
@dataclass
|
|
class MessageEvent:
|
|
"""
|
|
Incoming message from a platform.
|
|
|
|
Normalized representation that all adapters produce.
|
|
"""
|
|
# Message content
|
|
text: str
|
|
message_type: MessageType = MessageType.TEXT
|
|
|
|
# Source information
|
|
source: SessionSource = None
|
|
|
|
# Original platform data
|
|
raw_message: Any = None
|
|
message_id: Optional[str] = None
|
|
|
|
# Platform-specific update identifier. For Telegram this is the
|
|
# ``update_id`` from the PTB Update wrapper; other platforms currently
|
|
# ignore it. Used by ``/restart`` to record the triggering update so the
|
|
# new gateway can advance the Telegram offset past it and avoid processing
|
|
# the same ``/restart`` twice if PTB's graceful-shutdown ACK times out
|
|
# ("Error while calling `get_updates` one more time to mark all fetched
|
|
# updates" in gateway.log).
|
|
platform_update_id: Optional[int] = None
|
|
|
|
# Media attachments
|
|
# media_urls: local file paths (for vision tool access)
|
|
media_urls: List[str] = field(default_factory=list)
|
|
media_types: List[str] = field(default_factory=list)
|
|
|
|
# Reply context
|
|
reply_to_message_id: Optional[str] = None
|
|
reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection)
|
|
|
|
# Auto-loaded skill(s) for topic/channel bindings (e.g., Telegram DM Topics,
|
|
# Discord channel_skill_bindings). A single name or ordered list.
|
|
auto_skill: Optional[str | list[str]] = None
|
|
|
|
# Per-channel ephemeral system prompt (e.g. Discord channel_prompts).
|
|
# Applied at API call time and never persisted to transcript history.
|
|
channel_prompt: Optional[str] = None
|
|
|
|
# Internal flag — set for synthetic events (e.g. background process
|
|
# completion notifications) that must bypass user authorization checks.
|
|
internal: bool = False
|
|
|
|
# Timestamps
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
def is_command(self) -> bool:
|
|
"""Check if this is a command message (e.g., /new, /reset)."""
|
|
return self.text.startswith("/")
|
|
|
|
def get_command(self) -> Optional[str]:
|
|
"""Extract command name if this is a command message."""
|
|
if not self.is_command():
|
|
return None
|
|
# Split on space and get first word, strip the /
|
|
parts = self.text.split(maxsplit=1)
|
|
raw = parts[0][1:].lower() if parts else None
|
|
if raw and "@" in raw:
|
|
raw = raw.split("@", 1)[0]
|
|
# Reject file paths: valid command names never contain /
|
|
if raw and "/" in raw:
|
|
return None
|
|
return raw
|
|
|
|
def get_command_args(self) -> str:
|
|
"""Get the arguments after a command."""
|
|
if not self.is_command():
|
|
return self.text
|
|
parts = self.text.split(maxsplit=1)
|
|
return parts[1] if len(parts) > 1 else ""
|
|
|
|
|
|
@dataclass
|
|
class SendResult:
|
|
"""Result of sending a message."""
|
|
success: bool
|
|
message_id: Optional[str] = None
|
|
error: Optional[str] = None
|
|
raw_response: Any = None
|
|
retryable: bool = False # True for transient connection errors — base will retry automatically
|
|
|
|
|
|
def merge_pending_message_event(
|
|
pending_messages: Dict[str, MessageEvent],
|
|
session_key: str,
|
|
event: MessageEvent,
|
|
*,
|
|
merge_text: bool = False,
|
|
) -> None:
|
|
"""Store or merge a pending event for a session.
|
|
|
|
Photo bursts/albums often arrive as multiple near-simultaneous PHOTO
|
|
events. Merge those into the existing queued event so the next turn sees
|
|
the whole burst.
|
|
|
|
When ``merge_text`` is enabled, rapid follow-up TEXT events are appended
|
|
instead of replacing the pending turn. This is used for Telegram bursty
|
|
follow-ups so a multi-part user thought is not silently truncated to only
|
|
the last queued fragment.
|
|
"""
|
|
existing = pending_messages.get(session_key)
|
|
if existing:
|
|
existing_is_photo = getattr(existing, "message_type", None) == MessageType.PHOTO
|
|
incoming_is_photo = event.message_type == MessageType.PHOTO
|
|
existing_has_media = bool(existing.media_urls)
|
|
incoming_has_media = bool(event.media_urls)
|
|
|
|
if existing_is_photo and incoming_is_photo:
|
|
existing.media_urls.extend(event.media_urls)
|
|
existing.media_types.extend(event.media_types)
|
|
if event.text:
|
|
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
|
return
|
|
|
|
if existing_has_media or incoming_has_media:
|
|
if incoming_has_media:
|
|
existing.media_urls.extend(event.media_urls)
|
|
existing.media_types.extend(event.media_types)
|
|
if event.text:
|
|
if existing.text:
|
|
existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text)
|
|
else:
|
|
existing.text = event.text
|
|
if existing_is_photo or incoming_is_photo:
|
|
existing.message_type = MessageType.PHOTO
|
|
return
|
|
|
|
if (
|
|
merge_text
|
|
and getattr(existing, "message_type", None) == MessageType.TEXT
|
|
and event.message_type == MessageType.TEXT
|
|
):
|
|
if event.text:
|
|
existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text
|
|
return
|
|
|
|
pending_messages[session_key] = event
|
|
|
|
|
|
# Error substrings that indicate a transient *connection* failure worth retrying.
|
|
# "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally
|
|
# excluded: a read/write timeout on a non-idempotent call (e.g. send_message)
|
|
# means the request may have reached the server — retrying risks duplicate
|
|
# delivery. "connecttimeout" is safe because the connection was never
|
|
# established. Platforms that know a timeout is safe to retry should set
|
|
# SendResult.retryable = True explicitly.
|
|
_RETRYABLE_ERROR_PATTERNS = (
|
|
"connecterror",
|
|
"connectionerror",
|
|
"connectionreset",
|
|
"connectionrefused",
|
|
"connecttimeout",
|
|
"network",
|
|
"broken pipe",
|
|
"remotedisconnected",
|
|
"eoferror",
|
|
)
|
|
|
|
|
|
# Type for message handlers
|
|
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
|
|
|
|
|
def resolve_channel_prompt(
|
|
config_extra: dict,
|
|
channel_id: str,
|
|
parent_id: str | None = None,
|
|
) -> str | None:
|
|
"""Resolve a per-channel ephemeral prompt from platform config.
|
|
|
|
Looks up ``channel_prompts`` in the adapter's ``config.extra`` dict.
|
|
Prefers an exact match on *channel_id*; falls back to *parent_id*
|
|
(useful for forum threads / child channels inheriting a parent prompt).
|
|
|
|
Returns the prompt string, or None if no match is found. Blank/whitespace-
|
|
only prompts are treated as absent.
|
|
"""
|
|
prompts = config_extra.get("channel_prompts") or {}
|
|
if not isinstance(prompts, dict):
|
|
return None
|
|
|
|
for key in (channel_id, parent_id):
|
|
if not key:
|
|
continue
|
|
prompt = prompts.get(key)
|
|
if prompt is None:
|
|
continue
|
|
prompt = str(prompt).strip()
|
|
if prompt:
|
|
return prompt
|
|
return None
|
|
|
|
|
|
class BasePlatformAdapter(ABC):
|
|
"""
|
|
Base class for platform adapters.
|
|
|
|
Subclasses implement platform-specific logic for:
|
|
- Connecting and authenticating
|
|
- Receiving messages
|
|
- Sending messages/responses
|
|
- Handling media
|
|
"""
|
|
|
|
def __init__(self, config: PlatformConfig, platform: Platform):
|
|
self.config = config
|
|
self.platform = platform
|
|
self._message_handler: Optional[MessageHandler] = None
|
|
self._running = False
|
|
self._fatal_error_code: Optional[str] = None
|
|
self._fatal_error_message: Optional[str] = None
|
|
self._fatal_error_retryable = True
|
|
self._fatal_error_handler: Optional[Callable[["BasePlatformAdapter"], Awaitable[None] | None]] = None
|
|
|
|
# Track active message handlers per session for interrupt support
|
|
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
|
self._active_sessions: Dict[str, asyncio.Event] = {}
|
|
self._pending_messages: Dict[str, MessageEvent] = {}
|
|
# Background message-processing tasks spawned by handle_message().
|
|
# Gateway shutdown cancels these so an old gateway instance doesn't keep
|
|
# working on a task after --replace or manual restarts.
|
|
self._background_tasks: set[asyncio.Task] = set()
|
|
# One-shot callbacks to fire after the main response is delivered.
|
|
# Keyed by session_key. GatewayRunner uses this to defer
|
|
# background-review notifications ("💾 Skill created") until the
|
|
# primary reply has been sent.
|
|
self._post_delivery_callbacks: Dict[str, Callable] = {}
|
|
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
|
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
|
|
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
|
self._auto_tts_disabled_chats: set = set()
|
|
# Chats where typing indicator is paused (e.g. during approval waits).
|
|
# _keep_typing skips send_typing when the chat_id is in this set.
|
|
self._typing_paused: set = set()
|
|
|
|
@property
|
|
def has_fatal_error(self) -> bool:
|
|
return self._fatal_error_message is not None
|
|
|
|
@property
|
|
def fatal_error_message(self) -> Optional[str]:
|
|
return self._fatal_error_message
|
|
|
|
@property
|
|
def fatal_error_code(self) -> Optional[str]:
|
|
return self._fatal_error_code
|
|
|
|
@property
|
|
def fatal_error_retryable(self) -> bool:
|
|
return self._fatal_error_retryable
|
|
|
|
def set_fatal_error_handler(self, handler: Callable[["BasePlatformAdapter"], Awaitable[None] | None]) -> None:
|
|
self._fatal_error_handler = handler
|
|
|
|
def _mark_connected(self) -> None:
|
|
self._running = True
|
|
self._fatal_error_code = None
|
|
self._fatal_error_message = None
|
|
self._fatal_error_retryable = True
|
|
try:
|
|
from gateway.status import write_runtime_status
|
|
write_runtime_status(platform=self.platform.value, platform_state="connected", error_code=None, error_message=None)
|
|
except Exception:
|
|
pass
|
|
|
|
def _mark_disconnected(self) -> None:
|
|
self._running = False
|
|
if self.has_fatal_error:
|
|
return
|
|
try:
|
|
from gateway.status import write_runtime_status
|
|
write_runtime_status(platform=self.platform.value, platform_state="disconnected", error_code=None, error_message=None)
|
|
except Exception:
|
|
pass
|
|
|
|
def _set_fatal_error(self, code: str, message: str, *, retryable: bool) -> None:
|
|
self._running = False
|
|
self._fatal_error_code = code
|
|
self._fatal_error_message = message
|
|
self._fatal_error_retryable = retryable
|
|
try:
|
|
from gateway.status import write_runtime_status
|
|
write_runtime_status(
|
|
platform=self.platform.value,
|
|
platform_state="fatal",
|
|
error_code=code,
|
|
error_message=message,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
async def _notify_fatal_error(self) -> None:
|
|
handler = self._fatal_error_handler
|
|
if not handler:
|
|
return
|
|
result = handler(self)
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
|
|
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
|
|
"""Acquire a scoped lock for this adapter. Returns True on success."""
|
|
from gateway.status import acquire_scoped_lock
|
|
self._platform_lock_scope = scope
|
|
self._platform_lock_identity = identity
|
|
acquired, existing = acquire_scoped_lock(
|
|
scope, identity, metadata={'platform': self.platform.value}
|
|
)
|
|
if acquired:
|
|
return True
|
|
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
|
message = (
|
|
f'{resource_desc} already in use'
|
|
+ (f' (PID {owner_pid})' if owner_pid else '')
|
|
+ '. Stop the other gateway first.'
|
|
)
|
|
logger.error('[%s] %s', self.name, message)
|
|
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
|
|
return False
|
|
|
|
def _release_platform_lock(self) -> None:
|
|
"""Release the scoped lock acquired by _acquire_platform_lock."""
|
|
identity = getattr(self, '_platform_lock_identity', None)
|
|
if not identity:
|
|
return
|
|
from gateway.status import release_scoped_lock
|
|
release_scoped_lock(self._platform_lock_scope, identity)
|
|
self._platform_lock_identity = None
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Human-readable name for this adapter."""
|
|
return self.platform.value.title()
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if adapter is currently connected."""
|
|
return self._running
|
|
|
|
def set_message_handler(self, handler: MessageHandler) -> None:
|
|
"""
|
|
Set the handler for incoming messages.
|
|
|
|
The handler receives a MessageEvent and should return
|
|
an optional response string.
|
|
"""
|
|
self._message_handler = handler
|
|
|
|
def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None:
|
|
"""Set an optional handler for messages arriving during active sessions."""
|
|
self._busy_session_handler = handler
|
|
|
|
def set_session_store(self, session_store: Any) -> None:
|
|
"""
|
|
Set the session store for checking active sessions.
|
|
|
|
Used by adapters that need to check if a thread/conversation
|
|
has an active session before processing messages (e.g., Slack
|
|
thread replies without explicit mentions).
|
|
"""
|
|
self._session_store = session_store
|
|
|
|
@abstractmethod
|
|
async def connect(self) -> bool:
|
|
"""
|
|
Connect to the platform and start receiving messages.
|
|
|
|
Returns True if connection was successful.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the platform."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def send(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> SendResult:
|
|
"""
|
|
Send a message to a chat.
|
|
|
|
Args:
|
|
chat_id: The chat/channel ID to send to
|
|
content: Message content (may be markdown)
|
|
reply_to: Optional message ID to reply to
|
|
metadata: Additional platform-specific options
|
|
|
|
Returns:
|
|
SendResult with success status and message ID
|
|
"""
|
|
pass
|
|
|
|
# Default: the adapter treats ``finalize=True`` on edit_message as a
|
|
# no-op and is happy to have the stream consumer skip redundant final
|
|
# edits. Subclasses that *require* an explicit finalize call to close
|
|
# out the message lifecycle (e.g. rich card / AI assistant surfaces
|
|
# such as DingTalk AI Cards) override this to True (class attribute or
|
|
# property) so the stream consumer knows not to short-circuit.
|
|
REQUIRES_EDIT_FINALIZE: bool = False
|
|
|
|
async def edit_message(
|
|
self,
|
|
chat_id: str,
|
|
message_id: str,
|
|
content: str,
|
|
*,
|
|
finalize: bool = False,
|
|
) -> SendResult:
|
|
"""
|
|
Edit a previously sent message. Optional — platforms that don't
|
|
support editing return success=False and callers fall back to
|
|
sending a new message.
|
|
|
|
``finalize`` signals that this is the last edit in a streaming
|
|
sequence. Most platforms (Telegram, Slack, Discord, Matrix,
|
|
etc.) treat it as a no-op because their edit APIs have no notion
|
|
of message lifecycle state — an edit is an edit. Platforms that
|
|
render streaming updates with a distinct "in progress" state and
|
|
require explicit closure (e.g. rich card / AI assistant surfaces
|
|
such as DingTalk AI Cards) use it to finalize the message and
|
|
transition the UI out of the streaming indicator — those should
|
|
also set ``REQUIRES_EDIT_FINALIZE = True`` so callers route a
|
|
final edit through even when content is unchanged. Callers
|
|
should set ``finalize=True`` on the final edit of a streamed
|
|
response (typically when ``got_done`` fires in the stream
|
|
consumer) and leave it ``False`` on intermediate edits.
|
|
"""
|
|
return SendResult(success=False, error="Not supported")
|
|
|
|
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
|
"""
|
|
Send a typing indicator.
|
|
|
|
Override in subclasses if the platform supports it.
|
|
metadata: optional dict with platform-specific context (e.g. thread_id for Slack).
|
|
"""
|
|
pass
|
|
|
|
async def stop_typing(self, chat_id: str) -> None:
|
|
"""Stop a persistent typing indicator (if the platform uses one).
|
|
|
|
Override in subclasses that start background typing loops.
|
|
Default is a no-op for platforms with one-shot typing indicators.
|
|
"""
|
|
pass
|
|
|
|
async def send_image(
|
|
self,
|
|
chat_id: str,
|
|
image_url: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> SendResult:
|
|
"""
|
|
Send an image natively via the platform API.
|
|
|
|
Override in subclasses to send images as proper attachments
|
|
instead of plain-text URLs. Default falls back to sending the
|
|
URL as a text message.
|
|
"""
|
|
# Fallback: send URL as text (subclasses override for native images)
|
|
text = f"{caption}\n{image_url}" if caption else image_url
|
|
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
|
|
|
async def send_animation(
|
|
self,
|
|
chat_id: str,
|
|
animation_url: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> SendResult:
|
|
"""
|
|
Send an animated GIF natively via the platform API.
|
|
|
|
Override in subclasses to send GIFs as proper animations
|
|
(e.g., Telegram send_animation) so they auto-play inline.
|
|
Default falls back to send_image.
|
|
"""
|
|
return await self.send_image(chat_id=chat_id, image_url=animation_url, caption=caption, reply_to=reply_to, metadata=metadata)
|
|
|
|
@staticmethod
|
|
def _is_animation_url(url: str) -> bool:
|
|
"""Check if a URL points to an animated GIF (vs a static image)."""
|
|
lower = url.lower().split('?')[0] # Strip query params
|
|
return lower.endswith('.gif')
|
|
|
|
@staticmethod
|
|
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
|
"""
|
|
Extract image URLs from markdown and HTML image tags in a response.
|
|
|
|
Finds patterns like:
|
|
- 
|
|
- <img src="https://example.com/image.png">
|
|
- <img src="https://example.com/image.png"></img>
|
|
|
|
Args:
|
|
content: The response text to scan.
|
|
|
|
Returns:
|
|
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
|
"""
|
|
images = []
|
|
cleaned = content
|
|
|
|
# Match markdown images: 
|
|
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
|
for match in re.finditer(md_pattern, content):
|
|
alt_text = match.group(1)
|
|
url = match.group(2)
|
|
# Only extract URLs that look like actual images
|
|
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
|
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
|
images.append((url, alt_text))
|
|
|
|
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
|
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
|
for match in re.finditer(html_pattern, content):
|
|
url = match.group(1)
|
|
images.append((url, ""))
|
|
|
|
# Remove only the matched image tags from content (not all markdown images)
|
|
if images:
|
|
extracted_urls = {url for url, _ in images}
|
|
def _remove_if_extracted(match):
|
|
url = match.group(2) if match.lastindex >= 2 else match.group(1)
|
|
return '' if url in extracted_urls else match.group(0)
|
|
cleaned = re.sub(md_pattern, _remove_if_extracted, cleaned)
|
|
cleaned = re.sub(html_pattern, _remove_if_extracted, cleaned)
|
|
# Clean up leftover blank lines
|
|
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
|
|
|
return images, cleaned
|
|
|
|
async def send_voice(
|
|
self,
|
|
chat_id: str,
|
|
audio_path: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
**kwargs,
|
|
) -> SendResult:
|
|
"""
|
|
Send an audio file as a native voice message via the platform API.
|
|
|
|
Override in subclasses to send audio as voice bubbles (Telegram)
|
|
or file attachments (Discord). Default falls back to sending the
|
|
file path as text.
|
|
"""
|
|
text = f"🔊 Audio: {audio_path}"
|
|
if caption:
|
|
text = f"{caption}\n{text}"
|
|
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
|
|
|
async def play_tts(
|
|
self,
|
|
chat_id: str,
|
|
audio_path: str,
|
|
**kwargs,
|
|
) -> SendResult:
|
|
"""
|
|
Play auto-TTS audio for voice replies.
|
|
|
|
Override in subclasses for invisible playback (e.g. Web UI).
|
|
Default falls back to send_voice (shows audio player).
|
|
"""
|
|
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
|
|
|
async def send_video(
|
|
self,
|
|
chat_id: str,
|
|
video_path: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
**kwargs,
|
|
) -> SendResult:
|
|
"""
|
|
Send a video natively via the platform API.
|
|
|
|
Override in subclasses to send videos as inline playable media.
|
|
Default falls back to sending the file path as text.
|
|
"""
|
|
text = f"🎬 Video: {video_path}"
|
|
if caption:
|
|
text = f"{caption}\n{text}"
|
|
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
|
|
|
async def send_document(
|
|
self,
|
|
chat_id: str,
|
|
file_path: str,
|
|
caption: Optional[str] = None,
|
|
file_name: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
**kwargs,
|
|
) -> SendResult:
|
|
"""
|
|
Send a document/file natively via the platform API.
|
|
|
|
Override in subclasses to send files as downloadable attachments.
|
|
Default falls back to sending the file path as text.
|
|
"""
|
|
text = f"📎 File: {file_path}"
|
|
if caption:
|
|
text = f"{caption}\n{text}"
|
|
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
|
|
|
async def send_image_file(
|
|
self,
|
|
chat_id: str,
|
|
image_path: str,
|
|
caption: Optional[str] = None,
|
|
reply_to: Optional[str] = None,
|
|
**kwargs,
|
|
) -> SendResult:
|
|
"""
|
|
Send a local image file natively via the platform API.
|
|
|
|
Unlike send_image() which takes a URL, this takes a local file path.
|
|
Override in subclasses for native photo attachments.
|
|
Default falls back to sending the file path as text.
|
|
"""
|
|
text = f"🖼️ Image: {image_path}"
|
|
if caption:
|
|
text = f"{caption}\n{text}"
|
|
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
|
|
|
@staticmethod
|
|
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
|
"""
|
|
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
|
|
|
The TTS tool returns responses like:
|
|
[[audio_as_voice]]
|
|
MEDIA:/path/to/audio.ogg
|
|
|
|
Args:
|
|
content: The response text to scan.
|
|
|
|
Returns:
|
|
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
|
"""
|
|
media = []
|
|
cleaned = content
|
|
|
|
# Check for [[audio_as_voice]] directive
|
|
has_voice_tag = "[[audio_as_voice]]" in content
|
|
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
|
|
|
# Extract MEDIA:<path> tags, allowing optional whitespace after the colon
|
|
# and quoted/backticked paths for LLM-formatted outputs.
|
|
media_pattern = re.compile(
|
|
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?'''
|
|
)
|
|
for match in media_pattern.finditer(content):
|
|
path = match.group("path").strip()
|
|
if len(path) >= 2 and path[0] == path[-1] and path[0] in "`\"'":
|
|
path = path[1:-1].strip()
|
|
path = path.lstrip("`\"'").rstrip("`\"',.;:)}]")
|
|
if path:
|
|
media.append((os.path.expanduser(path), has_voice_tag))
|
|
|
|
# Remove MEDIA tags from content (including surrounding quote/backtick wrappers)
|
|
if media:
|
|
cleaned = media_pattern.sub('', cleaned)
|
|
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
|
|
|
return media, cleaned
|
|
|
|
@staticmethod
|
|
def extract_local_files(content: str) -> Tuple[List[str], str]:
|
|
"""
|
|
Detect bare local file paths in response text for native media delivery.
|
|
|
|
Matches absolute paths (/...) and tilde paths (~/) ending in common
|
|
image or video extensions. Validates each candidate with
|
|
``os.path.isfile()`` to avoid false positives from URLs or
|
|
non-existent paths.
|
|
|
|
Paths inside fenced code blocks (``` ... ```) and inline code
|
|
(`...`) are ignored so that code samples are never mutilated.
|
|
|
|
Returns:
|
|
Tuple of (list of expanded file paths, cleaned text with the
|
|
raw path strings removed).
|
|
"""
|
|
_LOCAL_MEDIA_EXTS = (
|
|
'.png', '.jpg', '.jpeg', '.gif', '.webp',
|
|
'.mp4', '.mov', '.avi', '.mkv', '.webm',
|
|
)
|
|
ext_part = '|'.join(e.lstrip('.') for e in _LOCAL_MEDIA_EXTS)
|
|
|
|
# (?<![/:\w.]) prevents matching inside URLs (e.g. https://…/img.png)
|
|
# and relative paths (./foo.png)
|
|
# (?:~/|/) anchors to absolute or home-relative paths
|
|
path_re = re.compile(
|
|
r'(?<![/:\w.])(?:~/|/)(?:[\w.\-]+/)*[\w.\-]+\.(?:' + ext_part + r')\b',
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# Build spans covered by fenced code blocks and inline code
|
|
code_spans: list = []
|
|
for m in re.finditer(r'```[^\n]*\n.*?```', content, re.DOTALL):
|
|
code_spans.append((m.start(), m.end()))
|
|
for m in re.finditer(r'`[^`\n]+`', content):
|
|
code_spans.append((m.start(), m.end()))
|
|
|
|
def _in_code(pos: int) -> bool:
|
|
return any(s <= pos < e for s, e in code_spans)
|
|
|
|
found: list = [] # (raw_match_text, expanded_path)
|
|
for match in path_re.finditer(content):
|
|
if _in_code(match.start()):
|
|
continue
|
|
raw = match.group(0)
|
|
expanded = os.path.expanduser(raw)
|
|
if os.path.isfile(expanded):
|
|
found.append((raw, expanded))
|
|
|
|
# Deduplicate by expanded path, preserving discovery order
|
|
seen: set = set()
|
|
unique: list = []
|
|
for raw, expanded in found:
|
|
if expanded not in seen:
|
|
seen.add(expanded)
|
|
unique.append((raw, expanded))
|
|
|
|
paths = [expanded for _, expanded in unique]
|
|
|
|
cleaned = content
|
|
if unique:
|
|
for raw, _exp in unique:
|
|
cleaned = cleaned.replace(raw, '')
|
|
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
|
|
|
return paths, cleaned
|
|
|
|
async def _keep_typing(
|
|
self,
|
|
chat_id: str,
|
|
interval: float = 2.0,
|
|
metadata=None,
|
|
stop_event: asyncio.Event | None = None,
|
|
) -> None:
|
|
"""
|
|
Continuously send typing indicator until cancelled.
|
|
|
|
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
|
to recover quickly after progress messages interrupt it.
|
|
|
|
Skips send_typing when the chat is in ``_typing_paused`` (e.g. while
|
|
the agent is waiting for dangerous-command approval). This is critical
|
|
for Slack's Assistant API where ``assistant_threads_setStatus`` disables
|
|
the compose box — pausing lets the user type ``/approve`` or ``/deny``.
|
|
"""
|
|
try:
|
|
while True:
|
|
if stop_event is not None and stop_event.is_set():
|
|
return
|
|
if chat_id not in self._typing_paused:
|
|
await self.send_typing(chat_id, metadata=metadata)
|
|
if stop_event is None:
|
|
await asyncio.sleep(interval)
|
|
continue
|
|
try:
|
|
await asyncio.wait_for(stop_event.wait(), timeout=interval)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
return
|
|
except asyncio.CancelledError:
|
|
pass # Normal cancellation when handler completes
|
|
finally:
|
|
# Ensure the underlying platform typing loop is stopped.
|
|
# _keep_typing may have called send_typing() after an outer
|
|
# stop_typing() cleared the task dict, recreating the loop.
|
|
# Cancelling _keep_typing alone won't clean that up.
|
|
if hasattr(self, "stop_typing"):
|
|
try:
|
|
await self.stop_typing(chat_id)
|
|
except Exception:
|
|
pass
|
|
self._typing_paused.discard(chat_id)
|
|
|
|
def pause_typing_for_chat(self, chat_id: str) -> None:
|
|
"""Pause typing indicator for a chat (e.g. during approval waits).
|
|
|
|
Thread-safe (CPython GIL) — can be called from the sync agent thread
|
|
while ``_keep_typing`` runs on the async event loop.
|
|
"""
|
|
self._typing_paused.add(chat_id)
|
|
|
|
def resume_typing_for_chat(self, chat_id: str) -> None:
|
|
"""Resume typing indicator for a chat after approval resolves."""
|
|
self._typing_paused.discard(chat_id)
|
|
|
|
async def interrupt_session_activity(self, session_key: str, chat_id: str) -> None:
|
|
"""Signal the active session loop to stop and clear typing immediately."""
|
|
if session_key:
|
|
interrupt_event = self._active_sessions.get(session_key)
|
|
if interrupt_event is not None:
|
|
interrupt_event.set()
|
|
try:
|
|
await self.stop_typing(chat_id)
|
|
except Exception:
|
|
pass
|
|
|
|
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
|
# Subclasses override these to react to message processing events
|
|
# (e.g. Discord adds 👀/✅/❌ reactions).
|
|
|
|
async def on_processing_start(self, event: MessageEvent) -> None:
|
|
"""Hook called when background processing begins."""
|
|
|
|
async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None:
|
|
"""Hook called when background processing completes."""
|
|
|
|
async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
|
|
"""Run a lifecycle hook without letting failures break message flow."""
|
|
hook = getattr(self, hook_name, None)
|
|
if not callable(hook):
|
|
return
|
|
try:
|
|
await hook(*args, **kwargs)
|
|
except Exception as e:
|
|
logger.warning("[%s] %s hook failed: %s", self.name, hook_name, e)
|
|
|
|
@staticmethod
|
|
def _is_retryable_error(error: Optional[str]) -> bool:
|
|
"""Return True if the error string looks like a transient network failure."""
|
|
if not error:
|
|
return False
|
|
lowered = error.lower()
|
|
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
|
|
|
@staticmethod
|
|
def _is_timeout_error(error: Optional[str]) -> bool:
|
|
"""Return True if the error string indicates a read/write timeout.
|
|
|
|
Timeout errors are NOT retryable and should NOT trigger plain-text
|
|
fallback — the request may have already been delivered.
|
|
"""
|
|
if not error:
|
|
return False
|
|
lowered = error.lower()
|
|
return "timed out" in lowered or "readtimeout" in lowered or "writetimeout" in lowered
|
|
|
|
async def _send_with_retry(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Any = None,
|
|
max_retries: int = 2,
|
|
base_delay: float = 2.0,
|
|
) -> "SendResult":
|
|
"""
|
|
Send a message with automatic retry for transient network errors.
|
|
|
|
On permanent failures (e.g. formatting / permission errors) falls back
|
|
to a plain-text version before giving up. If all attempts fail due to
|
|
network errors, sends the user a brief delivery-failure notice so they
|
|
know to retry rather than waiting indefinitely.
|
|
"""
|
|
|
|
result = await self.send(
|
|
chat_id=chat_id,
|
|
content=content,
|
|
reply_to=reply_to,
|
|
metadata=metadata,
|
|
)
|
|
|
|
if result.success:
|
|
return result
|
|
|
|
error_str = result.error or ""
|
|
is_network = result.retryable or self._is_retryable_error(error_str)
|
|
|
|
# Timeout errors are not safe to retry (message may have been
|
|
# delivered) and not formatting errors — return the failure as-is.
|
|
if not is_network and self._is_timeout_error(error_str):
|
|
return result
|
|
|
|
if is_network:
|
|
# Retry with exponential backoff for transient errors
|
|
for attempt in range(1, max_retries + 1):
|
|
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
|
logger.warning(
|
|
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
|
self.name, attempt, max_retries, delay, error_str,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
result = await self.send(
|
|
chat_id=chat_id,
|
|
content=content,
|
|
reply_to=reply_to,
|
|
metadata=metadata,
|
|
)
|
|
if result.success:
|
|
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
|
return result
|
|
error_str = result.error or ""
|
|
if not (result.retryable or self._is_retryable_error(error_str)):
|
|
break # error switched to non-transient — fall through to plain-text fallback
|
|
else:
|
|
# All retries exhausted (loop completed without break) — notify user
|
|
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
|
notice = (
|
|
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
|
"Please try again \u2014 your request was processed but the response could not be sent."
|
|
)
|
|
try:
|
|
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
|
except Exception as notify_err:
|
|
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
|
return result
|
|
|
|
# Non-network / post-retry formatting failure: try plain text as fallback
|
|
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
|
fallback_result = await self.send(
|
|
chat_id=chat_id,
|
|
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
|
reply_to=reply_to,
|
|
metadata=metadata,
|
|
)
|
|
if not fallback_result.success:
|
|
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
|
return fallback_result
|
|
|
|
@staticmethod
|
|
def _merge_caption(existing_text: Optional[str], new_text: str) -> str:
|
|
"""Merge a new caption into existing text, avoiding duplicates.
|
|
|
|
Uses line-by-line exact match (not substring) to prevent false positives
|
|
where a shorter caption is silently dropped because it appears as a
|
|
substring of a longer one (e.g. "Meeting" inside "Meeting agenda").
|
|
Whitespace is normalised for comparison.
|
|
"""
|
|
if not existing_text:
|
|
return new_text
|
|
existing_captions = [c.strip() for c in existing_text.split("\n\n")]
|
|
if new_text.strip() not in existing_captions:
|
|
return f"{existing_text}\n\n{new_text}".strip()
|
|
return existing_text
|
|
|
|
async def handle_message(self, event: MessageEvent) -> None:
|
|
"""
|
|
Process an incoming message.
|
|
|
|
This method returns quickly by spawning background tasks.
|
|
This allows new messages to be processed even while an agent is running,
|
|
enabling interruption support.
|
|
"""
|
|
if not self._message_handler:
|
|
return
|
|
|
|
session_key = build_session_key(
|
|
event.source,
|
|
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
|
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
|
)
|
|
|
|
# Check if there's already an active handler for this session
|
|
if session_key in self._active_sessions:
|
|
# Certain commands must bypass the active-session guard and be
|
|
# dispatched directly to the gateway runner. Without this, they
|
|
# are queued as pending messages and either:
|
|
# - leak into the conversation as user text (/stop, /new), or
|
|
# - deadlock (/approve, /deny — agent is blocked on Event.wait)
|
|
#
|
|
# Dispatch inline: call the message handler directly and send the
|
|
# response. Do NOT use _process_message_background — it manages
|
|
# session lifecycle and its cleanup races with the running task
|
|
# (see PR #4926).
|
|
cmd = event.get_command()
|
|
from hermes_cli.commands import should_bypass_active_session
|
|
|
|
if should_bypass_active_session(cmd):
|
|
logger.debug(
|
|
"[%s] Command '/%s' bypassing active-session guard for %s",
|
|
self.name, cmd, session_key,
|
|
)
|
|
try:
|
|
_thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
|
response = await self._message_handler(event)
|
|
if response:
|
|
await self._send_with_retry(
|
|
chat_id=event.source.chat_id,
|
|
content=response,
|
|
reply_to=event.message_id,
|
|
metadata=_thread_meta,
|
|
)
|
|
except Exception as e:
|
|
logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True)
|
|
return
|
|
|
|
if self._busy_session_handler is not None:
|
|
try:
|
|
if await self._busy_session_handler(event, session_key):
|
|
return
|
|
except Exception as e:
|
|
logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True)
|
|
|
|
# Special case: photo bursts/albums frequently arrive as multiple near-
|
|
# simultaneous messages. Queue them without interrupting the active run,
|
|
# then process them immediately after the current task finishes.
|
|
if event.message_type == MessageType.PHOTO:
|
|
logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key)
|
|
merge_pending_message_event(self._pending_messages, session_key, event)
|
|
return # Don't interrupt now - will run after current task completes
|
|
|
|
# Default behavior for non-photo follow-ups: interrupt the running agent
|
|
logger.debug("[%s] New message while session %s is active — triggering interrupt", self.name, session_key)
|
|
self._pending_messages[session_key] = event
|
|
# Signal the interrupt (the processing task checks this)
|
|
self._active_sessions[session_key].set()
|
|
return # Don't process now - will be handled after current task finishes
|
|
|
|
# Mark session as active BEFORE spawning background task to close
|
|
# the race window where a second message arriving before the task
|
|
# starts would also pass the _active_sessions check and spawn a
|
|
# duplicate task. (grammY sequentialize / aiogram EventIsolation
|
|
# pattern — set the guard synchronously, not inside the task.)
|
|
self._active_sessions[session_key] = asyncio.Event()
|
|
|
|
# Spawn background task to process this message
|
|
task = asyncio.create_task(self._process_message_background(event, session_key))
|
|
try:
|
|
self._background_tasks.add(task)
|
|
except TypeError:
|
|
# Some tests stub create_task() with lightweight sentinels that are not
|
|
# hashable and do not support lifecycle callbacks.
|
|
return
|
|
if hasattr(task, "add_done_callback"):
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
task.add_done_callback(self._expected_cancelled_tasks.discard)
|
|
|
|
@staticmethod
|
|
def _get_human_delay() -> float:
|
|
"""
|
|
Return a random delay in seconds for human-like response pacing.
|
|
|
|
Reads from env vars:
|
|
HERMES_HUMAN_DELAY_MODE: "off" (default) | "natural" | "custom"
|
|
HERMES_HUMAN_DELAY_MIN_MS: minimum delay in ms (default 800, custom mode)
|
|
HERMES_HUMAN_DELAY_MAX_MS: maximum delay in ms (default 2500, custom mode)
|
|
"""
|
|
import random
|
|
|
|
mode = os.getenv("HERMES_HUMAN_DELAY_MODE", "off").lower()
|
|
if mode == "off":
|
|
return 0.0
|
|
min_ms = int(os.getenv("HERMES_HUMAN_DELAY_MIN_MS", "800"))
|
|
max_ms = int(os.getenv("HERMES_HUMAN_DELAY_MAX_MS", "2500"))
|
|
if mode == "natural":
|
|
min_ms, max_ms = 800, 2500
|
|
return random.uniform(min_ms / 1000.0, max_ms / 1000.0)
|
|
|
|
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
|
"""Background task that actually processes the message."""
|
|
# Track delivery outcomes for the processing-complete hook
|
|
delivery_attempted = False
|
|
delivery_succeeded = False
|
|
|
|
def _record_delivery(result):
|
|
nonlocal delivery_attempted, delivery_succeeded
|
|
if result is None:
|
|
return
|
|
delivery_attempted = True
|
|
if getattr(result, "success", False):
|
|
delivery_succeeded = True
|
|
|
|
# Reuse the interrupt event set by handle_message() (which marks
|
|
# the session active before spawning this task to prevent races).
|
|
# Fall back to a new Event only if the entry was removed externally.
|
|
interrupt_event = self._active_sessions.get(session_key) or asyncio.Event()
|
|
self._active_sessions[session_key] = interrupt_event
|
|
|
|
# Start continuous typing indicator (refreshes every 2 seconds)
|
|
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
|
typing_task = asyncio.create_task(
|
|
self._keep_typing(
|
|
event.source.chat_id,
|
|
metadata=_thread_metadata,
|
|
stop_event=interrupt_event,
|
|
)
|
|
)
|
|
|
|
try:
|
|
await self._run_processing_hook("on_processing_start", event)
|
|
|
|
# Call the handler (this can take a while with tool calls)
|
|
response = await self._message_handler(event)
|
|
|
|
# Send response if any. A None/empty response is normal when
|
|
# streaming already delivered the text (already_sent=True) or
|
|
# when the message was queued behind an active agent. Log at
|
|
# DEBUG to avoid noisy warnings for expected behavior.
|
|
#
|
|
# Suppress stale response when the session was interrupted by a
|
|
# new message that hasn't been consumed yet. The pending message
|
|
# is processed by the pending-message handler below (#8221/#2483).
|
|
if (
|
|
response
|
|
and interrupt_event.is_set()
|
|
and session_key in self._pending_messages
|
|
):
|
|
logger.info(
|
|
"[%s] Suppressing stale response for interrupted session %s",
|
|
self.name,
|
|
session_key,
|
|
)
|
|
response = None
|
|
if not response:
|
|
logger.debug("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id)
|
|
if response:
|
|
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
|
media_files, response = self.extract_media(response)
|
|
|
|
# Extract image URLs and send them as native platform attachments
|
|
images, text_content = self.extract_images(response)
|
|
# Strip any remaining internal directives from message body (fixes #1561)
|
|
text_content = text_content.replace("[[audio_as_voice]]", "").strip()
|
|
text_content = re.sub(r"MEDIA:\s*\S+", "", text_content).strip()
|
|
if images:
|
|
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
|
|
|
# Auto-detect bare local file paths for native media delivery
|
|
# (helps small models that don't use MEDIA: syntax)
|
|
local_files, text_content = self.extract_local_files(text_content)
|
|
if local_files:
|
|
logger.info("[%s] extract_local_files found %d file(s) in response", self.name, len(local_files))
|
|
|
|
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
|
# Skipped when the chat has voice mode disabled (/voice off)
|
|
_tts_path = None
|
|
if (event.message_type == MessageType.VOICE
|
|
and text_content
|
|
and not media_files
|
|
and event.source.chat_id not in self._auto_tts_disabled_chats):
|
|
try:
|
|
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
|
if check_tts_requirements():
|
|
import json as _json
|
|
speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip()
|
|
if not speech_text:
|
|
raise ValueError("Empty text after markdown cleanup")
|
|
tts_result_str = await asyncio.to_thread(
|
|
text_to_speech_tool, text=speech_text
|
|
)
|
|
tts_data = _json.loads(tts_result_str)
|
|
_tts_path = tts_data.get("file_path")
|
|
except Exception as tts_err:
|
|
logger.warning("[%s] Auto-TTS failed: %s", self.name, tts_err)
|
|
|
|
# Play TTS audio before text (voice-first experience)
|
|
if _tts_path and Path(_tts_path).exists():
|
|
try:
|
|
await self.play_tts(
|
|
chat_id=event.source.chat_id,
|
|
audio_path=_tts_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
finally:
|
|
try:
|
|
os.remove(_tts_path)
|
|
except OSError:
|
|
pass
|
|
|
|
# Send the text portion
|
|
if text_content:
|
|
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
|
result = await self._send_with_retry(
|
|
chat_id=event.source.chat_id,
|
|
content=text_content,
|
|
reply_to=event.message_id,
|
|
metadata=_thread_metadata,
|
|
)
|
|
_record_delivery(result)
|
|
|
|
# Human-like pacing delay between text and media
|
|
human_delay = self._get_human_delay()
|
|
|
|
# Send extracted images as native attachments
|
|
if images:
|
|
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
|
for image_url, alt_text in images:
|
|
if human_delay > 0:
|
|
await asyncio.sleep(human_delay)
|
|
try:
|
|
logger.info(
|
|
"[%s] Sending image: %s (alt=%s)",
|
|
self.name,
|
|
safe_url_for_log(image_url),
|
|
alt_text[:30] if alt_text else "",
|
|
)
|
|
# Route animated GIFs through send_animation for proper playback
|
|
if self._is_animation_url(image_url):
|
|
img_result = await self.send_animation(
|
|
chat_id=event.source.chat_id,
|
|
animation_url=image_url,
|
|
caption=alt_text if alt_text else None,
|
|
metadata=_thread_metadata,
|
|
)
|
|
else:
|
|
img_result = await self.send_image(
|
|
chat_id=event.source.chat_id,
|
|
image_url=image_url,
|
|
caption=alt_text if alt_text else None,
|
|
metadata=_thread_metadata,
|
|
)
|
|
if not img_result.success:
|
|
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
|
except Exception as img_err:
|
|
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
|
|
|
# Send extracted media files — route by file type
|
|
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
|
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'}
|
|
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
|
|
|
for media_path, is_voice in media_files:
|
|
if human_delay > 0:
|
|
await asyncio.sleep(human_delay)
|
|
try:
|
|
ext = Path(media_path).suffix.lower()
|
|
if ext in _AUDIO_EXTS:
|
|
media_result = await self.send_voice(
|
|
chat_id=event.source.chat_id,
|
|
audio_path=media_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
elif ext in _VIDEO_EXTS:
|
|
media_result = await self.send_video(
|
|
chat_id=event.source.chat_id,
|
|
video_path=media_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
elif ext in _IMAGE_EXTS:
|
|
media_result = await self.send_image_file(
|
|
chat_id=event.source.chat_id,
|
|
image_path=media_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
else:
|
|
media_result = await self.send_document(
|
|
chat_id=event.source.chat_id,
|
|
file_path=media_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
|
|
if not media_result.success:
|
|
logger.warning("[%s] Failed to send media (%s): %s", self.name, ext, media_result.error)
|
|
except Exception as media_err:
|
|
logger.warning("[%s] Error sending media: %s", self.name, media_err)
|
|
|
|
# Send auto-detected local files as native attachments
|
|
for file_path in local_files:
|
|
if human_delay > 0:
|
|
await asyncio.sleep(human_delay)
|
|
try:
|
|
ext = Path(file_path).suffix.lower()
|
|
if ext in _IMAGE_EXTS:
|
|
await self.send_image_file(
|
|
chat_id=event.source.chat_id,
|
|
image_path=file_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
elif ext in _VIDEO_EXTS:
|
|
await self.send_video(
|
|
chat_id=event.source.chat_id,
|
|
video_path=file_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
else:
|
|
await self.send_document(
|
|
chat_id=event.source.chat_id,
|
|
file_path=file_path,
|
|
metadata=_thread_metadata,
|
|
)
|
|
except Exception as file_err:
|
|
logger.error("[%s] Error sending local file %s: %s", self.name, file_path, file_err)
|
|
|
|
# Determine overall success for the processing hook
|
|
processing_ok = delivery_succeeded if delivery_attempted else not bool(response)
|
|
await self._run_processing_hook(
|
|
"on_processing_complete",
|
|
event,
|
|
ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE,
|
|
)
|
|
|
|
# Check if there's a pending message that was queued during our processing
|
|
if session_key in self._pending_messages:
|
|
pending_event = self._pending_messages.pop(session_key)
|
|
logger.debug("[%s] Processing queued message from interrupt", self.name)
|
|
# Keep the _active_sessions entry live across the turn chain
|
|
# and only CLEAR the interrupt Event — do NOT delete the entry.
|
|
# If we deleted here, a concurrent inbound message arriving
|
|
# during the awaits below would pass the Level-1 guard, spawn
|
|
# its own _process_message_background, and run simultaneously
|
|
# with the recursive drain below. Two agents on one
|
|
# session_key = duplicate responses, duplicate tool calls.
|
|
# Clearing the Event keeps the guard live so follow-ups take
|
|
# the busy-handler path (queue + interrupt) as intended.
|
|
_active = self._active_sessions.get(session_key)
|
|
if _active is not None:
|
|
_active.clear()
|
|
typing_task.cancel()
|
|
try:
|
|
await typing_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Process pending message in new background task
|
|
await self._process_message_background(pending_event, session_key)
|
|
return # Already cleaned up
|
|
|
|
except asyncio.CancelledError:
|
|
current_task = asyncio.current_task()
|
|
outcome = ProcessingOutcome.CANCELLED
|
|
if current_task is None or current_task not in self._expected_cancelled_tasks:
|
|
outcome = ProcessingOutcome.FAILURE
|
|
await self._run_processing_hook("on_processing_complete", event, outcome)
|
|
raise
|
|
except Exception as e:
|
|
await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE)
|
|
logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True)
|
|
# Send the error to the user so they aren't left with radio silence
|
|
try:
|
|
error_type = type(e).__name__
|
|
error_detail = str(e)[:300] if str(e) else "no details available"
|
|
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
|
await self.send(
|
|
chat_id=event.source.chat_id,
|
|
content=(
|
|
f"Sorry, I encountered an error ({error_type}).\n"
|
|
f"{error_detail}\n"
|
|
"Try again or use /reset to start a fresh session."
|
|
),
|
|
metadata=_thread_metadata,
|
|
)
|
|
except Exception:
|
|
pass # Last resort — don't let error reporting crash the handler
|
|
finally:
|
|
# Fire any one-shot post-delivery callback registered for this
|
|
# session (e.g. deferred background-review notifications).
|
|
_post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None)
|
|
if callable(_post_cb):
|
|
try:
|
|
_post_cb()
|
|
except Exception:
|
|
pass
|
|
# Stop typing indicator
|
|
typing_task.cancel()
|
|
try:
|
|
await typing_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Also cancel any platform-level persistent typing tasks (e.g. Discord)
|
|
# that may have been recreated by _keep_typing after the last stop_typing()
|
|
try:
|
|
if hasattr(self, "stop_typing"):
|
|
await self.stop_typing(event.source.chat_id)
|
|
except Exception:
|
|
pass
|
|
# Late-arrival drain: a message may have arrived during the
|
|
# cleanup awaits above (typing_task cancel, stop_typing). Such
|
|
# messages passed the Level-1 guard (entry still live, Event
|
|
# possibly set) and landed in _pending_messages via the
|
|
# busy-handler path. Without this block, we would delete the
|
|
# active-session entry and the queued message would be silently
|
|
# dropped (user never gets a reply).
|
|
late_pending = self._pending_messages.pop(session_key, None)
|
|
if late_pending is not None:
|
|
logger.debug(
|
|
"[%s] Late-arrival pending message during cleanup — spawning drain task",
|
|
self.name,
|
|
)
|
|
_active = self._active_sessions.get(session_key)
|
|
if _active is not None:
|
|
_active.clear()
|
|
drain_task = asyncio.create_task(
|
|
self._process_message_background(late_pending, session_key)
|
|
)
|
|
try:
|
|
self._background_tasks.add(drain_task)
|
|
drain_task.add_done_callback(self._background_tasks.discard)
|
|
except TypeError:
|
|
# Tests stub create_task() with non-hashable sentinels; tolerate.
|
|
pass
|
|
# Leave _active_sessions[session_key] populated — the drain
|
|
# task's own lifecycle will clean it up.
|
|
return
|
|
# Clean up session tracking
|
|
if session_key in self._active_sessions:
|
|
del self._active_sessions[session_key]
|
|
|
|
async def cancel_background_tasks(self) -> None:
|
|
"""Cancel any in-flight background message-processing tasks.
|
|
|
|
Used during gateway shutdown/replacement so active sessions from the old
|
|
process do not keep running after adapters are being torn down.
|
|
"""
|
|
# Loop until no new tasks appear. Without this, a message
|
|
# arriving during the `await asyncio.gather` below would spawn
|
|
# a fresh _process_message_background task (added to
|
|
# self._background_tasks at line ~1668 via handle_message),
|
|
# and the _background_tasks.clear() at the end of this method
|
|
# would drop the reference — the task runs untracked against a
|
|
# disconnecting adapter, logs send-failures, and may linger
|
|
# until it completes on its own. Retrying the drain until the
|
|
# task set stabilizes closes the window.
|
|
MAX_DRAIN_ROUNDS = 5
|
|
for _ in range(MAX_DRAIN_ROUNDS):
|
|
tasks = [task for task in self._background_tasks if not task.done()]
|
|
if not tasks:
|
|
break
|
|
for task in tasks:
|
|
self._expected_cancelled_tasks.add(task)
|
|
task.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
# Loop: late-arrival tasks spawned during the gather above
|
|
# will be in self._background_tasks now. Re-check.
|
|
self._background_tasks.clear()
|
|
self._expected_cancelled_tasks.clear()
|
|
self._pending_messages.clear()
|
|
self._active_sessions.clear()
|
|
|
|
def has_pending_interrupt(self, session_key: str) -> bool:
|
|
"""Check if there's a pending interrupt for a session."""
|
|
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
|
|
|
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
|
"""Get and clear any pending message for a session."""
|
|
return self._pending_messages.pop(session_key, None)
|
|
|
|
def build_source(
|
|
self,
|
|
chat_id: str,
|
|
chat_name: Optional[str] = None,
|
|
chat_type: str = "dm",
|
|
user_id: Optional[str] = None,
|
|
user_name: Optional[str] = None,
|
|
thread_id: Optional[str] = None,
|
|
chat_topic: Optional[str] = None,
|
|
user_id_alt: Optional[str] = None,
|
|
chat_id_alt: Optional[str] = None,
|
|
is_bot: bool = False,
|
|
) -> SessionSource:
|
|
"""Helper to build a SessionSource for this platform."""
|
|
# Normalize empty topic to None
|
|
if chat_topic is not None and not chat_topic.strip():
|
|
chat_topic = None
|
|
return SessionSource(
|
|
platform=self.platform,
|
|
chat_id=str(chat_id),
|
|
chat_name=chat_name,
|
|
chat_type=chat_type,
|
|
user_id=str(user_id) if user_id else None,
|
|
user_name=user_name,
|
|
thread_id=str(thread_id) if thread_id else None,
|
|
chat_topic=chat_topic.strip() if chat_topic else None,
|
|
user_id_alt=user_id_alt,
|
|
chat_id_alt=chat_id_alt,
|
|
is_bot=is_bot,
|
|
)
|
|
|
|
@abstractmethod
|
|
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Get information about a chat/channel.
|
|
|
|
Returns dict with at least:
|
|
- name: Chat name
|
|
- type: "dm", "group", "channel"
|
|
"""
|
|
pass
|
|
|
|
def format_message(self, content: str) -> str:
|
|
"""
|
|
Format a message for this platform.
|
|
|
|
Override in subclasses to handle platform-specific formatting
|
|
(e.g., Telegram MarkdownV2, Discord markdown).
|
|
|
|
Default implementation returns content as-is.
|
|
"""
|
|
return content
|
|
|
|
@staticmethod
|
|
def truncate_message(
|
|
content: str,
|
|
max_length: int = 4096,
|
|
len_fn: Optional["Callable[[str], int]"] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Split a long message into chunks, preserving code block boundaries.
|
|
|
|
When a split falls inside a triple-backtick code block, the fence is
|
|
closed at the end of the current chunk and reopened (with the original
|
|
language tag) at the start of the next chunk. Multi-chunk responses
|
|
receive indicators like ``(1/3)``.
|
|
|
|
Args:
|
|
content: The full message content
|
|
max_length: Maximum length per chunk (platform-specific)
|
|
len_fn: Optional length function for measuring string length.
|
|
Defaults to ``len`` (Unicode code-points). Pass
|
|
``utf16_len`` for platforms that measure message
|
|
length in UTF-16 code units (e.g. Telegram).
|
|
|
|
Returns:
|
|
List of message chunks
|
|
"""
|
|
_len = len_fn or len
|
|
if _len(content) <= max_length:
|
|
return [content]
|
|
|
|
INDICATOR_RESERVE = 10 # room for " (XX/XX)"
|
|
FENCE_CLOSE = "\n```"
|
|
|
|
chunks: List[str] = []
|
|
remaining = content
|
|
# When the previous chunk ended mid-code-block, this holds the
|
|
# language tag (possibly "") so we can reopen the fence.
|
|
carry_lang: Optional[str] = None
|
|
|
|
while remaining:
|
|
# If we're continuing a code block from the previous chunk,
|
|
# prepend a new opening fence with the same language tag.
|
|
prefix = f"```{carry_lang}\n" if carry_lang is not None else ""
|
|
|
|
# How much body text we can fit after accounting for the prefix,
|
|
# a potential closing fence, and the chunk indicator.
|
|
headroom = max_length - INDICATOR_RESERVE - _len(prefix) - _len(FENCE_CLOSE)
|
|
if headroom < 1:
|
|
headroom = max_length // 2
|
|
|
|
# Everything remaining fits in one final chunk
|
|
if _len(prefix) + _len(remaining) <= max_length - INDICATOR_RESERVE:
|
|
chunks.append(prefix + remaining)
|
|
break
|
|
|
|
# Find a natural split point (prefer newlines, then spaces).
|
|
# When _len != len (e.g. utf16_len for Telegram), headroom is
|
|
# measured in the custom unit. We need codepoint-based slice
|
|
# positions that stay within the custom-unit budget.
|
|
#
|
|
# _safe_slice_pos() maps a custom-unit budget to the largest
|
|
# codepoint offset whose custom length ≤ budget.
|
|
if _len is not len:
|
|
# Map headroom (custom units) → codepoint slice length
|
|
_cp_limit = _custom_unit_to_cp(remaining, headroom, _len)
|
|
else:
|
|
_cp_limit = headroom
|
|
region = remaining[:_cp_limit]
|
|
split_at = region.rfind("\n")
|
|
if split_at < _cp_limit // 2:
|
|
split_at = region.rfind(" ")
|
|
if split_at < 1:
|
|
split_at = _cp_limit
|
|
|
|
# Avoid splitting inside an inline code span (`...`).
|
|
# If the text before split_at has an odd number of unescaped
|
|
# backticks, the split falls inside inline code — the resulting
|
|
# chunk would have an unpaired backtick and any special characters
|
|
# (like parentheses) inside the broken span would be unescaped,
|
|
# causing MarkdownV2 parse errors on Telegram.
|
|
candidate = remaining[:split_at]
|
|
backtick_count = candidate.count("`") - candidate.count("\\`")
|
|
if backtick_count % 2 == 1:
|
|
# Find the last unescaped backtick and split before it
|
|
last_bt = candidate.rfind("`")
|
|
while last_bt > 0 and candidate[last_bt - 1] == "\\":
|
|
last_bt = candidate.rfind("`", 0, last_bt)
|
|
if last_bt > 0:
|
|
# Try to find a space or newline just before the backtick
|
|
safe_split = candidate.rfind(" ", 0, last_bt)
|
|
nl_split = candidate.rfind("\n", 0, last_bt)
|
|
safe_split = max(safe_split, nl_split)
|
|
if safe_split > _cp_limit // 4:
|
|
split_at = safe_split
|
|
|
|
chunk_body = remaining[:split_at]
|
|
remaining = remaining[split_at:].lstrip()
|
|
|
|
full_chunk = prefix + chunk_body
|
|
|
|
# Walk only the chunk_body (not the prefix we prepended) to
|
|
# determine whether we end inside an open code block.
|
|
in_code = carry_lang is not None
|
|
lang = carry_lang or ""
|
|
for line in chunk_body.split("\n"):
|
|
stripped = line.strip()
|
|
if stripped.startswith("```"):
|
|
if in_code:
|
|
in_code = False
|
|
lang = ""
|
|
else:
|
|
in_code = True
|
|
tag = stripped[3:].strip()
|
|
lang = tag.split()[0] if tag else ""
|
|
|
|
if in_code:
|
|
# Close the orphaned fence so the chunk is valid on its own
|
|
full_chunk += FENCE_CLOSE
|
|
carry_lang = lang
|
|
else:
|
|
carry_lang = None
|
|
|
|
chunks.append(full_chunk)
|
|
|
|
# Append chunk indicators when the response spans multiple messages
|
|
if len(chunks) > 1:
|
|
total = len(chunks)
|
|
chunks = [
|
|
f"{chunk} ({i + 1}/{total})" for i, chunk in enumerate(chunks)
|
|
]
|
|
|
|
return chunks
|