mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
ec553fdb49
93 changed files with 3531 additions and 1330 deletions
|
|
@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
|
|||
Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -14,6 +13,8 @@ from dataclasses import dataclass, field
|
|||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
from utils import safe_json_loads
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
_RESET = "\033[0m"
|
||||
|
|
@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
|
|||
"""Conservatively detect whether a tool result represents success."""
|
||||
if not result:
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = safe_json_loads(result)
|
||||
if data is None:
|
||||
return False
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
|
@ -423,10 +423,7 @@ def extract_edit_diff(
|
|||
) -> str | None:
|
||||
"""Extract a unified diff from a file-edit tool result."""
|
||||
if tool_name == "patch" and result:
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = None
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
diff = data.get("diff")
|
||||
if isinstance(diff, str) and diff.strip():
|
||||
|
|
@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
|||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse terminal result as JSON for exit code check")
|
||||
return False, ""
|
||||
|
||||
# Memory-specific: distinguish "full" from real errors
|
||||
if tool_name == "memory":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
||||
return True, " [full]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse memory result as JSON for capacity check")
|
||||
|
||||
# Generic heuristic for non-terminal tools
|
||||
lower = result[:500].lower()
|
||||
|
|
|
|||
|
|
@ -179,6 +179,12 @@ _MAX_COMPLETION_KEYS = (
|
|||
|
||||
# Local server hostnames / address patterns
|
||||
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||
# Docker / Podman / Lima DNS names that resolve to the host machine
|
||||
_CONTAINER_LOCAL_SUFFIXES = (
|
||||
".docker.internal",
|
||||
".containers.internal",
|
||||
".lima.internal",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
|
|
@ -254,6 +260,9 @@ def is_local_endpoint(base_url: str) -> bool:
|
|||
return False
|
||||
if host in _LOCAL_HOSTS:
|
||||
return True
|
||||
# Docker / Podman / Lima internal DNS names (e.g. host.docker.internal)
|
||||
if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES):
|
||||
return True
|
||||
# RFC-1918 private ranges and link-local
|
||||
import ipaddress
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import threading
|
|||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, get_skills_dir
|
||||
from typing import Optional
|
||||
|
||||
from agent.skill_utils import (
|
||||
|
|
@ -548,8 +548,7 @@ def build_skills_system_prompt(
|
|||
are read-only — they appear in the index but new skills are always created
|
||||
in the local dir. Local skills take precedence when names collide.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
skills_dir = get_skills_dir()
|
||||
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
|
||||
|
||||
if not skills_dir.exists() and not external_dirs:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_skills_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
|
|||
Reads the config file directly (no CLI config imports) to stay
|
||||
lightweight.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return set()
|
||||
try:
|
||||
|
|
@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
|
|||
path. Only directories that actually exist are returned. Duplicates and
|
||||
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return []
|
||||
try:
|
||||
|
|
@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
|
|||
if not isinstance(raw_dirs, list):
|
||||
return []
|
||||
|
||||
local_skills = (get_hermes_home() / "skills").resolve()
|
||||
local_skills = get_skills_dir().resolve()
|
||||
seen: Set[Path] = set()
|
||||
result: List[Path] = []
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
|
|||
The local dir is always first (and always included even if it doesn't exist
|
||||
yet — callers handle that). External dirs follow in config order.
|
||||
"""
|
||||
dirs = [get_hermes_home() / "skills"]
|
||||
dirs = [get_skills_dir()]
|
||||
dirs.extend(get_external_skills_dirs())
|
||||
return dirs
|
||||
|
||||
|
|
@ -384,7 +384,7 @@ def resolve_skill_config_values(
|
|||
current values (or the declared default if the key isn't set).
|
||||
Path values are expanded via ``os.path.expanduser``.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
|
|
|
|||
9
cli.py
9
cli.py
|
|
@ -2748,6 +2748,15 @@ class HermesCLI:
|
|||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
# When a custom_provider entry carries an explicit `model` field,
|
||||
# use it as the effective model name. Without this, running
|
||||
# `hermes chat --model <provider-name>` sends the provider name
|
||||
# (e.g. "my-provider") as the model string to the API instead of
|
||||
# the configured model (e.g. "qwen3.6-plus"), causing 400 errors.
|
||||
runtime_model = runtime.get("model")
|
||||
if runtime_model and isinstance(runtime_model, str):
|
||||
self.model = runtime_model
|
||||
|
||||
# Normalize model for the resolved provider (e.g. swap non-Codex
|
||||
# models when provider is openai-codex). Fixes #651.
|
||||
model_changed = self._normalize_model_for_provider(resolved_provider)
|
||||
|
|
|
|||
|
|
@ -722,6 +722,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
|||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd
|
||||
skip_memory=True, # Cron system prompts would corrupt user representations
|
||||
platform="cron",
|
||||
session_id=_cron_session_id,
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1"
|
|||
DEFAULT_PORT = 8642
|
||||
MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
|
|
@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"""
|
||||
import queue as _q
|
||||
|
||||
sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}
|
||||
sse_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
# CORS middleware can't inject headers into StreamResponse after
|
||||
# prepare() flushes them, so resolve CORS headers up front.
|
||||
origin = request.headers.get("Origin", "")
|
||||
|
|
@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
|
|
@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Helper — route a queue item to the correct SSE event.
|
||||
async def _emit(item):
|
||||
|
|
@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
return time.monotonic()
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
|
||||
await response.write(b": keepalive\n\n")
|
||||
last_activity = time.monotonic()
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
|
|
|||
|
|
@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC):
|
|||
result = handler(self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
|
||||
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
|
||||
"""Acquire a scoped lock for this adapter. Returns True on success."""
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._platform_lock_scope = scope
|
||||
self._platform_lock_identity = identity
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
scope, identity, metadata={'platform': self.platform.value}
|
||||
)
|
||||
if acquired:
|
||||
return True
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = (
|
||||
f'{resource_desc} already in use'
|
||||
+ (f' (PID {owner_pid})' if owner_pid else '')
|
||||
+ '. Stop the other gateway first.'
|
||||
)
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
|
||||
return False
|
||||
|
||||
def _release_platform_lock(self) -> None:
|
||||
"""Release the scoped lock acquired by _acquire_platform_lock."""
|
||||
identity = getattr(self, '_platform_lock_identity', None)
|
||||
if not identity:
|
||||
return
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock(self._platform_lock_scope, identity)
|
||||
self._platform_lock_identity = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from gateway.platforms.base import (
|
|||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
|
|||
return value.rstrip("/")
|
||||
|
||||
|
||||
def _strip_markdown(text: str) -> str:
|
||||
"""Strip common markdown formatting for iMessage plain-text delivery."""
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
|||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
text = _strip_markdown(content or "")
|
||||
text = strip_markdown(content or "")
|
||||
if not text:
|
||||
return SendResult(success=False, error="BlueBubbles send requires text")
|
||||
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
|
||||
|
|
@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
|||
return info
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
return _strip_markdown(content)
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound attachment downloading (from #4588)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ except ImportError:
|
|||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -52,8 +53,6 @@ from gateway.platforms.base import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_LENGTH = 20000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
_SESSION_WEBHOOKS_MAX = 500
|
||||
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
|
||||
|
|
@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
self._stream_task: Optional[asyncio.Task] = None
|
||||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||||
|
||||
# Message deduplication: msg_id -> timestamp
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
# Message deduplication
|
||||
self._dedup = MessageDeduplicator(max_size=1000)
|
||||
# Map chat_id -> session_webhook for reply routing
|
||||
self._session_webhooks: Dict[str, str] = {}
|
||||
|
||||
|
|
@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
|
||||
self._stream_client = None
|
||||
self._session_webhooks.clear()
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# -- Inbound message processing -----------------------------------------
|
||||
|
|
@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
async def _on_message(self, message: "ChatbotMessage") -> None:
|
||||
"""Process an incoming DingTalk chatbot message."""
|
||||
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
|
||||
return
|
||||
|
||||
|
|
@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
content = " ".join(parts).strip()
|
||||
return content
|
||||
|
||||
# -- Deduplication ------------------------------------------------------
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Check and record a message ID. Returns True if already seen."""
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
# -- Outbound messaging -------------------------------------------------
|
||||
|
||||
async def send(
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
|||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
# Dedup cache: message_id → timestamp. Prevents duplicate bot
|
||||
# responses when Discord RESUME replays events after reconnects.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Discord
|
||||
# RESUME replays events after reconnects.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Reply threading mode: "off" (no replies), "first" (reply on first
|
||||
# chunk only, default), "all" (reply-reference on every chunk).
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
|
|
@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate bot token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('discord_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
|
||||
return False
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
|
|
@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
msg_id = str(message.id)
|
||||
now = time.time()
|
||||
if msg_id in adapter_self._seen_messages:
|
||||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||||
return
|
||||
adapter_self._seen_messages[msg_id] = now
|
||||
if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
|
||||
cutoff = now - adapter_self._SEEN_TTL
|
||||
adapter_self._seen_messages = {
|
||||
k: v for k, v in adapter_self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
|
||||
# Always ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
|
|
@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
|
@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self._client = None
|
||||
self._ready_event.clear()
|
||||
|
||||
# Release the token lock
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
|
|
@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Track thread participation so follow-ups don't require @mention
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# If a message was provided, kick off a new Hermes session in the thread
|
||||
starter = (message or "").strip()
|
||||
|
|
@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "discord_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load discord thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
# Trim to most recent entries if over cap
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save discord thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
|
|
@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
|
||||
in_bot_thread = is_thread and thread_id in self._threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions:
|
||||
|
|
@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
|
|
@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Only batch plain text messages — commands, media, etc. dispatch
|
||||
# immediately since they won't be split by the Discord client.
|
||||
|
|
|
|||
|
|
@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
|
|||
|
||||
|
||||
def _strip_markdown_to_plain_text(text: str) -> str:
|
||||
"""Strip markdown formatting to plain text for Feishu text fallbacks.
|
||||
|
||||
Delegates common markdown stripping to the shared helper and adds
|
||||
Feishu-specific patterns (blockquotes, strikethrough, underline tags,
|
||||
horizontal rules, \\r\\n normalisation).
|
||||
"""
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
plain = text.replace("\r\n", "\n")
|
||||
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
|
||||
plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
|
||||
plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
|
||||
plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
|
||||
plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
|
||||
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
|
||||
plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain)
|
||||
plain = re.sub(r"\n{3,}", "\n\n", plain)
|
||||
return plain.strip()
|
||||
plain = strip_markdown(plain)
|
||||
return plain
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:
|
||||
|
|
|
|||
261
gateway/platforms/helpers.py
Normal file
261
gateway/platforms/helpers.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""Shared helper classes for gateway platform adapters.
|
||||
|
||||
Extracts common patterns that were duplicated across 5-7 adapters:
|
||||
message deduplication, text batch aggregation, markdown stripping,
|
||||
and thread participation tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Message Deduplication ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""TTL-based message deduplication cache.
|
||||
|
||||
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
|
||||
previously duplicated in discord, slack, dingtalk, wecom, weixin,
|
||||
mattermost, and feishu adapters.
|
||||
|
||||
Usage::
|
||||
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# In message handler:
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
|
||||
self._seen: Dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
def is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Return True if *msg_id* was already seen within the TTL window."""
|
||||
if not msg_id:
|
||||
return False
|
||||
now = time.time()
|
||||
if msg_id in self._seen:
|
||||
return True
|
||||
self._seen[msg_id] = now
|
||||
if len(self._seen) > self._max_size:
|
||||
cutoff = now - self._ttl
|
||||
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all tracked messages."""
|
||||
self._seen.clear()
|
||||
|
||||
|
||||
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TextBatchAggregator:
|
||||
"""Aggregates rapid-fire text events into single messages.
|
||||
|
||||
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
|
||||
previously duplicated in telegram, discord, matrix, wecom, and feishu.
|
||||
|
||||
Usage::
|
||||
|
||||
self._text_batcher = TextBatchAggregator(
|
||||
handler=self._message_handler,
|
||||
batch_delay=0.6,
|
||||
split_threshold=1900,
|
||||
)
|
||||
|
||||
# In message dispatch:
|
||||
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
|
||||
self._text_batcher.enqueue(event, session_key)
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
*,
|
||||
batch_delay: float = 0.6,
|
||||
split_delay: float = 2.0,
|
||||
split_threshold: int = 4000,
|
||||
):
|
||||
self._handler = handler
|
||||
self._batch_delay = batch_delay
|
||||
self._split_delay = split_delay
|
||||
self._split_threshold = split_threshold
|
||||
self._pending: Dict[str, "MessageEvent"] = {}
|
||||
self._pending_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if batching is active (delay > 0)."""
|
||||
return self._batch_delay > 0
|
||||
|
||||
def enqueue(self, event: "MessageEvent", key: str) -> None:
|
||||
"""Add *event* to the pending batch for *key*."""
|
||||
chunk_len = len(event.text or "")
|
||||
existing = self._pending.get(key)
|
||||
if not existing:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending[key] = event
|
||||
else:
|
||||
existing.text = f"{existing.text}\n{event.text}"
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
|
||||
# Cancel prior flush timer, start a new one
|
||||
prior = self._pending_tasks.get(key)
|
||||
if prior and not prior.done():
|
||||
prior.cancel()
|
||||
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
|
||||
|
||||
async def _flush(self, key: str) -> None:
|
||||
"""Wait then dispatch the batched event for *key*."""
|
||||
current_task = self._pending_tasks.get(key)
|
||||
pending = self._pending.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
|
||||
# Use longer delay when the last chunk looks like a split message
|
||||
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
event = self._pending.pop(key, None)
|
||||
if event:
|
||||
try:
|
||||
await self._handler(event)
|
||||
except Exception:
|
||||
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
|
||||
|
||||
if self._pending_tasks.get(key) is current_task:
|
||||
self._pending_tasks.pop(key, None)
|
||||
|
||||
def cancel_all(self) -> None:
|
||||
"""Cancel all pending flush tasks."""
|
||||
for task in self._pending_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
self._pending_tasks.clear()
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
# ─── Markdown Stripping ──────────────────────────────────────────────────────
|
||||
|
||||
# Pre-compiled regexes for performance
|
||||
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
|
||||
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
|
||||
_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
|
||||
_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
|
||||
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
|
||||
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
|
||||
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
||||
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
|
||||
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
|
||||
|
||||
|
||||
def strip_markdown(text: str) -> str:
|
||||
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
|
||||
|
||||
Replaces the identical ``_strip_markdown()`` functions previously
|
||||
duplicated in sms.py, bluebubbles.py, and feishu.py.
|
||||
"""
|
||||
text = _RE_BOLD.sub(r"\1", text)
|
||||
text = _RE_ITALIC_STAR.sub(r"\1", text)
|
||||
text = _RE_BOLD_UNDER.sub(r"\1", text)
|
||||
text = _RE_ITALIC_UNDER.sub(r"\1", text)
|
||||
text = _RE_CODE_BLOCK.sub("", text)
|
||||
text = _RE_INLINE_CODE.sub(r"\1", text)
|
||||
text = _RE_HEADING.sub("", text)
|
||||
text = _RE_LINK.sub(r"\1", text)
|
||||
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ─── Thread Participation Tracking ───────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreadParticipationTracker:
|
||||
"""Persistent tracking of threads the bot has participated in.
|
||||
|
||||
Replaces the identical ``_load/_save_participated_threads`` +
|
||||
``_mark_thread_participated`` pattern previously duplicated in
|
||||
discord.py and matrix.py.
|
||||
|
||||
Usage::
|
||||
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
|
||||
# Check membership:
|
||||
if thread_id in self._threads:
|
||||
...
|
||||
|
||||
# Mark participation:
|
||||
self._threads.mark(thread_id)
|
||||
"""
|
||||
|
||||
_MAX_TRACKED = 500
|
||||
|
||||
def __init__(self, platform_name: str, max_tracked: int = 500):
|
||||
self._platform = platform_name
|
||||
self._max_tracked = max_tracked
|
||||
self._threads: set = self._load()
|
||||
|
||||
def _state_path(self) -> Path:
|
||||
from hermes_constants import get_hermes_home
|
||||
return get_hermes_home() / f"{self._platform}_threads.json"
|
||||
|
||||
def _load(self) -> set:
|
||||
path = self._state_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return set(json.loads(path.read_text(encoding="utf-8")))
|
||||
except Exception:
|
||||
pass
|
||||
return set()
|
||||
|
||||
def _save(self) -> None:
|
||||
path = self._state_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
thread_list = list(self._threads)
|
||||
if len(thread_list) > self._max_tracked:
|
||||
thread_list = thread_list[-self._max_tracked:]
|
||||
self._threads = set(thread_list)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
|
||||
def mark(self, thread_id: str) -> None:
|
||||
"""Mark *thread_id* as participated and persist."""
|
||||
if thread_id not in self._threads:
|
||||
self._threads.add(thread_id)
|
||||
self._save()
|
||||
|
||||
def __contains__(self, thread_id: str) -> bool:
|
||||
return thread_id in self._threads
|
||||
|
||||
def clear(self) -> None:
|
||||
self._threads.clear()
|
||||
|
||||
|
||||
# ─── Phone Number Redaction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging, preserving country code and last 4.
|
||||
|
||||
Replaces the identical ``_redact_phone()`` functions in signal.py,
|
||||
sms.py, and bluebubbles.py.
|
||||
"""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
|
@ -92,6 +92,7 @@ from gateway.platforms.base import (
|
|||
ProcessingOutcome,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
self._pending_megolm: list = []
|
||||
|
||||
# Thread participation tracking (for require_mention bypass)
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
self._threads = ThreadParticipationTracker("matrix")
|
||||
|
||||
# Mention/thread gating — parsed once from env vars.
|
||||
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
|
@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# Require-mention gating.
|
||||
if not is_dm:
|
||||
is_free_room = room_id in self._free_rooms
|
||||
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
|
||||
in_bot_thread = bool(thread_id and thread_id in self._threads)
|
||||
if self._require_mention and not is_free_room and not in_bot_thread:
|
||||
if not is_mentioned:
|
||||
return None
|
||||
|
|
@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# DM mention-thread.
|
||||
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Strip mention from body.
|
||||
if is_mentioned:
|
||||
|
|
@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# Auto-thread.
|
||||
if not is_dm and not thread_id and self._auto_thread:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
display_name = await self._get_display_name(room_id, sender)
|
||||
source = self.build_source(
|
||||
|
|
@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
self._background_read_receipt(room_id, event_id)
|
||||
|
||||
|
|
@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
for rid in self._joined_rooms
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation tracking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "matrix_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load matrix thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save matrix thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mention detection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
or os.getenv("MATTERMOST_REPLY_MODE", "off")
|
||||
).lower()
|
||||
|
||||
# Dedup cache: post_id → timestamp (prevent reprocessing)
|
||||
self._seen_posts: Dict[str, float] = {}
|
||||
self._SEEN_MAX = 2000
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
# Dedup cache (prevent reprocessing)
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
|
|
@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
post_id = post.get("id", "")
|
||||
|
||||
# Dedup.
|
||||
self._prune_seen()
|
||||
if post_id in self._seen_posts:
|
||||
if self._dedup.is_duplicate(post_id):
|
||||
return
|
||||
self._seen_posts[post_id] = time.time()
|
||||
|
||||
# Build message event.
|
||||
channel_id = post.get("channel_id", "")
|
||||
|
|
@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
def _prune_seen(self) -> None:
|
||||
"""Remove expired entries from the dedup cache."""
|
||||
if len(self._seen_posts) < self._SEEN_MAX:
|
||||
return
|
||||
now = time.time()
|
||||
self._seen_posts = {
|
||||
pid: ts
|
||||
for pid, ts in self._seen_posts.items()
|
||||
if now - ts < self._SEEN_TTL
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from gateway.platforms.base import (
|
|||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
|
|||
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
|
||||
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
|
|
@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
self._phone_lock_identity: Optional[str] = None
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
self.http_url, redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._phone_lock_identity = self.account
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"signal-phone",
|
||||
self._phone_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Signal account"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Signal listener."
|
||||
)
|
||||
logger.error("Signal: %s", message)
|
||||
self._set_fatal_error("signal_phone_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
|
||||
|
|
@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
if self._phone_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("signal-phone", self._phone_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
|
||||
self._phone_lock_identity = None
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
|
|
@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from pathlib import Path as _Path
|
|||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
|
||||
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
|
||||
self._channel_team: Dict[str, str] = {} # channel_id → team_id
|
||||
# Dedup cache: event_ts → timestamp. Prevents duplicate bot
|
||||
# responses when Socket Mode reconnects redeliver events.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Socket Mode
|
||||
# reconnects redeliver events.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Track pending approval message_ts → resolved flag to prevent
|
||||
# double-clicks on approval buttons.
|
||||
self._approval_resolved: Dict[str, bool] = {}
|
||||
|
|
@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate app token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = app_token
|
||||
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('slack_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
|
||||
return False
|
||||
|
||||
# First token is the primary — used for AsyncApp / Socket Mode
|
||||
|
|
@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
|
||||
self._running = False
|
||||
|
||||
# Release the token lock (use stored identity, not re-read env)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('slack-app-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[Slack] Disconnected")
|
||||
|
||||
|
|
@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
event_ts = event.get("ts", "")
|
||||
if event_ts:
|
||||
now = time.time()
|
||||
if event_ts in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[event_ts] = now
|
||||
if len(self._seen_messages) > self._SEEN_MAX:
|
||||
cutoff = now - self._SEEN_TTL
|
||||
self._seen_messages = {
|
||||
k: v for k, v in self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
if event_ts and self._dedup.is_duplicate(event_ts):
|
||||
return
|
||||
|
||||
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
|
||||
# "none" — ignore all bot messages (default, backward-compatible)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ Shares credentials with the optional telephony skill — same env vars:
|
|||
|
||||
Gateway-specific env vars:
|
||||
- SMS_WEBHOOK_PORT (default 8080)
|
||||
- SMS_WEBHOOK_HOST (default 0.0.0.0)
|
||||
- SMS_WEBHOOK_URL (public URL for Twilio signature validation — required)
|
||||
- SMS_INSECURE_NO_SIGNATURE (true to disable signature validation — dev only)
|
||||
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
|
||||
- SMS_ALLOW_ALL_USERS (true/false)
|
||||
- SMS_HOME_CHANNEL (phone number for cron delivery)
|
||||
|
|
@ -17,9 +20,10 @@ Gateway-specific env vars:
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
|
@ -30,24 +34,14 @@ from gateway.platforms.base import (
|
|||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone, strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
|
||||
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
|
||||
DEFAULT_WEBHOOK_PORT = 8080
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:5] + "***" + phone[-4:]
|
||||
DEFAULT_WEBHOOK_HOST = "0.0.0.0"
|
||||
|
||||
|
||||
def check_sms_requirements() -> bool:
|
||||
|
|
@ -77,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
self._webhook_port: int = int(
|
||||
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
|
||||
self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip()
|
||||
self._runner = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
|
|
@ -98,13 +94,33 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
|
||||
return False
|
||||
|
||||
insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true"
|
||||
|
||||
if not self._webhook_url and not insecure_no_sig:
|
||||
logger.error(
|
||||
"[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio "
|
||||
"signature validation. Set it to the public URL configured in your "
|
||||
"Twilio console (e.g. https://example.com/webhooks/twilio). "
|
||||
"For local development without validation, set "
|
||||
"SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).",
|
||||
)
|
||||
return False
|
||||
|
||||
if insecure_no_sig and not self._webhook_url:
|
||||
logger.warning(
|
||||
"[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation "
|
||||
"is DISABLED. Any client that can reach port %d can inject messages. "
|
||||
"Do NOT use this in production.",
|
||||
self._webhook_port,
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/webhooks/twilio", self._handle_webhook)
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
|
||||
site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port)
|
||||
await site.start()
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
|
|
@ -112,9 +128,10 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
self._running = True
|
||||
|
||||
logger.info(
|
||||
"[sms] Twilio webhook server listening on port %d, from: %s",
|
||||
"[sms] Twilio webhook server listening on %s:%d, from: %s",
|
||||
self._webhook_host,
|
||||
self._webhook_port,
|
||||
_redact_phone(self._from_number),
|
||||
redact_phone(self._from_number),
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -163,7 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
error_msg = body.get("message", str(body))
|
||||
logger.error(
|
||||
"[sms] send failed to %s: %s %s",
|
||||
_redact_phone(chat_id),
|
||||
redact_phone(chat_id),
|
||||
resp.status,
|
||||
error_msg,
|
||||
)
|
||||
|
|
@ -174,7 +191,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
msg_sid = body.get("sid", "")
|
||||
last_result = SendResult(success=True, message_id=msg_sid)
|
||||
except Exception as e:
|
||||
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
|
||||
logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
finally:
|
||||
# Close session only if we created a fallback (no persistent session)
|
||||
|
|
@ -192,16 +209,75 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Strip markdown — SMS renders it as literal characters."""
|
||||
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"```[a-z]*\n?", "", content)
|
||||
content = re.sub(r"`(.+?)`", r"\1", content)
|
||||
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
|
||||
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
|
||||
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||
return content.strip()
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio signature validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_twilio_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64).
|
||||
|
||||
Tries both with and without the default port for the URL scheme,
|
||||
since Twilio may sign with either variant.
|
||||
|
||||
Algorithm: https://www.twilio.com/docs/usage/security#validating-requests
|
||||
"""
|
||||
if self._check_signature(url, post_params, signature):
|
||||
return True
|
||||
|
||||
variant = self._port_variant_url(url)
|
||||
if variant and self._check_signature(variant, post_params, signature):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Compute and compare a single Twilio signature."""
|
||||
data_to_sign = url
|
||||
for key in sorted(post_params.keys()):
|
||||
data_to_sign += key + post_params[key]
|
||||
mac = hmac.new(
|
||||
self._auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
computed = base64.b64encode(mac.digest()).decode("utf-8")
|
||||
return hmac.compare_digest(computed, signature)
|
||||
|
||||
@staticmethod
|
||||
def _port_variant_url(url: str) -> str | None:
|
||||
"""Return the URL with the default port toggled, or None.
|
||||
|
||||
Only toggles default ports (443 for https, 80 for http).
|
||||
Non-standard ports are never modified.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
default_ports = {"https": 443, "http": 80}
|
||||
default_port = default_ports.get(parsed.scheme)
|
||||
if default_port is None:
|
||||
return None
|
||||
|
||||
if parsed.port == default_port:
|
||||
# Has explicit default port → strip it
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, parsed.hostname, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
elif parsed.port is None:
|
||||
# No port → add default
|
||||
netloc = f"{parsed.hostname}:{default_port}"
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, netloc, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
|
||||
# Non-standard port — no variant
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio webhook handler
|
||||
|
|
@ -213,7 +289,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
raw = await request.read()
|
||||
# Twilio sends form-encoded data, not JSON
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"))
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
|
||||
except Exception as e:
|
||||
logger.error("[sms] webhook parse error: %s", e)
|
||||
return web.Response(
|
||||
|
|
@ -222,6 +298,27 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
status=400,
|
||||
)
|
||||
|
||||
# Validate Twilio request signature when SMS_WEBHOOK_URL is configured
|
||||
if self._webhook_url:
|
||||
twilio_sig = request.headers.get("X-Twilio-Signature", "")
|
||||
if not twilio_sig:
|
||||
logger.warning("[sms] Rejected: missing X-Twilio-Signature header")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
flat_params = {k: v[0] for k, v in form.items() if v}
|
||||
if not self._validate_twilio_signature(
|
||||
self._webhook_url, flat_params, twilio_sig
|
||||
):
|
||||
logger.warning("[sms] Rejected: invalid Twilio signature")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
|
||||
# Extract fields (parse_qs returns lists)
|
||||
from_number = (form.get("From", [""]))[0].strip()
|
||||
to_number = (form.get("To", [""]))[0].strip()
|
||||
|
|
@ -236,7 +333,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
# Ignore messages from our own number (echo prevention)
|
||||
if from_number == self._from_number:
|
||||
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
|
||||
logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
|
|
@ -244,8 +341,8 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
logger.info(
|
||||
"[sms] inbound from %s -> %s: %s",
|
||||
_redact_phone(from_number),
|
||||
_redact_phone(to_number),
|
||||
redact_phone(from_number),
|
||||
redact_phone(to_number),
|
||||
text[:80],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
|
|
@ -300,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
|
||||
# Exhausted retries — fatal
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Another process is already polling this Telegram bot token "
|
||||
"(possibly OpenClaw or another Hermes instance). "
|
||||
"Hermes stopped Telegram polling after %d retries. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
"Only one poller can run per token — stop the other process "
|
||||
"and restart with 'hermes start'."
|
||||
% MAX_CONFLICT_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
|
|
@ -497,23 +498,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"telegram-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Telegram bot token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Telegram poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("telegram_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
|
|
@ -737,12 +722,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
message = f"Telegram startup failed: {e}"
|
||||
self._set_fatal_error("telegram_connect_error", message, retryable=True)
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
|
|
@ -768,12 +748,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
for task in self._pending_photo_batch_tasks.values():
|
||||
if task and not task.done():
|
||||
|
|
@ -784,7 +759,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self._mark_disconnected()
|
||||
self._app = None
|
||||
self._bot = None
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ except ImportError:
|
|||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
|
|||
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
IMAGE_MAX_BYTES = 10 * 1024 * 1024
|
||||
|
|
@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
|
||||
self._reply_req_ids: Dict[str, str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
|
|
@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _cleanup_ws(self) -> None:
|
||||
|
|
@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
|
||||
return
|
||||
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
|
||||
|
|
@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
if voice_text:
|
||||
text_parts.append(voice_text)
|
||||
|
||||
# Extract appmsg title (filename) for WeCom AI Bot attachments
|
||||
if msgtype == "appmsg":
|
||||
appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {}
|
||||
title = str(appmsg.get("title") or "").strip()
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
if quote_type == "text":
|
||||
|
|
@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
refs.append(("image", body["image"]))
|
||||
if msgtype == "file" and isinstance(body.get("file"), dict):
|
||||
refs.append(("file", body["file"]))
|
||||
# Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel)
|
||||
if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict):
|
||||
appmsg = body["appmsg"]
|
||||
if isinstance(appmsg.get("file"), dict):
|
||||
refs.append(("file", appmsg["file"]))
|
||||
elif isinstance(appmsg.get("image"), dict):
|
||||
refs.append(("image", appmsg["image"]))
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
|
|
@ -825,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
wildcard = self._groups.get("*")
|
||||
return wildcard if isinstance(wildcard, dict) else {}
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {
|
||||
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
|
||||
}
|
||||
if self._reply_req_ids:
|
||||
self._reply_req_ids = {
|
||||
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
|
||||
}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
|
||||
normalized_message_id = str(message_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
|
|||
CRYPTO_AVAILABLE = False
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -63,6 +64,7 @@ from gateway.platforms.base import (
|
|||
cache_image_from_bytes,
|
||||
)
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import atomic_json_write
|
||||
|
||||
ILINK_BASE_URL = "https://ilinkai.weixin.qq.com"
|
||||
WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||
|
|
@ -206,7 +208,7 @@ def save_weixin_account(
|
|||
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
path = _account_file(hermes_home, account_id)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
atomic_json_write(path, payload)
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
|
|
@ -269,7 +271,7 @@ class ContextTokenStore:
|
|||
if key.startswith(prefix)
|
||||
}
|
||||
try:
|
||||
self._path(account_id).write_text(json.dumps(payload), encoding="utf-8")
|
||||
atomic_json_write(self._path(account_id), payload)
|
||||
except Exception as exc:
|
||||
logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc)
|
||||
|
||||
|
|
@ -868,7 +870,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str:
|
|||
|
||||
def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None:
|
||||
path = _sync_buf_path(hermes_home, account_id)
|
||||
path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8")
|
||||
atomic_json_write(path, {"get_updates_buf": sync_buf})
|
||||
|
||||
|
||||
async def qr_login(
|
||||
|
|
@ -1007,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
self._typing_cache = TypingTicketCache()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
|
||||
|
||||
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
|
||||
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
|
||||
|
|
@ -1016,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
self._cdn_base_url = str(
|
||||
extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL)
|
||||
).strip().rstrip("/")
|
||||
self._send_chunk_delay_seconds = float(
|
||||
extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35")
|
||||
)
|
||||
self._send_chunk_retries = int(
|
||||
extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2")
|
||||
)
|
||||
self._send_chunk_retry_delay_seconds = float(
|
||||
extra.get("send_chunk_retry_delay_seconds")
|
||||
or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0")
|
||||
)
|
||||
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower()
|
||||
self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower()
|
||||
allow_from = extra.get("allow_from")
|
||||
|
|
@ -1066,23 +1077,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self._token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"weixin-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Weixin token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Weixin poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("weixin_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
|
||||
|
|
@ -1106,12 +1101,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("weixin-bot-token", self._token_lock_identity)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
self._mark_disconnected()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
|
|
@ -1189,16 +1179,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
message_id = str(message.get("message_id") or "").strip()
|
||||
if message_id:
|
||||
now = time.time()
|
||||
self._seen_messages = {
|
||||
key: value
|
||||
for key, value in self._seen_messages.items()
|
||||
if now - value < MESSAGE_DEDUP_TTL_SECONDS
|
||||
}
|
||||
if message_id in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[message_id] = now
|
||||
if message_id and self._dedup.is_duplicate(message_id):
|
||||
return
|
||||
|
||||
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
|
||||
if chat_type == "group":
|
||||
|
|
@ -1374,6 +1356,47 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages,
|
||||
)
|
||||
|
||||
async def _send_text_chunk(
|
||||
self,
|
||||
*,
|
||||
chat_id: str,
|
||||
chunk: str,
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
"""Send a single text chunk with per-chunk retry and backoff."""
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(self._send_chunk_retries + 1):
|
||||
try:
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt >= self._send_chunk_retries:
|
||||
break
|
||||
wait = self._send_chunk_retry_delay_seconds * (attempt + 1)
|
||||
logger.warning(
|
||||
"[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s",
|
||||
self.name,
|
||||
_safe_id(chat_id),
|
||||
attempt + 1,
|
||||
self._send_chunk_retries + 1,
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
@ -1388,19 +1411,16 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
chunks = self._split_text(self.format_message(content))
|
||||
for idx, chunk in enumerate(chunks):
|
||||
if idx > 0:
|
||||
await asyncio.sleep(0.3)
|
||||
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
await self._send_text_chunk(
|
||||
chat_id=chat_id,
|
||||
chunk=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
last_message_id = client_id
|
||||
if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0:
|
||||
await asyncio.sleep(self._send_chunk_delay_seconds)
|
||||
return SendResult(success=True, message_id=last_message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
|
|
|
|||
|
|
@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
|
|
@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
|
||||
# Acquire scoped lock to prevent duplicate sessions
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._session_lock_identity = str(self._session_path)
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"whatsapp-session",
|
||||
self._session_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this WhatsApp session"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second WhatsApp bridge."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
|
||||
|
|
@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
|
@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
self._session_lock_identity = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
|
|
|
|||
458
gateway/run.py
458
gateway/run.py
|
|
@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str:
|
|||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
|
||||
"""Consume and return the text of a pending queued message.
|
||||
def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
|
||||
"""Consume and return the full pending event for a session.
|
||||
|
||||
Preserves media context for captionless photo/document events by
|
||||
building a placeholder so the message isn't silently dropped.
|
||||
Queued follow-ups must preserve their media metadata so they can re-enter
|
||||
the normal image/STT/document preprocessing path instead of being reduced
|
||||
to a placeholder string.
|
||||
"""
|
||||
event = adapter.get_pending_message(session_key)
|
||||
if not event:
|
||||
return None
|
||||
text = event.text
|
||||
if not text and getattr(event, "media_urls", None):
|
||||
text = _build_media_placeholder(event)
|
||||
return text
|
||||
return adapter.get_pending_message(session_key)
|
||||
|
||||
|
||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||
|
|
@ -1465,7 +1460,18 @@ class GatewayRunner:
|
|||
logger.info("Recovered %s background process(es) from previous run", recovered)
|
||||
except Exception as e:
|
||||
logger.warning("Process checkpoint recovery: %s", e)
|
||||
|
||||
|
||||
# Suspend sessions that were active when the gateway last exited.
|
||||
# This prevents stuck sessions from being blindly resumed on restart,
|
||||
# which can create an unrecoverable loop (#7536). Suspended sessions
|
||||
# auto-reset on the next incoming message, giving the user a clean start.
|
||||
try:
|
||||
suspended = self.session_store.suspend_recently_active()
|
||||
if suspended:
|
||||
logger.info("Suspended %d in-flight session(s) from previous run", suspended)
|
||||
except Exception as e:
|
||||
logger.warning("Session suspension on startup failed: %s", e)
|
||||
|
||||
connected_count = 0
|
||||
enabled_platform_count = 0
|
||||
startup_nonretryable_errors: list[str] = []
|
||||
|
|
@ -2221,6 +2227,13 @@ class GatewayRunner:
|
|||
# are system-generated and must skip user authorization.
|
||||
if getattr(event, "internal", False):
|
||||
pass
|
||||
elif source.user_id is None:
|
||||
# Messages with no user identity (Telegram service messages,
|
||||
# channel forwards, anonymous admin actions) cannot be
|
||||
# authorized — drop silently instead of triggering the pairing
|
||||
# flow with a None user_id.
|
||||
logger.debug("Ignoring message with no user_id from %s", source.platform.value)
|
||||
return None
|
||||
elif not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
# In DMs: offer pairing code. In groups: silently ignore.
|
||||
|
|
@ -2370,8 +2383,11 @@ class GatewayRunner:
|
|||
self._pending_messages.pop(_quick_key, None)
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
logger.info("HARD STOP for session %s — session lock released", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||
# Mark session suspended so the next message starts fresh
|
||||
# instead of resuming the stuck context (#7536).
|
||||
self.session_store.suspend_session(_quick_key)
|
||||
logger.info("HARD STOP for session %s — suspended, session lock released", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The session is suspended — your next message will start fresh."
|
||||
|
||||
# /reset and /new must bypass the running-agent guard so they
|
||||
# actually dispatch as commands instead of being queued as user
|
||||
|
|
@ -2761,6 +2777,162 @@ class GatewayRunner:
|
|||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
async def _prepare_inbound_message_text(
|
||||
self,
|
||||
*,
|
||||
event: MessageEvent,
|
||||
source: SessionSource,
|
||||
history: List[Dict[str, Any]],
|
||||
) -> Optional[str]:
|
||||
"""Prepare inbound event text for the agent.
|
||||
|
||||
Keep the normal inbound path and the queued follow-up path on the same
|
||||
preprocessing pipeline so sender attribution, image enrichment, STT,
|
||||
document notes, reply context, and @ references all behave the same.
|
||||
"""
|
||||
history = history or []
|
||||
message_text = event.text or ""
|
||||
|
||||
_is_shared_thread = (
|
||||
source.chat_type != "dm"
|
||||
and source.thread_id
|
||||
and not getattr(self.config, "thread_sessions_per_user", False)
|
||||
)
|
||||
if _is_shared_thread and source.user_name:
|
||||
message_text = f"[{source.user_name}] {message_text}"
|
||||
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
if mtype.startswith("image/") or event.message_type == MessageType.PHOTO:
|
||||
image_paths.append(path)
|
||||
if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO):
|
||||
audio_paths.append(path)
|
||||
|
||||
if image_paths:
|
||||
message_text = await self._enrich_message_with_vision(
|
||||
message_text,
|
||||
image_paths,
|
||||
)
|
||||
|
||||
if audio_paths:
|
||||
message_text = await self._enrich_message_with_transcription(
|
||||
message_text,
|
||||
audio_paths,
|
||||
)
|
||||
_stt_fail_markers = (
|
||||
"No STT provider",
|
||||
"STT is disabled",
|
||||
"can't listen",
|
||||
"VOICE_TOOLS_OPENAI_KEY",
|
||||
)
|
||||
if any(marker in message_text for marker in _stt_fail_markers):
|
||||
_stt_adapter = self.adapters.get(source.platform)
|
||||
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
if _stt_adapter:
|
||||
try:
|
||||
_stt_msg = (
|
||||
"🎤 I received your voice message but can't transcribe it — "
|
||||
"no speech-to-text provider is configured.\n\n"
|
||||
"To enable voice: install faster-whisper "
|
||||
"(`pip install faster-whisper` in the Hermes venv) "
|
||||
"and set `stt.enabled: true` in config.yaml, "
|
||||
"then /restart the gateway."
|
||||
)
|
||||
if self._has_setup_skill():
|
||||
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
|
||||
await _stt_adapter.send(
|
||||
source.chat_id,
|
||||
_stt_msg,
|
||||
metadata=_stt_meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if event.media_urls and event.message_type == MessageType.DOCUMENT:
|
||||
import mimetypes as _mimetypes
|
||||
|
||||
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
if mtype in ("", "application/octet-stream"):
|
||||
import os as _os2
|
||||
|
||||
_ext = _os2.path.splitext(path)[1].lower()
|
||||
if _ext in _TEXT_EXTENSIONS:
|
||||
mtype = "text/plain"
|
||||
else:
|
||||
guessed, _ = _mimetypes.guess_type(path)
|
||||
if guessed:
|
||||
mtype = guessed
|
||||
if not mtype.startswith(("application/", "text/")):
|
||||
continue
|
||||
|
||||
import os as _os
|
||||
import re as _re
|
||||
|
||||
basename = _os.path.basename(path)
|
||||
parts = basename.split("_", 2)
|
||||
display_name = parts[2] if len(parts) >= 3 else basename
|
||||
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
|
||||
if mtype.startswith("text/"):
|
||||
context_note = (
|
||||
f"[The user sent a text document: '{display_name}'. "
|
||||
f"Its content has been included below. "
|
||||
f"The file is also saved at: {path}]"
|
||||
)
|
||||
else:
|
||||
context_note = (
|
||||
f"[The user sent a document: '{display_name}'. "
|
||||
f"The file is saved at: {path}. "
|
||||
f"Ask the user what they'd like you to do with it.]"
|
||||
)
|
||||
message_text = f"{context_note}\n\n{message_text}"
|
||||
|
||||
if getattr(event, "reply_to_text", None) and event.reply_to_message_id:
|
||||
reply_snippet = event.reply_to_text[:500]
|
||||
found_in_history = any(
|
||||
reply_snippet[:200] in (msg.get("content") or "")
|
||||
for msg in history
|
||||
if msg.get("role") in ("assistant", "user", "tool")
|
||||
)
|
||||
if not found_in_history:
|
||||
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
|
||||
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model,
|
||||
base_url=self._base_url or "",
|
||||
)
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text,
|
||||
cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len,
|
||||
allowed_root=_msg_cwd,
|
||||
)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return None
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
return message_text
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
_msg_start_time = time.time()
|
||||
|
|
@ -2812,7 +2984,9 @@ class GatewayRunner:
|
|||
# so the agent knows this is a fresh conversation (not an intentional /reset).
|
||||
if getattr(session_entry, 'was_auto_reset', False):
|
||||
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
|
||||
if reset_reason == "daily":
|
||||
if reset_reason == "suspended":
|
||||
context_note = "[System note: The user's previous session was stopped and suspended. This is a fresh conversation with no prior context.]"
|
||||
elif reset_reason == "daily":
|
||||
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
|
||||
else:
|
||||
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
|
||||
|
|
@ -2829,7 +3003,9 @@ class GatewayRunner:
|
|||
)
|
||||
platform_name = source.platform.value if source.platform else ""
|
||||
had_activity = getattr(session_entry, 'reset_had_activity', False)
|
||||
should_notify = (
|
||||
# Suspended sessions always notify (they were explicitly stopped
|
||||
# or crashed mid-operation) — skip the policy check.
|
||||
should_notify = reset_reason == "suspended" or (
|
||||
policy.notify
|
||||
and had_activity
|
||||
and platform_name not in policy.notify_exclude_platforms
|
||||
|
|
@ -2837,7 +3013,9 @@ class GatewayRunner:
|
|||
if should_notify:
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
if reset_reason == "daily":
|
||||
if reset_reason == "suspended":
|
||||
reason_text = "previous session was stopped or interrupted"
|
||||
elif reset_reason == "daily":
|
||||
reason_text = f"daily schedule at {policy.at_hour}:00"
|
||||
else:
|
||||
hours = policy.idle_minutes // 60
|
||||
|
|
@ -3195,149 +3373,13 @@ class GatewayRunner:
|
|||
# attachments (documents, audio, etc.) are not sent to the vision
|
||||
# tool even when they appear in the same message.
|
||||
# -----------------------------------------------------------------
|
||||
message_text = event.text or ""
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Sender attribution for shared thread sessions.
|
||||
#
|
||||
# When multiple users share a single thread session (the default for
|
||||
# threads), prefix each message with [sender name] so the agent can
|
||||
# tell participants apart. Skip for DMs (single-user by nature) and
|
||||
# when per-user thread isolation is explicitly enabled.
|
||||
# -----------------------------------------------------------------
|
||||
_is_shared_thread = (
|
||||
source.chat_type != "dm"
|
||||
and source.thread_id
|
||||
and not getattr(self.config, "thread_sessions_per_user", False)
|
||||
message_text = await self._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=history,
|
||||
)
|
||||
if _is_shared_thread and source.user_name:
|
||||
message_text = f"[{source.user_name}] {message_text}"
|
||||
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
# Check media_types if available; otherwise infer from message type
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
is_image = (
|
||||
mtype.startswith("image/")
|
||||
or event.message_type == MessageType.PHOTO
|
||||
)
|
||||
if is_image:
|
||||
image_paths.append(path)
|
||||
if image_paths:
|
||||
message_text = await self._enrich_message_with_vision(
|
||||
message_text, image_paths
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Auto-transcribe voice/audio messages sent by the user
|
||||
# -----------------------------------------------------------------
|
||||
if event.media_urls:
|
||||
audio_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
is_audio = (
|
||||
mtype.startswith("audio/")
|
||||
or event.message_type in (MessageType.VOICE, MessageType.AUDIO)
|
||||
)
|
||||
if is_audio:
|
||||
audio_paths.append(path)
|
||||
if audio_paths:
|
||||
message_text = await self._enrich_message_with_transcription(
|
||||
message_text, audio_paths
|
||||
)
|
||||
# If STT failed, send a direct message to the user so they
|
||||
# know voice isn't configured — don't rely on the agent to
|
||||
# relay the error clearly.
|
||||
_stt_fail_markers = (
|
||||
"No STT provider",
|
||||
"STT is disabled",
|
||||
"can't listen",
|
||||
"VOICE_TOOLS_OPENAI_KEY",
|
||||
)
|
||||
if any(m in message_text for m in _stt_fail_markers):
|
||||
_stt_adapter = self.adapters.get(source.platform)
|
||||
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
if _stt_adapter:
|
||||
try:
|
||||
_stt_msg = (
|
||||
"🎤 I received your voice message but can't transcribe it — "
|
||||
"no speech-to-text provider is configured.\n\n"
|
||||
"To enable voice: install faster-whisper "
|
||||
"(`pip install faster-whisper` in the Hermes venv) "
|
||||
"and set `stt.enabled: true` in config.yaml, "
|
||||
"then /restart the gateway."
|
||||
)
|
||||
# Point to setup skill if it's installed
|
||||
if self._has_setup_skill():
|
||||
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
|
||||
await _stt_adapter.send(
|
||||
source.chat_id, _stt_msg,
|
||||
metadata=_stt_meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Enrich document messages with context notes for the agent
|
||||
# -----------------------------------------------------------------
|
||||
if event.media_urls and event.message_type == MessageType.DOCUMENT:
|
||||
import mimetypes as _mimetypes
|
||||
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
# Fall back to extension-based detection when MIME type is unreliable.
|
||||
if mtype in ("", "application/octet-stream"):
|
||||
import os as _os2
|
||||
_ext = _os2.path.splitext(path)[1].lower()
|
||||
if _ext in _TEXT_EXTENSIONS:
|
||||
mtype = "text/plain"
|
||||
else:
|
||||
guessed, _ = _mimetypes.guess_type(path)
|
||||
if guessed:
|
||||
mtype = guessed
|
||||
if not mtype.startswith(("application/", "text/")):
|
||||
continue
|
||||
# Extract display filename by stripping the doc_{uuid12}_ prefix
|
||||
import os as _os
|
||||
basename = _os.path.basename(path)
|
||||
# Format: doc_<12hex>_<original_filename>
|
||||
parts = basename.split("_", 2)
|
||||
display_name = parts[2] if len(parts) >= 3 else basename
|
||||
# Sanitize to prevent prompt injection via filenames
|
||||
import re as _re
|
||||
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
|
||||
if mtype.startswith("text/"):
|
||||
context_note = (
|
||||
f"[The user sent a text document: '{display_name}'. "
|
||||
f"Its content has been included below. "
|
||||
f"The file is also saved at: {path}]"
|
||||
)
|
||||
else:
|
||||
context_note = (
|
||||
f"[The user sent a document: '{display_name}'. "
|
||||
f"The file is saved at: {path}. "
|
||||
f"Ask the user what they'd like you to do with it.]"
|
||||
)
|
||||
message_text = f"{context_note}\n\n{message_text}"
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Inject reply context when user replies to a message not in history.
|
||||
# Telegram (and other platforms) let users reply to specific messages,
|
||||
# but if the quoted message is from a previous session, cron delivery,
|
||||
# or background task, the agent has no context about what's being
|
||||
# referenced. Prepend the quoted text so the agent understands. (#1594)
|
||||
# -----------------------------------------------------------------
|
||||
if getattr(event, 'reply_to_text', None) and event.reply_to_message_id:
|
||||
reply_snippet = event.reply_to_text[:500]
|
||||
found_in_history = any(
|
||||
reply_snippet[:200] in (msg.get("content") or "")
|
||||
for msg in history
|
||||
if msg.get("role") in ("assistant", "user", "tool")
|
||||
)
|
||||
if not found_in_history:
|
||||
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
|
||||
if message_text is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Emit agent:start hook
|
||||
|
|
@ -3349,30 +3391,6 @@ class GatewayRunner:
|
|||
}
|
||||
await self.hooks.emit("agent:start", hook_ctx)
|
||||
|
||||
# Expand @ context references (@file:, @folder:, @diff, etc.)
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model, base_url=self._base_url or "")
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text, cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
# Run the agent
|
||||
agent_result = await self._run_agent(
|
||||
message=message_text,
|
||||
|
|
@ -4010,25 +4028,31 @@ class GatewayRunner:
|
|||
handles /stop before this method is reached. This handler fires
|
||||
only through normal command dispatch (no running agent) or as a
|
||||
fallback. Force-clean the session lock in all cases for safety.
|
||||
|
||||
When there IS a running/pending agent, the session is also marked
|
||||
as *suspended* so the next message starts a fresh session instead
|
||||
of resuming the stuck context (#7536).
|
||||
"""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||
self.session_store.suspend_session(session_key)
|
||||
logger.info("HARD STOP (pending) for session %s — suspended, sentinel cleared", session_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — your next message will start fresh."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||
self.session_store.suspend_session(session_key)
|
||||
return "⚡ Force-stopped. Your next message will start a fresh session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
|
|
@ -6694,6 +6718,8 @@ class GatewayRunner:
|
|||
chat_id=context.source.chat_id,
|
||||
chat_name=context.source.chat_name or "",
|
||||
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
|
||||
user_id=str(context.source.user_id) if context.source.user_id else "",
|
||||
user_name=str(context.source.user_name) if context.source.user_name else "",
|
||||
)
|
||||
|
||||
def _clear_session_env(self, tokens: list) -> None:
|
||||
|
|
@ -6906,6 +6932,8 @@ class GatewayRunner:
|
|||
platform_name = watcher.get("platform", "")
|
||||
chat_id = watcher.get("chat_id", "")
|
||||
thread_id = watcher.get("thread_id", "")
|
||||
user_id = watcher.get("user_id", "")
|
||||
user_name = watcher.get("user_name", "")
|
||||
agent_notify = watcher.get("notify_on_complete", False)
|
||||
notify_mode = self._load_background_notifications_mode()
|
||||
|
||||
|
|
@ -6961,6 +6989,8 @@ class GatewayRunner:
|
|||
platform=_platform_enum,
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id or None,
|
||||
user_id=user_id or None,
|
||||
user_name=user_name or None,
|
||||
)
|
||||
synth_event = MessageEvent(
|
||||
text=synth_text,
|
||||
|
|
@ -8115,17 +8145,16 @@ class GatewayRunner:
|
|||
|
||||
# Get pending message from adapter.
|
||||
# Use session_key (not source.chat_id) to match adapter's storage keys.
|
||||
pending_event = None
|
||||
pending = None
|
||||
if result and adapter and session_key:
|
||||
if result.get("interrupted"):
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if not pending and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
else:
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if pending:
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
elif pending_event:
|
||||
pending = pending_event.text or _build_media_placeholder(pending_event)
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
# Safety net: if the pending text is a slash command (e.g. "/stop",
|
||||
# "/new"), discard it — commands should never be passed to the agent
|
||||
# as user input. The primary fix is in base.py (commands bypass the
|
||||
|
|
@ -8143,27 +8172,29 @@ class GatewayRunner:
|
|||
"commands must not be passed as agent input",
|
||||
_pending_cmd_word,
|
||||
)
|
||||
pending_event = None
|
||||
pending = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self._draining and pending:
|
||||
if self._draining and (pending_event or pending):
|
||||
logger.info(
|
||||
"Discarding pending follow-up for session %s during gateway %s",
|
||||
session_key[:20] if session_key else "?",
|
||||
self._status_action_label(),
|
||||
)
|
||||
pending_event = None
|
||||
pending = None
|
||||
|
||||
if pending:
|
||||
if pending_event or pending:
|
||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
# even makes its first API call (this was causing an infinite loop).
|
||||
if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions:
|
||||
adapter._active_sessions[session_key].clear()
|
||||
|
||||
|
||||
# Cap recursion depth to prevent resource exhaustion when the
|
||||
# user sends multiple messages while the agent keeps failing. (#816)
|
||||
if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH:
|
||||
|
|
@ -8172,9 +8203,10 @@ class GatewayRunner:
|
|||
"queueing message instead of recursing.",
|
||||
_interrupt_depth, session_key,
|
||||
)
|
||||
# Queue the pending message for normal processing on next turn
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, 'queue_message'):
|
||||
if adapter and pending_event:
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
|
||||
elif adapter and hasattr(adapter, 'queue_message'):
|
||||
adapter.queue_message(session_key, pending)
|
||||
return result_holder[0] or {"final_response": response, "messages": history}
|
||||
|
||||
|
|
@ -8189,23 +8221,37 @@ class GatewayRunner:
|
|||
if first_response and not _already_streamed:
|
||||
try:
|
||||
await adapter.send(source.chat_id, first_response,
|
||||
metadata=getattr(event, "metadata", None))
|
||||
metadata={"thread_id": source.thread_id} if source.thread_id else None)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send first response before queued message: %s", e)
|
||||
# else: interrupted — discard the interrupted response ("Operation
|
||||
# interrupted." is just noise; the user already knows they sent a
|
||||
# new message).
|
||||
|
||||
# Process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
next_source = source
|
||||
next_message = pending
|
||||
next_message_id = None
|
||||
if pending_event is not None:
|
||||
next_source = getattr(pending_event, "source", None) or source
|
||||
next_message = await self._prepare_inbound_message_text(
|
||||
event=pending_event,
|
||||
source=next_source,
|
||||
history=updated_history,
|
||||
)
|
||||
if next_message is None:
|
||||
return result
|
||||
next_message_id = getattr(pending_event, "message_id", None)
|
||||
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
message=next_message,
|
||||
context_prompt=context_prompt,
|
||||
history=updated_history,
|
||||
source=source,
|
||||
source=next_source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
_interrupt_depth=_interrupt_depth + 1,
|
||||
event_message_id=next_message_id,
|
||||
)
|
||||
finally:
|
||||
# Stop progress sender, interrupt monitor, and notification task
|
||||
|
|
|
|||
|
|
@ -368,6 +368,11 @@ class SessionEntry:
|
|||
# survives gateway restarts (the old in-memory _pre_flushed_sessions
|
||||
# set was lost on restart, causing redundant re-flushes).
|
||||
memory_flushed: bool = False
|
||||
|
||||
# When True the next call to get_or_create_session() will auto-reset
|
||||
# this session (create a new session_id) so the user starts fresh.
|
||||
# Set by /stop to break stuck-resume loops (#7536).
|
||||
suspended: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
|
|
@ -387,6 +392,7 @@ class SessionEntry:
|
|||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"cost_status": self.cost_status,
|
||||
"memory_flushed": self.memory_flushed,
|
||||
"suspended": self.suspended,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
|
|
@ -423,6 +429,7 @@ class SessionEntry:
|
|||
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
||||
cost_status=data.get("cost_status", "unknown"),
|
||||
memory_flushed=data.get("memory_flushed", False),
|
||||
suspended=data.get("suspended", False),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -698,7 +705,12 @@ class SessionStore:
|
|||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
# Auto-reset sessions marked as suspended (e.g. after /stop
|
||||
# broke a stuck loop — #7536).
|
||||
if entry.suspended:
|
||||
reset_reason = "suspended"
|
||||
else:
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
|
|
@ -771,6 +783,44 @@ class SessionStore:
|
|||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
self._save()
|
||||
|
||||
def suspend_session(self, session_key: str) -> bool:
|
||||
"""Mark a session as suspended so it auto-resets on next access.
|
||||
|
||||
Used by ``/stop`` to prevent stuck sessions from being resumed
|
||||
after a gateway restart (#7536). Returns True if the session
|
||||
existed and was marked.
|
||||
"""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
if session_key in self._entries:
|
||||
self._entries[session_key].suspended = True
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
Called on gateway startup to prevent sessions that were likely
|
||||
in-flight when the gateway last exited from being blindly resumed
|
||||
(#7536). Only suspends sessions updated within *max_age_seconds*
|
||||
to avoid resetting long-idle sessions that are harmless to resume.
|
||||
Returns the number of sessions that were suspended.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
cutoff = _time.time() - max_age_seconds
|
||||
count = 0
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for entry in self._entries.values():
|
||||
if not entry.suspended and entry.updated_at >= cutoff:
|
||||
entry.suspended = True
|
||||
count += 1
|
||||
if count:
|
||||
self._save()
|
||||
return count
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
db_end_session_id = None
|
||||
|
|
|
|||
|
|
@ -46,12 +46,16 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau
|
|||
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
|
||||
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
|
||||
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
|
||||
_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="")
|
||||
_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="")
|
||||
|
||||
_VAR_MAP = {
|
||||
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
|
||||
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
|
||||
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
|
||||
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
|
||||
"HERMES_SESSION_USER_ID": _SESSION_USER_ID,
|
||||
"HERMES_SESSION_USER_NAME": _SESSION_USER_NAME,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -60,6 +64,8 @@ def set_session_vars(
|
|||
chat_id: str = "",
|
||||
chat_name: str = "",
|
||||
thread_id: str = "",
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
) -> list:
|
||||
"""Set all session context variables and return reset tokens.
|
||||
|
||||
|
|
@ -74,6 +80,8 @@ def set_session_vars(
|
|||
_SESSION_CHAT_ID.set(chat_id),
|
||||
_SESSION_CHAT_NAME.set(chat_name),
|
||||
_SESSION_THREAD_ID.set(thread_id),
|
||||
_SESSION_USER_ID.set(user_id),
|
||||
_SESSION_USER_NAME.set(user_name),
|
||||
]
|
||||
return tokens
|
||||
|
||||
|
|
@ -87,6 +95,8 @@ def clear_session_vars(tokens: list) -> None:
|
|||
_SESSION_CHAT_ID,
|
||||
_SESSION_CHAT_NAME,
|
||||
_SESSION_THREAD_ID,
|
||||
_SESSION_USER_ID,
|
||||
_SESSION_USER_NAME,
|
||||
]
|
||||
for var, token in zip(vars_in_order, tokens):
|
||||
var.reset(token)
|
||||
|
|
|
|||
|
|
@ -261,6 +261,28 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
|||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Key Helper
|
||||
# =============================================================================
|
||||
|
||||
def get_anthropic_key() -> str:
|
||||
"""Return the first usable Anthropic credential, or ``""``.
|
||||
|
||||
Checks both the ``.env`` file (via ``get_env_value``) and the process
|
||||
environment (``os.getenv``). The fallback order mirrors the
|
||||
``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
|
||||
|
||||
ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
|
||||
"""
|
||||
from hermes_cli.config import get_env_value
|
||||
|
||||
for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
|
||||
value = get_env_value(var) or os.getenv(var, "")
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -52,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = (
|
|||
# Known OpenClaw directory names (current + legacy)
|
||||
_OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot")
|
||||
|
||||
def _warn_if_gateway_running(auto_yes: bool) -> None:
|
||||
"""Check if a Hermes gateway is running with connected platforms.
|
||||
|
||||
Migrating bot tokens while the gateway is polling will cause conflicts
|
||||
(e.g. Telegram 409 "terminated by other getUpdates request"). Warn the
|
||||
user and let them decide whether to continue.
|
||||
"""
|
||||
from gateway.status import get_running_pid, read_runtime_status
|
||||
|
||||
if not get_running_pid():
|
||||
return
|
||||
|
||||
data = read_runtime_status() or {}
|
||||
platforms = data.get("platforms") or {}
|
||||
connected = [name for name, info in platforms.items()
|
||||
if isinstance(info, dict) and info.get("state") == "connected"]
|
||||
if not connected:
|
||||
return
|
||||
|
||||
print()
|
||||
print_error(
|
||||
"Hermes gateway is running with active connections: "
|
||||
+ ", ".join(connected)
|
||||
)
|
||||
print_info(
|
||||
"Migrating bot tokens while the gateway is active will cause "
|
||||
"conflicts (Telegram, Discord, and Slack only allow one active "
|
||||
"session per token)."
|
||||
)
|
||||
print_info("Recommendation: stop the gateway first with 'hermes stop'.")
|
||||
print()
|
||||
if not auto_yes and not prompt_yes_no("Continue anyway?", default=False):
|
||||
print_info("Migration cancelled. Stop the gateway and try again.")
|
||||
sys.exit(0)
|
||||
|
||||
# State files commonly found in OpenClaw workspace directories that cause
|
||||
# confusion after migration (the agent discovers them and writes to them)
|
||||
_WORKSPACE_STATE_GLOBS = (
|
||||
|
|
@ -252,6 +287,10 @@ def _cmd_migrate(args):
|
|||
print_info(f"Workspace: {workspace_target}")
|
||||
print()
|
||||
|
||||
# Check if a gateway is running with connected platforms — migrating tokens
|
||||
# while the gateway is active will cause conflicts (e.g. Telegram 409).
|
||||
_warn_if_gateway_running(auto_yes)
|
||||
|
||||
# Ensure config.yaml exists before migration tries to read it
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
|
|
|
|||
79
hermes_cli/cli_output.py
Normal file
79
hermes_cli/cli_output.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Shared CLI output helpers for Hermes CLI modules.
|
||||
|
||||
Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
|
||||
functions previously duplicated across setup.py, tools_config.py,
|
||||
mcp_config.py, and memory_setup.py.
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
# ─── Print Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def print_info(text: str) -> None:
|
||||
"""Print a dim informational message."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str) -> None:
|
||||
"""Print a green success message with ✓ prefix."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str) -> None:
|
||||
"""Print a yellow warning message with ⚠ prefix."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str) -> None:
|
||||
"""Print a red error message with ✗ prefix."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def print_header(text: str) -> None:
|
||||
"""Print a bold yellow header."""
|
||||
print(color(f"\n {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
# ─── Input Prompts ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def prompt(
|
||||
question: str,
|
||||
default: str | None = None,
|
||||
password: bool = False,
|
||||
) -> str:
|
||||
"""Prompt the user for input with optional default and password masking.
|
||||
|
||||
Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
|
||||
in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
|
||||
|
||||
Returns the user's input (stripped), or *default* if the user presses Enter.
|
||||
Returns empty string on Ctrl-C or EOF.
|
||||
"""
|
||||
suffix = f" [{default}]" if default else ""
|
||||
display = color(f" {question}{suffix}: ", Colors.YELLOW)
|
||||
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(display)
|
||||
else:
|
||||
value = input(display)
|
||||
value = value.strip()
|
||||
return value if value else (default or "")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return ""
|
||||
|
||||
|
||||
def prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
"""Prompt for a yes/no answer. Returns bool."""
|
||||
hint = "Y/n" if default else "y/N"
|
||||
answer = prompt(f"{question} ({hint})")
|
||||
if not answer:
|
||||
return default
|
||||
return answer.lower().startswith("y")
|
||||
|
|
@ -1497,7 +1497,7 @@ _KNOWN_ROOT_KEYS = {
|
|||
|
||||
# Valid fields inside a custom_providers list entry
|
||||
_VALID_CUSTOM_PROVIDER_FIELDS = {
|
||||
"name", "base_url", "api_key", "api_mode", "models",
|
||||
"name", "base_url", "api_key", "api_mode", "model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
}
|
||||
|
||||
|
|
@ -2582,7 +2582,8 @@ def show_config():
|
|||
for env_key, name in keys:
|
||||
value = get_env_value(env_key)
|
||||
print(f" {name:<14} {redact_key(value)}")
|
||||
anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
|
||||
|
||||
# Model settings
|
||||
|
|
@ -2798,8 +2799,8 @@ def set_config_value(key: str, value: str):
|
|||
|
||||
# Write only user config back (not the full merged defaults)
|
||||
ensure_hermes_home()
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, user_config, sort_keys=False)
|
||||
|
||||
# Keep .env in sync for keys that terminal_tool reads directly from env vars.
|
||||
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.
|
||||
|
|
|
|||
|
|
@ -336,8 +336,8 @@ def run_doctor(args):
|
|||
model_section[k] = raw_config.pop(k)
|
||||
else:
|
||||
raw_config.pop(k)
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(raw_config, f, default_flow_style=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, raw_config)
|
||||
check_ok("Migrated stale root-level keys into model section")
|
||||
fixed_count += 1
|
||||
else:
|
||||
|
|
@ -686,7 +686,8 @@ def run_doctor(args):
|
|||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_key = get_anthropic_key()
|
||||
if anthropic_key:
|
||||
print(" Checking Anthropic API...", end="", flush=True)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
pids = []
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
"hermes_cli.main -p",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes_cli/main.py --profile",
|
||||
"hermes_cli/main.py -p",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
]
|
||||
current_home = str(get_hermes_home().resolve())
|
||||
current_profile_arg = _profile_arg(current_home)
|
||||
current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else ""
|
||||
|
||||
def _matches_current_profile(command: str) -> bool:
|
||||
if current_profile_name:
|
||||
return (
|
||||
f"--profile {current_profile_name}" in command
|
||||
or f"-p {current_profile_name}" in command
|
||||
or f"HERMES_HOME={current_home}" in command
|
||||
)
|
||||
|
||||
if "--profile " in command or " -p " in command:
|
||||
return False
|
||||
if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
if is_windows():
|
||||
# Windows: use wmic to search command lines
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
line = line.strip()
|
||||
|
|
@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
|||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns):
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
|
|
@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
|||
current_cmd = ""
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
["ps", "eww", "-ax", "-o", "pid=,command="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
# Skip grep and current process
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
continue
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
|
||||
pid = None
|
||||
command = ""
|
||||
|
||||
parts = stripped.split(None, 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
command = parts[1]
|
||||
except ValueError:
|
||||
pid = None
|
||||
|
||||
if pid is None:
|
||||
aux_parts = stripped.split()
|
||||
if len(aux_parts) > 10 and aux_parts[1].isdigit():
|
||||
pid = int(aux_parts[1])
|
||||
command = " ".join(aux_parts[10:])
|
||||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int:
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
|
||||
Args:
|
||||
force: Use the platform's force-kill mechanism instead of graceful terminate.
|
||||
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
|
||||
restarted and should not be killed).
|
||||
all_profiles: When ``True``, kill across all profiles. Passed
|
||||
through to :func:`find_gateway_pids`.
|
||||
"""
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids)
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles)
|
||||
killed = 0
|
||||
|
||||
for pid in pids:
|
||||
|
|
@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None:
|
|||
print(" If you want the gateway user service to survive logout, run:")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
def _launchd_user_home() -> Path:
|
||||
"""Return the real macOS user home for launchd artifacts.
|
||||
|
||||
Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but
|
||||
launchd user agents still live under the actual account home.
|
||||
"""
|
||||
import pwd
|
||||
|
||||
return Path(pwd.getpwuid(os.getuid()).pw_dir)
|
||||
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
"""Return the launchd plist path, scoped per profile.
|
||||
|
||||
|
|
@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path:
|
|||
"""
|
||||
suffix = _profile_suffix()
|
||||
name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway"
|
||||
return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
|
||||
def _detect_venv_dir() -> Path | None:
|
||||
"""Detect the active virtualenv directory.
|
||||
|
|
@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str:
|
|||
return "\n".join(line.rstrip() for line in text.strip().splitlines())
|
||||
|
||||
|
||||
def _normalize_launchd_plist_for_comparison(text: str) -> str:
|
||||
"""Normalize launchd plist text for staleness checks.
|
||||
|
||||
The generated plist intentionally captures a broad PATH assembled from the
|
||||
invoking shell so user-installed tools remain reachable under launchd.
|
||||
That makes raw text comparison unstable across shells, so ignore the PATH
|
||||
payload when deciding whether the installed plist is stale.
|
||||
"""
|
||||
import re
|
||||
|
||||
normalized = _normalize_service_definition(text)
|
||||
return re.sub(
|
||||
r'(<key>PATH</key>\s*<string>)(.*?)(</string>)',
|
||||
r'\1__HERMES_PATH__\3',
|
||||
normalized,
|
||||
flags=re.S,
|
||||
)
|
||||
|
||||
|
||||
def systemd_unit_is_current(system: bool = False) -> bool:
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if not unit_path.exists():
|
||||
|
|
@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool:
|
|||
|
||||
installed = plist_path.read_text(encoding="utf-8")
|
||||
expected = generate_launchd_plist()
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected)
|
||||
|
||||
|
||||
def refresh_launchd_plist_if_needed() -> bool:
|
||||
|
|
@ -1981,6 +2051,36 @@ def _setup_whatsapp():
|
|||
cmd_whatsapp(argparse.Namespace())
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via the standard platform setup."""
|
||||
email_platform = next(p for p in _PLATFORMS if p["key"] == "email")
|
||||
_setup_standard_platform(email_platform)
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via the standard platform setup."""
|
||||
sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms")
|
||||
_setup_standard_platform(sms_platform)
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via the standard platform setup."""
|
||||
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via the standard platform setup."""
|
||||
feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu")
|
||||
_setup_standard_platform(feishu_platform)
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via the standard platform setup."""
|
||||
wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom")
|
||||
_setup_standard_platform(wecom_platform)
|
||||
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if supports_systemd_services():
|
||||
|
|
@ -2540,7 +2640,7 @@ def gateway_command(args):
|
|||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
killed = kill_gateway_processes()
|
||||
killed = kill_gateway_processes(all_profiles=True)
|
||||
total = killed + (1 if service_available else 0)
|
||||
if total:
|
||||
print(f"✓ Stopped {total} gateway process(es) across all profiles")
|
||||
|
|
|
|||
|
|
@ -2758,13 +2758,8 @@ def _model_flow_anthropic(config, current_model=""):
|
|||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
|
||||
# Check ALL credential sources
|
||||
existing_key = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or os.getenv("ANTHROPIC_TOKEN", "")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or os.getenv("ANTHROPIC_API_KEY", "")
|
||||
or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
existing_key = get_anthropic_key()
|
||||
cc_available = False
|
||||
try:
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid
|
||||
|
|
@ -4090,7 +4085,7 @@ def cmd_update(args):
|
|||
# Exclude PIDs that belong to just-restarted services so we don't
|
||||
# immediately kill the process that systemd/launchd just spawned.
|
||||
service_pids = _get_service_pids()
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids)
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True)
|
||||
for pid in manual_pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGTERM)
|
||||
|
|
|
|||
|
|
@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool:
|
|||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
from hermes_cli.cli_output import prompt as _shared_prompt
|
||||
return _shared_prompt(question, default=default, password=password)
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
|
|||
items: list of (label, description) tuples.
|
||||
Returns selected index, or default on escape/quit.
|
||||
"""
|
||||
try:
|
||||
import curses
|
||||
result = [default]
|
||||
|
||||
def _menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, curses.COLOR_CYAN, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Title
|
||||
try:
|
||||
stdscr.addnstr(0, 0, title, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
|
||||
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, (label, desc) in enumerate(items):
|
||||
y = i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {label}"
|
||||
if desc:
|
||||
line += f" {desc}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_menu)
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Fallback: numbered input
|
||||
print(f"\n {title}\n")
|
||||
for i, (label, desc) in enumerate(items):
|
||||
marker = "→" if i == default else " "
|
||||
d = f" {desc}" if desc else ""
|
||||
print(f" {marker} {i + 1}. {label}{d}")
|
||||
while True:
|
||||
try:
|
||||
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(items):
|
||||
return idx
|
||||
except (ValueError, EOFError):
|
||||
return default
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
# Format (label, desc) tuples into display strings
|
||||
display_items = [
|
||||
f"{label} {desc}" if desc else label
|
||||
for label, desc in items
|
||||
]
|
||||
return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
|
||||
|
|
|
|||
45
hermes_cli/platforms.py
Normal file
45
hermes_cli/platforms.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
Shared platform registry for Hermes Agent.
|
||||
|
||||
Single source of truth for platform metadata consumed by both
|
||||
skills_config (label display) and tools_config (default toolset
|
||||
resolution). Import ``PLATFORMS`` from here instead of maintaining
|
||||
duplicate dicts in each module.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class PlatformInfo(NamedTuple):
|
||||
"""Metadata for a single platform entry."""
|
||||
label: str
|
||||
default_toolset: str
|
||||
|
||||
|
||||
# Ordered so that TUI menus are deterministic.
|
||||
PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
|
||||
("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
|
||||
("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
|
||||
("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
|
||||
("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
|
||||
("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
|
||||
("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
|
||||
("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
|
||||
("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
|
||||
("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
|
||||
("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
|
||||
("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
|
||||
("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
|
||||
("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
|
||||
("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
|
||||
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
|
||||
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
|
||||
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
|
||||
])
|
||||
|
||||
|
||||
def platform_label(key: str, default: str = "") -> str:
|
||||
"""Return the display label for a platform key, or *default*."""
|
||||
info = PLATFORMS.get(key)
|
||||
return info.label if info is not None else default
|
||||
|
|
@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
api_mode = _parse_api_mode(entry.get("api_mode"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
model_name = str(entry.get("model", "") or "").strip()
|
||||
if model_name:
|
||||
result["model"] = model_name
|
||||
return result
|
||||
|
||||
return None
|
||||
|
|
@ -329,6 +332,11 @@ def _resolve_named_custom_runtime(
|
|||
# Check if a credential pool exists for this custom endpoint
|
||||
pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode"))
|
||||
if pool_result:
|
||||
# Propagate the model name even when using pooled credentials —
|
||||
# the pool doesn't know about the custom_providers model field.
|
||||
model_name = custom_provider.get("model")
|
||||
if model_name:
|
||||
pool_result["model"] = model_name
|
||||
return pool_result
|
||||
|
||||
api_key_candidates = [
|
||||
|
|
@ -339,7 +347,7 @@ def _resolve_named_custom_runtime(
|
|||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
result = {
|
||||
"provider": "custom",
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
|
|
@ -348,6 +356,11 @@ def _resolve_named_custom_runtime(
|
|||
"api_key": api_key or "no-key-required",
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
}
|
||||
# Propagate the model name so callers can override self.model when the
|
||||
# provider name differs from the actual model string the API expects.
|
||||
if custom_provider.get("model"):
|
||||
result["model"] = custom_provider["model"]
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_openrouter_runtime(
|
||||
|
|
|
|||
|
|
@ -197,24 +197,12 @@ def print_header(title: str):
|
|||
print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
def print_info(text: str):
|
||||
"""Print info text."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
from hermes_cli.cli_output import ( # noqa: E402
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
|
||||
def is_interactive_stdin() -> bool:
|
||||
|
|
@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
|
|||
|
||||
|
||||
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu using curses to avoid simple_term_menu rendering bugs."""
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Rows available for list items: rows 2..(max_y-2) inclusive.
|
||||
visible = max(1, max_y - 3)
|
||||
|
||||
# Scroll the viewport so the cursor is always visible.
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible:
|
||||
scroll_offset = cursor - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
|
||||
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
question,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
|
||||
y = row + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {choices[i]}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
"""Single-select menu using curses. Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
|
||||
|
||||
|
||||
|
||||
|
|
@ -2052,6 +1969,42 @@ def _setup_weixin():
|
|||
_gateway_setup_weixin()
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Configure Signal via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_signal as _gateway_setup_signal
|
||||
_gateway_setup_signal()
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_email as _gateway_setup_email
|
||||
_gateway_setup_email()
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_sms as _gateway_setup_sms
|
||||
_gateway_setup_sms()
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk
|
||||
_gateway_setup_dingtalk()
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu
|
||||
_gateway_setup_feishu()
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom
|
||||
_gateway_setup_wecom()
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
|
|
@ -2168,9 +2121,15 @@ _GATEWAY_PLATFORMS = [
|
|||
("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram),
|
||||
("Discord", "DISCORD_BOT_TOKEN", _setup_discord),
|
||||
("Slack", "SLACK_BOT_TOKEN", _setup_slack),
|
||||
("Signal", "SIGNAL_HTTP_URL", _setup_signal),
|
||||
("Email", "EMAIL_ADDRESS", _setup_email),
|
||||
("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms),
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk),
|
||||
("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu),
|
||||
("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
|
|
@ -2212,10 +2171,17 @@ def setup_gateway(config: dict):
|
|||
get_env_value("TELEGRAM_BOT_TOKEN")
|
||||
or get_env_value("DISCORD_BOT_TOKEN")
|
||||
or get_env_value("SLACK_BOT_TOKEN")
|
||||
or get_env_value("SIGNAL_HTTP_URL")
|
||||
or get_env_value("EMAIL_ADDRESS")
|
||||
or get_env_value("TWILIO_ACCOUNT_SID")
|
||||
or get_env_value("MATTERMOST_TOKEN")
|
||||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("DINGTALK_CLIENT_ID")
|
||||
or get_env_value("FEISHU_APP_ID")
|
||||
or get_env_value("WECOM_BOT_ID")
|
||||
or get_env_value("WEIXIN_ACCOUNT_ID")
|
||||
or get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
|
|
@ -2404,12 +2370,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
|
|||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if get_env_value("EMAIL_ADDRESS"):
|
||||
platforms.append("Email")
|
||||
if get_env_value("TWILIO_ACCOUNT_SID"):
|
||||
platforms.append("SMS")
|
||||
if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"):
|
||||
platforms.append("Matrix")
|
||||
if get_env_value("MATTERMOST_TOKEN"):
|
||||
platforms.append("Mattermost")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("DINGTALK_CLIENT_ID"):
|
||||
platforms.append("DingTalk")
|
||||
if get_env_value("FEISHU_APP_ID"):
|
||||
platforms.append("Feishu")
|
||||
if get_env_value("WECOM_BOT_ID"):
|
||||
platforms.append("WeCom")
|
||||
if get_env_value("WEIXIN_ACCOUNT_ID"):
|
||||
platforms.append("Weixin")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL"):
|
||||
platforms.append("BlueBubbles")
|
||||
if get_env_value("WEBHOOK_ENABLED"):
|
||||
platforms.append("Webhooks")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
|
|
|||
|
|
@ -15,25 +15,12 @@ from typing import List, Optional, Set
|
|||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": "🖥️ CLI",
|
||||
"telegram": "📱 Telegram",
|
||||
"discord": "💬 Discord",
|
||||
"slack": "💼 Slack",
|
||||
"whatsapp": "📱 WhatsApp",
|
||||
"signal": "📡 Signal",
|
||||
"bluebubbles": "💬 BlueBubbles",
|
||||
"email": "📧 Email",
|
||||
"homeassistant": "🏠 Home Assistant",
|
||||
"mattermost": "💬 Mattermost",
|
||||
"matrix": "💬 Matrix",
|
||||
"dingtalk": "💬 DingTalk",
|
||||
"feishu": "🪽 Feishu",
|
||||
"wecom": "💬 WeCom",
|
||||
"weixin": "💬 Weixin",
|
||||
"webhook": "🔗 Webhook",
|
||||
}
|
||||
# Backward-compatible view: {key: label_string} so existing code that
|
||||
# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
|
||||
# working without changes to every call site.
|
||||
PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -141,11 +141,8 @@ def show_status(args):
|
|||
display = redact_key(value) if not show_all else value
|
||||
print(f" {name:<12} {check_mark(has_key)} {display}")
|
||||
|
||||
anthropic_value = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or ""
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
|
||||
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")
|
||||
|
||||
|
|
|
|||
|
|
@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
|||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
from hermes_cli.cli_output import ( # noqa: E402 — late import block
|
||||
print_error as _print_error,
|
||||
print_info as _print_info,
|
||||
print_success as _print_success,
|
||||
print_warning as _print_warning,
|
||||
prompt as _prompt,
|
||||
)
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
|
|||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
# Platform display config — derived from the canonical registry so every
|
||||
# module shares the same data. Kept as dict-of-dicts for backward
|
||||
# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
|
||||
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
|
||||
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
|
||||
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
|
||||
k: {"label": info.label, "default_toolset": info.default_toolset}
|
||||
for k, info in _PLATFORMS_REGISTRY.items()
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
|
|||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
"""Single-select menu (arrow keys). Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
# ─── Token Estimation ────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -189,6 +189,33 @@ def is_wsl() -> bool:
|
|||
return _wsl_detected
|
||||
|
||||
|
||||
# ─── Well-Known Paths ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Return the path to ``config.yaml`` under HERMES_HOME.
|
||||
|
||||
Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
|
||||
in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
|
||||
"""
|
||||
return get_hermes_home() / "config.yaml"
|
||||
|
||||
|
||||
def get_skills_dir() -> Path:
|
||||
"""Return the path to the skills directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "skills"
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""Return the path to the logs directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "logs"
|
||||
|
||||
|
||||
def get_env_path() -> Path:
|
||||
"""Return the path to the ``.env`` file under HERMES_HOME."""
|
||||
return get_hermes_home() / ".env"
|
||||
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_hermes_home
|
||||
|
||||
# Sentinel to track whether setup_logging() has already run. The function
|
||||
# is idempotent — calling it twice is safe but the second call is a no-op
|
||||
|
|
@ -246,7 +246,7 @@ def _read_logging_config():
|
|||
"""
|
||||
try:
|
||||
import yaml
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ crashes due to a bad timezone string.
|
|||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
|
|||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
|
|
|||
80
run_agent.py
80
run_agent.py
|
|
@ -739,6 +739,7 @@ class AIAgent:
|
|||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
|
|
@ -2832,8 +2833,10 @@ class AIAgent:
|
|||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal all tools to abort any in-flight operations immediately
|
||||
_set_interrupt(True)
|
||||
# Signal all tools to abort any in-flight operations immediately.
|
||||
# Scope the interrupt to this agent's execution thread so other
|
||||
# agents running in the same process (gateway) are not affected.
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
|
|
@ -2846,10 +2849,10 @@ class AIAgent:
|
|||
print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||
|
||||
def clear_interrupt(self) -> None:
|
||||
"""Clear any pending interrupt request and the global tool interrupt signal."""
|
||||
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_interrupt(False)
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
|
|
@ -3443,6 +3446,7 @@ class AIAgent:
|
|||
def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert internal chat-style messages to Responses input items."""
|
||||
items: List[Dict[str, Any]] = []
|
||||
seen_item_ids: set = set()
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
|
|
@ -3463,7 +3467,12 @@ class AIAgent:
|
|||
if isinstance(codex_reasoning, list):
|
||||
for ri in codex_reasoning:
|
||||
if isinstance(ri, dict) and ri.get("encrypted_content"):
|
||||
item_id = ri.get("id")
|
||||
if item_id and item_id in seen_item_ids:
|
||||
continue
|
||||
items.append(ri)
|
||||
if item_id:
|
||||
seen_item_ids.add(item_id)
|
||||
has_codex_reasoning = True
|
||||
|
||||
if content_text.strip():
|
||||
|
|
@ -3543,6 +3552,7 @@ class AIAgent:
|
|||
raise ValueError("Codex Responses input must be a list of input items.")
|
||||
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
seen_ids: set = set()
|
||||
for idx, item in enumerate(raw_items):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Codex Responses input[{idx}] must be an object.")
|
||||
|
|
@ -3595,8 +3605,12 @@ class AIAgent:
|
|||
if item_type == "reasoning":
|
||||
encrypted = item.get("encrypted_content")
|
||||
if isinstance(encrypted, str) and encrypted:
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
item_id = item.get("id")
|
||||
if isinstance(item_id, str) and item_id:
|
||||
if item_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(item_id)
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
if isinstance(item_id, str) and item_id:
|
||||
reasoning_item["id"] = item_id
|
||||
summary = item.get("summary")
|
||||
|
|
@ -7800,6 +7814,11 @@ class AIAgent:
|
|||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
|
||||
# Record the execution thread so interrupt()/clear_interrupt() can
|
||||
# scope the tool-level interrupt signal to THIS agent's thread only.
|
||||
# Must be set before clear_interrupt() which uses it.
|
||||
self._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
|
|
@ -8278,8 +8297,24 @@ class AIAgent:
|
|||
_text_parts.append(getattr(_blk, "text", ""))
|
||||
_trunc_content = "\n".join(_text_parts) if _text_parts else None
|
||||
|
||||
# A response is "thinking exhausted" only when the model
|
||||
# actually produced reasoning blocks but no visible text after
|
||||
# them. Models that do not use <think> tags (e.g. GLM-4.7 on
|
||||
# NVIDIA Build, minimax) may return content=None or an empty
|
||||
# string for unrelated reasons — treat those as normal
|
||||
# truncations that deserve continuation retries, not as
|
||||
# thinking-budget exhaustion.
|
||||
_has_think_tags = bool(
|
||||
_trunc_content and re.search(
|
||||
r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>',
|
||||
_trunc_content,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
_thinking_exhausted = (
|
||||
not _trunc_has_tool_calls and (
|
||||
not _trunc_has_tool_calls
|
||||
and _has_think_tags
|
||||
and (
|
||||
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
|
||||
or _trunc_content is None
|
||||
)
|
||||
|
|
@ -9507,12 +9542,41 @@ class AIAgent:
|
|||
invalid_json_args.append((tc.function.name, str(e)))
|
||||
|
||||
if invalid_json_args:
|
||||
# Check if the invalid JSON is due to truncation rather
|
||||
# than a model formatting mistake. Routers sometimes
|
||||
# rewrite finish_reason from "length" to "tool_calls",
|
||||
# hiding the truncation from the length handler above.
|
||||
# Detect truncation: args that don't end with } or ]
|
||||
# (after stripping whitespace) are cut off mid-stream.
|
||||
_truncated = any(
|
||||
not (tc.function.arguments or "").rstrip().endswith(("}", "]"))
|
||||
for tc in assistant_message.tool_calls
|
||||
if tc.function.name in {n for n, _ in invalid_json_args}
|
||||
)
|
||||
if _truncated:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Truncated tool call arguments detected "
|
||||
f"(finish_reason={finish_reason!r}) — refusing to execute.",
|
||||
force=True,
|
||||
)
|
||||
self._invalid_json_retries = 0
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Response truncated due to output length limit",
|
||||
}
|
||||
|
||||
# Track retries for invalid JSON arguments
|
||||
self._invalid_json_retries += 1
|
||||
|
||||
|
||||
tool_name, error_msg = invalid_json_args[0]
|
||||
self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
|
||||
|
||||
if self._invalid_json_retries < 3:
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
# Don't add anything to messages, just retry the API call
|
||||
|
|
|
|||
15
scripts/whatsapp-bridge/package-lock.json
generated
15
scripts/whatsapp-bridge/package-lock.json
generated
|
|
@ -8,7 +8,7 @@
|
|||
"name": "hermes-whatsapp-bridge",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"pino": "^9.0.0",
|
||||
"qrcode-terminal": "^0.12.0"
|
||||
|
|
@ -730,21 +730,22 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@whiskeysockets/baileys": {
|
||||
"name": "baileys",
|
||||
"version": "7.0.0-rc.9",
|
||||
"resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz",
|
||||
"integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==",
|
||||
"resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@cacheable/node-cache": "^1.4.0",
|
||||
"@hapi/boom": "^9.1.3",
|
||||
"async-mutex": "^0.5.0",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node",
|
||||
"lru-cache": "^11.1.0",
|
||||
"music-metadata": "^11.7.0",
|
||||
"p-queue": "^9.0.0",
|
||||
"pino": "^9.6",
|
||||
"protobufjs": "^7.2.4",
|
||||
"whatsapp-rust-bridge": "0.5.2",
|
||||
"ws": "^8.13.0"
|
||||
},
|
||||
"engines": {
|
||||
|
|
@ -2125,6 +2126,12 @@
|
|||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/whatsapp-rust-bridge": {
|
||||
"version": "0.5.2",
|
||||
"resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz",
|
||||
"integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/win-guid": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
"start": "node bridge.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout:
|
|||
"http://0.0.0.0:5000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_local_endpoint_bumps_read_timeout(self, base_url):
|
||||
"""Local endpoint + default timeout -> bumps to base_timeout."""
|
||||
|
|
@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout:
|
|||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Direct unit tests for is_local_endpoint."""
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://localhost:11434",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:5000",
|
||||
"http://[::1]:11434",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://172.17.0.1:11434",
|
||||
])
|
||||
def test_classic_local_addresses(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.docker.internal:8080/v1",
|
||||
"http://gateway.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_container_dns_names(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://api.openai.com",
|
||||
"https://openrouter.ai/api",
|
||||
"https://api.anthropic.com",
|
||||
"https://evil.docker.internal.example.com",
|
||||
])
|
||||
def test_remote_endpoints(self, url):
|
||||
assert is_local_endpoint(url) is False
|
||||
|
|
|
|||
|
|
@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
|
|||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
|
||||
if platform == Platform.DISCORD:
|
||||
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
platform_key = Platform.DISCORD
|
||||
elif platform == Platform.SLACK:
|
||||
|
|
|
|||
|
|
@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint:
|
|||
)
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
body = await resp.text()
|
||||
assert "data: " in body
|
||||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||
import asyncio
|
||||
import gateway.platforms.api_server as api_server_mod
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Working")
|
||||
await asyncio.sleep(0.65)
|
||||
cb("...done")
|
||||
return (
|
||||
{"final_response": "Working...done", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01),
|
||||
patch.object(adapter, "_run_agent", side_effect=_mock_run_agent),
|
||||
):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "do the thing"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert ": keepalive" in body
|
||||
assert "Working" in body
|
||||
assert "...done" in body
|
||||
assert "[DONE]" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
|
||||
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.
|
||||
|
|
|
|||
|
|
@ -119,28 +119,29 @@ class TestDeduplication:
|
|||
def test_first_message_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
assert adapter._is_duplicate("msg-1") is False
|
||||
assert adapter._dedup.is_duplicate("msg-1") is False
|
||||
|
||||
def test_second_same_message_is_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-1") is True
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-1") is True
|
||||
|
||||
def test_different_messages_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-2") is False
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-2") is False
|
||||
|
||||
def test_cache_cleanup_on_overflow(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
max_size = adapter._dedup._max_size
|
||||
# Fill beyond max
|
||||
for i in range(DEDUP_MAX_SIZE + 10):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
for i in range(max_size + 10):
|
||||
adapter._dedup.is_duplicate(f"msg-{i}")
|
||||
# Cache should have been pruned
|
||||
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
|
||||
assert len(adapter._dedup._seen) <= max_size + 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -253,13 +254,13 @@ class TestConnect:
|
|||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._session_webhooks["a"] = "http://x"
|
||||
adapter._seen_messages["b"] = 1.0
|
||||
adapter._dedup._seen["b"] = 1.0
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
|
||||
await adapter.disconnect()
|
||||
assert len(adapter._session_webhooks) == 0
|
||||
assert len(adapter._seen_messages) == 0
|
||||
assert len(adapter._dedup._seen) == 0
|
||||
assert adapter._http_client is None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
|||
|
||||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._token_lock_identity is None
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
|
|||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
adapter._threads.mark("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
|
|
@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
|||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
assert "555" in adapter._threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
|
|||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
assert "777" in adapter._threads
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for Discord thread participation persistence.
|
||||
|
||||
Verifies that _bot_participated_threads survives adapter restarts by
|
||||
Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
|
||||
being persisted to ~/.hermes/discord_threads.json.
|
||||
"""
|
||||
|
||||
|
|
@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
|
|||
|
||||
def test_starts_empty_when_no_state_file(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists_to_disk(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("222")
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("222")
|
||||
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
assert state_file.exists()
|
||||
|
|
@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
|
|||
"""Threads tracked by one adapter instance are visible to the next."""
|
||||
adapter1 = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter1._track_thread("aaa")
|
||||
adapter1._track_thread("bbb")
|
||||
adapter1._threads.mark("aaa")
|
||||
adapter1._threads.mark("bbb")
|
||||
|
||||
adapter2 = self._make_adapter(tmp_path)
|
||||
assert "aaa" in adapter2._bot_participated_threads
|
||||
assert "bbb" in adapter2._bot_participated_threads
|
||||
assert "aaa" in adapter2._threads
|
||||
assert "bbb" in adapter2._threads
|
||||
|
||||
def test_duplicate_track_does_not_double_save(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("111") # no-op
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("111") # no-op
|
||||
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert saved.count("111") == 1
|
||||
|
||||
def test_caps_at_max_tracked_threads(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
for i in range(10):
|
||||
adapter._track_thread(str(i))
|
||||
adapter._threads.mark(str(i))
|
||||
|
||||
assert len(adapter._bot_participated_threads) == 5
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert len(saved) == 5
|
||||
|
||||
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
state_file.write_text("not valid json{{{")
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_missing_hermes_home_does_not_crash(self, tmp_path):
|
||||
"""Load/save tolerate missing directories."""
|
||||
fake_home = tmp_path / "nonexistent" / "deep"
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
# _load should return empty set, not crash
|
||||
threads = DiscordAdapter._load_participated_threads()
|
||||
assert threads == set()
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
# ThreadParticipationTracker should return empty set, not crash
|
||||
tracker = ThreadParticipationTracker("discord")
|
||||
assert "$test" not in tracker
|
||||
|
|
|
|||
|
|
@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path):
|
||||
"""Synthetic completion event should carry user_id and user_name from the watcher."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path)
|
||||
adapter = runner.adapters[Platform.DISCORD]
|
||||
|
||||
watcher = _watcher_dict_with_notify()
|
||||
watcher["user_id"] = "user-42"
|
||||
watcher["user_name"] = "alice"
|
||||
|
||||
await runner._run_process_watcher(watcher)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.user_id == "user-42"
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
|
||||
"""A non-internal event with user_id=None should be silently dropped."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="service message",
|
||||
source=source,
|
||||
internal=False,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
# Should return None (dropped) and NOT send any pairing message
|
||||
assert result is None
|
||||
assert adapter.send.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path):
|
||||
"""A message with user_id=None must never call generate_code."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
generate_called = False
|
||||
original_generate = runner.pairing_store.generate_code
|
||||
|
||||
def tracking_generate(*args, **kwargs):
|
||||
nonlocal generate_called
|
||||
generate_called = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
runner.pairing_store.generate_code = tracking_generate
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="456",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(text="anonymous", source=source, internal=False)
|
||||
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert not generate_called, (
|
||||
"Pairing code should NOT be generated for messages with user_id=None"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
|
|||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread1")
|
||||
adapter._threads.mark("$thread1")
|
||||
|
||||
event = _make_event("hello without mention", thread_id="$thread1")
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
|
|||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread_root")
|
||||
adapter._threads.mark("$thread_root")
|
||||
event = _make_event("reply in thread", thread_id="$thread_root")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
|
|
@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_tracks_participation(monkeypatch):
|
||||
"""Auto-created threads are tracked in _bot_participated_threads."""
|
||||
"""Auto-created threads are tracked in _threads."""
|
||||
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
event = _make_event("hello", event_id="$msg1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$msg1" in adapter._bot_participated_threads
|
||||
assert "$msg1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
|
|||
class TestThreadPersistence:
|
||||
def test_empty_state_file(self, tmp_path, monkeypatch):
|
||||
"""No state file → empty set."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: tmp_path / "matrix_threads.json"),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: tmp_path / "matrix_threads.json",
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
loaded = adapter._load_participated_threads()
|
||||
assert loaded == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists(self, tmp_path, monkeypatch):
|
||||
"""_track_thread writes to disk."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""mark() writes to disk."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._track_thread("$thread_abc")
|
||||
adapter._threads.mark("$thread_abc")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert "$thread_abc" in data
|
||||
|
||||
def test_threads_survive_reload(self, tmp_path, monkeypatch):
|
||||
"""Persisted threads are loaded by a new adapter instance."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
state_path.write_text(json.dumps(["$t1", "$t2"]))
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
assert "$t1" in adapter._bot_participated_threads
|
||||
assert "$t2" in adapter._bot_participated_threads
|
||||
assert "$t1" in adapter._threads
|
||||
assert "$t2" in adapter._threads
|
||||
|
||||
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
|
||||
"""Thread set is trimmed to _MAX_TRACKED_THREADS."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""Thread set is trimmed to max_tracked."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
|
||||
for i in range(10):
|
||||
adapter._bot_participated_threads.add(f"$t{i}")
|
||||
adapter._save_participated_threads()
|
||||
adapter._threads.mark(f"$t{i}")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert len(data) == 5
|
||||
|
|
@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
|
|||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
|
@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
|||
|
||||
adapter = _make_adapter()
|
||||
_set_dm(adapter)
|
||||
adapter._bot_participated_threads.add("$existing_thread")
|
||||
adapter._threads.mark("$existing_thread")
|
||||
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
|
|
@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
"""DM mention-thread tracks the thread in _bot_participated_threads."""
|
||||
"""DM mention-thread tracks the thread in _threads."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
|
|
@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
|||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$dm1" in adapter._bot_participated_threads
|
||||
assert "$dm1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -614,25 +614,27 @@ class TestMattermostDedup:
|
|||
assert self.adapter.handle_message.call_count == 2
|
||||
|
||||
def test_prune_seen_clears_expired(self):
|
||||
"""_prune_seen should remove entries older than _SEEN_TTL."""
|
||||
"""Dedup cache should remove entries older than TTL on overflow."""
|
||||
now = time.time()
|
||||
dedup = self.adapter._dedup
|
||||
# Fill with enough expired entries to trigger pruning
|
||||
for i in range(self.adapter._SEEN_MAX + 10):
|
||||
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
|
||||
for i in range(dedup._max_size + 10):
|
||||
dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
|
||||
|
||||
# Add a fresh one
|
||||
self.adapter._seen_posts["fresh"] = now
|
||||
dedup._seen["fresh"] = now
|
||||
|
||||
self.adapter._prune_seen()
|
||||
# Trigger pruning by calling is_duplicate with a new entry (over max_size)
|
||||
dedup.is_duplicate("trigger_prune")
|
||||
|
||||
# Old entries should be pruned, fresh one kept
|
||||
assert "fresh" in self.adapter._seen_posts
|
||||
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
|
||||
assert "fresh" in dedup._seen
|
||||
assert len(dedup._seen) < dedup._max_size + 10
|
||||
|
||||
def test_seen_cache_tracks_post_ids(self):
|
||||
"""Posts are tracked in _seen_posts dict."""
|
||||
self.adapter._seen_posts["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._seen_posts
|
||||
"""Posts are tracked in the dedup cache."""
|
||||
self.adapter._dedup._seen["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._dedup._seen
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.run import _dequeue_pending_event
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
|
|||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:voice"
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="voice-q1",
|
||||
media_urls=["/tmp/voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = _dequeue_pending_event(adapter, session_key)
|
||||
|
||||
assert retrieved is event
|
||||
assert retrieved.media_urls == ["/tmp/voice.ogg"]
|
||||
assert retrieved.media_types == ["audio/ogg"]
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
|
@ -25,6 +27,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
|
|
@ -33,6 +37,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == "alice"
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
# os.environ should NOT be touched
|
||||
|
|
@ -50,6 +56,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
source = SessionSource(
|
||||
|
|
@ -57,12 +65,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
|
@ -70,6 +81,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -114,16 +114,16 @@ class TestSignalAdapterInit:
|
|||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+15551234567") == "+155****4567"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+155****4567") == "+155****4567"
|
||||
|
||||
def test_redact_phone_short(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+12345") == "+1****45"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+12345") == "+1****45"
|
||||
|
||||
def test_redact_phone_empty(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("") == "<none>"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("") == "<none>"
|
||||
|
||||
def test_parse_comma_list(self):
|
||||
from gateway.platforms.signal import _parse_comma_list
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
"""Tests for SMS (Twilio) platform integration.
|
||||
|
||||
Covers config loading, format/truncate, echo prevention,
|
||||
requirements check, and toolset verification.
|
||||
requirements check, toolset verification, and Twilio signature validation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -213,3 +216,335 @@ class TestSmsToolset:
|
|||
from tools.cronjob_tools import CRONJOB_SCHEMA
|
||||
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
|
||||
assert "sms" in deliver_desc.lower()
|
||||
|
||||
|
||||
# ── Webhook host configuration ─────────────────────────────────────
|
||||
|
||||
class TestWebhookHostConfig:
|
||||
"""Verify SMS_WEBHOOK_HOST env var and default."""
|
||||
|
||||
def test_default_host_is_all_interfaces(self):
|
||||
from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST
|
||||
assert DEFAULT_WEBHOOK_HOST == "0.0.0.0"
|
||||
|
||||
def test_host_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_HOST": "127.0.0.1",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_host == "127.0.0.1"
|
||||
|
||||
def test_webhook_url_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
def test_webhook_url_stripped(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
|
||||
# ── Startup guard (fail-closed) ────────────────────────────────────
|
||||
|
||||
class TestStartupGuard:
|
||||
"""Adapter must refuse to start without SMS_WEBHOOK_URL."""
|
||||
|
||||
def _make_adapter(self, extra_env=None):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_start_without_webhook_url(self):
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_allows_start_without_url(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \
|
||||
patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_url_allows_start(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter(
|
||||
extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"}
|
||||
)
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
# ── Twilio signature validation ────────────────────────────────────
|
||||
|
||||
def _compute_twilio_signature(auth_token, url, params):
|
||||
"""Reference implementation of Twilio's signature algorithm."""
|
||||
data_to_sign = url
|
||||
for key in sorted(params.keys()):
|
||||
data_to_sign += key + params[key]
|
||||
mac = hmac.new(
|
||||
auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
return base64.b64encode(mac.digest()).decode("utf-8")
|
||||
|
||||
|
||||
class TestTwilioSignatureValidation:
|
||||
"""Unit tests for SmsAdapter._validate_twilio_signature."""
|
||||
|
||||
def _make_adapter(self, auth_token="test_token_secret"):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": auth_token,
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key=auth_token)
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
def test_valid_signature_accepted(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_invalid_signature_rejected(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
assert adapter._validate_twilio_signature(url, params, "badsig") is False
|
||||
|
||||
def test_wrong_token_rejected(self):
|
||||
adapter = self._make_adapter(auth_token="correct_token")
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature("wrong_token", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is False
|
||||
|
||||
def test_params_sorted_by_key(self):
|
||||
"""Signature must be computed with params sorted alphabetically."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_empty_param_values_included(self):
|
||||
"""Blank values must be included in signature computation."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_url_matters(self):
|
||||
"""Different URLs produce different signatures."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://a.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://b.com/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_443_matches_without_port(self):
|
||||
"""Signature for https URL with :443 validates against URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_port_variant_without_port_matches_443(self):
|
||||
"""Signature for https URL without port validates against URL with :443."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:443/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_non_standard_port_no_variant(self):
|
||||
"""Non-standard port must NOT match URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:8080/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_http_80(self):
|
||||
"""Port variant also works for http with port 80."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "http://example.com:80/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"http://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
|
||||
# ── Webhook signature enforcement (handler-level) ──────────────────
|
||||
|
||||
class TestWebhookSignatureEnforcement:
|
||||
"""Integration tests for signature validation in _handle_webhook."""
|
||||
|
||||
def _make_adapter(self, webhook_url=""):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "test_token_secret",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": webhook_url,
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="test_token_secret")
|
||||
adapter = SmsAdapter(pc)
|
||||
adapter._message_handler = AsyncMock()
|
||||
return adapter
|
||||
|
||||
def _mock_request(self, body, headers=None):
|
||||
request = MagicMock()
|
||||
request.read = AsyncMock(return_value=body)
|
||||
request.headers = headers or {}
|
||||
return request
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_skips_validation(self):
|
||||
"""With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted."""
|
||||
env = {"SMS_INSECURE_NO_SIGNATURE": "true"}
|
||||
with patch.dict(os.environ, env):
|
||||
adapter = self._make_adapter(webhook_url="")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body)
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_with_url_still_validates(self):
|
||||
"""When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set,
|
||||
validation stays active (URL takes precedence)."""
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_returns_200(self):
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature("test_token_secret", webhook_url, params)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_port_variant_signature_returns_200(self):
|
||||
"""Signature computed with :443 should pass when URL configured without port."""
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
from gateway.config import GatewayConfig, Platform, load_gateway_config
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
|
|
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
|
|||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
runner.adapters = {}
|
||||
runner._model = "test-model"
|
||||
runner._base_url = ""
|
||||
runner._has_setup_skill = lambda: False
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=source,
|
||||
media_urls=["/tmp/queued-voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={
|
||||
"success": True,
|
||||
"transcript": "queued voice transcript",
|
||||
"provider": "local_command",
|
||||
},
|
||||
):
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "queued voice transcript" in result
|
||||
assert "voice message" in result.lower()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch):
|
|||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
# Mock HTTPXRequest so the builder chain doesn't fail
|
||||
monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
|||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_token_lock"
|
||||
assert adapter.fatal_error_code == "telegram-bot-token_lock"
|
||||
assert adapter.has_fatal_error is True
|
||||
assert "already using this Telegram bot token" in adapter.fatal_error_message
|
||||
assert "already in use" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
|
|
@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
|
|
@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m
|
|||
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
app = SimpleNamespace(
|
||||
bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()),
|
||||
updater=SimpleNamespace(),
|
||||
|
|
@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.telegram.Application",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
"""Tests for the Weixin platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
|
||||
from gateway.platforms.weixin import WeixinAdapter
|
||||
from gateway.platforms import weixin
|
||||
from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter
|
||||
from tools.send_message_tool import _parse_target_ref, _send_to_platform
|
||||
|
||||
|
||||
|
|
@ -187,6 +189,70 @@ class TestWeixinConfig:
|
|||
assert config.get_connected_platforms() == []
|
||||
|
||||
|
||||
class TestWeixinStatePersistence:
|
||||
def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
account_path = tmp_path / "weixin" / "accounts" / "acct.json"
|
||||
account_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
original = {"token": "old-token", "base_url": "https://old.example.com"}
|
||||
account_path.write_text(json.dumps(original), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
try:
|
||||
weixin.save_weixin_account(
|
||||
str(tmp_path),
|
||||
account_id="acct",
|
||||
token="new-token",
|
||||
base_url="https://new.example.com",
|
||||
user_id="wxid_new",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("expected save_weixin_account to propagate replace failure")
|
||||
|
||||
assert json.loads(account_path.read_text(encoding="utf-8")) == original
|
||||
|
||||
def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json"
|
||||
token_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
store = ContextTokenStore(str(tmp_path))
|
||||
with patch.object(weixin.logger, "warning") as warning_mock:
|
||||
store.set("acct", "user-b", "new-token")
|
||||
|
||||
assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json"
|
||||
sync_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
try:
|
||||
weixin._save_sync_buf(str(tmp_path), "acct", "new-sync")
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("expected _save_sync_buf to propagate replace failure")
|
||||
|
||||
assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"}
|
||||
|
||||
|
||||
class TestWeixinSendMessageIntegration:
|
||||
def test_parse_target_ref_accepts_weixin_ids(self):
|
||||
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)
|
||||
|
|
@ -217,6 +283,55 @@ class TestWeixinSendMessageIntegration:
|
|||
)
|
||||
|
||||
|
||||
class TestWeixinChunkDelivery:
|
||||
def _connected_adapter(self) -> WeixinAdapter:
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
return adapter
|
||||
|
||||
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
|
||||
def test_send_waits_between_multiple_chunks(self, send_message_mock, sleep_mock):
|
||||
adapter = self._connected_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 12
|
||||
|
||||
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
|
||||
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
|
||||
|
||||
assert result.success is True
|
||||
assert send_message_mock.await_count == 3
|
||||
assert sleep_mock.await_count == 2
|
||||
|
||||
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
|
||||
def test_send_retries_failed_chunk_before_continuing(self, send_message_mock, sleep_mock):
|
||||
adapter = self._connected_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 12
|
||||
calls = {"count": 0}
|
||||
|
||||
async def flaky_send(*args, **kwargs):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 2:
|
||||
raise RuntimeError("temporary iLink failure")
|
||||
|
||||
send_message_mock.side_effect = flaky_send
|
||||
|
||||
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
|
||||
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
|
||||
|
||||
assert result.success is True
|
||||
# 3 chunks, but chunk 2 fails once and retries → 4 _send_message calls total
|
||||
assert send_message_mock.await_count == 4
|
||||
# The retried chunk should reuse the same client_id for deduplication
|
||||
first_try = send_message_mock.await_args_list[1].kwargs
|
||||
retry = send_message_mock.await_args_list[2].kwargs
|
||||
assert first_try["text"] == retry["text"]
|
||||
assert first_try["client_id"] == retry["client_id"]
|
||||
|
||||
|
||||
class TestWeixinRemoteMediaSafety:
|
||||
def test_download_remote_media_blocks_unsafe_urls(self):
|
||||
adapter = _make_adapter()
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class TestWaitForGatewayExit:
|
|||
def test_kill_gateway_processes_force_uses_helper(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22])
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22])
|
||||
monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force)))
|
||||
|
||||
killed = gateway.kill_gateway_processes(force=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for gateway service management helpers."""
|
||||
|
||||
import os
|
||||
import pwd
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
|
@ -129,7 +130,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop"))
|
||||
|
|
@ -155,7 +156,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True}))
|
||||
|
|
@ -924,6 +925,23 @@ class TestProfileArg:
|
|||
assert "<string>--profile</string>" in plist
|
||||
assert "<string>mybot</string>" in plist
|
||||
|
||||
def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
machine_home = tmp_path / "machine-home"
|
||||
machine_home.mkdir()
|
||||
profile_home = profile_dir / "home"
|
||||
profile_home.mkdir()
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: profile_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home)))
|
||||
|
||||
plist_path = gateway_cli.get_launchd_plist_path()
|
||||
|
||||
assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist"
|
||||
|
||||
|
||||
class TestRemapPathForUser:
|
||||
"""Unit tests for _remap_path_for_user()."""
|
||||
|
|
|
|||
|
|
@ -1214,3 +1214,115 @@ def test_openrouter_provider_not_affected_by_custom_fix(monkeypatch):
|
|||
|
||||
resolved = rp.resolve_runtime_provider(requested="openrouter")
|
||||
assert resolved["provider"] == "openrouter"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# fix #7828 — custom_providers model field must propagate to runtime
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_named_custom_provider_includes_model(monkeypatch):
|
||||
"""_get_named_custom_provider should include the model field from config."""
|
||||
monkeypatch.setattr(rp, "load_config", lambda: {
|
||||
"custom_providers": [{
|
||||
"name": "my-dashscope",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "test-key",
|
||||
"api_mode": "chat_completions",
|
||||
"model": "qwen3.6-plus",
|
||||
}],
|
||||
})
|
||||
|
||||
result = rp._get_named_custom_provider("my-dashscope")
|
||||
assert result is not None
|
||||
assert result["model"] == "qwen3.6-plus"
|
||||
|
||||
|
||||
def test_get_named_custom_provider_excludes_empty_model(monkeypatch):
|
||||
"""Empty or whitespace-only model field should not appear in result."""
|
||||
for model_val in ["", " ", None]:
|
||||
entry = {
|
||||
"name": "test-ep",
|
||||
"base_url": "https://example.com/v1",
|
||||
"api_key": "key",
|
||||
}
|
||||
if model_val is not None:
|
||||
entry["model"] = model_val
|
||||
|
||||
monkeypatch.setattr(rp, "load_config", lambda e=entry: {
|
||||
"custom_providers": [e],
|
||||
})
|
||||
|
||||
result = rp._get_named_custom_provider("test-ep")
|
||||
assert result is not None
|
||||
assert "model" not in result, (
|
||||
f"model field {model_val!r} should not be included in result"
|
||||
)
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_model_direct_path(monkeypatch):
|
||||
"""Model should propagate through the direct (non-pool) resolution path."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "qwen3.6-plus",
|
||||
},
|
||||
)
|
||||
# Ensure pool doesn't intercept
|
||||
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert resolved["model"] == "qwen3.6-plus"
|
||||
assert resolved["provider"] == "custom"
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_model_pool_path(monkeypatch):
|
||||
"""Model should propagate even when credential pool handles credentials."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "qwen3.6-plus",
|
||||
},
|
||||
)
|
||||
# Pool returns a result (intercepting the normal path)
|
||||
monkeypatch.setattr(
|
||||
rp, "_try_resolve_from_custom_pool",
|
||||
lambda *a, **k: {
|
||||
"provider": "custom",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "pool-key",
|
||||
"source": "pool:custom:my-server",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert resolved["model"] == "qwen3.6-plus", (
|
||||
"model must be injected into pool result"
|
||||
)
|
||||
assert resolved["api_key"] == "pool-key", "pool credentials should be used"
|
||||
|
||||
|
||||
def test_named_custom_runtime_no_model_when_absent(monkeypatch):
|
||||
"""When custom_providers entry has no model field, runtime should not either."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert "model" not in resolved
|
||||
|
|
|
|||
|
|
@ -191,6 +191,19 @@ class TestLaunchdPlistPath:
|
|||
raise AssertionError("PATH key not found in plist")
|
||||
|
||||
|
||||
class TestLaunchdPlistCurrentness:
|
||||
def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin")
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin")
|
||||
|
||||
assert gateway_cli.launchd_plist_is_current() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — macOS launchd detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -536,7 +549,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -579,7 +592,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -618,7 +631,7 @@ class TestServicePidExclusion:
|
|||
launchctl_loaded=True,
|
||||
)
|
||||
|
||||
def fake_find(exclude_pids=None):
|
||||
def fake_find(exclude_pids=None, all_profiles=False):
|
||||
_exclude = exclude_pids or set()
|
||||
return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude]
|
||||
|
||||
|
|
@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude:
|
|||
pids = gateway_cli.find_gateway_pids()
|
||||
assert 100 in pids
|
||||
assert 200 in pids
|
||||
|
||||
def test_filters_to_current_profile(self, monkeypatch, tmp_path):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(gateway_cli, "is_windows", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return subprocess.CompletedProcess(
|
||||
cmd, 0,
|
||||
stdout=(
|
||||
"100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n"
|
||||
"200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n"
|
||||
),
|
||||
stderr="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr("os.getpid", lambda: 999)
|
||||
monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set())
|
||||
monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha")
|
||||
|
||||
pids = gateway_cli.find_gateway_pids()
|
||||
|
||||
assert pids == [100]
|
||||
|
|
|
|||
|
|
@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def _make_bare_agent(self):
|
||||
"""Create a bare AIAgent via __new__ with all interrupt-related attrs."""
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._execution_thread_id = None # defaults to current thread in set_interrupt
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.quiet_mode = True
|
||||
return agent
|
||||
|
||||
def test_parent_interrupt_sets_child_flag(self):
|
||||
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
parent = self._make_bare_agent()
|
||||
child = self._make_bare_agent()
|
||||
|
||||
parent._active_children.append(child)
|
||||
|
||||
|
|
@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_global(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
|
||||
|
||||
This is the intended behavior at startup, but verify it doesn't
|
||||
accidentally clear an interrupt intended for a running child.
|
||||
def test_child_clear_interrupt_at_start_clears_thread(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the
|
||||
per-thread interrupt flag for the current thread.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child = self._make_bare_agent()
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child.quiet_mode = True
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
|
||||
# Global is set
|
||||
# Interrupt for current thread is set
|
||||
set_interrupt(True)
|
||||
assert is_interrupted() is True
|
||||
|
||||
# child.clear_interrupt() clears both
|
||||
# child.clear_interrupt() clears both instance flag and thread flag
|
||||
child.clear_interrupt()
|
||||
assert child._interrupt_requested is False
|
||||
assert is_interrupted() is False
|
||||
|
||||
def test_interrupt_during_child_api_call_detected(self):
|
||||
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
child = self._make_bare_agent()
|
||||
child.api_mode = "chat_completions"
|
||||
child.log_prefix = ""
|
||||
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
|
||||
|
|
@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
|
||||
def test_concurrent_interrupt_propagation(self):
|
||||
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
parent = self._make_bare_agent()
|
||||
child = self._make_bare_agent()
|
||||
|
||||
# Register child (simulating what _run_single_child does)
|
||||
parent._active_children.append(child)
|
||||
|
|
@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
set_interrupt(False)
|
||||
|
||||
|
||||
class TestPerThreadInterruptIsolation(unittest.TestCase):
|
||||
"""Verify that interrupting one agent does NOT affect another agent's thread.
|
||||
|
||||
This is the core fix for the gateway cross-session interrupt leak:
|
||||
multiple agents run in separate threads within the same process, and
|
||||
interrupting agent A must not kill agent B's running tools.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_interrupt_only_affects_target_thread(self):
|
||||
"""set_interrupt(True, tid) only makes is_interrupted() True on that thread."""
|
||||
results = {}
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def thread_a():
|
||||
"""Agent A's execution thread — will be interrupted."""
|
||||
tid = threading.current_thread().ident
|
||||
results["a_tid"] = tid
|
||||
barrier.wait(timeout=5) # sync with thread B
|
||||
time.sleep(0.2) # let the interrupt arrive
|
||||
results["a_interrupted"] = is_interrupted()
|
||||
|
||||
def thread_b():
|
||||
"""Agent B's execution thread — should NOT be affected."""
|
||||
tid = threading.current_thread().ident
|
||||
results["b_tid"] = tid
|
||||
barrier.wait(timeout=5) # sync with thread A
|
||||
time.sleep(0.2)
|
||||
results["b_interrupted"] = is_interrupted()
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
tb.start()
|
||||
|
||||
# Wait for both threads to register their TIDs
|
||||
time.sleep(0.05)
|
||||
while "a_tid" not in results or "b_tid" not in results:
|
||||
time.sleep(0.01)
|
||||
|
||||
# Interrupt ONLY thread A (simulates gateway interrupting agent A)
|
||||
set_interrupt(True, results["a_tid"])
|
||||
|
||||
ta.join(timeout=3)
|
||||
tb.join(timeout=3)
|
||||
|
||||
assert results["a_interrupted"] is True, "Thread A should see the interrupt"
|
||||
assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt"
|
||||
|
||||
def test_clear_interrupt_only_clears_target_thread(self):
|
||||
"""Clearing one thread's interrupt doesn't clear another's."""
|
||||
tid_a = 99990001
|
||||
tid_b = 99990002
|
||||
set_interrupt(True, tid_a)
|
||||
set_interrupt(True, tid_b)
|
||||
|
||||
# Clear only A
|
||||
set_interrupt(False, tid_a)
|
||||
|
||||
# Simulate checking from thread B's perspective
|
||||
from tools.interrupt import _interrupted_threads, _lock
|
||||
with _lock:
|
||||
assert tid_a not in _interrupted_threads
|
||||
assert tid_b in _interrupted_threads
|
||||
|
||||
# Cleanup
|
||||
set_interrupt(False, tid_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -2087,8 +2087,9 @@ class TestRunConversation:
|
|||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
assert "/thinkon" in result["final_response"]
|
||||
|
||||
def test_length_empty_content_detected_as_thinking_exhausted(self, agent):
|
||||
"""When finish_reason='length' and content is None/empty, detect exhaustion."""
|
||||
def test_length_empty_content_without_think_tags_retries_normally(self, agent):
|
||||
"""When finish_reason='length' and content is None but no think tags,
|
||||
fall through to normal continuation retry (not thinking-exhaustion)."""
|
||||
self._setup_agent(agent)
|
||||
resp = _mock_response(content=None, finish_reason="length")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
|
@ -2100,12 +2101,10 @@ class TestRunConversation:
|
|||
):
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
# Without think tags, the agent should attempt continuation retries
|
||||
# (up to 3), not immediately fire thinking-exhaustion.
|
||||
assert result["api_calls"] == 3
|
||||
assert result["completed"] is False
|
||||
assert result["api_calls"] == 1
|
||||
assert "reasoning" in result["error"].lower()
|
||||
# User-friendly message is returned
|
||||
assert result["final_response"] is not None
|
||||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
|
||||
def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
|
@ -2169,6 +2168,35 @@ class TestRunConversation:
|
|||
mock_hfc.assert_called_once()
|
||||
assert result["final_response"] == "Done!"
|
||||
|
||||
def test_truncated_tool_args_detected_when_finish_reason_not_length(self, agent):
|
||||
"""When a router rewrites finish_reason from 'length' to 'tool_calls',
|
||||
truncated JSON arguments should still be detected and refused rather
|
||||
than wasting 3 retry attempts."""
|
||||
self._setup_agent(agent)
|
||||
agent.valid_tool_names.add("write_file")
|
||||
bad_tc = _mock_tool_call(
|
||||
name="write_file",
|
||||
arguments='{"path":"report.md","content":"partial',
|
||||
call_id="c1",
|
||||
)
|
||||
resp = _mock_response(
|
||||
content="", finish_reason="tool_calls", tool_calls=[bad_tc],
|
||||
)
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call") as mock_handle_function_call,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("write the report")
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result["partial"] is True
|
||||
assert "truncated due to output length limit" in result["error"]
|
||||
mock_handle_function_call.assert_not_called()
|
||||
|
||||
|
||||
class TestRetryExhaustion:
|
||||
"""Regression: retry_count > max_retries was dead code (off-by-one).
|
||||
|
|
|
|||
|
|
@ -1104,3 +1104,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch
|
|||
]
|
||||
assert "enc_first" in encrypted_contents
|
||||
assert "enc_second" in encrypted_contents
|
||||
|
||||
|
||||
def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""Duplicate reasoning item IDs across multi-turn incomplete responses
|
||||
must be deduplicated so the Responses API doesn't reject with HTTP 400."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
messages = [
|
||||
{"role": "user", "content": "think hard"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"codex_reasoning_items": [
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "partial answer",
|
||||
"codex_reasoning_items": [
|
||||
# rs_aaa is duplicated from the previous turn
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"},
|
||||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
|
||||
reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"]
|
||||
# rs_aaa should appear only once (first occurrence kept)
|
||||
assert reasoning_ids.count("rs_aaa") == 1
|
||||
# rs_bbb and rs_ccc should each appear once
|
||||
assert reasoning_ids.count("rs_bbb") == 1
|
||||
assert reasoning_ids.count("rs_ccc") == 1
|
||||
assert len(reasoning_ids) == 3
|
||||
|
||||
|
||||
def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""_preflight_codex_input_items should also deduplicate reasoning items by ID."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
raw_input = [
|
||||
{"role": "user", "content": [{"type": "input_text", "text": "hello"}]},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
|
||||
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
|
||||
reasoning_ids = [it["id"] for it in reasoning_items]
|
||||
assert reasoning_ids.count("rs_xyz") == 1
|
||||
assert reasoning_ids.count("rs_zzz") == 1
|
||||
assert len(reasoning_items) == 2
|
||||
|
|
|
|||
158
tests/tools/test_browser_orphan_reaper.py
Normal file
158
tests/tools/test_browser_orphan_reaper.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Tests for _reap_orphaned_browser_sessions() — kills orphaned agent-browser
|
||||
daemons whose Python parent exited without cleaning up."""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_tmpdir(tmp_path):
|
||||
"""Patch _socket_safe_tmpdir to return a temp dir we control."""
|
||||
with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)):
|
||||
yield tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_sessions():
|
||||
"""Ensure _active_sessions is empty for each test."""
|
||||
import tools.browser_tool as bt
|
||||
orig = bt._active_sessions.copy()
|
||||
bt._active_sessions.clear()
|
||||
yield
|
||||
bt._active_sessions.clear()
|
||||
bt._active_sessions.update(orig)
|
||||
|
||||
|
||||
def _make_socket_dir(tmpdir, session_name, pid=None):
|
||||
"""Create a fake agent-browser socket directory with optional PID file."""
|
||||
d = tmpdir / f"agent-browser-{session_name}"
|
||||
d.mkdir()
|
||||
if pid is not None:
|
||||
(d / f"{session_name}.pid").write_text(str(pid))
|
||||
return d
|
||||
|
||||
|
||||
class TestReapOrphanedBrowserSessions:
|
||||
"""Tests for the orphan reaper function."""
|
||||
|
||||
def test_no_socket_dirs_is_noop(self, fake_tmpdir):
|
||||
"""No socket dirs => nothing happens, no errors."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
_reap_orphaned_browser_sessions() # should not raise
|
||||
|
||||
def test_stale_dir_without_pid_file_is_removed(self, fake_tmpdir):
|
||||
"""Socket dir with no PID file is cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
d = _make_socket_dir(fake_tmpdir, "h_abc1234567")
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
||||
def test_stale_dir_with_dead_pid_is_removed(self, fake_tmpdir):
|
||||
"""Socket dir whose daemon PID is dead gets cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
d = _make_socket_dir(fake_tmpdir, "h_dead123456", pid=999999999)
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
||||
def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir):
|
||||
"""Alive daemon not tracked by _active_sessions gets SIGTERM."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345)
|
||||
|
||||
kill_calls = []
|
||||
original_kill = os.kill
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
if sig == 0:
|
||||
return # pretend process exists
|
||||
# Don't actually kill anything
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should have checked existence (sig 0) then killed (SIGTERM)
|
||||
assert (12345, 0) in kill_calls
|
||||
assert (12345, signal.SIGTERM) in kill_calls
|
||||
|
||||
def test_tracked_session_is_not_reaped(self, fake_tmpdir):
|
||||
"""Sessions tracked in _active_sessions are left alone."""
|
||||
import tools.browser_tool as bt
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
session_name = "h_tracked1234"
|
||||
d = _make_socket_dir(fake_tmpdir, session_name, pid=12345)
|
||||
|
||||
# Register the session as actively tracked
|
||||
bt._active_sessions["some_task"] = {"session_name": session_name}
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should NOT have tried to kill anything
|
||||
assert len(kill_calls) == 0
|
||||
# Dir should still exist
|
||||
assert d.exists()
|
||||
|
||||
def test_permission_error_on_kill_check_skips(self, fake_tmpdir):
|
||||
"""If we can't check the PID (PermissionError), skip it."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345)
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
if sig == 0:
|
||||
raise PermissionError("not our process")
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Dir should still exist (we didn't touch someone else's process)
|
||||
assert d.exists()
|
||||
|
||||
def test_cdp_sessions_are_also_reaped(self, fake_tmpdir):
|
||||
"""CDP sessions (cdp_ prefix) are also scanned."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "cdp_abc1234567")
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
# No PID file → cleaned up
|
||||
assert not d.exists()
|
||||
|
||||
def test_non_hermes_dirs_are_ignored(self, fake_tmpdir):
|
||||
"""Socket dirs that don't match our naming pattern are left alone."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
# Create a dir that doesn't match h_* or cdp_* pattern
|
||||
d = fake_tmpdir / "agent-browser-other_session"
|
||||
d.mkdir()
|
||||
(d / "other_session.pid").write_text("12345")
|
||||
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should NOT be touched
|
||||
assert d.exists()
|
||||
|
||||
def test_corrupt_pid_file_is_cleaned(self, fake_tmpdir):
|
||||
"""PID file with non-integer content is cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_corrupt1234")
|
||||
(d / "h_corrupt1234.pid").write_text("not-a-number")
|
||||
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
|
@ -1,9 +1,6 @@
|
|||
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
|
@ -42,6 +39,19 @@ def checkpoint_base(tmp_path):
|
|||
return tmp_path / "checkpoints"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_home(tmp_path, monkeypatch):
|
||||
"""Set a deterministic fake home for expanduser/path-home behavior."""
|
||||
home = tmp_path / "home"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HOME", str(home))
|
||||
monkeypatch.setenv("USERPROFILE", str(home))
|
||||
monkeypatch.delenv("HOMEDRIVE", raising=False)
|
||||
monkeypatch.delenv("HOMEPATH", raising=False)
|
||||
monkeypatch.setattr(Path, "home", classmethod(lambda cls: home))
|
||||
return home
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mgr(work_dir, checkpoint_base, monkeypatch):
|
||||
"""CheckpointManager with redirected checkpoint base."""
|
||||
|
|
@ -78,6 +88,16 @@ class TestShadowRepoPath:
|
|||
p = _shadow_repo_path(str(work_dir))
|
||||
assert str(p).startswith(str(checkpoint_base))
|
||||
|
||||
def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
expanded_path = str(project)
|
||||
|
||||
assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Shadow repo init
|
||||
|
|
@ -221,6 +241,20 @@ class TestListCheckpoints:
|
|||
assert result[0]["reason"] == "third"
|
||||
assert result[2]["reason"] == "first"
|
||||
|
||||
def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
(project / "main.py").write_text("v1\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
|
||||
listed = mgr.list_checkpoints(str(project))
|
||||
assert len(listed) == 1
|
||||
assert listed[0]["reason"] == "initial"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — restoring
|
||||
|
|
@ -271,6 +305,28 @@ class TestRestore:
|
|||
assert len(all_cps) >= 2
|
||||
assert "pre-rollback" in all_cps[0]["reason"]
|
||||
|
||||
def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
file_path = project / "main.py"
|
||||
file_path.write_text("original\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
mgr.new_turn()
|
||||
|
||||
file_path.write_text("changed\n")
|
||||
checkpoints = mgr.list_checkpoints(str(project))
|
||||
diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"])
|
||||
assert diff_result["success"] is True
|
||||
assert "main.py" in diff_result["diff"]
|
||||
|
||||
restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"])
|
||||
assert restore_result["success"] is True
|
||||
assert file_path.read_text() == "original\n"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — working dir resolution
|
||||
|
|
@ -310,6 +366,19 @@ class TestWorkingDirResolution:
|
|||
result = mgr.get_working_dir_for_path(str(filepath))
|
||||
assert result == str(filepath.parent)
|
||||
|
||||
def test_resolves_tilde_path_to_project_root(self, fake_home):
|
||||
mgr = CheckpointManager(enabled=True)
|
||||
project = fake_home / "myproject"
|
||||
project.mkdir()
|
||||
(project / "pyproject.toml").write_text("[project]\n")
|
||||
subdir = project / "src"
|
||||
subdir.mkdir()
|
||||
filepath = subdir / "main.py"
|
||||
filepath.write_text("x\n")
|
||||
|
||||
result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py")
|
||||
assert result == str(project)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Git env isolation
|
||||
|
|
@ -333,6 +402,14 @@ class TestGitEnvIsolation:
|
|||
env = _git_env(shadow, str(tmp_path))
|
||||
assert "GIT_INDEX_FILE" not in env
|
||||
|
||||
def test_expands_tilde_in_work_tree(self, fake_home, tmp_path):
|
||||
shadow = tmp_path / "shadow"
|
||||
work = fake_home / "work"
|
||||
work.mkdir()
|
||||
|
||||
env = _git_env(shadow, f"~/{work.name}")
|
||||
assert env["GIT_WORK_TREE"] == str(work.resolve())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_checkpoint_list
|
||||
|
|
@ -384,6 +461,8 @@ class TestErrorResilience:
|
|||
assert result is False
|
||||
|
||||
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
completed = subprocess.CompletedProcess(
|
||||
args=["git", "diff", "--cached", "--quiet"],
|
||||
returncode=1,
|
||||
|
|
@ -395,7 +474,7 @@ class TestErrorResilience:
|
|||
ok, stdout, stderr = _run_git(
|
||||
["diff", "--cached", "--quiet"],
|
||||
tmp_path / "shadow",
|
||||
str(tmp_path / "work"),
|
||||
str(work),
|
||||
allowed_returncodes={1},
|
||||
)
|
||||
assert ok is False
|
||||
|
|
@ -403,6 +482,38 @@ class TestErrorResilience:
|
|||
assert stderr == ""
|
||||
assert not caplog.records
|
||||
|
||||
def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog):
|
||||
missing = tmp_path / "missing"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(missing),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert "working directory not found" in stderr
|
||||
assert not any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
|
||||
def raise_missing_git(*args, **kwargs):
|
||||
raise FileNotFoundError(2, "No such file or directory", "git")
|
||||
|
||||
monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git)
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(work),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert stderr == "git not found"
|
||||
assert any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
||||
"""Checkpoint failures should never raise — they're silently logged."""
|
||||
def broken_run_git(*args, **kwargs):
|
||||
|
|
@ -411,3 +522,68 @@ class TestErrorResilience:
|
|||
# Should not raise
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
||||
assert result is False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Security / Input validation
|
||||
# =========================================================================
|
||||
|
||||
class TestSecurity:
|
||||
def test_restore_rejects_argument_injection(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Try to pass a git flag as a commit hash
|
||||
result = mgr.restore(str(work_dir), "--patch")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
assert "must not start with '-'" in result["error"]
|
||||
|
||||
result = mgr.restore(str(work_dir), "-p")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
|
||||
def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Git hashes should not contain characters like ;, &, |
|
||||
result = mgr.restore(str(work_dir), "abc; rm -rf /")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
result = mgr.diff(str(work_dir), "abc&def")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
def test_restore_rejects_path_traversal(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Real commit hash but malicious path
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Absolute path outside
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd")
|
||||
assert result["success"] is False
|
||||
assert "got absolute path" in result["error"]
|
||||
|
||||
# Relative traversal outside path
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt")
|
||||
assert result["success"] is False
|
||||
assert "escapes the working directory" in result["error"]
|
||||
|
||||
def test_restore_accepts_valid_file_path(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Valid path inside directory
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="main.py")
|
||||
assert result["success"] is True
|
||||
|
||||
# Another valid path with subdirectories
|
||||
(work_dir / "subdir").mkdir()
|
||||
(work_dir / "subdir" / "test.txt").write_text("hello")
|
||||
mgr.new_turn()
|
||||
mgr.ensure_checkpoint(str(work_dir), "second")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt")
|
||||
assert result["success"] is True
|
||||
|
|
|
|||
|
|
@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase):
|
|||
@unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
|
||||
class TestInterruptHandling(unittest.TestCase):
|
||||
def test_interrupt_event_stops_execution(self):
|
||||
"""When _interrupt_event is set, execute_code should stop the script."""
|
||||
"""When interrupt is set for the execution thread, execute_code should stop."""
|
||||
code = "import time; time.sleep(60); print('should not reach')"
|
||||
from tools.interrupt import set_interrupt
|
||||
|
||||
# Capture the main thread ID so we can target the interrupt correctly.
|
||||
# execute_code runs in the current thread; set_interrupt needs its ID.
|
||||
main_tid = threading.current_thread().ident
|
||||
|
||||
def set_interrupt_after_delay():
|
||||
import time as _t
|
||||
_t.sleep(1)
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
_interrupt_event.set()
|
||||
set_interrupt(True, main_tid)
|
||||
|
||||
t = threading.Thread(target=set_interrupt_after_delay, daemon=True)
|
||||
t.start()
|
||||
|
|
@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase):
|
|||
self.assertEqual(result["status"], "interrupted")
|
||||
self.assertIn("interrupted", result["output"])
|
||||
finally:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
_interrupt_event.clear()
|
||||
set_interrupt(False, main_tid)
|
||||
t.join(timeout=3)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,8 @@ class TestCheckpointNotify:
|
|||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_user_id": "u123",
|
||||
"watcher_user_name": "alice",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 5,
|
||||
"notify_on_complete": True,
|
||||
|
|
@ -236,6 +238,8 @@ class TestCheckpointNotify:
|
|||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 1
|
||||
assert registry.pending_watchers[0]["notify_on_complete"] is True
|
||||
assert registry.pending_watchers[0]["user_id"] == "u123"
|
||||
assert registry.pending_watchers[0]["user_name"] == "alice"
|
||||
|
||||
def test_recover_defaults_false(self, registry, tmp_path):
|
||||
"""Old checkpoint entries without the field default to False."""
|
||||
|
|
|
|||
|
|
@ -438,6 +438,8 @@ class TestCheckpoint:
|
|||
s = _make_session()
|
||||
s.watcher_platform = "telegram"
|
||||
s.watcher_chat_id = "999"
|
||||
s.watcher_user_id = "u123"
|
||||
s.watcher_user_name = "alice"
|
||||
s.watcher_thread_id = "42"
|
||||
s.watcher_interval = 60
|
||||
registry._running[s.id] = s
|
||||
|
|
@ -447,6 +449,8 @@ class TestCheckpoint:
|
|||
assert len(data) == 1
|
||||
assert data[0]["watcher_platform"] == "telegram"
|
||||
assert data[0]["watcher_chat_id"] == "999"
|
||||
assert data[0]["watcher_user_id"] == "u123"
|
||||
assert data[0]["watcher_user_name"] == "alice"
|
||||
assert data[0]["watcher_thread_id"] == "42"
|
||||
assert data[0]["watcher_interval"] == 60
|
||||
|
||||
|
|
@ -460,6 +464,8 @@ class TestCheckpoint:
|
|||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_user_id": "u123",
|
||||
"watcher_user_name": "alice",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 60,
|
||||
}]))
|
||||
|
|
@ -471,6 +477,8 @@ class TestCheckpoint:
|
|||
assert w["session_id"] == "proc_live"
|
||||
assert w["platform"] == "telegram"
|
||||
assert w["chat_id"] == "123"
|
||||
assert w["user_id"] == "u123"
|
||||
assert w["user_name"] == "alice"
|
||||
assert w["thread_id"] == "42"
|
||||
assert w["check_interval"] == 60
|
||||
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ word word
|
|||
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert outside_file.read_text() == "old text here"
|
||||
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ class TestWriteFile:
|
|||
result = _write_file("my-skill", "references/escape/owned.md", "malicious")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert not (outside_dir / "owned.md").exists()
|
||||
|
||||
|
||||
|
|
@ -449,7 +449,7 @@ class TestRemoveFile:
|
|||
result = _remove_file("my-skill", "references/escape/keep.txt")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert outside_file.exists()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,34 @@ class TestWriteToSandbox:
|
|||
cmd = env.execute.call_args[0][0]
|
||||
assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd
|
||||
|
||||
def test_path_with_spaces_is_quoted(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
remote_path = "/tmp/hermes results/abc file.txt"
|
||||
_write_to_sandbox("content", remote_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert "'/tmp/hermes results'" in cmd
|
||||
assert "'/tmp/hermes results/abc file.txt'" in cmd
|
||||
|
||||
def test_shell_metacharacters_neutralized(self):
|
||||
"""Paths with shell metacharacters must be quoted to prevent injection."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/hermes-results/$(whoami).txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The $() must not appear unquoted — shlex.quote wraps it
|
||||
assert "'/tmp/hermes-results/$(whoami).txt'" in cmd
|
||||
|
||||
def test_semicolon_injection_neutralized(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/x; rm -rf /; echo .txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The semicolons must be inside quotes, not acting as command separators
|
||||
assert "'/tmp/x; rm -rf /; echo .txt'" in cmd
|
||||
|
||||
|
||||
class TestResolveStorageDir:
|
||||
def test_defaults_to_storage_dir_without_env(self):
|
||||
|
|
|
|||
|
|
@ -473,13 +473,104 @@ def _cleanup_inactive_browser_sessions():
|
|||
logger.warning("Error cleaning up inactive session %s: %s", task_id, e)
|
||||
|
||||
|
||||
def _reap_orphaned_browser_sessions():
|
||||
"""Scan for orphaned agent-browser daemon processes from previous runs.
|
||||
|
||||
When the Python process that created a browser session exits uncleanly
|
||||
(SIGKILL, crash, gateway restart), the in-memory ``_active_sessions``
|
||||
tracking is lost but the node + Chromium processes keep running.
|
||||
|
||||
This function scans the tmp directory for ``agent-browser-*`` socket dirs
|
||||
left behind by previous runs, reads the daemon PID files, and kills any
|
||||
daemons that are still alive but not tracked by the current process.
|
||||
|
||||
Called once on cleanup-thread startup — not every 30 seconds — to avoid
|
||||
races with sessions being actively created.
|
||||
"""
|
||||
import glob
|
||||
|
||||
tmpdir = _socket_safe_tmpdir()
|
||||
pattern = os.path.join(tmpdir, "agent-browser-h_*")
|
||||
socket_dirs = glob.glob(pattern)
|
||||
# Also pick up CDP sessions
|
||||
socket_dirs += glob.glob(os.path.join(tmpdir, "agent-browser-cdp_*"))
|
||||
|
||||
if not socket_dirs:
|
||||
return
|
||||
|
||||
# Build set of session_names currently tracked by this process
|
||||
with _cleanup_lock:
|
||||
tracked_names = {
|
||||
info.get("session_name")
|
||||
for info in _active_sessions.values()
|
||||
if info.get("session_name")
|
||||
}
|
||||
|
||||
reaped = 0
|
||||
for socket_dir in socket_dirs:
|
||||
dir_name = os.path.basename(socket_dir)
|
||||
# dir_name is "agent-browser-{session_name}"
|
||||
session_name = dir_name.removeprefix("agent-browser-")
|
||||
if not session_name:
|
||||
continue
|
||||
|
||||
# Skip sessions that we are actively tracking
|
||||
if session_name in tracked_names:
|
||||
continue
|
||||
|
||||
pid_file = os.path.join(socket_dir, f"{session_name}.pid")
|
||||
if not os.path.isfile(pid_file):
|
||||
# No PID file — just a stale dir, remove it
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
|
||||
try:
|
||||
daemon_pid = int(Path(pid_file).read_text().strip())
|
||||
except (ValueError, OSError):
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
|
||||
# Check if the daemon is still alive
|
||||
try:
|
||||
os.kill(daemon_pid, 0) # signal 0 = existence check
|
||||
except ProcessLookupError:
|
||||
# Already dead, just clean up the dir
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
except PermissionError:
|
||||
# Alive but owned by someone else — leave it alone
|
||||
continue
|
||||
|
||||
# Daemon is alive and not tracked — orphan. Kill it.
|
||||
try:
|
||||
os.kill(daemon_pid, signal.SIGTERM)
|
||||
logger.info("Reaped orphaned browser daemon PID %d (session %s)",
|
||||
daemon_pid, session_name)
|
||||
reaped += 1
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# Clean up the socket directory
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
|
||||
if reaped:
|
||||
logger.info("Reaped %d orphaned browser session(s) from previous run(s)", reaped)
|
||||
|
||||
|
||||
def _browser_cleanup_thread_worker():
|
||||
"""
|
||||
Background thread that periodically cleans up inactive browser sessions.
|
||||
|
||||
Runs every 30 seconds and checks for sessions that haven't been used
|
||||
within the BROWSER_SESSION_INACTIVITY_TIMEOUT period.
|
||||
On first run, also reaps orphaned sessions from previous process lifetimes.
|
||||
"""
|
||||
# One-time orphan reap on startup
|
||||
try:
|
||||
_reap_orphaned_browser_sessions()
|
||||
except Exception as e:
|
||||
logger.warning("Orphan reap error: %s", e)
|
||||
|
||||
while _cleanup_running:
|
||||
try:
|
||||
_cleanup_inactive_browser_sessions()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ into the user's project directory.
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
|
@ -64,23 +65,72 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", "
|
|||
# Max files to snapshot — skip huge directories to avoid slowdowns.
|
||||
_MAX_FILES = 50_000
|
||||
|
||||
# Valid git commit hash pattern: 4–40 hex chars (short or full SHA-1/SHA-256).
|
||||
_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _validate_commit_hash(commit_hash: str) -> Optional[str]:
|
||||
"""Validate a commit hash to prevent git argument injection.
|
||||
|
||||
Returns an error string if invalid, None if valid.
|
||||
Values starting with '-' would be interpreted as git flags
|
||||
(e.g., '--patch', '-p') instead of revision specifiers.
|
||||
"""
|
||||
if not commit_hash or not commit_hash.strip():
|
||||
return "Empty commit hash"
|
||||
if commit_hash.startswith("-"):
|
||||
return f"Invalid commit hash (must not start with '-'): {commit_hash!r}"
|
||||
if not _COMMIT_HASH_RE.match(commit_hash):
|
||||
return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]:
|
||||
"""Validate a file path to prevent path traversal outside the working directory.
|
||||
|
||||
Returns an error string if invalid, None if valid.
|
||||
"""
|
||||
if not file_path or not file_path.strip():
|
||||
return "Empty file path"
|
||||
# Reject absolute paths — restore targets must be relative to the workdir
|
||||
if os.path.isabs(file_path):
|
||||
return f"File path must be relative, got absolute path: {file_path!r}"
|
||||
# Resolve and check containment within working_dir
|
||||
abs_workdir = _normalize_path(working_dir)
|
||||
resolved = (abs_workdir / file_path).resolve()
|
||||
try:
|
||||
resolved.relative_to(abs_workdir)
|
||||
except ValueError:
|
||||
return f"File path escapes the working directory via traversal: {file_path!r}"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shadow repo helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _normalize_path(path_value: str) -> Path:
|
||||
"""Return a canonical absolute path for checkpoint operations."""
|
||||
return Path(path_value).expanduser().resolve()
|
||||
|
||||
|
||||
def _shadow_repo_path(working_dir: str) -> Path:
|
||||
"""Deterministic shadow repo path: sha256(abs_path)[:16]."""
|
||||
abs_path = str(Path(working_dir).resolve())
|
||||
abs_path = str(_normalize_path(working_dir))
|
||||
dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16]
|
||||
return CHECKPOINT_BASE / dir_hash
|
||||
|
||||
|
||||
def _git_env(shadow_repo: Path, working_dir: str) -> dict:
|
||||
"""Build env dict that redirects git to the shadow repo."""
|
||||
normalized_working_dir = _normalize_path(working_dir)
|
||||
env = os.environ.copy()
|
||||
env["GIT_DIR"] = str(shadow_repo)
|
||||
env["GIT_WORK_TREE"] = str(Path(working_dir).resolve())
|
||||
env["GIT_WORK_TREE"] = str(normalized_working_dir)
|
||||
env.pop("GIT_INDEX_FILE", None)
|
||||
env.pop("GIT_NAMESPACE", None)
|
||||
env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None)
|
||||
|
|
@ -100,7 +150,17 @@ def _run_git(
|
|||
exits while preserving the normal ``ok = (returncode == 0)`` contract.
|
||||
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
|
||||
"""
|
||||
env = _git_env(shadow_repo, working_dir)
|
||||
normalized_working_dir = _normalize_path(working_dir)
|
||||
if not normalized_working_dir.exists():
|
||||
msg = f"working directory not found: {normalized_working_dir}"
|
||||
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
|
||||
return False, "", msg
|
||||
if not normalized_working_dir.is_dir():
|
||||
msg = f"working directory is not a directory: {normalized_working_dir}"
|
||||
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
|
||||
return False, "", msg
|
||||
|
||||
env = _git_env(shadow_repo, str(normalized_working_dir))
|
||||
cmd = ["git"] + list(args)
|
||||
allowed_returncodes = allowed_returncodes or set()
|
||||
try:
|
||||
|
|
@ -110,7 +170,7 @@ def _run_git(
|
|||
text=True,
|
||||
timeout=timeout,
|
||||
env=env,
|
||||
cwd=str(Path(working_dir).resolve()),
|
||||
cwd=str(normalized_working_dir),
|
||||
)
|
||||
ok = result.returncode == 0
|
||||
stdout = result.stdout.strip()
|
||||
|
|
@ -125,9 +185,14 @@ def _run_git(
|
|||
msg = f"git timed out after {timeout}s: {' '.join(cmd)}"
|
||||
logger.error(msg, exc_info=True)
|
||||
return False, "", msg
|
||||
except FileNotFoundError:
|
||||
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
|
||||
return False, "", "git not found"
|
||||
except FileNotFoundError as exc:
|
||||
missing_target = getattr(exc, "filename", None)
|
||||
if missing_target == "git":
|
||||
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
|
||||
return False, "", "git not found"
|
||||
msg = f"working directory not found: {normalized_working_dir}"
|
||||
logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True)
|
||||
return False, "", msg
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True)
|
||||
return False, "", str(exc)
|
||||
|
|
@ -154,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]:
|
|||
)
|
||||
|
||||
(shadow_repo / "HERMES_WORKDIR").write_text(
|
||||
str(Path(working_dir).resolve()) + "\n", encoding="utf-8"
|
||||
str(_normalize_path(working_dir)) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir)
|
||||
|
|
@ -229,7 +294,7 @@ class CheckpointManager:
|
|||
if not self._git_available:
|
||||
return False
|
||||
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
|
||||
# Skip root, home, and other overly broad directories
|
||||
if abs_dir in ("/", str(Path.home())):
|
||||
|
|
@ -254,7 +319,7 @@ class CheckpointManager:
|
|||
Returns a list of dicts with keys: hash, short_hash, timestamp, reason,
|
||||
files_changed, insertions, deletions. Most recent first.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -311,7 +376,12 @@ class CheckpointManager:
|
|||
|
||||
Returns dict with success, diff text, and stat summary.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
# Validate commit_hash to prevent git argument injection
|
||||
hash_err = _validate_commit_hash(commit_hash)
|
||||
if hash_err:
|
||||
return {"success": False, "error": hash_err}
|
||||
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -364,7 +434,19 @@ class CheckpointManager:
|
|||
|
||||
Returns dict with success/error info.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
# Validate commit_hash to prevent git argument injection
|
||||
hash_err = _validate_commit_hash(commit_hash)
|
||||
if hash_err:
|
||||
return {"success": False, "error": hash_err}
|
||||
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
|
||||
# Validate file_path to prevent path traversal outside the working dir
|
||||
if file_path:
|
||||
path_err = _validate_file_path(file_path, abs_dir)
|
||||
if path_err:
|
||||
return {"success": False, "error": path_err}
|
||||
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -413,7 +495,7 @@ class CheckpointManager:
|
|||
(directory containing .git, pyproject.toml, package.json, etc.).
|
||||
Falls back to the file's parent directory.
|
||||
"""
|
||||
path = Path(file_path).resolve()
|
||||
path = _normalize_path(file_path)
|
||||
if path.is_dir():
|
||||
candidate = path
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -924,8 +924,8 @@ def execute_code(
|
|||
|
||||
# --- Local execution path (UDS) --- below this line is unchanged ---
|
||||
|
||||
# Import interrupt event from terminal_tool (cooperative cancellation)
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
# Import per-thread interrupt check (cooperative cancellation)
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
|
||||
# Resolve config
|
||||
_cfg = _load_config()
|
||||
|
|
@ -1114,7 +1114,7 @@ def execute_code(
|
|||
|
||||
status = "success"
|
||||
while proc.poll() is None:
|
||||
if _interrupt_event.is_set():
|
||||
if _is_interrupted():
|
||||
_kill_process_group(proc)
|
||||
status = "interrupted"
|
||||
break
|
||||
|
|
|
|||
|
|
@ -80,20 +80,18 @@ def register_credential_file(
|
|||
|
||||
# Resolve symlinks and normalise ``..`` before the containment check so
|
||||
# that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME.
|
||||
try:
|
||||
resolved = host_path.resolve()
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
resolved.relative_to(hermes_home_resolved) # raises ValueError if outside
|
||||
except ValueError:
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
containment_error = validate_within_dir(host_path, hermes_home)
|
||||
if containment_error:
|
||||
logger.warning(
|
||||
"credential_files: rejected path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
"credential_files: rejected path traversal %r (%s)",
|
||||
relative_path,
|
||||
resolved,
|
||||
hermes_home_resolved,
|
||||
containment_error,
|
||||
)
|
||||
return False
|
||||
|
||||
resolved = host_path.resolve()
|
||||
if not resolved.is_file():
|
||||
logger.debug("credential_files: skipping %s (not found)", resolved)
|
||||
return False
|
||||
|
|
@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]:
|
|||
cfg = read_raw_config()
|
||||
cred_files = cfg.get("terminal", {}).get("credential_files")
|
||||
if isinstance(cred_files, list):
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
for item in cred_files:
|
||||
if isinstance(item, str) and item.strip():
|
||||
rel = item.strip()
|
||||
|
|
@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]:
|
|||
"credential_files: rejected absolute config path %r", rel,
|
||||
)
|
||||
continue
|
||||
host_path = (hermes_home / rel).resolve()
|
||||
try:
|
||||
host_path.relative_to(hermes_home_resolved)
|
||||
except ValueError:
|
||||
host_path = hermes_home / rel
|
||||
containment_error = validate_within_dir(host_path, hermes_home)
|
||||
if containment_error:
|
||||
logger.warning(
|
||||
"credential_files: rejected config path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
rel, host_path, hermes_home_resolved,
|
||||
"credential_files: rejected config path traversal %r (%s)",
|
||||
rel, containment_error,
|
||||
)
|
||||
continue
|
||||
if host_path.is_file():
|
||||
resolved_path = host_path.resolve()
|
||||
if resolved_path.is_file():
|
||||
container_path = f"/root/.hermes/{rel}"
|
||||
result.append({
|
||||
"host_path": str(host_path),
|
||||
"host_path": str(resolved_path),
|
||||
"container_path": container_path,
|
||||
})
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]:
|
|||
)
|
||||
|
||||
# Validate containment after resolution
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
scripts_dir = get_hermes_home() / "scripts"
|
||||
scripts_dir.mkdir(parents=True, exist_ok=True)
|
||||
resolved = (scripts_dir / raw).resolve()
|
||||
try:
|
||||
resolved.relative_to(scripts_dir.resolve())
|
||||
except ValueError:
|
||||
containment_error = validate_within_dir(scripts_dir / raw, scripts_dir)
|
||||
if containment_error:
|
||||
return (
|
||||
f"Script path escapes the scripts directory via traversal: {raw!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
"""Shared interrupt signaling for all tools.
|
||||
"""Per-thread interrupt signaling for all tools.
|
||||
|
||||
Provides a global threading.Event that any tool can check to determine
|
||||
if the user has requested an interrupt. The agent's interrupt() method
|
||||
sets this event, and tools poll it during long-running operations.
|
||||
Provides thread-scoped interrupt tracking so that interrupting one agent
|
||||
session does not kill tools running in other sessions. This is critical
|
||||
in the gateway where multiple agents run concurrently in the same process.
|
||||
|
||||
The agent stores its execution thread ID at the start of run_conversation()
|
||||
and passes it to set_interrupt()/clear_interrupt(). Tools call
|
||||
is_interrupted() which checks the CURRENT thread — no argument needed.
|
||||
|
||||
Usage in tools:
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -12,17 +16,61 @@ Usage in tools:
|
|||
|
||||
import threading
|
||||
|
||||
_interrupt_event = threading.Event()
|
||||
# Set of thread idents that have been interrupted.
|
||||
_interrupted_threads: set[int] = set()
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def set_interrupt(active: bool) -> None:
|
||||
"""Called by the agent to signal or clear the interrupt."""
|
||||
if active:
|
||||
_interrupt_event.set()
|
||||
else:
|
||||
_interrupt_event.clear()
|
||||
def set_interrupt(active: bool, thread_id: int | None = None) -> None:
|
||||
"""Set or clear interrupt for a specific thread.
|
||||
|
||||
Args:
|
||||
active: True to signal interrupt, False to clear it.
|
||||
thread_id: Target thread ident. When None, targets the
|
||||
current thread (backward compat for CLI/tests).
|
||||
"""
|
||||
tid = thread_id if thread_id is not None else threading.current_thread().ident
|
||||
with _lock:
|
||||
if active:
|
||||
_interrupted_threads.add(tid)
|
||||
else:
|
||||
_interrupted_threads.discard(tid)
|
||||
|
||||
|
||||
def is_interrupted() -> bool:
|
||||
"""Check if an interrupt has been requested. Safe to call from any thread."""
|
||||
return _interrupt_event.is_set()
|
||||
"""Check if an interrupt has been requested for the current thread.
|
||||
|
||||
Safe to call from any thread — each thread only sees its own
|
||||
interrupt state.
|
||||
"""
|
||||
tid = threading.current_thread().ident
|
||||
with _lock:
|
||||
return tid in _interrupted_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatible _interrupt_event proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
# Some legacy call sites (code_execution_tool, process_registry, tests)
|
||||
# import _interrupt_event directly and call .is_set() / .set() / .clear().
|
||||
# This shim maps those calls to the per-thread functions above so existing
|
||||
# code keeps working while the underlying mechanism is thread-scoped.
|
||||
|
||||
class _ThreadAwareEventProxy:
|
||||
"""Drop-in proxy that maps threading.Event methods to per-thread state."""
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return is_interrupted()
|
||||
|
||||
def set(self) -> None: # noqa: A003
|
||||
set_interrupt(True)
|
||||
|
||||
def clear(self) -> None:
|
||||
set_interrupt(False)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Not truly supported — returns current state immediately."""
|
||||
return self.is_set()
|
||||
|
||||
|
||||
_interrupt_event = _ThreadAwareEventProxy()
|
||||
|
|
|
|||
43
tools/path_security.py
Normal file
43
tools/path_security.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""Shared path validation helpers for tool implementations.
|
||||
|
||||
Extracts the ``resolve() + relative_to()`` and ``..`` traversal check
|
||||
patterns previously duplicated across skill_manager_tool, skills_tool,
|
||||
skills_hub, cronjob_tools, and credential_files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_within_dir(path: Path, root: Path) -> Optional[str]:
|
||||
"""Ensure *path* resolves to a location within *root*.
|
||||
|
||||
Returns an error message string if validation fails, or ``None`` if the
|
||||
path is safe. Uses ``Path.resolve()`` to follow symlinks and normalize
|
||||
``..`` components.
|
||||
|
||||
Usage::
|
||||
|
||||
error = validate_within_dir(user_path, allowed_root)
|
||||
if error:
|
||||
return json.dumps({"error": error})
|
||||
"""
|
||||
try:
|
||||
resolved = path.resolve()
|
||||
root_resolved = root.resolve()
|
||||
resolved.relative_to(root_resolved)
|
||||
except (ValueError, OSError) as exc:
|
||||
return f"Path escapes allowed directory: {exc}"
|
||||
return None
|
||||
|
||||
|
||||
def has_traversal_component(path_str: str) -> bool:
|
||||
"""Return True if *path_str* contains ``..`` traversal components.
|
||||
|
||||
Quick check for obvious traversal attempts before doing full resolution.
|
||||
"""
|
||||
parts = Path(path_str).parts
|
||||
return ".." in parts
|
||||
|
|
@ -96,6 +96,8 @@ class ProcessSession:
|
|||
# Watcher/notification metadata (persisted for crash recovery)
|
||||
watcher_platform: str = ""
|
||||
watcher_chat_id: str = ""
|
||||
watcher_user_id: str = ""
|
||||
watcher_user_name: str = ""
|
||||
watcher_thread_id: str = ""
|
||||
watcher_interval: int = 0 # 0 = no watcher configured
|
||||
notify_on_complete: bool = False # Queue agent notification on exit
|
||||
|
|
@ -695,7 +697,7 @@ class ProcessRegistry:
|
|||
and output snapshot.
|
||||
"""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
|
||||
try:
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
|
|
@ -732,7 +734,7 @@ class ProcessRegistry:
|
|||
result["timeout_note"] = timeout_note
|
||||
return result
|
||||
|
||||
if _interrupt_event.is_set():
|
||||
if _is_interrupted():
|
||||
result = {
|
||||
"status": "interrupted",
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
|
|
@ -981,6 +983,8 @@ class ProcessRegistry:
|
|||
"session_key": s.session_key,
|
||||
"watcher_platform": s.watcher_platform,
|
||||
"watcher_chat_id": s.watcher_chat_id,
|
||||
"watcher_user_id": s.watcher_user_id,
|
||||
"watcher_user_name": s.watcher_user_name,
|
||||
"watcher_thread_id": s.watcher_thread_id,
|
||||
"watcher_interval": s.watcher_interval,
|
||||
"notify_on_complete": s.notify_on_complete,
|
||||
|
|
@ -1042,6 +1046,8 @@ class ProcessRegistry:
|
|||
detached=True, # Can't read output, but can report status + kill
|
||||
watcher_platform=entry.get("watcher_platform", ""),
|
||||
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||
watcher_user_id=entry.get("watcher_user_id", ""),
|
||||
watcher_user_name=entry.get("watcher_user_name", ""),
|
||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||
watcher_interval=entry.get("watcher_interval", 0),
|
||||
notify_on_complete=entry.get("notify_on_complete", False),
|
||||
|
|
@ -1060,6 +1066,8 @@ class ProcessRegistry:
|
|||
"session_key": session.session_key,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"user_id": session.watcher_user_id,
|
||||
"user_name": session.watcher_user_name,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
"notify_on_complete": session.notify_on_complete,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -219,13 +219,15 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
|||
Validate a file path for write_file/remove_file.
|
||||
Must be under an allowed subdirectory and not escape the skill dir.
|
||||
"""
|
||||
from tools.path_security import has_traversal_component
|
||||
|
||||
if not file_path:
|
||||
return "file_path is required."
|
||||
|
||||
normalized = Path(file_path)
|
||||
|
||||
# Prevent path traversal
|
||||
if ".." in normalized.parts:
|
||||
if has_traversal_component(file_path):
|
||||
return "Path traversal ('..') is not allowed."
|
||||
|
||||
# Must be under an allowed subdirectory
|
||||
|
|
@ -242,15 +244,12 @@ def _validate_file_path(file_path: str) -> Optional[str]:
|
|||
|
||||
def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Path], Optional[str]]:
|
||||
"""Resolve a supporting-file path and ensure it stays within the skill directory."""
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
target = skill_dir / file_path
|
||||
try:
|
||||
resolved = target.resolve(strict=False)
|
||||
skill_dir_resolved = skill_dir.resolve()
|
||||
resolved.relative_to(skill_dir_resolved)
|
||||
except ValueError:
|
||||
return None, "Path escapes skill directory boundary."
|
||||
except OSError as e:
|
||||
return None, f"Invalid file path '{file_path}': {e}"
|
||||
error = validate_within_dir(target, skill_dir)
|
||||
if error:
|
||||
return None, error
|
||||
return target, None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -447,17 +447,8 @@ def _get_category_from_path(skill_path: Path) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def _estimate_tokens(content: str) -> int:
|
||||
"""
|
||||
Rough token estimate (4 chars per token average).
|
||||
|
||||
Args:
|
||||
content: Text content
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
return len(content) // 4
|
||||
# Token estimation — use the shared implementation from model_metadata.
|
||||
from agent.model_metadata import estimate_tokens_rough as _estimate_tokens
|
||||
|
||||
|
||||
def _parse_tags(tags_value) -> List[str]:
|
||||
|
|
@ -947,9 +938,10 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
|
||||
# If a specific file path is requested, read that instead
|
||||
if file_path and skill_dir:
|
||||
from tools.path_security import validate_within_dir, has_traversal_component
|
||||
|
||||
# Security: Prevent path traversal attacks
|
||||
normalized_path = Path(file_path)
|
||||
if ".." in normalized_path.parts:
|
||||
if has_traversal_component(file_path):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
|
|
@ -962,24 +954,13 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
target_file = skill_dir / file_path
|
||||
|
||||
# Security: Verify resolved path is still within skill directory
|
||||
try:
|
||||
resolved = target_file.resolve()
|
||||
skill_dir_resolved = skill_dir.resolve()
|
||||
if not resolved.is_relative_to(skill_dir_resolved):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Path escapes skill directory boundary.",
|
||||
"hint": "Use a relative path within the skill directory",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except (OSError, ValueError):
|
||||
traversal_error = validate_within_dir(target_file, skill_dir)
|
||||
if traversal_error:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Invalid file path: '{file_path}'",
|
||||
"hint": "Use a valid relative path within the skill directory",
|
||||
"error": traversal_error,
|
||||
"hint": "Use a relative path within the skill directory",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1427,8 +1427,12 @@ def terminal_tool(
|
|||
if _gw_platform and not check_interval:
|
||||
_gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "")
|
||||
_gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "")
|
||||
_gw_user_id = _gse("HERMES_SESSION_USER_ID", "")
|
||||
_gw_user_name = _gse("HERMES_SESSION_USER_NAME", "")
|
||||
proc_session.watcher_platform = _gw_platform
|
||||
proc_session.watcher_chat_id = _gw_chat_id
|
||||
proc_session.watcher_user_id = _gw_user_id
|
||||
proc_session.watcher_user_name = _gw_user_name
|
||||
proc_session.watcher_thread_id = _gw_thread_id
|
||||
proc_session.watcher_interval = 5
|
||||
process_registry.pending_watchers.append({
|
||||
|
|
@ -1437,6 +1441,8 @@ def terminal_tool(
|
|||
"session_key": session_key,
|
||||
"platform": _gw_platform,
|
||||
"chat_id": _gw_chat_id,
|
||||
"user_id": _gw_user_id,
|
||||
"user_name": _gw_user_name,
|
||||
"thread_id": _gw_thread_id,
|
||||
"notify_on_complete": True,
|
||||
})
|
||||
|
|
@ -1457,10 +1463,14 @@ def terminal_tool(
|
|||
watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "")
|
||||
watcher_user_id = _gse2("HERMES_SESSION_USER_ID", "")
|
||||
watcher_user_name = _gse2("HERMES_SESSION_USER_NAME", "")
|
||||
|
||||
# Store on session for checkpoint persistence
|
||||
proc_session.watcher_platform = watcher_platform
|
||||
proc_session.watcher_chat_id = watcher_chat_id
|
||||
proc_session.watcher_user_id = watcher_user_id
|
||||
proc_session.watcher_user_name = watcher_user_name
|
||||
proc_session.watcher_thread_id = watcher_thread_id
|
||||
proc_session.watcher_interval = effective_interval
|
||||
|
||||
|
|
@ -1470,6 +1480,8 @@ def terminal_tool(
|
|||
"session_key": session_key,
|
||||
"platform": watcher_platform,
|
||||
"chat_id": watcher_chat_id,
|
||||
"user_id": watcher_user_id,
|
||||
"user_name": watcher_user_name,
|
||||
"thread_id": watcher_thread_id,
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ Defense against context-window overflow operates at three levels:
|
|||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import uuid
|
||||
|
||||
from tools.budget_config import (
|
||||
|
|
@ -79,7 +80,7 @@ def _write_to_sandbox(content: str, remote_path: str, env) -> bool:
|
|||
marker = _heredoc_marker(content)
|
||||
storage_dir = os.path.dirname(remote_path)
|
||||
cmd = (
|
||||
f"mkdir -p {storage_dir} && cat > {remote_path} << '{marker}'\n"
|
||||
f"mkdir -p {shlex.quote(storage_dir)} && cat > {shlex.quote(remote_path)} << '{marker}'\n"
|
||||
f"{content}\n"
|
||||
f"{marker}"
|
||||
)
|
||||
|
|
|
|||
90
utils.py
90
utils.py
|
|
@ -1,13 +1,16 @@
|
|||
"""Shared utility functions for hermes-agent."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TRUTHY_STRINGS = frozenset({"1", "true", "yes", "on"})
|
||||
|
||||
|
|
@ -124,3 +127,88 @@ def atomic_yaml_write(
|
|||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
# ─── JSON Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def safe_json_loads(text: str, default: Any = None) -> Any:
|
||||
"""Parse JSON, returning *default* on any parse error.
|
||||
|
||||
Replaces the ``try: json.loads(x) except (JSONDecodeError, TypeError)``
|
||||
pattern duplicated across display.py, anthropic_adapter.py,
|
||||
auxiliary_client.py, and others.
|
||||
"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def read_json_file(path: Path, default: Any = None) -> Any:
|
||||
"""Read and parse a JSON file, returning *default* on any error.
|
||||
|
||||
Replaces the repeated ``try: json.loads(path.read_text()) except ...``
|
||||
pattern in anthropic_adapter.py, auxiliary_client.py, credential_pool.py,
|
||||
and skill_utils.py.
|
||||
"""
|
||||
try:
|
||||
return json.loads(Path(path).read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError, IOError, ValueError) as exc:
|
||||
logger.debug("Failed to read %s: %s", path, exc)
|
||||
return default
|
||||
|
||||
|
||||
def read_jsonl(path: Path) -> List[dict]:
|
||||
"""Read a JSONL file (one JSON object per line).
|
||||
|
||||
Returns a list of parsed objects, skipping blank lines.
|
||||
"""
|
||||
entries = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
entries.append(json.loads(line))
|
||||
return entries
|
||||
|
||||
|
||||
def append_jsonl(path: Path, entry: dict) -> None:
|
||||
"""Append a single JSON object as a new line to a JSONL file."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
# ─── Environment Variable Helpers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def env_str(key: str, default: str = "") -> str:
|
||||
"""Read an environment variable, stripped of whitespace.
|
||||
|
||||
Replaces the ``os.getenv("X", "").strip()`` pattern repeated 50+ times
|
||||
across runtime_provider.py, anthropic_adapter.py, models.py, etc.
|
||||
"""
|
||||
return os.getenv(key, default).strip()
|
||||
|
||||
|
||||
def env_lower(key: str, default: str = "") -> str:
|
||||
"""Read an environment variable, stripped and lowercased."""
|
||||
return os.getenv(key, default).strip().lower()
|
||||
|
||||
|
||||
def env_int(key: str, default: int = 0) -> int:
|
||||
"""Read an environment variable as an integer, with fallback."""
|
||||
raw = os.getenv(key, "").strip()
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
def env_bool(key: str, default: bool = False) -> bool:
|
||||
"""Read an environment variable as a boolean."""
|
||||
return is_truthy_value(os.getenv(key, ""), default=default)
|
||||
|
|
|
|||
|
|
@ -195,9 +195,12 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI
|
|||
| `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates |
|
||||
| `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist |
|
||||
| `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) |
|
||||
| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) |
|
||||
| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill; also used for webhook signature validation) |
|
||||
| `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) |
|
||||
| `SMS_WEBHOOK_URL` | Public URL for Twilio signature validation — must match the webhook URL in Twilio Console (required) |
|
||||
| `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) |
|
||||
| `SMS_WEBHOOK_HOST` | Webhook bind address (default: `0.0.0.0`) |
|
||||
| `SMS_INSECURE_NO_SIGNATURE` | Set to `true` to disable Twilio signature validation (local dev only — not for production) |
|
||||
| `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat |
|
||||
| `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist |
|
||||
| `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery |
|
||||
|
|
|
|||
|
|
@ -178,6 +178,8 @@ EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com
|
|||
MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c
|
||||
MATRIX_ALLOWED_USERS=@alice:matrix.org
|
||||
DINGTALK_ALLOWED_USERS=user-id-1
|
||||
FEISHU_ALLOWED_USERS=ou_xxxxxxxx,ou_yyyyyyyy
|
||||
WECOM_ALLOWED_USERS=user-id-1,user-id-2
|
||||
|
||||
# Or allow
|
||||
GATEWAY_ALLOWED_USERS=123456789,987654321
|
||||
|
|
|
|||
|
|
@ -84,6 +84,13 @@ ngrok http 8080
|
|||
Set the resulting public URL as your Twilio webhook.
|
||||
:::
|
||||
|
||||
**Set `SMS_WEBHOOK_URL` to the same URL you configured in Twilio.** This is required for Twilio signature validation — the adapter will refuse to start without it:
|
||||
|
||||
```bash
|
||||
# Must match the webhook URL in your Twilio Console
|
||||
SMS_WEBHOOK_URL=https://your-server:8080/webhooks/twilio
|
||||
```
|
||||
|
||||
The webhook port defaults to `8080`. Override with:
|
||||
|
||||
```bash
|
||||
|
|
@ -101,9 +108,11 @@ hermes gateway
|
|||
You should see:
|
||||
|
||||
```
|
||||
[sms] Twilio webhook server listening on port 8080, from: +1555***4567
|
||||
[sms] Twilio webhook server listening on 0.0.0.0:8080, from: +1555***4567
|
||||
```
|
||||
|
||||
If you see `Refusing to start: SMS_WEBHOOK_URL is required`, set `SMS_WEBHOOK_URL` to the public URL configured in your Twilio Console (see Step 3).
|
||||
|
||||
Text your Twilio number — Hermes will respond via SMS.
|
||||
|
||||
---
|
||||
|
|
@ -113,9 +122,12 @@ Text your Twilio number — Hermes will respond via SMS.
|
|||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) |
|
||||
| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token |
|
||||
| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token (also used for webhook signature validation) |
|
||||
| `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) |
|
||||
| `SMS_WEBHOOK_URL` | Yes | Public URL for Twilio signature validation — must match the webhook URL in your Twilio Console |
|
||||
| `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) |
|
||||
| `SMS_WEBHOOK_HOST` | No | Webhook bind address (default: `0.0.0.0`) |
|
||||
| `SMS_INSECURE_NO_SIGNATURE` | No | Set to `true` to disable signature validation (local dev only — **not for production**) |
|
||||
| `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat |
|
||||
| `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) |
|
||||
| `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery |
|
||||
|
|
@ -134,6 +146,21 @@ Text your Twilio number — Hermes will respond via SMS.
|
|||
|
||||
## Security
|
||||
|
||||
### Webhook signature validation
|
||||
|
||||
Hermes validates that inbound webhooks genuinely originate from Twilio by verifying the `X-Twilio-Signature` header (HMAC-SHA1). This prevents attackers from injecting forged messages.
|
||||
|
||||
**`SMS_WEBHOOK_URL` is required.** Set it to the public URL configured in your Twilio Console. The adapter will refuse to start without it.
|
||||
|
||||
For local development without a public URL, you can disable validation:
|
||||
|
||||
```bash
|
||||
# Local dev only — NOT for production
|
||||
SMS_INSECURE_NO_SIGNATURE=true
|
||||
```
|
||||
|
||||
### User allowlists
|
||||
|
||||
**The gateway denies all users by default.** Configure an allowlist:
|
||||
|
||||
```bash
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue