diff --git a/gateway/config.py b/gateway/config.py index 34ef31d7b0..3023b1b500 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -46,7 +46,13 @@ def _normalize_unauthorized_dm_behavior(value: Any, default: str = "pair") -> st class Platform(Enum): - """Supported messaging platforms.""" + """Supported messaging platforms. + + Built-in platforms have explicit members. Plugin platforms use dynamic + members created on-demand by ``_missing_()`` so that + ``Platform("irc")`` works without modifying this enum. Dynamic members + are cached in ``_value2member_map_`` for identity-stable comparisons. + """ LOCAL = "local" TELEGRAM = "telegram" DISCORD = "discord" @@ -66,6 +72,28 @@ class Platform(Enum): WEIXIN = "weixin" BLUEBUBBLES = "bluebubbles" + @classmethod + def _missing_(cls, value): + """Accept unknown platform names for plugin-registered adapters. + + Creates a pseudo-member cached in ``_value2member_map_`` so that + ``Platform("irc") is Platform("irc")`` holds True (identity-stable). + """ + if not isinstance(value, str) or not value.strip(): + return None + # Normalise to lowercase to avoid case mismatches in config + value = value.strip().lower() + # Check cache first (another call may have created it already) + if value in cls._value2member_map_: + return cls._value2member_map_[value] + pseudo = object.__new__(cls) + pseudo._value_ = value + pseudo._name_ = value.upper().replace("-", "_").replace(" ", "_") + # Cache so future lookups return the same object + cls._value2member_map_[value] = pseudo + cls._member_map_[pseudo._name_] = pseudo + return pseudo + @dataclass class HomeChannel: @@ -297,6 +325,17 @@ class GatewayConfig: # BlueBubbles uses extra dict for local server config elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): connected.append(platform) + else: + # Plugin-registered platform — delegate validation to the + # registry entry's validate_config if available. + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform.value) + if entry: + if entry.validate_config is None or entry.validate_config(config): + connected.append(platform) + except Exception: + pass # Registry not yet initialised during early import return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: diff --git a/gateway/platform_registry.py b/gateway/platform_registry.py new file mode 100644 index 0000000000..1279b61a7b --- /dev/null +++ b/gateway/platform_registry.py @@ -0,0 +1,169 @@ +""" +Platform Adapter Registry + +Allows platform adapters (built-in and plugin) to self-register so the gateway +can discover and instantiate them without hardcoded if/elif chains. + +Built-in adapters continue to use the existing if/elif in _create_adapter() +for now. Plugin adapters register here via PluginContext.register_platform() +and are looked up first -- if nothing is found the gateway falls through to +the legacy code path. + +Usage (plugin side): + + from gateway.platform_registry import platform_registry, PlatformEntry + + platform_registry.register(PlatformEntry( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=check_requirements, + validate_config=lambda cfg: bool(cfg.extra.get("server")), + required_env=["IRC_SERVER"], + install_hint="pip install irc", + )) + +Usage (gateway side): + + adapter = platform_registry.create_adapter("irc", platform_config) +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class PlatformEntry: + """Metadata and factory for a single platform adapter.""" + + # Identifier used in config.yaml (e.g. "irc", "viber"). + name: str + + # Human-readable label (e.g. "IRC", "Viber"). + label: str + + # Factory callable: receives a PlatformConfig, returns an adapter instance. + # Using a factory instead of a bare class lets plugins do custom init + # (e.g. passing extra kwargs, wrapping in try/except). + adapter_factory: Callable[[Any], Any] + + # Returns True when the platform's dependencies are available. + check_fn: Callable[[], bool] + + # Optional: given a PlatformConfig, is it properly configured? + # If None, the registry skips config validation and lets the adapter + # fail at connect() time with a descriptive error. + validate_config: Optional[Callable[[Any], bool]] = None + + # Env vars this platform needs (for ``hermes setup`` display). + required_env: list = field(default_factory=list) + + # Hint shown when check_fn returns False. + install_hint: str = "" + + # "builtin" or "plugin" + source: str = "plugin" + + +class PlatformRegistry: + """Central registry of platform adapters. + + Thread-safe for reads (dict lookups are atomic under GIL). + Writes happen at startup during sequential discovery. + """ + + def __init__(self) -> None: + self._entries: dict[str, PlatformEntry] = {} + + def register(self, entry: PlatformEntry) -> None: + """Register a platform adapter entry. + + If an entry with the same name exists, it is replaced (last writer + wins -- this lets plugins override built-in adapters if desired). + """ + if entry.name in self._entries: + prev = self._entries[entry.name] + logger.info( + "Platform '%s' re-registered (was %s, now %s)", + entry.name, + prev.source, + entry.source, + ) + self._entries[entry.name] = entry + logger.debug("Registered platform adapter: %s (%s)", entry.name, entry.source) + + def unregister(self, name: str) -> bool: + """Remove a platform entry. Returns True if it existed.""" + return self._entries.pop(name, None) is not None + + def get(self, name: str) -> Optional[PlatformEntry]: + """Look up a platform entry by name.""" + return self._entries.get(name) + + def all_entries(self) -> list[PlatformEntry]: + """Return all registered platform entries.""" + return list(self._entries.values()) + + def plugin_entries(self) -> list[PlatformEntry]: + """Return only plugin-registered platform entries.""" + return [e for e in self._entries.values() if e.source == "plugin"] + + def is_registered(self, name: str) -> bool: + return name in self._entries + + def create_adapter(self, name: str, config: Any) -> Optional[Any]: + """Create an adapter instance for the given platform name. + + Returns None if: + - No entry registered for *name* + - check_fn() returns False (missing deps) + - validate_config() returns False (misconfigured) + - The factory raises an exception + """ + entry = self._entries.get(name) + if entry is None: + return None + + if not entry.check_fn(): + hint = f" ({entry.install_hint})" if entry.install_hint else "" + logger.warning( + "Platform '%s' requirements not met%s", + entry.label, + hint, + ) + return None + + if entry.validate_config is not None: + try: + if not entry.validate_config(config): + logger.warning( + "Platform '%s' config validation failed", + entry.label, + ) + return None + except Exception as e: + logger.warning( + "Platform '%s' config validation error: %s", + entry.label, + e, + ) + return None + + try: + adapter = entry.adapter_factory(config) + return adapter + except Exception as e: + logger.error( + "Failed to create adapter for platform '%s': %s", + entry.label, + e, + exc_info=True, + ) + return None + + +# Module-level singleton +platform_registry = PlatformRegistry() diff --git a/gateway/run.py b/gateway/run.py index 469abe9ec0..9902f3641e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1967,7 +1967,11 @@ class GatewayRunner: platform: Platform, config: Any ) -> Optional[BasePlatformAdapter]: - """Create the appropriate adapter for a platform.""" + """Create the appropriate adapter for a platform. + + Checks the platform_registry first (plugin adapters), then falls + through to the built-in if/elif chain for core platforms. + """ if hasattr(config, "extra") and isinstance(config.extra, dict): config.extra.setdefault( "group_sessions_per_user", @@ -1978,6 +1982,16 @@ class GatewayRunner: getattr(self.config, "thread_sessions_per_user", False), ) + # ── Plugin-registered platforms (checked first) ────────────── + try: + from gateway.platform_registry import platform_registry + adapter = platform_registry.create_adapter(platform.value, config) + if adapter is not None: + return adapter + except Exception as e: + logger.debug("Platform registry lookup for '%s' failed: %s", platform.value, e) + # Fall through to built-in adapters below + if platform == Platform.TELEGRAM: from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements if not check_telegram_requirements(): diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 1545d15aad..a8866fef56 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -42,6 +42,8 @@ _EXTRA_ENV_KEYS = frozenset({ "WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY", "WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS", "BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD", + "IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", + "IRC_USE_TLS", "IRC_SERVER_PASSWORD", "IRC_NICKSERV_PASSWORD", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", @@ -1234,6 +1236,43 @@ OPTIONAL_ENV_VARS = { "password": False, "category": "messaging", }, + "IRC_SERVER": { + "description": "IRC server hostname (e.g. irc.libera.chat)", + "prompt": "IRC server", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_CHANNEL": { + "description": "IRC channel to join (e.g. #hermes)", + "prompt": "IRC channel", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_NICKNAME": { + "description": "Bot nickname on IRC (default: hermes-bot)", + "prompt": "IRC nickname", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_SERVER_PASSWORD": { + "description": "IRC server password (if required)", + "prompt": "IRC server password", + "url": None, + "password": True, + "category": "messaging", + "advanced": True, + }, + "IRC_NICKSERV_PASSWORD": { + "description": "NickServ password for nick identification", + "prompt": "NickServ password", + "url": None, + "password": True, + "category": "messaging", + "advanced": True, + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 94ec20836d..3e2cf5ed5c 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -244,6 +244,53 @@ class PluginContext: self.manifest.name, engine.name, ) + # -- platform adapter registration --------------------------------------- + + def register_platform( + self, + name: str, + label: str, + adapter_factory: Callable, + check_fn: Callable, + validate_config: Callable | None = None, + required_env: list | None = None, + install_hint: str = "", + ) -> None: + """Register a gateway platform adapter. + + The adapter_factory receives a ``PlatformConfig`` and returns a + ``BasePlatformAdapter`` subclass instance. The gateway calls + ``check_fn()`` before instantiation to verify dependencies. + + Example:: + + ctx.register_platform( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=lambda: True, + ) + """ + from gateway.platform_registry import platform_registry, PlatformEntry + + entry = PlatformEntry( + name=name, + label=label, + adapter_factory=adapter_factory, + check_fn=check_fn, + validate_config=validate_config, + required_env=required_env or [], + install_hint=install_hint, + source="plugin", + ) + platform_registry.register(entry) + self._manager._plugin_platform_names.add(name) + logger.debug( + "Plugin %s registered platform: %s", + self.manifest.name, + name, + ) + # -- hook registration -------------------------------------------------- def register_hook(self, hook_name: str, callback: Callable) -> None: @@ -275,6 +322,7 @@ class PluginManager: self._plugins: Dict[str, LoadedPlugin] = {} self._hooks: Dict[str, List[Callable]] = {} self._plugin_tool_names: Set[str] = set() + self._plugin_platform_names: Set[str] = set() self._cli_commands: Dict[str, dict] = {} self._context_engine = None # Set by a plugin via register_context_engine() self._discovered: bool = False diff --git a/plugins/platforms/irc/PLUGIN.yaml b/plugins/platforms/irc/PLUGIN.yaml new file mode 100644 index 0000000000..632d5b1746 --- /dev/null +++ b/plugins/platforms/irc/PLUGIN.yaml @@ -0,0 +1,12 @@ +name: irc-platform +version: 1.0.0 +description: > + IRC gateway adapter for Hermes Agent. + Connects to an IRC server and relays messages between an IRC channel + (or DMs) and the Hermes agent. No external dependencies — uses + Python's stdlib asyncio for the IRC protocol. +author: Nous Research +requires_env: + - IRC_SERVER + - IRC_CHANNEL + - IRC_NICKNAME diff --git a/plugins/platforms/irc/adapter.py b/plugins/platforms/irc/adapter.py new file mode 100644 index 0000000000..20a7bf7a5d --- /dev/null +++ b/plugins/platforms/irc/adapter.py @@ -0,0 +1,493 @@ +""" +IRC Platform Adapter for Hermes Agent. + +A plugin-based gateway adapter that connects to an IRC server and relays +messages to/from the Hermes agent. Zero external dependencies — uses +Python's stdlib asyncio for the IRC protocol. + +Configuration in config.yaml:: + + gateway: + platforms: + irc: + enabled: true + extra: + server: irc.libera.chat + port: 6697 + nickname: hermes-bot + channel: "#hermes" + use_tls: true + server_password: "" # optional server password + nickserv_password: "" # optional NickServ identification + allowed_users: [] # empty = allow all, or list of nicks + max_message_length: 450 # IRC line limit (safe default) + +Or via environment variables (overrides config.yaml): + IRC_SERVER, IRC_PORT, IRC_NICKNAME, IRC_CHANNEL, IRC_USE_TLS, + IRC_SERVER_PASSWORD, IRC_NICKSERV_PASSWORD +""" + +import asyncio +import logging +import os +import re +import ssl +import time +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy import: BasePlatformAdapter and friends live in the main repo. +# We import at function/class level to avoid import errors when the plugin +# is discovered but the gateway hasn't been fully initialised yet. +# --------------------------------------------------------------------------- + +from gateway.platforms.base import ( + BasePlatformAdapter, + SendResult, + MessageEvent, + MessageType, +) +from gateway.session import SessionSource +from gateway.config import PlatformConfig, Platform + + +def _ensure_imports(): + """No-op — kept for backward compatibility with any call sites.""" + pass + + +# --------------------------------------------------------------------------- +# IRC protocol helpers +# --------------------------------------------------------------------------- + +def _parse_irc_message(raw: str) -> dict: + """Parse a raw IRC protocol line into components. + + Returns dict with keys: prefix, command, params. + """ + prefix = "" + trailing = "" + + if raw.startswith(":"): + prefix, raw = raw[1:].split(" ", 1) + + if " :" in raw: + raw, trailing = raw.split(" :", 1) + + parts = raw.split() + command = parts[0] if parts else "" + params = parts[1:] if len(parts) > 1 else [] + if trailing: + params.append(trailing) + + return {"prefix": prefix, "command": command, "params": params} + + +def _extract_nick(prefix: str) -> str: + """Extract nickname from IRC prefix (nick!user@host).""" + return prefix.split("!")[0] if "!" in prefix else prefix + + +# --------------------------------------------------------------------------- +# IRC Adapter +# --------------------------------------------------------------------------- + +class IRCAdapter(BasePlatformAdapter): + """Async IRC adapter implementing the BasePlatformAdapter interface. + + This class is instantiated by the adapter_factory passed to + register_platform(). + """ + + def __init__(self, config, **kwargs): + platform = Platform("irc") + super().__init__(config=config, platform=platform) + + extra = getattr(config, "extra", {}) or {} + + # Connection settings (env vars override config.yaml) + self.server = os.getenv("IRC_SERVER") or extra.get("server", "") + self.port = int(os.getenv("IRC_PORT") or extra.get("port", 6697)) + self.nickname = os.getenv("IRC_NICKNAME") or extra.get("nickname", "hermes-bot") + self.channel = os.getenv("IRC_CHANNEL") or extra.get("channel", "") + self.use_tls = ( + os.getenv("IRC_USE_TLS", "").lower() in ("1", "true", "yes") + if os.getenv("IRC_USE_TLS") + else extra.get("use_tls", True) + ) + self.server_password = os.getenv("IRC_SERVER_PASSWORD") or extra.get("server_password", "") + self.nickserv_password = os.getenv("IRC_NICKSERV_PASSWORD") or extra.get("nickserv_password", "") + + # Auth + self.allowed_users: list = extra.get("allowed_users", []) + + # IRC limits + self.max_message_length = int(extra.get("max_message_length", 450)) + + # Runtime state + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._recv_task: Optional[asyncio.Task] = None + self._current_nick = self.nickname + self._registered = False # IRC registration complete + self._registration_event = asyncio.Event() + + @property + def name(self) -> str: + return "IRC" + + # ── Connection lifecycle ────────────────────────────────────────────── + + async def connect(self) -> bool: + """Connect to the IRC server, register, and join the channel.""" + if not self.server or not self.channel: + logger.error("IRC: server and channel must be configured") + self._set_fatal_error( + "config_missing", + "IRC_SERVER and IRC_CHANNEL must be set", + retryable=False, + ) + return False + + try: + ssl_ctx = None + if self.use_tls: + ssl_ctx = ssl.create_default_context() + + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(self.server, self.port, ssl=ssl_ctx), + timeout=30.0, + ) + except Exception as e: + logger.error("IRC: failed to connect to %s:%s — %s", self.server, self.port, e) + self._set_fatal_error("connect_failed", str(e), retryable=True) + return False + + # IRC registration sequence + if self.server_password: + await self._send_raw(f"PASS {self.server_password}") + await self._send_raw(f"NICK {self.nickname}") + await self._send_raw(f"USER {self.nickname} 0 * :Hermes Agent") + + # Start receive loop + self._recv_task = asyncio.create_task(self._receive_loop()) + + # Wait for registration (001 RPL_WELCOME) with timeout + try: + await asyncio.wait_for(self._registration_event.wait(), timeout=30.0) + except asyncio.TimeoutError: + logger.error("IRC: registration timed out") + await self.disconnect() + self._set_fatal_error("registration_timeout", "IRC server did not send RPL_WELCOME", retryable=True) + return False + + # NickServ identification + if self.nickserv_password: + await self._send_raw(f"PRIVMSG NickServ :IDENTIFY {self.nickserv_password}") + await asyncio.sleep(2) # Give NickServ time to process + + # Join channel + await self._send_raw(f"JOIN {self.channel}") + + self._mark_connected() + logger.info("IRC: connected to %s:%s as %s, joined %s", self.server, self.port, self._current_nick, self.channel) + return True + + async def disconnect(self) -> None: + """Quit and close the connection.""" + self._mark_disconnected() + if self._writer and not self._writer.is_closing(): + try: + await self._send_raw("QUIT :Hermes Agent shutting down") + await asyncio.sleep(0.5) + except Exception: + pass + try: + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + + self._reader = None + self._writer = None + self._registered = False + self._registration_event.clear() + + # ── Sending ─────────────────────────────────────────────────────────── + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + if not self._writer or self._writer.is_closing(): + return SendResult(success=False, error="Not connected") + + target = chat_id # channel name or nick for DMs + lines = self._split_message(content, target) + + for line in lines: + try: + await self._send_raw(f"PRIVMSG {target} :{line}") + # Basic rate limiting to avoid excess flood + await asyncio.sleep(0.3) + except Exception as e: + return SendResult(success=False, error=str(e)) + + return SendResult(success=True, message_id=str(int(time.time() * 1000))) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """IRC has no typing indicator — no-op.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + is_channel = chat_id.startswith("#") or chat_id.startswith("&") + return { + "name": chat_id, + "type": "group" if is_channel else "dm", + } + + # ── Message splitting ───────────────────────────────────────────────── + + def _split_message(self, content: str, target: str) -> List[str]: + """Split a long message into IRC-safe chunks. + + IRC has a ~512 byte line limit. After accounting for protocol + overhead (``PRIVMSG :``), we split content into chunks. + """ + # Strip markdown formatting that doesn't render in IRC + content = self._strip_markdown(content) + + overhead = len(f"PRIVMSG {target} :".encode("utf-8")) + 2 # +2 for \r\n + max_bytes = 510 - overhead + max_chars = min(self.max_message_length, max_bytes) + + lines: List[str] = [] + for paragraph in content.split("\n"): + if not paragraph.strip(): + continue + while len(paragraph) > max_chars: + # Find a space to break at + split_at = paragraph.rfind(" ", 0, max_chars) + if split_at < max_chars // 3: + split_at = max_chars + lines.append(paragraph[:split_at]) + paragraph = paragraph[split_at:].lstrip() + if paragraph.strip(): + lines.append(paragraph) + + return lines if lines else [""] + + @staticmethod + def _strip_markdown(text: str) -> str: + """Convert basic markdown to plain text for IRC.""" + # Bold: **text** or __text__ → text + text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) + text = re.sub(r"__(.+?)__", r"\1", text) + # Italic: *text* or _text_ → text + text = re.sub(r"\*(.+?)\*", r"\1", text) + text = re.sub(r"(? None: + """Send a raw IRC protocol line.""" + if not self._writer or self._writer.is_closing(): + return + encoded = (line + "\r\n").encode("utf-8") + self._writer.write(encoded) + await self._writer.drain() + + async def _receive_loop(self) -> None: + """Main receive loop — reads lines and dispatches them.""" + buffer = b"" + try: + while self._reader and not self._reader.at_eof(): + data = await self._reader.read(4096) + if not data: + break + buffer += data + while b"\r\n" in buffer: + line, buffer = buffer.split(b"\r\n", 1) + try: + decoded = line.decode("utf-8", errors="replace") + await self._handle_line(decoded) + except Exception as e: + logger.warning("IRC: error handling line: %s", e) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("IRC: receive loop error: %s", e) + finally: + if self.is_connected: + logger.warning("IRC: connection lost, marking disconnected") + self._set_fatal_error("connection_lost", "IRC connection closed unexpectedly", retryable=True) + await self._notify_fatal_error() + + async def _handle_line(self, raw: str) -> None: + """Dispatch a single IRC protocol line.""" + msg = _parse_irc_message(raw) + command = msg["command"] + params = msg["params"] + + # PING/PONG keepalive + if command == "PING": + payload = params[0] if params else "" + await self._send_raw(f"PONG :{payload}") + return + + # RPL_WELCOME (001) — registration complete + if command == "001": + self._registered = True + self._registration_event.set() + if params: + # Server may confirm our nick in the first param + self._current_nick = params[0] + return + + # ERR_NICKNAMEINUSE (433) — nick collision during registration + if command == "433": + self._current_nick = self.nickname + "_" + await self._send_raw(f"NICK {self._current_nick}") + return + + # PRIVMSG — incoming message (channel or DM) + if command == "PRIVMSG" and len(params) >= 2: + sender_nick = _extract_nick(msg["prefix"]) + target = params[0] + text = params[1] + + # Ignore our own messages + if sender_nick.lower() == self._current_nick.lower(): + return + + # CTCP ACTION (/me) — convert to text + if text.startswith("\x01ACTION ") and text.endswith("\x01"): + text = f"* {sender_nick} {text[8:-1]}" + + # Ignore other CTCP + if text.startswith("\x01"): + return + + # Determine if this is a channel message or DM + is_channel = target.startswith("#") or target.startswith("&") + chat_id = target if is_channel else sender_nick + chat_type = "group" if is_channel else "dm" + + # In channels, only respond if addressed (nick: or nick,) + if is_channel: + addressed = False + for prefix in (f"{self._current_nick}:", f"{self._current_nick},", + f"{self._current_nick} "): + if text.lower().startswith(prefix.lower()): + text = text[len(prefix):].strip() + addressed = True + break + if not addressed: + return # Ignore unaddressed channel messages + + # Auth check + if self.allowed_users and sender_nick not in self.allowed_users: + logger.debug("IRC: ignoring message from unauthorized user %s", sender_nick) + return + + await self._dispatch_message( + text=text, + chat_id=chat_id, + chat_type=chat_type, + user_id=sender_nick, + user_name=sender_nick, + ) + + # NICK — track our own nick changes + if command == "NICK" and _extract_nick(msg["prefix"]).lower() == self._current_nick.lower(): + if params: + self._current_nick = params[0] + + async def _dispatch_message( + self, + text: str, + chat_id: str, + chat_type: str, + user_id: str, + user_name: str, + ) -> None: + """Build a MessageEvent and hand it to the base class handler.""" + if not self._message_handler: + return + + source = self.build_source( + chat_id=chat_id, + chat_name=chat_id, + chat_type=chat_type, + user_id=user_id, + user_name=user_name, + ) + + event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id=str(int(time.time() * 1000)), + timestamp=__import__("datetime").datetime.now(), + ) + + await self.handle_message(event) + + +# --------------------------------------------------------------------------- +# Plugin registration +# --------------------------------------------------------------------------- + +def check_requirements() -> bool: + """Check if IRC is configured. + + Only requires the server and channel — no external pip packages needed. + """ + server = os.getenv("IRC_SERVER", "") + channel = os.getenv("IRC_CHANNEL", "") + # Also accept config.yaml-only configuration (no env vars). + # The gateway passes PlatformConfig; we just check env for the + # hermes setup / requirements check path. + return bool(server and channel) + + +def validate_config(config) -> bool: + """Validate that the platform config has enough info to connect.""" + extra = getattr(config, "extra", {}) or {} + server = os.getenv("IRC_SERVER") or extra.get("server", "") + channel = os.getenv("IRC_CHANNEL") or extra.get("channel", "") + return bool(server and channel) + + +def register(ctx): + """Plugin entry point — called by the Hermes plugin system.""" + ctx.register_platform( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=check_requirements, + validate_config=validate_config, + required_env=["IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"], + install_hint="No extra packages needed (stdlib only)", + ) diff --git a/tests/gateway/test_irc_adapter.py b/tests/gateway/test_irc_adapter.py new file mode 100644 index 0000000000..40f3e6a892 --- /dev/null +++ b/tests/gateway/test_irc_adapter.py @@ -0,0 +1,380 @@ +"""Tests for the IRC platform adapter plugin.""" + +import asyncio +import os +import sys +import pytest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +# Ensure the plugins directory is on sys.path for direct import +_REPO_ROOT = Path(__file__).resolve().parents[2] +_IRC_PLUGIN_DIR = _REPO_ROOT / "plugins" / "platforms" / "irc" +if str(_IRC_PLUGIN_DIR) not in sys.path: + sys.path.insert(0, str(_IRC_PLUGIN_DIR)) + + +# ── IRC protocol helpers ───────────────────────────────────────────────── + +from adapter import _parse_irc_message, _extract_nick + + +class TestIRCProtocolHelpers: + + def test_parse_simple_command(self): + msg = _parse_irc_message("PING :server.example.com") + assert msg["command"] == "PING" + assert msg["params"] == ["server.example.com"] + assert msg["prefix"] == "" + + def test_parse_prefixed_message(self): + msg = _parse_irc_message(":nick!user@host PRIVMSG #channel :Hello world") + assert msg["prefix"] == "nick!user@host" + assert msg["command"] == "PRIVMSG" + assert msg["params"] == ["#channel", "Hello world"] + + def test_parse_numeric_reply(self): + msg = _parse_irc_message(":server 001 hermes-bot :Welcome to IRC") + assert msg["prefix"] == "server" + assert msg["command"] == "001" + assert msg["params"] == ["hermes-bot", "Welcome to IRC"] + + def test_parse_nick_collision(self): + msg = _parse_irc_message(":server 433 * hermes-bot :Nickname is already in use") + assert msg["command"] == "433" + + def test_extract_nick_full_prefix(self): + assert _extract_nick("nick!user@host") == "nick" + + def test_extract_nick_bare(self): + assert _extract_nick("server.example.com") == "server.example.com" + + +# ── IRC Adapter ────────────────────────────────────────────────────────── + +from adapter import IRCAdapter, check_requirements, validate_config + + +class TestIRCAdapterInit: + + def test_init_from_env(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_PORT", "6667") + monkeypatch.setenv("IRC_NICKNAME", "testbot") + monkeypatch.setenv("IRC_CHANNEL", "#test") + monkeypatch.setenv("IRC_USE_TLS", "false") + + from gateway.config import PlatformConfig + cfg = PlatformConfig(enabled=True) + adapter = IRCAdapter(cfg) + + assert adapter.server == "irc.test.net" + assert adapter.port == 6667 + assert adapter.nickname == "testbot" + assert adapter.channel == "#test" + assert adapter.use_tls is False + + def test_init_from_config_extra(self, monkeypatch): + # Clear any env vars + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "irc.libera.chat", + "port": 6697, + "nickname": "hermes", + "channel": "#hermes-dev", + "use_tls": True, + }, + ) + adapter = IRCAdapter(cfg) + + assert adapter.server == "irc.libera.chat" + assert adapter.port == 6697 + assert adapter.nickname == "hermes" + assert adapter.channel == "#hermes-dev" + assert adapter.use_tls is True + + def test_env_overrides_config(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "env-server.net") + + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={"server": "config-server.net", "channel": "#ch"}, + ) + adapter = IRCAdapter(cfg) + assert adapter.server == "env-server.net" + + +class TestIRCAdapterSend: + + @pytest.fixture + def adapter(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "testbot", + "channel": "#test", + "use_tls": False, + }, + ) + return IRCAdapter(cfg) + + @pytest.mark.asyncio + async def test_send_not_connected(self, adapter): + result = await adapter.send("#test", "hello") + assert result.success is False + assert "Not connected" in result.error + + @pytest.mark.asyncio + async def test_send_success(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + result = await adapter.send("#test", "hello world") + assert result.success is True + assert result.message_id is not None + # Verify PRIVMSG was sent + writer.write.assert_called() + sent_data = writer.write.call_args[0][0] + assert b"PRIVMSG #test :hello world" in sent_data + + @pytest.mark.asyncio + async def test_send_splits_long_messages(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + long_msg = "x" * 1000 + result = await adapter.send("#test", long_msg) + assert result.success is True + # Should have been split into multiple PRIVMSG calls + assert writer.write.call_count > 1 + + +class TestIRCAdapterMessageParsing: + + @pytest.fixture + def adapter(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "hermes", + "channel": "#test", + "use_tls": False, + }, + ) + a = IRCAdapter(cfg) + a._current_nick = "hermes" + a._registered = True + return a + + @pytest.mark.asyncio + async def test_handle_ping(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + await adapter._handle_line("PING :test-server") + sent = writer.write.call_args[0][0] + assert b"PONG :test-server" in sent + + @pytest.mark.asyncio + async def test_handle_welcome(self, adapter): + adapter._registered = False + adapter._registration_event = asyncio.Event() + + await adapter._handle_line(":server 001 hermes :Welcome to IRC") + assert adapter._registered is True + assert adapter._registration_event.is_set() + + @pytest.mark.asyncio + async def test_handle_nick_collision(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + await adapter._handle_line(":server 433 * hermes :Nickname in use") + assert adapter._current_nick == "hermes_" + sent = writer.write.call_args[0][0] + assert b"NICK hermes_" in sent + + @pytest.mark.asyncio + async def test_handle_addressed_channel_message(self, adapter): + """Messages addressed to the bot (nick: msg) should be dispatched.""" + handler = AsyncMock(return_value="response") + adapter._message_handler = handler + + # Mock handle_message to capture the event + dispatched = [] + original_dispatch = adapter._dispatch_message + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + + await adapter._handle_line(":user!u@host PRIVMSG #test :hermes: hello there") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "hello there" + assert dispatched[0]["chat_id"] == "#test" + + @pytest.mark.asyncio + async def test_ignores_unaddressed_channel_message(self, adapter): + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG #test :just talking") + assert len(dispatched) == 0 + + @pytest.mark.asyncio + async def test_handle_dm(self, adapter): + """DMs (target == bot nick) should always be dispatched.""" + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG hermes :private message") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "private message" + assert dispatched[0]["chat_type"] == "dm" + assert dispatched[0]["chat_id"] == "user" + + @pytest.mark.asyncio + async def test_ignores_own_messages(self, adapter): + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":hermes!bot@host PRIVMSG #test :my own msg") + assert len(dispatched) == 0 + + @pytest.mark.asyncio + async def test_ctcp_action_converted(self, adapter): + """CTCP ACTION (/me) should be converted to text.""" + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG hermes :\x01ACTION waves\x01") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "* user waves" + + +class TestIRCAdapterMarkdown: + + def test_strip_bold(self): + assert IRCAdapter._strip_markdown("**bold**") == "bold" + + def test_strip_italic(self): + assert IRCAdapter._strip_markdown("*italic*") == "italic" + + def test_strip_code(self): + assert IRCAdapter._strip_markdown("`code`") == "code" + + def test_strip_link(self): + result = IRCAdapter._strip_markdown("[click here](https://example.com)") + assert result == "click here (https://example.com)" + + def test_strip_image(self): + result = IRCAdapter._strip_markdown("![alt](https://example.com/img.png)") + assert result == "https://example.com/img.png" + + +# ── Requirements / validation ──────────────────────────────────────────── + + +class TestIRCRequirements: + + def test_check_requirements_with_env(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_CHANNEL", "#test") + assert check_requirements() is True + + def test_check_requirements_missing_server(self, monkeypatch): + monkeypatch.delenv("IRC_SERVER", raising=False) + monkeypatch.setenv("IRC_CHANNEL", "#test") + assert check_requirements() is False + + def test_check_requirements_missing_channel(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.delenv("IRC_CHANNEL", raising=False) + assert check_requirements() is False + + def test_validate_config_from_extra(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_CHANNEL"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig(extra={"server": "irc.test.net", "channel": "#test"}) + assert validate_config(cfg) is True + + def test_validate_config_missing(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_CHANNEL"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig(extra={}) + assert validate_config(cfg) is False + + +# ── Plugin registration ────────────────────────────────────────────────── + + +class TestIRCPluginRegistration: + """Test the register() entry point.""" + + def test_register_adds_to_registry(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_CHANNEL", "#test") + + from gateway.platform_registry import platform_registry + + # Clean up if already registered + platform_registry.unregister("irc") + + from adapter import register + + ctx = MagicMock() + register(ctx) + ctx.register_platform.assert_called_once() + call_kwargs = ctx.register_platform.call_args + assert call_kwargs[1]["name"] == "irc" or call_kwargs[0][0] == "irc" if call_kwargs[0] else call_kwargs[1]["name"] == "irc" diff --git a/tests/gateway/test_platform_registry.py b/tests/gateway/test_platform_registry.py new file mode 100644 index 0000000000..08451ae7cb --- /dev/null +++ b/tests/gateway/test_platform_registry.py @@ -0,0 +1,267 @@ +"""Tests for the platform adapter registry and dynamic Platform enum.""" + +import os +import pytest +from unittest.mock import MagicMock, patch +from dataclasses import dataclass + +from gateway.platform_registry import PlatformRegistry, PlatformEntry, platform_registry +from gateway.config import Platform, PlatformConfig, GatewayConfig + + +# ── Platform enum dynamic members ───────────────────────────────────────── + + +class TestPlatformEnumDynamic: + """Test that Platform enum accepts unknown values for plugin platforms.""" + + def test_builtin_members_still_work(self): + assert Platform.TELEGRAM.value == "telegram" + assert Platform("telegram") is Platform.TELEGRAM + + def test_dynamic_member_created(self): + p = Platform("irc") + assert p.value == "irc" + assert p.name == "IRC" + + def test_dynamic_member_identity_stable(self): + """Same value returns same object (cached).""" + a = Platform("irc") + b = Platform("irc") + assert a is b + + def test_dynamic_member_case_normalised(self): + """Mixed case normalised to lowercase.""" + a = Platform("IRC") + b = Platform("irc") + assert a is b + assert a.value == "irc" + + def test_dynamic_member_with_hyphens(self): + p = Platform("my-platform") + assert p.value == "my-platform" + assert p.name == "MY_PLATFORM" + + def test_dynamic_member_rejects_non_string(self): + with pytest.raises(ValueError): + Platform(123) + + def test_dynamic_member_rejects_empty(self): + with pytest.raises(ValueError): + Platform("") + + def test_dynamic_member_rejects_whitespace_only(self): + with pytest.raises(ValueError): + Platform(" ") + + +# ── PlatformRegistry ────────────────────────────────────────────────────── + + +class TestPlatformRegistry: + """Test the PlatformRegistry itself.""" + + def _make_entry(self, name="test", check_ok=True, validate_ok=True, factory_ok=True): + adapter_mock = MagicMock() + return PlatformEntry( + name=name, + label=name.title(), + adapter_factory=lambda cfg, _m=adapter_mock: _m if factory_ok else (_ for _ in ()).throw(RuntimeError("factory error")), + check_fn=lambda: check_ok, + validate_config=lambda cfg: validate_ok, + required_env=[], + source="plugin", + ), adapter_mock + + def test_register_and_get(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("alpha") + reg.register(entry) + assert reg.get("alpha") is entry + assert reg.is_registered("alpha") + + def test_get_unknown_returns_none(self): + reg = PlatformRegistry() + assert reg.get("nonexistent") is None + + def test_unregister(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("beta") + reg.register(entry) + assert reg.unregister("beta") is True + assert reg.get("beta") is None + assert reg.unregister("beta") is False # already gone + + def test_create_adapter_success(self): + reg = PlatformRegistry() + entry, mock_adapter = self._make_entry("gamma") + reg.register(entry) + result = reg.create_adapter("gamma", MagicMock()) + assert result is mock_adapter + + def test_create_adapter_unknown_name(self): + reg = PlatformRegistry() + assert reg.create_adapter("unknown", MagicMock()) is None + + def test_create_adapter_check_fails(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("delta", check_ok=False) + reg.register(entry) + assert reg.create_adapter("delta", MagicMock()) is None + + def test_create_adapter_validate_fails(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("epsilon", validate_ok=False) + reg.register(entry) + assert reg.create_adapter("epsilon", MagicMock()) is None + + def test_create_adapter_factory_exception(self): + reg = PlatformRegistry() + entry = PlatformEntry( + name="broken", + label="Broken", + adapter_factory=lambda cfg: (_ for _ in ()).throw(RuntimeError("boom")), + check_fn=lambda: True, + validate_config=None, + source="plugin", + ) + reg.register(entry) + # factory raises → create_adapter returns None instead of propagating + assert reg.create_adapter("broken", MagicMock()) is None + + def test_create_adapter_no_validate(self): + """When validate_config is None, skip validation.""" + reg = PlatformRegistry() + mock_adapter = MagicMock() + entry = PlatformEntry( + name="novalidate", + label="NoValidate", + adapter_factory=lambda cfg: mock_adapter, + check_fn=lambda: True, + validate_config=None, + source="plugin", + ) + reg.register(entry) + assert reg.create_adapter("novalidate", MagicMock()) is mock_adapter + + def test_all_entries(self): + reg = PlatformRegistry() + e1, _ = self._make_entry("one") + e2, _ = self._make_entry("two") + reg.register(e1) + reg.register(e2) + names = {e.name for e in reg.all_entries()} + assert names == {"one", "two"} + + def test_plugin_entries(self): + reg = PlatformRegistry() + plugin_entry, _ = self._make_entry("plugged") + builtin_entry = PlatformEntry( + name="core", + label="Core", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + source="builtin", + ) + reg.register(plugin_entry) + reg.register(builtin_entry) + plugin_names = {e.name for e in reg.plugin_entries()} + assert plugin_names == {"plugged"} + + def test_re_register_replaces(self): + reg = PlatformRegistry() + entry1, mock1 = self._make_entry("dup") + entry2 = PlatformEntry( + name="dup", + label="Dup v2", + adapter_factory=lambda cfg: "v2", + check_fn=lambda: True, + source="plugin", + ) + reg.register(entry1) + reg.register(entry2) + assert reg.get("dup").label == "Dup v2" + + +# ── GatewayConfig integration ──────────────────────────────────────────── + + +class TestGatewayConfigPluginPlatform: + """Test that GatewayConfig parses and validates plugin platforms.""" + + def test_from_dict_accepts_plugin_platform(self): + data = { + "platforms": { + "telegram": {"enabled": True, "token": "test-token"}, + "irc": {"enabled": True, "extra": {"server": "irc.libera.chat"}}, + } + } + cfg = GatewayConfig.from_dict(data) + platform_values = {p.value for p in cfg.platforms} + assert "telegram" in platform_values + assert "irc" in platform_values + + def test_get_connected_platforms_includes_registered_plugin(self): + """Plugin platform with registry entry passes get_connected_platforms.""" + # Register a fake plugin platform + from gateway.platform_registry import platform_registry as _reg + + test_entry = PlatformEntry( + name="testplat", + label="TestPlat", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + validate_config=lambda cfg: bool(cfg.extra.get("token")), + source="plugin", + ) + _reg.register(test_entry) + try: + data = { + "platforms": { + "testplat": {"enabled": True, "extra": {"token": "abc"}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "testplat" in connected_values + finally: + _reg.unregister("testplat") + + def test_get_connected_platforms_excludes_unregistered_plugin(self): + """Plugin platform without registry entry is excluded.""" + data = { + "platforms": { + "unknown_plugin": {"enabled": True, "extra": {"token": "abc"}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "unknown_plugin" not in connected_values + + def test_get_connected_platforms_excludes_invalid_config(self): + """Plugin platform with failing validate_config is excluded.""" + from gateway.platform_registry import platform_registry as _reg + + test_entry = PlatformEntry( + name="badconfig", + label="BadConfig", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + validate_config=lambda cfg: False, # always fails + source="plugin", + ) + _reg.register(test_entry) + try: + data = { + "platforms": { + "badconfig": {"enabled": True, "extra": {}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "badconfig" not in connected_values + finally: + _reg.unregister("badconfig") diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index b86d18575d..bb770e1a75 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -83,9 +83,12 @@ class TestSessionSourceRoundtrip: assert restored.chat_topic is None assert restored.chat_type == "dm" - def test_invalid_platform_raises(self): - with pytest.raises((ValueError, KeyError)): - SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"}) + def test_unknown_platform_accepted_for_plugins(self): + """Unknown platform names are now accepted (dynamic enum members for + plugin platforms), so from_dict should succeed rather than raise.""" + source = SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"}) + assert source.platform.value == "nonexistent" + assert source.chat_id == "1" class TestSessionSourceDescription: