mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(gateway): add WeCom callback-mode adapter for self-built apps
Add a second WeCom integration mode for regular enterprise self-built applications. Unlike the existing bot/websocket adapter (wecom.py), this handles WeCom's standard callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the adapter decrypts, queues for the agent, and immediately acknowledges. The agent's reply is delivered proactively via the message/send API. Key design choice: always acknowledge immediately and use proactive send — agent sessions take 3-30 minutes, so the 5-second inline reply window is never useful. The original PR's Future/pending-reply machinery was removed in favour of this simpler architecture. Features: - AES-CBC encrypt/decrypt (BizMsgCrypt-compatible) - Multi-app routing scoped by corp_id:user_id - Legacy bare user_id fallback for backward compat - Access-token management with auto-refresh - WECOM_CALLBACK_* env var overrides - Port-in-use pre-check before binding - Health endpoint at /health Salvaged from PR #7774 by @chqchshj. Simplified by removing the inline reply Future system and fixing: secrets.choice for nonce generation, immediate plain-text acknowledgment (not encrypted XML containing 'success'), and initial token refresh error handling.
This commit is contained in:
parent
90352b2adf
commit
5f0caf54d6
13 changed files with 800 additions and 5 deletions
387
gateway/platforms/wecom_callback.py
Normal file
387
gateway/platforms/wecom_callback.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""WeCom callback-mode adapter for self-built enterprise applications.
|
||||
|
||||
Unlike the bot/websocket adapter in ``wecom.py``, this handles the standard
|
||||
WeCom callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the
|
||||
adapter decrypts it, queues the message for the agent, and immediately
|
||||
acknowledges. The agent's reply is delivered later via the proactive
|
||||
``message/send`` API using an access-token.
|
||||
|
||||
Supports multiple self-built apps under one gateway instance, scoped by
|
||||
``corp_id:user_id`` to avoid cross-corp collisions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket as _socket
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
web = None # type: ignore[assignment]
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import httpx
|
||||
|
||||
HTTPX_AVAILABLE = True
|
||||
except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
HTTPX_AVAILABLE = False
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.platforms.wecom_crypto import WXBizMsgCrypt, WeComCryptoError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HOST = "0.0.0.0"
|
||||
DEFAULT_PORT = 8645
|
||||
DEFAULT_PATH = "/wecom/callback"
|
||||
ACCESS_TOKEN_TTL_SECONDS = 7200
|
||||
MESSAGE_DEDUP_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def check_wecom_callback_requirements() -> bool:
|
||||
return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE
|
||||
|
||||
|
||||
class WecomCallbackAdapter(BasePlatformAdapter):
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WECOM_CALLBACK)
|
||||
extra = config.extra or {}
|
||||
self._host = str(extra.get("host") or DEFAULT_HOST)
|
||||
self._port = int(extra.get("port") or DEFAULT_PORT)
|
||||
self._path = str(extra.get("path") or DEFAULT_PATH)
|
||||
self._apps: List[Dict[str, Any]] = self._normalize_apps(extra)
|
||||
self._runner: Optional[web.AppRunner] = None
|
||||
self._site: Optional[web.TCPSite] = None
|
||||
self._app: Optional[web.Application] = None
|
||||
self._http_client: Optional[httpx.AsyncClient] = None
|
||||
self._message_queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._user_app_map: Dict[str, str] = {}
|
||||
self._access_tokens: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# App normalisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _user_app_key(corp_id: str, user_id: str) -> str:
|
||||
return f"{corp_id}:{user_id}" if corp_id else user_id
|
||||
|
||||
@staticmethod
|
||||
def _normalize_apps(extra: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
apps = extra.get("apps")
|
||||
if isinstance(apps, list) and apps:
|
||||
return [dict(app) for app in apps if isinstance(app, dict)]
|
||||
if extra.get("corp_id"):
|
||||
return [
|
||||
{
|
||||
"name": extra.get("name") or "default",
|
||||
"corp_id": extra.get("corp_id", ""),
|
||||
"corp_secret": extra.get("corp_secret", ""),
|
||||
"agent_id": str(extra.get("agent_id", "")),
|
||||
"token": extra.get("token", ""),
|
||||
"encoding_aes_key": extra.get("encoding_aes_key", ""),
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
if not self._apps:
|
||||
logger.warning("[WecomCallback] No callback apps configured")
|
||||
return False
|
||||
if not check_wecom_callback_requirements():
|
||||
logger.warning("[WecomCallback] aiohttp/httpx not installed")
|
||||
return False
|
||||
|
||||
# Quick port-in-use check.
|
||||
try:
|
||||
with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(1)
|
||||
sock.connect(("127.0.0.1", self._port))
|
||||
logger.error("[WecomCallback] Port %d already in use", self._port)
|
||||
return False
|
||||
except (ConnectionRefusedError, OSError):
|
||||
pass
|
||||
|
||||
try:
|
||||
self._http_client = httpx.AsyncClient(timeout=20.0)
|
||||
self._app = web.Application()
|
||||
self._app.router.add_get("/health", self._handle_health)
|
||||
self._app.router.add_get(self._path, self._handle_verify)
|
||||
self._app.router.add_post(self._path, self._handle_callback)
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(self._runner, self._host, self._port)
|
||||
await self._site.start()
|
||||
self._poll_task = asyncio.create_task(self._poll_loop())
|
||||
self._mark_connected()
|
||||
logger.info(
|
||||
"[WecomCallback] HTTP server listening on %s:%s%s",
|
||||
self._host, self._port, self._path,
|
||||
)
|
||||
for app in self._apps:
|
||||
try:
|
||||
await self._refresh_access_token(app)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[WecomCallback] Initial token refresh failed for app '%s': %s",
|
||||
app.get("name", "default"), exc,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
await self._cleanup()
|
||||
logger.exception("[WecomCallback] Failed to start")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
if self._poll_task:
|
||||
self._poll_task.cancel()
|
||||
try:
|
||||
await self._poll_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._poll_task = None
|
||||
await self._cleanup()
|
||||
self._mark_disconnected()
|
||||
logger.info("[WecomCallback] Disconnected")
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
self._site = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._app = None
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: proactive send via access-token API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
app = self._resolve_app_for_chat(chat_id)
|
||||
touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id
|
||||
try:
|
||||
token = await self._get_access_token(app)
|
||||
payload = {
|
||||
"touser": touser,
|
||||
"msgtype": "text",
|
||||
"agentid": int(str(app.get("agent_id") or 0)),
|
||||
"text": {"content": content[:2048]},
|
||||
"safe": 0,
|
||||
}
|
||||
resp = await self._http_client.post(
|
||||
f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}",
|
||||
json=payload,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
return SendResult(success=False, error=str(data))
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=str(data.get("msgid", "")),
|
||||
raw_response=data,
|
||||
)
|
||||
except Exception as exc:
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
def _resolve_app_for_chat(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Pick the app associated with *chat_id*, falling back sensibly."""
|
||||
app_name = self._user_app_map.get(chat_id)
|
||||
if not app_name and ":" not in chat_id:
|
||||
# Legacy bare user_id — try to find a unique match.
|
||||
matching = [k for k in self._user_app_map if k.endswith(f":{chat_id}")]
|
||||
if len(matching) == 1:
|
||||
app_name = self._user_app_map.get(matching[0])
|
||||
app = self._get_app_by_name(app_name) if app_name else None
|
||||
return app or self._apps[0]
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: HTTP callback handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_health(self, request: web.Request) -> web.Response:
|
||||
return web.json_response({"status": "ok", "platform": "wecom_callback"})
|
||||
|
||||
async def _handle_verify(self, request: web.Request) -> web.Response:
|
||||
"""GET endpoint — WeCom URL verification handshake."""
|
||||
msg_signature = request.query.get("msg_signature", "")
|
||||
timestamp = request.query.get("timestamp", "")
|
||||
nonce = request.query.get("nonce", "")
|
||||
echostr = request.query.get("echostr", "")
|
||||
for app in self._apps:
|
||||
try:
|
||||
crypt = self._crypt_for_app(app)
|
||||
plain = crypt.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||
return web.Response(text=plain, content_type="text/plain")
|
||||
except Exception:
|
||||
continue
|
||||
return web.Response(status=403, text="signature verification failed")
|
||||
|
||||
async def _handle_callback(self, request: web.Request) -> web.Response:
|
||||
"""POST endpoint — receive an encrypted message callback."""
|
||||
msg_signature = request.query.get("msg_signature", "")
|
||||
timestamp = request.query.get("timestamp", "")
|
||||
nonce = request.query.get("nonce", "")
|
||||
body = await request.text()
|
||||
|
||||
for app in self._apps:
|
||||
try:
|
||||
decrypted = self._decrypt_request(
|
||||
app, body, msg_signature, timestamp, nonce,
|
||||
)
|
||||
event = self._build_event(app, decrypted)
|
||||
if event is not None:
|
||||
# Record which app this user belongs to.
|
||||
if event.source and event.source.user_id:
|
||||
map_key = self._user_app_key(
|
||||
str(app.get("corp_id") or ""), event.source.user_id,
|
||||
)
|
||||
self._user_app_map[map_key] = app["name"]
|
||||
await self._message_queue.put(event)
|
||||
# Immediately acknowledge — the agent's reply will arrive
|
||||
# later via the proactive message/send API.
|
||||
return web.Response(text="success", content_type="text/plain")
|
||||
except WeComCryptoError:
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("[WecomCallback] Error handling message")
|
||||
break
|
||||
return web.Response(status=400, text="invalid callback payload")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Drain the message queue and dispatch to the gateway runner."""
|
||||
while True:
|
||||
event = await self._message_queue.get()
|
||||
try:
|
||||
task = asyncio.create_task(self.handle_message(event))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
except Exception:
|
||||
logger.exception("[WecomCallback] Failed to enqueue event")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# XML / crypto helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _decrypt_request(
|
||||
self, app: Dict[str, Any], body: str,
|
||||
msg_signature: str, timestamp: str, nonce: str,
|
||||
) -> str:
|
||||
root = ET.fromstring(body)
|
||||
encrypt = root.findtext("Encrypt", default="")
|
||||
crypt = self._crypt_for_app(app)
|
||||
return crypt.decrypt(msg_signature, timestamp, nonce, encrypt).decode("utf-8")
|
||||
|
||||
def _build_event(self, app: Dict[str, Any], xml_text: str) -> Optional[MessageEvent]:
|
||||
root = ET.fromstring(xml_text)
|
||||
msg_type = (root.findtext("MsgType") or "").lower()
|
||||
# Silently acknowledge lifecycle events.
|
||||
if msg_type == "event":
|
||||
event_name = (root.findtext("Event") or "").lower()
|
||||
if event_name in {"enter_agent", "subscribe"}:
|
||||
return None
|
||||
if msg_type not in {"text", "event"}:
|
||||
return None
|
||||
|
||||
user_id = root.findtext("FromUserName", default="")
|
||||
corp_id = root.findtext("ToUserName", default=app.get("corp_id", ""))
|
||||
scoped_chat_id = self._user_app_key(corp_id, user_id)
|
||||
content = root.findtext("Content", default="").strip()
|
||||
if not content and msg_type == "event":
|
||||
content = "/start"
|
||||
msg_id = (
|
||||
root.findtext("MsgId")
|
||||
or f"{user_id}:{root.findtext('CreateTime', default='0')}"
|
||||
)
|
||||
source = self.build_source(
|
||||
chat_id=scoped_chat_id,
|
||||
chat_name=user_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
user_name=user_id,
|
||||
)
|
||||
return MessageEvent(
|
||||
text=content,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=xml_text,
|
||||
message_id=msg_id,
|
||||
)
|
||||
|
||||
def _crypt_for_app(self, app: Dict[str, Any]) -> WXBizMsgCrypt:
|
||||
return WXBizMsgCrypt(
|
||||
token=str(app.get("token") or ""),
|
||||
encoding_aes_key=str(app.get("encoding_aes_key") or ""),
|
||||
receive_id=str(app.get("corp_id") or ""),
|
||||
)
|
||||
|
||||
def _get_app_by_name(self, name: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
if not name:
|
||||
return None
|
||||
for app in self._apps:
|
||||
if app.get("name") == name:
|
||||
return app
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access-token management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_access_token(self, app: Dict[str, Any]) -> str:
|
||||
cached = self._access_tokens.get(app["name"])
|
||||
now = time.time()
|
||||
if cached and cached.get("expires_at", 0) > now + 60:
|
||||
return cached["token"]
|
||||
return await self._refresh_access_token(app)
|
||||
|
||||
async def _refresh_access_token(self, app: Dict[str, Any]) -> str:
|
||||
resp = await self._http_client.get(
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
||||
params={
|
||||
"corpid": app.get("corp_id"),
|
||||
"corpsecret": app.get("corp_secret"),
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
raise RuntimeError(f"WeCom token refresh failed: {data}")
|
||||
token = data["access_token"]
|
||||
expires_in = int(data.get("expires_in", ACCESS_TOKEN_TTL_SECONDS))
|
||||
self._access_tokens[app["name"]] = {
|
||||
"token": token,
|
||||
"expires_at": time.time() + expires_in,
|
||||
}
|
||||
logger.info(
|
||||
"[WecomCallback] Token refreshed for app '%s' (corp=%s), expires in %ss",
|
||||
app.get("name", "default"),
|
||||
app.get("corp_id", ""),
|
||||
expires_in,
|
||||
)
|
||||
return token
|
||||
Loading…
Add table
Add a link
Reference in a new issue