mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Add a second WeCom integration mode for regular enterprise self-built applications. Unlike the existing bot/websocket adapter (wecom.py), this handles WeCom's standard callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the adapter decrypts, queues for the agent, and immediately acknowledges. The agent's reply is delivered proactively via the message/send API. Key design choice: always acknowledge immediately and use proactive send — agent sessions take 3-30 minutes, so the 5-second inline reply window is never useful. The original PR's Future/pending-reply machinery was removed in favour of this simpler architecture. Features: - AES-CBC encrypt/decrypt (BizMsgCrypt-compatible) - Multi-app routing scoped by corp_id:user_id - Legacy bare user_id fallback for backward compat - Access-token management with auto-refresh - WECOM_CALLBACK_* env var overrides - Port-in-use pre-check before binding - Health endpoint at /health Salvaged from PR #7774 by @chqchshj. Simplified by removing the inline reply Future system and fixing: secrets.choice for nonce generation, immediate plain-text acknowledgment (not encrypted XML containing 'success'), and initial token refresh error handling.
387 lines
15 KiB
Python
387 lines
15 KiB
Python
"""WeCom callback-mode adapter for self-built enterprise applications.
|
|
|
|
Unlike the bot/websocket adapter in ``wecom.py``, this handles the standard
|
|
WeCom callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the
|
|
adapter decrypts it, queues the message for the agent, and immediately
|
|
acknowledges. The agent's reply is delivered later via the proactive
|
|
``message/send`` API using an access-token.
|
|
|
|
Supports multiple self-built apps under one gateway instance, scoped by
|
|
``corp_id:user_id`` to avoid cross-corp collisions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import socket as _socket
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
from xml.etree import ElementTree as ET
|
|
|
|
try:
|
|
from aiohttp import web
|
|
|
|
AIOHTTP_AVAILABLE = True
|
|
except ImportError:
|
|
web = None # type: ignore[assignment]
|
|
AIOHTTP_AVAILABLE = False
|
|
|
|
try:
|
|
import httpx
|
|
|
|
HTTPX_AVAILABLE = True
|
|
except ImportError:
|
|
httpx = None # type: ignore[assignment]
|
|
HTTPX_AVAILABLE = False
|
|
|
|
from gateway.config import Platform, PlatformConfig
|
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
|
from gateway.platforms.wecom_crypto import WXBizMsgCrypt, WeComCryptoError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_HOST = "0.0.0.0"
|
|
DEFAULT_PORT = 8645
|
|
DEFAULT_PATH = "/wecom/callback"
|
|
ACCESS_TOKEN_TTL_SECONDS = 7200
|
|
MESSAGE_DEDUP_TTL_SECONDS = 300
|
|
|
|
|
|
def check_wecom_callback_requirements() -> bool:
|
|
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
|
|
|
|
|
class WecomCallbackAdapter(BasePlatformAdapter):
|
|
def __init__(self, config: PlatformConfig):
|
|
super().__init__(config, Platform.WECOM_CALLBACK)
|
|
extra = config.extra or {}
|
|
self._host = str(extra.get("host") or DEFAULT_HOST)
|
|
self._port = int(extra.get("port") or DEFAULT_PORT)
|
|
self._path = str(extra.get("path") or DEFAULT_PATH)
|
|
self._apps: List[Dict[str, Any]] = self._normalize_apps(extra)
|
|
self._runner: Optional[web.AppRunner] = None
|
|
self._site: Optional[web.TCPSite] = None
|
|
self._app: Optional[web.Application] = None
|
|
self._http_client: Optional[httpx.AsyncClient] = None
|
|
self._message_queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
|
|
self._poll_task: Optional[asyncio.Task] = None
|
|
self._seen_messages: Dict[str, float] = {}
|
|
self._user_app_map: Dict[str, str] = {}
|
|
self._access_tokens: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# App normalisation
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _user_app_key(corp_id: str, user_id: str) -> str:
|
|
return f"{corp_id}:{user_id}" if corp_id else user_id
|
|
|
|
@staticmethod
|
|
def _normalize_apps(extra: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
apps = extra.get("apps")
|
|
if isinstance(apps, list) and apps:
|
|
return [dict(app) for app in apps if isinstance(app, dict)]
|
|
if extra.get("corp_id"):
|
|
return [
|
|
{
|
|
"name": extra.get("name") or "default",
|
|
"corp_id": extra.get("corp_id", ""),
|
|
"corp_secret": extra.get("corp_secret", ""),
|
|
"agent_id": str(extra.get("agent_id", "")),
|
|
"token": extra.get("token", ""),
|
|
"encoding_aes_key": extra.get("encoding_aes_key", ""),
|
|
}
|
|
]
|
|
return []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Lifecycle
|
|
# ------------------------------------------------------------------
|
|
|
|
async def connect(self) -> bool:
|
|
if not self._apps:
|
|
logger.warning("[WecomCallback] No callback apps configured")
|
|
return False
|
|
if not check_wecom_callback_requirements():
|
|
logger.warning("[WecomCallback] aiohttp/httpx not installed")
|
|
return False
|
|
|
|
# Quick port-in-use check.
|
|
try:
|
|
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as sock:
|
|
sock.settimeout(1)
|
|
sock.connect(("127.0.0.1", self._port))
|
|
logger.error("[WecomCallback] Port %d already in use", self._port)
|
|
return False
|
|
except (ConnectionRefusedError, OSError):
|
|
pass
|
|
|
|
try:
|
|
self._http_client = httpx.AsyncClient(timeout=20.0)
|
|
self._app = web.Application()
|
|
self._app.router.add_get("/health", self._handle_health)
|
|
self._app.router.add_get(self._path, self._handle_verify)
|
|
self._app.router.add_post(self._path, self._handle_callback)
|
|
self._runner = web.AppRunner(self._app)
|
|
await self._runner.setup()
|
|
self._site = web.TCPSite(self._runner, self._host, self._port)
|
|
await self._site.start()
|
|
self._poll_task = asyncio.create_task(self._poll_loop())
|
|
self._mark_connected()
|
|
logger.info(
|
|
"[WecomCallback] HTTP server listening on %s:%s%s",
|
|
self._host, self._port, self._path,
|
|
)
|
|
for app in self._apps:
|
|
try:
|
|
await self._refresh_access_token(app)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"[WecomCallback] Initial token refresh failed for app '%s': %s",
|
|
app.get("name", "default"), exc,
|
|
)
|
|
return True
|
|
except Exception:
|
|
await self._cleanup()
|
|
logger.exception("[WecomCallback] Failed to start")
|
|
return False
|
|
|
|
async def disconnect(self) -> None:
|
|
self._running = False
|
|
if self._poll_task:
|
|
self._poll_task.cancel()
|
|
try:
|
|
await self._poll_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._poll_task = None
|
|
await self._cleanup()
|
|
self._mark_disconnected()
|
|
logger.info("[WecomCallback] Disconnected")
|
|
|
|
async def _cleanup(self) -> None:
|
|
self._site = None
|
|
if self._runner:
|
|
await self._runner.cleanup()
|
|
self._runner = None
|
|
self._app = None
|
|
if self._http_client:
|
|
await self._http_client.aclose()
|
|
self._http_client = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Outbound: proactive send via access-token API
|
|
# ------------------------------------------------------------------
|
|
|
|
async def send(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> SendResult:
|
|
app = self._resolve_app_for_chat(chat_id)
|
|
touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id
|
|
try:
|
|
token = await self._get_access_token(app)
|
|
payload = {
|
|
"touser": touser,
|
|
"msgtype": "text",
|
|
"agentid": int(str(app.get("agent_id") or 0)),
|
|
"text": {"content": content[:2048]},
|
|
"safe": 0,
|
|
}
|
|
resp = await self._http_client.post(
|
|
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}",
|
|
json=payload,
|
|
)
|
|
data = resp.json()
|
|
if data.get("errcode") != 0:
|
|
return SendResult(success=False, error=str(data))
|
|
return SendResult(
|
|
success=True,
|
|
message_id=str(data.get("msgid", "")),
|
|
raw_response=data,
|
|
)
|
|
except Exception as exc:
|
|
return SendResult(success=False, error=str(exc))
|
|
|
|
def _resolve_app_for_chat(self, chat_id: str) -> Dict[str, Any]:
|
|
"""Pick the app associated with *chat_id*, falling back sensibly."""
|
|
app_name = self._user_app_map.get(chat_id)
|
|
if not app_name and ":" not in chat_id:
|
|
# Legacy bare user_id — try to find a unique match.
|
|
matching = [k for k in self._user_app_map if k.endswith(f":{chat_id}")]
|
|
if len(matching) == 1:
|
|
app_name = self._user_app_map.get(matching[0])
|
|
app = self._get_app_by_name(app_name) if app_name else None
|
|
return app or self._apps[0]
|
|
|
|
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
|
return {"name": chat_id, "type": "dm"}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Inbound: HTTP callback handlers
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _handle_health(self, request: web.Request) -> web.Response:
|
|
return web.json_response({"status": "ok", "platform": "wecom_callback"})
|
|
|
|
async def _handle_verify(self, request: web.Request) -> web.Response:
|
|
"""GET endpoint — WeCom URL verification handshake."""
|
|
msg_signature = request.query.get("msg_signature", "")
|
|
timestamp = request.query.get("timestamp", "")
|
|
nonce = request.query.get("nonce", "")
|
|
echostr = request.query.get("echostr", "")
|
|
for app in self._apps:
|
|
try:
|
|
crypt = self._crypt_for_app(app)
|
|
plain = crypt.verify_url(msg_signature, timestamp, nonce, echostr)
|
|
return web.Response(text=plain, content_type="text/plain")
|
|
except Exception:
|
|
continue
|
|
return web.Response(status=403, text="signature verification failed")
|
|
|
|
async def _handle_callback(self, request: web.Request) -> web.Response:
|
|
"""POST endpoint — receive an encrypted message callback."""
|
|
msg_signature = request.query.get("msg_signature", "")
|
|
timestamp = request.query.get("timestamp", "")
|
|
nonce = request.query.get("nonce", "")
|
|
body = await request.text()
|
|
|
|
for app in self._apps:
|
|
try:
|
|
decrypted = self._decrypt_request(
|
|
app, body, msg_signature, timestamp, nonce,
|
|
)
|
|
event = self._build_event(app, decrypted)
|
|
if event is not None:
|
|
# Record which app this user belongs to.
|
|
if event.source and event.source.user_id:
|
|
map_key = self._user_app_key(
|
|
str(app.get("corp_id") or ""), event.source.user_id,
|
|
)
|
|
self._user_app_map[map_key] = app["name"]
|
|
await self._message_queue.put(event)
|
|
# Immediately acknowledge — the agent's reply will arrive
|
|
# later via the proactive message/send API.
|
|
return web.Response(text="success", content_type="text/plain")
|
|
except WeComCryptoError:
|
|
continue
|
|
except Exception:
|
|
logger.exception("[WecomCallback] Error handling message")
|
|
break
|
|
return web.Response(status=400, text="invalid callback payload")
|
|
|
|
async def _poll_loop(self) -> None:
|
|
"""Drain the message queue and dispatch to the gateway runner."""
|
|
while True:
|
|
event = await self._message_queue.get()
|
|
try:
|
|
task = asyncio.create_task(self.handle_message(event))
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
except Exception:
|
|
logger.exception("[WecomCallback] Failed to enqueue event")
|
|
|
|
# ------------------------------------------------------------------
|
|
# XML / crypto helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _decrypt_request(
|
|
self, app: Dict[str, Any], body: str,
|
|
msg_signature: str, timestamp: str, nonce: str,
|
|
) -> str:
|
|
root = ET.fromstring(body)
|
|
encrypt = root.findtext("Encrypt", default="")
|
|
crypt = self._crypt_for_app(app)
|
|
return crypt.decrypt(msg_signature, timestamp, nonce, encrypt).decode("utf-8")
|
|
|
|
def _build_event(self, app: Dict[str, Any], xml_text: str) -> Optional[MessageEvent]:
|
|
root = ET.fromstring(xml_text)
|
|
msg_type = (root.findtext("MsgType") or "").lower()
|
|
# Silently acknowledge lifecycle events.
|
|
if msg_type == "event":
|
|
event_name = (root.findtext("Event") or "").lower()
|
|
if event_name in {"enter_agent", "subscribe"}:
|
|
return None
|
|
if msg_type not in {"text", "event"}:
|
|
return None
|
|
|
|
user_id = root.findtext("FromUserName", default="")
|
|
corp_id = root.findtext("ToUserName", default=app.get("corp_id", ""))
|
|
scoped_chat_id = self._user_app_key(corp_id, user_id)
|
|
content = root.findtext("Content", default="").strip()
|
|
if not content and msg_type == "event":
|
|
content = "/start"
|
|
msg_id = (
|
|
root.findtext("MsgId")
|
|
or f"{user_id}:{root.findtext('CreateTime', default='0')}"
|
|
)
|
|
source = self.build_source(
|
|
chat_id=scoped_chat_id,
|
|
chat_name=user_id,
|
|
chat_type="dm",
|
|
user_id=user_id,
|
|
user_name=user_id,
|
|
)
|
|
return MessageEvent(
|
|
text=content,
|
|
message_type=MessageType.TEXT,
|
|
source=source,
|
|
raw_message=xml_text,
|
|
message_id=msg_id,
|
|
)
|
|
|
|
def _crypt_for_app(self, app: Dict[str, Any]) -> WXBizMsgCrypt:
|
|
return WXBizMsgCrypt(
|
|
token=str(app.get("token") or ""),
|
|
encoding_aes_key=str(app.get("encoding_aes_key") or ""),
|
|
receive_id=str(app.get("corp_id") or ""),
|
|
)
|
|
|
|
def _get_app_by_name(self, name: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
if not name:
|
|
return None
|
|
for app in self._apps:
|
|
if app.get("name") == name:
|
|
return app
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Access-token management
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_access_token(self, app: Dict[str, Any]) -> str:
|
|
cached = self._access_tokens.get(app["name"])
|
|
now = time.time()
|
|
if cached and cached.get("expires_at", 0) > now + 60:
|
|
return cached["token"]
|
|
return await self._refresh_access_token(app)
|
|
|
|
async def _refresh_access_token(self, app: Dict[str, Any]) -> str:
|
|
resp = await self._http_client.get(
|
|
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
|
params={
|
|
"corpid": app.get("corp_id"),
|
|
"corpsecret": app.get("corp_secret"),
|
|
},
|
|
)
|
|
data = resp.json()
|
|
if data.get("errcode") != 0:
|
|
raise RuntimeError(f"WeCom token refresh failed: {data}")
|
|
token = data["access_token"]
|
|
expires_in = int(data.get("expires_in", ACCESS_TOKEN_TTL_SECONDS))
|
|
self._access_tokens[app["name"]] = {
|
|
"token": token,
|
|
"expires_at": time.time() + expires_in,
|
|
}
|
|
logger.info(
|
|
"[WecomCallback] Token refreshed for app '%s' (corp=%s), expires in %ss",
|
|
app.get("name", "default"),
|
|
app.get("corp_id", ""),
|
|
expires_in,
|
|
)
|
|
return token
|