Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor

This commit is contained in:
Brooklyn Nicholson 2026-04-11 17:15:41 -05:00
commit ec553fdb49
93 changed files with 3531 additions and 1330 deletions

View file

@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
Used by AIAgent._execute_tool_calls for CLI feedback.
"""
import json
import logging
import os
import sys
@ -14,6 +13,8 @@ from dataclasses import dataclass, field
from difflib import unified_diff
from pathlib import Path
from utils import safe_json_loads
# ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m"
_RESET = "\033[0m"
@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
"""Conservatively detect whether a tool result represents success."""
if not result:
return False
try:
data = json.loads(result)
except (json.JSONDecodeError, TypeError):
data = safe_json_loads(result)
if data is None:
return False
if not isinstance(data, dict):
return False
@ -423,10 +423,7 @@ def extract_edit_diff(
) -> str | None:
"""Extract a unified diff from a file-edit tool result."""
if tool_name == "patch" and result:
try:
data = json.loads(result)
except (json.JSONDecodeError, TypeError):
data = None
data = safe_json_loads(result)
if isinstance(data, dict):
diff = data.get("diff")
if isinstance(diff, str) and diff.strip():
@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
return False, ""
if tool_name == "terminal":
try:
data = json.loads(result)
data = safe_json_loads(result)
if isinstance(data, dict):
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse terminal result as JSON for exit code check")
return False, ""
# Memory-specific: distinguish "full" from real errors
if tool_name == "memory":
try:
data = json.loads(result)
data = safe_json_loads(result)
if isinstance(data, dict):
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse memory result as JSON for capacity check")
# Generic heuristic for non-terminal tools
lower = result[:500].lower()

View file

@ -179,6 +179,12 @@ _MAX_COMPLETION_KEYS = (
# Local server hostnames / address patterns
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
# Docker / Podman / Lima DNS names that resolve to the host machine
_CONTAINER_LOCAL_SUFFIXES = (
".docker.internal",
".containers.internal",
".lima.internal",
)
def _normalize_base_url(base_url: str) -> str:
@ -254,6 +260,9 @@ def is_local_endpoint(base_url: str) -> bool:
return False
if host in _LOCAL_HOSTS:
return True
# Docker / Podman / Lima internal DNS names (e.g. host.docker.internal)
if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES):
return True
# RFC-1918 private ranges and link-local
import ipaddress
try:

View file

@ -12,7 +12,7 @@ import threading
from collections import OrderedDict
from pathlib import Path
from hermes_constants import get_hermes_home
from hermes_constants import get_hermes_home, get_skills_dir
from typing import Optional
from agent.skill_utils import (
@ -548,8 +548,7 @@ def build_skills_system_prompt(
are read-only they appear in the index but new skills are always created
in the local dir. Local skills take precedence when names collide.
"""
hermes_home = get_hermes_home()
skills_dir = hermes_home / "skills"
skills_dir = get_skills_dir()
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
if not skills_dir.exists() and not external_dirs:

View file

@ -12,7 +12,7 @@ import sys
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
from hermes_constants import get_hermes_home
from hermes_constants import get_config_path, get_skills_dir
logger = logging.getLogger(__name__)
@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
Reads the config file directly (no CLI config imports) to stay
lightweight.
"""
config_path = get_hermes_home() / "config.yaml"
config_path = get_config_path()
if not config_path.exists():
return set()
try:
@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
path. Only directories that actually exist are returned. Duplicates and
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
"""
config_path = get_hermes_home() / "config.yaml"
config_path = get_config_path()
if not config_path.exists():
return []
try:
@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
if not isinstance(raw_dirs, list):
return []
local_skills = (get_hermes_home() / "skills").resolve()
local_skills = get_skills_dir().resolve()
seen: Set[Path] = set()
result: List[Path] = []
@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
The local dir is always first (and always included even if it doesn't exist
yet callers handle that). External dirs follow in config order.
"""
dirs = [get_hermes_home() / "skills"]
dirs = [get_skills_dir()]
dirs.extend(get_external_skills_dirs())
return dirs
@ -384,7 +384,7 @@ def resolve_skill_config_values(
current values (or the declared default if the key isn't set).
Path values are expanded via ``os.path.expanduser``.
"""
config_path = get_hermes_home() / "config.yaml"
config_path = get_config_path()
config: Dict[str, Any] = {}
if config_path.exists():
try:

9
cli.py
View file

@ -2748,6 +2748,15 @@ class HermesCLI:
self.api_key = api_key
self.base_url = base_url
# When a custom_provider entry carries an explicit `model` field,
# use it as the effective model name. Without this, running
# `hermes chat --model <provider-name>` sends the provider name
# (e.g. "my-provider") as the model string to the API instead of
# the configured model (e.g. "qwen3.6-plus"), causing 400 errors.
runtime_model = runtime.get("model")
if runtime_model and isinstance(runtime_model, str):
self.model = runtime_model
# Normalize model for the resolved provider (e.g. swap non-Codex
# models when provider is openai-codex). Fixes #651.
model_changed = self._normalize_model_for_provider(resolved_provider)

View file

@ -722,6 +722,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
provider_sort=pr.get("sort"),
disabled_toolsets=["cronjob", "messaging", "clarify"],
quiet_mode=True,
skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd
skip_memory=True, # Cron system prompts would corrupt user representations
platform="cron",
session_id=_cron_session_id,

View file

@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8642
MAX_STORED_RESPONSES = 100
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
def check_api_server_requirements() -> bool:
@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter):
"""
import queue as _q
sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}
sse_headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
}
# CORS middleware can't inject headers into StreamResponse after
# prepare() flushes them, so resolve CORS headers up front.
origin = request.headers.get("Origin", "")
@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter):
await response.prepare(request)
try:
last_activity = time.monotonic()
# Role chunk
role_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter):
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
last_activity = time.monotonic()
# Helper — route a queue item to the correct SSE event.
async def _emit(item):
@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter):
"choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
return time.monotonic()
# Stream content chunks as they arrive from the agent
loop = asyncio.get_event_loop()
@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter):
delta = stream_q.get_nowait()
if delta is None:
break
await _emit(delta)
last_activity = await _emit(delta)
except _q.Empty:
break
break
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
await response.write(b": keepalive\n\n")
last_activity = time.monotonic()
continue
if delta is None: # End of stream sentinel
break
await _emit(delta)
last_activity = await _emit(delta)
# Get usage from completed agent
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}

View file

@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC):
result = handler(self)
if asyncio.iscoroutine(result):
await result
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
"""Acquire a scoped lock for this adapter. Returns True on success."""
from gateway.status import acquire_scoped_lock
self._platform_lock_scope = scope
self._platform_lock_identity = identity
acquired, existing = acquire_scoped_lock(
scope, identity, metadata={'platform': self.platform.value}
)
if acquired:
return True
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = (
f'{resource_desc} already in use'
+ (f' (PID {owner_pid})' if owner_pid else '')
+ '. Stop the other gateway first.'
)
logger.error('[%s] %s', self.name, message)
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
return False
def _release_platform_lock(self) -> None:
"""Release the scoped lock acquired by _acquire_platform_lock."""
identity = getattr(self, '_platform_lock_identity', None)
if not identity:
return
from gateway.status import release_scoped_lock
release_scoped_lock(self._platform_lock_scope, identity)
self._platform_lock_identity = None
@property
def name(self) -> str:
"""Human-readable name for this adapter."""

View file

@ -30,6 +30,7 @@ from gateway.platforms.base import (
cache_audio_from_bytes,
cache_document_from_bytes,
)
from gateway.platforms.helpers import strip_markdown
logger = logging.getLogger(__name__)
@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
return value.rstrip("/")
def _strip_markdown(text: str) -> str:
"""Strip common markdown formatting for iMessage plain-text delivery."""
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
text = re.sub(r"`(.+?)`", r"\1", text)
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
# ---------------------------------------------------------------------------
@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
text = _strip_markdown(content or "")
text = strip_markdown(content or "")
if not text:
return SendResult(success=False, error="BlueBubbles send requires text")
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
return info
def format_message(self, content: str) -> str:
return _strip_markdown(content)
return strip_markdown(content)
# ------------------------------------------------------------------
# Inbound attachment downloading (from #4588)

View file

@ -42,6 +42,7 @@ except ImportError:
httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -52,8 +53,6 @@ from gateway.platforms.base import (
logger = logging.getLogger(__name__)
MAX_MESSAGE_LENGTH = 20000
DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
_SESSION_WEBHOOKS_MAX = 500
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_task: Optional[asyncio.Task] = None
self._http_client: Optional["httpx.AsyncClient"] = None
# Message deduplication: msg_id -> timestamp
self._seen_messages: Dict[str, float] = {}
# Message deduplication
self._dedup = MessageDeduplicator(max_size=1000)
# Map chat_id -> session_webhook for reply routing
self._session_webhooks: Dict[str, str] = {}
@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
self._stream_client = None
self._session_webhooks.clear()
self._seen_messages.clear()
self._dedup.clear()
logger.info("[%s] Disconnected", self.name)
# -- Inbound message processing -----------------------------------------
@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
async def _on_message(self, message: "ChatbotMessage") -> None:
"""Process an incoming DingTalk chatbot message."""
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
if self._is_duplicate(msg_id):
if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
return
@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
content = " ".join(parts).strip()
return content
# -- Deduplication ------------------------------------------------------
def _is_duplicate(self, msg_id: str) -> bool:
"""Check and record a message ID. Returns True if already seen."""
now = time.time()
if len(self._seen_messages) > DEDUP_MAX_SIZE:
cutoff = now - DEDUP_WINDOW_SECONDS
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
if msg_id in self._seen_messages:
return True
self._seen_messages[msg_id] = now
return False
# -- Outbound messaging -------------------------------------------------
async def send(

View file

@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
import re
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
# Track threads where the bot has participated so follow-up messages
# in those threads don't require @mention. Persisted to disk so the
# set survives gateway restarts.
self._bot_participated_threads: set = self._load_participated_threads()
self._threads = ThreadParticipationTracker("discord")
# Persistent typing indicator loops per channel (DMs don't reliably
# show the standard typing gateway event for bots)
self._typing_tasks: Dict[str, asyncio.Task] = {}
self._bot_task: Optional[asyncio.Task] = None
# Cap to prevent unbounded growth (Discord threads get archived).
self._MAX_TRACKED_THREADS = 500
# Dedup cache: message_id → timestamp. Prevents duplicate bot
# responses when Discord RESUME replays events after reconnects.
self._seen_messages: Dict[str, float] = {}
self._SEEN_TTL = 300 # 5 minutes
self._SEEN_MAX = 2000 # prune threshold
# Dedup cache: prevents duplicate bot responses when Discord
# RESUME replays events after reconnects.
self._dedup = MessageDeduplicator()
# Reply threading mode: "off" (no replies), "first" (reply on first
# chunk only, default), "all" (reply-reference on every chunk).
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
return False
try:
# Acquire scoped lock to prevent duplicate bot token usage
from gateway.status import acquire_scoped_lock
self._token_lock_identity = self.config.token
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
if not acquired:
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
logger.error('[%s] %s', self.name, message)
self._set_fatal_error('discord_token_lock', message, retryable=False)
if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
return False
# Parse allowed user entries (may contain usernames or IDs)
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
if allowed_env:
@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event
async def on_message(message: DiscordMessage):
# Dedup: Discord RESUME replays events after reconnects (#4777)
msg_id = str(message.id)
now = time.time()
if msg_id in adapter_self._seen_messages:
if adapter_self._dedup.is_duplicate(str(message.id)):
return
adapter_self._seen_messages[msg_id] = now
if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
cutoff = now - adapter_self._SEEN_TTL
adapter_self._seen_messages = {
k: v for k, v in adapter_self._seen_messages.items()
if v > cutoff
}
# Always ignore our own messages
if message.author == self._client.user:
@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
except asyncio.TimeoutError:
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
self._release_platform_lock()
return False
except Exception as e: # pragma: no cover - defensive logging
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
self._release_platform_lock()
return False
async def disconnect(self) -> None:
@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
self._client = None
self._ready_event.clear()
# Release the token lock
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('discord-bot-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
self._release_platform_lock()
logger.info("[%s] Disconnected", self.name)
@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so follow-ups don't require @mention
if thread_id:
self._track_thread(thread_id)
self._threads.mark(thread_id)
# If a message was provided, kick off a new Hermes session in the thread
starter = (message or "").strip()
@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}"
return thread_name
# ------------------------------------------------------------------
# Thread participation persistence
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "discord_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load discord thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
# Trim to most recent entries if over cap
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save discord thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned
@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Skip the mention check if the message is in a thread where
# the bot has previously participated (auto-created or replied in).
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
in_bot_thread = is_thread and thread_id in self._threads
if require_mention and not is_free_channel and not in_bot_thread:
if self._client.user not in message.mentions:
@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_thread = True
thread_id = str(thread.id)
auto_threaded_channel = thread
self._track_thread(thread_id)
self._threads.mark(thread_id)
# Determine message type
msg_type = MessageType.TEXT
@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so the bot won't require @mention for
# follow-up messages in threads it has already engaged in.
if thread_id:
self._track_thread(thread_id)
self._threads.mark(thread_id)
# Only batch plain text messages — commands, media, etc. dispatch
# immediately since they won't be split by the Discord client.

View file

@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
def _strip_markdown_to_plain_text(text: str) -> str:
"""Strip markdown formatting to plain text for Feishu text fallbacks.
Delegates common markdown stripping to the shared helper and adds
Feishu-specific patterns (blockquotes, strikethrough, underline tags,
horizontal rules, \\r\\n normalisation).
"""
from gateway.platforms.helpers import strip_markdown
plain = text.replace("\r\n", "\n")
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain)
plain = re.sub(r"\n{3,}", "\n\n", plain)
return plain.strip()
plain = strip_markdown(plain)
return plain
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:

View file

@ -0,0 +1,261 @@
"""Shared helper classes for gateway platform adapters.
Extracts common patterns that were duplicated across 5-7 adapters:
message deduplication, text batch aggregation, markdown stripping,
and thread participation tracking.
"""
import asyncio
import json
import logging
import re
import time
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional
if TYPE_CHECKING:
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
logger = logging.getLogger(__name__)
# ─── Message Deduplication ────────────────────────────────────────────────────
class MessageDeduplicator:
"""TTL-based message deduplication cache.
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
previously duplicated in discord, slack, dingtalk, wecom, weixin,
mattermost, and feishu adapters.
Usage::
self._dedup = MessageDeduplicator()
# In message handler:
if self._dedup.is_duplicate(msg_id):
return
"""
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
self._seen: Dict[str, float] = {}
self._max_size = max_size
self._ttl = ttl_seconds
def is_duplicate(self, msg_id: str) -> bool:
"""Return True if *msg_id* was already seen within the TTL window."""
if not msg_id:
return False
now = time.time()
if msg_id in self._seen:
return True
self._seen[msg_id] = now
if len(self._seen) > self._max_size:
cutoff = now - self._ttl
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
return False
def clear(self):
"""Clear all tracked messages."""
self._seen.clear()
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
class TextBatchAggregator:
"""Aggregates rapid-fire text events into single messages.
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
previously duplicated in telegram, discord, matrix, wecom, and feishu.
Usage::
self._text_batcher = TextBatchAggregator(
handler=self._message_handler,
batch_delay=0.6,
split_threshold=1900,
)
# In message dispatch:
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
self._text_batcher.enqueue(event, session_key)
return
"""
def __init__(
self,
handler,
*,
batch_delay: float = 0.6,
split_delay: float = 2.0,
split_threshold: int = 4000,
):
self._handler = handler
self._batch_delay = batch_delay
self._split_delay = split_delay
self._split_threshold = split_threshold
self._pending: Dict[str, "MessageEvent"] = {}
self._pending_tasks: Dict[str, asyncio.Task] = {}
def is_enabled(self) -> bool:
"""Return True if batching is active (delay > 0)."""
return self._batch_delay > 0
def enqueue(self, event: "MessageEvent", key: str) -> None:
"""Add *event* to the pending batch for *key*."""
chunk_len = len(event.text or "")
existing = self._pending.get(key)
if not existing:
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
self._pending[key] = event
else:
existing.text = f"{existing.text}\n{event.text}"
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
# Cancel prior flush timer, start a new one
prior = self._pending_tasks.get(key)
if prior and not prior.done():
prior.cancel()
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
async def _flush(self, key: str) -> None:
"""Wait then dispatch the batched event for *key*."""
current_task = self._pending_tasks.get(key)
pending = self._pending.get(key)
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
# Use longer delay when the last chunk looks like a split message
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
await asyncio.sleep(delay)
event = self._pending.pop(key, None)
if event:
try:
await self._handler(event)
except Exception:
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
if self._pending_tasks.get(key) is current_task:
self._pending_tasks.pop(key, None)
def cancel_all(self) -> None:
"""Cancel all pending flush tasks."""
for task in self._pending_tasks.values():
if not task.done():
task.cancel()
self._pending_tasks.clear()
self._pending.clear()
# ─── Markdown Stripping ──────────────────────────────────────────────────────
# Pre-compiled regexes for performance
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
def strip_markdown(text: str) -> str:
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
Replaces the identical ``_strip_markdown()`` functions previously
duplicated in sms.py, bluebubbles.py, and feishu.py.
"""
text = _RE_BOLD.sub(r"\1", text)
text = _RE_ITALIC_STAR.sub(r"\1", text)
text = _RE_BOLD_UNDER.sub(r"\1", text)
text = _RE_ITALIC_UNDER.sub(r"\1", text)
text = _RE_CODE_BLOCK.sub("", text)
text = _RE_INLINE_CODE.sub(r"\1", text)
text = _RE_HEADING.sub("", text)
text = _RE_LINK.sub(r"\1", text)
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
return text.strip()
# ─── Thread Participation Tracking ───────────────────────────────────────────
class ThreadParticipationTracker:
"""Persistent tracking of threads the bot has participated in.
Replaces the identical ``_load/_save_participated_threads`` +
``_mark_thread_participated`` pattern previously duplicated in
discord.py and matrix.py.
Usage::
self._threads = ThreadParticipationTracker("discord")
# Check membership:
if thread_id in self._threads:
...
# Mark participation:
self._threads.mark(thread_id)
"""
_MAX_TRACKED = 500
def __init__(self, platform_name: str, max_tracked: int = 500):
self._platform = platform_name
self._max_tracked = max_tracked
self._threads: set = self._load()
def _state_path(self) -> Path:
from hermes_constants import get_hermes_home
return get_hermes_home() / f"{self._platform}_threads.json"
def _load(self) -> set:
path = self._state_path()
if path.exists():
try:
return set(json.loads(path.read_text(encoding="utf-8")))
except Exception:
pass
return set()
def _save(self) -> None:
path = self._state_path()
path.parent.mkdir(parents=True, exist_ok=True)
thread_list = list(self._threads)
if len(thread_list) > self._max_tracked:
thread_list = thread_list[-self._max_tracked:]
self._threads = set(thread_list)
path.write_text(json.dumps(thread_list), encoding="utf-8")
def mark(self, thread_id: str) -> None:
"""Mark *thread_id* as participated and persist."""
if thread_id not in self._threads:
self._threads.add(thread_id)
self._save()
def __contains__(self, thread_id: str) -> bool:
return thread_id in self._threads
def clear(self) -> None:
self._threads.clear()
# ─── Phone Number Redaction ──────────────────────────────────────────────────
def redact_phone(phone: str) -> str:
"""Redact a phone number for logging, preserving country code and last 4.
Replaces the identical ``_redact_phone()`` functions in signal.py,
sms.py, and bluebubbles.py.
"""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
return phone[:4] + "****" + phone[-4:]

View file

@ -92,6 +92,7 @@ from gateway.platforms.base import (
ProcessingOutcome,
SendResult,
)
from gateway.platforms.helpers import ThreadParticipationTracker
logger = logging.getLogger(__name__)
@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
self._pending_megolm: list = []
# Thread participation tracking (for require_mention bypass)
self._bot_participated_threads: set = self._load_participated_threads()
self._MAX_TRACKED_THREADS = 500
self._threads = ThreadParticipationTracker("matrix")
# Mention/thread gating — parsed once from env vars.
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Require-mention gating.
if not is_dm:
is_free_room = room_id in self._free_rooms
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
in_bot_thread = bool(thread_id and thread_id in self._threads)
if self._require_mention and not is_free_room and not in_bot_thread:
if not is_mentioned:
return None
@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
# DM mention-thread.
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
thread_id = event_id
self._track_thread(thread_id)
self._threads.mark(thread_id)
# Strip mention from body.
if is_mentioned:
@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
# Auto-thread.
if not is_dm and not thread_id and self._auto_thread:
thread_id = event_id
self._track_thread(thread_id)
self._threads.mark(thread_id)
display_name = await self._get_display_name(room_id, sender)
source = self.build_source(
@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
)
if thread_id:
self._track_thread(thread_id)
self._threads.mark(thread_id)
self._background_read_receipt(room_id, event_id)
@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
for rid in self._joined_rooms
}
# ------------------------------------------------------------------
# Thread participation tracking
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "matrix_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load matrix thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save matrix thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
# ------------------------------------------------------------------
# Mention detection helpers
# ------------------------------------------------------------------

View file

@ -18,11 +18,11 @@ import json
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
or os.getenv("MATTERMOST_REPLY_MODE", "off")
).lower()
# Dedup cache: post_id → timestamp (prevent reprocessing)
self._seen_posts: Dict[str, float] = {}
self._SEEN_MAX = 2000
self._SEEN_TTL = 300 # 5 minutes
# Dedup cache (prevent reprocessing)
self._dedup = MessageDeduplicator()
# ------------------------------------------------------------------
# HTTP helpers
@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
post_id = post.get("id", "")
# Dedup.
self._prune_seen()
if post_id in self._seen_posts:
if self._dedup.is_duplicate(post_id):
return
self._seen_posts[post_id] = time.time()
# Build message event.
channel_id = post.get("channel_id", "")
@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
await self.handle_message(msg_event)
def _prune_seen(self) -> None:
"""Remove expired entries from the dedup cache."""
if len(self._seen_posts) < self._SEEN_MAX:
return
now = time.time()
self._seen_posts = {
pid: ts
for pid, ts in self._seen_posts.items()
if now - ts < self._SEEN_TTL
}

View file

@ -37,6 +37,7 @@ from gateway.platforms.base import (
cache_document_from_bytes,
cache_image_from_url,
)
from gateway.platforms.helpers import redact_phone
logger = logging.getLogger(__name__)
@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
# E.164 phone number pattern for redaction
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
return phone[:4] + "****" + phone[-4:]
def _parse_comma_list(value: str) -> List[str]:
"""Split a comma-separated string into a list, stripping whitespace."""
@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
self._recent_sent_timestamps: set = set()
self._max_recent_timestamps = 50
self._phone_lock_identity: Optional[str] = None
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
self.http_url, _redact_phone(self.account),
self.http_url, redact_phone(self.account),
"enabled" if self.group_allow_from else "disabled")
# ------------------------------------------------------------------
@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
try:
from gateway.status import acquire_scoped_lock
self._phone_lock_identity = self.account
acquired, existing = acquire_scoped_lock(
"signal-phone",
self._phone_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Signal account"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Signal listener."
)
logger.error("Signal: %s", message)
self._set_fatal_error("signal_phone_lock", message, retryable=False)
if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
return False
except Exception as e:
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
await self.client.aclose()
self.client = None
if self._phone_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("signal-phone", self._phone_lock_identity)
except Exception as e:
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
self._phone_lock_identity = None
self._release_platform_lock()
logger.info("Signal: disconnected")
@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
)
logger.debug("Signal: message from %s in %s: %s",
_redact_phone(sender), chat_id[:20], (text or "")[:50])
redact_phone(sender), chat_id[:20], (text or "")[:50])
await self.handle_message(event)

View file

@ -33,6 +33,7 @@ from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
self._channel_team: Dict[str, str] = {} # channel_id → team_id
# Dedup cache: event_ts → timestamp. Prevents duplicate bot
# responses when Socket Mode reconnects redeliver events.
self._seen_messages: Dict[str, float] = {}
self._SEEN_TTL = 300 # 5 minutes
self._SEEN_MAX = 2000 # prune threshold
# Dedup cache: prevents duplicate bot responses when Socket Mode
# reconnects redeliver events.
self._dedup = MessageDeduplicator()
# Track pending approval message_ts → resolved flag to prevent
# double-clicks on approval buttons.
self._approval_resolved: Dict[str, bool] = {}
@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
try:
# Acquire scoped lock to prevent duplicate app token usage
from gateway.status import acquire_scoped_lock
self._token_lock_identity = app_token
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
if not acquired:
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
logger.error('[%s] %s', self.name, message)
self._set_fatal_error('slack_token_lock', message, retryable=False)
if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
return False
# First token is the primary — used for AsyncApp / Socket Mode
@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
self._running = False
# Release the token lock (use stored identity, not re-read env)
try:
from gateway.status import release_scoped_lock
if getattr(self, '_token_lock_identity', None):
release_scoped_lock('slack-app-token', self._token_lock_identity)
self._token_lock_identity = None
except Exception:
pass
self._release_platform_lock()
logger.info("[Slack] Disconnected")
@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
"""Handle an incoming Slack message event."""
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
event_ts = event.get("ts", "")
if event_ts:
now = time.time()
if event_ts in self._seen_messages:
return
self._seen_messages[event_ts] = now
if len(self._seen_messages) > self._SEEN_MAX:
cutoff = now - self._SEEN_TTL
self._seen_messages = {
k: v for k, v in self._seen_messages.items()
if v > cutoff
}
if event_ts and self._dedup.is_duplicate(event_ts):
return
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
# "none" — ignore all bot messages (default, backward-compatible)

View file

@ -10,6 +10,9 @@ Shares credentials with the optional telephony skill — same env vars:
Gateway-specific env vars:
- SMS_WEBHOOK_PORT (default 8080)
- SMS_WEBHOOK_HOST (default 0.0.0.0)
- SMS_WEBHOOK_URL (public URL for Twilio signature validation required)
- SMS_INSECURE_NO_SIGNATURE (true to disable signature validation dev only)
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
- SMS_ALLOW_ALL_USERS (true/false)
- SMS_HOME_CHANNEL (phone number for cron delivery)
@ -17,9 +20,10 @@ Gateway-specific env vars:
import asyncio
import base64
import hashlib
import hmac
import logging
import os
import re
import urllib.parse
from typing import Any, Dict, Optional
@ -30,24 +34,14 @@ from gateway.platforms.base import (
MessageType,
SendResult,
)
from gateway.platforms.helpers import redact_phone, strip_markdown
logger = logging.getLogger(__name__)
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
DEFAULT_WEBHOOK_PORT = 8080
# E.164 phone number pattern for redaction
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
return phone[:5] + "***" + phone[-4:]
DEFAULT_WEBHOOK_HOST = "0.0.0.0"
def check_sms_requirements() -> bool:
@ -77,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter):
self._webhook_port: int = int(
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
)
self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip()
self._runner = None
self._http_session: Optional["aiohttp.ClientSession"] = None
@ -98,13 +94,33 @@ class SmsAdapter(BasePlatformAdapter):
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
return False
insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true"
if not self._webhook_url and not insecure_no_sig:
logger.error(
"[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio "
"signature validation. Set it to the public URL configured in your "
"Twilio console (e.g. https://example.com/webhooks/twilio). "
"For local development without validation, set "
"SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).",
)
return False
if insecure_no_sig and not self._webhook_url:
logger.warning(
"[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation "
"is DISABLED. Any client that can reach port %d can inject messages. "
"Do NOT use this in production.",
self._webhook_port,
)
app = web.Application()
app.router.add_post("/webhooks/twilio", self._handle_webhook)
app.router.add_get("/health", lambda _: web.Response(text="ok"))
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port)
await site.start()
self._http_session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=30),
@ -112,9 +128,10 @@ class SmsAdapter(BasePlatformAdapter):
self._running = True
logger.info(
"[sms] Twilio webhook server listening on port %d, from: %s",
"[sms] Twilio webhook server listening on %s:%d, from: %s",
self._webhook_host,
self._webhook_port,
_redact_phone(self._from_number),
redact_phone(self._from_number),
)
return True
@ -163,7 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
error_msg = body.get("message", str(body))
logger.error(
"[sms] send failed to %s: %s %s",
_redact_phone(chat_id),
redact_phone(chat_id),
resp.status,
error_msg,
)
@ -174,7 +191,7 @@ class SmsAdapter(BasePlatformAdapter):
msg_sid = body.get("sid", "")
last_result = SendResult(success=True, message_id=msg_sid)
except Exception as e:
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
return SendResult(success=False, error=str(e))
finally:
# Close session only if we created a fallback (no persistent session)
@ -192,16 +209,75 @@ class SmsAdapter(BasePlatformAdapter):
def format_message(self, content: str) -> str:
"""Strip markdown — SMS renders it as literal characters."""
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
content = re.sub(r"```[a-z]*\n?", "", content)
content = re.sub(r"`(.+?)`", r"\1", content)
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
content = re.sub(r"\n{3,}", "\n\n", content)
return content.strip()
return strip_markdown(content)
# ------------------------------------------------------------------
# Twilio signature validation
# ------------------------------------------------------------------
def _validate_twilio_signature(
self, url: str, post_params: dict, signature: str,
) -> bool:
"""Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64).
Tries both with and without the default port for the URL scheme,
since Twilio may sign with either variant.
Algorithm: https://www.twilio.com/docs/usage/security#validating-requests
"""
if self._check_signature(url, post_params, signature):
return True
variant = self._port_variant_url(url)
if variant and self._check_signature(variant, post_params, signature):
return True
return False
def _check_signature(
self, url: str, post_params: dict, signature: str,
) -> bool:
"""Compute and compare a single Twilio signature."""
data_to_sign = url
for key in sorted(post_params.keys()):
data_to_sign += key + post_params[key]
mac = hmac.new(
self._auth_token.encode("utf-8"),
data_to_sign.encode("utf-8"),
hashlib.sha1,
)
computed = base64.b64encode(mac.digest()).decode("utf-8")
return hmac.compare_digest(computed, signature)
@staticmethod
def _port_variant_url(url: str) -> str | None:
"""Return the URL with the default port toggled, or None.
Only toggles default ports (443 for https, 80 for http).
Non-standard ports are never modified.
"""
parsed = urllib.parse.urlparse(url)
default_ports = {"https": 443, "http": 80}
default_port = default_ports.get(parsed.scheme)
if default_port is None:
return None
if parsed.port == default_port:
# Has explicit default port → strip it
return urllib.parse.urlunparse(
(parsed.scheme, parsed.hostname, parsed.path,
parsed.params, parsed.query, parsed.fragment)
)
elif parsed.port is None:
# No port → add default
netloc = f"{parsed.hostname}:{default_port}"
return urllib.parse.urlunparse(
(parsed.scheme, netloc, parsed.path,
parsed.params, parsed.query, parsed.fragment)
)
# Non-standard port — no variant
return None
# ------------------------------------------------------------------
# Twilio webhook handler
@ -213,7 +289,7 @@ class SmsAdapter(BasePlatformAdapter):
try:
raw = await request.read()
# Twilio sends form-encoded data, not JSON
form = urllib.parse.parse_qs(raw.decode("utf-8"))
form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
except Exception as e:
logger.error("[sms] webhook parse error: %s", e)
return web.Response(
@ -222,6 +298,27 @@ class SmsAdapter(BasePlatformAdapter):
status=400,
)
# Validate Twilio request signature when SMS_WEBHOOK_URL is configured
if self._webhook_url:
twilio_sig = request.headers.get("X-Twilio-Signature", "")
if not twilio_sig:
logger.warning("[sms] Rejected: missing X-Twilio-Signature header")
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
status=403,
)
flat_params = {k: v[0] for k, v in form.items() if v}
if not self._validate_twilio_signature(
self._webhook_url, flat_params, twilio_sig
):
logger.warning("[sms] Rejected: invalid Twilio signature")
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
status=403,
)
# Extract fields (parse_qs returns lists)
from_number = (form.get("From", [""]))[0].strip()
to_number = (form.get("To", [""]))[0].strip()
@ -236,7 +333,7 @@ class SmsAdapter(BasePlatformAdapter):
# Ignore messages from our own number (echo prevention)
if from_number == self._from_number:
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
@ -244,8 +341,8 @@ class SmsAdapter(BasePlatformAdapter):
logger.info(
"[sms] inbound from %s -> %s: %s",
_redact_phone(from_number),
_redact_phone(to_number),
redact_phone(from_number),
redact_phone(to_number),
text[:80],
)

View file

@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
self._pending_text_batches: Dict[str, MessageEvent] = {}
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
self._token_lock_identity: Optional[str] = None
self._polling_error_task: Optional[asyncio.Task] = None
self._polling_conflict_count: int = 0
self._polling_network_error_count: int = 0
@ -300,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter):
# Exhausted retries — fatal
message = (
"Another Telegram bot poller is already using this token. "
"Another process is already polling this Telegram bot token "
"(possibly OpenClaw or another Hermes instance). "
"Hermes stopped Telegram polling after %d retries. "
"Make sure only one gateway instance is running for this bot token."
"Only one poller can run per token — stop the other process "
"and restart with 'hermes start'."
% MAX_CONFLICT_RETRIES
)
logger.error("[%s] %s Original error: %s", self.name, message, error)
@ -497,23 +498,7 @@ class TelegramAdapter(BasePlatformAdapter):
return False
try:
from gateway.status import acquire_scoped_lock
self._token_lock_identity = self.config.token
acquired, existing = acquire_scoped_lock(
"telegram-bot-token",
self._token_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Telegram bot token"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Telegram poller."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("telegram_token_lock", message, retryable=False)
if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
return False
# Build the application
@ -737,12 +722,7 @@ class TelegramAdapter(BasePlatformAdapter):
return True
except Exception as e:
if self._token_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
except Exception:
pass
self._release_platform_lock()
message = f"Telegram startup failed: {e}"
self._set_fatal_error("telegram_connect_error", message, retryable=True)
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
@ -768,12 +748,7 @@ class TelegramAdapter(BasePlatformAdapter):
await self._app.shutdown()
except Exception as e:
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
if self._token_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
except Exception as e:
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
self._release_platform_lock()
for task in self._pending_photo_batch_tasks.values():
if task and not task.done():
@ -784,7 +759,6 @@ class TelegramAdapter(BasePlatformAdapter):
self._mark_disconnected()
self._app = None
self._bot = None
self._token_lock_identity = None
logger.info("[%s] Disconnected from Telegram", self.name)
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:

View file

@ -59,6 +59,7 @@ except ImportError:
httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
HEARTBEAT_INTERVAL_SECONDS = 30.0
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000
IMAGE_MAX_BYTES = 10 * 1024 * 1024
@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
self._listen_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._pending_responses: Dict[str, asyncio.Future] = {}
self._seen_messages: Dict[str, float] = {}
self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
self._reply_req_ids: Dict[str, str] = {}
# Text batching: merge rapid successive messages (Telegram-style).
@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
await self._http_client.aclose()
self._http_client = None
self._seen_messages.clear()
self._dedup.clear()
logger.info("[%s] Disconnected", self.name)
async def _cleanup_ws(self) -> None:
@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
return
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
if self._is_duplicate(msg_id):
if self._dedup.is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
return
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter):
if voice_text:
text_parts.append(voice_text)
# Extract appmsg title (filename) for WeCom AI Bot attachments
if msgtype == "appmsg":
appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {}
title = str(appmsg.get("title") or "").strip()
if title:
text_parts.append(title)
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
quote_type = str(quote.get("msgtype") or "").lower()
if quote_type == "text":
@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter):
refs.append(("image", body["image"]))
if msgtype == "file" and isinstance(body.get("file"), dict):
refs.append(("file", body["file"]))
# Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel)
if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict):
appmsg = body["appmsg"]
if isinstance(appmsg.get("file"), dict):
refs.append(("file", appmsg["file"]))
elif isinstance(appmsg.get("image"), dict):
refs.append(("image", appmsg["image"]))
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
quote_type = str(quote.get("msgtype") or "").lower()
@ -825,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
wildcard = self._groups.get("*")
return wildcard if isinstance(wildcard, dict) else {}
def _is_duplicate(self, msg_id: str) -> bool:
now = time.time()
if len(self._seen_messages) > DEDUP_MAX_SIZE:
cutoff = now - DEDUP_WINDOW_SECONDS
self._seen_messages = {
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
}
if self._reply_req_ids:
self._reply_req_ids = {
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
}
if msg_id in self._seen_messages:
return True
self._seen_messages[msg_id] = now
return False
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
normalized_message_id = str(message_id or "").strip()
normalized_req_id = str(req_id or "").strip()

View file

@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
CRYPTO_AVAILABLE = False
from gateway.config import Platform, PlatformConfig
from gateway.platforms.helpers import MessageDeduplicator
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -63,6 +64,7 @@ from gateway.platforms.base import (
cache_image_from_bytes,
)
from hermes_constants import get_hermes_home
from utils import atomic_json_write
ILINK_BASE_URL = "https://ilinkai.weixin.qq.com"
WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
@ -206,7 +208,7 @@ def save_weixin_account(
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
path = _account_file(hermes_home, account_id)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
atomic_json_write(path, payload)
try:
path.chmod(0o600)
except OSError:
@ -269,7 +271,7 @@ class ContextTokenStore:
if key.startswith(prefix)
}
try:
self._path(account_id).write_text(json.dumps(payload), encoding="utf-8")
atomic_json_write(self._path(account_id), payload)
except Exception as exc:
logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc)
@ -868,7 +870,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str:
def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None:
path = _sync_buf_path(hermes_home, account_id)
path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8")
atomic_json_write(path, {"get_updates_buf": sync_buf})
async def qr_login(
@ -1007,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
self._typing_cache = TypingTicketCache()
self._session: Optional[aiohttp.ClientSession] = None
self._poll_task: Optional[asyncio.Task] = None
self._seen_messages: Dict[str, float] = {}
self._token_lock_identity: Optional[str] = None
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
@ -1016,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter):
self._cdn_base_url = str(
extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL)
).strip().rstrip("/")
self._send_chunk_delay_seconds = float(
extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35")
)
self._send_chunk_retries = int(
extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2")
)
self._send_chunk_retry_delay_seconds = float(
extra.get("send_chunk_retry_delay_seconds")
or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0")
)
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower()
self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower()
allow_from = extra.get("allow_from")
@ -1066,23 +1077,7 @@ class WeixinAdapter(BasePlatformAdapter):
return False
try:
from gateway.status import acquire_scoped_lock
self._token_lock_identity = self._token
acquired, existing = acquire_scoped_lock(
"weixin-bot-token",
self._token_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this Weixin token"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second Weixin poller."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("weixin_token_lock", message, retryable=False)
if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
return False
except Exception as exc:
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
@ -1106,12 +1101,7 @@ class WeixinAdapter(BasePlatformAdapter):
if self._session and not self._session.closed:
await self._session.close()
self._session = None
if self._token_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("weixin-bot-token", self._token_lock_identity)
except Exception as exc:
logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
self._release_platform_lock()
self._mark_disconnected()
logger.info("[%s] Disconnected", self.name)
@ -1189,16 +1179,8 @@ class WeixinAdapter(BasePlatformAdapter):
return
message_id = str(message.get("message_id") or "").strip()
if message_id:
now = time.time()
self._seen_messages = {
key: value
for key, value in self._seen_messages.items()
if now - value < MESSAGE_DEDUP_TTL_SECONDS
}
if message_id in self._seen_messages:
return
self._seen_messages[message_id] = now
if message_id and self._dedup.is_duplicate(message_id):
return
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
if chat_type == "group":
@ -1374,6 +1356,47 @@ class WeixinAdapter(BasePlatformAdapter):
content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages,
)
async def _send_text_chunk(
self,
*,
chat_id: str,
chunk: str,
context_token: Optional[str],
client_id: str,
) -> None:
"""Send a single text chunk with per-chunk retry and backoff."""
last_error: Optional[Exception] = None
for attempt in range(self._send_chunk_retries + 1):
try:
await _send_message(
self._session,
base_url=self._base_url,
token=self._token,
to=chat_id,
text=chunk,
context_token=context_token,
client_id=client_id,
)
return
except Exception as exc:
last_error = exc
if attempt >= self._send_chunk_retries:
break
wait = self._send_chunk_retry_delay_seconds * (attempt + 1)
logger.warning(
"[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s",
self.name,
_safe_id(chat_id),
attempt + 1,
self._send_chunk_retries + 1,
wait,
exc,
)
if wait > 0:
await asyncio.sleep(wait)
assert last_error is not None
raise last_error
async def send(
self,
chat_id: str,
@ -1388,19 +1411,16 @@ class WeixinAdapter(BasePlatformAdapter):
try:
chunks = self._split_text(self.format_message(content))
for idx, chunk in enumerate(chunks):
if idx > 0:
await asyncio.sleep(0.3)
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
await _send_message(
self._session,
base_url=self._base_url,
token=self._token,
to=chat_id,
text=chunk,
await self._send_text_chunk(
chat_id=chat_id,
chunk=chunk,
context_token=context_token,
client_id=client_id,
)
last_message_id = client_id
if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0:
await asyncio.sleep(self._send_chunk_delay_seconds)
return SendResult(success=True, message_id=last_message_id)
except Exception as exc:
logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc)

View file

@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._bridge_log: Optional[Path] = None
self._poll_task: Optional[asyncio.Task] = None
self._http_session: Optional["aiohttp.ClientSession"] = None
self._session_lock_identity: Optional[str] = None
def _whatsapp_require_mention(self) -> bool:
configured = self.config.extra.get("require_mention")
@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
# Acquire scoped lock to prevent duplicate sessions
try:
from gateway.status import acquire_scoped_lock
self._session_lock_identity = str(self._session_path)
acquired, existing = acquire_scoped_lock(
"whatsapp-session",
self._session_lock_identity,
metadata={"platform": self.platform.value},
)
if not acquired:
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
message = (
"Another local Hermes gateway is already using this WhatsApp session"
+ (f" (PID {owner_pid})." if owner_pid else ".")
+ " Stop the other gateway before starting a second WhatsApp bridge."
)
logger.error("[%s] %s", self.name, message)
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
return False
except Exception as e:
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
return True
except Exception as e:
if self._session_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("whatsapp-session", self._session_lock_identity)
except Exception:
pass
self._release_platform_lock()
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
self._close_bridge_log()
return False
@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
await self._http_session.close()
self._http_session = None
if self._session_lock_identity:
try:
from gateway.status import release_scoped_lock
release_scoped_lock("whatsapp-session", self._session_lock_identity)
except Exception as e:
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
self._release_platform_lock()
self._mark_disconnected()
self._bridge_process = None
self._close_bridge_log()
self._session_lock_identity = None
print(f"[{self.name}] Disconnected")
async def send(

View file

@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str:
return "\n".join(parts)
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
"""Consume and return the text of a pending queued message.
def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
"""Consume and return the full pending event for a session.
Preserves media context for captionless photo/document events by
building a placeholder so the message isn't silently dropped.
Queued follow-ups must preserve their media metadata so they can re-enter
the normal image/STT/document preprocessing path instead of being reduced
to a placeholder string.
"""
event = adapter.get_pending_message(session_key)
if not event:
return None
text = event.text
if not text and getattr(event, "media_urls", None):
text = _build_media_placeholder(event)
return text
return adapter.get_pending_message(session_key)
def _check_unavailable_skill(command_name: str) -> str | None:
@ -1465,7 +1460,18 @@ class GatewayRunner:
logger.info("Recovered %s background process(es) from previous run", recovered)
except Exception as e:
logger.warning("Process checkpoint recovery: %s", e)
# Suspend sessions that were active when the gateway last exited.
# This prevents stuck sessions from being blindly resumed on restart,
# which can create an unrecoverable loop (#7536). Suspended sessions
# auto-reset on the next incoming message, giving the user a clean start.
try:
suspended = self.session_store.suspend_recently_active()
if suspended:
logger.info("Suspended %d in-flight session(s) from previous run", suspended)
except Exception as e:
logger.warning("Session suspension on startup failed: %s", e)
connected_count = 0
enabled_platform_count = 0
startup_nonretryable_errors: list[str] = []
@ -2221,6 +2227,13 @@ class GatewayRunner:
# are system-generated and must skip user authorization.
if getattr(event, "internal", False):
pass
elif source.user_id is None:
# Messages with no user identity (Telegram service messages,
# channel forwards, anonymous admin actions) cannot be
# authorized — drop silently instead of triggering the pairing
# flow with a None user_id.
logger.debug("Ignoring message with no user_id from %s", source.platform.value)
return None
elif not self._is_user_authorized(source):
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
# In DMs: offer pairing code. In groups: silently ignore.
@ -2370,8 +2383,11 @@ class GatewayRunner:
self._pending_messages.pop(_quick_key, None)
if _quick_key in self._running_agents:
del self._running_agents[_quick_key]
logger.info("HARD STOP for session %s — session lock released", _quick_key[:20])
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
# Mark session suspended so the next message starts fresh
# instead of resuming the stuck context (#7536).
self.session_store.suspend_session(_quick_key)
logger.info("HARD STOP for session %s — suspended, session lock released", _quick_key[:20])
return "⚡ Force-stopped. The session is suspended — your next message will start fresh."
# /reset and /new must bypass the running-agent guard so they
# actually dispatch as commands instead of being queued as user
@ -2761,6 +2777,162 @@ class GatewayRunner:
del self._running_agents[_quick_key]
self._running_agents_ts.pop(_quick_key, None)
async def _prepare_inbound_message_text(
self,
*,
event: MessageEvent,
source: SessionSource,
history: List[Dict[str, Any]],
) -> Optional[str]:
"""Prepare inbound event text for the agent.
Keep the normal inbound path and the queued follow-up path on the same
preprocessing pipeline so sender attribution, image enrichment, STT,
document notes, reply context, and @ references all behave the same.
"""
history = history or []
message_text = event.text or ""
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
)
if _is_shared_thread and source.user_name:
message_text = f"[{source.user_name}] {message_text}"
if event.media_urls:
image_paths = []
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype.startswith("image/") or event.message_type == MessageType.PHOTO:
image_paths.append(path)
if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO):
audio_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text,
image_paths,
)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text,
audio_paths,
)
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(marker in message_text for marker in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id,
_stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
import os as _os
import re as _re
basename = _os.path.basename(path)
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
if getattr(event, "reply_to_text", None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model,
base_url=self._base_url or "",
)
_ctx_result = await preprocess_context_references_async(
message_text,
cwd=_msg_cwd,
context_length=_msg_ctx_len,
allowed_root=_msg_cwd,
)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return None
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
return message_text
async def _handle_message_with_agent(self, event, source, _quick_key: str):
"""Inner handler that runs under the _running_agents sentinel guard."""
_msg_start_time = time.time()
@ -2812,7 +2984,9 @@ class GatewayRunner:
# so the agent knows this is a fresh conversation (not an intentional /reset).
if getattr(session_entry, 'was_auto_reset', False):
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
if reset_reason == "daily":
if reset_reason == "suspended":
context_note = "[System note: The user's previous session was stopped and suspended. This is a fresh conversation with no prior context.]"
elif reset_reason == "daily":
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
else:
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
@ -2829,7 +3003,9 @@ class GatewayRunner:
)
platform_name = source.platform.value if source.platform else ""
had_activity = getattr(session_entry, 'reset_had_activity', False)
should_notify = (
# Suspended sessions always notify (they were explicitly stopped
# or crashed mid-operation) — skip the policy check.
should_notify = reset_reason == "suspended" or (
policy.notify
and had_activity
and platform_name not in policy.notify_exclude_platforms
@ -2837,7 +3013,9 @@ class GatewayRunner:
if should_notify:
adapter = self.adapters.get(source.platform)
if adapter:
if reset_reason == "daily":
if reset_reason == "suspended":
reason_text = "previous session was stopped or interrupted"
elif reset_reason == "daily":
reason_text = f"daily schedule at {policy.at_hour}:00"
else:
hours = policy.idle_minutes // 60
@ -3195,149 +3373,13 @@ class GatewayRunner:
# attachments (documents, audio, etc.) are not sent to the vision
# tool even when they appear in the same message.
# -----------------------------------------------------------------
message_text = event.text or ""
# -----------------------------------------------------------------
# Sender attribution for shared thread sessions.
#
# When multiple users share a single thread session (the default for
# threads), prefix each message with [sender name] so the agent can
# tell participants apart. Skip for DMs (single-user by nature) and
# when per-user thread isolation is explicitly enabled.
# -----------------------------------------------------------------
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
message_text = await self._prepare_inbound_message_text(
event=event,
source=source,
history=history,
)
if _is_shared_thread and source.user_name:
message_text = f"[{source.user_name}] {message_text}"
if event.media_urls:
image_paths = []
for i, path in enumerate(event.media_urls):
# Check media_types if available; otherwise infer from message type
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_image = (
mtype.startswith("image/")
or event.message_type == MessageType.PHOTO
)
if is_image:
image_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text, image_paths
)
# -----------------------------------------------------------------
# Auto-transcribe voice/audio messages sent by the user
# -----------------------------------------------------------------
if event.media_urls:
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_audio = (
mtype.startswith("audio/")
or event.message_type in (MessageType.VOICE, MessageType.AUDIO)
)
if is_audio:
audio_paths.append(path)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text, audio_paths
)
# If STT failed, send a direct message to the user so they
# know voice isn't configured — don't rely on the agent to
# relay the error clearly.
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(m in message_text for m in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
# Point to setup skill if it's installed
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id, _stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
# -----------------------------------------------------------------
# Enrich document messages with context notes for the agent
# -----------------------------------------------------------------
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
# Fall back to extension-based detection when MIME type is unreliable.
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
# Extract display filename by stripping the doc_{uuid12}_ prefix
import os as _os
basename = _os.path.basename(path)
# Format: doc_<12hex>_<original_filename>
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
# Sanitize to prevent prompt injection via filenames
import re as _re
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
# -----------------------------------------------------------------
# Inject reply context when user replies to a message not in history.
# Telegram (and other platforms) let users reply to specific messages,
# but if the quoted message is from a previous session, cron delivery,
# or background task, the agent has no context about what's being
# referenced. Prepend the quoted text so the agent understands. (#1594)
# -----------------------------------------------------------------
if getattr(event, 'reply_to_text', None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
if message_text is None:
return
try:
# Emit agent:start hook
@ -3349,30 +3391,6 @@ class GatewayRunner:
}
await self.hooks.emit("agent:start", hook_ctx)
# Expand @ context references (@file:, @folder:, @diff, etc.)
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model, base_url=self._base_url or "")
_ctx_result = await preprocess_context_references_async(
message_text, cwd=_msg_cwd,
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
# Run the agent
agent_result = await self._run_agent(
message=message_text,
@ -4010,25 +4028,31 @@ class GatewayRunner:
handles /stop before this method is reached. This handler fires
only through normal command dispatch (no running agent) or as a
fallback. Force-clean the session lock in all cases for safety.
When there IS a running/pending agent, the session is also marked
as *suspended* so the next message starts a fresh session instead
of resuming the stuck context (#7536).
"""
source = event.source
session_entry = self.session_store.get_or_create_session(source)
session_key = session_entry.session_key
agent = self._running_agents.get(session_key)
if agent is _AGENT_PENDING_SENTINEL:
# Force-clean the sentinel so the session is unlocked.
if session_key in self._running_agents:
del self._running_agents[session_key]
logger.info("HARD STOP (pending) for session %s — sentinel cleared", session_key[:20])
return "⚡ Force-stopped. The agent was still starting — session unlocked."
self.session_store.suspend_session(session_key)
logger.info("HARD STOP (pending) for session %s — suspended, sentinel cleared", session_key[:20])
return "⚡ Force-stopped. The agent was still starting — your next message will start fresh."
if agent:
agent.interrupt("Stop requested")
# Force-clean the session lock so a truly hung agent doesn't
# keep it locked forever.
if session_key in self._running_agents:
del self._running_agents[session_key]
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
self.session_store.suspend_session(session_key)
return "⚡ Force-stopped. Your next message will start a fresh session."
else:
return "No active task to stop."
@ -6694,6 +6718,8 @@ class GatewayRunner:
chat_id=context.source.chat_id,
chat_name=context.source.chat_name or "",
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
user_id=str(context.source.user_id) if context.source.user_id else "",
user_name=str(context.source.user_name) if context.source.user_name else "",
)
def _clear_session_env(self, tokens: list) -> None:
@ -6906,6 +6932,8 @@ class GatewayRunner:
platform_name = watcher.get("platform", "")
chat_id = watcher.get("chat_id", "")
thread_id = watcher.get("thread_id", "")
user_id = watcher.get("user_id", "")
user_name = watcher.get("user_name", "")
agent_notify = watcher.get("notify_on_complete", False)
notify_mode = self._load_background_notifications_mode()
@ -6961,6 +6989,8 @@ class GatewayRunner:
platform=_platform_enum,
chat_id=chat_id,
thread_id=thread_id or None,
user_id=user_id or None,
user_name=user_name or None,
)
synth_event = MessageEvent(
text=synth_text,
@ -8115,17 +8145,16 @@ class GatewayRunner:
# Get pending message from adapter.
# Use session_key (not source.chat_id) to match adapter's storage keys.
pending_event = None
pending = None
if result and adapter and session_key:
if result.get("interrupted"):
pending = _dequeue_pending_text(adapter, session_key)
if not pending and result.get("interrupt_message"):
pending = result.get("interrupt_message")
else:
pending = _dequeue_pending_text(adapter, session_key)
if pending:
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
pending_event = _dequeue_pending_event(adapter, session_key)
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
pending = result.get("interrupt_message")
elif pending_event:
pending = pending_event.text or _build_media_placeholder(pending_event)
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
# Safety net: if the pending text is a slash command (e.g. "/stop",
# "/new"), discard it — commands should never be passed to the agent
# as user input. The primary fix is in base.py (commands bypass the
@ -8143,27 +8172,29 @@ class GatewayRunner:
"commands must not be passed as agent input",
_pending_cmd_word,
)
pending_event = None
pending = None
except Exception:
pass
if self._draining and pending:
if self._draining and (pending_event or pending):
logger.info(
"Discarding pending follow-up for session %s during gateway %s",
session_key[:20] if session_key else "?",
self._status_action_label(),
)
pending_event = None
pending = None
if pending:
if pending_event or pending:
logger.debug("Processing pending message: '%s...'", pending[:40])
# Clear the adapter's interrupt event so the next _run_agent call
# doesn't immediately re-trigger the interrupt before the new agent
# even makes its first API call (this was causing an infinite loop).
if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions:
adapter._active_sessions[session_key].clear()
# Cap recursion depth to prevent resource exhaustion when the
# user sends multiple messages while the agent keeps failing. (#816)
if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH:
@ -8172,9 +8203,10 @@ class GatewayRunner:
"queueing message instead of recursing.",
_interrupt_depth, session_key,
)
# Queue the pending message for normal processing on next turn
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, 'queue_message'):
if adapter and pending_event:
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
elif adapter and hasattr(adapter, 'queue_message'):
adapter.queue_message(session_key, pending)
return result_holder[0] or {"final_response": response, "messages": history}
@ -8189,23 +8221,37 @@ class GatewayRunner:
if first_response and not _already_streamed:
try:
await adapter.send(source.chat_id, first_response,
metadata=getattr(event, "metadata", None))
metadata={"thread_id": source.thread_id} if source.thread_id else None)
except Exception as e:
logger.warning("Failed to send first response before queued message: %s", e)
# else: interrupted — discard the interrupted response ("Operation
# interrupted." is just noise; the user already knows they sent a
# new message).
# Process the pending message with updated history
updated_history = result.get("messages", history)
next_source = source
next_message = pending
next_message_id = None
if pending_event is not None:
next_source = getattr(pending_event, "source", None) or source
next_message = await self._prepare_inbound_message_text(
event=pending_event,
source=next_source,
history=updated_history,
)
if next_message is None:
return result
next_message_id = getattr(pending_event, "message_id", None)
return await self._run_agent(
message=pending,
message=next_message,
context_prompt=context_prompt,
history=updated_history,
source=source,
source=next_source,
session_id=session_id,
session_key=session_key,
_interrupt_depth=_interrupt_depth + 1,
event_message_id=next_message_id,
)
finally:
# Stop progress sender, interrupt monitor, and notification task

View file

@ -368,6 +368,11 @@ class SessionEntry:
# survives gateway restarts (the old in-memory _pre_flushed_sessions
# set was lost on restart, causing redundant re-flushes).
memory_flushed: bool = False
# When True the next call to get_or_create_session() will auto-reset
# this session (create a new session_id) so the user starts fresh.
# Set by /stop to break stuck-resume loops (#7536).
suspended: bool = False
def to_dict(self) -> Dict[str, Any]:
result = {
@ -387,6 +392,7 @@ class SessionEntry:
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
"suspended": self.suspended,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@ -423,6 +429,7 @@ class SessionEntry:
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
suspended=data.get("suspended", False),
)
@ -698,7 +705,12 @@ class SessionStore:
if session_key in self._entries and not force_new:
entry = self._entries[session_key]
reset_reason = self._should_reset(entry, source)
# Auto-reset sessions marked as suspended (e.g. after /stop
# broke a stuck loop — #7536).
if entry.suspended:
reset_reason = "suspended"
else:
reset_reason = self._should_reset(entry, source)
if not reset_reason:
entry.updated_at = now
self._save()
@ -771,6 +783,44 @@ class SessionStore:
entry.last_prompt_tokens = last_prompt_tokens
self._save()
def suspend_session(self, session_key: str) -> bool:
"""Mark a session as suspended so it auto-resets on next access.
Used by ``/stop`` to prevent stuck sessions from being resumed
after a gateway restart (#7536). Returns True if the session
existed and was marked.
"""
with self._lock:
self._ensure_loaded_locked()
if session_key in self._entries:
self._entries[session_key].suspended = True
self._save()
return True
return False
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
"""Mark recently-active sessions as suspended.
Called on gateway startup to prevent sessions that were likely
in-flight when the gateway last exited from being blindly resumed
(#7536). Only suspends sessions updated within *max_age_seconds*
to avoid resetting long-idle sessions that are harmless to resume.
Returns the number of sessions that were suspended.
"""
import time as _time
cutoff = _time.time() - max_age_seconds
count = 0
with self._lock:
self._ensure_loaded_locked()
for entry in self._entries.values():
if not entry.suspended and entry.updated_at >= cutoff:
entry.suspended = True
count += 1
if count:
self._save()
return count
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
"""Force reset a session, creating a new session ID."""
db_end_session_id = None

View file

@ -46,12 +46,16 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="")
_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="")
_VAR_MAP = {
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
"HERMES_SESSION_USER_ID": _SESSION_USER_ID,
"HERMES_SESSION_USER_NAME": _SESSION_USER_NAME,
}
@ -60,6 +64,8 @@ def set_session_vars(
chat_id: str = "",
chat_name: str = "",
thread_id: str = "",
user_id: str = "",
user_name: str = "",
) -> list:
"""Set all session context variables and return reset tokens.
@ -74,6 +80,8 @@ def set_session_vars(
_SESSION_CHAT_ID.set(chat_id),
_SESSION_CHAT_NAME.set(chat_name),
_SESSION_THREAD_ID.set(thread_id),
_SESSION_USER_ID.set(user_id),
_SESSION_USER_NAME.set(user_name),
]
return tokens
@ -87,6 +95,8 @@ def clear_session_vars(tokens: list) -> None:
_SESSION_CHAT_ID,
_SESSION_CHAT_NAME,
_SESSION_THREAD_ID,
_SESSION_USER_ID,
_SESSION_USER_NAME,
]
for var, token in zip(vars_in_order, tokens):
var.reset(token)

View file

@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
}
# =============================================================================
# Anthropic Key Helper
# =============================================================================
def get_anthropic_key() -> str:
"""Return the first usable Anthropic credential, or ``""``.
Checks both the ``.env`` file (via ``get_env_value``) and the process
environment (``os.getenv``). The fallback order mirrors the
``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
"""
from hermes_cli.config import get_env_value
for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
value = get_env_value(var) or os.getenv(var, "")
if value:
return value
return ""
# =============================================================================
# Kimi Code Endpoint Detection
# =============================================================================

View file

@ -52,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = (
# Known OpenClaw directory names (current + legacy)
_OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot")
def _warn_if_gateway_running(auto_yes: bool) -> None:
"""Check if a Hermes gateway is running with connected platforms.
Migrating bot tokens while the gateway is polling will cause conflicts
(e.g. Telegram 409 "terminated by other getUpdates request"). Warn the
user and let them decide whether to continue.
"""
from gateway.status import get_running_pid, read_runtime_status
if not get_running_pid():
return
data = read_runtime_status() or {}
platforms = data.get("platforms") or {}
connected = [name for name, info in platforms.items()
if isinstance(info, dict) and info.get("state") == "connected"]
if not connected:
return
print()
print_error(
"Hermes gateway is running with active connections: "
+ ", ".join(connected)
)
print_info(
"Migrating bot tokens while the gateway is active will cause "
"conflicts (Telegram, Discord, and Slack only allow one active "
"session per token)."
)
print_info("Recommendation: stop the gateway first with 'hermes stop'.")
print()
if not auto_yes and not prompt_yes_no("Continue anyway?", default=False):
print_info("Migration cancelled. Stop the gateway and try again.")
sys.exit(0)
# State files commonly found in OpenClaw workspace directories that cause
# confusion after migration (the agent discovers them and writes to them)
_WORKSPACE_STATE_GLOBS = (
@ -252,6 +287,10 @@ def _cmd_migrate(args):
print_info(f"Workspace: {workspace_target}")
print()
# Check if a gateway is running with connected platforms — migrating tokens
# while the gateway is active will cause conflicts (e.g. Telegram 409).
_warn_if_gateway_running(auto_yes)
# Ensure config.yaml exists before migration tries to read it
config_path = get_config_path()
if not config_path.exists():

79
hermes_cli/cli_output.py Normal file
View file

@ -0,0 +1,79 @@
"""Shared CLI output helpers for Hermes CLI modules.
Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
functions previously duplicated across setup.py, tools_config.py,
mcp_config.py, and memory_setup.py.
"""
import getpass
import sys
from hermes_cli.colors import Colors, color
# ─── Print Helpers ────────────────────────────────────────────────────────────
def print_info(text: str) -> None:
"""Print a dim informational message."""
print(color(f" {text}", Colors.DIM))
def print_success(text: str) -> None:
"""Print a green success message with ✓ prefix."""
print(color(f"{text}", Colors.GREEN))
def print_warning(text: str) -> None:
"""Print a yellow warning message with ⚠ prefix."""
print(color(f"{text}", Colors.YELLOW))
def print_error(text: str) -> None:
"""Print a red error message with ✗ prefix."""
print(color(f"{text}", Colors.RED))
def print_header(text: str) -> None:
"""Print a bold yellow header."""
print(color(f"\n {text}", Colors.YELLOW))
# ─── Input Prompts ────────────────────────────────────────────────────────────
def prompt(
question: str,
default: str | None = None,
password: bool = False,
) -> str:
"""Prompt the user for input with optional default and password masking.
Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
Returns the user's input (stripped), or *default* if the user presses Enter.
Returns empty string on Ctrl-C or EOF.
"""
suffix = f" [{default}]" if default else ""
display = color(f" {question}{suffix}: ", Colors.YELLOW)
try:
if password:
value = getpass.getpass(display)
else:
value = input(display)
value = value.strip()
return value if value else (default or "")
except (KeyboardInterrupt, EOFError):
print()
return ""
def prompt_yes_no(question: str, default: bool = True) -> bool:
"""Prompt for a yes/no answer. Returns bool."""
hint = "Y/n" if default else "y/N"
answer = prompt(f"{question} ({hint})")
if not answer:
return default
return answer.lower().startswith("y")

View file

@ -1497,7 +1497,7 @@ _KNOWN_ROOT_KEYS = {
# Valid fields inside a custom_providers list entry
_VALID_CUSTOM_PROVIDER_FIELDS = {
"name", "base_url", "api_key", "api_mode", "models",
"name", "base_url", "api_key", "api_mode", "model", "models",
"context_length", "rate_limit_delay",
}
@ -2582,7 +2582,8 @@ def show_config():
for env_key, name in keys:
value = get_env_value(env_key)
print(f" {name:<14} {redact_key(value)}")
anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY")
from hermes_cli.auth import get_anthropic_key
anthropic_value = get_anthropic_key()
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
# Model settings
@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str):
# Write only user config back (not the full merged defaults)
ensure_hermes_home()
with open(config_path, 'w', encoding="utf-8") as f:
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
from utils import atomic_yaml_write
atomic_yaml_write(config_path, user_config, sort_keys=False)
# Keep .env in sync for keys that terminal_tool reads directly from env vars.
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.

View file

@ -336,8 +336,8 @@ def run_doctor(args):
model_section[k] = raw_config.pop(k)
else:
raw_config.pop(k)
with open(config_path, "w") as f:
yaml.dump(raw_config, f, default_flow_style=False)
from utils import atomic_yaml_write
atomic_yaml_write(config_path, raw_config)
check_ok("Migrated stale root-level keys into model section")
fixed_count += 1
else:
@ -686,7 +686,8 @@ def run_doctor(args):
else:
check_warn("OpenRouter API", "(not configured)")
anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY")
from hermes_cli.auth import get_anthropic_key
anthropic_key = get_anthropic_key()
if anthropic_key:
print(" Checking Anthropic API...", end="", flush=True)
try:

View file

@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool:
return True
def find_gateway_pids(exclude_pids: set | None = None) -> list:
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
"""Find PIDs of running gateway processes.
Args:
exclude_pids: PIDs to exclude from the result (e.g. service-managed
PIDs that should not be killed during a stale-process sweep).
all_profiles: When ``True``, return gateway PIDs across **all**
profiles (the pre-7923 global behaviour). ``hermes update``
needs this because a code update affects every profile.
When ``False`` (default), only PIDs belonging to the current
Hermes profile are returned.
"""
pids = []
_exclude = exclude_pids or set()
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
patterns = [
"hermes_cli.main gateway",
"hermes_cli.main --profile",
"hermes_cli.main -p",
"hermes_cli/main.py gateway",
"hermes_cli/main.py --profile",
"hermes_cli/main.py -p",
"hermes gateway",
"gateway/run.py",
]
current_home = str(get_hermes_home().resolve())
current_profile_arg = _profile_arg(current_home)
current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else ""
def _matches_current_profile(command: str) -> bool:
if current_profile_name:
return (
f"--profile {current_profile_name}" in command
or f"-p {current_profile_name}" in command
or f"HERMES_HOME={current_home}" in command
)
if "--profile " in command or " -p " in command:
return False
if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command:
return False
return True
try:
if is_windows():
# Windows: use wmic to search command lines
result = subprocess.run(
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
capture_output=True, text=True, timeout=10
)
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
current_cmd = ""
for line in result.stdout.split('\n'):
line = line.strip()
@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
current_cmd = line[len("CommandLine="):]
elif line.startswith("ProcessId="):
pid_str = line[len("ProcessId="):]
if any(p in current_cmd for p in patterns):
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
try:
pid = int(pid_str)
if pid != os.getpid() and pid not in pids and pid not in _exclude:
@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
current_cmd = ""
else:
result = subprocess.run(
["ps", "aux"],
["ps", "eww", "-ax", "-o", "pid=,command="],
capture_output=True,
text=True,
timeout=10,
)
for line in result.stdout.split('\n'):
# Skip grep and current process
if 'grep' in line or str(os.getpid()) in line:
stripped = line.strip()
if not stripped or 'grep' in stripped:
continue
for pattern in patterns:
if pattern in line:
parts = line.split()
if len(parts) > 1:
try:
pid = int(parts[1])
if pid not in pids and pid not in _exclude:
pids.append(pid)
except ValueError:
continue
break
except Exception:
pid = None
command = ""
parts = stripped.split(None, 1)
if len(parts) == 2:
try:
pid = int(parts[0])
command = parts[1]
except ValueError:
pid = None
if pid is None:
aux_parts = stripped.split()
if len(aux_parts) > 10 and aux_parts[1].isdigit():
pid = int(aux_parts[1])
command = " ".join(aux_parts[10:])
if pid is None:
continue
if pid == os.getpid() or pid in pids or pid in _exclude:
continue
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
pids.append(pid)
except (OSError, subprocess.TimeoutExpired):
pass
return pids
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int:
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
all_profiles: bool = False) -> int:
"""Kill any running gateway processes. Returns count killed.
Args:
force: Use the platform's force-kill mechanism instead of graceful terminate.
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
restarted and should not be killed).
all_profiles: When ``True``, kill across all profiles. Passed
through to :func:`find_gateway_pids`.
"""
pids = find_gateway_pids(exclude_pids=exclude_pids)
pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles)
killed = 0
for pid in pids:
@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None:
print(" If you want the gateway user service to survive logout, run:")
print(" sudo loginctl enable-linger $USER")
def _launchd_user_home() -> Path:
"""Return the real macOS user home for launchd artifacts.
Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but
launchd user agents still live under the actual account home.
"""
import pwd
return Path(pwd.getpwuid(os.getuid()).pw_dir)
def get_launchd_plist_path() -> Path:
"""Return the launchd plist path, scoped per profile.
@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path:
"""
suffix = _profile_suffix()
name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway"
return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist"
return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist"
def _detect_venv_dir() -> Path | None:
"""Detect the active virtualenv directory.
@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str:
return "\n".join(line.rstrip() for line in text.strip().splitlines())
def _normalize_launchd_plist_for_comparison(text: str) -> str:
"""Normalize launchd plist text for staleness checks.
The generated plist intentionally captures a broad PATH assembled from the
invoking shell so user-installed tools remain reachable under launchd.
That makes raw text comparison unstable across shells, so ignore the PATH
payload when deciding whether the installed plist is stale.
"""
import re
normalized = _normalize_service_definition(text)
return re.sub(
r'(<key>PATH</key>\s*<string>)(.*?)(</string>)',
r'\1__HERMES_PATH__\3',
normalized,
flags=re.S,
)
def systemd_unit_is_current(system: bool = False) -> bool:
unit_path = get_systemd_unit_path(system=system)
if not unit_path.exists():
@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool:
installed = plist_path.read_text(encoding="utf-8")
expected = generate_launchd_plist()
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected)
def refresh_launchd_plist_if_needed() -> bool:
@ -1981,6 +2051,36 @@ def _setup_whatsapp():
cmd_whatsapp(argparse.Namespace())
def _setup_email():
"""Configure Email via the standard platform setup."""
email_platform = next(p for p in _PLATFORMS if p["key"] == "email")
_setup_standard_platform(email_platform)
def _setup_sms():
"""Configure SMS (Twilio) via the standard platform setup."""
sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms")
_setup_standard_platform(sms_platform)
def _setup_dingtalk():
"""Configure DingTalk via the standard platform setup."""
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
_setup_standard_platform(dingtalk_platform)
def _setup_feishu():
"""Configure Feishu / Lark via the standard platform setup."""
feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu")
_setup_standard_platform(feishu_platform)
def _setup_wecom():
"""Configure WeCom (Enterprise WeChat) via the standard platform setup."""
wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom")
_setup_standard_platform(wecom_platform)
def _is_service_installed() -> bool:
"""Check if the gateway is installed as a system service."""
if supports_systemd_services():
@ -2540,7 +2640,7 @@ def gateway_command(args):
service_available = True
except subprocess.CalledProcessError:
pass
killed = kill_gateway_processes()
killed = kill_gateway_processes(all_profiles=True)
total = killed + (1 if service_available else 0)
if total:
print(f"✓ Stopped {total} gateway process(es) across all profiles")

View file

@ -2758,13 +2758,8 @@ def _model_flow_anthropic(config, current_model=""):
from hermes_cli.models import _PROVIDER_MODELS
# Check ALL credential sources
existing_key = (
get_env_value("ANTHROPIC_TOKEN")
or os.getenv("ANTHROPIC_TOKEN", "")
or get_env_value("ANTHROPIC_API_KEY")
or os.getenv("ANTHROPIC_API_KEY", "")
or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
)
from hermes_cli.auth import get_anthropic_key
existing_key = get_anthropic_key()
cc_available = False
try:
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid
@ -4090,7 +4085,7 @@ def cmd_update(args):
# Exclude PIDs that belong to just-restarted services so we don't
# immediately kill the process that systemd/launchd just spawned.
service_pids = _get_service_pids()
manual_pids = find_gateway_pids(exclude_pids=service_pids)
manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True)
for pid in manual_pids:
try:
os.kill(pid, _signal.SIGTERM)

View file

@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool:
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
display = f" {question}"
if default:
display += f" [{default}]"
display += ": "
try:
if password:
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default
except (KeyboardInterrupt, EOFError):
print()
return default
from hermes_cli.cli_output import prompt as _shared_prompt
return _shared_prompt(question, default=default, password=password)
# ─── Config Helpers ───────────────────────────────────────────────────────────

View file

@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
items: list of (label, description) tuples.
Returns selected index, or default on escape/quit.
"""
try:
import curses
result = [default]
def _menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
curses.init_pair(3, curses.COLOR_CYAN, -1)
cursor = default
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Title
try:
stdscr.addnstr(0, 0, title, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
except curses.error:
pass
for i, (label, desc) in enumerate(items):
y = i + 3
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {label}"
if desc:
line += f" {desc}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
cursor = (cursor - 1) % len(items)
elif key in (curses.KEY_DOWN, ord('j')):
cursor = (cursor + 1) % len(items)
elif key in (curses.KEY_ENTER, 10, 13):
result[0] = cursor
return
elif key in (27, ord('q')):
return
curses.wrapper(_menu)
return result[0]
except Exception:
# Fallback: numbered input
print(f"\n {title}\n")
for i, (label, desc) in enumerate(items):
marker = "" if i == default else " "
d = f" {desc}" if desc else ""
print(f" {marker} {i + 1}. {label}{d}")
while True:
try:
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(items):
return idx
except (ValueError, EOFError):
return default
from hermes_cli.curses_ui import curses_radiolist
# Format (label, desc) tuples into display strings
display_items = [
f"{label} {desc}" if desc else label
for label, desc in items
]
return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:

45
hermes_cli/platforms.py Normal file
View file

@ -0,0 +1,45 @@
"""
Shared platform registry for Hermes Agent.
Single source of truth for platform metadata consumed by both
skills_config (label display) and tools_config (default toolset
resolution). Import ``PLATFORMS`` from here instead of maintaining
duplicate dicts in each module.
"""
from collections import OrderedDict
from typing import NamedTuple
class PlatformInfo(NamedTuple):
"""Metadata for a single platform entry."""
label: str
default_toolset: str
# Ordered so that TUI menus are deterministic.
PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
])
def platform_label(key: str, default: str = "") -> str:
"""Return the display label for a platform key, or *default*."""
info = PLATFORMS.get(key)
return info.label if info is not None else default

View file

@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
api_mode = _parse_api_mode(entry.get("api_mode"))
if api_mode:
result["api_mode"] = api_mode
model_name = str(entry.get("model", "") or "").strip()
if model_name:
result["model"] = model_name
return result
return None
@ -329,6 +332,11 @@ def _resolve_named_custom_runtime(
# Check if a credential pool exists for this custom endpoint
pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode"))
if pool_result:
# Propagate the model name even when using pooled credentials —
# the pool doesn't know about the custom_providers model field.
model_name = custom_provider.get("model")
if model_name:
pool_result["model"] = model_name
return pool_result
api_key_candidates = [
@ -339,7 +347,7 @@ def _resolve_named_custom_runtime(
]
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
return {
result = {
"provider": "custom",
"api_mode": custom_provider.get("api_mode")
or _detect_api_mode_for_url(base_url)
@ -348,6 +356,11 @@ def _resolve_named_custom_runtime(
"api_key": api_key or "no-key-required",
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
}
# Propagate the model name so callers can override self.model when the
# provider name differs from the actual model string the API expects.
if custom_provider.get("model"):
result["model"] = custom_provider["model"]
return result
def _resolve_openrouter_runtime(

View file

@ -197,24 +197,12 @@ def print_header(title: str):
print(color(f"{title}", Colors.CYAN, Colors.BOLD))
def print_info(text: str):
"""Print info text."""
print(color(f" {text}", Colors.DIM))
def print_success(text: str):
"""Print success message."""
print(color(f"{text}", Colors.GREEN))
def print_warning(text: str):
"""Print warning message."""
print(color(f"{text}", Colors.YELLOW))
def print_error(text: str):
"""Print error message."""
print(color(f"{text}", Colors.RED))
from hermes_cli.cli_output import ( # noqa: E402
print_error,
print_info,
print_success,
print_warning,
)
def is_interactive_stdin() -> bool:
@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu using curses to avoid simple_term_menu rendering bugs."""
try:
import curses
result_holder = [default]
def _curses_menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
cursor = default
scroll_offset = 0
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
# Rows available for list items: rows 2..(max_y-2) inclusive.
visible = max(1, max_y - 3)
# Scroll the viewport so the cursor is always visible.
if cursor < scroll_offset:
scroll_offset = cursor
elif cursor >= scroll_offset + visible:
scroll_offset = cursor - visible + 1
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
try:
stdscr.addnstr(
0,
0,
question,
max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
)
except curses.error:
pass
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
y = row + 2
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {choices[i]}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord("k")):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord("j")):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord("q")):
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
return -1
"""Single-select menu using curses. Delegates to curses_radiolist."""
from hermes_cli.curses_ui import curses_radiolist
return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
@ -2052,6 +1969,42 @@ def _setup_weixin():
_gateway_setup_weixin()
def _setup_signal():
"""Configure Signal via gateway setup."""
from hermes_cli.gateway import _setup_signal as _gateway_setup_signal
_gateway_setup_signal()
def _setup_email():
"""Configure Email via gateway setup."""
from hermes_cli.gateway import _setup_email as _gateway_setup_email
_gateway_setup_email()
def _setup_sms():
"""Configure SMS (Twilio) via gateway setup."""
from hermes_cli.gateway import _setup_sms as _gateway_setup_sms
_gateway_setup_sms()
def _setup_dingtalk():
"""Configure DingTalk via gateway setup."""
from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk
_gateway_setup_dingtalk()
def _setup_feishu():
"""Configure Feishu / Lark via gateway setup."""
from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu
_gateway_setup_feishu()
def _setup_wecom():
"""Configure WeCom (Enterprise WeChat) via gateway setup."""
from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom
_gateway_setup_wecom()
def _setup_bluebubbles():
"""Configure BlueBubbles iMessage gateway."""
print_header("BlueBubbles (iMessage)")
@ -2168,9 +2121,15 @@ _GATEWAY_PLATFORMS = [
("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram),
("Discord", "DISCORD_BOT_TOKEN", _setup_discord),
("Slack", "SLACK_BOT_TOKEN", _setup_slack),
("Signal", "SIGNAL_HTTP_URL", _setup_signal),
("Email", "EMAIL_ADDRESS", _setup_email),
("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms),
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk),
("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu),
("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom),
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
@ -2212,10 +2171,17 @@ def setup_gateway(config: dict):
get_env_value("TELEGRAM_BOT_TOKEN")
or get_env_value("DISCORD_BOT_TOKEN")
or get_env_value("SLACK_BOT_TOKEN")
or get_env_value("SIGNAL_HTTP_URL")
or get_env_value("EMAIL_ADDRESS")
or get_env_value("TWILIO_ACCOUNT_SID")
or get_env_value("MATTERMOST_TOKEN")
or get_env_value("MATRIX_ACCESS_TOKEN")
or get_env_value("MATRIX_PASSWORD")
or get_env_value("WHATSAPP_ENABLED")
or get_env_value("DINGTALK_CLIENT_ID")
or get_env_value("FEISHU_APP_ID")
or get_env_value("WECOM_BOT_ID")
or get_env_value("WEIXIN_ACCOUNT_ID")
or get_env_value("BLUEBUBBLES_SERVER_URL")
or get_env_value("WEBHOOK_ENABLED")
)
@ -2404,12 +2370,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
platforms.append("Discord")
if get_env_value("SLACK_BOT_TOKEN"):
platforms.append("Slack")
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
platforms.append("WhatsApp")
if get_env_value("SIGNAL_ACCOUNT"):
platforms.append("Signal")
if get_env_value("EMAIL_ADDRESS"):
platforms.append("Email")
if get_env_value("TWILIO_ACCOUNT_SID"):
platforms.append("SMS")
if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"):
platforms.append("Matrix")
if get_env_value("MATTERMOST_TOKEN"):
platforms.append("Mattermost")
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
platforms.append("WhatsApp")
if get_env_value("DINGTALK_CLIENT_ID"):
platforms.append("DingTalk")
if get_env_value("FEISHU_APP_ID"):
platforms.append("Feishu")
if get_env_value("WECOM_BOT_ID"):
platforms.append("WeCom")
if get_env_value("WEIXIN_ACCOUNT_ID"):
platforms.append("Weixin")
if get_env_value("BLUEBUBBLES_SERVER_URL"):
platforms.append("BlueBubbles")
if get_env_value("WEBHOOK_ENABLED"):
platforms.append("Webhooks")
if platforms:
return ", ".join(platforms)
return None # No platforms configured — section must run

View file

@ -15,25 +15,12 @@ from typing import List, Optional, Set
from hermes_cli.config import load_config, save_config
from hermes_cli.colors import Colors, color
from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
PLATFORMS = {
"cli": "🖥️ CLI",
"telegram": "📱 Telegram",
"discord": "💬 Discord",
"slack": "💼 Slack",
"whatsapp": "📱 WhatsApp",
"signal": "📡 Signal",
"bluebubbles": "💬 BlueBubbles",
"email": "📧 Email",
"homeassistant": "🏠 Home Assistant",
"mattermost": "💬 Mattermost",
"matrix": "💬 Matrix",
"dingtalk": "💬 DingTalk",
"feishu": "🪽 Feishu",
"wecom": "💬 WeCom",
"weixin": "💬 Weixin",
"webhook": "🔗 Webhook",
}
# Backward-compatible view: {key: label_string} so existing code that
# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
# working without changes to every call site.
PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
# ─── Config Helpers ───────────────────────────────────────────────────────────

View file

@ -141,11 +141,8 @@ def show_status(args):
display = redact_key(value) if not show_all else value
print(f" {name:<12} {check_mark(has_key)} {display}")
anthropic_value = (
get_env_value("ANTHROPIC_TOKEN")
or get_env_value("ANTHROPIC_API_KEY")
or ""
)
from hermes_cli.auth import get_anthropic_key
anthropic_value = get_anthropic_key()
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")

View file

@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
def _print_info(text: str):
print(color(f" {text}", Colors.DIM))
def _print_success(text: str):
print(color(f"{text}", Colors.GREEN))
def _print_warning(text: str):
print(color(f"{text}", Colors.YELLOW))
def _print_error(text: str):
print(color(f"{text}", Colors.RED))
def _prompt(question: str, default: str = None, password: bool = False) -> str:
if default:
display = f"{question} [{default}]: "
else:
display = f"{question}: "
try:
if password:
import getpass
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default or ""
except (KeyboardInterrupt, EOFError):
print()
return default or ""
from hermes_cli.cli_output import ( # noqa: E402 — late import block
print_error as _print_error,
print_info as _print_info,
print_success as _print_success,
print_warning as _print_warning,
prompt as _prompt,
)
# ─── Toolset Registry ─────────────────────────────────────────────────────────
@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
except Exception:
return set()
# Platform display config
# Platform display config — derived from the canonical registry so every
# module shares the same data. Kept as dict-of-dicts for backward
# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
PLATFORMS = {
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
k: {"label": info.label, "default_toolset": info.default_toolset}
for k, info in _PLATFORMS_REGISTRY.items()
}
@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
rendering bugs in tmux, iTerm, and other non-standard terminals."""
# Curses-based single-select — works in tmux, iTerm, and standard terminals
try:
import curses
result_holder = [default]
def _curses_menu(stdscr):
curses.curs_set(0)
if curses.has_colors():
curses.start_color()
curses.use_default_colors()
curses.init_pair(1, curses.COLOR_GREEN, -1)
curses.init_pair(2, curses.COLOR_YELLOW, -1)
cursor = default
while True:
stdscr.clear()
max_y, max_x = stdscr.getmaxyx()
try:
stdscr.addnstr(0, 0, question, max_x - 1,
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
except curses.error:
pass
for i, c in enumerate(choices):
y = i + 2
if y >= max_y - 1:
break
arrow = "" if i == cursor else " "
line = f" {arrow} {c}"
attr = curses.A_NORMAL
if i == cursor:
attr = curses.A_BOLD
if curses.has_colors():
attr |= curses.color_pair(1)
try:
stdscr.addnstr(y, 0, line, max_x - 1, attr)
except curses.error:
pass
stdscr.refresh()
key = stdscr.getch()
if key in (curses.KEY_UP, ord('k')):
cursor = (cursor - 1) % len(choices)
elif key in (curses.KEY_DOWN, ord('j')):
cursor = (cursor + 1) % len(choices)
elif key in (curses.KEY_ENTER, 10, 13):
result_holder[0] = cursor
return
elif key in (27, ord('q')):
return
curses.wrapper(_curses_menu)
from hermes_cli.curses_ui import flush_stdin
flush_stdin()
return result_holder[0]
except Exception:
pass
# Fallback: numbered input (Windows without curses, etc.)
print(color(question, Colors.YELLOW))
for i, c in enumerate(choices):
marker = "" if i == default else ""
style = Colors.GREEN if i == default else ""
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
while True:
try:
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
if not val:
return default
idx = int(val) - 1
if 0 <= idx < len(choices):
return idx
except (ValueError, KeyboardInterrupt, EOFError):
print()
return default
"""Single-select menu (arrow keys). Delegates to curses_radiolist."""
from hermes_cli.curses_ui import curses_radiolist
return curses_radiolist(question, choices, selected=default, cancel_returns=default)
# ─── Token Estimation ────────────────────────────────────────────────────────

View file

@ -189,6 +189,33 @@ def is_wsl() -> bool:
return _wsl_detected
# ─── Well-Known Paths ─────────────────────────────────────────────────────────
def get_config_path() -> Path:
"""Return the path to ``config.yaml`` under HERMES_HOME.
Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
"""
return get_hermes_home() / "config.yaml"
def get_skills_dir() -> Path:
"""Return the path to the skills directory under HERMES_HOME."""
return get_hermes_home() / "skills"
def get_logs_dir() -> Path:
"""Return the path to the logs directory under HERMES_HOME."""
return get_hermes_home() / "logs"
def get_env_path() -> Path:
"""Return the path to the ``.env`` file under HERMES_HOME."""
return get_hermes_home() / ".env"
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"

View file

@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
from hermes_constants import get_hermes_home
from hermes_constants import get_config_path, get_hermes_home
# Sentinel to track whether setup_logging() has already run. The function
# is idempotent — calling it twice is safe but the second call is a no-op
@ -246,7 +246,7 @@ def _read_logging_config():
"""
try:
import yaml
config_path = get_hermes_home() / "config.yaml"
config_path = get_config_path()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}

View file

@ -16,7 +16,7 @@ crashes due to a bad timezone string.
import logging
import os
from datetime import datetime
from hermes_constants import get_hermes_home
from hermes_constants import get_config_path
from typing import Optional
logger = logging.getLogger(__name__)
@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
# 2. config.yaml ``timezone`` key
try:
import yaml
hermes_home = get_hermes_home()
config_path = hermes_home / "config.yaml"
config_path = get_config_path()
if config_path.exists():
with open(config_path) as f:
cfg = yaml.safe_load(f) or {}

View file

@ -739,6 +739,7 @@ class AIAgent:
# Interrupt mechanism for breaking out of tool loops
self._interrupt_requested = False
self._interrupt_message = None # Optional message that triggered interrupt
self._execution_thread_id: int | None = None # Set at run_conversation() start
self._client_lock = threading.RLock()
# Subagent delegation state
@ -2832,8 +2833,10 @@ class AIAgent:
"""
self._interrupt_requested = True
self._interrupt_message = message
# Signal all tools to abort any in-flight operations immediately
_set_interrupt(True)
# Signal all tools to abort any in-flight operations immediately.
# Scope the interrupt to this agent's execution thread so other
# agents running in the same process (gateway) are not affected.
_set_interrupt(True, self._execution_thread_id)
# Propagate interrupt to any running child agents (subagent delegation)
with self._active_children_lock:
children_copy = list(self._active_children)
@ -2846,10 +2849,10 @@ class AIAgent:
print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
def clear_interrupt(self) -> None:
"""Clear any pending interrupt request and the global tool interrupt signal."""
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
self._interrupt_requested = False
self._interrupt_message = None
_set_interrupt(False)
_set_interrupt(False, self._execution_thread_id)
def _touch_activity(self, desc: str) -> None:
"""Update the last-activity timestamp and description (thread-safe)."""
@ -3443,6 +3446,7 @@ class AIAgent:
def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert internal chat-style messages to Responses input items."""
items: List[Dict[str, Any]] = []
seen_item_ids: set = set()
for msg in messages:
if not isinstance(msg, dict):
@ -3463,7 +3467,12 @@ class AIAgent:
if isinstance(codex_reasoning, list):
for ri in codex_reasoning:
if isinstance(ri, dict) and ri.get("encrypted_content"):
item_id = ri.get("id")
if item_id and item_id in seen_item_ids:
continue
items.append(ri)
if item_id:
seen_item_ids.add(item_id)
has_codex_reasoning = True
if content_text.strip():
@ -3543,6 +3552,7 @@ class AIAgent:
raise ValueError("Codex Responses input must be a list of input items.")
normalized: List[Dict[str, Any]] = []
seen_ids: set = set()
for idx, item in enumerate(raw_items):
if not isinstance(item, dict):
raise ValueError(f"Codex Responses input[{idx}] must be an object.")
@ -3595,8 +3605,12 @@ class AIAgent:
if item_type == "reasoning":
encrypted = item.get("encrypted_content")
if isinstance(encrypted, str) and encrypted:
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
item_id = item.get("id")
if isinstance(item_id, str) and item_id:
if item_id in seen_ids:
continue
seen_ids.add(item_id)
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
if isinstance(item_id, str) and item_id:
reasoning_item["id"] = item_id
summary = item.get("summary")
@ -7800,6 +7814,11 @@ class AIAgent:
compression_attempts = 0
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
# Record the execution thread so interrupt()/clear_interrupt() can
# scope the tool-level interrupt signal to THIS agent's thread only.
# Must be set before clear_interrupt() which uses it.
self._execution_thread_id = threading.current_thread().ident
# Clear any stale interrupt state at start
self.clear_interrupt()
@ -8278,8 +8297,24 @@ class AIAgent:
_text_parts.append(getattr(_blk, "text", ""))
_trunc_content = "\n".join(_text_parts) if _text_parts else None
# A response is "thinking exhausted" only when the model
# actually produced reasoning blocks but no visible text after
# them. Models that do not use <think> tags (e.g. GLM-4.7 on
# NVIDIA Build, minimax) may return content=None or an empty
# string for unrelated reasons — treat those as normal
# truncations that deserve continuation retries, not as
# thinking-budget exhaustion.
_has_think_tags = bool(
_trunc_content and re.search(
r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>',
_trunc_content,
re.IGNORECASE,
)
)
_thinking_exhausted = (
not _trunc_has_tool_calls and (
not _trunc_has_tool_calls
and _has_think_tags
and (
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
or _trunc_content is None
)
@ -9507,12 +9542,41 @@ class AIAgent:
invalid_json_args.append((tc.function.name, str(e)))
if invalid_json_args:
# Check if the invalid JSON is due to truncation rather
# than a model formatting mistake. Routers sometimes
# rewrite finish_reason from "length" to "tool_calls",
# hiding the truncation from the length handler above.
# Detect truncation: args that don't end with } or ]
# (after stripping whitespace) are cut off mid-stream.
_truncated = any(
not (tc.function.arguments or "").rstrip().endswith(("}", "]"))
for tc in assistant_message.tool_calls
if tc.function.name in {n for n, _ in invalid_json_args}
)
if _truncated:
self._vprint(
f"{self.log_prefix}⚠️ Truncated tool call arguments detected "
f"(finish_reason={finish_reason!r}) — refusing to execute.",
force=True,
)
self._invalid_json_retries = 0
self._cleanup_task_resources(effective_task_id)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Response truncated due to output length limit",
}
# Track retries for invalid JSON arguments
self._invalid_json_retries += 1
tool_name, error_msg = invalid_json_args[0]
self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
if self._invalid_json_retries < 3:
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
# Don't add anything to messages, just retry the API call

View file

@ -8,7 +8,7 @@
"name": "hermes-whatsapp-bridge",
"version": "1.0.0",
"dependencies": {
"@whiskeysockets/baileys": "7.0.0-rc.9",
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
"express": "^4.21.0",
"pino": "^9.0.0",
"qrcode-terminal": "^0.12.0"
@ -730,21 +730,22 @@
}
},
"node_modules/@whiskeysockets/baileys": {
"name": "baileys",
"version": "7.0.0-rc.9",
"resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz",
"integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==",
"resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {
"@cacheable/node-cache": "^1.4.0",
"@hapi/boom": "^9.1.3",
"async-mutex": "^0.5.0",
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git",
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node",
"lru-cache": "^11.1.0",
"music-metadata": "^11.7.0",
"p-queue": "^9.0.0",
"pino": "^9.6",
"protobufjs": "^7.2.4",
"whatsapp-rust-bridge": "0.5.2",
"ws": "^8.13.0"
},
"engines": {
@ -2125,6 +2126,12 @@
"node": ">= 0.8"
}
},
"node_modules/whatsapp-rust-bridge": {
"version": "0.5.2",
"resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz",
"integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==",
"license": "MIT"
},
"node_modules/win-guid": {
"version": "0.2.1",
"resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz",

View file

@ -8,7 +8,7 @@
"start": "node bridge.js"
},
"dependencies": {
"@whiskeysockets/baileys": "7.0.0-rc.9",
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
"express": "^4.21.0",
"qrcode-terminal": "^0.12.0",
"pino": "^9.0.0"

View file

@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout:
"http://0.0.0.0:5000",
"http://192.168.1.100:8000",
"http://10.0.0.5:1234",
"http://host.docker.internal:11434",
"http://host.containers.internal:11434",
"http://host.lima.internal:11434",
])
def test_local_endpoint_bumps_read_timeout(self, base_url):
"""Local endpoint + default timeout -> bumps to base_timeout."""
@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout:
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
_stream_read_timeout = _base_timeout
assert _stream_read_timeout == 120.0
class TestIsLocalEndpoint:
"""Direct unit tests for is_local_endpoint."""
@pytest.mark.parametrize("url", [
"http://localhost:11434",
"http://127.0.0.1:8080",
"http://0.0.0.0:5000",
"http://[::1]:11434",
"http://192.168.1.100:8000",
"http://10.0.0.5:1234",
"http://172.17.0.1:11434",
])
def test_classic_local_addresses(self, url):
assert is_local_endpoint(url) is True
@pytest.mark.parametrize("url", [
"http://host.docker.internal:11434",
"http://host.docker.internal:8080/v1",
"http://gateway.docker.internal:11434",
"http://host.containers.internal:11434",
"http://host.lima.internal:11434",
])
def test_container_dns_names(self, url):
assert is_local_endpoint(url) is True
@pytest.mark.parametrize("url", [
"https://api.openai.com",
"https://openrouter.ai/api",
"https://api.anthropic.com",
"https://evil.docker.internal.example.com",
])
def test_remote_endpoints(self, url):
assert is_local_endpoint(url) is False

View file

@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
config = PlatformConfig(enabled=True, token="e2e-test-token")
if platform == Platform.DISCORD:
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
from gateway.platforms.helpers import ThreadParticipationTracker
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
adapter = DiscordAdapter(config)
platform_key = Platform.DISCORD
elif platform == Platform.SLACK:

View file

@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint:
)
assert resp.status == 200
assert "text/event-stream" in resp.headers.get("Content-Type", "")
assert resp.headers.get("X-Accel-Buffering") == "no"
body = await resp.text()
assert "data: " in body
assert "[DONE]" in body
assert "Hello!" in body
@pytest.mark.asyncio
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
"""Idle SSE streams should send keepalive comments while tools run silently."""
import asyncio
import gateway.platforms.api_server as api_server_mod
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
async def _mock_run_agent(**kwargs):
cb = kwargs.get("stream_delta_callback")
if cb:
cb("Working")
await asyncio.sleep(0.65)
cb("...done")
return (
{"final_response": "Working...done", "messages": [], "api_calls": 1},
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
)
with (
patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01),
patch.object(adapter, "_run_agent", side_effect=_mock_run_agent),
):
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "test",
"messages": [{"role": "user", "content": "do the thing"}],
"stream": True,
},
)
assert resp.status == 200
body = await resp.text()
assert ": keepalive" in body
assert "Working" in body
assert "...done" in body
assert "[DONE]" in body
@pytest.mark.asyncio
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.

View file

@ -119,28 +119,29 @@ class TestDeduplication:
def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
assert adapter._is_duplicate("msg-1") is False
assert adapter._dedup.is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True
adapter._dedup.is_duplicate("msg-1")
assert adapter._dedup.is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False
adapter._dedup.is_duplicate("msg-1")
assert adapter._dedup.is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self):
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
max_size = adapter._dedup._max_size
# Fill beyond max
for i in range(DEDUP_MAX_SIZE + 10):
adapter._is_duplicate(f"msg-{i}")
for i in range(max_size + 10):
adapter._dedup.is_duplicate(f"msg-{i}")
# Cache should have been pruned
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
assert len(adapter._dedup._seen) <= max_size + 10
# ---------------------------------------------------------------------------
@ -253,13 +254,13 @@ class TestConnect:
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x"
adapter._seen_messages["b"] = 1.0
adapter._dedup._seen["b"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
await adapter.disconnect()
assert len(adapter._session_webhooks) == 0
assert len(adapter._seen_messages) == 0
assert len(adapter._dedup._seen) == 0
assert adapter._http_client is None

View file

@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
assert ok is False
assert released == [("discord-bot-token", "test-token")]
assert adapter._token_lock_identity is None
assert adapter._platform_lock_identity is None

View file

@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
# Simulate bot having previously participated in thread 456
adapter._bot_participated_threads.add("456")
adapter._threads.mark("456")
thread = FakeThread(channel_id=456, name="existing thread")
message = make_message(channel=thread, content="follow-up without mention")
@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
await adapter._handle_message(message)
assert "555" in adapter._bot_participated_threads
assert "555" in adapter._threads
@pytest.mark.asyncio
@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
await adapter._handle_message(message)
assert "777" in adapter._bot_participated_threads
assert "777" in adapter._threads

View file

@ -1,6 +1,6 @@
"""Tests for Discord thread participation persistence.
Verifies that _bot_participated_threads survives adapter restarts by
Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
being persisted to ~/.hermes/discord_threads.json.
"""
@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
def test_starts_empty_when_no_state_file(self, tmp_path):
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
assert "$nonexistent" not in adapter._threads
def test_track_thread_persists_to_disk(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("222")
adapter._threads.mark("111")
adapter._threads.mark("222")
state_file = tmp_path / "discord_threads.json"
assert state_file.exists()
@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
"""Threads tracked by one adapter instance are visible to the next."""
adapter1 = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter1._track_thread("aaa")
adapter1._track_thread("bbb")
adapter1._threads.mark("aaa")
adapter1._threads.mark("bbb")
adapter2 = self._make_adapter(tmp_path)
assert "aaa" in adapter2._bot_participated_threads
assert "bbb" in adapter2._bot_participated_threads
assert "aaa" in adapter2._threads
assert "bbb" in adapter2._threads
def test_duplicate_track_does_not_double_save(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("111") # no-op
adapter._threads.mark("111")
adapter._threads.mark("111") # no-op
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert saved.count("111") == 1
def test_caps_at_max_tracked_threads(self, tmp_path):
adapter = self._make_adapter(tmp_path)
adapter._MAX_TRACKED_THREADS = 5
adapter._threads._max_tracked = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
for i in range(10):
adapter._track_thread(str(i))
adapter._threads.mark(str(i))
assert len(adapter._bot_participated_threads) == 5
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert len(saved) == 5
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json"
state_file.write_text("not valid json{{{")
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
assert "$nonexistent" not in adapter._threads
def test_missing_hermes_home_does_not_crash(self, tmp_path):
"""Load/save tolerate missing directories."""
fake_home = tmp_path / "nonexistent" / "deep"
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
from gateway.platforms.discord import DiscordAdapter
# _load should return empty set, not crash
threads = DiscordAdapter._load_participated_threads()
assert threads == set()
from gateway.platforms.helpers import ThreadParticipationTracker
# ThreadParticipationTracker should return empty set, not crash
tracker = ThreadParticipationTracker("discord")
assert "$test" not in tracker

View file

@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
)
@pytest.mark.asyncio
async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path):
"""Synthetic completion event should carry user_id and user_name from the watcher."""
import tools.process_registry as pr_module
sessions = [
SimpleNamespace(
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
),
]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path)
adapter = runner.adapters[Platform.DISCORD]
watcher = _watcher_dict_with_notify()
watcher["user_id"] = "user-42"
watcher["user_name"] = "alice"
await runner._run_process_watcher(watcher)
assert adapter.handle_message.await_count == 1
event = adapter.handle_message.await_args.args[0]
assert event.source.user_id == "user-42"
assert event.source.user_name == "alice"
@pytest.mark.asyncio
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
"""A non-internal event with user_id=None should be silently dropped."""
import gateway.run as gateway_run
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
runner = GatewayRunner(GatewayConfig())
adapter = SimpleNamespace(send=AsyncMock())
runner.adapters[Platform.TELEGRAM] = adapter
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
user_id=None,
)
event = MessageEvent(
text="service message",
source=source,
internal=False,
)
result = await runner._handle_message(event)
# Should return None (dropped) and NOT send any pairing message
assert result is None
assert adapter.send.await_count == 0
@pytest.mark.asyncio
async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path):
"""A message with user_id=None must never call generate_code."""
import gateway.run as gateway_run
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
runner = GatewayRunner(GatewayConfig())
adapter = SimpleNamespace(send=AsyncMock())
runner.adapters[Platform.DISCORD] = adapter
generate_called = False
original_generate = runner.pairing_store.generate_code
def tracking_generate(*args, **kwargs):
nonlocal generate_called
generate_called = True
return original_generate(*args, **kwargs)
runner.pairing_store.generate_code = tracking_generate
source = SessionSource(
platform=Platform.DISCORD,
chat_id="456",
chat_type="dm",
user_id=None,
)
event = MessageEvent(text="anonymous", source=source, internal=False)
await runner._handle_message(event)
assert not generate_called, (
"Pairing code should NOT be generated for messages with user_id=None"
)
@pytest.mark.asyncio
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""

View file

@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
adapter = _make_adapter()
adapter._bot_participated_threads.add("$thread1")
adapter._threads.mark("$thread1")
event = _make_event("hello without mention", thread_id="$thread1")
@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter()
adapter._bot_participated_threads.add("$thread_root")
adapter._threads.mark("$thread_root")
event = _make_event("reply in thread", thread_id="$thread_root")
await adapter._on_room_message(event)
@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
@pytest.mark.asyncio
async def test_auto_thread_tracks_participation(monkeypatch):
"""Auto-created threads are tracked in _bot_participated_threads."""
"""Auto-created threads are tracked in _threads."""
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
adapter = _make_adapter()
event = _make_event("hello", event_id="$msg1")
with patch.object(adapter, "_save_participated_threads"):
with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
assert "$msg1" in adapter._bot_participated_threads
assert "$msg1" in adapter._threads
# ---------------------------------------------------------------------------
@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
class TestThreadPersistence:
def test_empty_state_file(self, tmp_path, monkeypatch):
"""No state file → empty set."""
from gateway.platforms.matrix import MatrixAdapter
from gateway.platforms.helpers import ThreadParticipationTracker
monkeypatch.setattr(
MatrixAdapter, "_thread_state_path",
staticmethod(lambda: tmp_path / "matrix_threads.json"),
ThreadParticipationTracker, "_state_path",
lambda self: tmp_path / "matrix_threads.json",
)
adapter = _make_adapter()
loaded = adapter._load_participated_threads()
assert loaded == set()
assert "$nonexistent" not in adapter._threads
def test_track_thread_persists(self, tmp_path, monkeypatch):
"""_track_thread writes to disk."""
from gateway.platforms.matrix import MatrixAdapter
"""mark() writes to disk."""
from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr(
MatrixAdapter, "_thread_state_path",
staticmethod(lambda: state_path),
ThreadParticipationTracker, "_state_path",
lambda self: state_path,
)
adapter = _make_adapter()
adapter._track_thread("$thread_abc")
adapter._threads.mark("$thread_abc")
data = json.loads(state_path.read_text())
assert "$thread_abc" in data
def test_threads_survive_reload(self, tmp_path, monkeypatch):
"""Persisted threads are loaded by a new adapter instance."""
from gateway.platforms.matrix import MatrixAdapter
from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
state_path.write_text(json.dumps(["$t1", "$t2"]))
monkeypatch.setattr(
MatrixAdapter, "_thread_state_path",
staticmethod(lambda: state_path),
ThreadParticipationTracker, "_state_path",
lambda self: state_path,
)
adapter = _make_adapter()
assert "$t1" in adapter._bot_participated_threads
assert "$t2" in adapter._bot_participated_threads
assert "$t1" in adapter._threads
assert "$t2" in adapter._threads
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
"""Thread set is trimmed to _MAX_TRACKED_THREADS."""
from gateway.platforms.matrix import MatrixAdapter
"""Thread set is trimmed to max_tracked."""
from gateway.platforms.helpers import ThreadParticipationTracker
state_path = tmp_path / "matrix_threads.json"
monkeypatch.setattr(
MatrixAdapter, "_thread_state_path",
staticmethod(lambda: state_path),
ThreadParticipationTracker, "_state_path",
lambda self: state_path,
)
adapter = _make_adapter()
adapter._MAX_TRACKED_THREADS = 5
adapter._threads._max_tracked = 5
for i in range(10):
adapter._bot_participated_threads.add(f"$t{i}")
adapter._save_participated_threads()
adapter._threads.mark(f"$t{i}")
data = json.loads(state_path.read_text())
assert len(data) == 5
@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
_set_dm(adapter)
event = _make_event("@hermes:example.org help me", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"):
with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
adapter.handle_message.assert_awaited_once()
@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
adapter = _make_adapter()
_set_dm(adapter)
adapter._bot_participated_threads.add("$existing_thread")
adapter._threads.mark("$existing_thread")
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
await adapter._on_room_message(event)
@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
@pytest.mark.asyncio
async def test_dm_mention_thread_tracks_participation(monkeypatch):
"""DM mention-thread tracks the thread in _bot_participated_threads."""
"""DM mention-thread tracks the thread in _threads."""
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
_set_dm(adapter)
event = _make_event("@hermes:example.org help", event_id="$dm1")
with patch.object(adapter, "_save_participated_threads"):
with patch.object(adapter._threads, "_save"):
await adapter._on_room_message(event)
assert "$dm1" in adapter._bot_participated_threads
assert "$dm1" in adapter._threads
# ---------------------------------------------------------------------------

View file

@ -614,25 +614,27 @@ class TestMattermostDedup:
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL."""
"""Dedup cache should remove entries older than TTL on overflow."""
now = time.time()
dedup = self.adapter._dedup
# Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
for i in range(dedup._max_size + 10):
dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
# Add a fresh one
self.adapter._seen_posts["fresh"] = now
dedup._seen["fresh"] = now
self.adapter._prune_seen()
# Trigger pruning by calling is_duplicate with a new entry (over max_size)
dedup.is_duplicate("trigger_prune")
# Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
assert "fresh" in dedup._seen
assert len(dedup._seen) < dedup._max_size + 10
def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict."""
self.adapter._seen_posts["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts
"""Posts are tracked in the dedup cache."""
self.adapter._dedup._seen["test_post"] = time.time()
assert "test_post" in self.adapter._dedup._seen
# ---------------------------------------------------------------------------

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.run import _dequeue_pending_event
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
# Should be consumed (cleared)
assert adapter.get_pending_message(session_key) is None
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
adapter = _StubAdapter()
session_key = "telegram:user:voice"
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
message_id="voice-q1",
media_urls=["/tmp/voice.ogg"],
media_types=["audio/ogg"],
)
adapter._pending_messages[session_key] = event
retrieved = _dequeue_pending_event(adapter, session_key)
assert retrieved is event
assert retrieved.media_urls == ["/tmp/voice.ogg"]
assert retrieved.media_types == ["audio/ogg"]
assert adapter.get_pending_message(session_key) is None
def test_queue_does_not_set_interrupt_event(self):
"""The whole point of /queue — no interrupt signal."""
adapter = _StubAdapter()

View file

@ -18,6 +18,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
chat_id="-1001",
chat_name="Group",
chat_type="group",
user_id="123456",
user_name="alice",
thread_id="17585",
)
context = SessionContext(source=source, connected_platforms=[], home_channels={})
@ -25,6 +27,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
tokens = runner._set_session_env(context)
@ -33,6 +37,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
assert get_session_env("HERMES_SESSION_USER_NAME") == "alice"
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
# os.environ should NOT be touched
@ -50,6 +56,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
source = SessionSource(
@ -57,12 +65,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
chat_id="-1001",
chat_name="Group",
chat_type="group",
user_id="123456",
user_name="alice",
thread_id="17585",
)
context = SessionContext(source=source, connected_platforms=[], home_channels={})
tokens = runner._set_session_env(context)
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
runner._clear_session_env(tokens)
@ -70,6 +81,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
assert get_session_env("HERMES_SESSION_USER_ID") == ""
assert get_session_env("HERMES_SESSION_USER_NAME") == ""
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""

View file

@ -114,16 +114,16 @@ class TestSignalAdapterInit:
class TestSignalHelpers:
def test_redact_phone_long(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("+15551234567") == "+155****4567"
from gateway.platforms.helpers import redact_phone
assert redact_phone("+155****4567") == "+155****4567"
def test_redact_phone_short(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("+12345") == "+1****45"
from gateway.platforms.helpers import redact_phone
assert redact_phone("+12345") == "+1****45"
def test_redact_phone_empty(self):
from gateway.platforms.signal import _redact_phone
assert _redact_phone("") == "<none>"
from gateway.platforms.helpers import redact_phone
assert redact_phone("") == "<none>"
def test_parse_comma_list(self):
from gateway.platforms.signal import _parse_comma_list

View file

@ -1,11 +1,14 @@
"""Tests for SMS (Twilio) platform integration.
Covers config loading, format/truncate, echo prevention,
requirements check, and toolset verification.
requirements check, toolset verification, and Twilio signature validation.
"""
import base64
import hashlib
import hmac
import os
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -213,3 +216,335 @@ class TestSmsToolset:
from tools.cronjob_tools import CRONJOB_SCHEMA
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
assert "sms" in deliver_desc.lower()
# ── Webhook host configuration ─────────────────────────────────────
class TestWebhookHostConfig:
"""Verify SMS_WEBHOOK_HOST env var and default."""
def test_default_host_is_all_interfaces(self):
from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST
assert DEFAULT_WEBHOOK_HOST == "0.0.0.0"
def test_host_from_env(self):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
"SMS_WEBHOOK_HOST": "127.0.0.1",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._webhook_host == "127.0.0.1"
def test_webhook_url_from_env(self):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
def test_webhook_url_stripped(self):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
"SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
# ── Startup guard (fail-closed) ────────────────────────────────────
class TestStartupGuard:
"""Adapter must refuse to start without SMS_WEBHOOK_URL."""
def _make_adapter(self, extra_env=None):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
if extra_env:
env.update(extra_env)
with patch.dict(os.environ, env, clear=False):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
return adapter
@pytest.mark.asyncio
async def test_refuses_start_without_webhook_url(self):
adapter = self._make_adapter()
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_insecure_flag_allows_start_without_url(self):
mock_session = AsyncMock()
with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \
patch("aiohttp.web.AppRunner") as mock_runner_cls, \
patch("aiohttp.web.TCPSite") as mock_site_cls, \
patch("aiohttp.ClientSession", return_value=mock_session):
mock_runner_cls.return_value.setup = AsyncMock()
mock_runner_cls.return_value.cleanup = AsyncMock()
mock_site_cls.return_value.start = AsyncMock()
adapter = self._make_adapter()
result = await adapter.connect()
assert result is True
await adapter.disconnect()
@pytest.mark.asyncio
async def test_webhook_url_allows_start(self):
mock_session = AsyncMock()
with patch("aiohttp.web.AppRunner") as mock_runner_cls, \
patch("aiohttp.web.TCPSite") as mock_site_cls, \
patch("aiohttp.ClientSession", return_value=mock_session):
mock_runner_cls.return_value.setup = AsyncMock()
mock_runner_cls.return_value.cleanup = AsyncMock()
mock_site_cls.return_value.start = AsyncMock()
adapter = self._make_adapter(
extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"}
)
result = await adapter.connect()
assert result is True
await adapter.disconnect()
# ── Twilio signature validation ────────────────────────────────────
def _compute_twilio_signature(auth_token, url, params):
"""Reference implementation of Twilio's signature algorithm."""
data_to_sign = url
for key in sorted(params.keys()):
data_to_sign += key + params[key]
mac = hmac.new(
auth_token.encode("utf-8"),
data_to_sign.encode("utf-8"),
hashlib.sha1,
)
return base64.b64encode(mac.digest()).decode("utf-8")
class TestTwilioSignatureValidation:
"""Unit tests for SmsAdapter._validate_twilio_signature."""
def _make_adapter(self, auth_token="test_token_secret"):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": auth_token,
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key=auth_token)
adapter = SmsAdapter(pc)
return adapter
def test_valid_signature_accepted(self):
adapter = self._make_adapter()
url = "https://example.com/webhooks/twilio"
params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"}
sig = _compute_twilio_signature("test_token_secret", url, params)
assert adapter._validate_twilio_signature(url, params, sig) is True
def test_invalid_signature_rejected(self):
adapter = self._make_adapter()
url = "https://example.com/webhooks/twilio"
params = {"From": "+15551234567", "Body": "hello"}
assert adapter._validate_twilio_signature(url, params, "badsig") is False
def test_wrong_token_rejected(self):
adapter = self._make_adapter(auth_token="correct_token")
url = "https://example.com/webhooks/twilio"
params = {"From": "+15551234567", "Body": "hello"}
sig = _compute_twilio_signature("wrong_token", url, params)
assert adapter._validate_twilio_signature(url, params, sig) is False
def test_params_sorted_by_key(self):
"""Signature must be computed with params sorted alphabetically."""
adapter = self._make_adapter()
url = "https://example.com/webhooks/twilio"
params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"}
sig = _compute_twilio_signature("test_token_secret", url, params)
assert adapter._validate_twilio_signature(url, params, sig) is True
def test_empty_param_values_included(self):
"""Blank values must be included in signature computation."""
adapter = self._make_adapter()
url = "https://example.com/webhooks/twilio"
params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"}
sig = _compute_twilio_signature("test_token_secret", url, params)
assert adapter._validate_twilio_signature(url, params, sig) is True
def test_url_matters(self):
"""Different URLs produce different signatures."""
adapter = self._make_adapter()
params = {"Body": "hello"}
sig = _compute_twilio_signature(
"test_token_secret", "https://a.com/webhooks/twilio", params
)
assert adapter._validate_twilio_signature(
"https://b.com/webhooks/twilio", params, sig
) is False
def test_port_variant_443_matches_without_port(self):
"""Signature for https URL with :443 validates against URL without port."""
adapter = self._make_adapter()
params = {"From": "+15551234567", "Body": "hello"}
sig = _compute_twilio_signature(
"test_token_secret", "https://example.com:443/webhooks/twilio", params
)
assert adapter._validate_twilio_signature(
"https://example.com/webhooks/twilio", params, sig
) is True
def test_port_variant_without_port_matches_443(self):
"""Signature for https URL without port validates against URL with :443."""
adapter = self._make_adapter()
params = {"From": "+15551234567", "Body": "hello"}
sig = _compute_twilio_signature(
"test_token_secret", "https://example.com/webhooks/twilio", params
)
assert adapter._validate_twilio_signature(
"https://example.com:443/webhooks/twilio", params, sig
) is True
def test_non_standard_port_no_variant(self):
"""Non-standard port must NOT match URL without port."""
adapter = self._make_adapter()
params = {"From": "+15551234567", "Body": "hello"}
sig = _compute_twilio_signature(
"test_token_secret", "https://example.com/webhooks/twilio", params
)
assert adapter._validate_twilio_signature(
"https://example.com:8080/webhooks/twilio", params, sig
) is False
def test_port_variant_http_80(self):
"""Port variant also works for http with port 80."""
adapter = self._make_adapter()
params = {"From": "+15551234567", "Body": "hello"}
sig = _compute_twilio_signature(
"test_token_secret", "http://example.com:80/webhooks/twilio", params
)
assert adapter._validate_twilio_signature(
"http://example.com/webhooks/twilio", params, sig
) is True
# ── Webhook signature enforcement (handler-level) ──────────────────
class TestWebhookSignatureEnforcement:
"""Integration tests for signature validation in _handle_webhook."""
def _make_adapter(self, webhook_url=""):
from gateway.platforms.sms import SmsAdapter
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "test_token_secret",
"TWILIO_PHONE_NUMBER": "+15550001111",
"SMS_WEBHOOK_URL": webhook_url,
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="test_token_secret")
adapter = SmsAdapter(pc)
adapter._message_handler = AsyncMock()
return adapter
def _mock_request(self, body, headers=None):
request = MagicMock()
request.read = AsyncMock(return_value=body)
request.headers = headers or {}
return request
@pytest.mark.asyncio
async def test_insecure_flag_skips_validation(self):
"""With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted."""
env = {"SMS_INSECURE_NO_SIGNATURE": "true"}
with patch.dict(os.environ, env):
adapter = self._make_adapter(webhook_url="")
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body)
resp = await adapter._handle_webhook(request)
assert resp.status == 200
@pytest.mark.asyncio
async def test_insecure_flag_with_url_still_validates(self):
"""When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set,
validation stays active (URL takes precedence)."""
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body, headers={})
resp = await adapter._handle_webhook(request)
assert resp.status == 403
@pytest.mark.asyncio
async def test_missing_signature_returns_403(self):
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body, headers={})
resp = await adapter._handle_webhook(request)
assert resp.status == 403
@pytest.mark.asyncio
async def test_invalid_signature_returns_403(self):
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"})
resp = await adapter._handle_webhook(request)
assert resp.status == 403
@pytest.mark.asyncio
async def test_valid_signature_returns_200(self):
webhook_url = "https://example.com/webhooks/twilio"
adapter = self._make_adapter(webhook_url=webhook_url)
params = {
"From": "+15551234567",
"To": "+15550001111",
"Body": "hello",
"MessageSid": "SM123",
}
sig = _compute_twilio_signature("test_token_secret", webhook_url, params)
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
resp = await adapter._handle_webhook(request)
assert resp.status == 200
@pytest.mark.asyncio
async def test_port_variant_signature_returns_200(self):
"""Signature computed with :443 should pass when URL configured without port."""
webhook_url = "https://example.com/webhooks/twilio"
adapter = self._make_adapter(webhook_url=webhook_url)
params = {
"From": "+15551234567",
"To": "+15550001111",
"Body": "hello",
"MessageSid": "SM123",
}
sig = _compute_twilio_signature(
"test_token_secret", "https://example.com:443/webhooks/twilio", params
)
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
resp = await adapter._handle_webhook(request)
assert resp.status == 200

View file

@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
import pytest
import yaml
from gateway.config import GatewayConfig, load_gateway_config
from gateway.config import GatewayConfig, Platform, load_gateway_config
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
def test_gateway_config_stt_disabled_from_dict_nested():
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
assert "No STT provider is configured" not in result
assert "trouble transcribing" in result
assert "caption" in result
@pytest.mark.asyncio
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner.config = GatewayConfig(stt_enabled=True)
runner.adapters = {}
runner._model = "test-model"
runner._base_url = ""
runner._has_setup_skill = lambda: False
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=source,
media_urls=["/tmp/queued-voice.ogg"],
media_types=["audio/ogg"],
)
with patch(
"tools.transcription_tools.transcribe_audio",
return_value={
"success": True,
"transcript": "queued voice transcript",
"provider": "local_command",
},
):
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result is not None
assert "queued voice transcript" in result
assert "voice message" in result.lower()

View file

@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch):
async def _noop():
return []
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
# Mock HTTPXRequest so the builder chain doesn't fail
monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock())
@pytest.mark.asyncio
@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch):
ok = await adapter.connect()
assert ok is False
assert adapter.fatal_error_code == "telegram_token_lock"
assert adapter.fatal_error_code == "telegram-bot-token_lock"
assert adapter.has_fatal_error is True
assert "already using this Telegram bot token" in adapter.fatal_error_message
assert "already in use" in adapter.fatal_error_message
@pytest.mark.asyncio
@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m
builder = MagicMock()
builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
app = SimpleNamespace(
bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()),
updater=SimpleNamespace(),
@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch):
)
builder = MagicMock()
builder.token.return_value = builder
builder.request.return_value = builder
builder.get_updates_request.return_value = builder
builder.build.return_value = app
monkeypatch.setattr(
"gateway.platforms.telegram.Application",

View file

@ -1,12 +1,14 @@
"""Tests for the Weixin platform adapter."""
import asyncio
import json
import os
from unittest.mock import AsyncMock, patch
from gateway.config import PlatformConfig
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
from gateway.platforms.weixin import WeixinAdapter
from gateway.platforms import weixin
from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter
from tools.send_message_tool import _parse_target_ref, _send_to_platform
@ -187,6 +189,70 @@ class TestWeixinConfig:
assert config.get_connected_platforms() == []
class TestWeixinStatePersistence:
def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
account_path = tmp_path / "weixin" / "accounts" / "acct.json"
account_path.parent.mkdir(parents=True, exist_ok=True)
original = {"token": "old-token", "base_url": "https://old.example.com"}
account_path.write_text(json.dumps(original), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
try:
weixin.save_weixin_account(
str(tmp_path),
account_id="acct",
token="new-token",
base_url="https://new.example.com",
user_id="wxid_new",
)
except OSError:
pass
else:
raise AssertionError("expected save_weixin_account to propagate replace failure")
assert json.loads(account_path.read_text(encoding="utf-8")) == original
def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json"
token_path.parent.mkdir(parents=True, exist_ok=True)
token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
store = ContextTokenStore(str(tmp_path))
with patch.object(weixin.logger, "warning") as warning_mock:
store.set("acct", "user-b", "new-token")
assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"}
warning_mock.assert_called_once()
def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json"
sync_path.parent.mkdir(parents=True, exist_ok=True)
sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8")
def _boom(_src, _dst):
raise OSError("disk full")
monkeypatch.setattr("utils.os.replace", _boom)
try:
weixin._save_sync_buf(str(tmp_path), "acct", "new-sync")
except OSError:
pass
else:
raise AssertionError("expected _save_sync_buf to propagate replace failure")
assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"}
class TestWeixinSendMessageIntegration:
def test_parse_target_ref_accepts_weixin_ids(self):
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)
@ -217,6 +283,55 @@ class TestWeixinSendMessageIntegration:
)
class TestWeixinChunkDelivery:
def _connected_adapter(self) -> WeixinAdapter:
adapter = _make_adapter()
adapter._session = object()
adapter._token = "test-token"
adapter._base_url = "https://weixin.example.com"
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
return adapter
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
def test_send_waits_between_multiple_chunks(self, send_message_mock, sleep_mock):
adapter = self._connected_adapter()
adapter.MAX_MESSAGE_LENGTH = 12
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
assert result.success is True
assert send_message_mock.await_count == 3
assert sleep_mock.await_count == 2
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
def test_send_retries_failed_chunk_before_continuing(self, send_message_mock, sleep_mock):
adapter = self._connected_adapter()
adapter.MAX_MESSAGE_LENGTH = 12
calls = {"count": 0}
async def flaky_send(*args, **kwargs):
calls["count"] += 1
if calls["count"] == 2:
raise RuntimeError("temporary iLink failure")
send_message_mock.side_effect = flaky_send
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
assert result.success is True
# 3 chunks, but chunk 2 fails once and retries → 4 _send_message calls total
assert send_message_mock.await_count == 4
# The retried chunk should reuse the same client_id for deduplication
first_try = send_message_mock.await_args_list[1].kwargs
retry = send_message_mock.await_args_list[2].kwargs
assert first_try["text"] == retry["text"]
assert first_try["client_id"] == retry["client_id"]
class TestWeixinRemoteMediaSafety:
def test_download_remote_media_blocks_unsafe_urls(self):
adapter = _make_adapter()

View file

@ -260,7 +260,7 @@ class TestWaitForGatewayExit:
def test_kill_gateway_processes_force_uses_helper(self, monkeypatch):
calls = []
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22])
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22])
monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force)))
killed = gateway.kill_gateway_processes(force=True)

View file

@ -1,6 +1,7 @@
"""Tests for gateway service management helpers."""
import os
import pwd
from pathlib import Path
from types import SimpleNamespace
@ -129,7 +130,7 @@ class TestGatewayStopCleanup:
monkeypatch.setattr(
gateway_cli,
"kill_gateway_processes",
lambda force=False: kill_calls.append(force) or 2,
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
)
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop"))
@ -155,7 +156,7 @@ class TestGatewayStopCleanup:
monkeypatch.setattr(
gateway_cli,
"kill_gateway_processes",
lambda force=False: kill_calls.append(force) or 2,
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
)
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True}))
@ -924,6 +925,23 @@ class TestProfileArg:
assert "<string>--profile</string>" in plist
assert "<string>mybot</string>" in plist
def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch):
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
profile_dir.mkdir(parents=True)
machine_home = tmp_path / "machine-home"
machine_home.mkdir()
profile_home = profile_dir / "home"
profile_home.mkdir()
monkeypatch.setattr(Path, "home", lambda: profile_home)
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home)))
plist_path = gateway_cli.get_launchd_plist_path()
assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist"
class TestRemapPathForUser:
"""Unit tests for _remap_path_for_user()."""

View file

@ -1214,3 +1214,115 @@ def test_openrouter_provider_not_affected_by_custom_fix(monkeypatch):
resolved = rp.resolve_runtime_provider(requested="openrouter")
assert resolved["provider"] == "openrouter"
# ------------------------------------------------------------------
# fix #7828 — custom_providers model field must propagate to runtime
# ------------------------------------------------------------------
def test_get_named_custom_provider_includes_model(monkeypatch):
"""_get_named_custom_provider should include the model field from config."""
monkeypatch.setattr(rp, "load_config", lambda: {
"custom_providers": [{
"name": "my-dashscope",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "test-key",
"api_mode": "chat_completions",
"model": "qwen3.6-plus",
}],
})
result = rp._get_named_custom_provider("my-dashscope")
assert result is not None
assert result["model"] == "qwen3.6-plus"
def test_get_named_custom_provider_excludes_empty_model(monkeypatch):
"""Empty or whitespace-only model field should not appear in result."""
for model_val in ["", " ", None]:
entry = {
"name": "test-ep",
"base_url": "https://example.com/v1",
"api_key": "key",
}
if model_val is not None:
entry["model"] = model_val
monkeypatch.setattr(rp, "load_config", lambda e=entry: {
"custom_providers": [e],
})
result = rp._get_named_custom_provider("test-ep")
assert result is not None
assert "model" not in result, (
f"model field {model_val!r} should not be included in result"
)
def test_named_custom_runtime_propagates_model_direct_path(monkeypatch):
"""Model should propagate through the direct (non-pool) resolution path."""
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
monkeypatch.setattr(
rp, "_get_named_custom_provider",
lambda p: {
"name": "my-server",
"base_url": "http://localhost:8000/v1",
"api_key": "test-key",
"model": "qwen3.6-plus",
},
)
# Ensure pool doesn't intercept
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
resolved = rp.resolve_runtime_provider(requested="my-server")
assert resolved["model"] == "qwen3.6-plus"
assert resolved["provider"] == "custom"
def test_named_custom_runtime_propagates_model_pool_path(monkeypatch):
"""Model should propagate even when credential pool handles credentials."""
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
monkeypatch.setattr(
rp, "_get_named_custom_provider",
lambda p: {
"name": "my-server",
"base_url": "http://localhost:8000/v1",
"api_key": "test-key",
"model": "qwen3.6-plus",
},
)
# Pool returns a result (intercepting the normal path)
monkeypatch.setattr(
rp, "_try_resolve_from_custom_pool",
lambda *a, **k: {
"provider": "custom",
"api_mode": "chat_completions",
"base_url": "http://localhost:8000/v1",
"api_key": "pool-key",
"source": "pool:custom:my-server",
},
)
resolved = rp.resolve_runtime_provider(requested="my-server")
assert resolved["model"] == "qwen3.6-plus", (
"model must be injected into pool result"
)
assert resolved["api_key"] == "pool-key", "pool credentials should be used"
def test_named_custom_runtime_no_model_when_absent(monkeypatch):
"""When custom_providers entry has no model field, runtime should not either."""
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
monkeypatch.setattr(
rp, "_get_named_custom_provider",
lambda p: {
"name": "my-server",
"base_url": "http://localhost:8000/v1",
"api_key": "test-key",
},
)
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
resolved = rp.resolve_runtime_provider(requested="my-server")
assert "model" not in resolved

View file

@ -191,6 +191,19 @@ class TestLaunchdPlistPath:
raise AssertionError("PATH key not found in plist")
class TestLaunchdPlistCurrentness:
def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch):
plist_path = tmp_path / "ai.hermes.gateway.plist"
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin")
plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8")
monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin")
assert gateway_cli.launchd_plist_is_current() is True
# ---------------------------------------------------------------------------
# cmd_update — macOS launchd detection
# ---------------------------------------------------------------------------
@ -536,7 +549,7 @@ class TestServicePidExclusion:
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
), patch.object(
gateway_cli, "find_gateway_pids",
side_effect=lambda exclude_pids=None: (
side_effect=lambda exclude_pids=None, all_profiles=False: (
[SERVICE_PID] if not exclude_pids else
[p for p in [SERVICE_PID] if p not in exclude_pids]
),
@ -579,7 +592,7 @@ class TestServicePidExclusion:
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
), patch.object(
gateway_cli, "find_gateway_pids",
side_effect=lambda exclude_pids=None: (
side_effect=lambda exclude_pids=None, all_profiles=False: (
[SERVICE_PID] if not exclude_pids else
[p for p in [SERVICE_PID] if p not in exclude_pids]
),
@ -618,7 +631,7 @@ class TestServicePidExclusion:
launchctl_loaded=True,
)
def fake_find(exclude_pids=None):
def fake_find(exclude_pids=None, all_profiles=False):
_exclude = exclude_pids or set()
return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude]
@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude:
pids = gateway_cli.find_gateway_pids()
assert 100 in pids
assert 200 in pids
def test_filters_to_current_profile(self, monkeypatch, tmp_path):
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
profile_dir.mkdir(parents=True)
monkeypatch.setattr(gateway_cli, "is_windows", lambda: False)
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
def fake_run(cmd, **kwargs):
return subprocess.CompletedProcess(
cmd, 0,
stdout=(
"100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n"
"200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n"
),
stderr="",
)
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
monkeypatch.setattr("os.getpid", lambda: 999)
monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set())
monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha")
pids = gateway_cli.find_gateway_pids()
assert pids == [100]

View file

@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def tearDown(self):
set_interrupt(False)
def _make_bare_agent(self):
"""Create a bare AIAgent via __new__ with all interrupt-related attrs."""
from run_agent import AIAgent
agent = AIAgent.__new__(AIAgent)
agent._interrupt_requested = False
agent._interrupt_message = None
agent._execution_thread_id = None # defaults to current thread in set_interrupt
agent._active_children = []
agent._active_children_lock = threading.Lock()
agent.quiet_mode = True
return agent
def test_parent_interrupt_sets_child_flag(self):
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
parent._active_children.append(child)
@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase):
assert child._interrupt_message == "new user message"
assert is_interrupted() is True
def test_child_clear_interrupt_at_start_clears_global(self):
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
This is the intended behavior at startup, but verify it doesn't
accidentally clear an interrupt intended for a running child.
def test_child_clear_interrupt_at_start_clears_thread(self):
"""child.clear_interrupt() at start of run_conversation clears the
per-thread interrupt flag for the current thread.
"""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child = self._make_bare_agent()
child._interrupt_requested = True
child._interrupt_message = "msg"
child.quiet_mode = True
child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set
# Interrupt for current thread is set
set_interrupt(True)
assert is_interrupted() is True
# child.clear_interrupt() clears both
# child.clear_interrupt() clears both instance flag and thread flag
child.clear_interrupt()
assert child._interrupt_requested is False
assert is_interrupted() is False
def test_interrupt_during_child_api_call_detected(self):
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
child = self._make_bare_agent()
child.api_mode = "chat_completions"
child.log_prefix = ""
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def test_concurrent_interrupt_propagation(self):
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
# Register child (simulating what _run_single_child does)
parent._active_children.append(child)
@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase):
set_interrupt(False)
class TestPerThreadInterruptIsolation(unittest.TestCase):
"""Verify that interrupting one agent does NOT affect another agent's thread.
This is the core fix for the gateway cross-session interrupt leak:
multiple agents run in separate threads within the same process, and
interrupting agent A must not kill agent B's running tools.
"""
def setUp(self):
set_interrupt(False)
def tearDown(self):
set_interrupt(False)
def test_interrupt_only_affects_target_thread(self):
"""set_interrupt(True, tid) only makes is_interrupted() True on that thread."""
results = {}
barrier = threading.Barrier(2)
def thread_a():
"""Agent A's execution thread — will be interrupted."""
tid = threading.current_thread().ident
results["a_tid"] = tid
barrier.wait(timeout=5) # sync with thread B
time.sleep(0.2) # let the interrupt arrive
results["a_interrupted"] = is_interrupted()
def thread_b():
"""Agent B's execution thread — should NOT be affected."""
tid = threading.current_thread().ident
results["b_tid"] = tid
barrier.wait(timeout=5) # sync with thread A
time.sleep(0.2)
results["b_interrupted"] = is_interrupted()
ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)
ta.start()
tb.start()
# Wait for both threads to register their TIDs
time.sleep(0.05)
while "a_tid" not in results or "b_tid" not in results:
time.sleep(0.01)
# Interrupt ONLY thread A (simulates gateway interrupting agent A)
set_interrupt(True, results["a_tid"])
ta.join(timeout=3)
tb.join(timeout=3)
assert results["a_interrupted"] is True, "Thread A should see the interrupt"
assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt"
def test_clear_interrupt_only_clears_target_thread(self):
"""Clearing one thread's interrupt doesn't clear another's."""
tid_a = 99990001
tid_b = 99990002
set_interrupt(True, tid_a)
set_interrupt(True, tid_b)
# Clear only A
set_interrupt(False, tid_a)
# Simulate checking from thread B's perspective
from tools.interrupt import _interrupted_threads, _lock
with _lock:
assert tid_a not in _interrupted_threads
assert tid_b in _interrupted_threads
# Cleanup
set_interrupt(False, tid_b)
if __name__ == "__main__":
unittest.main()

View file

@ -2087,8 +2087,9 @@ class TestRunConversation:
assert "Thinking Budget Exhausted" in result["final_response"]
assert "/thinkon" in result["final_response"]
def test_length_empty_content_detected_as_thinking_exhausted(self, agent):
"""When finish_reason='length' and content is None/empty, detect exhaustion."""
def test_length_empty_content_without_think_tags_retries_normally(self, agent):
"""When finish_reason='length' and content is None but no think tags,
fall through to normal continuation retry (not thinking-exhaustion)."""
self._setup_agent(agent)
resp = _mock_response(content=None, finish_reason="length")
agent.client.chat.completions.create.return_value = resp
@ -2100,12 +2101,10 @@ class TestRunConversation:
):
result = agent.run_conversation("hello")
# Without think tags, the agent should attempt continuation retries
# (up to 3), not immediately fire thinking-exhaustion.
assert result["api_calls"] == 3
assert result["completed"] is False
assert result["api_calls"] == 1
assert "reasoning" in result["error"].lower()
# User-friendly message is returned
assert result["final_response"] is not None
assert "Thinking Budget Exhausted" in result["final_response"]
def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent):
self._setup_agent(agent)
@ -2169,6 +2168,35 @@ class TestRunConversation:
mock_hfc.assert_called_once()
assert result["final_response"] == "Done!"
def test_truncated_tool_args_detected_when_finish_reason_not_length(self, agent):
"""When a router rewrites finish_reason from 'length' to 'tool_calls',
truncated JSON arguments should still be detected and refused rather
than wasting 3 retry attempts."""
self._setup_agent(agent)
agent.valid_tool_names.add("write_file")
bad_tc = _mock_tool_call(
name="write_file",
arguments='{"path":"report.md","content":"partial',
call_id="c1",
)
resp = _mock_response(
content="", finish_reason="tool_calls", tool_calls=[bad_tc],
)
agent.client.chat.completions.create.return_value = resp
with (
patch("run_agent.handle_function_call") as mock_handle_function_call,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("write the report")
assert result["completed"] is False
assert result["partial"] is True
assert "truncated due to output length limit" in result["error"]
mock_handle_function_call.assert_not_called()
class TestRetryExhaustion:
"""Regression: retry_count > max_retries was dead code (off-by-one).

View file

@ -1104,3 +1104,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch
]
assert "enc_first" in encrypted_contents
assert "enc_second" in encrypted_contents
def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch):
"""Duplicate reasoning item IDs across multi-turn incomplete responses
must be deduplicated so the Responses API doesn't reject with HTTP 400."""
agent = _build_agent(monkeypatch)
messages = [
{"role": "user", "content": "think hard"},
{
"role": "assistant",
"content": "",
"codex_reasoning_items": [
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
{"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"},
],
},
{
"role": "assistant",
"content": "partial answer",
"codex_reasoning_items": [
# rs_aaa is duplicated from the previous turn
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
{"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"},
],
},
]
items = agent._chat_messages_to_responses_input(messages)
reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"]
# rs_aaa should appear only once (first occurrence kept)
assert reasoning_ids.count("rs_aaa") == 1
# rs_bbb and rs_ccc should each appear once
assert reasoning_ids.count("rs_bbb") == 1
assert reasoning_ids.count("rs_ccc") == 1
assert len(reasoning_ids) == 3
def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
"""_preflight_codex_input_items should also deduplicate reasoning items by ID."""
agent = _build_agent(monkeypatch)
raw_input = [
{"role": "user", "content": [{"type": "input_text", "text": "hello"}]},
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
{"role": "assistant", "content": "ok"},
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
{"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"},
{"role": "assistant", "content": "done"},
]
normalized = agent._preflight_codex_input_items(raw_input)
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
reasoning_ids = [it["id"] for it in reasoning_items]
assert reasoning_ids.count("rs_xyz") == 1
assert reasoning_ids.count("rs_zzz") == 1
assert len(reasoning_items) == 2

View file

@ -0,0 +1,158 @@
"""Tests for _reap_orphaned_browser_sessions() — kills orphaned agent-browser
daemons whose Python parent exited without cleaning up."""
import os
import signal
import textwrap
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
@pytest.fixture
def fake_tmpdir(tmp_path):
"""Patch _socket_safe_tmpdir to return a temp dir we control."""
with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)):
yield tmp_path
@pytest.fixture(autouse=True)
def _isolate_sessions():
"""Ensure _active_sessions is empty for each test."""
import tools.browser_tool as bt
orig = bt._active_sessions.copy()
bt._active_sessions.clear()
yield
bt._active_sessions.clear()
bt._active_sessions.update(orig)
def _make_socket_dir(tmpdir, session_name, pid=None):
"""Create a fake agent-browser socket directory with optional PID file."""
d = tmpdir / f"agent-browser-{session_name}"
d.mkdir()
if pid is not None:
(d / f"{session_name}.pid").write_text(str(pid))
return d
class TestReapOrphanedBrowserSessions:
"""Tests for the orphan reaper function."""
def test_no_socket_dirs_is_noop(self, fake_tmpdir):
"""No socket dirs => nothing happens, no errors."""
from tools.browser_tool import _reap_orphaned_browser_sessions
_reap_orphaned_browser_sessions() # should not raise
def test_stale_dir_without_pid_file_is_removed(self, fake_tmpdir):
"""Socket dir with no PID file is cleaned up."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_abc1234567")
assert d.exists()
_reap_orphaned_browser_sessions()
assert not d.exists()
def test_stale_dir_with_dead_pid_is_removed(self, fake_tmpdir):
"""Socket dir whose daemon PID is dead gets cleaned up."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_dead123456", pid=999999999)
assert d.exists()
_reap_orphaned_browser_sessions()
assert not d.exists()
def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir):
"""Alive daemon not tracked by _active_sessions gets SIGTERM."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345)
kill_calls = []
original_kill = os.kill
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
if sig == 0:
return # pretend process exists
# Don't actually kill anything
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Should have checked existence (sig 0) then killed (SIGTERM)
assert (12345, 0) in kill_calls
assert (12345, signal.SIGTERM) in kill_calls
def test_tracked_session_is_not_reaped(self, fake_tmpdir):
"""Sessions tracked in _active_sessions are left alone."""
import tools.browser_tool as bt
from tools.browser_tool import _reap_orphaned_browser_sessions
session_name = "h_tracked1234"
d = _make_socket_dir(fake_tmpdir, session_name, pid=12345)
# Register the session as actively tracked
bt._active_sessions["some_task"] = {"session_name": session_name}
kill_calls = []
def mock_kill(pid, sig):
kill_calls.append((pid, sig))
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Should NOT have tried to kill anything
assert len(kill_calls) == 0
# Dir should still exist
assert d.exists()
def test_permission_error_on_kill_check_skips(self, fake_tmpdir):
"""If we can't check the PID (PermissionError), skip it."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345)
def mock_kill(pid, sig):
if sig == 0:
raise PermissionError("not our process")
with patch("os.kill", side_effect=mock_kill):
_reap_orphaned_browser_sessions()
# Dir should still exist (we didn't touch someone else's process)
assert d.exists()
def test_cdp_sessions_are_also_reaped(self, fake_tmpdir):
"""CDP sessions (cdp_ prefix) are also scanned."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "cdp_abc1234567")
assert d.exists()
_reap_orphaned_browser_sessions()
# No PID file → cleaned up
assert not d.exists()
def test_non_hermes_dirs_are_ignored(self, fake_tmpdir):
"""Socket dirs that don't match our naming pattern are left alone."""
from tools.browser_tool import _reap_orphaned_browser_sessions
# Create a dir that doesn't match h_* or cdp_* pattern
d = fake_tmpdir / "agent-browser-other_session"
d.mkdir()
(d / "other_session.pid").write_text("12345")
_reap_orphaned_browser_sessions()
# Should NOT be touched
assert d.exists()
def test_corrupt_pid_file_is_cleaned(self, fake_tmpdir):
"""PID file with non-integer content is cleaned up."""
from tools.browser_tool import _reap_orphaned_browser_sessions
d = _make_socket_dir(fake_tmpdir, "h_corrupt1234")
(d / "h_corrupt1234.pid").write_text("not-a-number")
_reap_orphaned_browser_sessions()
assert not d.exists()

View file

@ -1,9 +1,6 @@
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
import logging
import os
import json
import shutil
import subprocess
import pytest
from pathlib import Path
@ -42,6 +39,19 @@ def checkpoint_base(tmp_path):
return tmp_path / "checkpoints"
@pytest.fixture()
def fake_home(tmp_path, monkeypatch):
"""Set a deterministic fake home for expanduser/path-home behavior."""
home = tmp_path / "home"
home.mkdir()
monkeypatch.setenv("HOME", str(home))
monkeypatch.setenv("USERPROFILE", str(home))
monkeypatch.delenv("HOMEDRIVE", raising=False)
monkeypatch.delenv("HOMEPATH", raising=False)
monkeypatch.setattr(Path, "home", classmethod(lambda cls: home))
return home
@pytest.fixture()
def mgr(work_dir, checkpoint_base, monkeypatch):
"""CheckpointManager with redirected checkpoint base."""
@ -78,6 +88,16 @@ class TestShadowRepoPath:
p = _shadow_repo_path(str(work_dir))
assert str(p).startswith(str(checkpoint_base))
def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch):
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
project = fake_home / "project"
project.mkdir()
tilde_path = f"~/{project.name}"
expanded_path = str(project)
assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path)
# =========================================================================
# Shadow repo init
@ -221,6 +241,20 @@ class TestListCheckpoints:
assert result[0]["reason"] == "third"
assert result[2]["reason"] == "first"
def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch):
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
mgr = CheckpointManager(enabled=True, max_snapshots=50)
project = fake_home / "project"
project.mkdir()
(project / "main.py").write_text("v1\n")
tilde_path = f"~/{project.name}"
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
listed = mgr.list_checkpoints(str(project))
assert len(listed) == 1
assert listed[0]["reason"] == "initial"
# =========================================================================
# CheckpointManager — restoring
@ -271,6 +305,28 @@ class TestRestore:
assert len(all_cps) >= 2
assert "pre-rollback" in all_cps[0]["reason"]
def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch):
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
mgr = CheckpointManager(enabled=True, max_snapshots=50)
project = fake_home / "project"
project.mkdir()
file_path = project / "main.py"
file_path.write_text("original\n")
tilde_path = f"~/{project.name}"
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
mgr.new_turn()
file_path.write_text("changed\n")
checkpoints = mgr.list_checkpoints(str(project))
diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"])
assert diff_result["success"] is True
assert "main.py" in diff_result["diff"]
restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"])
assert restore_result["success"] is True
assert file_path.read_text() == "original\n"
# =========================================================================
# CheckpointManager — working dir resolution
@ -310,6 +366,19 @@ class TestWorkingDirResolution:
result = mgr.get_working_dir_for_path(str(filepath))
assert result == str(filepath.parent)
def test_resolves_tilde_path_to_project_root(self, fake_home):
mgr = CheckpointManager(enabled=True)
project = fake_home / "myproject"
project.mkdir()
(project / "pyproject.toml").write_text("[project]\n")
subdir = project / "src"
subdir.mkdir()
filepath = subdir / "main.py"
filepath.write_text("x\n")
result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py")
assert result == str(project)
# =========================================================================
# Git env isolation
@ -333,6 +402,14 @@ class TestGitEnvIsolation:
env = _git_env(shadow, str(tmp_path))
assert "GIT_INDEX_FILE" not in env
def test_expands_tilde_in_work_tree(self, fake_home, tmp_path):
shadow = tmp_path / "shadow"
work = fake_home / "work"
work.mkdir()
env = _git_env(shadow, f"~/{work.name}")
assert env["GIT_WORK_TREE"] == str(work.resolve())
# =========================================================================
# format_checkpoint_list
@ -384,6 +461,8 @@ class TestErrorResilience:
assert result is False
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
work = tmp_path / "work"
work.mkdir()
completed = subprocess.CompletedProcess(
args=["git", "diff", "--cached", "--quiet"],
returncode=1,
@ -395,7 +474,7 @@ class TestErrorResilience:
ok, stdout, stderr = _run_git(
["diff", "--cached", "--quiet"],
tmp_path / "shadow",
str(tmp_path / "work"),
str(work),
allowed_returncodes={1},
)
assert ok is False
@ -403,6 +482,38 @@ class TestErrorResilience:
assert stderr == ""
assert not caplog.records
def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog):
missing = tmp_path / "missing"
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
ok, stdout, stderr = _run_git(
["status"],
tmp_path / "shadow",
str(missing),
)
assert ok is False
assert stdout == ""
assert "working directory not found" in stderr
assert not any("Git executable not found" in r.getMessage() for r in caplog.records)
def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog):
work = tmp_path / "work"
work.mkdir()
def raise_missing_git(*args, **kwargs):
raise FileNotFoundError(2, "No such file or directory", "git")
monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git)
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
ok, stdout, stderr = _run_git(
["status"],
tmp_path / "shadow",
str(work),
)
assert ok is False
assert stdout == ""
assert stderr == "git not found"
assert any("Git executable not found" in r.getMessage() for r in caplog.records)
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
"""Checkpoint failures should never raise — they're silently logged."""
def broken_run_git(*args, **kwargs):
@ -411,3 +522,68 @@ class TestErrorResilience:
# Should not raise
result = mgr.ensure_checkpoint(str(work_dir), "test")
assert result is False
# =========================================================================
# Security / Input validation
# =========================================================================
class TestSecurity:
def test_restore_rejects_argument_injection(self, mgr, work_dir):
mgr.ensure_checkpoint(str(work_dir), "initial")
# Try to pass a git flag as a commit hash
result = mgr.restore(str(work_dir), "--patch")
assert result["success"] is False
assert "Invalid commit hash" in result["error"]
assert "must not start with '-'" in result["error"]
result = mgr.restore(str(work_dir), "-p")
assert result["success"] is False
assert "Invalid commit hash" in result["error"]
def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir):
mgr.ensure_checkpoint(str(work_dir), "initial")
# Git hashes should not contain characters like ;, &, |
result = mgr.restore(str(work_dir), "abc; rm -rf /")
assert result["success"] is False
assert "expected 4-64 hex characters" in result["error"]
result = mgr.diff(str(work_dir), "abc&def")
assert result["success"] is False
assert "expected 4-64 hex characters" in result["error"]
def test_restore_rejects_path_traversal(self, mgr, work_dir):
mgr.ensure_checkpoint(str(work_dir), "initial")
# Real commit hash but malicious path
checkpoints = mgr.list_checkpoints(str(work_dir))
target_hash = checkpoints[0]["hash"]
# Absolute path outside
result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd")
assert result["success"] is False
assert "got absolute path" in result["error"]
# Relative traversal outside path
result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt")
assert result["success"] is False
assert "escapes the working directory" in result["error"]
def test_restore_accepts_valid_file_path(self, mgr, work_dir):
mgr.ensure_checkpoint(str(work_dir), "initial")
checkpoints = mgr.list_checkpoints(str(work_dir))
target_hash = checkpoints[0]["hash"]
# Valid path inside directory
result = mgr.restore(str(work_dir), target_hash, file_path="main.py")
assert result["success"] is True
# Another valid path with subdirectories
(work_dir / "subdir").mkdir()
(work_dir / "subdir" / "test.txt").write_text("hello")
mgr.new_turn()
mgr.ensure_checkpoint(str(work_dir), "second")
checkpoints = mgr.list_checkpoints(str(work_dir))
target_hash = checkpoints[0]["hash"]
result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt")
assert result["success"] is True

View file

@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
class TestInterruptHandling(unittest.TestCase):
def test_interrupt_event_stops_execution(self):
"""When _interrupt_event is set, execute_code should stop the script."""
"""When interrupt is set for the execution thread, execute_code should stop."""
code = "import time; time.sleep(60); print('should not reach')"
from tools.interrupt import set_interrupt
# Capture the main thread ID so we can target the interrupt correctly.
# execute_code runs in the current thread; set_interrupt needs its ID.
main_tid = threading.current_thread().ident
def set_interrupt_after_delay():
import time as _t
_t.sleep(1)
from tools.terminal_tool import _interrupt_event
_interrupt_event.set()
set_interrupt(True, main_tid)
t = threading.Thread(target=set_interrupt_after_delay, daemon=True)
t.start()
@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase):
self.assertEqual(result["status"], "interrupted")
self.assertIn("interrupted", result["output"])
finally:
from tools.terminal_tool import _interrupt_event
_interrupt_event.clear()
set_interrupt(False, main_tid)
t.join(timeout=3)

View file

@ -227,6 +227,8 @@ class TestCheckpointNotify:
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_user_id": "u123",
"watcher_user_name": "alice",
"watcher_thread_id": "42",
"watcher_interval": 5,
"notify_on_complete": True,
@ -236,6 +238,8 @@ class TestCheckpointNotify:
assert recovered == 1
assert len(registry.pending_watchers) == 1
assert registry.pending_watchers[0]["notify_on_complete"] is True
assert registry.pending_watchers[0]["user_id"] == "u123"
assert registry.pending_watchers[0]["user_name"] == "alice"
def test_recover_defaults_false(self, registry, tmp_path):
"""Old checkpoint entries without the field default to False."""

View file

@ -438,6 +438,8 @@ class TestCheckpoint:
s = _make_session()
s.watcher_platform = "telegram"
s.watcher_chat_id = "999"
s.watcher_user_id = "u123"
s.watcher_user_name = "alice"
s.watcher_thread_id = "42"
s.watcher_interval = 60
registry._running[s.id] = s
@ -447,6 +449,8 @@ class TestCheckpoint:
assert len(data) == 1
assert data[0]["watcher_platform"] == "telegram"
assert data[0]["watcher_chat_id"] == "999"
assert data[0]["watcher_user_id"] == "u123"
assert data[0]["watcher_user_name"] == "alice"
assert data[0]["watcher_thread_id"] == "42"
assert data[0]["watcher_interval"] == 60
@ -460,6 +464,8 @@ class TestCheckpoint:
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_user_id": "u123",
"watcher_user_name": "alice",
"watcher_thread_id": "42",
"watcher_interval": 60,
}]))
@ -471,6 +477,8 @@ class TestCheckpoint:
assert w["session_id"] == "proc_live"
assert w["platform"] == "telegram"
assert w["chat_id"] == "123"
assert w["user_id"] == "u123"
assert w["user_name"] == "alice"
assert w["thread_id"] == "42"
assert w["check_interval"] == 60

View file

@ -348,7 +348,7 @@ word word
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
assert result["success"] is False
assert "boundary" in result["error"].lower()
assert "escapes" in result["error"].lower()
assert outside_file.read_text() == "old text here"
@ -412,7 +412,7 @@ class TestWriteFile:
result = _write_file("my-skill", "references/escape/owned.md", "malicious")
assert result["success"] is False
assert "boundary" in result["error"].lower()
assert "escapes" in result["error"].lower()
assert not (outside_dir / "owned.md").exists()
@ -449,7 +449,7 @@ class TestRemoveFile:
result = _remove_file("my-skill", "references/escape/keep.txt")
assert result["success"] is False
assert "boundary" in result["error"].lower()
assert "escapes" in result["error"].lower()
assert outside_file.exists()

View file

@ -124,6 +124,34 @@ class TestWriteToSandbox:
cmd = env.execute.call_args[0][0]
assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd
def test_path_with_spaces_is_quoted(self):
env = MagicMock()
env.execute.return_value = {"output": "", "returncode": 0}
remote_path = "/tmp/hermes results/abc file.txt"
_write_to_sandbox("content", remote_path, env)
cmd = env.execute.call_args[0][0]
assert "'/tmp/hermes results'" in cmd
assert "'/tmp/hermes results/abc file.txt'" in cmd
def test_shell_metacharacters_neutralized(self):
"""Paths with shell metacharacters must be quoted to prevent injection."""
env = MagicMock()
env.execute.return_value = {"output": "", "returncode": 0}
malicious_path = "/tmp/hermes-results/$(whoami).txt"
_write_to_sandbox("content", malicious_path, env)
cmd = env.execute.call_args[0][0]
# The $() must not appear unquoted — shlex.quote wraps it
assert "'/tmp/hermes-results/$(whoami).txt'" in cmd
def test_semicolon_injection_neutralized(self):
env = MagicMock()
env.execute.return_value = {"output": "", "returncode": 0}
malicious_path = "/tmp/x; rm -rf /; echo .txt"
_write_to_sandbox("content", malicious_path, env)
cmd = env.execute.call_args[0][0]
# The semicolons must be inside quotes, not acting as command separators
assert "'/tmp/x; rm -rf /; echo .txt'" in cmd
class TestResolveStorageDir:
def test_defaults_to_storage_dir_without_env(self):

View file

@ -473,13 +473,104 @@ def _cleanup_inactive_browser_sessions():
logger.warning("Error cleaning up inactive session %s: %s", task_id, e)
def _reap_orphaned_browser_sessions():
"""Scan for orphaned agent-browser daemon processes from previous runs.
When the Python process that created a browser session exits uncleanly
(SIGKILL, crash, gateway restart), the in-memory ``_active_sessions``
tracking is lost but the node + Chromium processes keep running.
This function scans the tmp directory for ``agent-browser-*`` socket dirs
left behind by previous runs, reads the daemon PID files, and kills any
daemons that are still alive but not tracked by the current process.
Called once on cleanup-thread startup not every 30 seconds to avoid
races with sessions being actively created.
"""
import glob
tmpdir = _socket_safe_tmpdir()
pattern = os.path.join(tmpdir, "agent-browser-h_*")
socket_dirs = glob.glob(pattern)
# Also pick up CDP sessions
socket_dirs += glob.glob(os.path.join(tmpdir, "agent-browser-cdp_*"))
if not socket_dirs:
return
# Build set of session_names currently tracked by this process
with _cleanup_lock:
tracked_names = {
info.get("session_name")
for info in _active_sessions.values()
if info.get("session_name")
}
reaped = 0
for socket_dir in socket_dirs:
dir_name = os.path.basename(socket_dir)
# dir_name is "agent-browser-{session_name}"
session_name = dir_name.removeprefix("agent-browser-")
if not session_name:
continue
# Skip sessions that we are actively tracking
if session_name in tracked_names:
continue
pid_file = os.path.join(socket_dir, f"{session_name}.pid")
if not os.path.isfile(pid_file):
# No PID file — just a stale dir, remove it
shutil.rmtree(socket_dir, ignore_errors=True)
continue
try:
daemon_pid = int(Path(pid_file).read_text().strip())
except (ValueError, OSError):
shutil.rmtree(socket_dir, ignore_errors=True)
continue
# Check if the daemon is still alive
try:
os.kill(daemon_pid, 0) # signal 0 = existence check
except ProcessLookupError:
# Already dead, just clean up the dir
shutil.rmtree(socket_dir, ignore_errors=True)
continue
except PermissionError:
# Alive but owned by someone else — leave it alone
continue
# Daemon is alive and not tracked — orphan. Kill it.
try:
os.kill(daemon_pid, signal.SIGTERM)
logger.info("Reaped orphaned browser daemon PID %d (session %s)",
daemon_pid, session_name)
reaped += 1
except (ProcessLookupError, PermissionError, OSError):
pass
# Clean up the socket directory
shutil.rmtree(socket_dir, ignore_errors=True)
if reaped:
logger.info("Reaped %d orphaned browser session(s) from previous run(s)", reaped)
def _browser_cleanup_thread_worker():
"""
Background thread that periodically cleans up inactive browser sessions.
Runs every 30 seconds and checks for sessions that haven't been used
within the BROWSER_SESSION_INACTIVITY_TIMEOUT period.
On first run, also reaps orphaned sessions from previous process lifetimes.
"""
# One-time orphan reap on startup
try:
_reap_orphaned_browser_sessions()
except Exception as e:
logger.warning("Orphan reap error: %s", e)
while _cleanup_running:
try:
_cleanup_inactive_browser_sessions()

View file

@ -21,6 +21,7 @@ into the user's project directory.
import hashlib
import logging
import os
import re
import shutil
import subprocess
from pathlib import Path
@ -64,23 +65,72 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", "
# Max files to snapshot — skip huge directories to avoid slowdowns.
_MAX_FILES = 50_000
# Valid git commit hash pattern: 440 hex chars (short or full SHA-1/SHA-256).
_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$')
# ---------------------------------------------------------------------------
# Input validation helpers
# ---------------------------------------------------------------------------
def _validate_commit_hash(commit_hash: str) -> Optional[str]:
"""Validate a commit hash to prevent git argument injection.
Returns an error string if invalid, None if valid.
Values starting with '-' would be interpreted as git flags
(e.g., '--patch', '-p') instead of revision specifiers.
"""
if not commit_hash or not commit_hash.strip():
return "Empty commit hash"
if commit_hash.startswith("-"):
return f"Invalid commit hash (must not start with '-'): {commit_hash!r}"
if not _COMMIT_HASH_RE.match(commit_hash):
return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}"
return None
def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]:
"""Validate a file path to prevent path traversal outside the working directory.
Returns an error string if invalid, None if valid.
"""
if not file_path or not file_path.strip():
return "Empty file path"
# Reject absolute paths — restore targets must be relative to the workdir
if os.path.isabs(file_path):
return f"File path must be relative, got absolute path: {file_path!r}"
# Resolve and check containment within working_dir
abs_workdir = _normalize_path(working_dir)
resolved = (abs_workdir / file_path).resolve()
try:
resolved.relative_to(abs_workdir)
except ValueError:
return f"File path escapes the working directory via traversal: {file_path!r}"
return None
# ---------------------------------------------------------------------------
# Shadow repo helpers
# ---------------------------------------------------------------------------
def _normalize_path(path_value: str) -> Path:
"""Return a canonical absolute path for checkpoint operations."""
return Path(path_value).expanduser().resolve()
def _shadow_repo_path(working_dir: str) -> Path:
"""Deterministic shadow repo path: sha256(abs_path)[:16]."""
abs_path = str(Path(working_dir).resolve())
abs_path = str(_normalize_path(working_dir))
dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16]
return CHECKPOINT_BASE / dir_hash
def _git_env(shadow_repo: Path, working_dir: str) -> dict:
"""Build env dict that redirects git to the shadow repo."""
normalized_working_dir = _normalize_path(working_dir)
env = os.environ.copy()
env["GIT_DIR"] = str(shadow_repo)
env["GIT_WORK_TREE"] = str(Path(working_dir).resolve())
env["GIT_WORK_TREE"] = str(normalized_working_dir)
env.pop("GIT_INDEX_FILE", None)
env.pop("GIT_NAMESPACE", None)
env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None)
@ -100,7 +150,17 @@ def _run_git(
exits while preserving the normal ``ok = (returncode == 0)`` contract.
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
"""
env = _git_env(shadow_repo, working_dir)
normalized_working_dir = _normalize_path(working_dir)
if not normalized_working_dir.exists():
msg = f"working directory not found: {normalized_working_dir}"
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
return False, "", msg
if not normalized_working_dir.is_dir():
msg = f"working directory is not a directory: {normalized_working_dir}"
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
return False, "", msg
env = _git_env(shadow_repo, str(normalized_working_dir))
cmd = ["git"] + list(args)
allowed_returncodes = allowed_returncodes or set()
try:
@ -110,7 +170,7 @@ def _run_git(
text=True,
timeout=timeout,
env=env,
cwd=str(Path(working_dir).resolve()),
cwd=str(normalized_working_dir),
)
ok = result.returncode == 0
stdout = result.stdout.strip()
@ -125,9 +185,14 @@ def _run_git(
msg = f"git timed out after {timeout}s: {' '.join(cmd)}"
logger.error(msg, exc_info=True)
return False, "", msg
except FileNotFoundError:
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
return False, "", "git not found"
except FileNotFoundError as exc:
missing_target = getattr(exc, "filename", None)
if missing_target == "git":
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
return False, "", "git not found"
msg = f"working directory not found: {normalized_working_dir}"
logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True)
return False, "", msg
except Exception as exc:
logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True)
return False, "", str(exc)
@ -154,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]:
)
(shadow_repo / "HERMES_WORKDIR").write_text(
str(Path(working_dir).resolve()) + "\n", encoding="utf-8"
str(_normalize_path(working_dir)) + "\n", encoding="utf-8"
)
logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir)
@ -229,7 +294,7 @@ class CheckpointManager:
if not self._git_available:
return False
abs_dir = str(Path(working_dir).resolve())
abs_dir = str(_normalize_path(working_dir))
# Skip root, home, and other overly broad directories
if abs_dir in ("/", str(Path.home())):
@ -254,7 +319,7 @@ class CheckpointManager:
Returns a list of dicts with keys: hash, short_hash, timestamp, reason,
files_changed, insertions, deletions. Most recent first.
"""
abs_dir = str(Path(working_dir).resolve())
abs_dir = str(_normalize_path(working_dir))
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
@ -311,7 +376,12 @@ class CheckpointManager:
Returns dict with success, diff text, and stat summary.
"""
abs_dir = str(Path(working_dir).resolve())
# Validate commit_hash to prevent git argument injection
hash_err = _validate_commit_hash(commit_hash)
if hash_err:
return {"success": False, "error": hash_err}
abs_dir = str(_normalize_path(working_dir))
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
@ -364,7 +434,19 @@ class CheckpointManager:
Returns dict with success/error info.
"""
abs_dir = str(Path(working_dir).resolve())
# Validate commit_hash to prevent git argument injection
hash_err = _validate_commit_hash(commit_hash)
if hash_err:
return {"success": False, "error": hash_err}
abs_dir = str(_normalize_path(working_dir))
# Validate file_path to prevent path traversal outside the working dir
if file_path:
path_err = _validate_file_path(file_path, abs_dir)
if path_err:
return {"success": False, "error": path_err}
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
@ -413,7 +495,7 @@ class CheckpointManager:
(directory containing .git, pyproject.toml, package.json, etc.).
Falls back to the file's parent directory.
"""
path = Path(file_path).resolve()
path = _normalize_path(file_path)
if path.is_dir():
candidate = path
else:

View file

@ -924,8 +924,8 @@ def execute_code(
# --- Local execution path (UDS) --- below this line is unchanged ---
# Import interrupt event from terminal_tool (cooperative cancellation)
from tools.terminal_tool import _interrupt_event
# Import per-thread interrupt check (cooperative cancellation)
from tools.interrupt import is_interrupted as _is_interrupted
# Resolve config
_cfg = _load_config()
@ -1114,7 +1114,7 @@ def execute_code(
status = "success"
while proc.poll() is None:
if _interrupt_event.is_set():
if _is_interrupted():
_kill_process_group(proc)
status = "interrupted"
break

View file

@ -80,20 +80,18 @@ def register_credential_file(
# Resolve symlinks and normalise ``..`` before the containment check so
# that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME.
try:
resolved = host_path.resolve()
hermes_home_resolved = hermes_home.resolve()
resolved.relative_to(hermes_home_resolved) # raises ValueError if outside
except ValueError:
from tools.path_security import validate_within_dir
containment_error = validate_within_dir(host_path, hermes_home)
if containment_error:
logger.warning(
"credential_files: rejected path traversal %r "
"(resolves to %s, outside HERMES_HOME %s)",
"credential_files: rejected path traversal %r (%s)",
relative_path,
resolved,
hermes_home_resolved,
containment_error,
)
return False
resolved = host_path.resolve()
if not resolved.is_file():
logger.debug("credential_files: skipping %s (not found)", resolved)
return False
@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]:
cfg = read_raw_config()
cred_files = cfg.get("terminal", {}).get("credential_files")
if isinstance(cred_files, list):
hermes_home_resolved = hermes_home.resolve()
from tools.path_security import validate_within_dir
for item in cred_files:
if isinstance(item, str) and item.strip():
rel = item.strip()
@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]:
"credential_files: rejected absolute config path %r", rel,
)
continue
host_path = (hermes_home / rel).resolve()
try:
host_path.relative_to(hermes_home_resolved)
except ValueError:
host_path = hermes_home / rel
containment_error = validate_within_dir(host_path, hermes_home)
if containment_error:
logger.warning(
"credential_files: rejected config path traversal %r "
"(resolves to %s, outside HERMES_HOME %s)",
rel, host_path, hermes_home_resolved,
"credential_files: rejected config path traversal %r (%s)",
rel, containment_error,
)
continue
if host_path.is_file():
resolved_path = host_path.resolve()
if resolved_path.is_file():
container_path = f"/root/.hermes/{rel}"
result.append({
"host_path": str(host_path),
"host_path": str(resolved_path),
"container_path": container_path,
})
except Exception as e:

View file

@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]:
)
# Validate containment after resolution
from tools.path_security import validate_within_dir
scripts_dir = get_hermes_home() / "scripts"
scripts_dir.mkdir(parents=True, exist_ok=True)
resolved = (scripts_dir / raw).resolve()
try:
resolved.relative_to(scripts_dir.resolve())
except ValueError:
containment_error = validate_within_dir(scripts_dir / raw, scripts_dir)
if containment_error:
return (
f"Script path escapes the scripts directory via traversal: {raw!r}"
)

View file

@ -1,8 +1,12 @@
"""Shared interrupt signaling for all tools.
"""Per-thread interrupt signaling for all tools.
Provides a global threading.Event that any tool can check to determine
if the user has requested an interrupt. The agent's interrupt() method
sets this event, and tools poll it during long-running operations.
Provides thread-scoped interrupt tracking so that interrupting one agent
session does not kill tools running in other sessions. This is critical
in the gateway where multiple agents run concurrently in the same process.
The agent stores its execution thread ID at the start of run_conversation()
and passes it to set_interrupt()/clear_interrupt(). Tools call
is_interrupted() which checks the CURRENT thread no argument needed.
Usage in tools:
from tools.interrupt import is_interrupted
@ -12,17 +16,61 @@ Usage in tools:
import threading
_interrupt_event = threading.Event()
# Set of thread idents that have been interrupted.
_interrupted_threads: set[int] = set()
_lock = threading.Lock()
def set_interrupt(active: bool) -> None:
"""Called by the agent to signal or clear the interrupt."""
if active:
_interrupt_event.set()
else:
_interrupt_event.clear()
def set_interrupt(active: bool, thread_id: int | None = None) -> None:
"""Set or clear interrupt for a specific thread.
Args:
active: True to signal interrupt, False to clear it.
thread_id: Target thread ident. When None, targets the
current thread (backward compat for CLI/tests).
"""
tid = thread_id if thread_id is not None else threading.current_thread().ident
with _lock:
if active:
_interrupted_threads.add(tid)
else:
_interrupted_threads.discard(tid)
def is_interrupted() -> bool:
"""Check if an interrupt has been requested. Safe to call from any thread."""
return _interrupt_event.is_set()
"""Check if an interrupt has been requested for the current thread.
Safe to call from any thread each thread only sees its own
interrupt state.
"""
tid = threading.current_thread().ident
with _lock:
return tid in _interrupted_threads
# ---------------------------------------------------------------------------
# Backward-compatible _interrupt_event proxy
# ---------------------------------------------------------------------------
# Some legacy call sites (code_execution_tool, process_registry, tests)
# import _interrupt_event directly and call .is_set() / .set() / .clear().
# This shim maps those calls to the per-thread functions above so existing
# code keeps working while the underlying mechanism is thread-scoped.
class _ThreadAwareEventProxy:
"""Drop-in proxy that maps threading.Event methods to per-thread state."""
def is_set(self) -> bool:
return is_interrupted()
def set(self) -> None: # noqa: A003
set_interrupt(True)
def clear(self) -> None:
set_interrupt(False)
def wait(self, timeout: float | None = None) -> bool:
"""Not truly supported — returns current state immediately."""
return self.is_set()
_interrupt_event = _ThreadAwareEventProxy()

43
tools/path_security.py Normal file
View file

@ -0,0 +1,43 @@
"""Shared path validation helpers for tool implementations.
Extracts the ``resolve() + relative_to()`` and ``..`` traversal check
patterns previously duplicated across skill_manager_tool, skills_tool,
skills_hub, cronjob_tools, and credential_files.
"""
import logging
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
def validate_within_dir(path: Path, root: Path) -> Optional[str]:
"""Ensure *path* resolves to a location within *root*.
Returns an error message string if validation fails, or ``None`` if the
path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize
``..`` components.
Usage::
error = validate_within_dir(user_path, allowed_root)
if error:
return json.dumps({"error": error})
"""
try:
resolved = path.resolve()
root_resolved = root.resolve()
resolved.relative_to(root_resolved)
except (ValueError, OSError) as exc:
return f"Path escapes allowed directory: {exc}"
return None
def has_traversal_component(path_str: str) -> bool:
"""Return True if *path_str* contains ``..`` traversal components.
Quick check for obvious traversal attempts before doing full resolution.
"""
parts = Path(path_str).parts
return ".." in parts

View file

@ -96,6 +96,8 @@ class ProcessSession:
# Watcher/notification metadata (persisted for crash recovery)
watcher_platform: str = ""
watcher_chat_id: str = ""
watcher_user_id: str = ""
watcher_user_name: str = ""
watcher_thread_id: str = ""
watcher_interval: int = 0 # 0 = no watcher configured
notify_on_complete: bool = False # Queue agent notification on exit
@ -695,7 +697,7 @@ class ProcessRegistry:
and output snapshot.
"""
from tools.ansi_strip import strip_ansi
from tools.terminal_tool import _interrupt_event
from tools.interrupt import is_interrupted as _is_interrupted
try:
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
@ -732,7 +734,7 @@ class ProcessRegistry:
result["timeout_note"] = timeout_note
return result
if _interrupt_event.is_set():
if _is_interrupted():
result = {
"status": "interrupted",
"output": strip_ansi(session.output_buffer[-1000:]),
@ -981,6 +983,8 @@ class ProcessRegistry:
"session_key": s.session_key,
"watcher_platform": s.watcher_platform,
"watcher_chat_id": s.watcher_chat_id,
"watcher_user_id": s.watcher_user_id,
"watcher_user_name": s.watcher_user_name,
"watcher_thread_id": s.watcher_thread_id,
"watcher_interval": s.watcher_interval,
"notify_on_complete": s.notify_on_complete,
@ -1042,6 +1046,8 @@ class ProcessRegistry:
detached=True, # Can't read output, but can report status + kill
watcher_platform=entry.get("watcher_platform", ""),
watcher_chat_id=entry.get("watcher_chat_id", ""),
watcher_user_id=entry.get("watcher_user_id", ""),
watcher_user_name=entry.get("watcher_user_name", ""),
watcher_thread_id=entry.get("watcher_thread_id", ""),
watcher_interval=entry.get("watcher_interval", 0),
notify_on_complete=entry.get("notify_on_complete", False),
@ -1060,6 +1066,8 @@ class ProcessRegistry:
"session_key": session.session_key,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"user_id": session.watcher_user_id,
"user_name": session.watcher_user_name,
"thread_id": session.watcher_thread_id,
"notify_on_complete": session.notify_on_complete,
})

View file

@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]:
Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir.
"""
from tools.path_security import has_traversal_component
if not file_path:
return "file_path is required."
normalized = Path(file_path)
# Prevent path traversal
if ".." in normalized.parts:
if has_traversal_component(file_path):
return "Path traversal ('..') is not allowed."
# Must be under an allowed subdirectory
@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]:
def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]:
"""Resolve a supporting-file path and ensure it stays within the skill directory."""
from tools.path_security import validate_within_dir
target = skill_dir / file_path
try:
resolved = target.resolve(strict=False)
skill_dir_resolved = skill_dir.resolve()
resolved.relative_to(skill_dir_resolved)
except ValueError:
return None, "Path escapes skill directory boundary."
except OSError as e:
return None, f"Invalid file path '{file_path}': {e}"
error = validate_within_dir(target, skill_dir)
if error:
return None, error
return target, None

View file

@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
return None
def _estimate_tokens(content: str) -> int:
"""
Rough token estimate (4 chars per token average).
Args:
content: Text content
Returns:
Estimated token count
"""
return len(content) // 4
# Token estimation — use the shared implementation from model_metadata.
from agent.model_metadata import estimate_tokens_rough as _estimate_tokens
def _parse_tags(tags_value) -> List[str]:
@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
# If a specific file path is requested, read that instead
if file_path and skill_dir:
from tools.path_security import validate_within_dir, has_traversal_component
# Security: Prevent path traversal attacks
normalized_path = Path(file_path)
if ".." in normalized_path.parts:
if has_traversal_component(file_path):
return json.dumps(
{
"success": False,
@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
target_file = skill_dir / file_path
# Security: Verify resolved path is still within skill directory
try:
resolved = target_file.resolve()
skill_dir_resolved = skill_dir.resolve()
if not resolved.is_relative_to(skill_dir_resolved):
return json.dumps(
{
"success": False,
"error": "Path escapes skill directory boundary.",
"hint": "Use a relative path within the skill directory",
},
ensure_ascii=False,
)
except (OSError, ValueError):
traversal_error = validate_within_dir(target_file, skill_dir)
if traversal_error:
return json.dumps(
{
"success": False,
"error": f"Invalid file path: '{file_path}'",
"hint": "Use a valid relative path within the skill directory",
"error": traversal_error,
"hint": "Use a relative path within the skill directory",
},
ensure_ascii=False,
)

View file

@ -1427,8 +1427,12 @@ def terminal_tool(
if _gw_platform and not check_interval:
_gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "")
_gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "")
_gw_user_id = _gse("HERMES_SESSION_USER_ID", "")
_gw_user_name = _gse("HERMES_SESSION_USER_NAME", "")
proc_session.watcher_platform = _gw_platform
proc_session.watcher_chat_id = _gw_chat_id
proc_session.watcher_user_id = _gw_user_id
proc_session.watcher_user_name = _gw_user_name
proc_session.watcher_thread_id = _gw_thread_id
proc_session.watcher_interval = 5
process_registry.pending_watchers.append({
@ -1437,6 +1441,8 @@ def terminal_tool(
"session_key": session_key,
"platform": _gw_platform,
"chat_id": _gw_chat_id,
"user_id": _gw_user_id,
"user_name": _gw_user_name,
"thread_id": _gw_thread_id,
"notify_on_complete": True,
})
@ -1457,10 +1463,14 @@ def terminal_tool(
watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "")
watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "")
watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "")
watcher_user_id = _gse2("HERMES_SESSION_USER_ID", "")
watcher_user_name = _gse2("HERMES_SESSION_USER_NAME", "")
# Store on session for checkpoint persistence
proc_session.watcher_platform = watcher_platform
proc_session.watcher_chat_id = watcher_chat_id
proc_session.watcher_user_id = watcher_user_id
proc_session.watcher_user_name = watcher_user_name
proc_session.watcher_thread_id = watcher_thread_id
proc_session.watcher_interval = effective_interval
@ -1470,6 +1480,8 @@ def terminal_tool(
"session_key": session_key,
"platform": watcher_platform,
"chat_id": watcher_chat_id,
"user_id": watcher_user_id,
"user_name": watcher_user_name,
"thread_id": watcher_thread_id,
})

View file

@ -24,6 +24,7 @@ Defense against context-window overflow operates at three levels:
import logging
import os
import shlex
import uuid
from tools.budget_config import (
@ -79,7 +80,7 @@ def _write_to_sandbox(content: str, remote_path: str, env) -> bool:
marker = _heredoc_marker(content)
storage_dir = os.path.dirname(remote_path)
cmd = (
f"mkdir -p {storage_dir} && cat > {remote_path} << '{marker}'\n"
f"mkdir -p {shlex.quote(storage_dir)} && cat > {shlex.quote(remote_path)} << '{marker}'\n"
f"{content}\n"
f"{marker}"
)

View file

@ -1,13 +1,16 @@
"""Shared utility functions for hermes-agent."""
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Optional, Union
import yaml
logger = logging.getLogger(__name__)
TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"})
@ -124,3 +127,88 @@ def atomic_yaml_write(
except OSError:
pass
raise
# ─── JSON Helpers ─────────────────────────────────────────────────────────────
def safe_json_loads(text: str, default: Any = None) -> Any:
"""Parse JSON, returning *default* on any parse error.
Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)``
pattern duplicated across display.py, anthropic_adapter.py,
auxiliary_client.py, and others.
"""
try:
return json.loads(text)
except (json.JSONDecodeError, TypeError, ValueError):
return default
def read_json_file(path: Path, default: Any = None) -> Any:
"""Read and parse a JSON file, returning *default* on any error.
Replaces the repeated ``try: json.loads(path.read_text()) except ...``
pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py,
and skill_utils.py.
"""
try:
return json.loads(Path(path).read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError, IOError, ValueError) as exc:
logger.debug("Failed to read %s: %s", path, exc)
return default
def read_jsonl(path: Path) -> List[dict]:
"""Read a JSONL file (one JSON object per line).
Returns a list of parsed objects, skipping blank lines.
"""
entries = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
entries.append(json.loads(line))
return entries
def append_jsonl(path: Path, entry: dict) -> None:
"""Append a single JSON object as a new line to a JSONL file."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# ─── Environment Variable Helpers ─────────────────────────────────────────────
def env_str(key: str, default: str = "") -> str:
"""Read an environment variable, stripped of whitespace.
Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times
across runtime_provider.py, anthropic_adapter.py, models.py, etc.
"""
return os.getenv(key, default).strip()
def env_lower(key: str, default: str = "") -> str:
"""Read an environment variable, stripped and lowercased."""
return os.getenv(key, default).strip().lower()
def env_int(key: str, default: int = 0) -> int:
"""Read an environment variable as an integer, with fallback."""
raw = os.getenv(key, "").strip()
if not raw:
return default
try:
return int(raw)
except (ValueError, TypeError):
return default
def env_bool(key: str, default: bool = False) -> bool:
"""Read an environment variable as a boolean."""
return is_truthy_value(os.getenv(key, ""), default=default)

View file

@ -195,9 +195,12 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
| `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates |
| `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist |
| `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) |
| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) |
| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill; also used for webhook signature validation) |
| `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) |
| `SMS_WEBHOOK_URL` | Public URL for Twilio signature validation — must match the webhook URL in Twilio Console (required) |
| `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) |
| `SMS_WEBHOOK_HOST` | Webhook bind address (default: `0.0.0.0`) |
| `SMS_INSECURE_NO_SIGNATURE` | Set to `true` to disable Twilio signature validation (local dev only — not for production) |
| `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat |
| `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist |
| `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery |

View file

@ -178,6 +178,8 @@ EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com
MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c
MATRIX_ALLOWED_USERS=@alice:matrix.org
DINGTALK_ALLOWED_USERS=user-id-1
FEISHU_ALLOWED_USERS=ou_xxxxxxxx,ou_yyyyyyyy
WECOM_ALLOWED_USERS=user-id-1,user-id-2
# Or allow
GATEWAY_ALLOWED_USERS=123456789,987654321

View file

@ -84,6 +84,13 @@ ngrok http 8080
Set the resulting public URL as your Twilio webhook.
:::
**Set `SMS_WEBHOOK_URL` to the same URL you configured in Twilio.** This is required for Twilio signature validation — the adapter will refuse to start without it:
```bash
# Must match the webhook URL in your Twilio Console
SMS_WEBHOOK_URL=https://your-server:8080/webhooks/twilio
```
The webhook port defaults to `8080`. Override with:
```bash
@ -101,9 +108,11 @@ hermes gateway
You should see:
```
[sms] Twilio webhook server listening on port 8080, from: +1555***4567
[sms] Twilio webhook server listening on 0.0.0.0:8080, from: +1555***4567
```
If you see `Refusing to start: SMS_WEBHOOK_URL is required`, set `SMS_WEBHOOK_URL` to the public URL configured in your Twilio Console (see Step 3).
Text your Twilio number — Hermes will respond via SMS.
---
@ -113,9 +122,12 @@ Text your Twilio number — Hermes will respond via SMS.
| Variable | Required | Description |
|----------|----------|-------------|
| `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) |
| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token |
| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token (also used for webhook signature validation) |
| `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) |
| `SMS_WEBHOOK_URL` | Yes | Public URL for Twilio signature validation — must match the webhook URL in your Twilio Console |
| `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) |
| `SMS_WEBHOOK_HOST` | No | Webhook bind address (default: `0.0.0.0`) |
| `SMS_INSECURE_NO_SIGNATURE` | No | Set to `true` to disable signature validation (local dev only — **not for production**) |
| `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat |
| `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) |
| `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery |
@ -134,6 +146,21 @@ Text your Twilio number — Hermes will respond via SMS.
## Security
### Webhook signature validation
Hermes validates that inbound webhooks genuinely originate from Twilio by verifying the `X-Twilio-Signature` header (HMAC-SHA1). This prevents attackers from injecting forged messages.
**`SMS_WEBHOOK_URL` is required.** Set it to the public URL configured in your Twilio Console. The adapter will refuse to start without it.
For local development without a public URL, you can disable validation:
```bash
# Local dev only — NOT for production
SMS_INSECURE_NO_SIGNATURE=true
```
### User allowlists
**The gateway denies all users by default.** Configure an allowlist:
```bash