From 04c1c5d53f555b9798d180802a288f0566f9acd7 Mon Sep 17 00:00:00 2001
From: Teknium <127238744+teknium1@users.noreply.github.com>
Date: Sat, 11 Apr 2026 13:59:52 -0700
Subject: [PATCH] refactor: extract shared helpers to deduplicate repeated code
patterns (#7917)
* refactor: add shared helper modules for code deduplication
New modules:
- gateway/platforms/helpers.py: MessageDeduplicator, TextBatchAggregator,
strip_markdown, ThreadParticipationTracker, redact_phone
- hermes_cli/cli_output.py: print_info/success/warning/error, prompt helpers
- tools/path_security.py: validate_within_dir, has_traversal_component
- utils.py additions: safe_json_loads, read_json_file, read_jsonl,
append_jsonl, env_str/lower/int/bool helpers
- hermes_constants.py additions: get_config_path, get_skills_dir,
get_logs_dir, get_env_path
* refactor: migrate gateway adapters to shared helpers
- MessageDeduplicator: discord, slack, dingtalk, wecom, weixin, mattermost
- strip_markdown: bluebubbles, feishu, sms
- redact_phone: sms, signal
- ThreadParticipationTracker: discord, matrix
- _acquire/_release_platform_lock: telegram, discord, slack, whatsapp,
signal, weixin
Net -316 lines across 19 files.
* refactor: migrate CLI modules to shared helpers
- tools_config.py: use cli_output print/prompt + curses_radiolist (-117 lines)
- setup.py: use cli_output print helpers + curses_radiolist (-101 lines)
- mcp_config.py: use cli_output prompt (-15 lines)
- memory_setup.py: use curses_radiolist (-86 lines)
Net -263 lines across 5 files.
* refactor: migrate to shared utility helpers
- safe_json_loads: agent/display.py (4 sites)
- get_config_path: skill_utils.py, hermes_logging.py, hermes_time.py
- get_skills_dir: skill_utils.py, prompt_builder.py
- Token estimation dedup: skills_tool.py imports from model_metadata
- Path security: skills_tool, cronjob_tools, skill_manager_tool, credential_files
- Non-atomic YAML writes: doctor.py, config.py now use atomic_yaml_write
- Platform dict: new platforms.py, skills_config + tools_config derive from it
- Anthropic key: new get_anthropic_key() in auth.py, used by doctor/status/config/main
* test: update tests for shared helper migrations
- test_dingtalk: use _dedup.is_duplicate() instead of _is_duplicate()
- test_mattermost: use _dedup instead of _seen_posts/_prune_seen
- test_signal: import redact_phone from helpers instead of signal
- test_discord_connect: _platform_lock_identity instead of _token_lock_identity
- test_telegram_conflict: updated lock error message format
- test_skill_manager_tool: 'escapes' instead of 'boundary' in error msgs
---
agent/display.py | 25 +-
agent/prompt_builder.py | 5 +-
agent/skill_utils.py | 12 +-
gateway/platforms/base.py | 31 ++-
gateway/platforms/bluebubbles.py | 18 +-
gateway/platforms/dingtalk.py | 25 +-
gateway/platforms/discord.py | 111 +-------
gateway/platforms/feishu.py | 16 +-
gateway/platforms/helpers.py | 261 ++++++++++++++++++
gateway/platforms/matrix.py | 54 +---
gateway/platforms/mattermost.py | 23 +-
gateway/platforms/signal.py | 45 +--
gateway/platforms/slack.py | 41 +--
gateway/platforms/sms.py | 37 +--
gateway/platforms/telegram.py | 34 +--
gateway/platforms/wecom.py | 26 +-
gateway/platforms/weixin.py | 41 +--
gateway/platforms/whatsapp.py | 34 +--
hermes_cli/auth.py | 22 ++
hermes_cli/cli_output.py | 79 ++++++
hermes_cli/config.py | 7 +-
hermes_cli/doctor.py | 7 +-
hermes_cli/main.py | 9 +-
hermes_cli/mcp_config.py | 15 +-
hermes_cli/memory_setup.py | 86 +-----
hermes_cli/platforms.py | 45 +++
hermes_cli/setup.py | 101 +------
hermes_cli/skills_config.py | 23 +-
hermes_cli/status.py | 7 +-
hermes_cli/tools_config.py | 142 ++--------
hermes_constants.py | 27 ++
hermes_logging.py | 4 +-
hermes_time.py | 5 +-
tests/e2e/conftest.py | 3 +-
tests/gateway/test_dingtalk.py | 23 +-
tests/gateway/test_discord_connect.py | 2 +-
tests/gateway/test_discord_free_response.py | 6 +-
.../test_discord_thread_persistence.py | 37 +--
tests/gateway/test_matrix_mention.py | 62 ++---
tests/gateway/test_mattermost.py | 22 +-
tests/gateway/test_signal.py | 12 +-
tests/gateway/test_telegram_conflict.py | 14 +-
tests/tools/test_skill_manager_tool.py | 6 +-
tools/credential_files.py | 36 ++-
tools/cronjob_tools.py | 8 +-
tools/path_security.py | 43 +++
tools/skill_manager_tool.py | 17 +-
tools/skills_tool.py | 37 +--
utils.py | 90 +++++-
49 files changed, 887 insertions(+), 949 deletions(-)
create mode 100644 gateway/platforms/helpers.py
create mode 100644 hermes_cli/cli_output.py
create mode 100644 hermes_cli/platforms.py
create mode 100644 tools/path_security.py
diff --git a/agent/display.py b/agent/display.py
index 604b7a298..182064576 100644
--- a/agent/display.py
+++ b/agent/display.py
@@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
Used by AIAgent._execute_tool_calls for CLI feedback.
"""
-import json
import logging
import os
import sys
@@ -14,6 +13,8 @@ from dataclasses import dataclass, field
from difflib import unified_diff
from pathlib import Path
+from utils import safe_json_loads
+
# ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m"
_RESET = "\033[0m"
@@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
"""Conservatively detect whether a tool result represents success."""
if not result:
return False
- try:
- data = json.loads(result)
- except (json.JSONDecodeError, TypeError):
+ data = safe_json_loads(result)
+ if data is None:
return False
if not isinstance(data, dict):
return False
@@ -423,10 +423,7 @@ def extract_edit_diff(
) -> str | None:
"""Extract a unified diff from a file-edit tool result."""
if tool_name == "patch" and result:
- try:
- data = json.loads(result)
- except (json.JSONDecodeError, TypeError):
- data = None
+ data = safe_json_loads(result)
if isinstance(data, dict):
diff = data.get("diff")
if isinstance(diff, str) and diff.strip():
@@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
return False, ""
if tool_name == "terminal":
- try:
- data = json.loads(result)
+ data = safe_json_loads(result)
+ if isinstance(data, dict):
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
- except (json.JSONDecodeError, TypeError, AttributeError):
- logger.debug("Could not parse terminal result as JSON for exit code check")
return False, ""
# Memory-specific: distinguish "full" from real errors
if tool_name == "memory":
- try:
- data = json.loads(result)
+ data = safe_json_loads(result)
+ if isinstance(data, dict):
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]"
- except (json.JSONDecodeError, TypeError, AttributeError):
- logger.debug("Could not parse memory result as JSON for capacity check")
# Generic heuristic for non-terminal tools
lower = result[:500].lower()
diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py
index 08b8fe0a6..26d913a02 100644
--- a/agent/prompt_builder.py
+++ b/agent/prompt_builder.py
@@ -12,7 +12,7 @@ import threading
from collections import OrderedDict
from pathlib import Path
-from hermes_constants import get_hermes_home
+from hermes_constants import get_hermes_home, get_skills_dir
from typing import Optional
from agent.skill_utils import (
@@ -548,8 +548,7 @@ def build_skills_system_prompt(
are read-only — they appear in the index but new skills are always created
in the local dir. Local skills take precedence when names collide.
"""
- hermes_home = get_hermes_home()
- skills_dir = hermes_home / "skills"
+ skills_dir = get_skills_dir()
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
if not skills_dir.exists() and not external_dirs:
diff --git a/agent/skill_utils.py b/agent/skill_utils.py
index ba606b358..97ba92b73 100644
--- a/agent/skill_utils.py
+++ b/agent/skill_utils.py
@@ -12,7 +12,7 @@ import sys
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
-from hermes_constants import get_hermes_home
+from hermes_constants import get_config_path, get_skills_dir
logger = logging.getLogger(__name__)
@@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
Reads the config file directly (no CLI config imports) to stay
lightweight.
"""
- config_path = get_hermes_home() / "config.yaml"
+ config_path = get_config_path()
if not config_path.exists():
return set()
try:
@@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
path. Only directories that actually exist are returned. Duplicates and
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
"""
- config_path = get_hermes_home() / "config.yaml"
+ config_path = get_config_path()
if not config_path.exists():
return []
try:
@@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
if not isinstance(raw_dirs, list):
return []
- local_skills = (get_hermes_home() / "skills").resolve()
+ local_skills = get_skills_dir().resolve()
seen: Set[Path] = set()
result: List[Path] = []
@@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
The local dir is always first (and always included even if it doesn't exist
yet — callers handle that). External dirs follow in config order.
"""
- dirs = [get_hermes_home() / "skills"]
+ dirs = [get_skills_dir()]
dirs.extend(get_external_skills_dirs())
return dirs
@@ -384,7 +384,7 @@ def resolve_skill_config_values(
current values (or the declared default if the key isn't set).
Path values are expanded via ``os.path.expanduser``.
"""
- config_path = get_hermes_home() / "config.yaml"
+ config_path = get_config_path()
config: Dict[str, Any] = {}
if config_path.exists():
try:
diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py
index 04f0c1deb..352aecb33 100644
--- a/gateway/platforms/base.py
+++ b/gateway/platforms/base.py
@@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC):
result = handler(self)
if asyncio.iscoroutine(result):
await result
-
+
+ def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
+ """Acquire a scoped lock for this adapter. Returns True on success."""
+ from gateway.status import acquire_scoped_lock
+ self._platform_lock_scope = scope
+ self._platform_lock_identity = identity
+ acquired, existing = acquire_scoped_lock(
+ scope, identity, metadata={'platform': self.platform.value}
+ )
+ if acquired:
+ return True
+ owner_pid = existing.get('pid') if isinstance(existing, dict) else None
+ message = (
+ f'{resource_desc} already in use'
+ + (f' (PID {owner_pid})' if owner_pid else '')
+ + '. Stop the other gateway first.'
+ )
+ logger.error('[%s] %s', self.name, message)
+ self._set_fatal_error(f'{scope}_lock', message, retryable=False)
+ return False
+
+ def _release_platform_lock(self) -> None:
+ """Release the scoped lock acquired by _acquire_platform_lock."""
+ identity = getattr(self, '_platform_lock_identity', None)
+ if not identity:
+ return
+ from gateway.status import release_scoped_lock
+ release_scoped_lock(self._platform_lock_scope, identity)
+ self._platform_lock_identity = None
+
@property
def name(self) -> str:
"""Human-readable name for this adapter."""
diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py
index f50cd9503..115000996 100644
--- a/gateway/platforms/bluebubbles.py
+++ b/gateway/platforms/bluebubbles.py
@@ -30,6 +30,7 @@ from gateway.platforms.base import (
cache_audio_from_bytes,
cache_document_from_bytes,
)
+from gateway.platforms.helpers import strip_markdown
logger = logging.getLogger(__name__)
@@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
return value.rstrip("/")
-def _strip_markdown(text: str) -> str:
- """Strip common markdown formatting for iMessage plain-text delivery."""
- text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
- text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
- text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
- text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
- text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
- text = re.sub(r"`(.+?)`", r"\1", text)
- text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
- text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
- text = re.sub(r"\n{3,}", "\n\n", text)
- return text.strip()
+
# ---------------------------------------------------------------------------
@@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
- text = _strip_markdown(content or "")
+ text = strip_markdown(content or "")
if not text:
return SendResult(success=False, error="BlueBubbles send requires text")
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
@@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
return info
def format_message(self, content: str) -> str:
- return _strip_markdown(content)
+ return strip_markdown(content)
# ------------------------------------------------------------------
# Inbound attachment downloading (from #4588)
diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py
index e83b902df..5d50deca5 100644
--- a/gateway/platforms/dingtalk.py
+++ b/gateway/platforms/dingtalk.py
@@ -42,6 +42,7 @@ except ImportError:
httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig
+from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -52,8 +53,6 @@ from gateway.platforms.base import (
logger = logging.getLogger(__name__)
MAX_MESSAGE_LENGTH = 20000
-DEDUP_WINDOW_SECONDS = 300
-DEDUP_MAX_SIZE = 1000
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
_SESSION_WEBHOOKS_MAX = 500
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
@@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_task: Optional[asyncio.Task] = None
self._http_client: Optional["httpx.AsyncClient"] = None
- # Message deduplication: msg_id -> timestamp
- self._seen_messages: Dict[str, float] = {}
+ # Message deduplication
+ self._dedup = MessageDeduplicator(max_size=1000)
# Map chat_id -> session_webhook for reply routing
self._session_webhooks: Dict[str, str] = {}
@@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_client = None
self._session_webhooks.clear()
- self._seen_messages.clear()
+ self._dedup.clear()
logger.info("[%s] Disconnected", self.name)
# -- Inbound message processing -----------------------------------------
@@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
async def _on_message(self, message: "ChatbotMessage") -> None:
"""Process an incoming DingTalk chatbot message."""
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
- if self._is_duplicate(msg_id):
+ if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
return
@@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
content = " ".join(parts).strip()
return content
- # -- Deduplication ------------------------------------------------------
-
- def _is_duplicate(self, msg_id: str) -> bool:
- """Check and record a message ID. Returns True if already seen."""
- now = time.time()
- if len(self._seen_messages) > DEDUP_MAX_SIZE:
- cutoff = now - DEDUP_WINDOW_SECONDS
- self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
-
- if msg_id in self._seen_messages:
- return True
- self._seen_messages[msg_id] = now
- return False
-
# -- Outbound messaging -------------------------------------------------
async def send(
diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py
index dcf05a162..b1d07e5d6 100644
--- a/gateway/platforms/discord.py
+++ b/gateway/platforms/discord.py
@@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
import re
+from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
# Track threads where the bot has participated so follow-up messages
# in those threads don't require @mention. Persisted to disk so the
# set survives gateway restarts.
- self._bot_participated_threads: set = self._load_participated_threads()
+ self._threads = ThreadParticipationTracker("discord")
# Persistent typing indicator loops per channel (DMs don't reliably
# show the standard typing gateway event for bots)
self._typing_tasks: Dict[str, asyncio.Task] = {}
self._bot_task: Optional[asyncio.Task] = None
- # Cap to prevent unbounded growth (Discord threads get archived).
- self._MAX_TRACKED_THREADS = 500
- # Dedup cache: message_id → timestamp. Prevents duplicate bot
- # responses when Discord RESUME replays events after reconnects.
- self._seen_messages: Dict[str, float] = {}
- self._SEEN_TTL = 300 # 5 minutes
- self._SEEN_MAX = 2000 # prune threshold
+ # Dedup cache: prevents duplicate bot responses when Discord
+ # RESUME replays events after reconnects.
+ self._dedup = MessageDeduplicator()
# Reply threading mode: "off" (no replies), "first" (reply on first
# chunk only, default), "all" (reply-reference on every chunk).
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
@@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
return False
try:
- # Acquire scoped lock to prevent duplicate bot token usage
- from gateway.status import acquire_scoped_lock
- self._token_lock_identity = self.config.token
- acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
- if not acquired:
- owner_pid = existing.get('pid') if isinstance(existing, dict) else None
- message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
- logger.error('[%s] %s', self.name, message)
- self._set_fatal_error('discord_token_lock', message, retryable=False)
+ if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
return False
-
# Parse allowed user entries (may contain usernames or IDs)
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
if allowed_env:
@@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event
async def on_message(message: DiscordMessage):
# Dedup: Discord RESUME replays events after reconnects (#4777)
- msg_id = str(message.id)
- now = time.time()
- if msg_id in adapter_self._seen_messages:
+ if adapter_self._dedup.is_duplicate(str(message.id)):
return
- adapter_self._seen_messages[msg_id] = now
- if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
- cutoff = now - adapter_self._SEEN_TTL
- adapter_self._seen_messages = {
- k: v for k, v in adapter_self._seen_messages.items()
- if v > cutoff
- }
# Always ignore our own messages
if message.author == self._client.user:
@@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
except asyncio.TimeoutError:
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
- try:
- from gateway.status import release_scoped_lock
- if getattr(self, '_token_lock_identity', None):
- release_scoped_lock('discord-bot-token', self._token_lock_identity)
- self._token_lock_identity = None
- except Exception:
- pass
+ self._release_platform_lock()
return False
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
- try:
- from gateway.status import release_scoped_lock
- if getattr(self, '_token_lock_identity', None):
- release_scoped_lock('discord-bot-token', self._token_lock_identity)
- self._token_lock_identity = None
- except Exception:
- pass
+ self._release_platform_lock()
return False
async def disconnect(self) -> None:
@@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
self._client = None
self._ready_event.clear()
- # Release the token lock
- try:
- from gateway.status import release_scoped_lock
- if getattr(self, '_token_lock_identity', None):
- release_scoped_lock('discord-bot-token', self._token_lock_identity)
- self._token_lock_identity = None
- except Exception:
- pass
+ self._release_platform_lock()
logger.info("[%s] Disconnected", self.name)
@@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so follow-ups don't require @mention
if thread_id:
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
# If a message was provided, kick off a new Hermes session in the thread
starter = (message or "").strip()
@@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}"
return thread_name
- # ------------------------------------------------------------------
- # Thread participation persistence
- # ------------------------------------------------------------------
-
- @staticmethod
- def _thread_state_path() -> Path:
- """Path to the persisted thread participation set."""
- from hermes_cli.config import get_hermes_home
- return get_hermes_home() / "discord_threads.json"
-
- @classmethod
- def _load_participated_threads(cls) -> set:
- """Load persisted thread IDs from disk."""
- path = cls._thread_state_path()
- try:
- if path.exists():
- data = json.loads(path.read_text(encoding="utf-8"))
- if isinstance(data, list):
- return set(data)
- except Exception as e:
- logger.debug("Could not load discord thread state: %s", e)
- return set()
-
- def _save_participated_threads(self) -> None:
- """Persist the current thread set to disk (best-effort)."""
- path = self._thread_state_path()
- try:
- # Trim to most recent entries if over cap
- thread_list = list(self._bot_participated_threads)
- if len(thread_list) > self._MAX_TRACKED_THREADS:
- thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
- self._bot_participated_threads = set(thread_list)
- path.parent.mkdir(parents=True, exist_ok=True)
- path.write_text(json.dumps(thread_list), encoding="utf-8")
- except Exception as e:
- logger.debug("Could not save discord thread state: %s", e)
-
- def _track_thread(self, thread_id: str) -> None:
- """Add a thread to the participation set and persist."""
- if thread_id not in self._bot_participated_threads:
- self._bot_participated_threads.add(thread_id)
- self._save_participated_threads()
-
async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned
@@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Skip the mention check if the message is in a thread where
# the bot has previously participated (auto-created or replied in).
- in_bot_thread = is_thread and thread_id in self._bot_participated_threads
+ in_bot_thread = is_thread and thread_id in self._threads
if require_mention and not is_free_channel and not in_bot_thread:
if self._client.user not in message.mentions:
@@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_thread = True
thread_id = str(thread.id)
auto_threaded_channel = thread
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
# Determine message type
msg_type = MessageType.TEXT
@@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so the bot won't require @mention for
# follow-up messages in threads it has already engaged in.
if thread_id:
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
# Only batch plain text messages — commands, media, etc. dispatch
# immediately since they won't be split by the Discord client.
diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py
index a88c7e52b..16f5467b2 100644
--- a/gateway/platforms/feishu.py
+++ b/gateway/platforms/feishu.py
@@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
def _strip_markdown_to_plain_text(text: str) -> str:
+ """Strip markdown formatting to plain text for Feishu text fallbacks.
+
+ Delegates common markdown stripping to the shared helper and adds
+ Feishu-specific patterns (blockquotes, strikethrough, underline tags,
+ horizontal rules, \\r\\n normalisation).
+ """
+ from gateway.platforms.helpers import strip_markdown
plain = text.replace("\r\n", "\n")
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
- plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
- plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
- plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
- plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
- plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
plain = re.sub(r"([\s\S]*?)", r"\1", plain)
- plain = re.sub(r"\n{3,}", "\n\n", plain)
- return plain.strip()
+ plain = strip_markdown(plain)
+ return plain
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:
diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py
new file mode 100644
index 000000000..c834dd89c
--- /dev/null
+++ b/gateway/platforms/helpers.py
@@ -0,0 +1,261 @@
+"""Shared helper classes for gateway platform adapters.
+
+Extracts common patterns that were duplicated across 5-7 adapters:
+message deduplication, text batch aggregation, markdown stripping,
+and thread participation tracking.
+"""
+
+import asyncio
+import json
+import logging
+import re
+import time
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, Optional
+
+if TYPE_CHECKING:
+ from gateway.platforms.base import BasePlatformAdapter, MessageEvent
+
+logger = logging.getLogger(__name__)
+
+
+# ─── Message Deduplication ────────────────────────────────────────────────────
+
+
+class MessageDeduplicator:
+ """TTL-based message deduplication cache.
+
+ Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
+ previously duplicated in discord, slack, dingtalk, wecom, weixin,
+ mattermost, and feishu adapters.
+
+ Usage::
+
+ self._dedup = MessageDeduplicator()
+
+ # In message handler:
+ if self._dedup.is_duplicate(msg_id):
+ return
+ """
+
+ def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
+ self._seen: Dict[str, float] = {}
+ self._max_size = max_size
+ self._ttl = ttl_seconds
+
+ def is_duplicate(self, msg_id: str) -> bool:
+ """Return True if *msg_id* was already seen within the TTL window."""
+ if not msg_id:
+ return False
+ now = time.time()
+ if msg_id in self._seen:
+ return True
+ self._seen[msg_id] = now
+ if len(self._seen) > self._max_size:
+ cutoff = now - self._ttl
+ self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
+ return False
+
+ def clear(self):
+ """Clear all tracked messages."""
+ self._seen.clear()
+
+
+# ─── Text Batch Aggregation ──────────────────────────────────────────────────
+
+
+class TextBatchAggregator:
+ """Aggregates rapid-fire text events into single messages.
+
+ Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
+ previously duplicated in telegram, discord, matrix, wecom, and feishu.
+
+ Usage::
+
+ self._text_batcher = TextBatchAggregator(
+ handler=self._message_handler,
+ batch_delay=0.6,
+ split_threshold=1900,
+ )
+
+ # In message dispatch:
+ if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
+ self._text_batcher.enqueue(event, session_key)
+ return
+ """
+
+ def __init__(
+ self,
+ handler,
+ *,
+ batch_delay: float = 0.6,
+ split_delay: float = 2.0,
+ split_threshold: int = 4000,
+ ):
+ self._handler = handler
+ self._batch_delay = batch_delay
+ self._split_delay = split_delay
+ self._split_threshold = split_threshold
+ self._pending: Dict[str, "MessageEvent"] = {}
+ self._pending_tasks: Dict[str, asyncio.Task] = {}
+
+ def is_enabled(self) -> bool:
+ """Return True if batching is active (delay > 0)."""
+ return self._batch_delay > 0
+
+ def enqueue(self, event: "MessageEvent", key: str) -> None:
+ """Add *event* to the pending batch for *key*."""
+ chunk_len = len(event.text or "")
+ existing = self._pending.get(key)
+ if not existing:
+ event._last_chunk_len = chunk_len # type: ignore[attr-defined]
+ self._pending[key] = event
+ else:
+ existing.text = f"{existing.text}\n{event.text}"
+ existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
+
+ # Cancel prior flush timer, start a new one
+ prior = self._pending_tasks.get(key)
+ if prior and not prior.done():
+ prior.cancel()
+ self._pending_tasks[key] = asyncio.create_task(self._flush(key))
+
+ async def _flush(self, key: str) -> None:
+ """Wait then dispatch the batched event for *key*."""
+ current_task = self._pending_tasks.get(key)
+ pending = self._pending.get(key)
+ last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
+
+ # Use longer delay when the last chunk looks like a split message
+ delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
+ await asyncio.sleep(delay)
+
+ event = self._pending.pop(key, None)
+ if event:
+ try:
+ await self._handler(event)
+ except Exception:
+ logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
+
+ if self._pending_tasks.get(key) is current_task:
+ self._pending_tasks.pop(key, None)
+
+ def cancel_all(self) -> None:
+ """Cancel all pending flush tasks."""
+ for task in self._pending_tasks.values():
+ if not task.done():
+ task.cancel()
+ self._pending_tasks.clear()
+ self._pending.clear()
+
+
+# ─── Markdown Stripping ──────────────────────────────────────────────────────
+
+# Pre-compiled regexes for performance
+_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
+_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
+_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
+_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
+_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
+_RE_INLINE_CODE = re.compile(r"`(.+?)`")
+_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
+_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
+_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
+
+
+def strip_markdown(text: str) -> str:
+ """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
+
+ Replaces the identical ``_strip_markdown()`` functions previously
+ duplicated in sms.py, bluebubbles.py, and feishu.py.
+ """
+ text = _RE_BOLD.sub(r"\1", text)
+ text = _RE_ITALIC_STAR.sub(r"\1", text)
+ text = _RE_BOLD_UNDER.sub(r"\1", text)
+ text = _RE_ITALIC_UNDER.sub(r"\1", text)
+ text = _RE_CODE_BLOCK.sub("", text)
+ text = _RE_INLINE_CODE.sub(r"\1", text)
+ text = _RE_HEADING.sub("", text)
+ text = _RE_LINK.sub(r"\1", text)
+ text = _RE_MULTI_NEWLINE.sub("\n\n", text)
+ return text.strip()
+
+
+# ─── Thread Participation Tracking ───────────────────────────────────────────
+
+
+class ThreadParticipationTracker:
+ """Persistent tracking of threads the bot has participated in.
+
+ Replaces the identical ``_load/_save_participated_threads`` +
+ ``_mark_thread_participated`` pattern previously duplicated in
+ discord.py and matrix.py.
+
+ Usage::
+
+ self._threads = ThreadParticipationTracker("discord")
+
+ # Check membership:
+ if thread_id in self._threads:
+ ...
+
+ # Mark participation:
+ self._threads.mark(thread_id)
+ """
+
+ _MAX_TRACKED = 500
+
+ def __init__(self, platform_name: str, max_tracked: int = 500):
+ self._platform = platform_name
+ self._max_tracked = max_tracked
+ self._threads: set = self._load()
+
+ def _state_path(self) -> Path:
+ from hermes_constants import get_hermes_home
+ return get_hermes_home() / f"{self._platform}_threads.json"
+
+ def _load(self) -> set:
+ path = self._state_path()
+ if path.exists():
+ try:
+ return set(json.loads(path.read_text(encoding="utf-8")))
+ except Exception:
+ pass
+ return set()
+
+ def _save(self) -> None:
+ path = self._state_path()
+ path.parent.mkdir(parents=True, exist_ok=True)
+ thread_list = list(self._threads)
+ if len(thread_list) > self._max_tracked:
+ thread_list = thread_list[-self._max_tracked:]
+ self._threads = set(thread_list)
+ path.write_text(json.dumps(thread_list), encoding="utf-8")
+
+ def mark(self, thread_id: str) -> None:
+ """Mark *thread_id* as participated and persist."""
+ if thread_id not in self._threads:
+ self._threads.add(thread_id)
+ self._save()
+
+ def __contains__(self, thread_id: str) -> bool:
+ return thread_id in self._threads
+
+ def clear(self) -> None:
+ self._threads.clear()
+
+
+# ─── Phone Number Redaction ──────────────────────────────────────────────────
+
+
+def redact_phone(phone: str) -> str:
+ """Redact a phone number for logging, preserving country code and last 4.
+
+ Replaces the identical ``_redact_phone()`` functions in signal.py,
+ sms.py, and bluebubbles.py.
+ """
+ if not phone:
+ return ""
+ if len(phone) <= 8:
+ return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
+ return phone[:4] + "****" + phone[-4:]
diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py
index 7daf2e70e..349f962d2 100644
--- a/gateway/platforms/matrix.py
+++ b/gateway/platforms/matrix.py
@@ -92,6 +92,7 @@ from gateway.platforms.base import (
ProcessingOutcome,
SendResult,
)
+from gateway.platforms.helpers import ThreadParticipationTracker
logger = logging.getLogger(__name__)
@@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
self._pending_megolm: list = []
# Thread participation tracking (for require_mention bypass)
- self._bot_participated_threads: set = self._load_participated_threads()
- self._MAX_TRACKED_THREADS = 500
+ self._threads = ThreadParticipationTracker("matrix")
# Mention/thread gating — parsed once from env vars.
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
@@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Require-mention gating.
if not is_dm:
is_free_room = room_id in self._free_rooms
- in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
+ in_bot_thread = bool(thread_id and thread_id in self._threads)
if self._require_mention and not is_free_room and not in_bot_thread:
if not is_mentioned:
return None
@@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
# DM mention-thread.
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
thread_id = event_id
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
# Strip mention from body.
if is_mentioned:
@@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Auto-thread.
if not is_dm and not thread_id and self._auto_thread:
thread_id = event_id
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
display_name = await self._get_display_name(room_id, sender)
source = self.build_source(
@@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
)
if thread_id:
- self._track_thread(thread_id)
+ self._threads.mark(thread_id)
self._background_read_receipt(room_id, event_id)
@@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
for rid in self._joined_rooms
}
- # ------------------------------------------------------------------
- # Thread participation tracking
- # ------------------------------------------------------------------
-
- @staticmethod
- def _thread_state_path() -> Path:
- """Path to the persisted thread participation set."""
- from hermes_cli.config import get_hermes_home
- return get_hermes_home() / "matrix_threads.json"
-
- @classmethod
- def _load_participated_threads(cls) -> set:
- """Load persisted thread IDs from disk."""
- path = cls._thread_state_path()
- try:
- if path.exists():
- data = json.loads(path.read_text(encoding="utf-8"))
- if isinstance(data, list):
- return set(data)
- except Exception as e:
- logger.debug("Could not load matrix thread state: %s", e)
- return set()
-
- def _save_participated_threads(self) -> None:
- """Persist the current thread set to disk (best-effort)."""
- path = self._thread_state_path()
- try:
- thread_list = list(self._bot_participated_threads)
- if len(thread_list) > self._MAX_TRACKED_THREADS:
- thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
- self._bot_participated_threads = set(thread_list)
- path.parent.mkdir(parents=True, exist_ok=True)
- path.write_text(json.dumps(thread_list), encoding="utf-8")
- except Exception as e:
- logger.debug("Could not save matrix thread state: %s", e)
-
- def _track_thread(self, thread_id: str) -> None:
- """Add a thread to the participation set and persist."""
- if thread_id not in self._bot_participated_threads:
- self._bot_participated_threads.add(thread_id)
- self._save_participated_threads()
-
# ------------------------------------------------------------------
# Mention detection helpers
# ------------------------------------------------------------------
diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py
index 56f29e876..23a86f02b 100644
--- a/gateway/platforms/mattermost.py
+++ b/gateway/platforms/mattermost.py
@@ -18,11 +18,11 @@ import json
import logging
import os
import re
-import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from gateway.config import Platform, PlatformConfig
+from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
or os.getenv("MATTERMOST_REPLY_MODE", "off")
).lower()
- # Dedup cache: post_id → timestamp (prevent reprocessing)
- self._seen_posts: Dict[str, float] = {}
- self._SEEN_MAX = 2000
- self._SEEN_TTL = 300 # 5 minutes
+ # Dedup cache (prevent reprocessing)
+ self._dedup = MessageDeduplicator()
# ------------------------------------------------------------------
# HTTP helpers
@@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
post_id = post.get("id", "")
# Dedup.
- self._prune_seen()
- if post_id in self._seen_posts:
+ if self._dedup.is_duplicate(post_id):
return
- self._seen_posts[post_id] = time.time()
# Build message event.
channel_id = post.get("channel_id", "")
@@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
await self.handle_message(msg_event)
- def _prune_seen(self) -> None:
- """Remove expired entries from the dedup cache."""
- if len(self._seen_posts) < self._SEEN_MAX:
- return
- now = time.time()
- self._seen_posts = {
- pid: ts
- for pid, ts in self._seen_posts.items()
- if now - ts < self._SEEN_TTL
- }
+
diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py
index 08b62f2a6..8ef7bd0d6 100644
--- a/gateway/platforms/signal.py
+++ b/gateway/platforms/signal.py
@@ -37,6 +37,7 @@ from gateway.platforms.base import (
cache_document_from_bytes,
cache_image_from_url,
)
+from gateway.platforms.helpers import redact_phone
logger = logging.getLogger(__name__)
@@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
-# E.164 phone number pattern for redaction
-_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
-
-
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
-def _redact_phone(phone: str) -> str:
- """Redact a phone number for logging: +15551234567 -> +155****4567."""
- if not phone:
- return ""
- if len(phone) <= 8:
- return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
- return phone[:4] + "****" + phone[-4:]
-
def _parse_comma_list(value: str) -> List[str]:
"""Split a comma-separated string into a list, stripping whitespace."""
@@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
self._recent_sent_timestamps: set = set()
self._max_recent_timestamps = 50
- self._phone_lock_identity: Optional[str] = None
-
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
- self.http_url, _redact_phone(self.account),
+ self.http_url, redact_phone(self.account),
"enabled" if self.group_allow_from else "disabled")
# ------------------------------------------------------------------
@@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
try:
- from gateway.status import acquire_scoped_lock
-
- self._phone_lock_identity = self.account
- acquired, existing = acquire_scoped_lock(
- "signal-phone",
- self._phone_lock_identity,
- metadata={"platform": self.platform.value},
- )
- if not acquired:
- owner_pid = existing.get("pid") if isinstance(existing, dict) else None
- message = (
- "Another local Hermes gateway is already using this Signal account"
- + (f" (PID {owner_pid})." if owner_pid else ".")
- + " Stop the other gateway before starting a second Signal listener."
- )
- logger.error("Signal: %s", message)
- self._set_fatal_error("signal_phone_lock", message, retryable=False)
+ if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
return False
except Exception as e:
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
@@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
await self.client.aclose()
self.client = None
- if self._phone_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("signal-phone", self._phone_lock_identity)
- except Exception as e:
- logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
- self._phone_lock_identity = None
+ self._release_platform_lock()
logger.info("Signal: disconnected")
@@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
)
logger.debug("Signal: message from %s in %s: %s",
- _redact_phone(sender), chat_id[:20], (text or "")[:50])
+ redact_phone(sender), chat_id[:20], (text or "")[:50])
await self.handle_message(event)
diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py
index 361f74882..8f9934cf7 100644
--- a/gateway/platforms/slack.py
+++ b/gateway/platforms/slack.py
@@ -33,6 +33,7 @@ from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
+from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
self._channel_team: Dict[str, str] = {} # channel_id → team_id
- # Dedup cache: event_ts → timestamp. Prevents duplicate bot
- # responses when Socket Mode reconnects redeliver events.
- self._seen_messages: Dict[str, float] = {}
- self._SEEN_TTL = 300 # 5 minutes
- self._SEEN_MAX = 2000 # prune threshold
+ # Dedup cache: prevents duplicate bot responses when Socket Mode
+ # reconnects redeliver events.
+ self._dedup = MessageDeduplicator()
# Track pending approval message_ts → resolved flag to prevent
# double-clicks on approval buttons.
self._approval_resolved: Dict[str, bool] = {}
@@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
try:
- # Acquire scoped lock to prevent duplicate app token usage
- from gateway.status import acquire_scoped_lock
- self._token_lock_identity = app_token
- acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
- if not acquired:
- owner_pid = existing.get('pid') if isinstance(existing, dict) else None
- message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
- logger.error('[%s] %s', self.name, message)
- self._set_fatal_error('slack_token_lock', message, retryable=False)
+ if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
return False
# First token is the primary — used for AsyncApp / Socket Mode
@@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
self._running = False
- # Release the token lock (use stored identity, not re-read env)
- try:
- from gateway.status import release_scoped_lock
- if getattr(self, '_token_lock_identity', None):
- release_scoped_lock('slack-app-token', self._token_lock_identity)
- self._token_lock_identity = None
- except Exception:
- pass
+ self._release_platform_lock()
logger.info("[Slack] Disconnected")
@@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
"""Handle an incoming Slack message event."""
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
event_ts = event.get("ts", "")
- if event_ts:
- now = time.time()
- if event_ts in self._seen_messages:
- return
- self._seen_messages[event_ts] = now
- if len(self._seen_messages) > self._SEEN_MAX:
- cutoff = now - self._SEEN_TTL
- self._seen_messages = {
- k: v for k, v in self._seen_messages.items()
- if v > cutoff
- }
+ if event_ts and self._dedup.is_duplicate(event_ts):
+ return
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
# "none" — ignore all bot messages (default, backward-compatible)
diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py
index a0760199b..953ec5c5e 100644
--- a/gateway/platforms/sms.py
+++ b/gateway/platforms/sms.py
@@ -19,7 +19,6 @@ import asyncio
import base64
import logging
import os
-import re
import urllib.parse
from typing import Any, Dict, Optional
@@ -30,6 +29,7 @@ from gateway.platforms.base import (
MessageType,
SendResult,
)
+from gateway.platforms.helpers import redact_phone, strip_markdown
logger = logging.getLogger(__name__)
@@ -37,18 +37,6 @@ TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
DEFAULT_WEBHOOK_PORT = 8080
-# E.164 phone number pattern for redaction
-_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
-
-
-def _redact_phone(phone: str) -> str:
- """Redact a phone number for logging: +15551234567 -> +1555***4567."""
- if not phone:
- return ""
- if len(phone) <= 8:
- return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
- return phone[:5] + "***" + phone[-4:]
-
def check_sms_requirements() -> bool:
"""Check if SMS adapter dependencies are available."""
@@ -114,7 +102,7 @@ class SmsAdapter(BasePlatformAdapter):
logger.info(
"[sms] Twilio webhook server listening on port %d, from: %s",
self._webhook_port,
- _redact_phone(self._from_number),
+ redact_phone(self._from_number),
)
return True
@@ -163,7 +151,7 @@ class SmsAdapter(BasePlatformAdapter):
error_msg = body.get("message", str(body))
logger.error(
"[sms] send failed to %s: %s %s",
- _redact_phone(chat_id),
+ redact_phone(chat_id),
resp.status,
error_msg,
)
@@ -174,7 +162,7 @@ class SmsAdapter(BasePlatformAdapter):
msg_sid = body.get("sid", "")
last_result = SendResult(success=True, message_id=msg_sid)
except Exception as e:
- logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
+ logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
return SendResult(success=False, error=str(e))
finally:
# Close session only if we created a fallback (no persistent session)
@@ -192,16 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
def format_message(self, content: str) -> str:
"""Strip markdown — SMS renders it as literal characters."""
- content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
- content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
- content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
- content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
- content = re.sub(r"```[a-z]*\n?", "", content)
- content = re.sub(r"`(.+?)`", r"\1", content)
- content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
- content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
- content = re.sub(r"\n{3,}", "\n\n", content)
- return content.strip()
+ return strip_markdown(content)
# ------------------------------------------------------------------
# Twilio webhook handler
@@ -236,7 +215,7 @@ class SmsAdapter(BasePlatformAdapter):
# Ignore messages from our own number (echo prevention)
if from_number == self._from_number:
- logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
+ logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
return web.Response(
text='',
content_type="application/xml",
@@ -244,8 +223,8 @@ class SmsAdapter(BasePlatformAdapter):
logger.info(
"[sms] inbound from %s -> %s: %s",
- _redact_phone(from_number),
- _redact_phone(to_number),
+ redact_phone(from_number),
+ redact_phone(to_number),
text[:80],
)
diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py
index 8b4e43514..884ef9c45 100644
--- a/gateway/platforms/telegram.py
+++ b/gateway/platforms/telegram.py
@@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
self._pending_text_batches: Dict[str, MessageEvent] = {}
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
- self._token_lock_identity: Optional[str] = None
self._polling_error_task: Optional[asyncio.Task] = None
self._polling_conflict_count: int = 0
self._polling_network_error_count: int = 0
@@ -497,23 +496,7 @@ class TelegramAdapter(BasePlatformAdapter):
return False
try:
- from gateway.status import acquire_scoped_lock
-
- self._token_lock_identity = self.config.token
- acquired, existing = acquire_scoped_lock(
- "telegram-bot-token",
- self._token_lock_identity,
- metadata={"platform": self.platform.value},
- )
- if not acquired:
- owner_pid = existing.get("pid") if isinstance(existing, dict) else None
- message = (
- "Another local Hermes gateway is already using this Telegram bot token"
- + (f" (PID {owner_pid})." if owner_pid else ".")
- + " Stop the other gateway before starting a second Telegram poller."
- )
- logger.error("[%s] %s", self.name, message)
- self._set_fatal_error("telegram_token_lock", message, retryable=False)
+ if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
return False
# Build the application
@@ -737,12 +720,7 @@ class TelegramAdapter(BasePlatformAdapter):
return True
except Exception as e:
- if self._token_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("telegram-bot-token", self._token_lock_identity)
- except Exception:
- pass
+ self._release_platform_lock()
message = f"Telegram startup failed: {e}"
self._set_fatal_error("telegram_connect_error", message, retryable=True)
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
@@ -768,12 +746,7 @@ class TelegramAdapter(BasePlatformAdapter):
await self._app.shutdown()
except Exception as e:
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
- if self._token_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("telegram-bot-token", self._token_lock_identity)
- except Exception as e:
- logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
+ self._release_platform_lock()
for task in self._pending_photo_batch_tasks.values():
if task and not task.done():
@@ -784,7 +757,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._mark_disconnected()
self._app = None
self._bot = None
- self._token_lock_identity = None
logger.info("[%s] Disconnected from Telegram", self.name)
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:
diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py
index aa07dc6a9..a0e71e01b 100644
--- a/gateway/platforms/wecom.py
+++ b/gateway/platforms/wecom.py
@@ -59,6 +59,7 @@ except ImportError:
httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig
+from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
HEARTBEAT_INTERVAL_SECONDS = 30.0
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
-DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000
IMAGE_MAX_BYTES = 10 * 1024 * 1024
@@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
self._listen_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._pending_responses: Dict[str, asyncio.Future] = {}
- self._seen_messages: Dict[str, float] = {}
+ self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
self._reply_req_ids: Dict[str, str] = {}
# Text batching: merge rapid successive messages (Telegram-style).
@@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
await self._http_client.aclose()
self._http_client = None
- self._seen_messages.clear()
+ self._dedup.clear()
logger.info("[%s] Disconnected", self.name)
async def _cleanup_ws(self) -> None:
@@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
return
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
- if self._is_duplicate(msg_id):
+ if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
return
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
@@ -839,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
wildcard = self._groups.get("*")
return wildcard if isinstance(wildcard, dict) else {}
- def _is_duplicate(self, msg_id: str) -> bool:
- now = time.time()
- if len(self._seen_messages) > DEDUP_MAX_SIZE:
- cutoff = now - DEDUP_WINDOW_SECONDS
- self._seen_messages = {
- key: ts for key, ts in self._seen_messages.items() if ts > cutoff
- }
- if self._reply_req_ids:
- self._reply_req_ids = {
- key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
- }
-
- if msg_id in self._seen_messages:
- return True
-
- self._seen_messages[msg_id] = now
- return False
-
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
normalized_message_id = str(message_id or "").strip()
normalized_req_id = str(req_id or "").strip()
diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py
index 5e0208c77..3a4a80540 100644
--- a/gateway/platforms/weixin.py
+++ b/gateway/platforms/weixin.py
@@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
CRYPTO_AVAILABLE = False
from gateway.config import Platform, PlatformConfig
+from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@@ -1008,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
self._typing_cache = TypingTicketCache()
self._session: Optional[aiohttp.ClientSession] = None
self._poll_task: Optional[asyncio.Task] = None
- self._seen_messages: Dict[str, float] = {}
- self._token_lock_identity: Optional[str] = None
+ self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
@@ -1067,23 +1067,7 @@ class WeixinAdapter(BasePlatformAdapter):
return False
try:
- from gateway.status import acquire_scoped_lock
-
- self._token_lock_identity = self._token
- acquired, existing = acquire_scoped_lock(
- "weixin-bot-token",
- self._token_lock_identity,
- metadata={"platform": self.platform.value},
- )
- if not acquired:
- owner_pid = existing.get("pid") if isinstance(existing, dict) else None
- message = (
- "Another local Hermes gateway is already using this Weixin token"
- + (f" (PID {owner_pid})." if owner_pid else ".")
- + " Stop the other gateway before starting a second Weixin poller."
- )
- logger.error("[%s] %s", self.name, message)
- self._set_fatal_error("weixin_token_lock", message, retryable=False)
+ if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
return False
except Exception as exc:
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
@@ -1107,12 +1091,7 @@ class WeixinAdapter(BasePlatformAdapter):
if self._session and not self._session.closed:
await self._session.close()
self._session = None
- if self._token_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("weixin-bot-token", self._token_lock_identity)
- except Exception as exc:
- logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
+ self._release_platform_lock()
self._mark_disconnected()
logger.info("[%s] Disconnected", self.name)
@@ -1190,16 +1169,8 @@ class WeixinAdapter(BasePlatformAdapter):
return
message_id = str(message.get("message_id") or "").strip()
- if message_id:
- now = time.time()
- self._seen_messages = {
- key: value
- for key, value in self._seen_messages.items()
- if now - value < MESSAGE_DEDUP_TTL_SECONDS
- }
- if message_id in self._seen_messages:
- return
- self._seen_messages[message_id] = now
+ if message_id and self._dedup.is_duplicate(message_id):
+ return
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
if chat_type == "group":
diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py
index a6475dcb8..c616f7244 100644
--- a/gateway/platforms/whatsapp.py
+++ b/gateway/platforms/whatsapp.py
@@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._bridge_log: Optional[Path] = None
self._poll_task: Optional[asyncio.Task] = None
self._http_session: Optional["aiohttp.ClientSession"] = None
- self._session_lock_identity: Optional[str] = None
def _whatsapp_require_mention(self) -> bool:
configured = self.config.extra.get("require_mention")
@@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate sessions
try:
- from gateway.status import acquire_scoped_lock
-
- self._session_lock_identity = str(self._session_path)
- acquired, existing = acquire_scoped_lock(
- "whatsapp-session",
- self._session_lock_identity,
- metadata={"platform": self.platform.value},
- )
- if not acquired:
- owner_pid = existing.get("pid") if isinstance(existing, dict) else None
- message = (
- "Another local Hermes gateway is already using this WhatsApp session"
- + (f" (PID {owner_pid})." if owner_pid else ".")
- + " Stop the other gateway before starting a second WhatsApp bridge."
- )
- logger.error("[%s] %s", self.name, message)
- self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
+ if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
return False
except Exception as e:
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
@@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
return True
except Exception as e:
- if self._session_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("whatsapp-session", self._session_lock_identity)
- except Exception:
- pass
+ self._release_platform_lock()
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
self._close_bridge_log()
return False
@@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
await self._http_session.close()
self._http_session = None
- if self._session_lock_identity:
- try:
- from gateway.status import release_scoped_lock
- release_scoped_lock("whatsapp-session", self._session_lock_identity)
- except Exception as e:
- logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
+ self._release_platform_lock()
self._mark_disconnected()
self._bridge_process = None
self._close_bridge_log()
- self._session_lock_identity = None
print(f"[{self.name}] Disconnected")
async def send(
diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py
index fcb7c2dc5..56b9fb63c 100644
--- a/hermes_cli/auth.py
+++ b/hermes_cli/auth.py
@@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
}
+# =============================================================================
+# Anthropic Key Helper
+# =============================================================================
+
+def get_anthropic_key() -> str:
+ """Return the first usable Anthropic credential, or ``""``.
+
+ Checks both the ``.env`` file (via ``get_env_value``) and the process
+ environment (``os.getenv``). The fallback order mirrors the
+ ``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
+
+ ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
+ """
+ from hermes_cli.config import get_env_value
+
+ for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
+ value = get_env_value(var) or os.getenv(var, "")
+ if value:
+ return value
+ return ""
+
+
# =============================================================================
# Kimi Code Endpoint Detection
# =============================================================================
diff --git a/hermes_cli/cli_output.py b/hermes_cli/cli_output.py
new file mode 100644
index 000000000..3d454eb30
--- /dev/null
+++ b/hermes_cli/cli_output.py
@@ -0,0 +1,79 @@
+"""Shared CLI output helpers for Hermes CLI modules.
+
+Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
+functions previously duplicated across setup.py, tools_config.py,
+mcp_config.py, and memory_setup.py.
+"""
+
+import getpass
+import sys
+
+from hermes_cli.colors import Colors, color
+
+
+# ─── Print Helpers ────────────────────────────────────────────────────────────
+
+
+def print_info(text: str) -> None:
+ """Print a dim informational message."""
+ print(color(f" {text}", Colors.DIM))
+
+
+def print_success(text: str) -> None:
+ """Print a green success message with ✓ prefix."""
+ print(color(f"✓ {text}", Colors.GREEN))
+
+
+def print_warning(text: str) -> None:
+ """Print a yellow warning message with ⚠ prefix."""
+ print(color(f"⚠ {text}", Colors.YELLOW))
+
+
+def print_error(text: str) -> None:
+ """Print a red error message with ✗ prefix."""
+ print(color(f"✗ {text}", Colors.RED))
+
+
+def print_header(text: str) -> None:
+ """Print a bold yellow header."""
+ print(color(f"\n {text}", Colors.YELLOW))
+
+
+# ─── Input Prompts ────────────────────────────────────────────────────────────
+
+
+def prompt(
+ question: str,
+ default: str | None = None,
+ password: bool = False,
+) -> str:
+ """Prompt the user for input with optional default and password masking.
+
+ Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
+ in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
+
+ Returns the user's input (stripped), or *default* if the user presses Enter.
+ Returns empty string on Ctrl-C or EOF.
+ """
+ suffix = f" [{default}]" if default else ""
+ display = color(f" {question}{suffix}: ", Colors.YELLOW)
+
+ try:
+ if password:
+ value = getpass.getpass(display)
+ else:
+ value = input(display)
+ value = value.strip()
+ return value if value else (default or "")
+ except (KeyboardInterrupt, EOFError):
+ print()
+ return ""
+
+
+def prompt_yes_no(question: str, default: bool = True) -> bool:
+ """Prompt for a yes/no answer. Returns bool."""
+ hint = "Y/n" if default else "y/N"
+ answer = prompt(f"{question} ({hint})")
+ if not answer:
+ return default
+ return answer.lower().startswith("y")
diff --git a/hermes_cli/config.py b/hermes_cli/config.py
index 4661455d1..c3cf0456e 100644
--- a/hermes_cli/config.py
+++ b/hermes_cli/config.py
@@ -2582,7 +2582,8 @@ def show_config():
for env_key, name in keys:
value = get_env_value(env_key)
print(f" {name:<14} {redact_key(value)}")
- anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY")
+ from hermes_cli.auth import get_anthropic_key
+ anthropic_value = get_anthropic_key()
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
# Model settings
@@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str):
# Write only user config back (not the full merged defaults)
ensure_hermes_home()
- with open(config_path, 'w', encoding="utf-8") as f:
- yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
+ from utils import atomic_yaml_write
+ atomic_yaml_write(config_path, user_config, sort_keys=False)
# Keep .env in sync for keys that terminal_tool reads directly from env vars.
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.
diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py
index f5f8a228a..13c904692 100644
--- a/hermes_cli/doctor.py
+++ b/hermes_cli/doctor.py
@@ -336,8 +336,8 @@ def run_doctor(args):
model_section[k] = raw_config.pop(k)
else:
raw_config.pop(k)
- with open(config_path, "w") as f:
- yaml.dump(raw_config, f, default_flow_style=False)
+ from utils import atomic_yaml_write
+ atomic_yaml_write(config_path, raw_config)
check_ok("Migrated stale root-level keys into model section")
fixed_count += 1
else:
@@ -686,7 +686,8 @@ def run_doctor(args):
else:
check_warn("OpenRouter API", "(not configured)")
- anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY")
+ from hermes_cli.auth import get_anthropic_key
+ anthropic_key = get_anthropic_key()
if anthropic_key:
print(" Checking Anthropic API...", end="", flush=True)
try:
diff --git a/hermes_cli/main.py b/hermes_cli/main.py
index e004a6e93..4b7dd600b 100644
--- a/hermes_cli/main.py
+++ b/hermes_cli/main.py
@@ -2549,13 +2549,8 @@ def _model_flow_anthropic(config, current_model=""):
from hermes_cli.models import _PROVIDER_MODELS
# Check ALL credential sources
- existing_key = (
- get_env_value("ANTHROPIC_TOKEN")
- or os.getenv("ANTHROPIC_TOKEN", "")
- or get_env_value("ANTHROPIC_API_KEY")
- or os.getenv("ANTHROPIC_API_KEY", "")
- or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
- )
+ from hermes_cli.auth import get_anthropic_key
+ existing_key = get_anthropic_key()
cc_available = False
try:
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid
diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py
index 9154ed50a..cf2dde089 100644
--- a/hermes_cli/mcp_config.py
+++ b/hermes_cli/mcp_config.py
@@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool:
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
- display = f" {question}"
- if default:
- display += f" [{default}]"
- display += ": "
- try:
- if password:
- value = getpass.getpass(color(display, Colors.YELLOW))
- else:
- value = input(color(display, Colors.YELLOW))
- return value.strip() or default
- except (KeyboardInterrupt, EOFError):
- print()
- return default
+ from hermes_cli.cli_output import prompt as _shared_prompt
+ return _shared_prompt(question, default=default, password=password)
# ─── Config Helpers ───────────────────────────────────────────────────────────
diff --git a/hermes_cli/memory_setup.py b/hermes_cli/memory_setup.py
index 2843f4f44..1aa431367 100644
--- a/hermes_cli/memory_setup.py
+++ b/hermes_cli/memory_setup.py
@@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
items: list of (label, description) tuples.
Returns selected index, or default on escape/quit.
"""
- try:
- import curses
- result = [default]
-
- def _menu(stdscr):
- curses.curs_set(0)
- if curses.has_colors():
- curses.start_color()
- curses.use_default_colors()
- curses.init_pair(1, curses.COLOR_GREEN, -1)
- curses.init_pair(2, curses.COLOR_YELLOW, -1)
- curses.init_pair(3, curses.COLOR_CYAN, -1)
- cursor = default
-
- while True:
- stdscr.clear()
- max_y, max_x = stdscr.getmaxyx()
-
- # Title
- try:
- stdscr.addnstr(0, 0, title, max_x - 1,
- curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
- stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
- curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
- except curses.error:
- pass
-
- for i, (label, desc) in enumerate(items):
- y = i + 3
- if y >= max_y - 1:
- break
- arrow = "→" if i == cursor else " "
- line = f" {arrow} {label}"
- if desc:
- line += f" {desc}"
-
- attr = curses.A_NORMAL
- if i == cursor:
- attr = curses.A_BOLD
- if curses.has_colors():
- attr |= curses.color_pair(1)
- try:
- stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
- except curses.error:
- pass
-
- stdscr.refresh()
- key = stdscr.getch()
-
- if key in (curses.KEY_UP, ord('k')):
- cursor = (cursor - 1) % len(items)
- elif key in (curses.KEY_DOWN, ord('j')):
- cursor = (cursor + 1) % len(items)
- elif key in (curses.KEY_ENTER, 10, 13):
- result[0] = cursor
- return
- elif key in (27, ord('q')):
- return
-
- curses.wrapper(_menu)
- return result[0]
-
- except Exception:
- # Fallback: numbered input
- print(f"\n {title}\n")
- for i, (label, desc) in enumerate(items):
- marker = "→" if i == default else " "
- d = f" {desc}" if desc else ""
- print(f" {marker} {i + 1}. {label}{d}")
- while True:
- try:
- val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
- if not val:
- return default
- idx = int(val) - 1
- if 0 <= idx < len(items):
- return idx
- except (ValueError, EOFError):
- return default
+ from hermes_cli.curses_ui import curses_radiolist
+ # Format (label, desc) tuples into display strings
+ display_items = [
+ f"{label} {desc}" if desc else label
+ for label, desc in items
+ ]
+ return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py
new file mode 100644
index 000000000..18307912b
--- /dev/null
+++ b/hermes_cli/platforms.py
@@ -0,0 +1,45 @@
+"""
+Shared platform registry for Hermes Agent.
+
+Single source of truth for platform metadata consumed by both
+skills_config (label display) and tools_config (default toolset
+resolution). Import ``PLATFORMS`` from here instead of maintaining
+duplicate dicts in each module.
+"""
+
+from collections import OrderedDict
+from typing import NamedTuple
+
+
+class PlatformInfo(NamedTuple):
+ """Metadata for a single platform entry."""
+ label: str
+ default_toolset: str
+
+
+# Ordered so that TUI menus are deterministic.
+PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
+ ("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
+ ("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
+ ("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
+ ("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
+ ("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
+ ("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
+ ("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
+ ("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
+ ("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
+ ("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
+ ("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
+ ("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
+ ("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
+ ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
+ ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
+ ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
+ ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
+])
+
+
+def platform_label(key: str, default: str = "") -> str:
+ """Return the display label for a platform key, or *default*."""
+ info = PLATFORMS.get(key)
+ return info.label if info is not None else default
diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py
index ca877606f..fb70d9081 100644
--- a/hermes_cli/setup.py
+++ b/hermes_cli/setup.py
@@ -197,24 +197,12 @@ def print_header(title: str):
print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD))
-def print_info(text: str):
- """Print info text."""
- print(color(f" {text}", Colors.DIM))
-
-
-def print_success(text: str):
- """Print success message."""
- print(color(f"✓ {text}", Colors.GREEN))
-
-
-def print_warning(text: str):
- """Print warning message."""
- print(color(f"⚠ {text}", Colors.YELLOW))
-
-
-def print_error(text: str):
- """Print error message."""
- print(color(f"✗ {text}", Colors.RED))
+from hermes_cli.cli_output import ( # noqa: E402
+ print_error,
+ print_info,
+ print_success,
+ print_warning,
+)
def is_interactive_stdin() -> bool:
@@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
- """Single-select menu using curses to avoid simple_term_menu rendering bugs."""
- try:
- import curses
- result_holder = [default]
-
- def _curses_menu(stdscr):
- curses.curs_set(0)
- if curses.has_colors():
- curses.start_color()
- curses.use_default_colors()
- curses.init_pair(1, curses.COLOR_GREEN, -1)
- curses.init_pair(2, curses.COLOR_YELLOW, -1)
- cursor = default
- scroll_offset = 0
-
- while True:
- stdscr.clear()
- max_y, max_x = stdscr.getmaxyx()
-
- # Rows available for list items: rows 2..(max_y-2) inclusive.
- visible = max(1, max_y - 3)
-
- # Scroll the viewport so the cursor is always visible.
- if cursor < scroll_offset:
- scroll_offset = cursor
- elif cursor >= scroll_offset + visible:
- scroll_offset = cursor - visible + 1
- scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
-
- try:
- stdscr.addnstr(
- 0,
- 0,
- question,
- max_x - 1,
- curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
- )
- except curses.error:
- pass
-
- for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
- y = row + 2
- if y >= max_y - 1:
- break
- arrow = "→" if i == cursor else " "
- line = f" {arrow} {choices[i]}"
- attr = curses.A_NORMAL
- if i == cursor:
- attr = curses.A_BOLD
- if curses.has_colors():
- attr |= curses.color_pair(1)
- try:
- stdscr.addnstr(y, 0, line, max_x - 1, attr)
- except curses.error:
- pass
-
- stdscr.refresh()
- key = stdscr.getch()
- if key in (curses.KEY_UP, ord("k")):
- cursor = (cursor - 1) % len(choices)
- elif key in (curses.KEY_DOWN, ord("j")):
- cursor = (cursor + 1) % len(choices)
- elif key in (curses.KEY_ENTER, 10, 13):
- result_holder[0] = cursor
- return
- elif key in (27, ord("q")):
- return
-
- curses.wrapper(_curses_menu)
- from hermes_cli.curses_ui import flush_stdin
- flush_stdin()
- return result_holder[0]
- except Exception:
- return -1
+ """Single-select menu using curses. Delegates to curses_radiolist."""
+ from hermes_cli.curses_ui import curses_radiolist
+ return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py
index b017361fe..92424a0ca 100644
--- a/hermes_cli/skills_config.py
+++ b/hermes_cli/skills_config.py
@@ -15,25 +15,12 @@ from typing import List, Optional, Set
from hermes_cli.config import load_config, save_config
from hermes_cli.colors import Colors, color
+from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
-PLATFORMS = {
- "cli": "🖥️ CLI",
- "telegram": "📱 Telegram",
- "discord": "💬 Discord",
- "slack": "💼 Slack",
- "whatsapp": "📱 WhatsApp",
- "signal": "📡 Signal",
- "bluebubbles": "💬 BlueBubbles",
- "email": "📧 Email",
- "homeassistant": "🏠 Home Assistant",
- "mattermost": "💬 Mattermost",
- "matrix": "💬 Matrix",
- "dingtalk": "💬 DingTalk",
- "feishu": "🪽 Feishu",
- "wecom": "💬 WeCom",
- "weixin": "💬 Weixin",
- "webhook": "🔗 Webhook",
-}
+# Backward-compatible view: {key: label_string} so existing code that
+# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
+# working without changes to every call site.
+PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
# ─── Config Helpers ───────────────────────────────────────────────────────────
diff --git a/hermes_cli/status.py b/hermes_cli/status.py
index baba4f359..7a7a9c645 100644
--- a/hermes_cli/status.py
+++ b/hermes_cli/status.py
@@ -141,11 +141,8 @@ def show_status(args):
display = redact_key(value) if not show_all else value
print(f" {name:<12} {check_mark(has_key)} {display}")
- anthropic_value = (
- get_env_value("ANTHROPIC_TOKEN")
- or get_env_value("ANTHROPIC_API_KEY")
- or ""
- )
+ from hermes_cli.auth import get_anthropic_key
+ anthropic_value = get_anthropic_key()
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")
diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py
index 91c41dce5..343007cab 100644
--- a/hermes_cli/tools_config.py
+++ b/hermes_cli/tools_config.py
@@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
-def _print_info(text: str):
- print(color(f" {text}", Colors.DIM))
-
-def _print_success(text: str):
- print(color(f"✓ {text}", Colors.GREEN))
-
-def _print_warning(text: str):
- print(color(f"⚠ {text}", Colors.YELLOW))
-
-def _print_error(text: str):
- print(color(f"✗ {text}", Colors.RED))
-
-def _prompt(question: str, default: str = None, password: bool = False) -> str:
- if default:
- display = f"{question} [{default}]: "
- else:
- display = f"{question}: "
- try:
- if password:
- import getpass
- value = getpass.getpass(color(display, Colors.YELLOW))
- else:
- value = input(color(display, Colors.YELLOW))
- return value.strip() or default or ""
- except (KeyboardInterrupt, EOFError):
- print()
- return default or ""
+from hermes_cli.cli_output import ( # noqa: E402 — late import block
+ print_error as _print_error,
+ print_info as _print_info,
+ print_success as _print_success,
+ print_warning as _print_warning,
+ prompt as _prompt,
+)
# ─── Toolset Registry ─────────────────────────────────────────────────────────
@@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
except Exception:
return set()
-# Platform display config
+# Platform display config — derived from the canonical registry so every
+# module shares the same data. Kept as dict-of-dicts for backward
+# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
+from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
+
PLATFORMS = {
- "cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
- "telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
- "discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
- "slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
- "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
- "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
- "bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
- "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
- "email": {"label": "📧 Email", "default_toolset": "hermes-email"},
- "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
- "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
- "feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
- "wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
- "weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
- "api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
- "mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
- "webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
+ k: {"label": info.label, "default_toolset": info.default_toolset}
+ for k, info in _PLATFORMS_REGISTRY.items()
}
@@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
- """Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
- rendering bugs in tmux, iTerm, and other non-standard terminals."""
-
- # Curses-based single-select — works in tmux, iTerm, and standard terminals
- try:
- import curses
- result_holder = [default]
-
- def _curses_menu(stdscr):
- curses.curs_set(0)
- if curses.has_colors():
- curses.start_color()
- curses.use_default_colors()
- curses.init_pair(1, curses.COLOR_GREEN, -1)
- curses.init_pair(2, curses.COLOR_YELLOW, -1)
- cursor = default
-
- while True:
- stdscr.clear()
- max_y, max_x = stdscr.getmaxyx()
- try:
- stdscr.addnstr(0, 0, question, max_x - 1,
- curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
- except curses.error:
- pass
-
- for i, c in enumerate(choices):
- y = i + 2
- if y >= max_y - 1:
- break
- arrow = "→" if i == cursor else " "
- line = f" {arrow} {c}"
- attr = curses.A_NORMAL
- if i == cursor:
- attr = curses.A_BOLD
- if curses.has_colors():
- attr |= curses.color_pair(1)
- try:
- stdscr.addnstr(y, 0, line, max_x - 1, attr)
- except curses.error:
- pass
-
- stdscr.refresh()
- key = stdscr.getch()
-
- if key in (curses.KEY_UP, ord('k')):
- cursor = (cursor - 1) % len(choices)
- elif key in (curses.KEY_DOWN, ord('j')):
- cursor = (cursor + 1) % len(choices)
- elif key in (curses.KEY_ENTER, 10, 13):
- result_holder[0] = cursor
- return
- elif key in (27, ord('q')):
- return
-
- curses.wrapper(_curses_menu)
- from hermes_cli.curses_ui import flush_stdin
- flush_stdin()
- return result_holder[0]
-
- except Exception:
- pass
-
- # Fallback: numbered input (Windows without curses, etc.)
- print(color(question, Colors.YELLOW))
- for i, c in enumerate(choices):
- marker = "●" if i == default else "○"
- style = Colors.GREEN if i == default else ""
- print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
- while True:
- try:
- val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
- if not val:
- return default
- idx = int(val) - 1
- if 0 <= idx < len(choices):
- return idx
- except (ValueError, KeyboardInterrupt, EOFError):
- print()
- return default
+ """Single-select menu (arrow keys). Delegates to curses_radiolist."""
+ from hermes_cli.curses_ui import curses_radiolist
+ return curses_radiolist(question, choices, selected=default, cancel_returns=default)
# ─── Token Estimation ────────────────────────────────────────────────────────
diff --git a/hermes_constants.py b/hermes_constants.py
index 7d149f404..85955d548 100644
--- a/hermes_constants.py
+++ b/hermes_constants.py
@@ -189,6 +189,33 @@ def is_wsl() -> bool:
return _wsl_detected
+# ─── Well-Known Paths ─────────────────────────────────────────────────────────
+
+
+def get_config_path() -> Path:
+ """Return the path to ``config.yaml`` under HERMES_HOME.
+
+ Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
+ in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
+ """
+ return get_hermes_home() / "config.yaml"
+
+
+def get_skills_dir() -> Path:
+ """Return the path to the skills directory under HERMES_HOME."""
+ return get_hermes_home() / "skills"
+
+
+def get_logs_dir() -> Path:
+ """Return the path to the logs directory under HERMES_HOME."""
+ return get_hermes_home() / "logs"
+
+
+def get_env_path() -> Path:
+ """Return the path to the ``.env`` file under HERMES_HOME."""
+ return get_hermes_home() / ".env"
+
+
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
diff --git a/hermes_logging.py b/hermes_logging.py
index 5d71590c3..b765e9464 100644
--- a/hermes_logging.py
+++ b/hermes_logging.py
@@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
-from hermes_constants import get_hermes_home
+from hermes_constants import get_config_path, get_hermes_home
# Sentinel to track whether setup_logging() has already run. The function
# is idempotent — calling it twice is safe but the second call is a no-op
@@ -246,7 +246,7 @@ def _read_logging_config():
"""
try:
import yaml
- config_path = get_hermes_home() / "config.yaml"
+ config_path = get_config_path()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
diff --git a/hermes_time.py b/hermes_time.py
index f7d085544..9f172d28f 100644
--- a/hermes_time.py
+++ b/hermes_time.py
@@ -16,7 +16,7 @@ crashes due to a bad timezone string.
import logging
import os
from datetime import datetime
-from hermes_constants import get_hermes_home
+from hermes_constants import get_config_path
from typing import Optional
logger = logging.getLogger(__name__)
@@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
# 2. config.yaml ``timezone`` key
try:
import yaml
- hermes_home = get_hermes_home()
- config_path = hermes_home / "config.yaml"
+ config_path = get_config_path()
if config_path.exists():
with open(config_path) as f:
cfg = yaml.safe_load(f) or {}
diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py
index ef17af10b..d9ca627c4 100644
--- a/tests/e2e/conftest.py
+++ b/tests/e2e/conftest.py
@@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
config = PlatformConfig(enabled=True, token="e2e-test-token")
if platform == Platform.DISCORD:
- with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
+ from gateway.platforms.helpers import ThreadParticipationTracker
+ with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
adapter = DiscordAdapter(config)
platform_key = Platform.DISCORD
elif platform == Platform.SLACK:
diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py
index 5c73253fb..527113650 100644
--- a/tests/gateway/test_dingtalk.py
+++ b/tests/gateway/test_dingtalk.py
@@ -119,28 +119,29 @@ class TestDeduplication:
def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
- assert adapter._is_duplicate("msg-1") is False
+ assert adapter._dedup.is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
- adapter._is_duplicate("msg-1")
- assert adapter._is_duplicate("msg-1") is True
+ adapter._dedup.is_duplicate("msg-1")
+ assert adapter._dedup.is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
- adapter._is_duplicate("msg-1")
- assert adapter._is_duplicate("msg-2") is False
+ adapter._dedup.is_duplicate("msg-1")
+ assert adapter._dedup.is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self):
- from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
+ from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
+ max_size = adapter._dedup._max_size
# Fill beyond max
- for i in range(DEDUP_MAX_SIZE + 10):
- adapter._is_duplicate(f"msg-{i}")
+ for i in range(max_size + 10):
+ adapter._dedup.is_duplicate(f"msg-{i}")
# Cache should have been pruned
- assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
+ assert len(adapter._dedup._seen) <= max_size + 10
# ---------------------------------------------------------------------------
@@ -253,13 +254,13 @@ class TestConnect:
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x"
- adapter._seen_messages["b"] = 1.0
+ adapter._dedup._seen["b"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
await adapter.disconnect()
assert len(adapter._session_webhooks) == 0
- assert len(adapter._seen_messages) == 0
+ assert len(adapter._dedup._seen) == 0
assert adapter._http_client is None
diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py
index dd594cf7e..9f094dd0d 100644
--- a/tests/gateway/test_discord_connect.py
+++ b/tests/gateway/test_discord_connect.py
@@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
assert ok is False
assert released == [("discord-bot-token", "test-token")]
- assert adapter._token_lock_identity is None
+ assert adapter._platform_lock_identity is None
diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py
index bc63c14f5..29f65efc6 100644
--- a/tests/gateway/test_discord_free_response.py
+++ b/tests/gateway/test_discord_free_response.py
@@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
# Simulate bot having previously participated in thread 456
- adapter._bot_participated_threads.add("456")
+ adapter._threads.mark("456")
thread = FakeThread(channel_id=456, name="existing thread")
message = make_message(channel=thread, content="follow-up without mention")
@@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
await adapter._handle_message(message)
- assert "555" in adapter._bot_participated_threads
+ assert "555" in adapter._threads
@pytest.mark.asyncio
@@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
await adapter._handle_message(message)
- assert "777" in adapter._bot_participated_threads
+ assert "777" in adapter._threads
diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py
index 0288b620d..083f61ac7 100644
--- a/tests/gateway/test_discord_thread_persistence.py
+++ b/tests/gateway/test_discord_thread_persistence.py
@@ -1,6 +1,6 @@
"""Tests for Discord thread participation persistence.
-Verifies that _bot_participated_threads survives adapter restarts by
+Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
being persisted to ~/.hermes/discord_threads.json.
"""
@@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
def test_starts_empty_when_no_state_file(self, tmp_path):
adapter = self._make_adapter(tmp_path)
- assert adapter._bot_participated_threads == set()
+ assert "$nonexistent" not in adapter._threads
def test_track_thread_persists_to_disk(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
- adapter._track_thread("111")
- adapter._track_thread("222")
+ adapter._threads.mark("111")
+ adapter._threads.mark("222")
state_file = tmp_path / "discord_threads.json"
assert state_file.exists()
@@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
"""Threads tracked by one adapter instance are visible to the next."""
adapter1 = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
- adapter1._track_thread("aaa")
- adapter1._track_thread("bbb")
+ adapter1._threads.mark("aaa")
+ adapter1._threads.mark("bbb")
adapter2 = self._make_adapter(tmp_path)
- assert "aaa" in adapter2._bot_participated_threads
- assert "bbb" in adapter2._bot_participated_threads
+ assert "aaa" in adapter2._threads
+ assert "bbb" in adapter2._threads
def test_duplicate_track_does_not_double_save(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
- adapter._track_thread("111")
- adapter._track_thread("111") # no-op
+ adapter._threads.mark("111")
+ adapter._threads.mark("111") # no-op
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert saved.count("111") == 1
def test_caps_at_max_tracked_threads(self, tmp_path):
adapter = self._make_adapter(tmp_path)
- adapter._MAX_TRACKED_THREADS = 5
+ adapter._threads._max_tracked = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
for i in range(10):
- adapter._track_thread(str(i))
+ adapter._threads.mark(str(i))
- assert len(adapter._bot_participated_threads) == 5
+ saved = json.loads((tmp_path / "discord_threads.json").read_text())
+ assert len(saved) == 5
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json"
state_file.write_text("not valid json{{{")
adapter = self._make_adapter(tmp_path)
- assert adapter._bot_participated_threads == set()
+ assert "$nonexistent" not in adapter._threads
def test_missing_hermes_home_does_not_crash(self, tmp_path):
"""Load/save tolerate missing directories."""
fake_home = tmp_path / "nonexistent" / "deep"
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
- from gateway.platforms.discord import DiscordAdapter
- # _load should return empty set, not crash
- threads = DiscordAdapter._load_participated_threads()
- assert threads == set()
+ from gateway.platforms.helpers import ThreadParticipationTracker
+ # ThreadParticipationTracker should return empty set, not crash
+ tracker = ThreadParticipationTracker("discord")
+ assert "$test" not in tracker
diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py
index d36c2b765..873b873c2 100644
--- a/tests/gateway/test_matrix_mention.py
+++ b/tests/gateway/test_matrix_mention.py
@@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
- adapter._bot_participated_threads.add("$thread1")
+ adapter._threads.mark("$thread1")
event = _make_event("hello without mention", thread_id="$thread1")
@@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter()
- adapter._bot_participated_threads.add("$thread_root")
+ adapter._threads.mark("$thread_root")
event = _make_event("reply in thread", thread_id="$thread_root")
await adapter._on_room_message(event)
@@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
@pytest.mark.asyncio
async def test_auto_thread_tracks_participation(monkeypatch):
- """Auto-created threads are tracked in _bot_participated_threads."""
+ """Auto-created threads are tracked in _threads."""
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter()
event = _make_event("hello", event_id="$msg1")
- with patch.object(adapter, "_save_participated_threads"):
+ with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
- assert "$msg1" in adapter._bot_participated_threads
+ assert "$msg1" in adapter._threads
# ---------------------------------------------------------------------------
@@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
class TestThreadPersistence:
def test_empty_state_file(self, tmp_path, monkeypatch):
"""No state file → empty set."""
- from gateway.platforms.matrix import MatrixAdapter
+ from gateway.platforms.helpers import ThreadParticipationTracker
monkeypatch.setattr(
- MatrixAdapter, "_thread_state_path",
- staticmethod(lambda: tmp_path / "matrix_threads.json"),
+ ThreadParticipationTracker, "_state_path",
+ lambda self: tmp_path / "matrix_threads.json",
)
adapter = _make_adapter()
- loaded = adapter._load_participated_threads()
- assert loaded == set()
+ assert "$nonexistent" not in adapter._threads
def test_track_thread_persists(self, tmp_path, monkeypatch):
- """_track_thread writes to disk."""
- from gateway.platforms.matrix import MatrixAdapter
+ """mark() writes to disk."""
+ from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr(
- MatrixAdapter, "_thread_state_path",
- staticmethod(lambda: state_path),
+ ThreadParticipationTracker, "_state_path",
+ lambda self: state_path,
)
adapter = _make_adapter()
- adapter._track_thread("$thread_abc")
+ adapter._threads.mark("$thread_abc")
data = json.loads(state_path.read_text())
assert "$thread_abc" in data
def test_threads_survive_reload(self, tmp_path, monkeypatch):
"""Persisted threads are loaded by a new adapter instance."""
- from gateway.platforms.matrix import MatrixAdapter
+ from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
state_path.write_text(json.dumps(["$t1", "$t2"]))
monkeypatch.setattr(
- MatrixAdapter, "_thread_state_path",
- staticmethod(lambda: state_path),
+ ThreadParticipationTracker, "_state_path",
+ lambda self: state_path,
)
adapter = _make_adapter()
- assert "$t1" in adapter._bot_participated_threads
- assert "$t2" in adapter._bot_participated_threads
+ assert "$t1" in adapter._threads
+ assert "$t2" in adapter._threads
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
- """Thread set is trimmed to _MAX_TRACKED_THREADS."""
- from gateway.platforms.matrix import MatrixAdapter
+ """Thread set is trimmed to max_tracked."""
+ from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr(
- MatrixAdapter, "_thread_state_path",
- staticmethod(lambda: state_path),
+ ThreadParticipationTracker, "_state_path",
+ lambda self: state_path,
)
adapter = _make_adapter()
- adapter._MAX_TRACKED_THREADS = 5
+ adapter._threads._max_tracked = 5
for i in range(10):
- adapter._bot_participated_threads.add(f"$t{i}")
- adapter._save_participated_threads()
+ adapter._threads.mark(f"$t{i}")
data = json.loads(state_path.read_text())
assert len(data) == 5
@@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
_set_dm(adapter)
event = _make_event("@hermes:example.org help me", event_id="$dm1")
- with patch.object(adapter, "_save_participated_threads"):
+ with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
adapter.handle_message.assert_awaited_once()
@@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
adapter = _make_adapter()
_set_dm(adapter)
- adapter._bot_participated_threads.add("$existing_thread")
+ adapter._threads.mark("$existing_thread")
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
await adapter._on_room_message(event)
@@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
@pytest.mark.asyncio
async def test_dm_mention_thread_tracks_participation(monkeypatch):
- """DM mention-thread tracks the thread in _bot_participated_threads."""
+ """DM mention-thread tracks the thread in _threads."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
@@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
_set_dm(adapter)
event = _make_event("@hermes:example.org help", event_id="$dm1")
- with patch.object(adapter, "_save_participated_threads"):
+ with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
- assert "$dm1" in adapter._bot_participated_threads
+ assert "$dm1" in adapter._threads
# ---------------------------------------------------------------------------
diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py
index 7d47c0a3e..56e46f636 100644
--- a/tests/gateway/test_mattermost.py
+++ b/tests/gateway/test_mattermost.py
@@ -614,25 +614,27 @@ class TestMattermostDedup:
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
- """_prune_seen should remove entries older than _SEEN_TTL."""
+ """Dedup cache should remove entries older than TTL on overflow."""
now = time.time()
+ dedup = self.adapter._dedup
# Fill with enough expired entries to trigger pruning
- for i in range(self.adapter._SEEN_MAX + 10):
- self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
+ for i in range(dedup._max_size + 10):
+ dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
# Add a fresh one
- self.adapter._seen_posts["fresh"] = now
+ dedup._seen["fresh"] = now
- self.adapter._prune_seen()
+ # Trigger pruning by calling is_duplicate with a new entry (over max_size)
+ dedup.is_duplicate("trigger_prune")
# Old entries should be pruned, fresh one kept
- assert "fresh" in self.adapter._seen_posts
- assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
+ assert "fresh" in dedup._seen
+ assert len(dedup._seen) < dedup._max_size + 10
def test_seen_cache_tracks_post_ids(self):
- """Posts are tracked in _seen_posts dict."""
- self.adapter._seen_posts["test_post"] = time.time()
- assert "test_post" in self.adapter._seen_posts
+ """Posts are tracked in the dedup cache."""
+ self.adapter._dedup._seen["test_post"] = time.time()
+ assert "test_post" in self.adapter._dedup._seen
# ---------------------------------------------------------------------------
diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py
index ae985300d..265f9be78 100644
--- a/tests/gateway/test_signal.py
+++ b/tests/gateway/test_signal.py
@@ -114,16 +114,16 @@ class TestSignalAdapterInit:
class TestSignalHelpers:
def test_redact_phone_long(self):
- from gateway.platforms.signal import _redact_phone
- assert _redact_phone("+15551234567") == "+155****4567"
+ from gateway.platforms.helpers import redact_phone
+ assert redact_phone("+155****4567") == "+155****4567"
def test_redact_phone_short(self):
- from gateway.platforms.signal import _redact_phone
- assert _redact_phone("+12345") == "+1****45"
+ from gateway.platforms.helpers import redact_phone
+ assert redact_phone("+12345") == "+1****45"
def test_redact_phone_empty(self):
- from gateway.platforms.signal import _redact_phone
- assert _redact_phone("") == ""
+ from gateway.platforms.helpers import redact_phone
+ assert redact_phone("") == ""
def test_parse_comma_list(self):
from gateway.platforms.signal import _parse_comma_list
diff --git a/tests/gateway/test_telegram_conflict.py b/tests/gateway/test_telegram_conflict.py
index 47a67f229..dcf311688 100644
--- a/tests/gateway/test_telegram_conflict.py
+++ b/tests/gateway/test_telegram_conflict.py
@@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch):
async def _noop():
return []
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
+ # Mock HTTPXRequest so the builder chain doesn't fail
+ monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock())
@pytest.mark.asyncio
@@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch):
ok = await adapter.connect()
assert ok is False
- assert adapter.fatal_error_code == "telegram_token_lock"
+ assert adapter.fatal_error_code == "telegram-bot-token_lock"
assert adapter.has_fatal_error is True
- assert "already using this Telegram bot token" in adapter.fatal_error_message
+ assert "already in use" in adapter.fatal_error_message
@pytest.mark.asyncio
@@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
+ builder.request.return_value = builder
+ builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
+ builder.request.return_value = builder
+ builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m
builder = MagicMock()
builder.token.return_value = builder
+ builder.request.return_value = builder
+ builder.get_updates_request.return_value = builder
app = SimpleNamespace(
bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()),
updater=SimpleNamespace(),
@@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
+ builder.request.return_value = builder
+ builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr(
"gateway.platforms.telegram.Application",
diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py
index 7b9e49d4f..dd0ae17f8 100644
--- a/tests/tools/test_skill_manager_tool.py
+++ b/tests/tools/test_skill_manager_tool.py
@@ -348,7 +348,7 @@ word word
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
assert result["success"] is False
- assert "boundary" in result["error"].lower()
+ assert "escapes" in result["error"].lower()
assert outside_file.read_text() == "old text here"
@@ -412,7 +412,7 @@ class TestWriteFile:
result = _write_file("my-skill", "references/escape/owned.md", "malicious")
assert result["success"] is False
- assert "boundary" in result["error"].lower()
+ assert "escapes" in result["error"].lower()
assert not (outside_dir / "owned.md").exists()
@@ -449,7 +449,7 @@ class TestRemoveFile:
result = _remove_file("my-skill", "references/escape/keep.txt")
assert result["success"] is False
- assert "boundary" in result["error"].lower()
+ assert "escapes" in result["error"].lower()
assert outside_file.exists()
diff --git a/tools/credential_files.py b/tools/credential_files.py
index 6ddcd0770..7998321e6 100644
--- a/tools/credential_files.py
+++ b/tools/credential_files.py
@@ -80,20 +80,18 @@ def register_credential_file(
# Resolve symlinks and normalise ``..`` before the containment check so
# that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME.
- try:
- resolved = host_path.resolve()
- hermes_home_resolved = hermes_home.resolve()
- resolved.relative_to(hermes_home_resolved) # raises ValueError if outside
- except ValueError:
+ from tools.path_security import validate_within_dir
+
+ containment_error = validate_within_dir(host_path, hermes_home)
+ if containment_error:
logger.warning(
- "credential_files: rejected path traversal %r "
- "(resolves to %s, outside HERMES_HOME %s)",
+ "credential_files: rejected path traversal %r (%s)",
relative_path,
- resolved,
- hermes_home_resolved,
+ containment_error,
)
return False
+ resolved = host_path.resolve()
if not resolved.is_file():
logger.debug("credential_files: skipping %s (not found)", resolved)
return False
@@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]:
cfg = read_raw_config()
cred_files = cfg.get("terminal", {}).get("credential_files")
if isinstance(cred_files, list):
- hermes_home_resolved = hermes_home.resolve()
+ from tools.path_security import validate_within_dir
+
for item in cred_files:
if isinstance(item, str) and item.strip():
rel = item.strip()
@@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]:
"credential_files: rejected absolute config path %r", rel,
)
continue
- host_path = (hermes_home / rel).resolve()
- try:
- host_path.relative_to(hermes_home_resolved)
- except ValueError:
+ host_path = hermes_home / rel
+ containment_error = validate_within_dir(host_path, hermes_home)
+ if containment_error:
logger.warning(
- "credential_files: rejected config path traversal %r "
- "(resolves to %s, outside HERMES_HOME %s)",
- rel, host_path, hermes_home_resolved,
+ "credential_files: rejected config path traversal %r (%s)",
+ rel, containment_error,
)
continue
- if host_path.is_file():
+ resolved_path = host_path.resolve()
+ if resolved_path.is_file():
container_path = f"/root/.hermes/{rel}"
result.append({
- "host_path": str(host_path),
+ "host_path": str(resolved_path),
"container_path": container_path,
})
except Exception as e:
diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py
index 3018b8731..e2db93381 100644
--- a/tools/cronjob_tools.py
+++ b/tools/cronjob_tools.py
@@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]:
)
# Validate containment after resolution
+ from tools.path_security import validate_within_dir
+
scripts_dir = get_hermes_home() / "scripts"
scripts_dir.mkdir(parents=True, exist_ok=True)
- resolved = (scripts_dir / raw).resolve()
- try:
- resolved.relative_to(scripts_dir.resolve())
- except ValueError:
+ containment_error = validate_within_dir(scripts_dir / raw, scripts_dir)
+ if containment_error:
return (
f"Script path escapes the scripts directory via traversal: {raw!r}"
)
diff --git a/tools/path_security.py b/tools/path_security.py
new file mode 100644
index 000000000..828011e5d
--- /dev/null
+++ b/tools/path_security.py
@@ -0,0 +1,43 @@
+"""Shared path validation helpers for tool implementations.
+
+Extracts the ``resolve() + relative_to()`` and ``..`` traversal check
+patterns previously duplicated across skill_manager_tool, skills_tool,
+skills_hub, cronjob_tools, and credential_files.
+"""
+
+import logging
+from pathlib import Path
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+def validate_within_dir(path: Path, root: Path) -> Optional[str]:
+ """Ensure *path* resolves to a location within *root*.
+
+ Returns an error message string if validation fails, or ``None`` if the
+ path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize
+ ``..`` components.
+
+ Usage::
+
+ error = validate_within_dir(user_path, allowed_root)
+ if error:
+ return json.dumps({"error": error})
+ """
+ try:
+ resolved = path.resolve()
+ root_resolved = root.resolve()
+ resolved.relative_to(root_resolved)
+ except (ValueError, OSError) as exc:
+ return f"Path escapes allowed directory: {exc}"
+ return None
+
+
+def has_traversal_component(path_str: str) -> bool:
+ """Return True if *path_str* contains ``..`` traversal components.
+
+ Quick check for obvious traversal attempts before doing full resolution.
+ """
+ parts = Path(path_str).parts
+ return ".." in parts
diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py
index 2273d75fa..2b2625fa0 100644
--- a/tools/skill_manager_tool.py
+++ b/tools/skill_manager_tool.py
@@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]:
Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir.
"""
+ from tools.path_security import has_traversal_component
+
if not file_path:
return "file_path is required."
normalized = Path(file_path)
# Prevent path traversal
- if ".." in normalized.parts:
+ if has_traversal_component(file_path):
return "Path traversal ('..') is not allowed."
# Must be under an allowed subdirectory
@@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]:
def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]:
"""Resolve a supporting-file path and ensure it stays within the skill directory."""
+ from tools.path_security import validate_within_dir
+
target = skill_dir / file_path
- try:
- resolved = target.resolve(strict=False)
- skill_dir_resolved = skill_dir.resolve()
- resolved.relative_to(skill_dir_resolved)
- except ValueError:
- return None, "Path escapes skill directory boundary."
- except OSError as e:
- return None, f"Invalid file path '{file_path}': {e}"
+ error = validate_within_dir(target, skill_dir)
+ if error:
+ return None, error
return target, None
diff --git a/tools/skills_tool.py b/tools/skills_tool.py
index 085ed0055..94b7c235b 100644
--- a/tools/skills_tool.py
+++ b/tools/skills_tool.py
@@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
return None
-def _estimate_tokens(content: str) -> int:
- """
- Rough token estimate (4 chars per token average).
-
- Args:
- content: Text content
-
- Returns:
- Estimated token count
- """
- return len(content) // 4
+# Token estimation — use the shared implementation from model_metadata.
+from agent.model_metadata import estimate_tokens_rough as _estimate_tokens
def _parse_tags(tags_value) -> List[str]:
@@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
# If a specific file path is requested, read that instead
if file_path and skill_dir:
+ from tools.path_security import validate_within_dir, has_traversal_component
+
# Security: Prevent path traversal attacks
- normalized_path = Path(file_path)
- if ".." in normalized_path.parts:
+ if has_traversal_component(file_path):
return json.dumps(
{
"success": False,
@@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
target_file = skill_dir / file_path
# Security: Verify resolved path is still within skill directory
- try:
- resolved = target_file.resolve()
- skill_dir_resolved = skill_dir.resolve()
- if not resolved.is_relative_to(skill_dir_resolved):
- return json.dumps(
- {
- "success": False,
- "error": "Path escapes skill directory boundary.",
- "hint": "Use a relative path within the skill directory",
- },
- ensure_ascii=False,
- )
- except (OSError, ValueError):
+ traversal_error = validate_within_dir(target_file, skill_dir)
+ if traversal_error:
return json.dumps(
{
"success": False,
- "error": f"Invalid file path: '{file_path}'",
- "hint": "Use a valid relative path within the skill directory",
+ "error": traversal_error,
+ "hint": "Use a relative path within the skill directory",
},
ensure_ascii=False,
)
diff --git a/utils.py b/utils.py
index 9a2105d54..bd2a6b70f 100644
--- a/utils.py
+++ b/utils.py
@@ -1,13 +1,16 @@
"""Shared utility functions for hermes-agent."""
import json
+import logging
import os
import tempfile
from pathlib import Path
-from typing import Any, Union
+from typing import Any, List, Optional, Union
import yaml
+logger = logging.getLogger(__name__)
+
TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"})
@@ -124,3 +127,88 @@ def atomic_yaml_write(
except OSError:
pass
raise
+
+
+# ─── JSON Helpers ─────────────────────────────────────────────────────────────
+
+
+def safe_json_loads(text: str, default: Any = None) -> Any:
+ """Parse JSON, returning *default* on any parse error.
+
+ Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)``
+ pattern duplicated across display.py, anthropic_adapter.py,
+ auxiliary_client.py, and others.
+ """
+ try:
+ return json.loads(text)
+ except (json.JSONDecodeError, TypeError, ValueError):
+ return default
+
+
+def read_json_file(path: Path, default: Any = None) -> Any:
+ """Read and parse a JSON file, returning *default* on any error.
+
+ Replaces the repeated ``try: json.loads(path.read_text()) except ...``
+ pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py,
+ and skill_utils.py.
+ """
+ try:
+ return json.loads(Path(path).read_text(encoding="utf-8"))
+ except (json.JSONDecodeError, OSError, IOError, ValueError) as exc:
+ logger.debug("Failed to read %s: %s", path, exc)
+ return default
+
+
+def read_jsonl(path: Path) -> List[dict]:
+ """Read a JSONL file (one JSON object per line).
+
+ Returns a list of parsed objects, skipping blank lines.
+ """
+ entries = []
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ entries.append(json.loads(line))
+ return entries
+
+
+def append_jsonl(path: Path, entry: dict) -> None:
+ """Append a single JSON object as a new line to a JSONL file."""
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "a", encoding="utf-8") as f:
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
+
+
+# ─── Environment Variable Helpers ─────────────────────────────────────────────
+
+
+def env_str(key: str, default: str = "") -> str:
+ """Read an environment variable, stripped of whitespace.
+
+ Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times
+ across runtime_provider.py, anthropic_adapter.py, models.py, etc.
+ """
+ return os.getenv(key, default).strip()
+
+
+def env_lower(key: str, default: str = "") -> str:
+ """Read an environment variable, stripped and lowercased."""
+ return os.getenv(key, default).strip().lower()
+
+
+def env_int(key: str, default: int = 0) -> int:
+ """Read an environment variable as an integer, with fallback."""
+ raw = os.getenv(key, "").strip()
+ if not raw:
+ return default
+ try:
+ return int(raw)
+ except (ValueError, TypeError):
+ return default
+
+
+def env_bool(key: str, default: bool = False) -> bool:
+ """Read an environment variable as a boolean."""
+ return is_truthy_value(os.getenv(key, ""), default=default)