refactor: extract shared helpers to deduplicate repeated code patterns (#7917)

* refactor: add shared helper modules for code deduplication

New modules:
- gateway/platforms/helpers.py: MessageDeduplicator, TextBatchAggregator,
  strip_markdown, ThreadParticipationTracker, redact_phone
- hermes_cli/cli_output.py: print_info/success/warning/error, prompt helpers
- tools/path_security.py: validate_within_dir, has_traversal_component
- utils.py additions: safe_json_loads, read_json_file, read_jsonl,
  append_jsonl, env_str/lower/int/bool helpers
- hermes_constants.py additions: get_config_path, get_skills_dir,
  get_logs_dir, get_env_path

* refactor: migrate gateway adapters to shared helpers

- MessageDeduplicator: discord, slack, dingtalk, wecom, weixin, mattermost
- strip_markdown: bluebubbles, feishu, sms
- redact_phone: sms, signal
- ThreadParticipationTracker: discord, matrix
- _acquire/_release_platform_lock: telegram, discord, slack, whatsapp,
  signal, weixin

Net -316 lines across 19 files.

* refactor: migrate CLI modules to shared helpers

- tools_config.py: use cli_output print/prompt + curses_radiolist (-117 lines)
- setup.py: use cli_output print helpers + curses_radiolist (-101 lines)
- mcp_config.py: use cli_output prompt (-15 lines)
- memory_setup.py: use curses_radiolist (-86 lines)

Net -263 lines across 5 files.

* refactor: migrate to shared utility helpers

- safe_json_loads: agent/display.py (4 sites)
- get_config_path: skill_utils.py, hermes_logging.py, hermes_time.py
- get_skills_dir: skill_utils.py, prompt_builder.py
- Token estimation dedup: skills_tool.py imports from model_metadata
- Path security: skills_tool, cronjob_tools, skill_manager_tool, credential_files
- Non-atomic YAML writes: doctor.py, config.py now use atomic_yaml_write
- Platform dict: new platforms.py, skills_config + tools_config derive from it
- Anthropic key: new get_anthropic_key() in auth.py, used by doctor/status/config/main

* test: update tests for shared helper migrations

- test_dingtalk: use _dedup.is_duplicate() instead of _is_duplicate()
- test_mattermost: use _dedup instead of _seen_posts/_prune_seen
- test_signal: import redact_phone from helpers instead of signal
- test_discord_connect: _platform_lock_identity instead of _token_lock_identity
- test_telegram_conflict: updated lock error message format
- test_skill_manager_tool: 'escapes' instead of 'boundary' in error msgs
This commit is contained in:
Teknium 2026-04-11 13:59:52 -07:00 committed by GitHub
parent cf53e2676b
commit 04c1c5d53f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 887 additions and 949 deletions

View file

@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
Used by AIAgent._execute_tool_calls for CLI feedback. Used by AIAgent._execute_tool_calls for CLI feedback.
""" """
import json
import logging import logging
import os import os
import sys import sys
@ -14,6 +13,8 @@ from dataclasses import dataclass, field
from difflib import unified_diff from difflib import unified_diff
from pathlib import Path from pathlib import Path
from utils import safe_json_loads
# ANSI escape codes for coloring tool failure indicators # ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m" _RED = "\033[31m"
_RESET = "\033[0m" _RESET = "\033[0m"
@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
"""Conservatively detect whether a tool result represents success.""" """Conservatively detect whether a tool result represents success."""
if not result: if not result:
return False return False
try: data = safe_json_loads(result)
data = json.loads(result) if data is None:
except (json.JSONDecodeError, TypeError):
return False return False
if not isinstance(data, dict): if not isinstance(data, dict):
return False return False
@ -423,10 +423,7 @@ def extract_edit_diff(
) -> str | None: ) -> str | None:
"""Extract a unified diff from a file-edit tool result.""" """Extract a unified diff from a file-edit tool result."""
if tool_name == "patch" and result: if tool_name == "patch" and result:
try: data = safe_json_loads(result)
data = json.loads(result)
except (json.JSONDecodeError, TypeError):
data = None
if isinstance(data, dict): if isinstance(data, dict):
diff = data.get("diff") diff = data.get("diff")
if isinstance(diff, str) and diff.strip(): if isinstance(diff, str) and diff.strip():
@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
return False, "" return False, ""
if tool_name == "terminal": if tool_name == "terminal":
try: data = safe_json_loads(result)
data = json.loads(result) if isinstance(data, dict):
exit_code = data.get("exit_code") exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0: if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]" return True, f" [exit {exit_code}]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse terminal result as JSON for exit code check")
return False, "" return False, ""
# Memory-specific: distinguish "full" from real errors # Memory-specific: distinguish "full" from real errors
if tool_name == "memory": if tool_name == "memory":
try: data = safe_json_loads(result)
data = json.loads(result) if isinstance(data, dict):
if data.get("success") is False and "exceed the limit" in data.get("error", ""): if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]" return True, " [full]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse memory result as JSON for capacity check")
# Generic heuristic for non-terminal tools # Generic heuristic for non-terminal tools
lower = result[:500].lower() lower = result[:500].lower()

View file

@ -12,7 +12,7 @@ import threading
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from hermes_constants import get_hermes_home from hermes_constants import get_hermes_home, get_skills_dir
from typing import Optional from typing import Optional
from agent.skill_utils import ( from agent.skill_utils import (
@ -548,8 +548,7 @@ def build_skills_system_prompt(
are read-only they appear in the index but new skills are always created are read-only they appear in the index but new skills are always created
in the local dir. Local skills take precedence when names collide. in the local dir. Local skills take precedence when names collide.
""" """
hermes_home = get_hermes_home() skills_dir = get_skills_dir()
skills_dir = hermes_home / "skills"
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0) external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
if not skills_dir.exists() and not external_dirs: if not skills_dir.exists() and not external_dirs:

View file

@ -12,7 +12,7 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Set, Tuple from typing import Any, Dict, List, Set, Tuple
from hermes_constants import get_hermes_home from hermes_constants import get_config_path, get_skills_dir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
Reads the config file directly (no CLI config imports) to stay Reads the config file directly (no CLI config imports) to stay
lightweight. lightweight.
""" """
config_path = get_hermes_home() / "config.yaml" config_path = get_config_path()
if not config_path.exists(): if not config_path.exists():
return set() return set()
try: try:
@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
path. Only directories that actually exist are returned. Duplicates and path. Only directories that actually exist are returned. Duplicates and
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped. paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
""" """
config_path = get_hermes_home() / "config.yaml" config_path = get_config_path()
if not config_path.exists(): if not config_path.exists():
return [] return []
try: try:
@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
if not isinstance(raw_dirs, list): if not isinstance(raw_dirs, list):
return [] return []
local_skills = (get_hermes_home() / "skills").resolve() local_skills = get_skills_dir().resolve()
seen: Set[Path] = set() seen: Set[Path] = set()
result: List[Path] = [] result: List[Path] = []
@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
The local dir is always first (and always included even if it doesn't exist The local dir is always first (and always included even if it doesn't exist
yet callers handle that). External dirs follow in config order. yet callers handle that). External dirs follow in config order.
""" """
dirs = [get_hermes_home() / "skills"] dirs = [get_skills_dir()]
dirs.extend(get_external_skills_dirs()) dirs.extend(get_external_skills_dirs())
return dirs return dirs
@ -384,7 +384,7 @@ def resolve_skill_config_values(
current values (or the declared default if the key isn't set). current values (or the declared default if the key isn't set).
Path values are expanded via ``os.path.expanduser``. Path values are expanded via ``os.path.expanduser``.
""" """
config_path = get_hermes_home() / "config.yaml" config_path = get_config_path()
config: Dict[str, Any] = {} config: Dict[str, Any] = {}
if config_path.exists(): if config_path.exists():
try: try:

View file

@ -824,6 +824,35 @@ class BasePlatformAdapter(ABC):
if asyncio.iscoroutine(result): if asyncio.iscoroutine(result):
await result await result
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
"""Acquire a scoped lock for this adapter. Returns True on success."""
from gateway.status import acquire_scoped_lock
self._platform_lock_scope = scope
self._platform_lock_identity = identity
acquired, existing = acquire_scoped_lock(
scope, identity, metadata={'platform': self.platform.value}
)
if acquired:
return True
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = (
f'{resource_desc} already in use'
+ (f' (PID {owner_pid})' if owner_pid else '')
+ '. Stop the other gateway first.'
)
logger.error('[%s] %s', self.name, message)
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
return False
def _release_platform_lock(self) -> None:
"""Release the scoped lock acquired by _acquire_platform_lock."""
identity = getattr(self, '_platform_lock_identity', None)
if not identity:
return
from gateway.status import release_scoped_lock
release_scoped_lock(self._platform_lock_scope, identity)
self._platform_lock_identity = None
@property @property
def name(self) -> str: def name(self) -> str:
"""Human-readable name for this adapter.""" """Human-readable name for this adapter."""

View file

@ -30,6 +30,7 @@ from gateway.platforms.base import (
cache_audio_from_bytes, cache_audio_from_bytes,
cache_document_from_bytes, cache_document_from_bytes,
) )
from gateway.platforms.helpers import strip_markdown
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
return value.rstrip("/") return value.rstrip("/")
def _strip_markdown(text: str) -> str:
"""Strip common markdown formatting for iMessage plain-text delivery."""
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
text = re.sub(r"`(.+?)`", r"\1", text)
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
reply_to: Optional[str] = None, reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> SendResult: ) -> SendResult:
text = _strip_markdown(content or "") text = strip_markdown(content or "")
if not text: if not text:
return SendResult(success=False, error="BlueBubbles send requires text") return SendResult(success=False, error="BlueBubbles send requires text")
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH) chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
return info return info
def format_message(self, content: str) -> str: def format_message(self, content: str) -> str:
return _strip_markdown(content) return strip_markdown(content)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Inbound attachment downloading (from #4588) # Inbound attachment downloading (from #4588)

View file

@ -42,6 +42,7 @@ except ImportError:
httpx = None # type: ignore[assignment] httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -52,8 +53,6 @@ from gateway.platforms.base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_MESSAGE_LENGTH = 20000 MAX_MESSAGE_LENGTH = 20000
DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000
RECONNECT_BACKOFF = [2, 5, 10, 30, 60] RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
_SESSION_WEBHOOKS_MAX = 500 _SESSION_WEBHOOKS_MAX = 500
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/') _DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_task: Optional[asyncio.Task] = None self._stream_task: Optional[asyncio.Task] = None
self._http_client: Optional["httpx.AsyncClient"] = None self._http_client: Optional["httpx.AsyncClient"] = None
# Message deduplication: msg_id -> timestamp # Message deduplication
self._seen_messages: Dict[str, float] = {} self._dedup = MessageDeduplicator(max_size=1000)
# Map chat_id -> session_webhook for reply routing # Map chat_id -> session_webhook for reply routing
self._session_webhooks: Dict[str, str] = {} self._session_webhooks: Dict[str, str] = {}
@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_client = None self._stream_client = None
self._session_webhooks.clear() self._session_webhooks.clear()
self._seen_messages.clear() self._dedup.clear()
logger.info("[%s] Disconnected", self.name) logger.info("[%s] Disconnected", self.name)
# -- Inbound message processing ----------------------------------------- # -- Inbound message processing -----------------------------------------
@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
async def _on_message(self, message: "ChatbotMessage") -> None: async def _on_message(self, message: "ChatbotMessage") -> None:
"""Process an incoming DingTalk chatbot message.""" """Process an incoming DingTalk chatbot message."""
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
if self._is_duplicate(msg_id): if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
return return
@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
content = " ".join(parts).strip() content = " ".join(parts).strip()
return content return content
# -- Deduplication ------------------------------------------------------
def _is_duplicate(self, msg_id: str) -> bool:
"""Check and record a message ID. Returns True if already seen."""
now = time.time()
if len(self._seen_messages) > DEDUP_MAX_SIZE:
cutoff = now - DEDUP_WINDOW_SECONDS
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
if msg_id in self._seen_messages:
return True
self._seen_messages[msg_id] = now
return False
# -- Outbound messaging ------------------------------------------------- # -- Outbound messaging -------------------------------------------------
async def send( async def send(

View file

@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
import re import re
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
# Track threads where the bot has participated so follow-up messages # Track threads where the bot has participated so follow-up messages
# in those threads don't require @mention. Persisted to disk so the # in those threads don't require @mention. Persisted to disk so the
# set survives gateway restarts. # set survives gateway restarts.
self._bot_participated_threads: set = self._load_participated_threads() self._threads = ThreadParticipationTracker("discord")
# Persistent typing indicator loops per channel (DMs don't reliably # Persistent typing indicator loops per channel (DMs don't reliably
# show the standard typing gateway event for bots) # show the standard typing gateway event for bots)
self._typing_tasks: Dict[str, asyncio.Task] = {} self._typing_tasks: Dict[str, asyncio.Task] = {}
self._bot_task: Optional[asyncio.Task] = None self._bot_task: Optional[asyncio.Task] = None
# Cap to prevent unbounded growth (Discord threads get archived). # Dedup cache: prevents duplicate bot responses when Discord
self._MAX_TRACKED_THREADS = 500 # RESUME replays events after reconnects.
# Dedup cache: message_id → timestamp. Prevents duplicate bot self._dedup = MessageDeduplicator()
# responses when Discord RESUME replays events after reconnects.
self._seen_messages: Dict[str, float] = {}
self._SEEN_TTL = 300 # 5 minutes
self._SEEN_MAX = 2000 # prune threshold
# Reply threading mode: "off" (no replies), "first" (reply on first # Reply threading mode: "off" (no replies), "first" (reply on first
# chunk only, default), "all" (reply-reference on every chunk). # chunk only, default), "all" (reply-reference on every chunk).
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first' self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
return False return False
try: try:
# Acquire scoped lock to prevent duplicate bot token usage if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
from gateway.status import acquire_scoped_lock
self._token_lock_identity = self.config.token
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
if not acquired:
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
logger.error('[%s] %s', self.name, message)
self._set_fatal_error('discord_token_lock', message, retryable=False)
return False return False
# Parse allowed user entries (may contain usernames or IDs) # Parse allowed user entries (may contain usernames or IDs)
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "") allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
if allowed_env: if allowed_env:
@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event @self._client.event
async def on_message(message: DiscordMessage): async def on_message(message: DiscordMessage):
# Dedup: Discord RESUME replays events after reconnects (#4777) # Dedup: Discord RESUME replays events after reconnects (#4777)
msg_id = str(message.id) if adapter_self._dedup.is_duplicate(str(message.id)):
now = time.time()
if msg_id in adapter_self._seen_messages:
return return
adapter_self._seen_messages[msg_id] = now
if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
cutoff = now - adapter_self._SEEN_TTL
adapter_self._seen_messages = {
k: v for k, v in adapter_self._seen_messages.items()
if v > cutoff
}
# Always ignore our own messages # Always ignore our own messages
if message.author == self._client.user: if message.author == self._client.user:
@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True) logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
try: self._release_platform_lock()
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
return False return False
except Exception as e: # pragma: no cover - defensive logging except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True) logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
try: self._release_platform_lock()
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
return False return False
async def disconnect(self) -> None: async def disconnect(self) -> None:
@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
self._client = None self._client = None
self._ready_event.clear() self._ready_event.clear()
# Release the token lock self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
logger.info("[%s] Disconnected", self.name) logger.info("[%s] Disconnected", self.name)
@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so follow-ups don't require @mention # Track thread participation so follow-ups don't require @mention
if thread_id: if thread_id:
self._track_thread(thread_id) self._threads.mark(thread_id)
# If a message was provided, kick off a new Hermes session in the thread # If a message was provided, kick off a new Hermes session in the thread
starter = (message or "").strip() starter = (message or "").strip()
@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}" return f"{parent_name} / {thread_name}"
return thread_name return thread_name
# ------------------------------------------------------------------
# Thread participation persistence
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "discord_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load discord thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
# Trim to most recent entries if over cap
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save discord thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
async def _handle_message(self, message: DiscordMessage) -> None: async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages.""" """Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned # In server channels (not DMs), require the bot to be @mentioned
@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Skip the mention check if the message is in a thread where # Skip the mention check if the message is in a thread where
# the bot has previously participated (auto-created or replied in). # the bot has previously participated (auto-created or replied in).
in_bot_thread = is_thread and thread_id in self._bot_participated_threads in_bot_thread = is_thread and thread_id in self._threads
if require_mention and not is_free_channel and not in_bot_thread: if require_mention and not is_free_channel and not in_bot_thread:
if self._client.user not in message.mentions: if self._client.user not in message.mentions:
@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_thread = True is_thread = True
thread_id = str(thread.id) thread_id = str(thread.id)
auto_threaded_channel = thread auto_threaded_channel = thread
self._track_thread(thread_id) self._threads.mark(thread_id)
# Determine message type # Determine message type
msg_type = MessageType.TEXT msg_type = MessageType.TEXT
@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so the bot won't require @mention for # Track thread participation so the bot won't require @mention for
# follow-up messages in threads it has already engaged in. # follow-up messages in threads it has already engaged in.
if thread_id: if thread_id:
self._track_thread(thread_id) self._threads.mark(thread_id)
# Only batch plain text messages — commands, media, etc. dispatch # Only batch plain text messages — commands, media, etc. dispatch
# immediately since they won't be split by the Discord client. # immediately since they won't be split by the Discord client.

View file

@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
def _strip_markdown_to_plain_text(text: str) -> str: def _strip_markdown_to_plain_text(text: str) -> str:
"""Strip markdown formatting to plain text for Feishu text fallbacks.
Delegates common markdown stripping to the shared helper and adds
Feishu-specific patterns (blockquotes, strikethrough, underline tags,
horizontal rules, \\r\\n normalisation).
"""
from gateway.platforms.helpers import strip_markdown
plain = text.replace("\r\n", "\n") plain = text.replace("\r\n", "\n")
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain) plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE) plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE) plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain) plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain) plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain)
plain = re.sub(r"\n{3,}", "\n\n", plain) plain = strip_markdown(plain)
return plain.strip() return plain
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]: def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:

View file

@ -0,0 +1,261 @@
"""Shared helper classes for gateway platform adapters.
Extracts common patterns that were duplicated across 5-7 adapters:
message deduplication, text batch aggregation, markdown stripping,
and thread participation tracking.
"""
import asyncio
import json
import logging
import re
import time
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional
if TYPE_CHECKING:
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
logger = logging.getLogger(__name__)
# ─── Message Deduplication ────────────────────────────────────────────────────
class MessageDeduplicator:
"""TTL-based message deduplication cache.
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
previously duplicated in discord, slack, dingtalk, wecom, weixin,
mattermost, and feishu adapters.
Usage::
self._dedup = MessageDeduplicator()
# In message handler:
if self._dedup.is_duplicate(msg_id):
return
"""
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
self._seen: Dict[str, float] = {}
self._max_size = max_size
self._ttl = ttl_seconds
def is_duplicate(self, msg_id: str) -> bool:
"""Return True if *msg_id* was already seen within the TTL window."""
if not msg_id:
return False
now = time.time()
if msg_id in self._seen:
return True
self._seen[msg_id] = now
if len(self._seen) > self._max_size:
cutoff = now - self._ttl
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
return False
def clear(self):
"""Clear all tracked messages."""
self._seen.clear()
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
class TextBatchAggregator:
"""Aggregates rapid-fire text events into single messages.
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
previously duplicated in telegram, discord, matrix, wecom, and feishu.
Usage::
self._text_batcher = TextBatchAggregator(
handler=self._message_handler,
batch_delay=0.6,
split_threshold=1900,
)
# In message dispatch:
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
self._text_batcher.enqueue(event, session_key)
return
"""
def __init__(
self,
handler,
*,
batch_delay: float = 0.6,
split_delay: float = 2.0,
split_threshold: int = 4000,
):
self._handler = handler
self._batch_delay = batch_delay
self._split_delay = split_delay
self._split_threshold = split_threshold
self._pending: Dict[str, "MessageEvent"] = {}
self._pending_tasks: Dict[str, asyncio.Task] = {}
def is_enabled(self) -> bool:
"""Return True if batching is active (delay > 0)."""
return self._batch_delay > 0
def enqueue(self, event: "MessageEvent", key: str) -> None:
"""Add *event* to the pending batch for *key*."""
chunk_len = len(event.text or "")
existing = self._pending.get(key)
if not existing:
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
self._pending[key] = event
else:
existing.text = f"{existing.text}\n{event.text}"
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
# Cancel prior flush timer, start a new one
prior = self._pending_tasks.get(key)
if prior and not prior.done():
prior.cancel()
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
async def _flush(self, key: str) -> None:
"""Wait then dispatch the batched event for *key*."""
current_task = self._pending_tasks.get(key)
pending = self._pending.get(key)
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
# Use longer delay when the last chunk looks like a split message
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
await asyncio.sleep(delay)
event = self._pending.pop(key, None)
if event:
try:
await self._handler(event)
except Exception:
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
if self._pending_tasks.get(key) is current_task:
self._pending_tasks.pop(key, None)
def cancel_all(self) -> None:
"""Cancel all pending flush tasks."""
for task in self._pending_tasks.values():
if not task.done():
task.cancel()
self._pending_tasks.clear()
self._pending.clear()
# ─── Markdown Stripping ──────────────────────────────────────────────────────
# Pre-compiled regexes for performance
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
def strip_markdown(text: str) -> str:
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
Replaces the identical ``_strip_markdown()`` functions previously
duplicated in sms.py, bluebubbles.py, and feishu.py.
"""
text = _RE_BOLD.sub(r"\1", text)
text = _RE_ITALIC_STAR.sub(r"\1", text)
text = _RE_BOLD_UNDER.sub(r"\1", text)
text = _RE_ITALIC_UNDER.sub(r"\1", text)
text = _RE_CODE_BLOCK.sub("", text)
text = _RE_INLINE_CODE.sub(r"\1", text)
text = _RE_HEADING.sub("", text)
text = _RE_LINK.sub(r"\1", text)
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
return text.strip()
# ─── Thread Participation Tracking ───────────────────────────────────────────
class ThreadParticipationTracker:
"""Persistent tracking of threads the bot has participated in.
Replaces the identical ``_load/_save_participated_threads`` +
``_mark_thread_participated`` pattern previously duplicated in
discord.py and matrix.py.
Usage::
self._threads = ThreadParticipationTracker("discord")
# Check membership:
if thread_id in self._threads:
...
# Mark participation:
self._threads.mark(thread_id)
"""
_MAX_TRACKED = 500
def __init__(self, platform_name: str, max_tracked: int = 500):
self._platform = platform_name
self._max_tracked = max_tracked
self._threads: set = self._load()
def _state_path(self) -> Path:
from hermes_constants import get_hermes_home
return get_hermes_home() / f"{self._platform}_threads.json"
def _load(self) -> set:
path = self._state_path()
if path.exists():
try:
return set(json.loads(path.read_text(encoding="utf-8")))
except Exception:
pass
return set()
def _save(self) -> None:
path = self._state_path()
path.parent.mkdir(parents=True, exist_ok=True)
thread_list = list(self._threads)
if len(thread_list) > self._max_tracked:
thread_list = thread_list[-self._max_tracked:]
self._threads = set(thread_list)
path.write_text(json.dumps(thread_list), encoding="utf-8")
def mark(self, thread_id: str) -> None:
"""Mark *thread_id* as participated and persist."""
if thread_id not in self._threads:
self._threads.add(thread_id)
self._save()
def __contains__(self, thread_id: str) -> bool:
return thread_id in self._threads
def clear(self) -> None:
self._threads.clear()
# ─── Phone Number Redaction ──────────────────────────────────────────────────
def redact_phone(phone: str) -> str:
"""Redact a phone number for logging, preserving country code and last 4.
Replaces the identical ``_redact_phone()`` functions in signal.py,
sms.py, and bluebubbles.py.
"""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
return phone[:4] + "****" + phone[-4:]

View file

@ -92,6 +92,7 @@ from gateway.platforms.base import (
ProcessingOutcome, ProcessingOutcome,
SendResult, SendResult,
) )
from gateway.platforms.helpers import ThreadParticipationTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
self._pending_megolm: list = [] self._pending_megolm: list = []
# Thread participation tracking (for require_mention bypass) # Thread participation tracking (for require_mention bypass)
self._bot_participated_threads: set = self._load_participated_threads() self._threads = ThreadParticipationTracker("matrix")
self._MAX_TRACKED_THREADS = 500
# Mention/thread gating — parsed once from env vars. # Mention/thread gating — parsed once from env vars.
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Require-mention gating. # Require-mention gating.
if not is_dm: if not is_dm:
is_free_room = room_id in self._free_rooms is_free_room = room_id in self._free_rooms
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) in_bot_thread = bool(thread_id and thread_id in self._threads)
if self._require_mention and not is_free_room and not in_bot_thread: if self._require_mention and not is_free_room and not in_bot_thread:
if not is_mentioned: if not is_mentioned:
return None return None
@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
# DM mention-thread. # DM mention-thread.
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
thread_id = event_id thread_id = event_id
self._track_thread(thread_id) self._threads.mark(thread_id)
# Strip mention from body. # Strip mention from body.
if is_mentioned: if is_mentioned:
@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Auto-thread. # Auto-thread.
if not is_dm and not thread_id and self._auto_thread: if not is_dm and not thread_id and self._auto_thread:
thread_id = event_id thread_id = event_id
self._track_thread(thread_id) self._threads.mark(thread_id)
display_name = await self._get_display_name(room_id, sender) display_name = await self._get_display_name(room_id, sender)
source = self.build_source( source = self.build_source(
@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
) )
if thread_id: if thread_id:
self._track_thread(thread_id) self._threads.mark(thread_id)
self._background_read_receipt(room_id, event_id) self._background_read_receipt(room_id, event_id)
@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
for rid in self._joined_rooms for rid in self._joined_rooms
} }
# ------------------------------------------------------------------
# Thread participation tracking
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "matrix_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load matrix thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save matrix thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Mention detection helpers # Mention detection helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View file

@ -18,11 +18,11 @@ import json
import logging import logging
import os import os
import re import re
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
or os.getenv("MATTERMOST_REPLY_MODE", "off") or os.getenv("MATTERMOST_REPLY_MODE", "off")
).lower() ).lower()
# Dedup cache: post_id → timestamp (prevent reprocessing) # Dedup cache (prevent reprocessing)
self._seen_posts: Dict[str, float] = {} self._dedup = MessageDeduplicator()
self._SEEN_MAX = 2000
self._SEEN_TTL = 300 # 5 minutes
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# HTTP helpers # HTTP helpers
@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
post_id = post.get("id", "") post_id = post.get("id", "")
# Dedup. # Dedup.
self._prune_seen() if self._dedup.is_duplicate(post_id):
if post_id in self._seen_posts:
return return
self._seen_posts[post_id] = time.time()
# Build message event. # Build message event.
channel_id = post.get("channel_id", "") channel_id = post.get("channel_id", "")
@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
await self.handle_message(msg_event) await self.handle_message(msg_event)
def _prune_seen(self) -> None:
"""Remove expired entries from the dedup cache."""
if len(self._seen_posts) < self._SEEN_MAX:
return
now = time.time()
self._seen_posts = {
pid: ts
for pid, ts in self._seen_posts.items()
if now - ts < self._SEEN_TTL
}

View file

@ -37,6 +37,7 @@ from gateway.platforms.base import (
cache_document_from_bytes, cache_document_from_bytes,
cache_image_from_url, cache_image_from_url,
) )
from gateway.platforms.helpers import redact_phone
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
# E.164 phone number pattern for redaction
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
return phone[:4] + "****" + phone[-4:]
def _parse_comma_list(value: str) -> List[str]: def _parse_comma_list(value: str) -> List[str]:
"""Split a comma-separated string into a list, stripping whitespace.""" """Split a comma-separated string into a list, stripping whitespace."""
@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
self._recent_sent_timestamps: set = set() self._recent_sent_timestamps: set = set()
self._max_recent_timestamps = 50 self._max_recent_timestamps = 50
self._phone_lock_identity: Optional[str] = None
logger.info("Signal adapter initialized: url=%s account=%s groups=%s", logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
self.http_url, _redact_phone(self.account), self.http_url, redact_phone(self.account),
"enabled" if self.group_allow_from else "disabled") "enabled" if self.group_allow_from else "disabled")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone # Acquire scoped lock to prevent duplicate Signal listeners for the same phone
try: try:
from gateway.status import acquire_scoped_lock if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
self._phone_lock_identity = self.account
acquired, existing = acquire_scoped_lock(
"signal-phone",
self._phone_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Signal account"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Signal listener."
)
logger.error("Signal: %s", message)
self._set_fatal_error("signal_phone_lock", message, retryable=False)
return False return False
except Exception as e: except Exception as e:
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e) logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
await self.client.aclose() await self.client.aclose()
self.client = None self.client = None
if self._phone_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("signal-phone", self._phone_lock_identity)
except Exception as e:
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
self._phone_lock_identity = None
logger.info("Signal: disconnected") logger.info("Signal: disconnected")
@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
) )
logger.debug("Signal: message from %s in %s: %s", logger.debug("Signal: message from %s in %s: %s",
_redact_phone(sender), chat_id[:20], (text or "")[:50]) redact_phone(sender), chat_id[:20], (text or "")[:50])
await self.handle_message(event) await self.handle_message(event)

View file

@ -33,6 +33,7 @@ from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
self._channel_team: Dict[str, str] = {} # channel_id → team_id self._channel_team: Dict[str, str] = {} # channel_id → team_id
# Dedup cache: event_ts → timestamp. Prevents duplicate bot # Dedup cache: prevents duplicate bot responses when Socket Mode
# responses when Socket Mode reconnects redeliver events. # reconnects redeliver events.
self._seen_messages: Dict[str, float] = {} self._dedup = MessageDeduplicator()
self._SEEN_TTL = 300 # 5 minutes
self._SEEN_MAX = 2000 # prune threshold
# Track pending approval message_ts → resolved flag to prevent # Track pending approval message_ts → resolved flag to prevent
# double-clicks on approval buttons. # double-clicks on approval buttons.
self._approval_resolved: Dict[str, bool] = {} self._approval_resolved: Dict[str, bool] = {}
@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e) logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
try: try:
# Acquire scoped lock to prevent duplicate app token usage if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
from gateway.status import acquire_scoped_lock
self._token_lock_identity = app_token
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
if not acquired:
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
logger.error('[%s] %s', self.name, message)
self._set_fatal_error('slack_token_lock', message, retryable=False)
return False return False
# First token is the primary — used for AsyncApp / Socket Mode # First token is the primary — used for AsyncApp / Socket Mode
@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True) logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
self._running = False self._running = False
# Release the token lock (use stored identity, not re-read env) self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('slack-app-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
logger.info("[Slack] Disconnected") logger.info("[Slack] Disconnected")
@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
"""Handle an incoming Slack message event.""" """Handle an incoming Slack message event."""
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777) # Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
event_ts = event.get("ts", "") event_ts = event.get("ts", "")
if event_ts: if event_ts and self._dedup.is_duplicate(event_ts):
now = time.time()
if event_ts in self._seen_messages:
return return
self._seen_messages[event_ts] = now
if len(self._seen_messages) > self._SEEN_MAX:
cutoff = now - self._SEEN_TTL
self._seen_messages = {
k: v for k, v in self._seen_messages.items()
if v > cutoff
}
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots): # Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
# "none" — ignore all bot messages (default, backward-compatible) # "none" — ignore all bot messages (default, backward-compatible)

View file

@ -19,7 +19,6 @@ import asyncio
import base64 import base64
import logging import logging
import os import os
import re
import urllib.parse import urllib.parse
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -30,6 +29,7 @@ from gateway.platforms.base import (
MessageType, MessageType,
SendResult, SendResult,
) )
from gateway.platforms.helpers import redact_phone, strip_markdown
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,18 +37,6 @@ TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
MAX_SMS_LENGTH = 1600 # ~10 SMS segments MAX_SMS_LENGTH = 1600 # ~10 SMS segments
DEFAULT_WEBHOOK_PORT = 8080 DEFAULT_WEBHOOK_PORT = 8080
# E.164 phone number pattern for redaction
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
return phone[:5] + "***" + phone[-4:]
def check_sms_requirements() -> bool: def check_sms_requirements() -> bool:
"""Check if SMS adapter dependencies are available.""" """Check if SMS adapter dependencies are available."""
@ -114,7 +102,7 @@ class SmsAdapter(BasePlatformAdapter):
logger.info( logger.info(
"[sms] Twilio webhook server listening on port %d, from: %s", "[sms] Twilio webhook server listening on port %d, from: %s",
self._webhook_port, self._webhook_port,
_redact_phone(self._from_number), redact_phone(self._from_number),
) )
return True return True
@ -163,7 +151,7 @@ class SmsAdapter(BasePlatformAdapter):
error_msg = body.get("message", str(body)) error_msg = body.get("message", str(body))
logger.error( logger.error(
"[sms] send failed to %s: %s %s", "[sms] send failed to %s: %s %s",
_redact_phone(chat_id), redact_phone(chat_id),
resp.status, resp.status,
error_msg, error_msg,
) )
@ -174,7 +162,7 @@ class SmsAdapter(BasePlatformAdapter):
msg_sid = body.get("sid", "") msg_sid = body.get("sid", "")
last_result = SendResult(success=True, message_id=msg_sid) last_result = SendResult(success=True, message_id=msg_sid)
except Exception as e: except Exception as e:
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e) logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
return SendResult(success=False, error=str(e)) return SendResult(success=False, error=str(e))
finally: finally:
# Close session only if we created a fallback (no persistent session) # Close session only if we created a fallback (no persistent session)
@ -192,16 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
def format_message(self, content: str) -> str: def format_message(self, content: str) -> str:
"""Strip markdown — SMS renders it as literal characters.""" """Strip markdown — SMS renders it as literal characters."""
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL) return strip_markdown(content)
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
content = re.sub(r"```[a-z]*\n?", "", content)
content = re.sub(r"`(.+?)`", r"\1", content)
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
content = re.sub(r"\n{3,}", "\n\n", content)
return content.strip()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Twilio webhook handler # Twilio webhook handler
@ -236,7 +215,7 @@ class SmsAdapter(BasePlatformAdapter):
# Ignore messages from our own number (echo prevention) # Ignore messages from our own number (echo prevention)
if from_number == self._from_number: if from_number == self._from_number:
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number)) logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
return web.Response( return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>', text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml", content_type="application/xml",
@ -244,8 +223,8 @@ class SmsAdapter(BasePlatformAdapter):
logger.info( logger.info(
"[sms] inbound from %s -> %s: %s", "[sms] inbound from %s -> %s: %s",
_redact_phone(from_number), redact_phone(from_number),
_redact_phone(to_number), redact_phone(to_number),
text[:80], text[:80],
) )

View file

@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batches: Dict[str, MessageEvent] = {}
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
self._token_lock_identity: Optional[str] = None
self._polling_error_task: Optional[asyncio.Task] = None self._polling_error_task: Optional[asyncio.Task] = None
self._polling_conflict_count: int = 0 self._polling_conflict_count: int = 0
self._polling_network_error_count: int = 0 self._polling_network_error_count: int = 0
@ -497,23 +496,7 @@ class TelegramAdapter(BasePlatformAdapter):
return False return False
try: try:
from gateway.status import acquire_scoped_lock if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
self._token_lock_identity = self.config.token
acquired, existing = acquire_scoped_lock(
"telegram-bot-token",
self._token_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Telegram bot token"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Telegram poller."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("telegram_token_lock", message, retryable=False)
return False return False
# Build the application # Build the application
@ -737,12 +720,7 @@ class TelegramAdapter(BasePlatformAdapter):
return True return True
except Exception as e: except Exception as e:
if self._token_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
except Exception:
pass
message = f"Telegram startup failed: {e}" message = f"Telegram startup failed: {e}"
self._set_fatal_error("telegram_connect_error", message, retryable=True) self._set_fatal_error("telegram_connect_error", message, retryable=True)
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True) logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
@ -768,12 +746,7 @@ class TelegramAdapter(BasePlatformAdapter):
await self._app.shutdown() await self._app.shutdown()
except Exception as e: except Exception as e:
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True) logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
if self._token_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
except Exception as e:
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
for task in self._pending_photo_batch_tasks.values(): for task in self._pending_photo_batch_tasks.values():
if task and not task.done(): if task and not task.done():
@ -784,7 +757,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._mark_disconnected() self._mark_disconnected()
self._app = None self._app = None
self._bot = None self._bot = None
self._token_lock_identity = None
logger.info("[%s] Disconnected from Telegram", self.name) logger.info("[%s] Disconnected from Telegram", self.name)
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool: def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:

View file

@ -59,6 +59,7 @@ except ImportError:
httpx = None # type: ignore[assignment] httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
HEARTBEAT_INTERVAL_SECONDS = 30.0 HEARTBEAT_INTERVAL_SECONDS = 30.0
RECONNECT_BACKOFF = [2, 5, 10, 30, 60] RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000 DEDUP_MAX_SIZE = 1000
IMAGE_MAX_BYTES = 10 * 1024 * 1024 IMAGE_MAX_BYTES = 10 * 1024 * 1024
@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
self._listen_task: Optional[asyncio.Task] = None self._listen_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None self._heartbeat_task: Optional[asyncio.Task] = None
self._pending_responses: Dict[str, asyncio.Future] = {} self._pending_responses: Dict[str, asyncio.Future] = {}
self._seen_messages: Dict[str, float] = {} self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
self._reply_req_ids: Dict[str, str] = {} self._reply_req_ids: Dict[str, str] = {}
# Text batching: merge rapid successive messages (Telegram-style). # Text batching: merge rapid successive messages (Telegram-style).
@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
await self._http_client.aclose() await self._http_client.aclose()
self._http_client = None self._http_client = None
self._seen_messages.clear() self._dedup.clear()
logger.info("[%s] Disconnected", self.name) logger.info("[%s] Disconnected", self.name)
async def _cleanup_ws(self) -> None: async def _cleanup_ws(self) -> None:
@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
return return
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex) msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
if self._is_duplicate(msg_id): if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id) logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
return return
self._remember_reply_req_id(msg_id, self._payload_req_id(payload)) self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
@ -839,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
wildcard = self._groups.get("*") wildcard = self._groups.get("*")
return wildcard if isinstance(wildcard, dict) else {} return wildcard if isinstance(wildcard, dict) else {}
def _is_duplicate(self, msg_id: str) -> bool:
now = time.time()
if len(self._seen_messages) > DEDUP_MAX_SIZE:
cutoff = now - DEDUP_WINDOW_SECONDS
self._seen_messages = {
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
}
if self._reply_req_ids:
self._reply_req_ids = {
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
}
if msg_id in self._seen_messages:
return True
self._seen_messages[msg_id] = now
return False
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None: def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
normalized_message_id = str(message_id or "").strip() normalized_message_id = str(message_id or "").strip()
normalized_req_id = str(req_id or "").strip() normalized_req_id = str(req_id or "").strip()

View file

@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
CRYPTO_AVAILABLE = False CRYPTO_AVAILABLE = False
from gateway.config import Platform, PlatformConfig from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import ( from gateway.platforms.base import (
BasePlatformAdapter, BasePlatformAdapter,
MessageEvent, MessageEvent,
@ -1008,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
self._typing_cache = TypingTicketCache() self._typing_cache = TypingTicketCache()
self._session: Optional[aiohttp.ClientSession] = None self._session: Optional[aiohttp.ClientSession] = None
self._poll_task: Optional[asyncio.Task] = None self._poll_task: Optional[asyncio.Task] = None
self._seen_messages: Dict[str, float] = {} self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
self._token_lock_identity: Optional[str] = None
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip() self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip() self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
@ -1067,23 +1067,7 @@ class WeixinAdapter(BasePlatformAdapter):
return False return False
try: try:
from gateway.status import acquire_scoped_lock if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
self._token_lock_identity = self._token
acquired, existing = acquire_scoped_lock(
"weixin-bot-token",
self._token_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Weixin token"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Weixin poller."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("weixin_token_lock", message, retryable=False)
return False return False
except Exception as exc: except Exception as exc:
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc) logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
@ -1107,12 +1091,7 @@ class WeixinAdapter(BasePlatformAdapter):
if self._session and not self._session.closed: if self._session and not self._session.closed:
await self._session.close() await self._session.close()
self._session = None self._session = None
if self._token_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("weixin-bot-token", self._token_lock_identity)
except Exception as exc:
logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
self._mark_disconnected() self._mark_disconnected()
logger.info("[%s] Disconnected", self.name) logger.info("[%s] Disconnected", self.name)
@ -1190,16 +1169,8 @@ class WeixinAdapter(BasePlatformAdapter):
return return
message_id = str(message.get("message_id") or "").strip() message_id = str(message.get("message_id") or "").strip()
if message_id: if message_id and self._dedup.is_duplicate(message_id):
now = time.time()
self._seen_messages = {
key: value
for key, value in self._seen_messages.items()
if now - value < MESSAGE_DEDUP_TTL_SECONDS
}
if message_id in self._seen_messages:
return return
self._seen_messages[message_id] = now
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id) chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
if chat_type == "group": if chat_type == "group":

View file

@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._bridge_log: Optional[Path] = None self._bridge_log: Optional[Path] = None
self._poll_task: Optional[asyncio.Task] = None self._poll_task: Optional[asyncio.Task] = None
self._http_session: Optional["aiohttp.ClientSession"] = None self._http_session: Optional["aiohttp.ClientSession"] = None
self._session_lock_identity: Optional[str] = None
def _whatsapp_require_mention(self) -> bool: def _whatsapp_require_mention(self) -> bool:
configured = self.config.extra.get("require_mention") configured = self.config.extra.get("require_mention")
@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate sessions # Acquire scoped lock to prevent duplicate sessions
try: try:
from gateway.status import acquire_scoped_lock if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
self._session_lock_identity = str(self._session_path)
acquired, existing = acquire_scoped_lock(
"whatsapp-session",
self._session_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this WhatsApp session"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second WhatsApp bridge."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
return False return False
except Exception as e: except Exception as e:
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e) logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
return True return True
except Exception as e: except Exception as e:
if self._session_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("whatsapp-session", self._session_lock_identity)
except Exception:
pass
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
self._close_bridge_log() self._close_bridge_log()
return False return False
@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
await self._http_session.close() await self._http_session.close()
self._http_session = None self._http_session = None
if self._session_lock_identity: self._release_platform_lock()
try:
from gateway.status import release_scoped_lock
release_scoped_lock("whatsapp-session", self._session_lock_identity)
except Exception as e:
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
self._mark_disconnected() self._mark_disconnected()
self._bridge_process = None self._bridge_process = None
self._close_bridge_log() self._close_bridge_log()
self._session_lock_identity = None
print(f"[{self.name}] Disconnected") print(f"[{self.name}] Disconnected")
async def send( async def send(

View file

@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
} }
# =============================================================================
# Anthropic Key Helper
# =============================================================================
def get_anthropic_key() -> str:
"""Return the first usable Anthropic credential, or ``""``.
Checks both the ``.env`` file (via ``get_env_value``) and the process
environment (``os.getenv``). The fallback order mirrors the
``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
"""
from hermes_cli.config import get_env_value
for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
value = get_env_value(var) or os.getenv(var, "")
if value:
return value
return ""
# ============================================================================= # =============================================================================
# Kimi Code Endpoint Detection # Kimi Code Endpoint Detection
# ============================================================================= # =============================================================================

79
hermes_cli/cli_output.py Normal file
View file

@ -0,0 +1,79 @@
"""Shared CLI output helpers for Hermes CLI modules.
Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
functions previously duplicated across setup.py, tools_config.py,
mcp_config.py, and memory_setup.py.
"""
import getpass
import sys
from hermes_cli.colors import Colors, color
# ─── Print Helpers ────────────────────────────────────────────────────────────
def print_info(text: str) -> None:
"""Print a dim informational message."""
print(color(f" {text}", Colors.DIM))
def print_success(text: str) -> None:
"""Print a green success message with ✓ prefix."""
print(color(f"{text}", Colors.GREEN))
def print_warning(text: str) -> None:
"""Print a yellow warning message with ⚠ prefix."""
print(color(f"{text}", Colors.YELLOW))
def print_error(text: str) -> None:
"""Print a red error message with ✗ prefix."""
print(color(f"{text}", Colors.RED))
def print_header(text: str) -> None:
"""Print a bold yellow header."""
print(color(f"\n {text}", Colors.YELLOW))
# ─── Input Prompts ────────────────────────────────────────────────────────────
def prompt(
question: str,
default: str | None = None,
password: bool = False,
) -> str:
"""Prompt the user for input with optional default and password masking.
Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
Returns the user's input (stripped), or *default* if the user presses Enter.
Returns empty string on Ctrl-C or EOF.
"""
suffix = f" [{default}]" if default else ""
display = color(f" {question}{suffix}: ", Colors.YELLOW)
try:
if password:
value = getpass.getpass(display)
else:
value = input(display)
value = value.strip()
return value if value else (default or "")
except (KeyboardInterrupt, EOFError):
print()
return ""
def prompt_yes_no(question: str, default: bool = True) -> bool:
"""Prompt for a yes/no answer. Returns bool."""
hint = "Y/n" if default else "y/N"
answer = prompt(f"{question} ({hint})")
if not answer:
return default
return answer.lower().startswith("y")

View file

@ -2582,7 +2582,8 @@ def show_config():
for env_key, name in keys: for env_key, name in keys:
value = get_env_value(env_key) value = get_env_value(env_key)
print(f" {name:<14} {redact_key(value)}") print(f" {name:<14} {redact_key(value)}")
anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY") from hermes_cli.auth import get_anthropic_key
anthropic_value = get_anthropic_key()
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}") print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
# Model settings # Model settings
@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str):
# Write only user config back (not the full merged defaults) # Write only user config back (not the full merged defaults)
ensure_hermes_home() ensure_hermes_home()
with open(config_path, 'w', encoding="utf-8") as f: from utils import atomic_yaml_write
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) atomic_yaml_write(config_path, user_config, sort_keys=False)
# Keep .env in sync for keys that terminal_tool reads directly from env vars. # Keep .env in sync for keys that terminal_tool reads directly from env vars.
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc. # config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.

View file

@ -336,8 +336,8 @@ def run_doctor(args):
model_section[k] = raw_config.pop(k) model_section[k] = raw_config.pop(k)
else: else:
raw_config.pop(k) raw_config.pop(k)
with open(config_path, "w") as f: from utils import atomic_yaml_write
yaml.dump(raw_config, f, default_flow_style=False) atomic_yaml_write(config_path, raw_config)
check_ok("Migrated stale root-level keys into model section") check_ok("Migrated stale root-level keys into model section")
fixed_count += 1 fixed_count += 1
else: else:
@ -686,7 +686,8 @@ def run_doctor(args):
else: else:
check_warn("OpenRouter API", "(not configured)") check_warn("OpenRouter API", "(not configured)")
anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY") from hermes_cli.auth import get_anthropic_key
anthropic_key = get_anthropic_key()
if anthropic_key: if anthropic_key:
print(" Checking Anthropic API...", end="", flush=True) print(" Checking Anthropic API...", end="", flush=True)
try: try:

View file

@ -2549,13 +2549,8 @@ def _model_flow_anthropic(config, current_model=""):
from hermes_cli.models import _PROVIDER_MODELS from hermes_cli.models import _PROVIDER_MODELS
# Check ALL credential sources # Check ALL credential sources
existing_key = ( from hermes_cli.auth import get_anthropic_key
get_env_value("ANTHROPIC_TOKEN") existing_key = get_anthropic_key()
or os.getenv("ANTHROPIC_TOKEN", "")
or get_env_value("ANTHROPIC_API_KEY")
or os.getenv("ANTHROPIC_API_KEY", "")
or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
)
cc_available = False cc_available = False
try: try:
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid

View file

@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool:
def _prompt(question: str, *, password: bool = False, default: str = "") -> str: def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
display = f" {question}" from hermes_cli.cli_output import prompt as _shared_prompt
if default: return _shared_prompt(question, default=default, password=password)
display += f" [{default}]"
display += ": "
try:
if password:
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default
except (KeyboardInterrupt, EOFError):
print()
return default
# ─── Config Helpers ─────────────────────────────────────────────────────────── # ─── Config Helpers ───────────────────────────────────────────────────────────

View file

@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
items: list of (label, description) tuples. items: list of (label, description) tuples.
Returns selected index, or default on escape/quit. Returns selected index, or default on escape/quit.
""" """
try: from hermes_cli.curses_ui import curses_radiolist
import curses # Format (label, desc) tuples into display strings
result = [default] display_items = [
f"{label} {desc}" if desc else label
def _menu(stdscr): for label, desc in items
curses.curs_set(0) ]
if curses.has_colors(): return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
curses.init_pair(3, curses.COLOR_CYAN, -1)
cursor = default
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Title
try:
stdscr.addnstr(0, 0, title, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
except curses.error:
pass
for i, (label, desc) in enumerate(items):
y = i + 3
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {label}"
if desc:
line += f" {desc}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
cursor = (cursor - 1) % len(items)
elif key in (curses.KEY_DOWN, ord('j')):
cursor = (cursor + 1) % len(items)
elif key in (curses.KEY_ENTER, 10, 13):
result[0] = cursor
return
elif key in (27, ord('q')):
return
curses.wrapper(_menu)
return result[0]
except Exception:
# Fallback: numbered input
print(f"\n {title}\n")
for i, (label, desc) in enumerate(items):
marker = "" if i == default else " "
d = f" {desc}" if desc else ""
print(f" {marker} {i + 1}. {label}{d}")
while True:
try:
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(items):
return idx
except (ValueError, EOFError):
return default
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str: def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:

45
hermes_cli/platforms.py Normal file
View file

@ -0,0 +1,45 @@
"""
Shared platform registry for Hermes Agent.
Single source of truth for platform metadata consumed by both
skills_config (label display) and tools_config (default toolset
resolution). Import ``PLATFORMS`` from here instead of maintaining
duplicate dicts in each module.
"""
from collections import OrderedDict
from typing import NamedTuple
class PlatformInfo(NamedTuple):
"""Metadata for a single platform entry."""
label: str
default_toolset: str
# Ordered so that TUI menus are deterministic.
PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
])
def platform_label(key: str, default: str = "") -> str:
"""Return the display label for a platform key, or *default*."""
info = PLATFORMS.get(key)
return info.label if info is not None else default

View file

@ -197,24 +197,12 @@ def print_header(title: str):
print(color(f"{title}", Colors.CYAN, Colors.BOLD)) print(color(f"{title}", Colors.CYAN, Colors.BOLD))
def print_info(text: str): from hermes_cli.cli_output import ( # noqa: E402
"""Print info text.""" print_error,
print(color(f" {text}", Colors.DIM)) print_info,
print_success,
print_warning,
def print_success(text: str): )
"""Print success message."""
print(color(f"{text}", Colors.GREEN))
def print_warning(text: str):
"""Print warning message."""
print(color(f"{text}", Colors.YELLOW))
def print_error(text: str):
"""Print error message."""
print(color(f"{text}", Colors.RED))
def is_interactive_stdin() -> bool: def is_interactive_stdin() -> bool:
@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int: def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu using curses to avoid simple_term_menu rendering bugs.""" """Single-select menu using curses. Delegates to curses_radiolist."""
try: from hermes_cli.curses_ui import curses_radiolist
import curses return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
result_holder = [default]
def _curses_menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
cursor = default
scroll_offset = 0
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Rows available for list items: rows 2..(max_y-2) inclusive.
visible = max(1, max_y - 3)
# Scroll the viewport so the cursor is always visible.
if cursor < scroll_offset:
scroll_offset = cursor
elif cursor >= scroll_offset + visible:
scroll_offset = cursor - visible + 1
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
try:
stdscr.addnstr(
0,
0,
question,
max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
)
except curses.error:
pass
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
y = row + 2
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {choices[i]}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord("q")):
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
return -1

View file

@ -15,25 +15,12 @@ from typing import List, Optional, Set
from hermes_cli.config import load_config, save_config from hermes_cli.config import load_config, save_config
from hermes_cli.colors import Colors, color from hermes_cli.colors import Colors, color
from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
PLATFORMS = { # Backward-compatible view: {key: label_string} so existing code that
"cli": "🖥️ CLI", # iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
"telegram": "📱 Telegram", # working without changes to every call site.
"discord": "💬 Discord", PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
"slack": "💼 Slack",
"whatsapp": "📱 WhatsApp",
"signal": "📡 Signal",
"bluebubbles": "💬 BlueBubbles",
"email": "📧 Email",
"homeassistant": "🏠 Home Assistant",
"mattermost": "💬 Mattermost",
"matrix": "💬 Matrix",
"dingtalk": "💬 DingTalk",
"feishu": "🪽 Feishu",
"wecom": "💬 WeCom",
"weixin": "💬 Weixin",
"webhook": "🔗 Webhook",
}
# ─── Config Helpers ─────────────────────────────────────────────────────────── # ─── Config Helpers ───────────────────────────────────────────────────────────

View file

@ -141,11 +141,8 @@ def show_status(args):
display = redact_key(value) if not show_all else value display = redact_key(value) if not show_all else value
print(f" {name:<12} {check_mark(has_key)} {display}") print(f" {name:<12} {check_mark(has_key)} {display}")
anthropic_value = ( from hermes_cli.auth import get_anthropic_key
get_env_value("ANTHROPIC_TOKEN") anthropic_value = get_anthropic_key()
or get_env_value("ANTHROPIC_API_KEY")
or ""
)
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}") print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")

View file

@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ──────────────────────────────────────── # ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
def _print_info(text: str): from hermes_cli.cli_output import ( # noqa: E402 — late import block
print(color(f" {text}", Colors.DIM)) print_error as _print_error,
print_info as _print_info,
def _print_success(text: str): print_success as _print_success,
print(color(f"{text}", Colors.GREEN)) print_warning as _print_warning,
prompt as _prompt,
def _print_warning(text: str): )
print(color(f"{text}", Colors.YELLOW))
def _print_error(text: str):
print(color(f"{text}", Colors.RED))
def _prompt(question: str, default: str = None, password: bool = False) -> str:
if default:
display = f"{question} [{default}]: "
else:
display = f"{question}: "
try:
if password:
import getpass
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default or ""
except (KeyboardInterrupt, EOFError):
print()
return default or ""
# ─── Toolset Registry ───────────────────────────────────────────────────────── # ─── Toolset Registry ─────────────────────────────────────────────────────────
@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
except Exception: except Exception:
return set() return set()
# Platform display config # Platform display config — derived from the canonical registry so every
# module shares the same data. Kept as dict-of-dicts for backward
# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
PLATFORMS = { PLATFORMS = {
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"}, k: {"label": info.label, "default_toolset": info.default_toolset}
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"}, for k, info in _PLATFORMS_REGISTRY.items()
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
} }
@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
# ─── Menu Helpers ───────────────────────────────────────────────────────────── # ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int: def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu """Single-select menu (arrow keys). Delegates to curses_radiolist."""
rendering bugs in tmux, iTerm, and other non-standard terminals.""" from hermes_cli.curses_ui import curses_radiolist
return curses_radiolist(question, choices, selected=default, cancel_returns=default)
# Curses-based single-select — works in tmux, iTerm, and standard terminals
try:
import curses
result_holder = [default]
def _curses_menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
cursor = default
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
try:
stdscr.addnstr(0, 0, question, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
except curses.error:
pass
for i, c in enumerate(choices):
y = i + 2
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {c}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord('j')):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord('q')):
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
pass
# Fallback: numbered input (Windows without curses, etc.)
print(color(question, Colors.YELLOW))
for i, c in enumerate(choices):
marker = "" if i == default else ""
style = Colors.GREEN if i == default else ""
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
while True:
try:
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(choices):
return idx
except (ValueError, KeyboardInterrupt, EOFError):
print()
return default
# ─── Token Estimation ──────────────────────────────────────────────────────── # ─── Token Estimation ────────────────────────────────────────────────────────

View file

@ -189,6 +189,33 @@ def is_wsl() -> bool:
return _wsl_detected return _wsl_detected
# ─── Well-Known Paths ─────────────────────────────────────────────────────────
def get_config_path() -> Path:
"""Return the path to ``config.yaml`` under HERMES_HOME.
Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
"""
return get_hermes_home() / "config.yaml"
def get_skills_dir() -> Path:
"""Return the path to the skills directory under HERMES_HOME."""
return get_hermes_home() / "skills"
def get_logs_dir() -> Path:
"""Return the path to the logs directory under HERMES_HOME."""
return get_hermes_home() / "logs"
def get_env_path() -> Path:
"""Return the path to the ``.env`` file under HERMES_HOME."""
return get_hermes_home() / ".env"
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models" OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"

View file

@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from hermes_constants import get_hermes_home from hermes_constants import get_config_path, get_hermes_home
# Sentinel to track whether setup_logging() has already run. The function # Sentinel to track whether setup_logging() has already run. The function
# is idempotent — calling it twice is safe but the second call is a no-op # is idempotent — calling it twice is safe but the second call is a no-op
@ -246,7 +246,7 @@ def _read_logging_config():
""" """
try: try:
import yaml import yaml
config_path = get_hermes_home() / "config.yaml" config_path = get_config_path()
if config_path.exists(): if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {} cfg = yaml.safe_load(f) or {}

View file

@ -16,7 +16,7 @@ crashes due to a bad timezone string.
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
from hermes_constants import get_hermes_home from hermes_constants import get_config_path
from typing import Optional from typing import Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
# 2. config.yaml ``timezone`` key # 2. config.yaml ``timezone`` key
try: try:
import yaml import yaml
hermes_home = get_hermes_home() config_path = get_config_path()
config_path = hermes_home / "config.yaml"
if config_path.exists(): if config_path.exists():
with open(config_path) as f: with open(config_path) as f:
cfg = yaml.safe_load(f) or {} cfg = yaml.safe_load(f) or {}

View file

@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
config = PlatformConfig(enabled=True, token="e2e-test-token") config = PlatformConfig(enabled=True, token="e2e-test-token")
if platform == Platform.DISCORD: if platform == Platform.DISCORD:
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()): from gateway.platforms.helpers import ThreadParticipationTracker
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
adapter = DiscordAdapter(config) adapter = DiscordAdapter(config)
platform_key = Platform.DISCORD platform_key = Platform.DISCORD
elif platform == Platform.SLACK: elif platform == Platform.SLACK:

View file

@ -119,28 +119,29 @@ class TestDeduplication:
def test_first_message_not_duplicate(self): def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter = DingTalkAdapter(PlatformConfig(enabled=True))
assert adapter._is_duplicate("msg-1") is False assert adapter._dedup.is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self): def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1") adapter._dedup.is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True assert adapter._dedup.is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self): def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1") adapter._dedup.is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False assert adapter._dedup.is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self): def test_cache_cleanup_on_overflow(self):
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter = DingTalkAdapter(PlatformConfig(enabled=True))
max_size = adapter._dedup._max_size
# Fill beyond max # Fill beyond max
for i in range(DEDUP_MAX_SIZE + 10): for i in range(max_size + 10):
adapter._is_duplicate(f"msg-{i}") adapter._dedup.is_duplicate(f"msg-{i}")
# Cache should have been pruned # Cache should have been pruned
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10 assert len(adapter._dedup._seen) <= max_size + 10
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -253,13 +254,13 @@ class TestConnect:
from gateway.platforms.dingtalk import DingTalkAdapter from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True)) adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x" adapter._session_webhooks["a"] = "http://x"
adapter._seen_messages["b"] = 1.0 adapter._dedup._seen["b"] = 1.0
adapter._http_client = AsyncMock() adapter._http_client = AsyncMock()
adapter._stream_task = None adapter._stream_task = None
await adapter.disconnect() await adapter.disconnect()
assert len(adapter._session_webhooks) == 0 assert len(adapter._session_webhooks) == 0
assert len(adapter._seen_messages) == 0 assert len(adapter._dedup._seen) == 0
assert adapter._http_client is None assert adapter._http_client is None

View file

@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
assert ok is False assert ok is False
assert released == [("discord-bot-token", "test-token")] assert released == [("discord-bot-token", "test-token")]
assert adapter._token_lock_identity is None assert adapter._platform_lock_identity is None

View file

@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
# Simulate bot having previously participated in thread 456 # Simulate bot having previously participated in thread 456
adapter._bot_participated_threads.add("456") adapter._threads.mark("456")
thread = FakeThread(channel_id=456, name="existing thread") thread = FakeThread(channel_id=456, name="existing thread")
message = make_message(channel=thread, content="follow-up without mention") message = make_message(channel=thread, content="follow-up without mention")
@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
await adapter._handle_message(message) await adapter._handle_message(message)
assert "555" in adapter._bot_participated_threads assert "555" in adapter._threads
@pytest.mark.asyncio @pytest.mark.asyncio
@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
await adapter._handle_message(message) await adapter._handle_message(message)
assert "777" in adapter._bot_participated_threads assert "777" in adapter._threads

View file

@ -1,6 +1,6 @@
"""Tests for Discord thread participation persistence. """Tests for Discord thread participation persistence.
Verifies that _bot_participated_threads survives adapter restarts by Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
being persisted to ~/.hermes/discord_threads.json. being persisted to ~/.hermes/discord_threads.json.
""" """
@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
def test_starts_empty_when_no_state_file(self, tmp_path): def test_starts_empty_when_no_state_file(self, tmp_path):
adapter = self._make_adapter(tmp_path) adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set() assert "$nonexistent" not in adapter._threads
def test_track_thread_persists_to_disk(self, tmp_path): def test_track_thread_persists_to_disk(self, tmp_path):
adapter = self._make_adapter(tmp_path) adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111") adapter._threads.mark("111")
adapter._track_thread("222") adapter._threads.mark("222")
state_file = tmp_path / "discord_threads.json" state_file = tmp_path / "discord_threads.json"
assert state_file.exists() assert state_file.exists()
@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
"""Threads tracked by one adapter instance are visible to the next.""" """Threads tracked by one adapter instance are visible to the next."""
adapter1 = self._make_adapter(tmp_path) adapter1 = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter1._track_thread("aaa") adapter1._threads.mark("aaa")
adapter1._track_thread("bbb") adapter1._threads.mark("bbb")
adapter2 = self._make_adapter(tmp_path) adapter2 = self._make_adapter(tmp_path)
assert "aaa" in adapter2._bot_participated_threads assert "aaa" in adapter2._threads
assert "bbb" in adapter2._bot_participated_threads assert "bbb" in adapter2._threads
def test_duplicate_track_does_not_double_save(self, tmp_path): def test_duplicate_track_does_not_double_save(self, tmp_path):
adapter = self._make_adapter(tmp_path) adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111") adapter._threads.mark("111")
adapter._track_thread("111") # no-op adapter._threads.mark("111") # no-op
saved = json.loads((tmp_path / "discord_threads.json").read_text()) saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert saved.count("111") == 1 assert saved.count("111") == 1
def test_caps_at_max_tracked_threads(self, tmp_path): def test_caps_at_max_tracked_threads(self, tmp_path):
adapter = self._make_adapter(tmp_path) adapter = self._make_adapter(tmp_path)
adapter._MAX_TRACKED_THREADS = 5 adapter._threads._max_tracked = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
for i in range(10): for i in range(10):
adapter._track_thread(str(i)) adapter._threads.mark(str(i))
assert len(adapter._bot_participated_threads) == 5 saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert len(saved) == 5
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json" state_file = tmp_path / "discord_threads.json"
state_file.write_text("not valid json{{{") state_file.write_text("not valid json{{{")
adapter = self._make_adapter(tmp_path) adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set() assert "$nonexistent" not in adapter._threads
def test_missing_hermes_home_does_not_crash(self, tmp_path): def test_missing_hermes_home_does_not_crash(self, tmp_path):
"""Load/save tolerate missing directories.""" """Load/save tolerate missing directories."""
fake_home = tmp_path / "nonexistent" / "deep" fake_home = tmp_path / "nonexistent" / "deep"
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}): with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
from gateway.platforms.discord import DiscordAdapter from gateway.platforms.helpers import ThreadParticipationTracker
# _load should return empty set, not crash # ThreadParticipationTracker should return empty set, not crash
threads = DiscordAdapter._load_participated_threads() tracker = ThreadParticipationTracker("discord")
assert threads == set() assert "$test" not in tracker

View file

@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter() adapter = _make_adapter()
adapter._bot_participated_threads.add("$thread1") adapter._threads.mark("$thread1")
event = _make_event("hello without mention", thread_id="$thread1") event = _make_event("hello without mention", thread_id="$thread1")
@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter() adapter = _make_adapter()
adapter._bot_participated_threads.add("$thread_root") adapter._threads.mark("$thread_root")
event = _make_event("reply in thread", thread_id="$thread_root") event = _make_event("reply in thread", thread_id="$thread_root")
await adapter._on_room_message(event) await adapter._on_room_message(event)
@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auto_thread_tracks_participation(monkeypatch): async def test_auto_thread_tracks_participation(monkeypatch):
"""Auto-created threads are tracked in _bot_participated_threads.""" """Auto-created threads are tracked in _threads."""
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false") monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False) monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter() adapter = _make_adapter()
event = _make_event("hello", event_id="$msg1") event = _make_event("hello", event_id="$msg1")
with patch.object(adapter, "_save_participated_threads"): with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event) await adapter._on_room_message(event)
assert "$msg1" in adapter._bot_participated_threads assert "$msg1" in adapter._threads
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
class TestThreadPersistence: class TestThreadPersistence:
def test_empty_state_file(self, tmp_path, monkeypatch): def test_empty_state_file(self, tmp_path, monkeypatch):
"""No state file → empty set.""" """No state file → empty set."""
from gateway.platforms.matrix import MatrixAdapter from gateway.platforms.helpers import ThreadParticipationTracker
monkeypatch.setattr( monkeypatch.setattr(
MatrixAdapter, "_thread_state_path", ThreadParticipationTracker, "_state_path",
staticmethod(lambda: tmp_path / "matrix_threads.json"), lambda self: tmp_path / "matrix_threads.json",
) )
adapter = _make_adapter() adapter = _make_adapter()
loaded = adapter._load_participated_threads() assert "$nonexistent" not in adapter._threads
assert loaded == set()
def test_track_thread_persists(self, tmp_path, monkeypatch): def test_track_thread_persists(self, tmp_path, monkeypatch):
"""_track_thread writes to disk.""" """mark() writes to disk."""
from gateway.platforms.matrix import MatrixAdapter from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json" state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr( monkeypatch.setattr(
MatrixAdapter, "_thread_state_path", ThreadParticipationTracker, "_state_path",
staticmethod(lambda: state_path), lambda self: state_path,
) )
adapter = _make_adapter() adapter = _make_adapter()
adapter._track_thread("$thread_abc") adapter._threads.mark("$thread_abc")
data = json.loads(state_path.read_text()) data = json.loads(state_path.read_text())
assert "$thread_abc" in data assert "$thread_abc" in data
def test_threads_survive_reload(self, tmp_path, monkeypatch): def test_threads_survive_reload(self, tmp_path, monkeypatch):
"""Persisted threads are loaded by a new adapter instance.""" """Persisted threads are loaded by a new adapter instance."""
from gateway.platforms.matrix import MatrixAdapter from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json" state_path = tmp_path / "matrix_threads.json"
state_path.write_text(json.dumps(["$t1", "$t2"])) state_path.write_text(json.dumps(["$t1", "$t2"]))
monkeypatch.setattr( monkeypatch.setattr(
MatrixAdapter, "_thread_state_path", ThreadParticipationTracker, "_state_path",
staticmethod(lambda: state_path), lambda self: state_path,
) )
adapter = _make_adapter() adapter = _make_adapter()
assert "$t1" in adapter._bot_participated_threads assert "$t1" in adapter._threads
assert "$t2" in adapter._bot_participated_threads assert "$t2" in adapter._threads
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch): def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
"""Thread set is trimmed to _MAX_TRACKED_THREADS.""" """Thread set is trimmed to max_tracked."""
from gateway.platforms.matrix import MatrixAdapter from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json" state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr( monkeypatch.setattr(
MatrixAdapter, "_thread_state_path", ThreadParticipationTracker, "_state_path",
staticmethod(lambda: state_path), lambda self: state_path,
) )
adapter = _make_adapter() adapter = _make_adapter()
adapter._MAX_TRACKED_THREADS = 5 adapter._threads._max_tracked = 5
for i in range(10): for i in range(10):
adapter._bot_participated_threads.add(f"$t{i}") adapter._threads.mark(f"$t{i}")
adapter._save_participated_threads()
data = json.loads(state_path.read_text()) data = json.loads(state_path.read_text())
assert len(data) == 5 assert len(data) == 5
@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
_set_dm(adapter) _set_dm(adapter)
event = _make_event("@hermes:example.org help me", event_id="$dm1") event = _make_event("@hermes:example.org help me", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"): with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event) await adapter._on_room_message(event)
adapter.handle_message.assert_awaited_once() adapter.handle_message.assert_awaited_once()
@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
adapter = _make_adapter() adapter = _make_adapter()
_set_dm(adapter) _set_dm(adapter)
adapter._bot_participated_threads.add("$existing_thread") adapter._threads.mark("$existing_thread")
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread") event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
await adapter._on_room_message(event) await adapter._on_room_message(event)
@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dm_mention_thread_tracks_participation(monkeypatch): async def test_dm_mention_thread_tracks_participation(monkeypatch):
"""DM mention-thread tracks the thread in _bot_participated_threads.""" """DM mention-thread tracks the thread in _threads."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true") monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
_set_dm(adapter) _set_dm(adapter)
event = _make_event("@hermes:example.org help", event_id="$dm1") event = _make_event("@hermes:example.org help", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"): with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event) await adapter._on_room_message(event)
assert "$dm1" in adapter._bot_participated_threads assert "$dm1" in adapter._threads
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -614,25 +614,27 @@ class TestMattermostDedup:
assert self.adapter.handle_message.call_count == 2 assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self): def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL.""" """Dedup cache should remove entries older than TTL on overflow."""
now = time.time() now = time.time()
dedup = self.adapter._dedup
# Fill with enough expired entries to trigger pruning # Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10): for i in range(dedup._max_size + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
# Add a fresh one # Add a fresh one
self.adapter._seen_posts["fresh"] = now dedup._seen["fresh"] = now
self.adapter._prune_seen() # Trigger pruning by calling is_duplicate with a new entry (over max_size)
dedup.is_duplicate("trigger_prune")
# Old entries should be pruned, fresh one kept # Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts assert "fresh" in dedup._seen
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX assert len(dedup._seen) < dedup._max_size + 10
def test_seen_cache_tracks_post_ids(self): def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict.""" """Posts are tracked in the dedup cache."""
self.adapter._seen_posts["test_post"] = time.time() self.adapter._dedup._seen["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts assert "test_post" in self.adapter._dedup._seen
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -114,16 +114,16 @@ class TestSignalAdapterInit:
class TestSignalHelpers: class TestSignalHelpers:
def test_redact_phone_long(self): def test_redact_phone_long(self):
from gateway.platforms.signal import _redact_phone from gateway.platforms.helpers import redact_phone
assert _redact_phone("+15551234567") == "+155****4567" assert redact_phone("+155****4567") == "+155****4567"
def test_redact_phone_short(self): def test_redact_phone_short(self):
from gateway.platforms.signal import _redact_phone from gateway.platforms.helpers import redact_phone
assert _redact_phone("+12345") == "+1****45" assert redact_phone("+12345") == "+1****45"
def test_redact_phone_empty(self): def test_redact_phone_empty(self):
from gateway.platforms.signal import _redact_phone from gateway.platforms.helpers import redact_phone
assert _redact_phone("") == "<none>" assert redact_phone("") == "<none>"
def test_parse_comma_list(self): def test_parse_comma_list(self):
from gateway.platforms.signal import _parse_comma_list from gateway.platforms.signal import _parse_comma_list

View file

@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch):
async def _noop(): async def _noop():
return [] return []
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
# Mock HTTPXRequest so the builder chain doesn't fail
monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock())
@pytest.mark.asyncio @pytest.mark.asyncio
@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch):
ok = await adapter.connect() ok = await adapter.connect()
assert ok is False assert ok is False
assert adapter.fatal_error_code == "telegram_token_lock" assert adapter.fatal_error_code == "telegram-bot-token_lock"
assert adapter.has_fatal_error is True assert adapter.has_fatal_error is True
assert "already using this Telegram bot token" in adapter.fatal_error_message assert "already in use" in adapter.fatal_error_message
@pytest.mark.asyncio @pytest.mark.asyncio
@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch):
) )
builder = MagicMock() builder = MagicMock()
builder.token.return_value = builder builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
) )
builder = MagicMock() builder = MagicMock()
builder.token.return_value = builder builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder))) monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m
builder = MagicMock() builder = MagicMock()
builder.token.return_value = builder builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
app = SimpleNamespace( app = SimpleNamespace(
bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()), bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()),
updater=SimpleNamespace(), updater=SimpleNamespace(),
@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch):
) )
builder = MagicMock() builder = MagicMock()
builder.token.return_value = builder builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app builder.build.return_value = app
monkeypatch.setattr( monkeypatch.setattr(
"gateway.platforms.telegram.Application", "gateway.platforms.telegram.Application",

View file

@ -348,7 +348,7 @@ word word
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md") result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
assert result["success"] is False assert result["success"] is False
assert "boundary" in result["error"].lower() assert "escapes" in result["error"].lower()
assert outside_file.read_text() == "old text here" assert outside_file.read_text() == "old text here"
@ -412,7 +412,7 @@ class TestWriteFile:
result = _write_file("my-skill", "references/escape/owned.md", "malicious") result = _write_file("my-skill", "references/escape/owned.md", "malicious")
assert result["success"] is False assert result["success"] is False
assert "boundary" in result["error"].lower() assert "escapes" in result["error"].lower()
assert not (outside_dir / "owned.md").exists() assert not (outside_dir / "owned.md").exists()
@ -449,7 +449,7 @@ class TestRemoveFile:
result = _remove_file("my-skill", "references/escape/keep.txt") result = _remove_file("my-skill", "references/escape/keep.txt")
assert result["success"] is False assert result["success"] is False
assert "boundary" in result["error"].lower() assert "escapes" in result["error"].lower()
assert outside_file.exists() assert outside_file.exists()

View file

@ -80,20 +80,18 @@ def register_credential_file(
# Resolve symlinks and normalise ``..`` before the containment check so # Resolve symlinks and normalise ``..`` before the containment check so
# that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME. # that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME.
try: from tools.path_security import validate_within_dir
resolved = host_path.resolve()
hermes_home_resolved = hermes_home.resolve() containment_error = validate_within_dir(host_path, hermes_home)
resolved.relative_to(hermes_home_resolved) # raises ValueError if outside if containment_error:
except ValueError:
logger.warning( logger.warning(
"credential_files: rejected path traversal %r " "credential_files: rejected path traversal %r (%s)",
"(resolves to %s, outside HERMES_HOME %s)",
relative_path, relative_path,
resolved, containment_error,
hermes_home_resolved,
) )
return False return False
resolved = host_path.resolve()
if not resolved.is_file(): if not resolved.is_file():
logger.debug("credential_files: skipping %s (not found)", resolved) logger.debug("credential_files: skipping %s (not found)", resolved)
return False return False
@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]:
cfg = read_raw_config() cfg = read_raw_config()
cred_files = cfg.get("terminal", {}).get("credential_files") cred_files = cfg.get("terminal", {}).get("credential_files")
if isinstance(cred_files, list): if isinstance(cred_files, list):
hermes_home_resolved = hermes_home.resolve() from tools.path_security import validate_within_dir
for item in cred_files: for item in cred_files:
if isinstance(item, str) and item.strip(): if isinstance(item, str) and item.strip():
rel = item.strip() rel = item.strip()
@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]:
"credential_files: rejected absolute config path %r", rel, "credential_files: rejected absolute config path %r", rel,
) )
continue continue
host_path = (hermes_home / rel).resolve() host_path = hermes_home / rel
try: containment_error = validate_within_dir(host_path, hermes_home)
host_path.relative_to(hermes_home_resolved) if containment_error:
except ValueError:
logger.warning( logger.warning(
"credential_files: rejected config path traversal %r " "credential_files: rejected config path traversal %r (%s)",
"(resolves to %s, outside HERMES_HOME %s)", rel, containment_error,
rel, host_path, hermes_home_resolved,
) )
continue continue
if host_path.is_file(): resolved_path = host_path.resolve()
if resolved_path.is_file():
container_path = f"/root/.hermes/{rel}" container_path = f"/root/.hermes/{rel}"
result.append({ result.append({
"host_path": str(host_path), "host_path": str(resolved_path),
"container_path": container_path, "container_path": container_path,
}) })
except Exception as e: except Exception as e:

View file

@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]:
) )
# Validate containment after resolution # Validate containment after resolution
from tools.path_security import validate_within_dir
scripts_dir = get_hermes_home() / "scripts" scripts_dir = get_hermes_home() / "scripts"
scripts_dir.mkdir(parents=True, exist_ok=True) scripts_dir.mkdir(parents=True, exist_ok=True)
resolved = (scripts_dir / raw).resolve() containment_error = validate_within_dir(scripts_dir / raw, scripts_dir)
try: if containment_error:
resolved.relative_to(scripts_dir.resolve())
except ValueError:
return ( return (
f"Script path escapes the scripts directory via traversal: {raw!r}" f"Script path escapes the scripts directory via traversal: {raw!r}"
) )

43
tools/path_security.py Normal file
View file

@ -0,0 +1,43 @@
"""Shared path validation helpers for tool implementations.
Extracts the ``resolve() + relative_to()`` and ``..`` traversal check
patterns previously duplicated across skill_manager_tool, skills_tool,
skills_hub, cronjob_tools, and credential_files.
"""
import logging
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
def validate_within_dir(path: Path, root: Path) -> Optional[str]:
"""Ensure *path* resolves to a location within *root*.
Returns an error message string if validation fails, or ``None`` if the
path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize
``..`` components.
Usage::
error = validate_within_dir(user_path, allowed_root)
if error:
return json.dumps({"error": error})
"""
try:
resolved = path.resolve()
root_resolved = root.resolve()
resolved.relative_to(root_resolved)
except (ValueError, OSError) as exc:
return f"Path escapes allowed directory: {exc}"
return None
def has_traversal_component(path_str: str) -> bool:
"""Return True if *path_str* contains ``..`` traversal components.
Quick check for obvious traversal attempts before doing full resolution.
"""
parts = Path(path_str).parts
return ".." in parts

View file

@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]:
Validate a file path for write_file/remove_file. Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir. Must be under an allowed subdirectory and not escape the skill dir.
""" """
from tools.path_security import has_traversal_component
if not file_path: if not file_path:
return "file_path is required." return "file_path is required."
normalized = Path(file_path) normalized = Path(file_path)
# Prevent path traversal # Prevent path traversal
if ".." in normalized.parts: if has_traversal_component(file_path):
return "Path traversal ('..') is not allowed." return "Path traversal ('..') is not allowed."
# Must be under an allowed subdirectory # Must be under an allowed subdirectory
@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]:
def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]: def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]:
"""Resolve a supporting-file path and ensure it stays within the skill directory.""" """Resolve a supporting-file path and ensure it stays within the skill directory."""
from tools.path_security import validate_within_dir
target = skill_dir / file_path target = skill_dir / file_path
try: error = validate_within_dir(target, skill_dir)
resolved = target.resolve(strict=False) if error:
skill_dir_resolved = skill_dir.resolve() return None, error
resolved.relative_to(skill_dir_resolved)
except ValueError:
return None, "Path escapes skill directory boundary."
except OSError as e:
return None, f"Invalid file path '{file_path}': {e}"
return target, None return target, None

View file

@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
return None return None
def _estimate_tokens(content: str) -> int: # Token estimation — use the shared implementation from model_metadata.
""" from agent.model_metadata import estimate_tokens_rough as _estimate_tokens
Rough token estimate (4 chars per token average).
Args:
content: Text content
Returns:
Estimated token count
"""
return len(content) // 4
def _parse_tags(tags_value) -> List[str]: def _parse_tags(tags_value) -> List[str]:
@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
# If a specific file path is requested, read that instead # If a specific file path is requested, read that instead
if file_path and skill_dir: if file_path and skill_dir:
from tools.path_security import validate_within_dir, has_traversal_component
# Security: Prevent path traversal attacks # Security: Prevent path traversal attacks
normalized_path = Path(file_path) if has_traversal_component(file_path):
if ".." in normalized_path.parts:
return json.dumps( return json.dumps(
{ {
"success": False, "success": False,
@ -962,27 +954,16 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
target_file = skill_dir / file_path target_file = skill_dir / file_path
# Security: Verify resolved path is still within skill directory # Security: Verify resolved path is still within skill directory
try: traversal_error = validate_within_dir(target_file, skill_dir)
resolved = target_file.resolve() if traversal_error:
skill_dir_resolved = skill_dir.resolve()
if not resolved.is_relative_to(skill_dir_resolved):
return json.dumps( return json.dumps(
{ {
"success": False, "success": False,
"error": "Path escapes skill directory boundary.", "error": traversal_error,
"hint": "Use a relative path within the skill directory", "hint": "Use a relative path within the skill directory",
}, },
ensure_ascii=False, ensure_ascii=False,
) )
except (OSError, ValueError):
return json.dumps(
{
"success": False,
"error": f"Invalid file path: '{file_path}'",
"hint": "Use a valid relative path within the skill directory",
},
ensure_ascii=False,
)
if not target_file.exists(): if not target_file.exists():
# List available files in the skill directory, organized by type # List available files in the skill directory, organized by type
available_files = { available_files = {

View file

@ -1,13 +1,16 @@
"""Shared utility functions for hermes-agent.""" """Shared utility functions for hermes-agent."""
import json import json
import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, List, Optional, Union
import yaml import yaml
logger = logging.getLogger(__name__)
TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"}) TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"})
@ -124,3 +127,88 @@ def atomic_yaml_write(
except OSError: except OSError:
pass pass
raise raise
# ─── JSON Helpers ─────────────────────────────────────────────────────────────
def safe_json_loads(text: str, default: Any = None) -> Any:
"""Parse JSON, returning *default* on any parse error.
Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)``
pattern duplicated across display.py, anthropic_adapter.py,
auxiliary_client.py, and others.
"""
try:
return json.loads(text)
except (json.JSONDecodeError, TypeError, ValueError):
return default
def read_json_file(path: Path, default: Any = None) -> Any:
"""Read and parse a JSON file, returning *default* on any error.
Replaces the repeated ``try: json.loads(path.read_text()) except ...``
pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py,
and skill_utils.py.
"""
try:
return json.loads(Path(path).read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError, IOError, ValueError) as exc:
logger.debug("Failed to read %s: %s", path, exc)
return default
def read_jsonl(path: Path) -> List[dict]:
"""Read a JSONL file (one JSON object per line).
Returns a list of parsed objects, skipping blank lines.
"""
entries = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def append_jsonl(path: Path, entry: dict) -> None:
"""Append a single JSON object as a new line to a JSONL file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# ─── Environment Variable Helpers ─────────────────────────────────────────────
def env_str(key: str, default: str = "") -> str:
"""Read an environment variable, stripped of whitespace.
Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times
across runtime_provider.py, anthropic_adapter.py, models.py, etc.
"""
return os.getenv(key, default).strip()
def env_lower(key: str, default: str = "") -> str:
"""Read an environment variable, stripped and lowercased."""
return os.getenv(key, default).strip().lower()
def env_int(key: str, default: int = 0) -> int:
"""Read an environment variable as an integer, with fallback."""
raw = os.getenv(key, "").strip()
if not raw:
return default
try:
return int(raw)
except (ValueError, TypeError):
return default
def env_bool(key: str, default: bool = False) -> bool:
"""Read an environment variable as a boolean."""
return is_truthy_value(os.getenv(key, ""), default=default)