From 04c1c5d53f555b9798d180802a288f0566f9acd7 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 11 Apr 2026 13:59:52 -0700 Subject: [PATCH] refactor: extract shared helpers to deduplicate repeated code patterns (#7917) * refactor: add shared helper modules for code deduplication New modules: - gateway/platforms/helpers.py: MessageDeduplicator, TextBatchAggregator, strip_markdown, ThreadParticipationTracker, redact_phone - hermes_cli/cli_output.py: print_info/success/warning/error, prompt helpers - tools/path_security.py: validate_within_dir, has_traversal_component - utils.py additions: safe_json_loads, read_json_file, read_jsonl, append_jsonl, env_str/lower/int/bool helpers - hermes_constants.py additions: get_config_path, get_skills_dir, get_logs_dir, get_env_path * refactor: migrate gateway adapters to shared helpers - MessageDeduplicator: discord, slack, dingtalk, wecom, weixin, mattermost - strip_markdown: bluebubbles, feishu, sms - redact_phone: sms, signal - ThreadParticipationTracker: discord, matrix - _acquire/_release_platform_lock: telegram, discord, slack, whatsapp, signal, weixin Net -316 lines across 19 files. * refactor: migrate CLI modules to shared helpers - tools_config.py: use cli_output print/prompt + curses_radiolist (-117 lines) - setup.py: use cli_output print helpers + curses_radiolist (-101 lines) - mcp_config.py: use cli_output prompt (-15 lines) - memory_setup.py: use curses_radiolist (-86 lines) Net -263 lines across 5 files. * refactor: migrate to shared utility helpers - safe_json_loads: agent/display.py (4 sites) - get_config_path: skill_utils.py, hermes_logging.py, hermes_time.py - get_skills_dir: skill_utils.py, prompt_builder.py - Token estimation dedup: skills_tool.py imports from model_metadata - Path security: skills_tool, cronjob_tools, skill_manager_tool, credential_files - Non-atomic YAML writes: doctor.py, config.py now use atomic_yaml_write - Platform dict: new platforms.py, skills_config + tools_config derive from it - Anthropic key: new get_anthropic_key() in auth.py, used by doctor/status/config/main * test: update tests for shared helper migrations - test_dingtalk: use _dedup.is_duplicate() instead of _is_duplicate() - test_mattermost: use _dedup instead of _seen_posts/_prune_seen - test_signal: import redact_phone from helpers instead of signal - test_discord_connect: _platform_lock_identity instead of _token_lock_identity - test_telegram_conflict: updated lock error message format - test_skill_manager_tool: 'escapes' instead of 'boundary' in error msgs --- agent/display.py | 25 +- agent/prompt_builder.py | 5 +- agent/skill_utils.py | 12 +- gateway/platforms/base.py | 31 ++- gateway/platforms/bluebubbles.py | 18 +- gateway/platforms/dingtalk.py | 25 +- gateway/platforms/discord.py | 111 +------- gateway/platforms/feishu.py | 16 +- gateway/platforms/helpers.py | 261 ++++++++++++++++++ gateway/platforms/matrix.py | 54 +--- gateway/platforms/mattermost.py | 23 +- gateway/platforms/signal.py | 45 +-- gateway/platforms/slack.py | 41 +-- gateway/platforms/sms.py | 37 +-- gateway/platforms/telegram.py | 34 +-- gateway/platforms/wecom.py | 26 +- gateway/platforms/weixin.py | 41 +-- gateway/platforms/whatsapp.py | 34 +-- hermes_cli/auth.py | 22 ++ hermes_cli/cli_output.py | 79 ++++++ hermes_cli/config.py | 7 +- hermes_cli/doctor.py | 7 +- hermes_cli/main.py | 9 +- hermes_cli/mcp_config.py | 15 +- hermes_cli/memory_setup.py | 86 +----- hermes_cli/platforms.py | 45 +++ hermes_cli/setup.py | 101 +------ hermes_cli/skills_config.py | 23 +- hermes_cli/status.py | 7 +- hermes_cli/tools_config.py | 142 ++-------- hermes_constants.py | 27 ++ hermes_logging.py | 4 +- hermes_time.py | 5 +- tests/e2e/conftest.py | 3 +- tests/gateway/test_dingtalk.py | 23 +- tests/gateway/test_discord_connect.py | 2 +- tests/gateway/test_discord_free_response.py | 6 +- .../test_discord_thread_persistence.py | 37 +-- tests/gateway/test_matrix_mention.py | 62 ++--- tests/gateway/test_mattermost.py | 22 +- tests/gateway/test_signal.py | 12 +- tests/gateway/test_telegram_conflict.py | 14 +- tests/tools/test_skill_manager_tool.py | 6 +- tools/credential_files.py | 36 ++- tools/cronjob_tools.py | 8 +- tools/path_security.py | 43 +++ tools/skill_manager_tool.py | 17 +- tools/skills_tool.py | 37 +-- utils.py | 90 +++++- 49 files changed, 887 insertions(+), 949 deletions(-) create mode 100644 gateway/platforms/helpers.py create mode 100644 hermes_cli/cli_output.py create mode 100644 hermes_cli/platforms.py create mode 100644 tools/path_security.py diff --git a/agent/display.py b/agent/display.py index 604b7a298..182064576 100644 --- a/agent/display.py +++ b/agent/display.py @@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency. Used by AIAgent._execute_tool_calls for CLI feedback. """ -import json import logging import os import sys @@ -14,6 +13,8 @@ from dataclasses import dataclass, field from difflib import unified_diff from pathlib import Path +from utils import safe_json_loads + # ANSI escape codes for coloring tool failure indicators _RED = "\033[31m" _RESET = "\033[0m" @@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool: """Conservatively detect whether a tool result represents success.""" if not result: return False - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): + data = safe_json_loads(result) + if data is None: return False if not isinstance(data, dict): return False @@ -423,10 +423,7 @@ def extract_edit_diff( ) -> str | None: """Extract a unified diff from a file-edit tool result.""" if tool_name == "patch" and result: - try: - data = json.loads(result) - except (json.JSONDecodeError, TypeError): - data = None + data = safe_json_loads(result) if isinstance(data, dict): diff = data.get("diff") if isinstance(diff, str) and diff.strip(): @@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str] return False, "" if tool_name == "terminal": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): exit_code = data.get("exit_code") if exit_code is not None and exit_code != 0: return True, f" [exit {exit_code}]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse terminal result as JSON for exit code check") return False, "" # Memory-specific: distinguish "full" from real errors if tool_name == "memory": - try: - data = json.loads(result) + data = safe_json_loads(result) + if isinstance(data, dict): if data.get("success") is False and "exceed the limit" in data.get("error", ""): return True, " [full]" - except (json.JSONDecodeError, TypeError, AttributeError): - logger.debug("Could not parse memory result as JSON for capacity check") # Generic heuristic for non-terminal tools lower = result[:500].lower() diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 08b8fe0a6..26d913a02 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -12,7 +12,7 @@ import threading from collections import OrderedDict from pathlib import Path -from hermes_constants import get_hermes_home +from hermes_constants import get_hermes_home, get_skills_dir from typing import Optional from agent.skill_utils import ( @@ -548,8 +548,7 @@ def build_skills_system_prompt( are read-only — they appear in the index but new skills are always created in the local dir. Local skills take precedence when names collide. """ - hermes_home = get_hermes_home() - skills_dir = hermes_home / "skills" + skills_dir = get_skills_dir() external_dirs = get_all_skills_dirs()[1:] # skip local (index 0) if not skills_dir.exists() and not external_dirs: diff --git a/agent/skill_utils.py b/agent/skill_utils.py index ba606b358..97ba92b73 100644 --- a/agent/skill_utils.py +++ b/agent/skill_utils.py @@ -12,7 +12,7 @@ import sys from pathlib import Path from typing import Any, Dict, List, Set, Tuple -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_skills_dir logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]: Reads the config file directly (no CLI config imports) to stay lightweight. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return set() try: @@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]: path. Only directories that actually exist are returned. Duplicates and paths that resolve to the local ``~/.hermes/skills/`` are silently skipped. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if not config_path.exists(): return [] try: @@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]: if not isinstance(raw_dirs, list): return [] - local_skills = (get_hermes_home() / "skills").resolve() + local_skills = get_skills_dir().resolve() seen: Set[Path] = set() result: List[Path] = [] @@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]: The local dir is always first (and always included even if it doesn't exist yet — callers handle that). External dirs follow in config order. """ - dirs = [get_hermes_home() / "skills"] + dirs = [get_skills_dir()] dirs.extend(get_external_skills_dirs()) return dirs @@ -384,7 +384,7 @@ def resolve_skill_config_values( current values (or the declared default if the key isn't set). Path values are expanded via ``os.path.expanduser``. """ - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() config: Dict[str, Any] = {} if config_path.exists(): try: diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 04f0c1deb..352aecb33 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC): result = handler(self) if asyncio.iscoroutine(result): await result - + + def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool: + """Acquire a scoped lock for this adapter. Returns True on success.""" + from gateway.status import acquire_scoped_lock + self._platform_lock_scope = scope + self._platform_lock_identity = identity + acquired, existing = acquire_scoped_lock( + scope, identity, metadata={'platform': self.platform.value} + ) + if acquired: + return True + owner_pid = existing.get('pid') if isinstance(existing, dict) else None + message = ( + f'{resource_desc} already in use' + + (f' (PID {owner_pid})' if owner_pid else '') + + '. Stop the other gateway first.' + ) + logger.error('[%s] %s', self.name, message) + self._set_fatal_error(f'{scope}_lock', message, retryable=False) + return False + + def _release_platform_lock(self) -> None: + """Release the scoped lock acquired by _acquire_platform_lock.""" + identity = getattr(self, '_platform_lock_identity', None) + if not identity: + return + from gateway.status import release_scoped_lock + release_scoped_lock(self._platform_lock_scope, identity) + self._platform_lock_identity = None + @property def name(self) -> str: """Human-readable name for this adapter.""" diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py index f50cd9503..115000996 100644 --- a/gateway/platforms/bluebubbles.py +++ b/gateway/platforms/bluebubbles.py @@ -30,6 +30,7 @@ from gateway.platforms.base import ( cache_audio_from_bytes, cache_document_from_bytes, ) +from gateway.platforms.helpers import strip_markdown logger = logging.getLogger(__name__) @@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str: return value.rstrip("/") -def _strip_markdown(text: str) -> str: - """Strip common markdown formatting for iMessage plain-text delivery.""" - text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL) - text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL) - text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text) - text = re.sub(r"`(.+?)`", r"\1", text) - text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) - text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text) - text = re.sub(r"\n{3,}", "\n\n", text) - return text.strip() + # --------------------------------------------------------------------------- @@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: - text = _strip_markdown(content or "") + text = strip_markdown(content or "") if not text: return SendResult(success=False, error="BlueBubbles send requires text") chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) @@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): return info def format_message(self, content: str) -> str: - return _strip_markdown(content) + return strip_markdown(content) # ------------------------------------------------------------------ # Inbound attachment downloading (from #4588) diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index e83b902df..5d50deca5 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -42,6 +42,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -52,8 +53,6 @@ from gateway.platforms.base import ( logger = logging.getLogger(__name__) MAX_MESSAGE_LENGTH = 20000 -DEDUP_WINDOW_SECONDS = 300 -DEDUP_MAX_SIZE = 1000 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] _SESSION_WEBHOOKS_MAX = 500 _DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/') @@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_task: Optional[asyncio.Task] = None self._http_client: Optional["httpx.AsyncClient"] = None - # Message deduplication: msg_id -> timestamp - self._seen_messages: Dict[str, float] = {} + # Message deduplication + self._dedup = MessageDeduplicator(max_size=1000) # Map chat_id -> session_webhook for reply routing self._session_webhooks: Dict[str, str] = {} @@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter): self._stream_client = None self._session_webhooks.clear() - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) # -- Inbound message processing ----------------------------------------- @@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter): async def _on_message(self, message: "ChatbotMessage") -> None: """Process an incoming DingTalk chatbot message.""" msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) return @@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter): content = " ".join(parts).strip() return content - # -- Deduplication ------------------------------------------------------ - - def _is_duplicate(self, msg_id: str) -> bool: - """Check and record a message ID. Returns True if already seen.""" - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff} - - if msg_id in self._seen_messages: - return True - self._seen_messages[msg_id] = now - return False - # -- Outbound messaging ------------------------------------------------- async def send( diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index dcf05a162..b1d07e5d6 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig import re +from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter): # Track threads where the bot has participated so follow-up messages # in those threads don't require @mention. Persisted to disk so the # set survives gateway restarts. - self._bot_participated_threads: set = self._load_participated_threads() + self._threads = ThreadParticipationTracker("discord") # Persistent typing indicator loops per channel (DMs don't reliably # show the standard typing gateway event for bots) self._typing_tasks: Dict[str, asyncio.Task] = {} self._bot_task: Optional[asyncio.Task] = None - # Cap to prevent unbounded growth (Discord threads get archived). - self._MAX_TRACKED_THREADS = 500 - # Dedup cache: message_id → timestamp. Prevents duplicate bot - # responses when Discord RESUME replays events after reconnects. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Discord + # RESUME replays events after reconnects. + self._dedup = MessageDeduplicator() # Reply threading mode: "off" (no replies), "first" (reply on first # chunk only, default), "all" (reply-reference on every chunk). self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first' @@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter): return False try: - # Acquire scoped lock to prevent duplicate bot token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('discord_token_lock', message, retryable=False) + if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'): return False - # Parse allowed user entries (may contain usernames or IDs) allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") if allowed_env: @@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter): @self._client.event async def on_message(message: DiscordMessage): # Dedup: Discord RESUME replays events after reconnects (#4777) - msg_id = str(message.id) - now = time.time() - if msg_id in adapter_self._seen_messages: + if adapter_self._dedup.is_duplicate(str(message.id)): return - adapter_self._seen_messages[msg_id] = now - if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX: - cutoff = now - adapter_self._SEEN_TTL - adapter_self._seen_messages = { - k: v for k, v in adapter_self._seen_messages.items() - if v > cutoff - } # Always ignore our own messages if message.author == self._client.user: @@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter): except asyncio.TimeoutError: logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() return False async def disconnect(self) -> None: @@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter): self._client = None self._ready_event.clear() - # Release the token lock - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('discord-bot-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[%s] Disconnected", self.name) @@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so follow-ups don't require @mention if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # If a message was provided, kick off a new Hermes session in the thread starter = (message or "").strip() @@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter): return f"{parent_name} / {thread_name}" return thread_name - # ------------------------------------------------------------------ - # Thread participation persistence - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "discord_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load discord thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - # Trim to most recent entries if over cap - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save discord thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - async def _handle_message(self, message: DiscordMessage) -> None: """Handle incoming Discord messages.""" # In server channels (not DMs), require the bot to be @mentioned @@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter): # Skip the mention check if the message is in a thread where # the bot has previously participated (auto-created or replied in). - in_bot_thread = is_thread and thread_id in self._bot_participated_threads + in_bot_thread = is_thread and thread_id in self._threads if require_mention and not is_free_channel and not in_bot_thread: if self._client.user not in message.mentions: @@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter): is_thread = True thread_id = str(thread.id) auto_threaded_channel = thread - self._track_thread(thread_id) + self._threads.mark(thread_id) # Determine message type msg_type = MessageType.TEXT @@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter): # Track thread participation so the bot won't require @mention for # follow-up messages in threads it has already engaged in. if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) # Only batch plain text messages — commands, media, etc. dispatch # immediately since they won't be split by the Discord client. diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index a88c7e52b..16f5467b2 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str: def _strip_markdown_to_plain_text(text: str) -> str: + """Strip markdown formatting to plain text for Feishu text fallbacks. + + Delegates common markdown stripping to the shared helper and adds + Feishu-specific patterns (blockquotes, strikethrough, underline tags, + horizontal rules, \\r\\n normalisation). + """ + from gateway.platforms.helpers import strip_markdown plain = text.replace("\r\n", "\n") plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain) - plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE) plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE) plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE) - plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain) - plain = re.sub(r"`([^`\n]+)`", r"\1", plain) - plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain) - plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain) plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain) plain = re.sub(r"([\s\S]*?)", r"\1", plain) - plain = re.sub(r"\n{3,}", "\n\n", plain) - return plain.strip() + plain = strip_markdown(plain) + return plain def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]: diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py new file mode 100644 index 000000000..c834dd89c --- /dev/null +++ b/gateway/platforms/helpers.py @@ -0,0 +1,261 @@ +"""Shared helper classes for gateway platform adapters. + +Extracts common patterns that were duplicated across 5-7 adapters: +message deduplication, text batch aggregation, markdown stripping, +and thread participation tracking. +""" + +import asyncio +import json +import logging +import re +import time +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from gateway.platforms.base import BasePlatformAdapter, MessageEvent + +logger = logging.getLogger(__name__) + + +# ─── Message Deduplication ──────────────────────────────────────────────────── + + +class MessageDeduplicator: + """TTL-based message deduplication cache. + + Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern + previously duplicated in discord, slack, dingtalk, wecom, weixin, + mattermost, and feishu adapters. + + Usage:: + + self._dedup = MessageDeduplicator() + + # In message handler: + if self._dedup.is_duplicate(msg_id): + return + """ + + def __init__(self, max_size: int = 2000, ttl_seconds: float = 300): + self._seen: Dict[str, float] = {} + self._max_size = max_size + self._ttl = ttl_seconds + + def is_duplicate(self, msg_id: str) -> bool: + """Return True if *msg_id* was already seen within the TTL window.""" + if not msg_id: + return False + now = time.time() + if msg_id in self._seen: + return True + self._seen[msg_id] = now + if len(self._seen) > self._max_size: + cutoff = now - self._ttl + self._seen = {k: v for k, v in self._seen.items() if v > cutoff} + return False + + def clear(self): + """Clear all tracked messages.""" + self._seen.clear() + + +# ─── Text Batch Aggregation ────────────────────────────────────────────────── + + +class TextBatchAggregator: + """Aggregates rapid-fire text events into single messages. + + Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern + previously duplicated in telegram, discord, matrix, wecom, and feishu. + + Usage:: + + self._text_batcher = TextBatchAggregator( + handler=self._message_handler, + batch_delay=0.6, + split_threshold=1900, + ) + + # In message dispatch: + if msg_type == MessageType.TEXT and self._text_batcher.is_enabled(): + self._text_batcher.enqueue(event, session_key) + return + """ + + def __init__( + self, + handler, + *, + batch_delay: float = 0.6, + split_delay: float = 2.0, + split_threshold: int = 4000, + ): + self._handler = handler + self._batch_delay = batch_delay + self._split_delay = split_delay + self._split_threshold = split_threshold + self._pending: Dict[str, "MessageEvent"] = {} + self._pending_tasks: Dict[str, asyncio.Task] = {} + + def is_enabled(self) -> bool: + """Return True if batching is active (delay > 0).""" + return self._batch_delay > 0 + + def enqueue(self, event: "MessageEvent", key: str) -> None: + """Add *event* to the pending batch for *key*.""" + chunk_len = len(event.text or "") + existing = self._pending.get(key) + if not existing: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending[key] = event + else: + existing.text = f"{existing.text}\n{event.text}" + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + + # Cancel prior flush timer, start a new one + prior = self._pending_tasks.get(key) + if prior and not prior.done(): + prior.cancel() + self._pending_tasks[key] = asyncio.create_task(self._flush(key)) + + async def _flush(self, key: str) -> None: + """Wait then dispatch the batched event for *key*.""" + current_task = self._pending_tasks.get(key) + pending = self._pending.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + + # Use longer delay when the last chunk looks like a split message + delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay + await asyncio.sleep(delay) + + event = self._pending.pop(key, None) + if event: + try: + await self._handler(event) + except Exception: + logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key) + + if self._pending_tasks.get(key) is current_task: + self._pending_tasks.pop(key, None) + + def cancel_all(self) -> None: + """Cancel all pending flush tasks.""" + for task in self._pending_tasks.values(): + if not task.done(): + task.cancel() + self._pending_tasks.clear() + self._pending.clear() + + +# ─── Markdown Stripping ────────────────────────────────────────────────────── + +# Pre-compiled regexes for performance +_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) +_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL) +_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL) +_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL) +_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?") +_RE_INLINE_CODE = re.compile(r"`(.+?)`") +_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) +_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)") +_RE_MULTI_NEWLINE = re.compile(r"\n{3,}") + + +def strip_markdown(text: str) -> str: + """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.). + + Replaces the identical ``_strip_markdown()`` functions previously + duplicated in sms.py, bluebubbles.py, and feishu.py. + """ + text = _RE_BOLD.sub(r"\1", text) + text = _RE_ITALIC_STAR.sub(r"\1", text) + text = _RE_BOLD_UNDER.sub(r"\1", text) + text = _RE_ITALIC_UNDER.sub(r"\1", text) + text = _RE_CODE_BLOCK.sub("", text) + text = _RE_INLINE_CODE.sub(r"\1", text) + text = _RE_HEADING.sub("", text) + text = _RE_LINK.sub(r"\1", text) + text = _RE_MULTI_NEWLINE.sub("\n\n", text) + return text.strip() + + +# ─── Thread Participation Tracking ─────────────────────────────────────────── + + +class ThreadParticipationTracker: + """Persistent tracking of threads the bot has participated in. + + Replaces the identical ``_load/_save_participated_threads`` + + ``_mark_thread_participated`` pattern previously duplicated in + discord.py and matrix.py. + + Usage:: + + self._threads = ThreadParticipationTracker("discord") + + # Check membership: + if thread_id in self._threads: + ... + + # Mark participation: + self._threads.mark(thread_id) + """ + + _MAX_TRACKED = 500 + + def __init__(self, platform_name: str, max_tracked: int = 500): + self._platform = platform_name + self._max_tracked = max_tracked + self._threads: set = self._load() + + def _state_path(self) -> Path: + from hermes_constants import get_hermes_home + return get_hermes_home() / f"{self._platform}_threads.json" + + def _load(self) -> set: + path = self._state_path() + if path.exists(): + try: + return set(json.loads(path.read_text(encoding="utf-8"))) + except Exception: + pass + return set() + + def _save(self) -> None: + path = self._state_path() + path.parent.mkdir(parents=True, exist_ok=True) + thread_list = list(self._threads) + if len(thread_list) > self._max_tracked: + thread_list = thread_list[-self._max_tracked:] + self._threads = set(thread_list) + path.write_text(json.dumps(thread_list), encoding="utf-8") + + def mark(self, thread_id: str) -> None: + """Mark *thread_id* as participated and persist.""" + if thread_id not in self._threads: + self._threads.add(thread_id) + self._save() + + def __contains__(self, thread_id: str) -> bool: + return thread_id in self._threads + + def clear(self) -> None: + self._threads.clear() + + +# ─── Phone Number Redaction ────────────────────────────────────────────────── + + +def redact_phone(phone: str) -> str: + """Redact a phone number for logging, preserving country code and last 4. + + Replaces the identical ``_redact_phone()`` functions in signal.py, + sms.py, and bluebubbles.py. + """ + if not phone: + return "" + if len(phone) <= 8: + return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" + return phone[:4] + "****" + phone[-4:] diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 7daf2e70e..349f962d2 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -92,6 +92,7 @@ from gateway.platforms.base import ( ProcessingOutcome, SendResult, ) +from gateway.platforms.helpers import ThreadParticipationTracker logger = logging.getLogger(__name__) @@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter): self._pending_megolm: list = [] # Thread participation tracking (for require_mention bypass) - self._bot_participated_threads: set = self._load_participated_threads() - self._MAX_TRACKED_THREADS = 500 + self._threads = ThreadParticipationTracker("matrix") # Mention/thread gating — parsed once from env vars. self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") @@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter): # Require-mention gating. if not is_dm: is_free_room = room_id in self._free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) + in_bot_thread = bool(thread_id and thread_id in self._threads) if self._require_mention and not is_free_room and not in_bot_thread: if not is_mentioned: return None @@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter): # DM mention-thread. if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) # Strip mention from body. if is_mentioned: @@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter): # Auto-thread. if not is_dm and not thread_id and self._auto_thread: thread_id = event_id - self._track_thread(thread_id) + self._threads.mark(thread_id) display_name = await self._get_display_name(room_id, sender) source = self.build_source( @@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter): ) if thread_id: - self._track_thread(thread_id) + self._threads.mark(thread_id) self._background_read_receipt(room_id, event_id) @@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter): for rid in self._joined_rooms } - # ------------------------------------------------------------------ - # Thread participation tracking - # ------------------------------------------------------------------ - - @staticmethod - def _thread_state_path() -> Path: - """Path to the persisted thread participation set.""" - from hermes_cli.config import get_hermes_home - return get_hermes_home() / "matrix_threads.json" - - @classmethod - def _load_participated_threads(cls) -> set: - """Load persisted thread IDs from disk.""" - path = cls._thread_state_path() - try: - if path.exists(): - data = json.loads(path.read_text(encoding="utf-8")) - if isinstance(data, list): - return set(data) - except Exception as e: - logger.debug("Could not load matrix thread state: %s", e) - return set() - - def _save_participated_threads(self) -> None: - """Persist the current thread set to disk (best-effort).""" - path = self._thread_state_path() - try: - thread_list = list(self._bot_participated_threads) - if len(thread_list) > self._MAX_TRACKED_THREADS: - thread_list = thread_list[-self._MAX_TRACKED_THREADS:] - self._bot_participated_threads = set(thread_list) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(thread_list), encoding="utf-8") - except Exception as e: - logger.debug("Could not save matrix thread state: %s", e) - - def _track_thread(self, thread_id: str) -> None: - """Add a thread to the participation set and persist.""" - if thread_id not in self._bot_participated_threads: - self._bot_participated_threads.add(thread_id) - self._save_participated_threads() - # ------------------------------------------------------------------ # Mention detection helpers # ------------------------------------------------------------------ diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py index 56f29e876..23a86f02b 100644 --- a/gateway/platforms/mattermost.py +++ b/gateway/platforms/mattermost.py @@ -18,11 +18,11 @@ import json import logging import os import re -import time from pathlib import Path from typing import Any, Dict, List, Optional from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter): or os.getenv("MATTERMOST_REPLY_MODE", "off") ).lower() - # Dedup cache: post_id → timestamp (prevent reprocessing) - self._seen_posts: Dict[str, float] = {} - self._SEEN_MAX = 2000 - self._SEEN_TTL = 300 # 5 minutes + # Dedup cache (prevent reprocessing) + self._dedup = MessageDeduplicator() # ------------------------------------------------------------------ # HTTP helpers @@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter): post_id = post.get("id", "") # Dedup. - self._prune_seen() - if post_id in self._seen_posts: + if self._dedup.is_duplicate(post_id): return - self._seen_posts[post_id] = time.time() # Build message event. channel_id = post.get("channel_id", "") @@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter): await self.handle_message(msg_event) - def _prune_seen(self) -> None: - """Remove expired entries from the dedup cache.""" - if len(self._seen_posts) < self._SEEN_MAX: - return - now = time.time() - self._seen_posts = { - pid: ts - for pid, ts in self._seen_posts.items() - if now - ts < self._SEEN_TTL - } + diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 08b62f2a6..8ef7bd0d6 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -37,6 +37,7 @@ from gateway.platforms.base import ( cache_document_from_bytes, cache_image_from_url, ) +from gateway.platforms.helpers import redact_phone logger = logging.getLogger(__name__) @@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0 HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +155****4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" - return phone[:4] + "****" + phone[-4:] - def _parse_comma_list(value: str) -> List[str]: """Split a comma-separated string into a list, stripping whitespace.""" @@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter): self._recent_sent_timestamps: set = set() self._max_recent_timestamps = 50 - self._phone_lock_identity: Optional[str] = None - logger.info("Signal adapter initialized: url=%s account=%s groups=%s", - self.http_url, _redact_phone(self.account), + self.http_url, redact_phone(self.account), "enabled" if self.group_allow_from else "disabled") # ------------------------------------------------------------------ @@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate Signal listeners for the same phone try: - from gateway.status import acquire_scoped_lock - - self._phone_lock_identity = self.account - acquired, existing = acquire_scoped_lock( - "signal-phone", - self._phone_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Signal account" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Signal listener." - ) - logger.error("Signal: %s", message) - self._set_fatal_error("signal_phone_lock", message, retryable=False) + if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'): return False except Exception as e: logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e) @@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter): await self.client.aclose() self.client = None - if self._phone_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("signal-phone", self._phone_lock_identity) - except Exception as e: - logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True) - self._phone_lock_identity = None + self._release_platform_lock() logger.info("Signal: disconnected") @@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter): ) logger.debug("Signal: message from %s in %s: %s", - _redact_phone(sender), chat_id[:20], (text or "")[:50]) + redact_phone(sender), chat_id[:20], (text or "")[:50]) await self.handle_message(event) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 361f74882..8f9934cf7 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -33,6 +33,7 @@ from pathlib import Path as _Path sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter): self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id self._channel_team: Dict[str, str] = {} # channel_id → team_id - # Dedup cache: event_ts → timestamp. Prevents duplicate bot - # responses when Socket Mode reconnects redeliver events. - self._seen_messages: Dict[str, float] = {} - self._SEEN_TTL = 300 # 5 minutes - self._SEEN_MAX = 2000 # prune threshold + # Dedup cache: prevents duplicate bot responses when Socket Mode + # reconnects redeliver events. + self._dedup = MessageDeduplicator() # Track pending approval message_ts → resolved flag to prevent # double-clicks on approval buttons. self._approval_resolved: Dict[str, bool] = {} @@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Failed to read %s: %s", tokens_file, e) try: - # Acquire scoped lock to prevent duplicate app token usage - from gateway.status import acquire_scoped_lock - self._token_lock_identity = app_token - acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'}) - if not acquired: - owner_pid = existing.get('pid') if isinstance(existing, dict) else None - message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' - logger.error('[%s] %s', self.name, message) - self._set_fatal_error('slack_token_lock', message, retryable=False) + if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'): return False # First token is the primary — used for AsyncApp / Socket Mode @@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter): logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True) self._running = False - # Release the token lock (use stored identity, not re-read env) - try: - from gateway.status import release_scoped_lock - if getattr(self, '_token_lock_identity', None): - release_scoped_lock('slack-app-token', self._token_lock_identity) - self._token_lock_identity = None - except Exception: - pass + self._release_platform_lock() logger.info("[Slack] Disconnected") @@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter): """Handle an incoming Slack message event.""" # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) event_ts = event.get("ts", "") - if event_ts: - now = time.time() - if event_ts in self._seen_messages: - return - self._seen_messages[event_ts] = now - if len(self._seen_messages) > self._SEEN_MAX: - cutoff = now - self._SEEN_TTL - self._seen_messages = { - k: v for k, v in self._seen_messages.items() - if v > cutoff - } + if event_ts and self._dedup.is_duplicate(event_ts): + return # Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots): # "none" — ignore all bot messages (default, backward-compatible) diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index a0760199b..953ec5c5e 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -19,7 +19,6 @@ import asyncio import base64 import logging import os -import re import urllib.parse from typing import Any, Dict, Optional @@ -30,6 +29,7 @@ from gateway.platforms.base import ( MessageType, SendResult, ) +from gateway.platforms.helpers import redact_phone, strip_markdown logger = logging.getLogger(__name__) @@ -37,18 +37,6 @@ TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts" MAX_SMS_LENGTH = 1600 # ~10 SMS segments DEFAULT_WEBHOOK_PORT = 8080 -# E.164 phone number pattern for redaction -_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") - - -def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +1555***4567.""" - if not phone: - return "" - if len(phone) <= 8: - return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****" - return phone[:5] + "***" + phone[-4:] - def check_sms_requirements() -> bool: """Check if SMS adapter dependencies are available.""" @@ -114,7 +102,7 @@ class SmsAdapter(BasePlatformAdapter): logger.info( "[sms] Twilio webhook server listening on port %d, from: %s", self._webhook_port, - _redact_phone(self._from_number), + redact_phone(self._from_number), ) return True @@ -163,7 +151,7 @@ class SmsAdapter(BasePlatformAdapter): error_msg = body.get("message", str(body)) logger.error( "[sms] send failed to %s: %s %s", - _redact_phone(chat_id), + redact_phone(chat_id), resp.status, error_msg, ) @@ -174,7 +162,7 @@ class SmsAdapter(BasePlatformAdapter): msg_sid = body.get("sid", "") last_result = SendResult(success=True, message_id=msg_sid) except Exception as e: - logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e) + logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e) return SendResult(success=False, error=str(e)) finally: # Close session only if we created a fallback (no persistent session) @@ -192,16 +180,7 @@ class SmsAdapter(BasePlatformAdapter): def format_message(self, content: str) -> str: """Strip markdown — SMS renders it as literal characters.""" - content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL) - content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL) - content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL) - content = re.sub(r"```[a-z]*\n?", "", content) - content = re.sub(r"`(.+?)`", r"\1", content) - content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE) - content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content) - content = re.sub(r"\n{3,}", "\n\n", content) - return content.strip() + return strip_markdown(content) # ------------------------------------------------------------------ # Twilio webhook handler @@ -236,7 +215,7 @@ class SmsAdapter(BasePlatformAdapter): # Ignore messages from our own number (echo prevention) if from_number == self._from_number: - logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number)) + logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number)) return web.Response( text='', content_type="application/xml", @@ -244,8 +223,8 @@ class SmsAdapter(BasePlatformAdapter): logger.info( "[sms] inbound from %s -> %s: %s", - _redact_phone(from_number), - _redact_phone(to_number), + redact_phone(from_number), + redact_phone(to_number), text[:80], ) diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 8b4e43514..884ef9c45 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter): self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} - self._token_lock_identity: Optional[str] = None self._polling_error_task: Optional[asyncio.Task] = None self._polling_conflict_count: int = 0 self._polling_network_error_count: int = 0 @@ -497,23 +496,7 @@ class TelegramAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self.config.token - acquired, existing = acquire_scoped_lock( - "telegram-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Telegram bot token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Telegram poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("telegram_token_lock", message, retryable=False) + if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'): return False # Build the application @@ -737,12 +720,7 @@ class TelegramAdapter(BasePlatformAdapter): return True except Exception as e: - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception: - pass + self._release_platform_lock() message = f"Telegram startup failed: {e}" self._set_fatal_error("telegram_connect_error", message, retryable=True) logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True) @@ -768,12 +746,7 @@ class TelegramAdapter(BasePlatformAdapter): await self._app.shutdown() except Exception as e: logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True) - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("telegram-bot-token", self._token_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() for task in self._pending_photo_batch_tasks.values(): if task and not task.done(): @@ -784,7 +757,6 @@ class TelegramAdapter(BasePlatformAdapter): self._mark_disconnected() self._app = None self._bot = None - self._token_lock_identity = None logger.info("[%s] Disconnected from Telegram", self.name) def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool: diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index aa07dc6a9..a0e71e01b 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -59,6 +59,7 @@ except ImportError: httpx = None # type: ignore[assignment] from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0 HEARTBEAT_INTERVAL_SECONDS = 30.0 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] -DEDUP_WINDOW_SECONDS = 300 DEDUP_MAX_SIZE = 1000 IMAGE_MAX_BYTES = 10 * 1024 * 1024 @@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter): self._listen_task: Optional[asyncio.Task] = None self._heartbeat_task: Optional[asyncio.Task] = None self._pending_responses: Dict[str, asyncio.Future] = {} - self._seen_messages: Dict[str, float] = {} + self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE) self._reply_req_ids: Dict[str, str] = {} # Text batching: merge rapid successive messages (Telegram-style). @@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter): await self._http_client.aclose() self._http_client = None - self._seen_messages.clear() + self._dedup.clear() logger.info("[%s] Disconnected", self.name) async def _cleanup_ws(self) -> None: @@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter): return msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex) - if self._is_duplicate(msg_id): + if self._dedup.is_duplicate(msg_id): logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id) return self._remember_reply_req_id(msg_id, self._payload_req_id(payload)) @@ -839,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter): wildcard = self._groups.get("*") return wildcard if isinstance(wildcard, dict) else {} - def _is_duplicate(self, msg_id: str) -> bool: - now = time.time() - if len(self._seen_messages) > DEDUP_MAX_SIZE: - cutoff = now - DEDUP_WINDOW_SECONDS - self._seen_messages = { - key: ts for key, ts in self._seen_messages.items() if ts > cutoff - } - if self._reply_req_ids: - self._reply_req_ids = { - key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages - } - - if msg_id in self._seen_messages: - return True - - self._seen_messages[msg_id] = now - return False - def _remember_reply_req_id(self, message_id: str, req_id: str) -> None: normalized_message_id = str(message_id or "").strip() normalized_req_id = str(req_id or "").strip() diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index 5e0208c77..3a4a80540 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate CRYPTO_AVAILABLE = False from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -1008,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter): self._typing_cache = TypingTicketCache() self._session: Optional[aiohttp.ClientSession] = None self._poll_task: Optional[asyncio.Task] = None - self._seen_messages: Dict[str, float] = {} - self._token_lock_identity: Optional[str] = None + self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS) self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip() self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip() @@ -1067,23 +1067,7 @@ class WeixinAdapter(BasePlatformAdapter): return False try: - from gateway.status import acquire_scoped_lock - - self._token_lock_identity = self._token - acquired, existing = acquire_scoped_lock( - "weixin-bot-token", - self._token_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this Weixin token" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second Weixin poller." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("weixin_token_lock", message, retryable=False) + if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'): return False except Exception as exc: logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc) @@ -1107,12 +1091,7 @@ class WeixinAdapter(BasePlatformAdapter): if self._session and not self._session.closed: await self._session.close() self._session = None - if self._token_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("weixin-bot-token", self._token_lock_identity) - except Exception as exc: - logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True) + self._release_platform_lock() self._mark_disconnected() logger.info("[%s] Disconnected", self.name) @@ -1190,16 +1169,8 @@ class WeixinAdapter(BasePlatformAdapter): return message_id = str(message.get("message_id") or "").strip() - if message_id: - now = time.time() - self._seen_messages = { - key: value - for key, value in self._seen_messages.items() - if now - value < MESSAGE_DEDUP_TTL_SECONDS - } - if message_id in self._seen_messages: - return - self._seen_messages[message_id] = now + if message_id and self._dedup.is_duplicate(message_id): + return chat_type, effective_chat_id = _guess_chat_type(message, self._account_id) if chat_type == "group": diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index a6475dcb8..c616f7244 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter): self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None self._http_session: Optional["aiohttp.ClientSession"] = None - self._session_lock_identity: Optional[str] = None def _whatsapp_require_mention(self) -> bool: configured = self.config.extra.get("require_mention") @@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter): # Acquire scoped lock to prevent duplicate sessions try: - from gateway.status import acquire_scoped_lock - - self._session_lock_identity = str(self._session_path) - acquired, existing = acquire_scoped_lock( - "whatsapp-session", - self._session_lock_identity, - metadata={"platform": self.platform.value}, - ) - if not acquired: - owner_pid = existing.get("pid") if isinstance(existing, dict) else None - message = ( - "Another local Hermes gateway is already using this WhatsApp session" - + (f" (PID {owner_pid})." if owner_pid else ".") - + " Stop the other gateway before starting a second WhatsApp bridge." - ) - logger.error("[%s] %s", self.name, message) - self._set_fatal_error("whatsapp_session_lock", message, retryable=False) + if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'): return False except Exception as e: logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e) @@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter): return True except Exception as e: - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception: - pass + self._release_platform_lock() logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) self._close_bridge_log() return False @@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter): await self._http_session.close() self._http_session = None - if self._session_lock_identity: - try: - from gateway.status import release_scoped_lock - release_scoped_lock("whatsapp-session", self._session_lock_identity) - except Exception as e: - logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True) + self._release_platform_lock() self._mark_disconnected() self._bridge_process = None self._close_bridge_log() - self._session_lock_identity = None print(f"[{self.name}] Disconnected") async def send( diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index fcb7c2dc5..56b9fb63c 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { } +# ============================================================================= +# Anthropic Key Helper +# ============================================================================= + +def get_anthropic_key() -> str: + """Return the first usable Anthropic credential, or ``""``. + + Checks both the ``.env`` file (via ``get_env_value``) and the process + environment (``os.getenv``). The fallback order mirrors the + ``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple: + + ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN + """ + from hermes_cli.config import get_env_value + + for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars: + value = get_env_value(var) or os.getenv(var, "") + if value: + return value + return "" + + # ============================================================================= # Kimi Code Endpoint Detection # ============================================================================= diff --git a/hermes_cli/cli_output.py b/hermes_cli/cli_output.py new file mode 100644 index 000000000..3d454eb30 --- /dev/null +++ b/hermes_cli/cli_output.py @@ -0,0 +1,79 @@ +"""Shared CLI output helpers for Hermes CLI modules. + +Extracts the identical ``print_info/success/warning/error`` and ``prompt()`` +functions previously duplicated across setup.py, tools_config.py, +mcp_config.py, and memory_setup.py. +""" + +import getpass +import sys + +from hermes_cli.colors import Colors, color + + +# ─── Print Helpers ──────────────────────────────────────────────────────────── + + +def print_info(text: str) -> None: + """Print a dim informational message.""" + print(color(f" {text}", Colors.DIM)) + + +def print_success(text: str) -> None: + """Print a green success message with ✓ prefix.""" + print(color(f"✓ {text}", Colors.GREEN)) + + +def print_warning(text: str) -> None: + """Print a yellow warning message with ⚠ prefix.""" + print(color(f"⚠ {text}", Colors.YELLOW)) + + +def print_error(text: str) -> None: + """Print a red error message with ✗ prefix.""" + print(color(f"✗ {text}", Colors.RED)) + + +def print_header(text: str) -> None: + """Print a bold yellow header.""" + print(color(f"\n {text}", Colors.YELLOW)) + + +# ─── Input Prompts ──────────────────────────────────────────────────────────── + + +def prompt( + question: str, + default: str | None = None, + password: bool = False, +) -> str: + """Prompt the user for input with optional default and password masking. + + Replaces the four independent ``_prompt()`` / ``prompt()`` implementations + in setup.py, tools_config.py, mcp_config.py, and memory_setup.py. + + Returns the user's input (stripped), or *default* if the user presses Enter. + Returns empty string on Ctrl-C or EOF. + """ + suffix = f" [{default}]" if default else "" + display = color(f" {question}{suffix}: ", Colors.YELLOW) + + try: + if password: + value = getpass.getpass(display) + else: + value = input(display) + value = value.strip() + return value if value else (default or "") + except (KeyboardInterrupt, EOFError): + print() + return "" + + +def prompt_yes_no(question: str, default: bool = True) -> bool: + """Prompt for a yes/no answer. Returns bool.""" + hint = "Y/n" if default else "y/N" + answer = prompt(f"{question} ({hint})") + if not answer: + return default + return answer.lower().startswith("y") diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 4661455d1..c3cf0456e 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -2582,7 +2582,8 @@ def show_config(): for env_key, name in keys: value = get_env_value(env_key) print(f" {name:<14} {redact_key(value)}") - anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() print(f" {'Anthropic':<14} {redact_key(anthropic_value)}") # Model settings @@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str): # Write only user config back (not the full merged defaults) ensure_hermes_home() - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, user_config, sort_keys=False) # Keep .env in sync for keys that terminal_tool reads directly from env vars. # config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc. diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index f5f8a228a..13c904692 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -336,8 +336,8 @@ def run_doctor(args): model_section[k] = raw_config.pop(k) else: raw_config.pop(k) - with open(config_path, "w") as f: - yaml.dump(raw_config, f, default_flow_style=False) + from utils import atomic_yaml_write + atomic_yaml_write(config_path, raw_config) check_ok("Migrated stale root-level keys into model section") fixed_count += 1 else: @@ -686,7 +686,8 @@ def run_doctor(args): else: check_warn("OpenRouter API", "(not configured)") - anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY") + from hermes_cli.auth import get_anthropic_key + anthropic_key = get_anthropic_key() if anthropic_key: print(" Checking Anthropic API...", end="", flush=True) try: diff --git a/hermes_cli/main.py b/hermes_cli/main.py index e004a6e93..4b7dd600b 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -2549,13 +2549,8 @@ def _model_flow_anthropic(config, current_model=""): from hermes_cli.models import _PROVIDER_MODELS # Check ALL credential sources - existing_key = ( - get_env_value("ANTHROPIC_TOKEN") - or os.getenv("ANTHROPIC_TOKEN", "") - or get_env_value("ANTHROPIC_API_KEY") - or os.getenv("ANTHROPIC_API_KEY", "") - or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "") - ) + from hermes_cli.auth import get_anthropic_key + existing_key = get_anthropic_key() cc_available = False try: from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index 9154ed50a..cf2dde089 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool: def _prompt(question: str, *, password: bool = False, default: str = "") -> str: - display = f" {question}" - if default: - display += f" [{default}]" - display += ": " - try: - if password: - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default - except (KeyboardInterrupt, EOFError): - print() - return default + from hermes_cli.cli_output import prompt as _shared_prompt + return _shared_prompt(question, default=default, password=password) # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/memory_setup.py b/hermes_cli/memory_setup.py index 2843f4f44..1aa431367 100644 --- a/hermes_cli/memory_setup.py +++ b/hermes_cli/memory_setup.py @@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) - items: list of (label, description) tuples. Returns selected index, or default on escape/quit. """ - try: - import curses - result = [default] - - def _menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - curses.init_pair(3, curses.COLOR_CYAN, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Title - try: - stdscr.addnstr(0, 0, title, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1, - curses.color_pair(3) if curses.has_colors() else curses.A_DIM) - except curses.error: - pass - - for i, (label, desc) in enumerate(items): - y = i + 3 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {label}" - if desc: - line += f" {desc}" - - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(items) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(items) - elif key in (curses.KEY_ENTER, 10, 13): - result[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_menu) - return result[0] - - except Exception: - # Fallback: numbered input - print(f"\n {title}\n") - for i, (label, desc) in enumerate(items): - marker = "→" if i == default else " " - d = f" {desc}" if desc else "" - print(f" {marker} {i + 1}. {label}{d}") - while True: - try: - val = input(f"\n Select [1-{len(items)}] ({default + 1}): ") - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(items): - return idx - except (ValueError, EOFError): - return default + from hermes_cli.curses_ui import curses_radiolist + # Format (label, desc) tuples into display strings + display_items = [ + f"{label} {desc}" if desc else label + for label, desc in items + ] + return curses_radiolist(title, display_items, selected=default, cancel_returns=default) def _prompt(label: str, default: str | None = None, secret: bool = False) -> str: diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py new file mode 100644 index 000000000..18307912b --- /dev/null +++ b/hermes_cli/platforms.py @@ -0,0 +1,45 @@ +""" +Shared platform registry for Hermes Agent. + +Single source of truth for platform metadata consumed by both +skills_config (label display) and tools_config (default toolset +resolution). Import ``PLATFORMS`` from here instead of maintaining +duplicate dicts in each module. +""" + +from collections import OrderedDict +from typing import NamedTuple + + +class PlatformInfo(NamedTuple): + """Metadata for a single platform entry.""" + label: str + default_toolset: str + + +# Ordered so that TUI menus are deterministic. +PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ + ("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")), + ("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")), + ("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")), + ("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")), + ("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")), + ("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")), + ("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")), + ("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")), + ("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")), + ("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")), + ("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")), + ("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")), + ("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")), + ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")), + ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), + ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), + ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), +]) + + +def platform_label(key: str, default: str = "") -> str: + """Return the display label for a platform key, or *default*.""" + info = PLATFORMS.get(key) + return info.label if info is not None else default diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index ca877606f..fb70d9081 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -197,24 +197,12 @@ def print_header(title: str): print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD)) -def print_info(text: str): - """Print info text.""" - print(color(f" {text}", Colors.DIM)) - - -def print_success(text: str): - """Print success message.""" - print(color(f"✓ {text}", Colors.GREEN)) - - -def print_warning(text: str): - """Print warning message.""" - print(color(f"⚠ {text}", Colors.YELLOW)) - - -def print_error(text: str): - """Print error message.""" - print(color(f"✗ {text}", Colors.RED)) +from hermes_cli.cli_output import ( # noqa: E402 + print_error, + print_info, + print_success, + print_warning, +) def is_interactive_stdin() -> bool: @@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str: def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu using curses to avoid simple_term_menu rendering bugs.""" - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - scroll_offset = 0 - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - - # Rows available for list items: rows 2..(max_y-2) inclusive. - visible = max(1, max_y - 3) - - # Scroll the viewport so the cursor is always visible. - if cursor < scroll_offset: - scroll_offset = cursor - elif cursor >= scroll_offset + visible: - scroll_offset = cursor - visible + 1 - scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible))) - - try: - stdscr.addnstr( - 0, - 0, - question, - max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0), - ) - except curses.error: - pass - - for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))): - y = row + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {choices[i]}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - if key in (curses.KEY_UP, ord("k")): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord("j")): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord("q")): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - except Exception: - return -1 + """Single-select menu using curses. Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=-1) diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index b017361fe..92424a0ca 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -15,25 +15,12 @@ from typing import List, Optional, Set from hermes_cli.config import load_config, save_config from hermes_cli.colors import Colors, color +from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label -PLATFORMS = { - "cli": "🖥️ CLI", - "telegram": "📱 Telegram", - "discord": "💬 Discord", - "slack": "💼 Slack", - "whatsapp": "📱 WhatsApp", - "signal": "📡 Signal", - "bluebubbles": "💬 BlueBubbles", - "email": "📧 Email", - "homeassistant": "🏠 Home Assistant", - "mattermost": "💬 Mattermost", - "matrix": "💬 Matrix", - "dingtalk": "💬 DingTalk", - "feishu": "🪽 Feishu", - "wecom": "💬 WeCom", - "weixin": "💬 Weixin", - "webhook": "🔗 Webhook", -} +# Backward-compatible view: {key: label_string} so existing code that +# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps +# working without changes to every call site. +PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"} # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/status.py b/hermes_cli/status.py index baba4f359..7a7a9c645 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -141,11 +141,8 @@ def show_status(args): display = redact_key(value) if not show_all else value print(f" {name:<12} {check_mark(has_key)} {display}") - anthropic_value = ( - get_env_value("ANTHROPIC_TOKEN") - or get_env_value("ANTHROPIC_API_KEY") - or "" - ) + from hermes_cli.auth import get_anthropic_key + anthropic_value = get_anthropic_key() anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}") diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 91c41dce5..343007cab 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve() # ─── UI Helpers (shared with setup.py) ──────────────────────────────────────── -def _print_info(text: str): - print(color(f" {text}", Colors.DIM)) - -def _print_success(text: str): - print(color(f"✓ {text}", Colors.GREEN)) - -def _print_warning(text: str): - print(color(f"⚠ {text}", Colors.YELLOW)) - -def _print_error(text: str): - print(color(f"✗ {text}", Colors.RED)) - -def _prompt(question: str, default: str = None, password: bool = False) -> str: - if default: - display = f"{question} [{default}]: " - else: - display = f"{question}: " - try: - if password: - import getpass - value = getpass.getpass(color(display, Colors.YELLOW)) - else: - value = input(color(display, Colors.YELLOW)) - return value.strip() or default or "" - except (KeyboardInterrupt, EOFError): - print() - return default or "" +from hermes_cli.cli_output import ( # noqa: E402 — late import block + print_error as _print_error, + print_info as _print_info, + print_success as _print_success, + print_warning as _print_warning, + prompt as _prompt, +) # ─── Toolset Registry ───────────────────────────────────────────────────────── @@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set: except Exception: return set() -# Platform display config +# Platform display config — derived from the canonical registry so every +# module shares the same data. Kept as dict-of-dicts for backward +# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns. +from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY + PLATFORMS = { - "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, - "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, - "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"}, - "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"}, - "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, - "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, - "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"}, - "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, - "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, - "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, - "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, - "feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"}, - "wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"}, - "weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"}, - "api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"}, - "mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"}, - "webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"}, + k: {"label": info.label, "default_toolset": info.default_toolset} + for k, info in _PLATFORMS_REGISTRY.items() } @@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool: # ─── Menu Helpers ───────────────────────────────────────────────────────────── def _prompt_choice(question: str, choices: list, default: int = 0) -> int: - """Single-select menu (arrow keys). Uses curses to avoid simple_term_menu - rendering bugs in tmux, iTerm, and other non-standard terminals.""" - - # Curses-based single-select — works in tmux, iTerm, and standard terminals - try: - import curses - result_holder = [default] - - def _curses_menu(stdscr): - curses.curs_set(0) - if curses.has_colors(): - curses.start_color() - curses.use_default_colors() - curses.init_pair(1, curses.COLOR_GREEN, -1) - curses.init_pair(2, curses.COLOR_YELLOW, -1) - cursor = default - - while True: - stdscr.clear() - max_y, max_x = stdscr.getmaxyx() - try: - stdscr.addnstr(0, 0, question, max_x - 1, - curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0)) - except curses.error: - pass - - for i, c in enumerate(choices): - y = i + 2 - if y >= max_y - 1: - break - arrow = "→" if i == cursor else " " - line = f" {arrow} {c}" - attr = curses.A_NORMAL - if i == cursor: - attr = curses.A_BOLD - if curses.has_colors(): - attr |= curses.color_pair(1) - try: - stdscr.addnstr(y, 0, line, max_x - 1, attr) - except curses.error: - pass - - stdscr.refresh() - key = stdscr.getch() - - if key in (curses.KEY_UP, ord('k')): - cursor = (cursor - 1) % len(choices) - elif key in (curses.KEY_DOWN, ord('j')): - cursor = (cursor + 1) % len(choices) - elif key in (curses.KEY_ENTER, 10, 13): - result_holder[0] = cursor - return - elif key in (27, ord('q')): - return - - curses.wrapper(_curses_menu) - from hermes_cli.curses_ui import flush_stdin - flush_stdin() - return result_holder[0] - - except Exception: - pass - - # Fallback: numbered input (Windows without curses, etc.) - print(color(question, Colors.YELLOW)) - for i, c in enumerate(choices): - marker = "●" if i == default else "○" - style = Colors.GREEN if i == default else "" - print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}") - while True: - try: - val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM)) - if not val: - return default - idx = int(val) - 1 - if 0 <= idx < len(choices): - return idx - except (ValueError, KeyboardInterrupt, EOFError): - print() - return default + """Single-select menu (arrow keys). Delegates to curses_radiolist.""" + from hermes_cli.curses_ui import curses_radiolist + return curses_radiolist(question, choices, selected=default, cancel_returns=default) # ─── Token Estimation ──────────────────────────────────────────────────────── diff --git a/hermes_constants.py b/hermes_constants.py index 7d149f404..85955d548 100644 --- a/hermes_constants.py +++ b/hermes_constants.py @@ -189,6 +189,33 @@ def is_wsl() -> bool: return _wsl_detected +# ─── Well-Known Paths ───────────────────────────────────────────────────────── + + +def get_config_path() -> Path: + """Return the path to ``config.yaml`` under HERMES_HOME. + + Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated + in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.). + """ + return get_hermes_home() / "config.yaml" + + +def get_skills_dir() -> Path: + """Return the path to the skills directory under HERMES_HOME.""" + return get_hermes_home() / "skills" + + +def get_logs_dir() -> Path: + """Return the path to the logs directory under HERMES_HOME.""" + return get_hermes_home() / "logs" + + +def get_env_path() -> Path: + """Return the path to the ``.env`` file under HERMES_HOME.""" + return get_hermes_home() / ".env" + + OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models" diff --git a/hermes_logging.py b/hermes_logging.py index 5d71590c3..b765e9464 100644 --- a/hermes_logging.py +++ b/hermes_logging.py @@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Optional -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path, get_hermes_home # Sentinel to track whether setup_logging() has already run. The function # is idempotent — calling it twice is safe but the second call is a no-op @@ -246,7 +246,7 @@ def _read_logging_config(): """ try: import yaml - config_path = get_hermes_home() / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} diff --git a/hermes_time.py b/hermes_time.py index f7d085544..9f172d28f 100644 --- a/hermes_time.py +++ b/hermes_time.py @@ -16,7 +16,7 @@ crashes due to a bad timezone string. import logging import os from datetime import datetime -from hermes_constants import get_hermes_home +from hermes_constants import get_config_path from typing import Optional logger = logging.getLogger(__name__) @@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str: # 2. config.yaml ``timezone`` key try: import yaml - hermes_home = get_hermes_home() - config_path = hermes_home / "config.yaml" + config_path = get_config_path() if config_path.exists(): with open(config_path) as f: cfg = yaml.safe_load(f) or {} diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index ef17af10b..d9ca627c4 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None): config = PlatformConfig(enabled=True, token="e2e-test-token") if platform == Platform.DISCORD: - with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()): + from gateway.platforms.helpers import ThreadParticipationTracker + with patch.object(ThreadParticipationTracker, "_load", return_value=set()): adapter = DiscordAdapter(config) platform_key = Platform.DISCORD elif platform == Platform.SLACK: diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py index 5c73253fb..527113650 100644 --- a/tests/gateway/test_dingtalk.py +++ b/tests/gateway/test_dingtalk.py @@ -119,28 +119,29 @@ class TestDeduplication: def test_first_message_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - assert adapter._is_duplicate("msg-1") is False + assert adapter._dedup.is_duplicate("msg-1") is False def test_second_same_message_is_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-1") is True + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-1") is True def test_different_messages_not_duplicate(self): from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) - adapter._is_duplicate("msg-1") - assert adapter._is_duplicate("msg-2") is False + adapter._dedup.is_duplicate("msg-1") + assert adapter._dedup.is_duplicate("msg-2") is False def test_cache_cleanup_on_overflow(self): - from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE + from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + max_size = adapter._dedup._max_size # Fill beyond max - for i in range(DEDUP_MAX_SIZE + 10): - adapter._is_duplicate(f"msg-{i}") + for i in range(max_size + 10): + adapter._dedup.is_duplicate(f"msg-{i}") # Cache should have been pruned - assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10 + assert len(adapter._dedup._seen) <= max_size + 10 # --------------------------------------------------------------------------- @@ -253,13 +254,13 @@ class TestConnect: from gateway.platforms.dingtalk import DingTalkAdapter adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter._session_webhooks["a"] = "http://x" - adapter._seen_messages["b"] = 1.0 + adapter._dedup._seen["b"] = 1.0 adapter._http_client = AsyncMock() adapter._stream_task = None await adapter.disconnect() assert len(adapter._session_webhooks) == 0 - assert len(adapter._seen_messages) == 0 + assert len(adapter._dedup._seen) == 0 assert adapter._http_client is None diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index dd594cf7e..9f094dd0d 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch): assert ok is False assert released == [("discord-bot-token", "test-token")] - assert adapter._token_lock_identity is None + assert adapter._platform_lock_identity is None diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index bc63c14f5..29f65efc6 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") # Simulate bot having previously participated in thread 456 - adapter._bot_participated_threads.add("456") + adapter._threads.mark("456") thread = FakeThread(channel_id=456, name="existing thread") message = make_message(channel=thread, content="follow-up without mention") @@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch): await adapter._handle_message(message) - assert "555" in adapter._bot_participated_threads + assert "555" in adapter._threads @pytest.mark.asyncio @@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp await adapter._handle_message(message) - assert "777" in adapter._bot_participated_threads + assert "777" in adapter._threads diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py index 0288b620d..083f61ac7 100644 --- a/tests/gateway/test_discord_thread_persistence.py +++ b/tests/gateway/test_discord_thread_persistence.py @@ -1,6 +1,6 @@ """Tests for Discord thread participation persistence. -Verifies that _bot_participated_threads survives adapter restarts by +Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by being persisted to ~/.hermes/discord_threads.json. """ @@ -25,13 +25,13 @@ class TestDiscordThreadPersistence: def test_starts_empty_when_no_state_file(self, tmp_path): adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists_to_disk(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("222") + adapter._threads.mark("111") + adapter._threads.mark("222") state_file = tmp_path / "discord_threads.json" assert state_file.exists() @@ -42,42 +42,43 @@ class TestDiscordThreadPersistence: """Threads tracked by one adapter instance are visible to the next.""" adapter1 = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter1._track_thread("aaa") - adapter1._track_thread("bbb") + adapter1._threads.mark("aaa") + adapter1._threads.mark("bbb") adapter2 = self._make_adapter(tmp_path) - assert "aaa" in adapter2._bot_participated_threads - assert "bbb" in adapter2._bot_participated_threads + assert "aaa" in adapter2._threads + assert "bbb" in adapter2._threads def test_duplicate_track_does_not_double_save(self, tmp_path): adapter = self._make_adapter(tmp_path) with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): - adapter._track_thread("111") - adapter._track_thread("111") # no-op + adapter._threads.mark("111") + adapter._threads.mark("111") # no-op saved = json.loads((tmp_path / "discord_threads.json").read_text()) assert saved.count("111") == 1 def test_caps_at_max_tracked_threads(self, tmp_path): adapter = self._make_adapter(tmp_path) - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): for i in range(10): - adapter._track_thread(str(i)) + adapter._threads.mark(str(i)) - assert len(adapter._bot_participated_threads) == 5 + saved = json.loads((tmp_path / "discord_threads.json").read_text()) + assert len(saved) == 5 def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): state_file = tmp_path / "discord_threads.json" state_file.write_text("not valid json{{{") adapter = self._make_adapter(tmp_path) - assert adapter._bot_participated_threads == set() + assert "$nonexistent" not in adapter._threads def test_missing_hermes_home_does_not_crash(self, tmp_path): """Load/save tolerate missing directories.""" fake_home = tmp_path / "nonexistent" / "deep" with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}): - from gateway.platforms.discord import DiscordAdapter - # _load should return empty set, not crash - threads = DiscordAdapter._load_participated_threads() - assert threads == set() + from gateway.platforms.helpers import ThreadParticipationTracker + # ThreadParticipationTracker should return empty set, not crash + tracker = ThreadParticipationTracker("discord") + assert "$test" not in tracker diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index d36c2b765..873b873c2 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch): monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread1") + adapter._threads.mark("$thread1") event = _make_event("hello without mention", thread_id="$thread1") @@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch): monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() - adapter._bot_participated_threads.add("$thread_root") + adapter._threads.mark("$thread_root") event = _make_event("reply in thread", thread_id="$thread_root") await adapter._on_room_message(event) @@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch): @pytest.mark.asyncio async def test_auto_thread_tracks_participation(monkeypatch): - """Auto-created threads are tracked in _bot_participated_threads.""" + """Auto-created threads are tracked in _threads.""" monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false") monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) adapter = _make_adapter() event = _make_event("hello", event_id="$msg1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$msg1" in adapter._bot_participated_threads + assert "$msg1" in adapter._threads # --------------------------------------------------------------------------- @@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch): class TestThreadPersistence: def test_empty_state_file(self, tmp_path, monkeypatch): """No state file → empty set.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: tmp_path / "matrix_threads.json"), + ThreadParticipationTracker, "_state_path", + lambda self: tmp_path / "matrix_threads.json", ) adapter = _make_adapter() - loaded = adapter._load_participated_threads() - assert loaded == set() + assert "$nonexistent" not in adapter._threads def test_track_thread_persists(self, tmp_path, monkeypatch): - """_track_thread writes to disk.""" - from gateway.platforms.matrix import MatrixAdapter + """mark() writes to disk.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._track_thread("$thread_abc") + adapter._threads.mark("$thread_abc") data = json.loads(state_path.read_text()) assert "$thread_abc" in data def test_threads_survive_reload(self, tmp_path, monkeypatch): """Persisted threads are loaded by a new adapter instance.""" - from gateway.platforms.matrix import MatrixAdapter + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" state_path.write_text(json.dumps(["$t1", "$t2"])) monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - assert "$t1" in adapter._bot_participated_threads - assert "$t2" in adapter._bot_participated_threads + assert "$t1" in adapter._threads + assert "$t2" in adapter._threads def test_cap_max_tracked_threads(self, tmp_path, monkeypatch): - """Thread set is trimmed to _MAX_TRACKED_THREADS.""" - from gateway.platforms.matrix import MatrixAdapter + """Thread set is trimmed to max_tracked.""" + from gateway.platforms.helpers import ThreadParticipationTracker state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - MatrixAdapter, "_thread_state_path", - staticmethod(lambda: state_path), + ThreadParticipationTracker, "_state_path", + lambda self: state_path, ) adapter = _make_adapter() - adapter._MAX_TRACKED_THREADS = 5 + adapter._threads._max_tracked = 5 for i in range(10): - adapter._bot_participated_threads.add(f"$t{i}") - adapter._save_participated_threads() + adapter._threads.mark(f"$t{i}") data = json.loads(state_path.read_text()) assert len(data) == 5 @@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help me", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) adapter.handle_message.assert_awaited_once() @@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): adapter = _make_adapter() _set_dm(adapter) - adapter._bot_participated_threads.add("$existing_thread") + adapter._threads.mark("$existing_thread") event = _make_event("@hermes:example.org help me", thread_id="$existing_thread") await adapter._on_room_message(event) @@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch): @pytest.mark.asyncio async def test_dm_mention_thread_tracks_participation(monkeypatch): - """DM mention-thread tracks the thread in _bot_participated_threads.""" + """DM mention-thread tracks the thread in _threads.""" monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true") monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") @@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch): _set_dm(adapter) event = _make_event("@hermes:example.org help", event_id="$dm1") - with patch.object(adapter, "_save_participated_threads"): + with patch.object(adapter._threads, "_save"): await adapter._on_room_message(event) - assert "$dm1" in adapter._bot_participated_threads + assert "$dm1" in adapter._threads # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py index 7d47c0a3e..56e46f636 100644 --- a/tests/gateway/test_mattermost.py +++ b/tests/gateway/test_mattermost.py @@ -614,25 +614,27 @@ class TestMattermostDedup: assert self.adapter.handle_message.call_count == 2 def test_prune_seen_clears_expired(self): - """_prune_seen should remove entries older than _SEEN_TTL.""" + """Dedup cache should remove entries older than TTL on overflow.""" now = time.time() + dedup = self.adapter._dedup # Fill with enough expired entries to trigger pruning - for i in range(self.adapter._SEEN_MAX + 10): - self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago + for i in range(dedup._max_size + 10): + dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL) # Add a fresh one - self.adapter._seen_posts["fresh"] = now + dedup._seen["fresh"] = now - self.adapter._prune_seen() + # Trigger pruning by calling is_duplicate with a new entry (over max_size) + dedup.is_duplicate("trigger_prune") # Old entries should be pruned, fresh one kept - assert "fresh" in self.adapter._seen_posts - assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX + assert "fresh" in dedup._seen + assert len(dedup._seen) < dedup._max_size + 10 def test_seen_cache_tracks_post_ids(self): - """Posts are tracked in _seen_posts dict.""" - self.adapter._seen_posts["test_post"] = time.time() - assert "test_post" in self.adapter._seen_posts + """Posts are tracked in the dedup cache.""" + self.adapter._dedup._seen["test_post"] = time.time() + assert "test_post" in self.adapter._dedup._seen # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index ae985300d..265f9be78 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -114,16 +114,16 @@ class TestSignalAdapterInit: class TestSignalHelpers: def test_redact_phone_long(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+15551234567") == "+155****4567" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+155****4567") == "+155****4567" def test_redact_phone_short(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("+12345") == "+1****45" + from gateway.platforms.helpers import redact_phone + assert redact_phone("+12345") == "+1****45" def test_redact_phone_empty(self): - from gateway.platforms.signal import _redact_phone - assert _redact_phone("") == "" + from gateway.platforms.helpers import redact_phone + assert redact_phone("") == "" def test_parse_comma_list(self): from gateway.platforms.signal import _parse_comma_list diff --git a/tests/gateway/test_telegram_conflict.py b/tests/gateway/test_telegram_conflict.py index 47a67f229..dcf311688 100644 --- a/tests/gateway/test_telegram_conflict.py +++ b/tests/gateway/test_telegram_conflict.py @@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch): async def _noop(): return [] monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) + # Mock HTTPXRequest so the builder chain doesn't fail + monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock()) @pytest.mark.asyncio @@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch): ok = await adapter.connect() assert ok is False - assert adapter.fatal_error_code == "telegram_token_lock" + assert adapter.fatal_error_code == "telegram-bot-token_lock" assert adapter.has_fatal_error is True - assert "already using this Telegram bot token" in adapter.fatal_error_message + assert "already in use" in adapter.fatal_error_message @pytest.mark.asyncio @@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) @@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder app = SimpleNamespace( bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()), updater=SimpleNamespace(), @@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch): ) builder = MagicMock() builder.token.return_value = builder + builder.request.return_value = builder + builder.get_updates_request.return_value = builder builder.build.return_value = app monkeypatch.setattr( "gateway.platforms.telegram.Application", diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py index 7b9e49d4f..dd0ae17f8 100644 --- a/tests/tools/test_skill_manager_tool.py +++ b/tests/tools/test_skill_manager_tool.py @@ -348,7 +348,7 @@ word word result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.read_text() == "old text here" @@ -412,7 +412,7 @@ class TestWriteFile: result = _write_file("my-skill", "references/escape/owned.md", "malicious") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert not (outside_dir / "owned.md").exists() @@ -449,7 +449,7 @@ class TestRemoveFile: result = _remove_file("my-skill", "references/escape/keep.txt") assert result["success"] is False - assert "boundary" in result["error"].lower() + assert "escapes" in result["error"].lower() assert outside_file.exists() diff --git a/tools/credential_files.py b/tools/credential_files.py index 6ddcd0770..7998321e6 100644 --- a/tools/credential_files.py +++ b/tools/credential_files.py @@ -80,20 +80,18 @@ def register_credential_file( # Resolve symlinks and normalise ``..`` before the containment check so # that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME. - try: - resolved = host_path.resolve() - hermes_home_resolved = hermes_home.resolve() - resolved.relative_to(hermes_home_resolved) # raises ValueError if outside - except ValueError: + from tools.path_security import validate_within_dir + + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", + "credential_files: rejected path traversal %r (%s)", relative_path, - resolved, - hermes_home_resolved, + containment_error, ) return False + resolved = host_path.resolve() if not resolved.is_file(): logger.debug("credential_files: skipping %s (not found)", resolved) return False @@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]: cfg = read_raw_config() cred_files = cfg.get("terminal", {}).get("credential_files") if isinstance(cred_files, list): - hermes_home_resolved = hermes_home.resolve() + from tools.path_security import validate_within_dir + for item in cred_files: if isinstance(item, str) and item.strip(): rel = item.strip() @@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]: "credential_files: rejected absolute config path %r", rel, ) continue - host_path = (hermes_home / rel).resolve() - try: - host_path.relative_to(hermes_home_resolved) - except ValueError: + host_path = hermes_home / rel + containment_error = validate_within_dir(host_path, hermes_home) + if containment_error: logger.warning( - "credential_files: rejected config path traversal %r " - "(resolves to %s, outside HERMES_HOME %s)", - rel, host_path, hermes_home_resolved, + "credential_files: rejected config path traversal %r (%s)", + rel, containment_error, ) continue - if host_path.is_file(): + resolved_path = host_path.resolve() + if resolved_path.is_file(): container_path = f"/root/.hermes/{rel}" result.append({ - "host_path": str(host_path), + "host_path": str(resolved_path), "container_path": container_path, }) except Exception as e: diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 3018b8731..e2db93381 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]: ) # Validate containment after resolution + from tools.path_security import validate_within_dir + scripts_dir = get_hermes_home() / "scripts" scripts_dir.mkdir(parents=True, exist_ok=True) - resolved = (scripts_dir / raw).resolve() - try: - resolved.relative_to(scripts_dir.resolve()) - except ValueError: + containment_error = validate_within_dir(scripts_dir / raw, scripts_dir) + if containment_error: return ( f"Script path escapes the scripts directory via traversal: {raw!r}" ) diff --git a/tools/path_security.py b/tools/path_security.py new file mode 100644 index 000000000..828011e5d --- /dev/null +++ b/tools/path_security.py @@ -0,0 +1,43 @@ +"""Shared path validation helpers for tool implementations. + +Extracts the ``resolve() + relative_to()`` and ``..`` traversal check +patterns previously duplicated across skill_manager_tool, skills_tool, +skills_hub, cronjob_tools, and credential_files. +""" + +import logging +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + + +def validate_within_dir(path: Path, root: Path) -> Optional[str]: + """Ensure *path* resolves to a location within *root*. + + Returns an error message string if validation fails, or ``None`` if the + path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize + ``..`` components. + + Usage:: + + error = validate_within_dir(user_path, allowed_root) + if error: + return json.dumps({"error": error}) + """ + try: + resolved = path.resolve() + root_resolved = root.resolve() + resolved.relative_to(root_resolved) + except (ValueError, OSError) as exc: + return f"Path escapes allowed directory: {exc}" + return None + + +def has_traversal_component(path_str: str) -> bool: + """Return True if *path_str* contains ``..`` traversal components. + + Quick check for obvious traversal attempts before doing full resolution. + """ + parts = Path(path_str).parts + return ".." in parts diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 2273d75fa..2b2625fa0 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]: Validate a file path for write_file/remove_file. Must be under an allowed subdirectory and not escape the skill dir. """ + from tools.path_security import has_traversal_component + if not file_path: return "file_path is required." normalized = Path(file_path) # Prevent path traversal - if ".." in normalized.parts: + if has_traversal_component(file_path): return "Path traversal ('..') is not allowed." # Must be under an allowed subdirectory @@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]: def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]: """Resolve a supporting-file path and ensure it stays within the skill directory.""" + from tools.path_security import validate_within_dir + target = skill_dir / file_path - try: - resolved = target.resolve(strict=False) - skill_dir_resolved = skill_dir.resolve() - resolved.relative_to(skill_dir_resolved) - except ValueError: - return None, "Path escapes skill directory boundary." - except OSError as e: - return None, f"Invalid file path '{file_path}': {e}" + error = validate_within_dir(target, skill_dir) + if error: + return None, error return target, None diff --git a/tools/skills_tool.py b/tools/skills_tool.py index 085ed0055..94b7c235b 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]: return None -def _estimate_tokens(content: str) -> int: - """ - Rough token estimate (4 chars per token average). - - Args: - content: Text content - - Returns: - Estimated token count - """ - return len(content) // 4 +# Token estimation — use the shared implementation from model_metadata. +from agent.model_metadata import estimate_tokens_rough as _estimate_tokens def _parse_tags(tags_value) -> List[str]: @@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: # If a specific file path is requested, read that instead if file_path and skill_dir: + from tools.path_security import validate_within_dir, has_traversal_component + # Security: Prevent path traversal attacks - normalized_path = Path(file_path) - if ".." in normalized_path.parts: + if has_traversal_component(file_path): return json.dumps( { "success": False, @@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: target_file = skill_dir / file_path # Security: Verify resolved path is still within skill directory - try: - resolved = target_file.resolve() - skill_dir_resolved = skill_dir.resolve() - if not resolved.is_relative_to(skill_dir_resolved): - return json.dumps( - { - "success": False, - "error": "Path escapes skill directory boundary.", - "hint": "Use a relative path within the skill directory", - }, - ensure_ascii=False, - ) - except (OSError, ValueError): + traversal_error = validate_within_dir(target_file, skill_dir) + if traversal_error: return json.dumps( { "success": False, - "error": f"Invalid file path: '{file_path}'", - "hint": "Use a valid relative path within the skill directory", + "error": traversal_error, + "hint": "Use a relative path within the skill directory", }, ensure_ascii=False, ) diff --git a/utils.py b/utils.py index 9a2105d54..bd2a6b70f 100644 --- a/utils.py +++ b/utils.py @@ -1,13 +1,16 @@ """Shared utility functions for hermes-agent.""" import json +import logging import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any, List, Optional, Union import yaml +logger = logging.getLogger(__name__) + TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"}) @@ -124,3 +127,88 @@ def atomic_yaml_write( except OSError: pass raise + + +# ─── JSON Helpers ───────────────────────────────────────────────────────────── + + +def safe_json_loads(text: str, default: Any = None) -> Any: + """Parse JSON, returning *default* on any parse error. + + Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)`` + pattern duplicated across display.py, anthropic_adapter.py, + auxiliary_client.py, and others. + """ + try: + return json.loads(text) + except (json.JSONDecodeError, TypeError, ValueError): + return default + + +def read_json_file(path: Path, default: Any = None) -> Any: + """Read and parse a JSON file, returning *default* on any error. + + Replaces the repeated ``try: json.loads(path.read_text()) except ...`` + pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py, + and skill_utils.py. + """ + try: + return json.loads(Path(path).read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError, IOError, ValueError) as exc: + logger.debug("Failed to read %s: %s", path, exc) + return default + + +def read_jsonl(path: Path) -> List[dict]: + """Read a JSONL file (one JSON object per line). + + Returns a list of parsed objects, skipping blank lines. + """ + entries = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + entries.append(json.loads(line)) + return entries + + +def append_jsonl(path: Path, entry: dict) -> None: + """Append a single JSON object as a new line to a JSONL file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +# ─── Environment Variable Helpers ───────────────────────────────────────────── + + +def env_str(key: str, default: str = "") -> str: + """Read an environment variable, stripped of whitespace. + + Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times + across runtime_provider.py, anthropic_adapter.py, models.py, etc. + """ + return os.getenv(key, default).strip() + + +def env_lower(key: str, default: str = "") -> str: + """Read an environment variable, stripped and lowercased.""" + return os.getenv(key, default).strip().lower() + + +def env_int(key: str, default: int = 0) -> int: + """Read an environment variable as an integer, with fallback.""" + raw = os.getenv(key, "").strip() + if not raw: + return default + try: + return int(raw) + except (ValueError, TypeError): + return default + + +def env_bool(key: str, default: bool = False) -> bool: + """Read an environment variable as a boolean.""" + return is_truthy_value(os.getenv(key, ""), default=default)