mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Merge branch 'feat/ink-refactor' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
5552e1ffe1
152 changed files with 7146 additions and 2558 deletions
|
|
@ -89,6 +89,15 @@
|
|||
# Optional base URL override:
|
||||
# HERMES_QWEN_BASE_URL=https://portal.qwen.ai/v1
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Xiaomi MiMo)
|
||||
# =============================================================================
|
||||
# Xiaomi MiMo models (mimo-v2-pro, mimo-v2-omni, mimo-v2-flash).
|
||||
# Get your key at: https://platform.xiaomimimo.com
|
||||
# XIAOMI_API_KEY=your_key_here
|
||||
# Optional base URL override:
|
||||
# XIAOMI_BASE_URL=https://api.xiaomimimo.com/v1
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -23,17 +23,13 @@ Resolution order for vision/multimodal tasks (auto mode):
|
|||
6. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.)
|
||||
7. None
|
||||
|
||||
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task.
|
||||
Per-task overrides are configured in config.yaml under the ``auxiliary:`` section
|
||||
(e.g. ``auxiliary.vision.provider``, ``auxiliary.compression.model``).
|
||||
Default "auto" follows the chains above.
|
||||
|
||||
Per-task model overrides (e.g. AUXILIARY_VISION_MODEL,
|
||||
AUXILIARY_WEB_EXTRACT_MODEL) let callers use a different model slug
|
||||
than the provider's default.
|
||||
|
||||
Per-task direct endpoint overrides (e.g. AUXILIARY_VISION_BASE_URL,
|
||||
AUXILIARY_VISION_API_KEY) let callers route a specific auxiliary task to a
|
||||
custom OpenAI-compatible endpoint without touching the main model settings.
|
||||
Legacy env var overrides (AUXILIARY_{TASK}_PROVIDER, AUXILIARY_{TASK}_MODEL,
|
||||
AUXILIARY_{TASK}_BASE_URL, etc.) are still read as a backward-compat fallback
|
||||
but config.yaml takes priority. New configuration should always use config.yaml.
|
||||
|
||||
Payment / credit exhaustion fallback:
|
||||
When a resolved provider returns HTTP 402 or a credit-related error,
|
||||
|
|
@ -111,6 +107,14 @@ _API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
|||
"kilocode": "google/gemini-3-flash-preview",
|
||||
}
|
||||
|
||||
# Vision-specific model overrides for direct providers.
|
||||
# When the user's main provider has a dedicated vision/multimodal model that
|
||||
# differs from their main chat model, map it here. The vision auto-detect
|
||||
# "exotic provider" branch checks this before falling back to the main model.
|
||||
_PROVIDER_VISION_MODELS: Dict[str, str] = {
|
||||
"xiaomi": "mimo-v2-omni",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
|
||||
|
|
@ -1687,16 +1691,18 @@ def resolve_vision_provider_client(
|
|||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, named custom, etc.)
|
||||
# Exotic provider (DeepSeek, Alibaba, Xiaomi, named custom, etc.)
|
||||
# Use provider-specific vision model if available, otherwise main model.
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, main_model)
|
||||
main_provider, vision_model)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or main_model,
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or main_model)
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
|
||||
# Fall back through aggregators.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
|
|
@ -1958,8 +1964,8 @@ def _resolve_task_provider_model(
|
|||
|
||||
Priority:
|
||||
1. Explicit provider/model/base_url/api_key args (always win)
|
||||
2. Env var overrides (AUXILIARY_{TASK}_*, CONTEXT_{TASK}_*)
|
||||
3. Config file (auxiliary.{task}.* or compression.*)
|
||||
2. Config file (auxiliary.{task}.* or compression.*)
|
||||
3. Env var overrides (backward-compat: AUXILIARY_{TASK}_*, CONTEXT_{TASK}_*)
|
||||
4. "auto" (full auto-detection chain)
|
||||
|
||||
Returns (provider, model, base_url, api_key, api_mode) where model may
|
||||
|
|
@ -2002,10 +2008,11 @@ def _resolve_task_provider_model(
|
|||
_sbu = comp.get("summary_base_url") or ""
|
||||
cfg_base_url = cfg_base_url or _sbu.strip() or None
|
||||
|
||||
# Env vars are backward-compat fallback only — config.yaml is primary.
|
||||
env_model = _get_auxiliary_env_override(task, "MODEL") if task else None
|
||||
env_api_mode = _get_auxiliary_env_override(task, "API_MODE") if task else None
|
||||
resolved_model = model or env_model or cfg_model
|
||||
resolved_api_mode = env_api_mode or cfg_api_mode
|
||||
resolved_model = model or cfg_model or env_model
|
||||
resolved_api_mode = cfg_api_mode or env_api_mode
|
||||
|
||||
if base_url:
|
||||
return "custom", resolved_model, base_url, api_key, resolved_api_mode
|
||||
|
|
@ -2013,19 +2020,23 @@ def _resolve_task_provider_model(
|
|||
return provider, resolved_model, base_url, api_key, resolved_api_mode
|
||||
|
||||
if task:
|
||||
# Config.yaml is the primary source for per-task overrides.
|
||||
if cfg_base_url:
|
||||
return "custom", resolved_model, cfg_base_url, cfg_api_key, resolved_api_mode
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, resolved_model, None, None, resolved_api_mode
|
||||
|
||||
# Env vars are backward-compat fallback for users who haven't
|
||||
# migrated to config.yaml yet.
|
||||
env_base_url = _get_auxiliary_env_override(task, "BASE_URL")
|
||||
env_api_key = _get_auxiliary_env_override(task, "API_KEY")
|
||||
if env_base_url:
|
||||
return "custom", resolved_model, env_base_url, env_api_key or cfg_api_key, resolved_api_mode
|
||||
return "custom", resolved_model, env_base_url, env_api_key, resolved_api_mode
|
||||
|
||||
env_provider = _get_auxiliary_provider(task)
|
||||
if env_provider != "auto":
|
||||
return env_provider, resolved_model, None, None, resolved_api_mode
|
||||
|
||||
if cfg_base_url:
|
||||
return "custom", resolved_model, cfg_base_url, cfg_api_key, resolved_api_mode
|
||||
if cfg_provider and cfg_provider != "auto":
|
||||
return cfg_provider, resolved_model, None, None, resolved_api_mode
|
||||
return "auto", resolved_model, None, None, resolved_api_mode
|
||||
|
||||
return "auto", resolved_model, None, None, resolved_api_mode
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Pure display functions and classes with no AIAgent dependency.
|
|||
Used by AIAgent._execute_tool_calls for CLI feedback.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -14,6 +13,8 @@ from dataclasses import dataclass, field
|
|||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
from utils import safe_json_loads
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
_RESET = "\033[0m"
|
||||
|
|
@ -372,9 +373,8 @@ def _result_succeeded(result: str | None) -> bool:
|
|||
"""Conservatively detect whether a tool result represents success."""
|
||||
if not result:
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = safe_json_loads(result)
|
||||
if data is None:
|
||||
return False
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
|
@ -423,10 +423,7 @@ def extract_edit_diff(
|
|||
) -> str | None:
|
||||
"""Extract a unified diff from a file-edit tool result."""
|
||||
if tool_name == "patch" and result:
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = None
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
diff = data.get("diff")
|
||||
if isinstance(diff, str) and diff.strip():
|
||||
|
|
@ -780,23 +777,19 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
|||
return False, ""
|
||||
|
||||
if tool_name == "terminal":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
exit_code = data.get("exit_code")
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return True, f" [exit {exit_code}]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse terminal result as JSON for exit code check")
|
||||
return False, ""
|
||||
|
||||
# Memory-specific: distinguish "full" from real errors
|
||||
if tool_name == "memory":
|
||||
try:
|
||||
data = json.loads(result)
|
||||
data = safe_json_loads(result)
|
||||
if isinstance(data, dict):
|
||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
||||
return True, " [full]"
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
logger.debug("Could not parse memory result as JSON for capacity check")
|
||||
|
||||
# Generic heuristic for non-terminal tools
|
||||
lower = result[:500].lower()
|
||||
|
|
|
|||
|
|
@ -27,12 +27,14 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
|||
"gemini", "zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"qwen-oauth",
|
||||
"xiaomi",
|
||||
"custom", "local",
|
||||
# Common aliases
|
||||
"google", "google-gemini", "google-ai-studio",
|
||||
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
||||
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
"mimo", "xiaomi-mimo",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
|
|
@ -149,9 +151,10 @@ DEFAULT_CONTEXT_LENGTHS = {
|
|||
"moonshotai/Kimi-K2.5": 262144,
|
||||
"moonshotai/Kimi-K2-Thinking": 262144,
|
||||
"MiniMaxAI/MiniMax-M2.5": 204800,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 32768,
|
||||
"mimo-v2-pro": 1048576,
|
||||
"mimo-v2-omni": 1048576,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 256000,
|
||||
"mimo-v2-pro": 1000000,
|
||||
"mimo-v2-omni": 256000,
|
||||
"mimo-v2-flash": 256000,
|
||||
"zai-org/GLM-5": 202752,
|
||||
}
|
||||
|
||||
|
|
@ -176,6 +179,12 @@ _MAX_COMPLETION_KEYS = (
|
|||
|
||||
# Local server hostnames / address patterns
|
||||
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||
# Docker / Podman / Lima DNS names that resolve to the host machine
|
||||
_CONTAINER_LOCAL_SUFFIXES = (
|
||||
".docker.internal",
|
||||
".containers.internal",
|
||||
".lima.internal",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
|
|
@ -211,6 +220,8 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
|||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -249,6 +260,9 @@ def is_local_endpoint(base_url: str) -> bool:
|
|||
return False
|
||||
if host in _LOCAL_HOSTS:
|
||||
return True
|
||||
# Docker / Podman / Lima internal DNS names (e.g. host.docker.internal)
|
||||
if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES):
|
||||
return True
|
||||
# RFC-1918 private ranges and link-local
|
||||
import ipaddress
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
|||
"gemini": "google",
|
||||
"google": "google",
|
||||
"xai": "xai",
|
||||
"xiaomi": "xiaomi",
|
||||
"nvidia": "nvidia",
|
||||
"groq": "groq",
|
||||
"mistral": "mistral",
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import threading
|
|||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, get_skills_dir
|
||||
from typing import Optional
|
||||
|
||||
from agent.skill_utils import (
|
||||
|
|
@ -548,8 +548,7 @@ def build_skills_system_prompt(
|
|||
are read-only — they appear in the index but new skills are always created
|
||||
in the local dir. Local skills take precedence when names collide.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
skills_dir = get_skills_dir()
|
||||
external_dirs = get_all_skills_dirs()[1:] # skip local (index 0)
|
||||
|
||||
if not skills_dir.exists() and not external_dirs:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_skills_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]:
|
|||
Reads the config file directly (no CLI config imports) to stay
|
||||
lightweight.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return set()
|
||||
try:
|
||||
|
|
@ -178,7 +178,7 @@ def get_external_skills_dirs() -> List[Path]:
|
|||
path. Only directories that actually exist are returned. Duplicates and
|
||||
paths that resolve to the local ``~/.hermes/skills/`` are silently skipped.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
return []
|
||||
try:
|
||||
|
|
@ -200,7 +200,7 @@ def get_external_skills_dirs() -> List[Path]:
|
|||
if not isinstance(raw_dirs, list):
|
||||
return []
|
||||
|
||||
local_skills = (get_hermes_home() / "skills").resolve()
|
||||
local_skills = get_skills_dir().resolve()
|
||||
seen: Set[Path] = set()
|
||||
result: List[Path] = []
|
||||
|
||||
|
|
@ -230,7 +230,7 @@ def get_all_skills_dirs() -> List[Path]:
|
|||
The local dir is always first (and always included even if it doesn't exist
|
||||
yet — callers handle that). External dirs follow in config order.
|
||||
"""
|
||||
dirs = [get_hermes_home() / "skills"]
|
||||
dirs = [get_skills_dir()]
|
||||
dirs.extend(get_external_skills_dirs())
|
||||
return dirs
|
||||
|
||||
|
|
@ -384,7 +384,7 @@ def resolve_skill_config_values(
|
|||
current values (or the declared default if the key isn't set).
|
||||
Path values are expanded via ``os.path.expanduser``.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
config: Dict[str, Any] = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ model:
|
|||
# "minimax" - MiniMax global (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY)
|
||||
# "huggingface" - Hugging Face Inference (requires: HF_TOKEN)
|
||||
# "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY)
|
||||
# "kilocode" - KiloCode gateway (requires: KILOCODE_API_KEY)
|
||||
# "ai-gateway" - Vercel AI Gateway (requires: AI_GATEWAY_API_KEY)
|
||||
#
|
||||
|
|
|
|||
9
cli.py
9
cli.py
|
|
@ -2748,6 +2748,15 @@ class HermesCLI:
|
|||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
# When a custom_provider entry carries an explicit `model` field,
|
||||
# use it as the effective model name. Without this, running
|
||||
# `hermes chat --model <provider-name>` sends the provider name
|
||||
# (e.g. "my-provider") as the model string to the API instead of
|
||||
# the configured model (e.g. "qwen3.6-plus"), causing 400 errors.
|
||||
runtime_model = runtime.get("model")
|
||||
if runtime_model and isinstance(runtime_model, str):
|
||||
self.model = runtime_model
|
||||
|
||||
# Normalize model for the resolved provider (e.g. swap non-Codex
|
||||
# models when provider is openai-codex). Fixes #651.
|
||||
model_changed = self._normalize_model_for_provider(resolved_provider)
|
||||
|
|
|
|||
|
|
@ -722,6 +722,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
|||
provider_sort=pr.get("sort"),
|
||||
disabled_toolsets=["cronjob", "messaging", "clarify"],
|
||||
quiet_mode=True,
|
||||
skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd
|
||||
skip_memory=True, # Cron system prompts would corrupt user representations
|
||||
platform="cron",
|
||||
session_id=_cron_session_id,
|
||||
|
|
|
|||
|
|
@ -11,12 +11,14 @@ When you run `hermes setup` for the first time and Hermes detects `~/.openclaw`,
|
|||
### 2. CLI Command (quick, scriptable)
|
||||
|
||||
```bash
|
||||
hermes claw migrate # Full migration with confirmation prompt
|
||||
hermes claw migrate --dry-run # Preview what would happen
|
||||
hermes claw migrate # Preview then migrate (always shows preview first)
|
||||
hermes claw migrate --dry-run # Preview only, no changes
|
||||
hermes claw migrate --preset user-data # Migrate without API keys/secrets
|
||||
hermes claw migrate --yes # Skip confirmation prompt
|
||||
```
|
||||
|
||||
The migration always shows a full preview of what will be imported before making any changes. You review the preview and confirm before anything is written.
|
||||
|
||||
**All options:**
|
||||
|
||||
| Flag | Description |
|
||||
|
|
@ -39,7 +41,7 @@ Ask the agent to run the migration for you:
|
|||
```
|
||||
|
||||
The agent will use the `openclaw-migration` skill to:
|
||||
1. Run a dry-run first to preview changes
|
||||
1. Run a preview first to show what would change
|
||||
2. Ask about conflict resolution (SOUL.md, skills, etc.)
|
||||
3. Let you choose between `user-data` and `full` presets
|
||||
4. Execute the migration with your choices
|
||||
|
|
@ -58,16 +60,31 @@ The agent will use the `openclaw-migration` skill to:
|
|||
| Messaging settings | `~/.openclaw/config.yaml` (TELEGRAM_ALLOWED_USERS, MESSAGING_CWD) | `~/.hermes/.env` |
|
||||
| TTS assets | `~/.openclaw/workspace/tts/` | `~/.hermes/tts/` |
|
||||
|
||||
Workspace files are also checked at `workspace.default/` and `workspace-main/` as fallback paths (OpenClaw renamed `workspace/` to `workspace-main/` in recent versions).
|
||||
|
||||
### `full` preset (adds to `user-data`)
|
||||
| Item | Source | Destination |
|
||||
|------|--------|-------------|
|
||||
| Telegram bot token | `~/.openclaw/config.yaml` | `~/.hermes/.env` |
|
||||
| OpenRouter API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| OpenAI API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| Anthropic API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| ElevenLabs API key | `~/.openclaw/.env` or config | `~/.hermes/.env` |
|
||||
| Telegram bot token | `openclaw.json` channels config | `~/.hermes/.env` |
|
||||
| OpenRouter API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| OpenAI API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| Anthropic API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
| ElevenLabs API key | `.env`, `openclaw.json`, or `openclaw.json["env"]` | `~/.hermes/.env` |
|
||||
|
||||
Only these 6 allowlisted secrets are ever imported. Other credentials are skipped and reported.
|
||||
API keys are searched across four sources: inline config values, `~/.openclaw/.env`, the `openclaw.json` `"env"` sub-object, and per-agent auth profiles.
|
||||
|
||||
Only allowlisted secrets are ever imported. Other credentials are skipped and reported.
|
||||
|
||||
## OpenClaw Schema Compatibility
|
||||
|
||||
The migration handles both old and current OpenClaw config layouts:
|
||||
|
||||
- **Channel tokens**: Reads from flat paths (`channels.telegram.botToken`) and the newer `accounts.default` layout (`channels.telegram.accounts.default.botToken`)
|
||||
- **TTS provider**: OpenClaw renamed "edge" to "microsoft" — both are recognized and mapped to Hermes' "edge"
|
||||
- **Provider API types**: Both short (`openai`, `anthropic`) and hyphenated (`openai-completions`, `anthropic-messages`, `google-generative-ai`) values are mapped correctly
|
||||
- **thinkingDefault**: All enum values are handled including newer ones (`minimal`, `xhigh`, `adaptive`)
|
||||
- **Matrix**: Uses `accessToken` field (not `botToken`)
|
||||
- **SecretRef formats**: Plain strings, env templates (`${VAR}`), and `source: "env"` SecretRefs are resolved. `source: "file"` and `source: "exec"` SecretRefs produce a warning — add those keys manually after migration.
|
||||
|
||||
## Conflict Handling
|
||||
|
||||
|
|
@ -84,18 +101,24 @@ For skills, you can also use `--skill-conflict rename` to import conflicting ski
|
|||
|
||||
## Migration Report
|
||||
|
||||
Every migration (including dry runs) produces a report showing:
|
||||
Every migration produces a report showing:
|
||||
- **Migrated items** — what was successfully imported
|
||||
- **Conflicts** — items skipped because they already exist
|
||||
- **Skipped items** — items not found in the source
|
||||
- **Errors** — items that failed to import
|
||||
|
||||
For execute runs, the full report is saved to `~/.hermes/migration/openclaw/<timestamp>/`.
|
||||
For executed migrations, the full report is saved to `~/.hermes/migration/openclaw/<timestamp>/`.
|
||||
|
||||
## Post-Migration Notes
|
||||
|
||||
- **Skills require a new session** — imported skills take effect after restarting your agent or starting a new chat.
|
||||
- **WhatsApp requires re-pairing** — WhatsApp uses QR-code pairing, not token-based auth. Run `hermes whatsapp` to pair.
|
||||
- **Archive cleanup** — after migration, you'll be offered to rename `~/.openclaw/` to `.openclaw.pre-migration/` to prevent state confusion. You can also run `hermes claw cleanup` later.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "OpenClaw directory not found"
|
||||
The migration looks for `~/.openclaw` by default. If your OpenClaw is installed elsewhere, use `--source`:
|
||||
The migration looks for `~/.openclaw` by default, then tries `~/.clawdbot` and `~/.moldbot`. If your OpenClaw is installed elsewhere, use `--source`:
|
||||
```bash
|
||||
hermes claw migrate --source /path/to/.openclaw
|
||||
```
|
||||
|
|
@ -108,3 +131,12 @@ hermes skills install openclaw-migration
|
|||
|
||||
### Memory overflow
|
||||
If your OpenClaw MEMORY.md or USER.md exceeds Hermes' character limits, excess entries are exported to an overflow file in the migration report directory. You can manually review and add the most important ones.
|
||||
|
||||
### API keys not found
|
||||
Keys might be stored in different places depending on your OpenClaw setup:
|
||||
- `~/.openclaw/.env` file
|
||||
- Inline in `openclaw.json` under `models.providers.*.apiKey`
|
||||
- In `openclaw.json` under the `"env"` or `"env.vars"` sub-objects
|
||||
- In `~/.openclaw/agents/main/agent/auth-profiles.json`
|
||||
|
||||
The migration checks all four. If keys use `source: "file"` or `source: "exec"` SecretRefs, they can't be resolved automatically — add them via `hermes config set`.
|
||||
|
|
|
|||
|
|
@ -1017,6 +1017,9 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
|||
weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip()
|
||||
if weixin_group_allowed_users:
|
||||
extra["group_allow_from"] = weixin_group_allowed_users
|
||||
weixin_split_multiline = os.getenv("WEIXIN_SPLIT_MULTILINE_MESSAGES", "").strip()
|
||||
if weixin_split_multiline:
|
||||
extra["split_multiline_messages"] = weixin_split_multiline
|
||||
weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip()
|
||||
if weixin_home:
|
||||
config.platforms[Platform.WEIXIN].home_channel = HomeChannel(
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1"
|
|||
DEFAULT_PORT = 8642
|
||||
MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
|
|
@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"""
|
||||
import queue as _q
|
||||
|
||||
sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}
|
||||
sse_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
# CORS middleware can't inject headers into StreamResponse after
|
||||
# prepare() flushes them, so resolve CORS headers up front.
|
||||
origin = request.headers.get("Origin", "")
|
||||
|
|
@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
|
|
@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Helper — route a queue item to the correct SSE event.
|
||||
async def _emit(item):
|
||||
|
|
@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
"choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
return time.monotonic()
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
|
||||
await response.write(b": keepalive\n\n")
|
||||
last_activity = time.monotonic()
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
|
|
|||
|
|
@ -823,7 +823,36 @@ class BasePlatformAdapter(ABC):
|
|||
result = handler(self)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
|
||||
def _acquire_platform_lock(self, scope: str, identity: str, resource_desc: str) -> bool:
|
||||
"""Acquire a scoped lock for this adapter. Returns True on success."""
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._platform_lock_scope = scope
|
||||
self._platform_lock_identity = identity
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
scope, identity, metadata={'platform': self.platform.value}
|
||||
)
|
||||
if acquired:
|
||||
return True
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = (
|
||||
f'{resource_desc} already in use'
|
||||
+ (f' (PID {owner_pid})' if owner_pid else '')
|
||||
+ '. Stop the other gateway first.'
|
||||
)
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error(f'{scope}_lock', message, retryable=False)
|
||||
return False
|
||||
|
||||
def _release_platform_lock(self) -> None:
|
||||
"""Release the scoped lock acquired by _acquire_platform_lock."""
|
||||
identity = getattr(self, '_platform_lock_identity', None)
|
||||
if not identity:
|
||||
return
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock(self._platform_lock_scope, identity)
|
||||
self._platform_lock_identity = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from gateway.platforms.base import (
|
|||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
)
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -89,18 +90,7 @@ def _normalize_server_url(raw: str) -> str:
|
|||
return value.rstrip("/")
|
||||
|
||||
|
||||
def _strip_markdown(text: str) -> str:
|
||||
"""Strip common markdown formatting for iMessage plain-text delivery."""
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"__(.+?)__", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"_(.+?)_", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-zA-Z0-9_+-]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\[([^\]]+)\]\(([^\)]+)\)", r"\1", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -393,7 +383,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
|||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
text = _strip_markdown(content or "")
|
||||
text = strip_markdown(content or "")
|
||||
if not text:
|
||||
return SendResult(success=False, error="BlueBubbles send requires text")
|
||||
chunks = self.truncate_message(text, max_length=self.MAX_MESSAGE_LENGTH)
|
||||
|
|
@ -679,7 +669,7 @@ class BlueBubblesAdapter(BasePlatformAdapter):
|
|||
return info
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
return _strip_markdown(content)
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound attachment downloading (from #4588)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ except ImportError:
|
|||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -52,8 +53,6 @@ from gateway.platforms.base import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MESSAGE_LENGTH = 20000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
_SESSION_WEBHOOKS_MAX = 500
|
||||
_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/')
|
||||
|
|
@ -89,8 +88,8 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
self._stream_task: Optional[asyncio.Task] = None
|
||||
self._http_client: Optional["httpx.AsyncClient"] = None
|
||||
|
||||
# Message deduplication: msg_id -> timestamp
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
# Message deduplication
|
||||
self._dedup = MessageDeduplicator(max_size=1000)
|
||||
# Map chat_id -> session_webhook for reply routing
|
||||
self._session_webhooks: Dict[str, str] = {}
|
||||
|
||||
|
|
@ -170,7 +169,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
|
||||
self._stream_client = None
|
||||
self._session_webhooks.clear()
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
# -- Inbound message processing -----------------------------------------
|
||||
|
|
@ -178,7 +177,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
async def _on_message(self, message: "ChatbotMessage") -> None:
|
||||
"""Process an incoming DingTalk chatbot message."""
|
||||
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
|
||||
return
|
||||
|
||||
|
|
@ -256,20 +255,6 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
content = " ".join(parts).strip()
|
||||
return content
|
||||
|
||||
# -- Deduplication ------------------------------------------------------
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Check and record a message ID. Returns True if already seen."""
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
# -- Outbound messaging -------------------------------------------------
|
||||
|
||||
async def send(
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
|||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.helpers import MessageDeduplicator, ThreadParticipationTracker
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -450,18 +451,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Track threads where the bot has participated so follow-up messages
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._bot_task: Optional[asyncio.Task] = None
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
# Dedup cache: message_id → timestamp. Prevents duplicate bot
|
||||
# responses when Discord RESUME replays events after reconnects.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Discord
|
||||
# RESUME replays events after reconnects.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Reply threading mode: "off" (no replies), "first" (reply on first
|
||||
# chunk only, default), "all" (reply-reference on every chunk).
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
|
|
@ -502,18 +499,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate bot token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('discord_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('discord-bot-token', self.config.token, 'Discord bot token'):
|
||||
return False
|
||||
|
||||
|
||||
# Parse allowed user entries (may contain usernames or IDs)
|
||||
allowed_env = os.getenv("DISCORD_ALLOWED_USERS", "")
|
||||
if allowed_env:
|
||||
|
|
@ -569,17 +557,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
msg_id = str(message.id)
|
||||
now = time.time()
|
||||
if msg_id in adapter_self._seen_messages:
|
||||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||||
return
|
||||
adapter_self._seen_messages[msg_id] = now
|
||||
if len(adapter_self._seen_messages) > adapter_self._SEEN_MAX:
|
||||
cutoff = now - adapter_self._SEEN_TTL
|
||||
adapter_self._seen_messages = {
|
||||
k: v for k, v in adapter_self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
|
||||
# Always ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
|
|
@ -685,23 +664,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[%s] Timeout waiting for connection to Discord", self.name, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to connect to Discord: %s", self.name, e, exc_info=True)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
|
|
@ -723,14 +690,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self._client = None
|
||||
self._ready_event.clear()
|
||||
|
||||
# Release the token lock
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('discord-bot-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
|
|
@ -1870,7 +1830,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Track thread participation so follow-ups don't require @mention
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# If a message was provided, kick off a new Hermes session in the thread
|
||||
starter = (message or "").strip()
|
||||
|
|
@ -2241,49 +2201,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "discord_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load discord thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
# Trim to most recent entries if over cap
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save discord thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
|
|
@ -2335,7 +2252,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
in_bot_thread = is_thread and thread_id in self._bot_participated_threads
|
||||
in_bot_thread = is_thread and thread_id in self._threads
|
||||
|
||||
if require_mention and not is_free_channel and not in_bot_thread:
|
||||
if self._client.user not in message.mentions:
|
||||
|
|
@ -2361,7 +2278,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
|
|
@ -2545,7 +2462,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Track thread participation so the bot won't require @mention for
|
||||
# follow-up messages in threads it has already engaged in.
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Only batch plain text messages — commands, media, etc. dispatch
|
||||
# immediately since they won't be split by the Discord client.
|
||||
|
|
|
|||
|
|
@ -360,19 +360,21 @@ def _render_code_block_element(element: Dict[str, Any]) -> str:
|
|||
|
||||
|
||||
def _strip_markdown_to_plain_text(text: str) -> str:
|
||||
"""Strip markdown formatting to plain text for Feishu text fallbacks.
|
||||
|
||||
Delegates common markdown stripping to the shared helper and adds
|
||||
Feishu-specific patterns (blockquotes, strikethrough, underline tags,
|
||||
horizontal rules, \\r\\n normalisation).
|
||||
"""
|
||||
from gateway.platforms.helpers import strip_markdown
|
||||
plain = text.replace("\r\n", "\n")
|
||||
plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain)
|
||||
plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE)
|
||||
plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain)
|
||||
plain = re.sub(r"`([^`\n]+)`", r"\1", plain)
|
||||
plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain)
|
||||
plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain)
|
||||
plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain)
|
||||
plain = re.sub(r"<u>([\s\S]*?)</u>", r"\1", plain)
|
||||
plain = re.sub(r"\n{3,}", "\n\n", plain)
|
||||
return plain.strip()
|
||||
plain = strip_markdown(plain)
|
||||
return plain
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: Optional[int] = None, min_value: int = 0) -> Optional[int]:
|
||||
|
|
|
|||
261
gateway/platforms/helpers.py
Normal file
261
gateway/platforms/helpers.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""Shared helper classes for gateway platform adapters.
|
||||
|
||||
Extracts common patterns that were duplicated across 5-7 adapters:
|
||||
message deduplication, text batch aggregation, markdown stripping,
|
||||
and thread participation tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Message Deduplication ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MessageDeduplicator:
|
||||
"""TTL-based message deduplication cache.
|
||||
|
||||
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
|
||||
previously duplicated in discord, slack, dingtalk, wecom, weixin,
|
||||
mattermost, and feishu adapters.
|
||||
|
||||
Usage::
|
||||
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# In message handler:
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
|
||||
self._seen: Dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
def is_duplicate(self, msg_id: str) -> bool:
|
||||
"""Return True if *msg_id* was already seen within the TTL window."""
|
||||
if not msg_id:
|
||||
return False
|
||||
now = time.time()
|
||||
if msg_id in self._seen:
|
||||
return True
|
||||
self._seen[msg_id] = now
|
||||
if len(self._seen) > self._max_size:
|
||||
cutoff = now - self._ttl
|
||||
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all tracked messages."""
|
||||
self._seen.clear()
|
||||
|
||||
|
||||
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TextBatchAggregator:
|
||||
"""Aggregates rapid-fire text events into single messages.
|
||||
|
||||
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
|
||||
previously duplicated in telegram, discord, matrix, wecom, and feishu.
|
||||
|
||||
Usage::
|
||||
|
||||
self._text_batcher = TextBatchAggregator(
|
||||
handler=self._message_handler,
|
||||
batch_delay=0.6,
|
||||
split_threshold=1900,
|
||||
)
|
||||
|
||||
# In message dispatch:
|
||||
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
|
||||
self._text_batcher.enqueue(event, session_key)
|
||||
return
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
*,
|
||||
batch_delay: float = 0.6,
|
||||
split_delay: float = 2.0,
|
||||
split_threshold: int = 4000,
|
||||
):
|
||||
self._handler = handler
|
||||
self._batch_delay = batch_delay
|
||||
self._split_delay = split_delay
|
||||
self._split_threshold = split_threshold
|
||||
self._pending: Dict[str, "MessageEvent"] = {}
|
||||
self._pending_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if batching is active (delay > 0)."""
|
||||
return self._batch_delay > 0
|
||||
|
||||
def enqueue(self, event: "MessageEvent", key: str) -> None:
|
||||
"""Add *event* to the pending batch for *key*."""
|
||||
chunk_len = len(event.text or "")
|
||||
existing = self._pending.get(key)
|
||||
if not existing:
|
||||
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
self._pending[key] = event
|
||||
else:
|
||||
existing.text = f"{existing.text}\n{event.text}"
|
||||
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
||||
|
||||
# Cancel prior flush timer, start a new one
|
||||
prior = self._pending_tasks.get(key)
|
||||
if prior and not prior.done():
|
||||
prior.cancel()
|
||||
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
|
||||
|
||||
async def _flush(self, key: str) -> None:
|
||||
"""Wait then dispatch the batched event for *key*."""
|
||||
current_task = self._pending_tasks.get(key)
|
||||
pending = self._pending.get(key)
|
||||
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
||||
|
||||
# Use longer delay when the last chunk looks like a split message
|
||||
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
event = self._pending.pop(key, None)
|
||||
if event:
|
||||
try:
|
||||
await self._handler(event)
|
||||
except Exception:
|
||||
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
|
||||
|
||||
if self._pending_tasks.get(key) is current_task:
|
||||
self._pending_tasks.pop(key, None)
|
||||
|
||||
def cancel_all(self) -> None:
|
||||
"""Cancel all pending flush tasks."""
|
||||
for task in self._pending_tasks.values():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
self._pending_tasks.clear()
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
# ─── Markdown Stripping ──────────────────────────────────────────────────────
|
||||
|
||||
# Pre-compiled regexes for performance
|
||||
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
|
||||
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
|
||||
_RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL)
|
||||
_RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL)
|
||||
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
|
||||
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
|
||||
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
||||
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
|
||||
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
|
||||
|
||||
|
||||
def strip_markdown(text: str) -> str:
|
||||
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
|
||||
|
||||
Replaces the identical ``_strip_markdown()`` functions previously
|
||||
duplicated in sms.py, bluebubbles.py, and feishu.py.
|
||||
"""
|
||||
text = _RE_BOLD.sub(r"\1", text)
|
||||
text = _RE_ITALIC_STAR.sub(r"\1", text)
|
||||
text = _RE_BOLD_UNDER.sub(r"\1", text)
|
||||
text = _RE_ITALIC_UNDER.sub(r"\1", text)
|
||||
text = _RE_CODE_BLOCK.sub("", text)
|
||||
text = _RE_INLINE_CODE.sub(r"\1", text)
|
||||
text = _RE_HEADING.sub("", text)
|
||||
text = _RE_LINK.sub(r"\1", text)
|
||||
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
# ─── Thread Participation Tracking ───────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreadParticipationTracker:
|
||||
"""Persistent tracking of threads the bot has participated in.
|
||||
|
||||
Replaces the identical ``_load/_save_participated_threads`` +
|
||||
``_mark_thread_participated`` pattern previously duplicated in
|
||||
discord.py and matrix.py.
|
||||
|
||||
Usage::
|
||||
|
||||
self._threads = ThreadParticipationTracker("discord")
|
||||
|
||||
# Check membership:
|
||||
if thread_id in self._threads:
|
||||
...
|
||||
|
||||
# Mark participation:
|
||||
self._threads.mark(thread_id)
|
||||
"""
|
||||
|
||||
_MAX_TRACKED = 500
|
||||
|
||||
def __init__(self, platform_name: str, max_tracked: int = 500):
|
||||
self._platform = platform_name
|
||||
self._max_tracked = max_tracked
|
||||
self._threads: set = self._load()
|
||||
|
||||
def _state_path(self) -> Path:
|
||||
from hermes_constants import get_hermes_home
|
||||
return get_hermes_home() / f"{self._platform}_threads.json"
|
||||
|
||||
def _load(self) -> set:
|
||||
path = self._state_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return set(json.loads(path.read_text(encoding="utf-8")))
|
||||
except Exception:
|
||||
pass
|
||||
return set()
|
||||
|
||||
def _save(self) -> None:
|
||||
path = self._state_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
thread_list = list(self._threads)
|
||||
if len(thread_list) > self._max_tracked:
|
||||
thread_list = thread_list[-self._max_tracked:]
|
||||
self._threads = set(thread_list)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
|
||||
def mark(self, thread_id: str) -> None:
|
||||
"""Mark *thread_id* as participated and persist."""
|
||||
if thread_id not in self._threads:
|
||||
self._threads.add(thread_id)
|
||||
self._save()
|
||||
|
||||
def __contains__(self, thread_id: str) -> bool:
|
||||
return thread_id in self._threads
|
||||
|
||||
def clear(self) -> None:
|
||||
self._threads.clear()
|
||||
|
||||
|
||||
# ─── Phone Number Redaction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging, preserving country code and last 4.
|
||||
|
||||
Replaces the identical ``_redact_phone()`` functions in signal.py,
|
||||
sms.py, and bluebubbles.py.
|
||||
"""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
|
@ -92,6 +92,7 @@ from gateway.platforms.base import (
|
|||
ProcessingOutcome,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -216,8 +217,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
self._pending_megolm: list = []
|
||||
|
||||
# Thread participation tracking (for require_mention bypass)
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
self._threads = ThreadParticipationTracker("matrix")
|
||||
|
||||
# Mention/thread gating — parsed once from env vars.
|
||||
self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
|
@ -1019,7 +1019,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# Require-mention gating.
|
||||
if not is_dm:
|
||||
is_free_room = room_id in self._free_rooms
|
||||
in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads)
|
||||
in_bot_thread = bool(thread_id and thread_id in self._threads)
|
||||
if self._require_mention and not is_free_room and not in_bot_thread:
|
||||
if not is_mentioned:
|
||||
return None
|
||||
|
|
@ -1027,7 +1027,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# DM mention-thread.
|
||||
if is_dm and not thread_id and self._dm_mention_threads and is_mentioned:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
# Strip mention from body.
|
||||
if is_mentioned:
|
||||
|
|
@ -1036,7 +1036,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
# Auto-thread.
|
||||
if not is_dm and not thread_id and self._auto_thread:
|
||||
thread_id = event_id
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
display_name = await self._get_display_name(room_id, sender)
|
||||
source = self.build_source(
|
||||
|
|
@ -1048,7 +1048,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
if thread_id:
|
||||
self._track_thread(thread_id)
|
||||
self._threads.mark(thread_id)
|
||||
|
||||
self._background_read_receipt(room_id, event_id)
|
||||
|
||||
|
|
@ -1697,48 +1697,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
for rid in self._joined_rooms
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread participation tracking
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _thread_state_path() -> Path:
|
||||
"""Path to the persisted thread participation set."""
|
||||
from hermes_cli.config import get_hermes_home
|
||||
return get_hermes_home() / "matrix_threads.json"
|
||||
|
||||
@classmethod
|
||||
def _load_participated_threads(cls) -> set:
|
||||
"""Load persisted thread IDs from disk."""
|
||||
path = cls._thread_state_path()
|
||||
try:
|
||||
if path.exists():
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return set(data)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load matrix thread state: %s", e)
|
||||
return set()
|
||||
|
||||
def _save_participated_threads(self) -> None:
|
||||
"""Persist the current thread set to disk (best-effort)."""
|
||||
path = self._thread_state_path()
|
||||
try:
|
||||
thread_list = list(self._bot_participated_threads)
|
||||
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||
self._bot_participated_threads = set(thread_list)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.debug("Could not save matrix thread state: %s", e)
|
||||
|
||||
def _track_thread(self, thread_id: str) -> None:
|
||||
"""Add a thread to the participation set and persist."""
|
||||
if thread_id not in self._bot_participated_threads:
|
||||
self._bot_participated_threads.add(thread_id)
|
||||
self._save_participated_threads()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mention detection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -96,10 +96,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
or os.getenv("MATTERMOST_REPLY_MODE", "off")
|
||||
).lower()
|
||||
|
||||
# Dedup cache: post_id → timestamp (prevent reprocessing)
|
||||
self._seen_posts: Dict[str, float] = {}
|
||||
self._SEEN_MAX = 2000
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
# Dedup cache (prevent reprocessing)
|
||||
self._dedup = MessageDeduplicator()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
|
|
@ -604,10 +602,8 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
post_id = post.get("id", "")
|
||||
|
||||
# Dedup.
|
||||
self._prune_seen()
|
||||
if post_id in self._seen_posts:
|
||||
if self._dedup.is_duplicate(post_id):
|
||||
return
|
||||
self._seen_posts[post_id] = time.time()
|
||||
|
||||
# Build message event.
|
||||
channel_id = post.get("channel_id", "")
|
||||
|
|
@ -734,13 +730,4 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
def _prune_seen(self) -> None:
|
||||
"""Remove expired entries from the dedup cache."""
|
||||
if len(self._seen_posts) < self._SEEN_MAX:
|
||||
return
|
||||
now = time.time()
|
||||
self._seen_posts = {
|
||||
pid: ts
|
||||
for pid, ts in self._seen_posts.items()
|
||||
if now - ts < self._SEEN_TTL
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from gateway.platforms.base import (
|
|||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -51,22 +52,10 @@ SSE_RETRY_DELAY_MAX = 60.0
|
|||
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
|
||||
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
|
|
@ -184,10 +173,8 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
self._recent_sent_timestamps: set = set()
|
||||
self._max_recent_timestamps = 50
|
||||
|
||||
self._phone_lock_identity: Optional[str] = None
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
self.http_url, redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -202,23 +189,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
|
||||
# Acquire scoped lock to prevent duplicate Signal listeners for the same phone
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._phone_lock_identity = self.account
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"signal-phone",
|
||||
self._phone_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Signal account"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Signal listener."
|
||||
)
|
||||
logger.error("Signal: %s", message)
|
||||
self._set_fatal_error("signal_phone_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('signal-phone', self.account, 'Signal account'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e)
|
||||
|
|
@ -270,13 +241,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
if self._phone_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("signal-phone", self._phone_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True)
|
||||
self._phone_lock_identity = None
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
|
|
@ -542,7 +507,7 @@ class SignalAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from pathlib import Path as _Path
|
|||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -89,11 +90,9 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
self._team_clients: Dict[str, AsyncWebClient] = {} # team_id → WebClient
|
||||
self._team_bot_user_ids: Dict[str, str] = {} # team_id → bot_user_id
|
||||
self._channel_team: Dict[str, str] = {} # channel_id → team_id
|
||||
# Dedup cache: event_ts → timestamp. Prevents duplicate bot
|
||||
# responses when Socket Mode reconnects redeliver events.
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Dedup cache: prevents duplicate bot responses when Socket Mode
|
||||
# reconnects redeliver events.
|
||||
self._dedup = MessageDeduplicator()
|
||||
# Track pending approval message_ts → resolved flag to prevent
|
||||
# double-clicks on approval buttons.
|
||||
self._approval_resolved: Dict[str, bool] = {}
|
||||
|
|
@ -152,15 +151,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
logger.warning("[Slack] Failed to read %s: %s", tokens_file, e)
|
||||
|
||||
try:
|
||||
# Acquire scoped lock to prevent duplicate app token usage
|
||||
from gateway.status import acquire_scoped_lock
|
||||
self._token_lock_identity = app_token
|
||||
acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'})
|
||||
if not acquired:
|
||||
owner_pid = existing.get('pid') if isinstance(existing, dict) else None
|
||||
message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.'
|
||||
logger.error('[%s] %s', self.name, message)
|
||||
self._set_fatal_error('slack_token_lock', message, retryable=False)
|
||||
if not self._acquire_platform_lock('slack-app-token', app_token, 'Slack app token'):
|
||||
return False
|
||||
|
||||
# First token is the primary — used for AsyncApp / Socket Mode
|
||||
|
|
@ -247,14 +238,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True)
|
||||
self._running = False
|
||||
|
||||
# Release the token lock (use stored identity, not re-read env)
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
if getattr(self, '_token_lock_identity', None):
|
||||
release_scoped_lock('slack-app-token', self._token_lock_identity)
|
||||
self._token_lock_identity = None
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
|
||||
logger.info("[Slack] Disconnected")
|
||||
|
||||
|
|
@ -953,17 +937,8 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
"""Handle an incoming Slack message event."""
|
||||
# Dedup: Slack Socket Mode can redeliver events after reconnects (#4777)
|
||||
event_ts = event.get("ts", "")
|
||||
if event_ts:
|
||||
now = time.time()
|
||||
if event_ts in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[event_ts] = now
|
||||
if len(self._seen_messages) > self._SEEN_MAX:
|
||||
cutoff = now - self._SEEN_TTL
|
||||
self._seen_messages = {
|
||||
k: v for k, v in self._seen_messages.items()
|
||||
if v > cutoff
|
||||
}
|
||||
if event_ts and self._dedup.is_duplicate(event_ts):
|
||||
return
|
||||
|
||||
# Bot message filtering (SLACK_ALLOW_BOTS / config allow_bots):
|
||||
# "none" — ignore all bot messages (default, backward-compatible)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ Shares credentials with the optional telephony skill — same env vars:
|
|||
|
||||
Gateway-specific env vars:
|
||||
- SMS_WEBHOOK_PORT (default 8080)
|
||||
- SMS_WEBHOOK_HOST (default 0.0.0.0)
|
||||
- SMS_WEBHOOK_URL (public URL for Twilio signature validation — required)
|
||||
- SMS_INSECURE_NO_SIGNATURE (true to disable signature validation — dev only)
|
||||
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
|
||||
- SMS_ALLOW_ALL_USERS (true/false)
|
||||
- SMS_HOME_CHANNEL (phone number for cron delivery)
|
||||
|
|
@ -17,9 +20,10 @@ Gateway-specific env vars:
|
|||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
|
@ -30,24 +34,14 @@ from gateway.platforms.base import (
|
|||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.platforms.helpers import redact_phone, strip_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
|
||||
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
|
||||
DEFAULT_WEBHOOK_PORT = 8080
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:5] + "***" + phone[-4:]
|
||||
DEFAULT_WEBHOOK_HOST = "0.0.0.0"
|
||||
|
||||
|
||||
def check_sms_requirements() -> bool:
|
||||
|
|
@ -77,6 +71,8 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
self._webhook_port: int = int(
|
||||
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
|
||||
)
|
||||
self._webhook_host: str = os.getenv("SMS_WEBHOOK_HOST", DEFAULT_WEBHOOK_HOST)
|
||||
self._webhook_url: str = os.getenv("SMS_WEBHOOK_URL", "").strip()
|
||||
self._runner = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
|
|
@ -98,13 +94,33 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
|
||||
return False
|
||||
|
||||
insecure_no_sig = os.getenv("SMS_INSECURE_NO_SIGNATURE", "").lower() == "true"
|
||||
|
||||
if not self._webhook_url and not insecure_no_sig:
|
||||
logger.error(
|
||||
"[sms] Refusing to start: SMS_WEBHOOK_URL is required for Twilio "
|
||||
"signature validation. Set it to the public URL configured in your "
|
||||
"Twilio console (e.g. https://example.com/webhooks/twilio). "
|
||||
"For local development without validation, set "
|
||||
"SMS_INSECURE_NO_SIGNATURE=true (NOT recommended for production).",
|
||||
)
|
||||
return False
|
||||
|
||||
if insecure_no_sig and not self._webhook_url:
|
||||
logger.warning(
|
||||
"[sms] SMS_INSECURE_NO_SIGNATURE=true — Twilio signature validation "
|
||||
"is DISABLED. Any client that can reach port %d can inject messages. "
|
||||
"Do NOT use this in production.",
|
||||
self._webhook_port,
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_post("/webhooks/twilio", self._handle_webhook)
|
||||
app.router.add_get("/health", lambda _: web.Response(text="ok"))
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "0.0.0.0", self._webhook_port)
|
||||
site = web.TCPSite(self._runner, self._webhook_host, self._webhook_port)
|
||||
await site.start()
|
||||
self._http_session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
|
|
@ -112,9 +128,10 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
self._running = True
|
||||
|
||||
logger.info(
|
||||
"[sms] Twilio webhook server listening on port %d, from: %s",
|
||||
"[sms] Twilio webhook server listening on %s:%d, from: %s",
|
||||
self._webhook_host,
|
||||
self._webhook_port,
|
||||
_redact_phone(self._from_number),
|
||||
redact_phone(self._from_number),
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -163,7 +180,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
error_msg = body.get("message", str(body))
|
||||
logger.error(
|
||||
"[sms] send failed to %s: %s %s",
|
||||
_redact_phone(chat_id),
|
||||
redact_phone(chat_id),
|
||||
resp.status,
|
||||
error_msg,
|
||||
)
|
||||
|
|
@ -174,7 +191,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
msg_sid = body.get("sid", "")
|
||||
last_result = SendResult(success=True, message_id=msg_sid)
|
||||
except Exception as e:
|
||||
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
|
||||
logger.error("[sms] send error to %s: %s", redact_phone(chat_id), e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
finally:
|
||||
# Close session only if we created a fallback (no persistent session)
|
||||
|
|
@ -192,16 +209,75 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""Strip markdown — SMS renders it as literal characters."""
|
||||
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"_(.+?)_", r"\1", content, flags=re.DOTALL)
|
||||
content = re.sub(r"```[a-z]*\n?", "", content)
|
||||
content = re.sub(r"`(.+?)`", r"\1", content)
|
||||
content = re.sub(r"^#{1,6}\s+", "", content, flags=re.MULTILINE)
|
||||
content = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", content)
|
||||
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||
return content.strip()
|
||||
return strip_markdown(content)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio signature validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_twilio_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Validate ``X-Twilio-Signature`` header (HMAC-SHA1, base64).
|
||||
|
||||
Tries both with and without the default port for the URL scheme,
|
||||
since Twilio may sign with either variant.
|
||||
|
||||
Algorithm: https://www.twilio.com/docs/usage/security#validating-requests
|
||||
"""
|
||||
if self._check_signature(url, post_params, signature):
|
||||
return True
|
||||
|
||||
variant = self._port_variant_url(url)
|
||||
if variant and self._check_signature(variant, post_params, signature):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_signature(
|
||||
self, url: str, post_params: dict, signature: str,
|
||||
) -> bool:
|
||||
"""Compute and compare a single Twilio signature."""
|
||||
data_to_sign = url
|
||||
for key in sorted(post_params.keys()):
|
||||
data_to_sign += key + post_params[key]
|
||||
mac = hmac.new(
|
||||
self._auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
computed = base64.b64encode(mac.digest()).decode("utf-8")
|
||||
return hmac.compare_digest(computed, signature)
|
||||
|
||||
@staticmethod
|
||||
def _port_variant_url(url: str) -> str | None:
|
||||
"""Return the URL with the default port toggled, or None.
|
||||
|
||||
Only toggles default ports (443 for https, 80 for http).
|
||||
Non-standard ports are never modified.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
default_ports = {"https": 443, "http": 80}
|
||||
default_port = default_ports.get(parsed.scheme)
|
||||
if default_port is None:
|
||||
return None
|
||||
|
||||
if parsed.port == default_port:
|
||||
# Has explicit default port → strip it
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, parsed.hostname, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
elif parsed.port is None:
|
||||
# No port → add default
|
||||
netloc = f"{parsed.hostname}:{default_port}"
|
||||
return urllib.parse.urlunparse(
|
||||
(parsed.scheme, netloc, parsed.path,
|
||||
parsed.params, parsed.query, parsed.fragment)
|
||||
)
|
||||
|
||||
# Non-standard port — no variant
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Twilio webhook handler
|
||||
|
|
@ -213,7 +289,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
raw = await request.read()
|
||||
# Twilio sends form-encoded data, not JSON
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"))
|
||||
form = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
|
||||
except Exception as e:
|
||||
logger.error("[sms] webhook parse error: %s", e)
|
||||
return web.Response(
|
||||
|
|
@ -222,6 +298,27 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
status=400,
|
||||
)
|
||||
|
||||
# Validate Twilio request signature when SMS_WEBHOOK_URL is configured
|
||||
if self._webhook_url:
|
||||
twilio_sig = request.headers.get("X-Twilio-Signature", "")
|
||||
if not twilio_sig:
|
||||
logger.warning("[sms] Rejected: missing X-Twilio-Signature header")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
flat_params = {k: v[0] for k, v in form.items() if v}
|
||||
if not self._validate_twilio_signature(
|
||||
self._webhook_url, flat_params, twilio_sig
|
||||
):
|
||||
logger.warning("[sms] Rejected: invalid Twilio signature")
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
status=403,
|
||||
)
|
||||
|
||||
# Extract fields (parse_qs returns lists)
|
||||
from_number = (form.get("From", [""]))[0].strip()
|
||||
to_number = (form.get("To", [""]))[0].strip()
|
||||
|
|
@ -236,7 +333,7 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
# Ignore messages from our own number (echo prevention)
|
||||
if from_number == self._from_number:
|
||||
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
|
||||
logger.debug("[sms] ignoring echo from own number %s", redact_phone(from_number))
|
||||
return web.Response(
|
||||
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
|
||||
content_type="application/xml",
|
||||
|
|
@ -244,8 +341,8 @@ class SmsAdapter(BasePlatformAdapter):
|
|||
|
||||
logger.info(
|
||||
"[sms] inbound from %s -> %s: %s",
|
||||
_redact_phone(from_number),
|
||||
_redact_phone(to_number),
|
||||
redact_phone(from_number),
|
||||
redact_phone(to_number),
|
||||
text[:80],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_TELEGRAM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
|
|
@ -300,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
|
||||
# Exhausted retries — fatal
|
||||
message = (
|
||||
"Another Telegram bot poller is already using this token. "
|
||||
"Another process is already polling this Telegram bot token "
|
||||
"(possibly OpenClaw or another Hermes instance). "
|
||||
"Hermes stopped Telegram polling after %d retries. "
|
||||
"Make sure only one gateway instance is running for this bot token."
|
||||
"Only one poller can run per token — stop the other process "
|
||||
"and restart with 'hermes start'."
|
||||
% MAX_CONFLICT_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Original error: %s", self.name, message, error)
|
||||
|
|
@ -497,23 +498,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self.config.token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"telegram-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Telegram bot token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Telegram poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("telegram_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('telegram-bot-token', self.config.token, 'Telegram bot token'):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
|
|
@ -737,12 +722,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
message = f"Telegram startup failed: {e}"
|
||||
self._set_fatal_error("telegram_connect_error", message, retryable=True)
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
|
|
@ -768,12 +748,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error during Telegram disconnect: %s", self.name, e, exc_info=True)
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing Telegram token lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
for task in self._pending_photo_batch_tasks.values():
|
||||
if task and not task.done():
|
||||
|
|
@ -784,7 +759,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
self._mark_disconnected()
|
||||
self._app = None
|
||||
self._bot = None
|
||||
self._token_lock_identity = None
|
||||
logger.info("[%s] Disconnected from Telegram", self.name)
|
||||
|
||||
def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool:
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ except ImportError:
|
|||
httpx = None # type: ignore[assignment]
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -92,7 +93,6 @@ REQUEST_TIMEOUT_SECONDS = 15.0
|
|||
HEARTBEAT_INTERVAL_SECONDS = 30.0
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
IMAGE_MAX_BYTES = 10 * 1024 * 1024
|
||||
|
|
@ -172,7 +172,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._pending_responses: Dict[str, asyncio.Future] = {}
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._dedup = MessageDeduplicator(max_size=DEDUP_MAX_SIZE)
|
||||
self._reply_req_ids: Dict[str, str] = {}
|
||||
|
||||
# Text batching: merge rapid successive messages (Telegram-style).
|
||||
|
|
@ -250,7 +250,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
await self._http_client.aclose()
|
||||
self._http_client = None
|
||||
|
||||
self._seen_messages.clear()
|
||||
self._dedup.clear()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
async def _cleanup_ws(self) -> None:
|
||||
|
|
@ -476,7 +476,7 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex)
|
||||
if self._is_duplicate(msg_id):
|
||||
if self._dedup.is_duplicate(msg_id):
|
||||
logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id)
|
||||
return
|
||||
self._remember_reply_req_id(msg_id, self._payload_req_id(payload))
|
||||
|
|
@ -636,6 +636,13 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
if voice_text:
|
||||
text_parts.append(voice_text)
|
||||
|
||||
# Extract appmsg title (filename) for WeCom AI Bot attachments
|
||||
if msgtype == "appmsg":
|
||||
appmsg = body.get("appmsg") if isinstance(body.get("appmsg"), dict) else {}
|
||||
title = str(appmsg.get("title") or "").strip()
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
if quote_type == "text":
|
||||
|
|
@ -668,6 +675,13 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
refs.append(("image", body["image"]))
|
||||
if msgtype == "file" and isinstance(body.get("file"), dict):
|
||||
refs.append(("file", body["file"]))
|
||||
# Handle appmsg (WeCom AI Bot attachments with PDF/Word/Excel)
|
||||
if msgtype == "appmsg" and isinstance(body.get("appmsg"), dict):
|
||||
appmsg = body["appmsg"]
|
||||
if isinstance(appmsg.get("file"), dict):
|
||||
refs.append(("file", appmsg["file"]))
|
||||
elif isinstance(appmsg.get("image"), dict):
|
||||
refs.append(("image", appmsg["image"]))
|
||||
|
||||
quote = body.get("quote") if isinstance(body.get("quote"), dict) else {}
|
||||
quote_type = str(quote.get("msgtype") or "").lower()
|
||||
|
|
@ -825,24 +839,6 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
wildcard = self._groups.get("*")
|
||||
return wildcard if isinstance(wildcard, dict) else {}
|
||||
|
||||
def _is_duplicate(self, msg_id: str) -> bool:
|
||||
now = time.time()
|
||||
if len(self._seen_messages) > DEDUP_MAX_SIZE:
|
||||
cutoff = now - DEDUP_WINDOW_SECONDS
|
||||
self._seen_messages = {
|
||||
key: ts for key, ts in self._seen_messages.items() if ts > cutoff
|
||||
}
|
||||
if self._reply_req_ids:
|
||||
self._reply_req_ids = {
|
||||
key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages
|
||||
}
|
||||
|
||||
if msg_id in self._seen_messages:
|
||||
return True
|
||||
|
||||
self._seen_messages[msg_id] = now
|
||||
return False
|
||||
|
||||
def _remember_reply_req_id(self, message_id: str, req_id: str) -> None:
|
||||
normalized_message_id = str(message_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ except ImportError: # pragma: no cover - dependency gate
|
|||
CRYPTO_AVAILABLE = False
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.helpers import MessageDeduplicator
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -63,6 +64,7 @@ from gateway.platforms.base import (
|
|||
cache_image_from_bytes,
|
||||
)
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import atomic_json_write
|
||||
|
||||
ILINK_BASE_URL = "https://ilinkai.weixin.qq.com"
|
||||
WEIXIN_CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||
|
|
@ -206,7 +208,7 @@ def save_weixin_account(
|
|||
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
path = _account_file(hermes_home, account_id)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
atomic_json_write(path, payload)
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
|
|
@ -269,7 +271,7 @@ class ContextTokenStore:
|
|||
if key.startswith(prefix)
|
||||
}
|
||||
try:
|
||||
self._path(account_id).write_text(json.dumps(payload), encoding="utf-8")
|
||||
atomic_json_write(self._path(account_id), payload)
|
||||
except Exception as exc:
|
||||
logger.warning("weixin: failed to persist context tokens for %s: %s", _safe_id(account_id), exc)
|
||||
|
||||
|
|
@ -755,23 +757,58 @@ def _pack_markdown_blocks_for_weixin(content: str, max_length: int) -> List[str]
|
|||
return packed
|
||||
|
||||
|
||||
def _split_text_for_weixin_delivery(content: str, max_length: int) -> List[str]:
|
||||
def _split_text_for_weixin_delivery(
|
||||
content: str, max_length: int, split_per_line: bool = False,
|
||||
) -> List[str]:
|
||||
"""Split content into sequential Weixin messages.
|
||||
|
||||
Prefer one message per top-level line/markdown unit when the author used
|
||||
explicit line breaks. Oversized units fall back to block-aware packing so
|
||||
long code fences still split safely.
|
||||
"""
|
||||
if len(content) <= max_length and "\n" not in content:
|
||||
return [content]
|
||||
*compact* (default): Keep everything in a single message whenever it fits
|
||||
within the platform limit, even when the author used explicit line breaks.
|
||||
Only fall back to block-aware packing when the payload exceeds
|
||||
``max_length``.
|
||||
|
||||
chunks: List[str] = []
|
||||
for unit in _split_delivery_units_for_weixin(content):
|
||||
if len(unit) <= max_length:
|
||||
chunks.append(unit)
|
||||
continue
|
||||
chunks.extend(_pack_markdown_blocks_for_weixin(unit, max_length))
|
||||
return chunks or [content]
|
||||
*per_line* (``split_per_line=True``): Legacy behavior — top-level line
|
||||
breaks become separate chat messages; oversized units still use
|
||||
block-aware packing.
|
||||
|
||||
The active mode is controlled via ``config.yaml`` ->
|
||||
``platforms.weixin.extra.split_multiline_messages`` (``true`` / ``false``)
|
||||
or the env var ``WEIXIN_SPLIT_MULTILINE_MESSAGES``.
|
||||
"""
|
||||
if split_per_line:
|
||||
# Legacy: one message per top-level delivery unit.
|
||||
if len(content) <= max_length and "\n" not in content:
|
||||
return [content]
|
||||
chunks: List[str] = []
|
||||
for unit in _split_delivery_units_for_weixin(content):
|
||||
if len(unit) <= max_length:
|
||||
chunks.append(unit)
|
||||
continue
|
||||
chunks.extend(_pack_markdown_blocks_for_weixin(unit, max_length))
|
||||
return chunks or [content]
|
||||
|
||||
# Compact (default): single message when under the limit.
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
return _pack_markdown_blocks_for_weixin(content, max_length) or [content]
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = True) -> bool:
|
||||
"""Coerce a config value to bool, tolerating strings like ``"true"``."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
text = str(value).strip().lower()
|
||||
if not text:
|
||||
return default
|
||||
if text in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if text in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def _extract_text(item_list: List[Dict[str, Any]]) -> str:
|
||||
|
|
@ -833,7 +870,7 @@ def _load_sync_buf(hermes_home: str, account_id: str) -> str:
|
|||
|
||||
def _save_sync_buf(hermes_home: str, account_id: str, sync_buf: str) -> None:
|
||||
path = _sync_buf_path(hermes_home, account_id)
|
||||
path.write_text(json.dumps({"get_updates_buf": sync_buf}), encoding="utf-8")
|
||||
atomic_json_write(path, {"get_updates_buf": sync_buf})
|
||||
|
||||
|
||||
async def qr_login(
|
||||
|
|
@ -972,8 +1009,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
self._typing_cache = TypingTicketCache()
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._dedup = MessageDeduplicator(ttl_seconds=MESSAGE_DEDUP_TTL_SECONDS)
|
||||
|
||||
self._account_id = str(extra.get("account_id") or os.getenv("WEIXIN_ACCOUNT_ID", "")).strip()
|
||||
self._token = str(config.token or extra.get("token") or os.getenv("WEIXIN_TOKEN", "")).strip()
|
||||
|
|
@ -981,6 +1017,16 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
self._cdn_base_url = str(
|
||||
extra.get("cdn_base_url") or os.getenv("WEIXIN_CDN_BASE_URL", WEIXIN_CDN_BASE_URL)
|
||||
).strip().rstrip("/")
|
||||
self._send_chunk_delay_seconds = float(
|
||||
extra.get("send_chunk_delay_seconds") or os.getenv("WEIXIN_SEND_CHUNK_DELAY_SECONDS", "0.35")
|
||||
)
|
||||
self._send_chunk_retries = int(
|
||||
extra.get("send_chunk_retries") or os.getenv("WEIXIN_SEND_CHUNK_RETRIES", "2")
|
||||
)
|
||||
self._send_chunk_retry_delay_seconds = float(
|
||||
extra.get("send_chunk_retry_delay_seconds")
|
||||
or os.getenv("WEIXIN_SEND_CHUNK_RETRY_DELAY_SECONDS", "1.0")
|
||||
)
|
||||
self._dm_policy = str(extra.get("dm_policy") or os.getenv("WEIXIN_DM_POLICY", "open")).strip().lower()
|
||||
self._group_policy = str(extra.get("group_policy") or os.getenv("WEIXIN_GROUP_POLICY", "disabled")).strip().lower()
|
||||
allow_from = extra.get("allow_from")
|
||||
|
|
@ -991,6 +1037,11 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
group_allow_from = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "")
|
||||
self._allow_from = self._coerce_list(allow_from)
|
||||
self._group_allow_from = self._coerce_list(group_allow_from)
|
||||
self._split_multiline_messages = _coerce_bool(
|
||||
extra.get("split_multiline_messages")
|
||||
or os.getenv("WEIXIN_SPLIT_MULTILINE_MESSAGES"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
if self._account_id and not self._token:
|
||||
persisted = load_weixin_account(hermes_home, self._account_id)
|
||||
|
|
@ -1026,23 +1077,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
return False
|
||||
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._token_lock_identity = self._token
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"weixin-bot-token",
|
||||
self._token_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this Weixin token"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second Weixin poller."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("weixin_token_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('weixin-bot-token', self._token, 'Weixin bot token'):
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.debug("[%s] Token lock unavailable (non-fatal): %s", self.name, exc)
|
||||
|
|
@ -1066,12 +1101,7 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
if self._token_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("weixin-bot-token", self._token_lock_identity)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Error releasing Weixin token lock: %s", self.name, exc, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
self._mark_disconnected()
|
||||
logger.info("[%s] Disconnected", self.name)
|
||||
|
||||
|
|
@ -1149,16 +1179,8 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
return
|
||||
|
||||
message_id = str(message.get("message_id") or "").strip()
|
||||
if message_id:
|
||||
now = time.time()
|
||||
self._seen_messages = {
|
||||
key: value
|
||||
for key, value in self._seen_messages.items()
|
||||
if now - value < MESSAGE_DEDUP_TTL_SECONDS
|
||||
}
|
||||
if message_id in self._seen_messages:
|
||||
return
|
||||
self._seen_messages[message_id] = now
|
||||
if message_id and self._dedup.is_duplicate(message_id):
|
||||
return
|
||||
|
||||
chat_type, effective_chat_id = _guess_chat_type(message, self._account_id)
|
||||
if chat_type == "group":
|
||||
|
|
@ -1330,7 +1352,50 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
logger.debug("[%s] getConfig failed for %s: %s", self.name, _safe_id(user_id), exc)
|
||||
|
||||
def _split_text(self, content: str) -> List[str]:
|
||||
return _split_text_for_weixin_delivery(content, self.MAX_MESSAGE_LENGTH)
|
||||
return _split_text_for_weixin_delivery(
|
||||
content, self.MAX_MESSAGE_LENGTH, self._split_multiline_messages,
|
||||
)
|
||||
|
||||
async def _send_text_chunk(
|
||||
self,
|
||||
*,
|
||||
chat_id: str,
|
||||
chunk: str,
|
||||
context_token: Optional[str],
|
||||
client_id: str,
|
||||
) -> None:
|
||||
"""Send a single text chunk with per-chunk retry and backoff."""
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(self._send_chunk_retries + 1):
|
||||
try:
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt >= self._send_chunk_retries:
|
||||
break
|
||||
wait = self._send_chunk_retry_delay_seconds * (attempt + 1)
|
||||
logger.warning(
|
||||
"[%s] send chunk failed to=%s attempt=%d/%d, retrying in %.2fs: %s",
|
||||
self.name,
|
||||
_safe_id(chat_id),
|
||||
attempt + 1,
|
||||
self._send_chunk_retries + 1,
|
||||
wait,
|
||||
exc,
|
||||
)
|
||||
if wait > 0:
|
||||
await asyncio.sleep(wait)
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def send(
|
||||
self,
|
||||
|
|
@ -1344,18 +1409,18 @@ class WeixinAdapter(BasePlatformAdapter):
|
|||
context_token = self._token_store.get(self._account_id, chat_id)
|
||||
last_message_id: Optional[str] = None
|
||||
try:
|
||||
for chunk in self._split_text(self.format_message(content)):
|
||||
chunks = self._split_text(self.format_message(content))
|
||||
for idx, chunk in enumerate(chunks):
|
||||
client_id = f"hermes-weixin-{uuid.uuid4().hex}"
|
||||
await _send_message(
|
||||
self._session,
|
||||
base_url=self._base_url,
|
||||
token=self._token,
|
||||
to=chat_id,
|
||||
text=chunk,
|
||||
await self._send_text_chunk(
|
||||
chat_id=chat_id,
|
||||
chunk=chunk,
|
||||
context_token=context_token,
|
||||
client_id=client_id,
|
||||
)
|
||||
last_message_id = client_id
|
||||
if idx < len(chunks) - 1 and self._send_chunk_delay_seconds > 0:
|
||||
await asyncio.sleep(self._send_chunk_delay_seconds)
|
||||
return SendResult(success=True, message_id=last_message_id)
|
||||
except Exception as exc:
|
||||
logger.error("[%s] send failed to=%s: %s", self.name, _safe_id(chat_id), exc)
|
||||
|
|
|
|||
|
|
@ -145,7 +145,6 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
self._bridge_log: Optional[Path] = None
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
self._http_session: Optional["aiohttp.ClientSession"] = None
|
||||
self._session_lock_identity: Optional[str] = None
|
||||
|
||||
def _whatsapp_require_mention(self) -> bool:
|
||||
configured = self.config.extra.get("require_mention")
|
||||
|
|
@ -290,23 +289,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
|
||||
# Acquire scoped lock to prevent duplicate sessions
|
||||
try:
|
||||
from gateway.status import acquire_scoped_lock
|
||||
|
||||
self._session_lock_identity = str(self._session_path)
|
||||
acquired, existing = acquire_scoped_lock(
|
||||
"whatsapp-session",
|
||||
self._session_lock_identity,
|
||||
metadata={"platform": self.platform.value},
|
||||
)
|
||||
if not acquired:
|
||||
owner_pid = existing.get("pid") if isinstance(existing, dict) else None
|
||||
message = (
|
||||
"Another local Hermes gateway is already using this WhatsApp session"
|
||||
+ (f" (PID {owner_pid})." if owner_pid else ".")
|
||||
+ " Stop the other gateway before starting a second WhatsApp bridge."
|
||||
)
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_session_lock", message, retryable=False)
|
||||
if not self._acquire_platform_lock('whatsapp-session', str(self._session_path), 'WhatsApp session'):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e)
|
||||
|
|
@ -468,12 +451,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
return True
|
||||
|
||||
except Exception as e:
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
self._release_platform_lock()
|
||||
logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True)
|
||||
self._close_bridge_log()
|
||||
return False
|
||||
|
|
@ -546,17 +524,11 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
if self._session_lock_identity:
|
||||
try:
|
||||
from gateway.status import release_scoped_lock
|
||||
release_scoped_lock("whatsapp-session", self._session_lock_identity)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True)
|
||||
self._release_platform_lock()
|
||||
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
self._session_lock_identity = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
|
|
|
|||
458
gateway/run.py
458
gateway/run.py
|
|
@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str:
|
|||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
|
||||
"""Consume and return the text of a pending queued message.
|
||||
def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
|
||||
"""Consume and return the full pending event for a session.
|
||||
|
||||
Preserves media context for captionless photo/document events by
|
||||
building a placeholder so the message isn't silently dropped.
|
||||
Queued follow-ups must preserve their media metadata so they can re-enter
|
||||
the normal image/STT/document preprocessing path instead of being reduced
|
||||
to a placeholder string.
|
||||
"""
|
||||
event = adapter.get_pending_message(session_key)
|
||||
if not event:
|
||||
return None
|
||||
text = event.text
|
||||
if not text and getattr(event, "media_urls", None):
|
||||
text = _build_media_placeholder(event)
|
||||
return text
|
||||
return adapter.get_pending_message(session_key)
|
||||
|
||||
|
||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||
|
|
@ -1465,7 +1460,18 @@ class GatewayRunner:
|
|||
logger.info("Recovered %s background process(es) from previous run", recovered)
|
||||
except Exception as e:
|
||||
logger.warning("Process checkpoint recovery: %s", e)
|
||||
|
||||
|
||||
# Suspend sessions that were active when the gateway last exited.
|
||||
# This prevents stuck sessions from being blindly resumed on restart,
|
||||
# which can create an unrecoverable loop (#7536). Suspended sessions
|
||||
# auto-reset on the next incoming message, giving the user a clean start.
|
||||
try:
|
||||
suspended = self.session_store.suspend_recently_active()
|
||||
if suspended:
|
||||
logger.info("Suspended %d in-flight session(s) from previous run", suspended)
|
||||
except Exception as e:
|
||||
logger.warning("Session suspension on startup failed: %s", e)
|
||||
|
||||
connected_count = 0
|
||||
enabled_platform_count = 0
|
||||
startup_nonretryable_errors: list[str] = []
|
||||
|
|
@ -2221,6 +2227,13 @@ class GatewayRunner:
|
|||
# are system-generated and must skip user authorization.
|
||||
if getattr(event, "internal", False):
|
||||
pass
|
||||
elif source.user_id is None:
|
||||
# Messages with no user identity (Telegram service messages,
|
||||
# channel forwards, anonymous admin actions) cannot be
|
||||
# authorized — drop silently instead of triggering the pairing
|
||||
# flow with a None user_id.
|
||||
logger.debug("Ignoring message with no user_id from %s", source.platform.value)
|
||||
return None
|
||||
elif not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
# In DMs: offer pairing code. In groups: silently ignore.
|
||||
|
|
@ -2370,8 +2383,11 @@ class GatewayRunner:
|
|||
self._pending_messages.pop(_quick_key, None)
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
logger.info("HARD STOP for session %s — session lock released", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||
# Mark session suspended so the next message starts fresh
|
||||
# instead of resuming the stuck context (#7536).
|
||||
self.session_store.suspend_session(_quick_key)
|
||||
logger.info("HARD STOP for session %s — suspended, session lock released", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The session is suspended — your next message will start fresh."
|
||||
|
||||
# /reset and /new must bypass the running-agent guard so they
|
||||
# actually dispatch as commands instead of being queued as user
|
||||
|
|
@ -2761,6 +2777,162 @@ class GatewayRunner:
|
|||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
|
||||
async def _prepare_inbound_message_text(
|
||||
self,
|
||||
*,
|
||||
event: MessageEvent,
|
||||
source: SessionSource,
|
||||
history: List[Dict[str, Any]],
|
||||
) -> Optional[str]:
|
||||
"""Prepare inbound event text for the agent.
|
||||
|
||||
Keep the normal inbound path and the queued follow-up path on the same
|
||||
preprocessing pipeline so sender attribution, image enrichment, STT,
|
||||
document notes, reply context, and @ references all behave the same.
|
||||
"""
|
||||
history = history or []
|
||||
message_text = event.text or ""
|
||||
|
||||
_is_shared_thread = (
|
||||
source.chat_type != "dm"
|
||||
and source.thread_id
|
||||
and not getattr(self.config, "thread_sessions_per_user", False)
|
||||
)
|
||||
if _is_shared_thread and source.user_name:
|
||||
message_text = f"[{source.user_name}] {message_text}"
|
||||
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
if mtype.startswith("image/") or event.message_type == MessageType.PHOTO:
|
||||
image_paths.append(path)
|
||||
if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO):
|
||||
audio_paths.append(path)
|
||||
|
||||
if image_paths:
|
||||
message_text = await self._enrich_message_with_vision(
|
||||
message_text,
|
||||
image_paths,
|
||||
)
|
||||
|
||||
if audio_paths:
|
||||
message_text = await self._enrich_message_with_transcription(
|
||||
message_text,
|
||||
audio_paths,
|
||||
)
|
||||
_stt_fail_markers = (
|
||||
"No STT provider",
|
||||
"STT is disabled",
|
||||
"can't listen",
|
||||
"VOICE_TOOLS_OPENAI_KEY",
|
||||
)
|
||||
if any(marker in message_text for marker in _stt_fail_markers):
|
||||
_stt_adapter = self.adapters.get(source.platform)
|
||||
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
if _stt_adapter:
|
||||
try:
|
||||
_stt_msg = (
|
||||
"🎤 I received your voice message but can't transcribe it — "
|
||||
"no speech-to-text provider is configured.\n\n"
|
||||
"To enable voice: install faster-whisper "
|
||||
"(`pip install faster-whisper` in the Hermes venv) "
|
||||
"and set `stt.enabled: true` in config.yaml, "
|
||||
"then /restart the gateway."
|
||||
)
|
||||
if self._has_setup_skill():
|
||||
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
|
||||
await _stt_adapter.send(
|
||||
source.chat_id,
|
||||
_stt_msg,
|
||||
metadata=_stt_meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if event.media_urls and event.message_type == MessageType.DOCUMENT:
|
||||
import mimetypes as _mimetypes
|
||||
|
||||
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
if mtype in ("", "application/octet-stream"):
|
||||
import os as _os2
|
||||
|
||||
_ext = _os2.path.splitext(path)[1].lower()
|
||||
if _ext in _TEXT_EXTENSIONS:
|
||||
mtype = "text/plain"
|
||||
else:
|
||||
guessed, _ = _mimetypes.guess_type(path)
|
||||
if guessed:
|
||||
mtype = guessed
|
||||
if not mtype.startswith(("application/", "text/")):
|
||||
continue
|
||||
|
||||
import os as _os
|
||||
import re as _re
|
||||
|
||||
basename = _os.path.basename(path)
|
||||
parts = basename.split("_", 2)
|
||||
display_name = parts[2] if len(parts) >= 3 else basename
|
||||
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
|
||||
if mtype.startswith("text/"):
|
||||
context_note = (
|
||||
f"[The user sent a text document: '{display_name}'. "
|
||||
f"Its content has been included below. "
|
||||
f"The file is also saved at: {path}]"
|
||||
)
|
||||
else:
|
||||
context_note = (
|
||||
f"[The user sent a document: '{display_name}'. "
|
||||
f"The file is saved at: {path}. "
|
||||
f"Ask the user what they'd like you to do with it.]"
|
||||
)
|
||||
message_text = f"{context_note}\n\n{message_text}"
|
||||
|
||||
if getattr(event, "reply_to_text", None) and event.reply_to_message_id:
|
||||
reply_snippet = event.reply_to_text[:500]
|
||||
found_in_history = any(
|
||||
reply_snippet[:200] in (msg.get("content") or "")
|
||||
for msg in history
|
||||
if msg.get("role") in ("assistant", "user", "tool")
|
||||
)
|
||||
if not found_in_history:
|
||||
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
|
||||
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model,
|
||||
base_url=self._base_url or "",
|
||||
)
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text,
|
||||
cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len,
|
||||
allowed_root=_msg_cwd,
|
||||
)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return None
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
return message_text
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
_msg_start_time = time.time()
|
||||
|
|
@ -2812,7 +2984,9 @@ class GatewayRunner:
|
|||
# so the agent knows this is a fresh conversation (not an intentional /reset).
|
||||
if getattr(session_entry, 'was_auto_reset', False):
|
||||
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
|
||||
if reset_reason == "daily":
|
||||
if reset_reason == "suspended":
|
||||
context_note = "[System note: The user's previous session was stopped and suspended. This is a fresh conversation with no prior context.]"
|
||||
elif reset_reason == "daily":
|
||||
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
|
||||
else:
|
||||
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
|
||||
|
|
@ -2829,7 +3003,9 @@ class GatewayRunner:
|
|||
)
|
||||
platform_name = source.platform.value if source.platform else ""
|
||||
had_activity = getattr(session_entry, 'reset_had_activity', False)
|
||||
should_notify = (
|
||||
# Suspended sessions always notify (they were explicitly stopped
|
||||
# or crashed mid-operation) — skip the policy check.
|
||||
should_notify = reset_reason == "suspended" or (
|
||||
policy.notify
|
||||
and had_activity
|
||||
and platform_name not in policy.notify_exclude_platforms
|
||||
|
|
@ -2837,7 +3013,9 @@ class GatewayRunner:
|
|||
if should_notify:
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
if reset_reason == "daily":
|
||||
if reset_reason == "suspended":
|
||||
reason_text = "previous session was stopped or interrupted"
|
||||
elif reset_reason == "daily":
|
||||
reason_text = f"daily schedule at {policy.at_hour}:00"
|
||||
else:
|
||||
hours = policy.idle_minutes // 60
|
||||
|
|
@ -3195,149 +3373,13 @@ class GatewayRunner:
|
|||
# attachments (documents, audio, etc.) are not sent to the vision
|
||||
# tool even when they appear in the same message.
|
||||
# -----------------------------------------------------------------
|
||||
message_text = event.text or ""
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Sender attribution for shared thread sessions.
|
||||
#
|
||||
# When multiple users share a single thread session (the default for
|
||||
# threads), prefix each message with [sender name] so the agent can
|
||||
# tell participants apart. Skip for DMs (single-user by nature) and
|
||||
# when per-user thread isolation is explicitly enabled.
|
||||
# -----------------------------------------------------------------
|
||||
_is_shared_thread = (
|
||||
source.chat_type != "dm"
|
||||
and source.thread_id
|
||||
and not getattr(self.config, "thread_sessions_per_user", False)
|
||||
message_text = await self._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=history,
|
||||
)
|
||||
if _is_shared_thread and source.user_name:
|
||||
message_text = f"[{source.user_name}] {message_text}"
|
||||
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
# Check media_types if available; otherwise infer from message type
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
is_image = (
|
||||
mtype.startswith("image/")
|
||||
or event.message_type == MessageType.PHOTO
|
||||
)
|
||||
if is_image:
|
||||
image_paths.append(path)
|
||||
if image_paths:
|
||||
message_text = await self._enrich_message_with_vision(
|
||||
message_text, image_paths
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Auto-transcribe voice/audio messages sent by the user
|
||||
# -----------------------------------------------------------------
|
||||
if event.media_urls:
|
||||
audio_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
is_audio = (
|
||||
mtype.startswith("audio/")
|
||||
or event.message_type in (MessageType.VOICE, MessageType.AUDIO)
|
||||
)
|
||||
if is_audio:
|
||||
audio_paths.append(path)
|
||||
if audio_paths:
|
||||
message_text = await self._enrich_message_with_transcription(
|
||||
message_text, audio_paths
|
||||
)
|
||||
# If STT failed, send a direct message to the user so they
|
||||
# know voice isn't configured — don't rely on the agent to
|
||||
# relay the error clearly.
|
||||
_stt_fail_markers = (
|
||||
"No STT provider",
|
||||
"STT is disabled",
|
||||
"can't listen",
|
||||
"VOICE_TOOLS_OPENAI_KEY",
|
||||
)
|
||||
if any(m in message_text for m in _stt_fail_markers):
|
||||
_stt_adapter = self.adapters.get(source.platform)
|
||||
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
if _stt_adapter:
|
||||
try:
|
||||
_stt_msg = (
|
||||
"🎤 I received your voice message but can't transcribe it — "
|
||||
"no speech-to-text provider is configured.\n\n"
|
||||
"To enable voice: install faster-whisper "
|
||||
"(`pip install faster-whisper` in the Hermes venv) "
|
||||
"and set `stt.enabled: true` in config.yaml, "
|
||||
"then /restart the gateway."
|
||||
)
|
||||
# Point to setup skill if it's installed
|
||||
if self._has_setup_skill():
|
||||
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
|
||||
await _stt_adapter.send(
|
||||
source.chat_id, _stt_msg,
|
||||
metadata=_stt_meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Enrich document messages with context notes for the agent
|
||||
# -----------------------------------------------------------------
|
||||
if event.media_urls and event.message_type == MessageType.DOCUMENT:
|
||||
import mimetypes as _mimetypes
|
||||
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
|
||||
for i, path in enumerate(event.media_urls):
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
# Fall back to extension-based detection when MIME type is unreliable.
|
||||
if mtype in ("", "application/octet-stream"):
|
||||
import os as _os2
|
||||
_ext = _os2.path.splitext(path)[1].lower()
|
||||
if _ext in _TEXT_EXTENSIONS:
|
||||
mtype = "text/plain"
|
||||
else:
|
||||
guessed, _ = _mimetypes.guess_type(path)
|
||||
if guessed:
|
||||
mtype = guessed
|
||||
if not mtype.startswith(("application/", "text/")):
|
||||
continue
|
||||
# Extract display filename by stripping the doc_{uuid12}_ prefix
|
||||
import os as _os
|
||||
basename = _os.path.basename(path)
|
||||
# Format: doc_<12hex>_<original_filename>
|
||||
parts = basename.split("_", 2)
|
||||
display_name = parts[2] if len(parts) >= 3 else basename
|
||||
# Sanitize to prevent prompt injection via filenames
|
||||
import re as _re
|
||||
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
|
||||
if mtype.startswith("text/"):
|
||||
context_note = (
|
||||
f"[The user sent a text document: '{display_name}'. "
|
||||
f"Its content has been included below. "
|
||||
f"The file is also saved at: {path}]"
|
||||
)
|
||||
else:
|
||||
context_note = (
|
||||
f"[The user sent a document: '{display_name}'. "
|
||||
f"The file is saved at: {path}. "
|
||||
f"Ask the user what they'd like you to do with it.]"
|
||||
)
|
||||
message_text = f"{context_note}\n\n{message_text}"
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Inject reply context when user replies to a message not in history.
|
||||
# Telegram (and other platforms) let users reply to specific messages,
|
||||
# but if the quoted message is from a previous session, cron delivery,
|
||||
# or background task, the agent has no context about what's being
|
||||
# referenced. Prepend the quoted text so the agent understands. (#1594)
|
||||
# -----------------------------------------------------------------
|
||||
if getattr(event, 'reply_to_text', None) and event.reply_to_message_id:
|
||||
reply_snippet = event.reply_to_text[:500]
|
||||
found_in_history = any(
|
||||
reply_snippet[:200] in (msg.get("content") or "")
|
||||
for msg in history
|
||||
if msg.get("role") in ("assistant", "user", "tool")
|
||||
)
|
||||
if not found_in_history:
|
||||
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
|
||||
if message_text is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Emit agent:start hook
|
||||
|
|
@ -3349,30 +3391,6 @@ class GatewayRunner:
|
|||
}
|
||||
await self.hooks.emit("agent:start", hook_ctx)
|
||||
|
||||
# Expand @ context references (@file:, @folder:, @diff, etc.)
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model, base_url=self._base_url or "")
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text, cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
# Run the agent
|
||||
agent_result = await self._run_agent(
|
||||
message=message_text,
|
||||
|
|
@ -4010,25 +4028,31 @@ class GatewayRunner:
|
|||
handles /stop before this method is reached. This handler fires
|
||||
only through normal command dispatch (no running agent) or as a
|
||||
fallback. Force-clean the session lock in all cases for safety.
|
||||
|
||||
When there IS a running/pending agent, the session is also marked
|
||||
as *suspended* so the next message starts a fresh session instead
|
||||
of resuming the stuck context (#7536).
|
||||
"""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||
self.session_store.suspend_session(session_key)
|
||||
logger.info("HARD STOP (pending) for session %s — suspended, sentinel cleared", session_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — your next message will start fresh."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
return "⚡ Force-stopped. The session is unlocked — you can send a new message."
|
||||
self.session_store.suspend_session(session_key)
|
||||
return "⚡ Force-stopped. Your next message will start a fresh session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
|
|
@ -6694,6 +6718,8 @@ class GatewayRunner:
|
|||
chat_id=context.source.chat_id,
|
||||
chat_name=context.source.chat_name or "",
|
||||
thread_id=str(context.source.thread_id) if context.source.thread_id else "",
|
||||
user_id=str(context.source.user_id) if context.source.user_id else "",
|
||||
user_name=str(context.source.user_name) if context.source.user_name else "",
|
||||
)
|
||||
|
||||
def _clear_session_env(self, tokens: list) -> None:
|
||||
|
|
@ -6906,6 +6932,8 @@ class GatewayRunner:
|
|||
platform_name = watcher.get("platform", "")
|
||||
chat_id = watcher.get("chat_id", "")
|
||||
thread_id = watcher.get("thread_id", "")
|
||||
user_id = watcher.get("user_id", "")
|
||||
user_name = watcher.get("user_name", "")
|
||||
agent_notify = watcher.get("notify_on_complete", False)
|
||||
notify_mode = self._load_background_notifications_mode()
|
||||
|
||||
|
|
@ -6961,6 +6989,8 @@ class GatewayRunner:
|
|||
platform=_platform_enum,
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id or None,
|
||||
user_id=user_id or None,
|
||||
user_name=user_name or None,
|
||||
)
|
||||
synth_event = MessageEvent(
|
||||
text=synth_text,
|
||||
|
|
@ -8115,17 +8145,16 @@ class GatewayRunner:
|
|||
|
||||
# Get pending message from adapter.
|
||||
# Use session_key (not source.chat_id) to match adapter's storage keys.
|
||||
pending_event = None
|
||||
pending = None
|
||||
if result and adapter and session_key:
|
||||
if result.get("interrupted"):
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if not pending and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
else:
|
||||
pending = _dequeue_pending_text(adapter, session_key)
|
||||
if pending:
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
elif pending_event:
|
||||
pending = pending_event.text or _build_media_placeholder(pending_event)
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
# Safety net: if the pending text is a slash command (e.g. "/stop",
|
||||
# "/new"), discard it — commands should never be passed to the agent
|
||||
# as user input. The primary fix is in base.py (commands bypass the
|
||||
|
|
@ -8143,27 +8172,29 @@ class GatewayRunner:
|
|||
"commands must not be passed as agent input",
|
||||
_pending_cmd_word,
|
||||
)
|
||||
pending_event = None
|
||||
pending = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self._draining and pending:
|
||||
if self._draining and (pending_event or pending):
|
||||
logger.info(
|
||||
"Discarding pending follow-up for session %s during gateway %s",
|
||||
session_key[:20] if session_key else "?",
|
||||
self._status_action_label(),
|
||||
)
|
||||
pending_event = None
|
||||
pending = None
|
||||
|
||||
if pending:
|
||||
if pending_event or pending:
|
||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
# even makes its first API call (this was causing an infinite loop).
|
||||
if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions:
|
||||
adapter._active_sessions[session_key].clear()
|
||||
|
||||
|
||||
# Cap recursion depth to prevent resource exhaustion when the
|
||||
# user sends multiple messages while the agent keeps failing. (#816)
|
||||
if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH:
|
||||
|
|
@ -8172,9 +8203,10 @@ class GatewayRunner:
|
|||
"queueing message instead of recursing.",
|
||||
_interrupt_depth, session_key,
|
||||
)
|
||||
# Queue the pending message for normal processing on next turn
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, 'queue_message'):
|
||||
if adapter and pending_event:
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
|
||||
elif adapter and hasattr(adapter, 'queue_message'):
|
||||
adapter.queue_message(session_key, pending)
|
||||
return result_holder[0] or {"final_response": response, "messages": history}
|
||||
|
||||
|
|
@ -8189,23 +8221,37 @@ class GatewayRunner:
|
|||
if first_response and not _already_streamed:
|
||||
try:
|
||||
await adapter.send(source.chat_id, first_response,
|
||||
metadata=getattr(event, "metadata", None))
|
||||
metadata={"thread_id": source.thread_id} if source.thread_id else None)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send first response before queued message: %s", e)
|
||||
# else: interrupted — discard the interrupted response ("Operation
|
||||
# interrupted." is just noise; the user already knows they sent a
|
||||
# new message).
|
||||
|
||||
# Process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
next_source = source
|
||||
next_message = pending
|
||||
next_message_id = None
|
||||
if pending_event is not None:
|
||||
next_source = getattr(pending_event, "source", None) or source
|
||||
next_message = await self._prepare_inbound_message_text(
|
||||
event=pending_event,
|
||||
source=next_source,
|
||||
history=updated_history,
|
||||
)
|
||||
if next_message is None:
|
||||
return result
|
||||
next_message_id = getattr(pending_event, "message_id", None)
|
||||
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
message=next_message,
|
||||
context_prompt=context_prompt,
|
||||
history=updated_history,
|
||||
source=source,
|
||||
source=next_source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
_interrupt_depth=_interrupt_depth + 1,
|
||||
event_message_id=next_message_id,
|
||||
)
|
||||
finally:
|
||||
# Stop progress sender, interrupt monitor, and notification task
|
||||
|
|
|
|||
|
|
@ -368,6 +368,11 @@ class SessionEntry:
|
|||
# survives gateway restarts (the old in-memory _pre_flushed_sessions
|
||||
# set was lost on restart, causing redundant re-flushes).
|
||||
memory_flushed: bool = False
|
||||
|
||||
# When True the next call to get_or_create_session() will auto-reset
|
||||
# this session (create a new session_id) so the user starts fresh.
|
||||
# Set by /stop to break stuck-resume loops (#7536).
|
||||
suspended: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
|
|
@ -387,6 +392,7 @@ class SessionEntry:
|
|||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"cost_status": self.cost_status,
|
||||
"memory_flushed": self.memory_flushed,
|
||||
"suspended": self.suspended,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
|
|
@ -423,6 +429,7 @@ class SessionEntry:
|
|||
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
|
||||
cost_status=data.get("cost_status", "unknown"),
|
||||
memory_flushed=data.get("memory_flushed", False),
|
||||
suspended=data.get("suspended", False),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -698,7 +705,12 @@ class SessionStore:
|
|||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
# Auto-reset sessions marked as suspended (e.g. after /stop
|
||||
# broke a stuck loop — #7536).
|
||||
if entry.suspended:
|
||||
reset_reason = "suspended"
|
||||
else:
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
|
|
@ -771,6 +783,44 @@ class SessionStore:
|
|||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
self._save()
|
||||
|
||||
def suspend_session(self, session_key: str) -> bool:
|
||||
"""Mark a session as suspended so it auto-resets on next access.
|
||||
|
||||
Used by ``/stop`` to prevent stuck sessions from being resumed
|
||||
after a gateway restart (#7536). Returns True if the session
|
||||
existed and was marked.
|
||||
"""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
if session_key in self._entries:
|
||||
self._entries[session_key].suspended = True
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
Called on gateway startup to prevent sessions that were likely
|
||||
in-flight when the gateway last exited from being blindly resumed
|
||||
(#7536). Only suspends sessions updated within *max_age_seconds*
|
||||
to avoid resetting long-idle sessions that are harmless to resume.
|
||||
Returns the number of sessions that were suspended.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
cutoff = _time.time() - max_age_seconds
|
||||
count = 0
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for entry in self._entries.values():
|
||||
if not entry.suspended and entry.updated_at >= cutoff:
|
||||
entry.suspended = True
|
||||
count += 1
|
||||
if count:
|
||||
self._save()
|
||||
return count
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
db_end_session_id = None
|
||||
|
|
|
|||
|
|
@ -46,12 +46,16 @@ _SESSION_PLATFORM: ContextVar[str] = ContextVar("HERMES_SESSION_PLATFORM", defau
|
|||
_SESSION_CHAT_ID: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_ID", default="")
|
||||
_SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", default="")
|
||||
_SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="")
|
||||
_SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="")
|
||||
_SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="")
|
||||
|
||||
_VAR_MAP = {
|
||||
"HERMES_SESSION_PLATFORM": _SESSION_PLATFORM,
|
||||
"HERMES_SESSION_CHAT_ID": _SESSION_CHAT_ID,
|
||||
"HERMES_SESSION_CHAT_NAME": _SESSION_CHAT_NAME,
|
||||
"HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID,
|
||||
"HERMES_SESSION_USER_ID": _SESSION_USER_ID,
|
||||
"HERMES_SESSION_USER_NAME": _SESSION_USER_NAME,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -60,6 +64,8 @@ def set_session_vars(
|
|||
chat_id: str = "",
|
||||
chat_name: str = "",
|
||||
thread_id: str = "",
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
) -> list:
|
||||
"""Set all session context variables and return reset tokens.
|
||||
|
||||
|
|
@ -74,6 +80,8 @@ def set_session_vars(
|
|||
_SESSION_CHAT_ID.set(chat_id),
|
||||
_SESSION_CHAT_NAME.set(chat_name),
|
||||
_SESSION_THREAD_ID.set(thread_id),
|
||||
_SESSION_USER_ID.set(user_id),
|
||||
_SESSION_USER_NAME.set(user_name),
|
||||
]
|
||||
return tokens
|
||||
|
||||
|
|
@ -87,6 +95,8 @@ def clear_session_vars(tokens: list) -> None:
|
|||
_SESSION_CHAT_ID,
|
||||
_SESSION_CHAT_NAME,
|
||||
_SESSION_THREAD_ID,
|
||||
_SESSION_USER_ID,
|
||||
_SESSION_USER_NAME,
|
||||
]
|
||||
for var, token in zip(vars_in_order, tokens):
|
||||
var.reset(token)
|
||||
|
|
|
|||
|
|
@ -250,9 +250,39 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
|||
api_key_env_vars=("HF_TOKEN",),
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
"xiaomi": ProviderConfig(
|
||||
id="xiaomi",
|
||||
name="Xiaomi MiMo",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.xiaomimimo.com/v1",
|
||||
api_key_env_vars=("XIAOMI_API_KEY",),
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Key Helper
|
||||
# =============================================================================
|
||||
|
||||
def get_anthropic_key() -> str:
|
||||
"""Return the first usable Anthropic credential, or ``""``.
|
||||
|
||||
Checks both the ``.env`` file (via ``get_env_value``) and the process
|
||||
environment (``os.getenv``). The fallback order mirrors the
|
||||
``PROVIDER_REGISTRY["anthropic"].api_key_env_vars`` tuple:
|
||||
|
||||
ANTHROPIC_API_KEY -> ANTHROPIC_TOKEN -> CLAUDE_CODE_OAUTH_TOKEN
|
||||
"""
|
||||
from hermes_cli.config import get_env_value
|
||||
|
||||
for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars:
|
||||
value = get_env_value(var) or os.getenv(var, "")
|
||||
if value:
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
|
@ -908,6 +938,7 @@ def resolve_provider(
|
|||
"opencode": "opencode-zen", "zen": "opencode-zen",
|
||||
"qwen-portal": "qwen-oauth", "qwen-cli": "qwen-oauth", "qwen-oauth": "qwen-oauth",
|
||||
"hf": "huggingface", "hugging-face": "huggingface", "huggingface-hub": "huggingface",
|
||||
"mimo": "xiaomi", "xiaomi-mimo": "xiaomi",
|
||||
"go": "opencode-go", "opencode-go-sub": "opencode-go",
|
||||
"kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode",
|
||||
# Local server aliases — route through the generic custom provider
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""hermes claw — OpenClaw migration commands.
|
||||
|
||||
Usage:
|
||||
hermes claw migrate # Interactive migration from ~/.openclaw
|
||||
hermes claw migrate --dry-run # Preview what would be migrated
|
||||
hermes claw migrate # Preview then migrate (always shows preview first)
|
||||
hermes claw migrate --dry-run # Preview only, no changes
|
||||
hermes claw migrate --yes # Skip confirmation prompt
|
||||
hermes claw migrate --preset full --overwrite # Full migration, overwrite conflicts
|
||||
hermes claw cleanup # Archive leftover OpenClaw directories
|
||||
hermes claw cleanup --dry-run # Preview what would be archived
|
||||
|
|
@ -51,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = (
|
|||
# Known OpenClaw directory names (current + legacy)
|
||||
_OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot")
|
||||
|
||||
def _warn_if_gateway_running(auto_yes: bool) -> None:
|
||||
"""Check if a Hermes gateway is running with connected platforms.
|
||||
|
||||
Migrating bot tokens while the gateway is polling will cause conflicts
|
||||
(e.g. Telegram 409 "terminated by other getUpdates request"). Warn the
|
||||
user and let them decide whether to continue.
|
||||
"""
|
||||
from gateway.status import get_running_pid, read_runtime_status
|
||||
|
||||
if not get_running_pid():
|
||||
return
|
||||
|
||||
data = read_runtime_status() or {}
|
||||
platforms = data.get("platforms") or {}
|
||||
connected = [name for name, info in platforms.items()
|
||||
if isinstance(info, dict) and info.get("state") == "connected"]
|
||||
if not connected:
|
||||
return
|
||||
|
||||
print()
|
||||
print_error(
|
||||
"Hermes gateway is running with active connections: "
|
||||
+ ", ".join(connected)
|
||||
)
|
||||
print_info(
|
||||
"Migrating bot tokens while the gateway is active will cause "
|
||||
"conflicts (Telegram, Discord, and Slack only allow one active "
|
||||
"session per token)."
|
||||
)
|
||||
print_info("Recommendation: stop the gateway first with 'hermes stop'.")
|
||||
print()
|
||||
if not auto_yes and not prompt_yes_no("Continue anyway?", default=False):
|
||||
print_info("Migration cancelled. Stop the gateway and try again.")
|
||||
sys.exit(0)
|
||||
|
||||
# State files commonly found in OpenClaw workspace directories that cause
|
||||
# confusion after migration (the agent discovers them and writes to them)
|
||||
_WORKSPACE_STATE_GLOBS = (
|
||||
|
|
@ -237,12 +273,12 @@ def _cmd_migrate(args):
|
|||
|
||||
# Show what we're doing
|
||||
hermes_home = get_hermes_home()
|
||||
auto_yes = getattr(args, "yes", False)
|
||||
print()
|
||||
print_header("Migration Settings")
|
||||
print_info(f"Source: {source_dir}")
|
||||
print_info(f"Target: {hermes_home}")
|
||||
print_info(f"Preset: {preset}")
|
||||
print_info(f"Mode: {'dry run (preview only)' if dry_run else 'execute'}")
|
||||
print_info(f"Overwrite: {'yes' if overwrite else 'no (skip conflicts)'}")
|
||||
print_info(f"Secrets: {'yes (allowlisted only)' if migrate_secrets else 'no'}")
|
||||
if skill_conflict != "skip":
|
||||
|
|
@ -251,31 +287,85 @@ def _cmd_migrate(args):
|
|||
print_info(f"Workspace: {workspace_target}")
|
||||
print()
|
||||
|
||||
# For execute mode (non-dry-run), confirm unless --yes was passed
|
||||
if not dry_run and not getattr(args, "yes", False):
|
||||
if not prompt_yes_no("Proceed with migration?", default=True):
|
||||
print_info("Migration cancelled.")
|
||||
return
|
||||
# Check if a gateway is running with connected platforms — migrating tokens
|
||||
# while the gateway is active will cause conflicts (e.g. Telegram 409).
|
||||
_warn_if_gateway_running(auto_yes)
|
||||
|
||||
# Ensure config.yaml exists before migration tries to read it
|
||||
config_path = get_config_path()
|
||||
if not config_path.exists():
|
||||
save_config(load_config())
|
||||
|
||||
# Load and run the migration
|
||||
# Load the migration module
|
||||
try:
|
||||
mod = _load_migration_module(script_path)
|
||||
if mod is None:
|
||||
print_error("Could not load migration script.")
|
||||
return
|
||||
except Exception as e:
|
||||
print()
|
||||
print_error(f"Could not load migration script: {e}")
|
||||
logger.debug("OpenClaw migration error", exc_info=True)
|
||||
return
|
||||
|
||||
selected = mod.resolve_selected_options(None, None, preset=preset)
|
||||
ws_target = Path(workspace_target).resolve() if workspace_target else None
|
||||
selected = mod.resolve_selected_options(None, None, preset=preset)
|
||||
ws_target = Path(workspace_target).resolve() if workspace_target else None
|
||||
|
||||
# ── Phase 1: Always preview first ──────────────────────────
|
||||
try:
|
||||
preview = mod.Migrator(
|
||||
source_root=source_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=False,
|
||||
workspace_target=ws_target,
|
||||
overwrite=overwrite,
|
||||
migrate_secrets=migrate_secrets,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
preset_name=preset,
|
||||
skill_conflict_mode=skill_conflict,
|
||||
)
|
||||
preview_report = preview.migrate()
|
||||
except Exception as e:
|
||||
print()
|
||||
print_error(f"Migration preview failed: {e}")
|
||||
logger.debug("OpenClaw migration preview error", exc_info=True)
|
||||
return
|
||||
|
||||
preview_summary = preview_report.get("summary", {})
|
||||
preview_count = preview_summary.get("migrated", 0)
|
||||
|
||||
if preview_count == 0:
|
||||
print()
|
||||
print_info("Nothing to migrate from OpenClaw.")
|
||||
_print_migration_report(preview_report, dry_run=True)
|
||||
return
|
||||
|
||||
print()
|
||||
print_header(f"Migration Preview — {preview_count} item(s) would be imported")
|
||||
print_info("No changes have been made yet. Review the list below:")
|
||||
_print_migration_report(preview_report, dry_run=True)
|
||||
|
||||
# If --dry-run, stop here
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
# ── Phase 2: Confirm and execute ───────────────────────────
|
||||
print()
|
||||
if not auto_yes:
|
||||
if not sys.stdin.isatty():
|
||||
print_info("Non-interactive session — preview only.")
|
||||
print_info("To execute, re-run with: hermes claw migrate --yes")
|
||||
return
|
||||
if not prompt_yes_no("Proceed with migration?", default=True):
|
||||
print_info("Migration cancelled.")
|
||||
return
|
||||
|
||||
try:
|
||||
migrator = mod.Migrator(
|
||||
source_root=source_dir.resolve(),
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=not dry_run,
|
||||
execute=True,
|
||||
workspace_target=ws_target,
|
||||
overwrite=overwrite,
|
||||
migrate_secrets=migrate_secrets,
|
||||
|
|
@ -292,11 +382,11 @@ def _cmd_migrate(args):
|
|||
return
|
||||
|
||||
# Print results
|
||||
_print_migration_report(report, dry_run)
|
||||
_print_migration_report(report, dry_run=False)
|
||||
|
||||
# After successful non-dry-run migration, offer to archive the source directory
|
||||
if not dry_run and report.get("summary", {}).get("migrated", 0) > 0:
|
||||
_offer_source_archival(source_dir, getattr(args, "yes", False))
|
||||
# After successful migration, offer to archive the source directory
|
||||
if report.get("summary", {}).get("migrated", 0) > 0:
|
||||
_offer_source_archival(source_dir, auto_yes)
|
||||
|
||||
|
||||
def _offer_source_archival(source_dir: Path, auto_yes: bool = False):
|
||||
|
|
@ -330,6 +420,11 @@ def _offer_source_archival(source_dir: Path, auto_yes: bool = False):
|
|||
print_info("You can always rename it back if needed.")
|
||||
print()
|
||||
|
||||
if not auto_yes and not sys.stdin.isatty():
|
||||
print_info("Non-interactive session — skipping archival.")
|
||||
print_info("Run later with: hermes claw cleanup")
|
||||
return
|
||||
|
||||
if auto_yes or prompt_yes_no(f"Archive {source_dir} now?", default=True):
|
||||
try:
|
||||
archive_path = _archive_directory(source_dir)
|
||||
|
|
@ -433,6 +528,9 @@ def _cmd_cleanup(args):
|
|||
if dry_run:
|
||||
archive_path = _archive_directory(source_dir, dry_run=True)
|
||||
print_info(f"Would archive: {source_dir} → {archive_path}")
|
||||
elif not auto_yes and not sys.stdin.isatty():
|
||||
print_info(f"Non-interactive session — would archive: {source_dir}")
|
||||
print_info("To execute, re-run with: hermes claw cleanup --yes")
|
||||
else:
|
||||
if auto_yes or prompt_yes_no(f"Archive {source_dir}?", default=True):
|
||||
try:
|
||||
|
|
|
|||
79
hermes_cli/cli_output.py
Normal file
79
hermes_cli/cli_output.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Shared CLI output helpers for Hermes CLI modules.
|
||||
|
||||
Extracts the identical ``print_info/success/warning/error`` and ``prompt()``
|
||||
functions previously duplicated across setup.py, tools_config.py,
|
||||
mcp_config.py, and memory_setup.py.
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
|
||||
# ─── Print Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def print_info(text: str) -> None:
|
||||
"""Print a dim informational message."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str) -> None:
|
||||
"""Print a green success message with ✓ prefix."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str) -> None:
|
||||
"""Print a yellow warning message with ⚠ prefix."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str) -> None:
|
||||
"""Print a red error message with ✗ prefix."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def print_header(text: str) -> None:
|
||||
"""Print a bold yellow header."""
|
||||
print(color(f"\n {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
# ─── Input Prompts ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def prompt(
|
||||
question: str,
|
||||
default: str | None = None,
|
||||
password: bool = False,
|
||||
) -> str:
|
||||
"""Prompt the user for input with optional default and password masking.
|
||||
|
||||
Replaces the four independent ``_prompt()`` / ``prompt()`` implementations
|
||||
in setup.py, tools_config.py, mcp_config.py, and memory_setup.py.
|
||||
|
||||
Returns the user's input (stripped), or *default* if the user presses Enter.
|
||||
Returns empty string on Ctrl-C or EOF.
|
||||
"""
|
||||
suffix = f" [{default}]" if default else ""
|
||||
display = color(f" {question}{suffix}: ", Colors.YELLOW)
|
||||
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(display)
|
||||
else:
|
||||
value = input(display)
|
||||
value = value.strip()
|
||||
return value if value else (default or "")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return ""
|
||||
|
||||
|
||||
def prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
"""Prompt for a yes/no answer. Returns bool."""
|
||||
hint = "Y/n" if default else "y/N"
|
||||
answer = prompt(f"{question} ({hint})")
|
||||
if not answer:
|
||||
return default
|
||||
return answer.lower().startswith("y")
|
||||
|
|
@ -32,7 +32,6 @@ _ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|||
_EXTRA_ENV_KEYS = frozenset({
|
||||
"OPENAI_API_KEY", "OPENAI_BASE_URL",
|
||||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
|
||||
"AUXILIARY_VISION_MODEL",
|
||||
"DISCORD_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL",
|
||||
"SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
|
|
@ -868,6 +867,21 @@ OPTIONAL_ENV_VARS = {
|
|||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"XIAOMI_API_KEY": {
|
||||
"description": "Xiaomi MiMo API key for MiMo models (mimo-v2-pro, mimo-v2-omni, mimo-v2-flash)",
|
||||
"prompt": "Xiaomi MiMo API Key",
|
||||
"url": "https://platform.xiaomimimo.com",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"XIAOMI_BASE_URL": {
|
||||
"description": "Xiaomi MiMo base URL override (default: https://api.xiaomimimo.com/v1)",
|
||||
"prompt": "Xiaomi base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"EXA_API_KEY": {
|
||||
|
|
@ -1483,7 +1497,7 @@ _KNOWN_ROOT_KEYS = {
|
|||
|
||||
# Valid fields inside a custom_providers list entry
|
||||
_VALID_CUSTOM_PROVIDER_FIELDS = {
|
||||
"name", "base_url", "api_key", "api_mode", "models",
|
||||
"name", "base_url", "api_key", "api_mode", "model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
}
|
||||
|
||||
|
|
@ -2568,7 +2582,8 @@ def show_config():
|
|||
for env_key, name in keys:
|
||||
value = get_env_value(env_key)
|
||||
print(f" {name:<14} {redact_key(value)}")
|
||||
anthropic_value = get_env_value("ANTHROPIC_TOKEN") or get_env_value("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
print(f" {'Anthropic':<14} {redact_key(anthropic_value)}")
|
||||
|
||||
# Model settings
|
||||
|
|
@ -2784,8 +2799,8 @@ def set_config_value(key: str, value: str):
|
|||
|
||||
# Write only user config back (not the full merged defaults)
|
||||
ensure_hermes_home()
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, user_config, sort_keys=False)
|
||||
|
||||
# Keep .env in sync for keys that terminal_tool reads directly from env vars.
|
||||
# config.yaml is authoritative, but terminal_tool only reads TERMINAL_ENV etc.
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ _PROVIDER_ENV_HINTS = (
|
|||
"AI_GATEWAY_API_KEY",
|
||||
"OPENCODE_ZEN_API_KEY",
|
||||
"OPENCODE_GO_API_KEY",
|
||||
"XIAOMI_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -335,8 +336,8 @@ def run_doctor(args):
|
|||
model_section[k] = raw_config.pop(k)
|
||||
else:
|
||||
raw_config.pop(k)
|
||||
with open(config_path, "w") as f:
|
||||
yaml.dump(raw_config, f, default_flow_style=False)
|
||||
from utils import atomic_yaml_write
|
||||
atomic_yaml_write(config_path, raw_config)
|
||||
check_ok("Migrated stale root-level keys into model section")
|
||||
fixed_count += 1
|
||||
else:
|
||||
|
|
@ -685,7 +686,8 @@ def run_doctor(args):
|
|||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_TOKEN") or os.getenv("ANTHROPIC_API_KEY")
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_key = get_anthropic_key()
|
||||
if anthropic_key:
|
||||
print(" Checking Anthropic API...", end="", flush=True)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
pids = []
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
"hermes_cli.main -p",
|
||||
"hermes_cli/main.py gateway",
|
||||
"hermes_cli/main.py --profile",
|
||||
"hermes_cli/main.py -p",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
]
|
||||
current_home = str(get_hermes_home().resolve())
|
||||
current_profile_arg = _profile_arg(current_home)
|
||||
current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else ""
|
||||
|
||||
def _matches_current_profile(command: str) -> bool:
|
||||
if current_profile_name:
|
||||
return (
|
||||
f"--profile {current_profile_name}" in command
|
||||
or f"-p {current_profile_name}" in command
|
||||
or f"HERMES_HOME={current_home}" in command
|
||||
)
|
||||
|
||||
if "--profile " in command or " -p " in command:
|
||||
return False
|
||||
if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
if is_windows():
|
||||
# Windows: use wmic to search command lines
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
# Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n"
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
line = line.strip()
|
||||
|
|
@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
|||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns):
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
|
|
@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list:
|
|||
current_cmd = ""
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
["ps", "eww", "-ax", "-o", "pid=,command="],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
# Skip grep and current process
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
continue
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
|
||||
pid = None
|
||||
command = ""
|
||||
|
||||
parts = stripped.split(None, 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
pid = int(parts[0])
|
||||
command = parts[1]
|
||||
except ValueError:
|
||||
pid = None
|
||||
|
||||
if pid is None:
|
||||
aux_parts = stripped.split()
|
||||
if len(aux_parts) > 10 and aux_parts[1].isdigit():
|
||||
pid = int(aux_parts[1])
|
||||
command = " ".join(aux_parts[10:])
|
||||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int:
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
|
||||
Args:
|
||||
force: Use the platform's force-kill mechanism instead of graceful terminate.
|
||||
exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just
|
||||
restarted and should not be killed).
|
||||
all_profiles: When ``True``, kill across all profiles. Passed
|
||||
through to :func:`find_gateway_pids`.
|
||||
"""
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids)
|
||||
pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles)
|
||||
killed = 0
|
||||
|
||||
for pid in pids:
|
||||
|
|
@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None:
|
|||
print(" If you want the gateway user service to survive logout, run:")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
def _launchd_user_home() -> Path:
|
||||
"""Return the real macOS user home for launchd artifacts.
|
||||
|
||||
Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but
|
||||
launchd user agents still live under the actual account home.
|
||||
"""
|
||||
import pwd
|
||||
|
||||
return Path(pwd.getpwuid(os.getuid()).pw_dir)
|
||||
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
"""Return the launchd plist path, scoped per profile.
|
||||
|
||||
|
|
@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path:
|
|||
"""
|
||||
suffix = _profile_suffix()
|
||||
name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway"
|
||||
return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist"
|
||||
|
||||
def _detect_venv_dir() -> Path | None:
|
||||
"""Detect the active virtualenv directory.
|
||||
|
|
@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str:
|
|||
return "\n".join(line.rstrip() for line in text.strip().splitlines())
|
||||
|
||||
|
||||
def _normalize_launchd_plist_for_comparison(text: str) -> str:
|
||||
"""Normalize launchd plist text for staleness checks.
|
||||
|
||||
The generated plist intentionally captures a broad PATH assembled from the
|
||||
invoking shell so user-installed tools remain reachable under launchd.
|
||||
That makes raw text comparison unstable across shells, so ignore the PATH
|
||||
payload when deciding whether the installed plist is stale.
|
||||
"""
|
||||
import re
|
||||
|
||||
normalized = _normalize_service_definition(text)
|
||||
return re.sub(
|
||||
r'(<key>PATH</key>\s*<string>)(.*?)(</string>)',
|
||||
r'\1__HERMES_PATH__\3',
|
||||
normalized,
|
||||
flags=re.S,
|
||||
)
|
||||
|
||||
|
||||
def systemd_unit_is_current(system: bool = False) -> bool:
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if not unit_path.exists():
|
||||
|
|
@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool:
|
|||
|
||||
installed = plist_path.read_text(encoding="utf-8")
|
||||
expected = generate_launchd_plist()
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected)
|
||||
|
||||
|
||||
def refresh_launchd_plist_if_needed() -> bool:
|
||||
|
|
@ -1981,6 +2051,36 @@ def _setup_whatsapp():
|
|||
cmd_whatsapp(argparse.Namespace())
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via the standard platform setup."""
|
||||
email_platform = next(p for p in _PLATFORMS if p["key"] == "email")
|
||||
_setup_standard_platform(email_platform)
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via the standard platform setup."""
|
||||
sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms")
|
||||
_setup_standard_platform(sms_platform)
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via the standard platform setup."""
|
||||
dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk")
|
||||
_setup_standard_platform(dingtalk_platform)
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via the standard platform setup."""
|
||||
feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu")
|
||||
_setup_standard_platform(feishu_platform)
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via the standard platform setup."""
|
||||
wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom")
|
||||
_setup_standard_platform(wecom_platform)
|
||||
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if supports_systemd_services():
|
||||
|
|
@ -2540,7 +2640,7 @@ def gateway_command(args):
|
|||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
killed = kill_gateway_processes()
|
||||
killed = kill_gateway_processes(all_profiles=True)
|
||||
total = killed + (1 if service_available else 0)
|
||||
if total:
|
||||
print(f"✓ Stopped {total} gateway process(es) across all profiles")
|
||||
|
|
|
|||
|
|
@ -606,18 +606,58 @@ def _print_tui_exit_summary(session_id: Optional[str]) -> None:
|
|||
)
|
||||
|
||||
|
||||
def _find_bundled_tui() -> Optional[Path]:
|
||||
"""Find a bundled copy of the TUI.
|
||||
Does *not* read from the `npm run build` dist dir,
|
||||
as this would be a footgun when developing
|
||||
"""
|
||||
bundled_tui_dir = os.environ.get("HERMES_TUI_DIR")
|
||||
if bundled_tui_dir and (Path(bundled_tui_dir) / "dist" / "entry.js").exists():
|
||||
return Path(bundled_tui_dir)
|
||||
def _find_bundled_tui(tui_dir: Path) -> Optional[Path]:
|
||||
"""Directory whose dist/entry.js we should run: HERMES_TUI_DIR first, else repo ui-tui."""
|
||||
env = os.environ.get("HERMES_TUI_DIR")
|
||||
if env:
|
||||
p = Path(env)
|
||||
if (p / "dist" / "entry.js").exists():
|
||||
return p
|
||||
if (tui_dir / "dist" / "entry.js").exists():
|
||||
return tui_dir
|
||||
return None
|
||||
|
||||
def _make_tui_argv(tui_dir: Path) -> tuple[list[str], Path]:
|
||||
"""Gets argv to run tui + the working directory. Will npm install deps in dev mode."""
|
||||
|
||||
def _tui_build_needed(tui_dir: Path) -> bool:
|
||||
entry = tui_dir / "dist" / "entry.js"
|
||||
if not entry.exists():
|
||||
return True
|
||||
dist_m = entry.stat().st_mtime
|
||||
skip = frozenset({"node_modules", "dist"})
|
||||
for dirpath, dirnames, filenames in os.walk(tui_dir, topdown=True):
|
||||
dirnames[:] = [d for d in dirnames if d not in skip]
|
||||
for fn in filenames:
|
||||
if fn.endswith((".ts", ".tsx")):
|
||||
if os.path.getmtime(os.path.join(dirpath, fn)) > dist_m:
|
||||
return True
|
||||
for meta in ("package.json", "package-lock.json", "tsconfig.json", "tsconfig.build.json"):
|
||||
mp = tui_dir / meta
|
||||
if mp.exists() and mp.stat().st_mtime > dist_m:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _hermes_ink_bundle_stale(tui_dir: Path) -> bool:
|
||||
ink_root = tui_dir / "packages" / "hermes-ink"
|
||||
bundle = ink_root / "dist" / "ink-bundle.js"
|
||||
if not bundle.exists():
|
||||
return True
|
||||
bm = bundle.stat().st_mtime
|
||||
skip = frozenset({"node_modules", "dist"})
|
||||
for dirpath, dirnames, filenames in os.walk(ink_root, topdown=True):
|
||||
dirnames[:] = [d for d in dirnames if d not in skip]
|
||||
for fn in filenames:
|
||||
if fn.endswith((".ts", ".tsx")):
|
||||
if os.path.getmtime(os.path.join(dirpath, fn)) > bm:
|
||||
return True
|
||||
mp = ink_root / "package.json"
|
||||
if mp.exists() and mp.stat().st_mtime > bm:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _make_tui_argv(tui_dir: Path, tui_dev: bool) -> tuple[list[str], Path]:
|
||||
"""Ink TUI: --dev → tsx src; else node dist (HERMES_TUI_DIR or ui-tui, build when stale)."""
|
||||
def _node_bin(bin: str)-> str:
|
||||
path = shutil.which(bin)
|
||||
if not path:
|
||||
|
|
@ -625,17 +665,15 @@ def _make_tui_argv(tui_dir: Path) -> tuple[list[str], Path]:
|
|||
sys.exit(1)
|
||||
return path
|
||||
|
||||
# use prebuilt TUI if it exists
|
||||
bundled = _find_bundled_tui()
|
||||
if bundled:
|
||||
node = _node_bin("node")
|
||||
return [node, str(bundled / "dist" / "entry.js")], bundled
|
||||
# pre-built dist (nix / HERMES_TUI_DIR) needs no npm at all.
|
||||
if not tui_dev:
|
||||
bundled = _find_bundled_tui(tui_dir)
|
||||
if bundled:
|
||||
node = _node_bin("node")
|
||||
return [node, str(bundled / "dist" / "entry.js")], bundled
|
||||
|
||||
# dev mode - run via tsx
|
||||
|
||||
# install deps if needed
|
||||
npm = _node_bin("npm")
|
||||
if not (tui_dir / "node_modules").exists():
|
||||
npm = _node_bin("npm")
|
||||
print("Installing TUI dependencies…")
|
||||
result = subprocess.run(
|
||||
[npm, "install", "--silent", "--no-fund", "--no-audit", "--progress=false"],
|
||||
|
|
@ -652,23 +690,60 @@ def _make_tui_argv(tui_dir: Path) -> tuple[list[str], Path]:
|
|||
print(preview)
|
||||
sys.exit(1)
|
||||
|
||||
tsx = tui_dir / "node_modules" / ".bin" / "tsx"
|
||||
if tsx.exists():
|
||||
return [str(tsx), "src/entry.tsx"], tui_dir
|
||||
if tui_dev:
|
||||
if _hermes_ink_bundle_stale(tui_dir):
|
||||
result = subprocess.run(
|
||||
[npm, "run", "build", "--prefix", "packages/hermes-ink"],
|
||||
cwd=str(tui_dir),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
combined = f"{result.stdout or ''}{result.stderr or ''}".strip()
|
||||
preview = "\n".join(combined.splitlines()[-30:])
|
||||
print("@hermes/ink build failed.")
|
||||
if preview:
|
||||
print(preview)
|
||||
sys.exit(1)
|
||||
tsx = tui_dir / "node_modules" / ".bin" / "tsx"
|
||||
if tsx.exists():
|
||||
return [str(tsx), "src/entry.tsx"], tui_dir
|
||||
return [npm, "start"], tui_dir
|
||||
|
||||
npm = _node_bin("npm")
|
||||
return [npm, "start"], tui_dir
|
||||
if _tui_build_needed(tui_dir):
|
||||
result = subprocess.run(
|
||||
[npm, "run", "build"],
|
||||
cwd=str(tui_dir),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
combined = f"{result.stdout or ''}{result.stderr or ''}".strip()
|
||||
preview = "\n".join(combined.splitlines()[-30:])
|
||||
print("TUI build failed.")
|
||||
if preview:
|
||||
print(preview)
|
||||
sys.exit(1)
|
||||
|
||||
def _launch_tui(resume_session_id: Optional[str] = None):
|
||||
root = _find_bundled_tui(tui_dir)
|
||||
if not root:
|
||||
print("TUI build did not produce dist/entry.js")
|
||||
sys.exit(1)
|
||||
|
||||
node = _node_bin("node")
|
||||
return [node, str(root / "dist" / "entry.js")], root
|
||||
|
||||
def _launch_tui(resume_session_id: Optional[str] = None, tui_dev: bool = False):
|
||||
"""Replace current process with the Ink TUI."""
|
||||
tui_dir = PROJECT_ROOT / "ui-tui"
|
||||
|
||||
env = os.environ.copy()
|
||||
env["HERMES_ROOT"] = os.environ.get("HERMES_ROOT", str(PROJECT_ROOT))
|
||||
env["HERMES_PYTHON_SRC_ROOT"] = os.environ.get("HERMES_PYTHON_SRC_ROOT", str(PROJECT_ROOT))
|
||||
env.setdefault("HERMES_CWD", os.getcwd())
|
||||
if resume_session_id:
|
||||
env["HERMES_TUI_RESUME"] = resume_session_id
|
||||
|
||||
argv, cwd = _make_tui_argv(tui_dir)
|
||||
argv, cwd = _make_tui_argv(tui_dir, tui_dev)
|
||||
try:
|
||||
code = subprocess.call(argv, cwd=str(cwd), env=env)
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -718,9 +793,6 @@ def cmd_chat(args):
|
|||
# If resolution fails, keep the original value — _init_agent will
|
||||
# report "Session not found" with the original input
|
||||
|
||||
if use_tui:
|
||||
_launch_tui(getattr(args, "resume", None))
|
||||
|
||||
# First-run guard: check if any provider is configured before launching
|
||||
if not _has_any_provider_configured():
|
||||
print()
|
||||
|
|
@ -770,6 +842,13 @@ def cmd_chat(args):
|
|||
if getattr(args, "source", None):
|
||||
os.environ["HERMES_SESSION_SOURCE"] = args.source
|
||||
|
||||
|
||||
if use_tui:
|
||||
_launch_tui(
|
||||
getattr(args, "resume", None),
|
||||
tui_dev=getattr(args, "tui_dev", False),
|
||||
)
|
||||
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
|
|
@ -1069,6 +1148,7 @@ def select_provider_and_model(args=None):
|
|||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"huggingface": "Hugging Face",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active) if active else "none"
|
||||
|
|
@ -1101,6 +1181,7 @@ def select_provider_and_model(args=None):
|
|||
("opencode-go", "OpenCode Go (open models, $10/month subscription)"),
|
||||
("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"),
|
||||
("alibaba", "Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"),
|
||||
("xiaomi", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
]
|
||||
|
||||
def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]:
|
||||
|
|
@ -1212,7 +1293,7 @@ def select_provider_and_model(args=None):
|
|||
_model_flow_anthropic(config, current_model)
|
||||
elif selected_provider == "kimi-coding":
|
||||
_model_flow_kimi(config, current_model)
|
||||
elif selected_provider in ("gemini", "zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface"):
|
||||
elif selected_provider in ("gemini", "zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface", "xiaomi"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
# ── Post-switch cleanup: clear stale OPENAI_BASE_URL ──────────────
|
||||
|
|
@ -2682,13 +2763,8 @@ def _model_flow_anthropic(config, current_model=""):
|
|||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
|
||||
# Check ALL credential sources
|
||||
existing_key = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or os.getenv("ANTHROPIC_TOKEN", "")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or os.getenv("ANTHROPIC_API_KEY", "")
|
||||
or os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "")
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
existing_key = get_anthropic_key()
|
||||
cc_available = False
|
||||
try:
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, is_claude_code_token_valid
|
||||
|
|
@ -3062,6 +3138,8 @@ def _update_via_zip(args):
|
|||
)
|
||||
_install_python_dependencies_with_optional_fallback(pip_cmd)
|
||||
|
||||
_update_node_dependencies()
|
||||
|
||||
# Sync skills
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
|
|
@ -3581,9 +3659,42 @@ def _install_python_dependencies_with_optional_fallback(
|
|||
print(f" ⚠ Skipped optional extras that still failed: {', '.join(failed_extras)}")
|
||||
|
||||
|
||||
def _update_node_dependencies() -> None:
|
||||
npm = shutil.which("npm")
|
||||
if not npm:
|
||||
return
|
||||
|
||||
paths = (
|
||||
("repo root", PROJECT_ROOT),
|
||||
("ui-tui", PROJECT_ROOT / "ui-tui"),
|
||||
)
|
||||
if not any((path / "package.json").exists() for _, path in paths):
|
||||
return
|
||||
|
||||
print("→ Updating Node.js dependencies...")
|
||||
for label, path in paths:
|
||||
if not (path / "package.json").exists():
|
||||
continue
|
||||
|
||||
result = subprocess.run(
|
||||
[npm, "install", "--silent", "--no-fund", "--no-audit", "--progress=false"],
|
||||
cwd=path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print(f" ✓ {label}")
|
||||
continue
|
||||
|
||||
print(f" ⚠ npm install failed in {label}")
|
||||
stderr = (result.stderr or "").strip()
|
||||
if stderr:
|
||||
print(f" {stderr.splitlines()[-1]}")
|
||||
|
||||
|
||||
def cmd_update(args):
|
||||
"""Update Hermes Agent to the latest version."""
|
||||
import shutil
|
||||
from hermes_cli.config import is_managed, managed_error
|
||||
|
||||
if is_managed():
|
||||
|
|
@ -3802,13 +3913,8 @@ def cmd_update(args):
|
|||
)
|
||||
_install_python_dependencies_with_optional_fallback(pip_cmd)
|
||||
|
||||
# Check for Node.js deps
|
||||
if (PROJECT_ROOT / "package.json").exists():
|
||||
import shutil
|
||||
if shutil.which("npm"):
|
||||
print("→ Updating Node.js dependencies...")
|
||||
subprocess.run(["npm", "install", "--silent"], cwd=PROJECT_ROOT, check=False)
|
||||
|
||||
_update_node_dependencies()
|
||||
|
||||
print()
|
||||
print("✓ Code updated!")
|
||||
|
||||
|
|
@ -4014,7 +4120,7 @@ def cmd_update(args):
|
|||
# Exclude PIDs that belong to just-restarted services so we don't
|
||||
# immediately kill the process that systemd/launchd just spawned.
|
||||
service_pids = _get_service_pids()
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids)
|
||||
manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True)
|
||||
for pid in manual_pids:
|
||||
try:
|
||||
os.kill(pid, _signal.SIGTERM)
|
||||
|
|
@ -4463,7 +4569,14 @@ For more help on a command:
|
|||
default=False,
|
||||
help="Launch the Ink-based terminal UI instead of the classic REPL"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dev",
|
||||
dest="tui_dev",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="With --tui: run TypeScript sources via tsx (skip dist build)",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -4498,7 +4611,7 @@ For more help on a command:
|
|||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "gemini", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "gemini", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "xiaomi"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
|
|
@ -4569,6 +4682,13 @@ For more help on a command:
|
|||
default=False,
|
||||
help="Launch the Ink-based terminal UI instead of the classic REPL"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--dev",
|
||||
dest="tui_dev",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="With --tui: run TypeScript sources via tsx (skip dist build)",
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -5558,7 +5678,8 @@ For more help on a command:
|
|||
claw_migrate = claw_subparsers.add_parser(
|
||||
"migrate",
|
||||
help="Migrate from OpenClaw to Hermes",
|
||||
description="Import settings, memories, skills, and API keys from an OpenClaw installation"
|
||||
description="Import settings, memories, skills, and API keys from an OpenClaw installation. "
|
||||
"Always shows a preview before making changes."
|
||||
)
|
||||
claw_migrate.add_argument(
|
||||
"--source",
|
||||
|
|
@ -5567,7 +5688,7 @@ For more help on a command:
|
|||
claw_migrate.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Preview what would be migrated without making changes"
|
||||
help="Preview only — stop after showing what would be migrated"
|
||||
)
|
||||
claw_migrate.add_argument(
|
||||
"--preset",
|
||||
|
|
|
|||
|
|
@ -57,19 +57,8 @@ def _confirm(question: str, default: bool = True) -> bool:
|
|||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
from hermes_cli.cli_output import prompt as _shared_prompt
|
||||
return _shared_prompt(question, default=default, password=password)
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -25,85 +25,13 @@ def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -
|
|||
items: list of (label, description) tuples.
|
||||
Returns selected index, or default on escape/quit.
|
||||
"""
|
||||
try:
|
||||
import curses
|
||||
result = [default]
|
||||
|
||||
def _menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
curses.init_pair(3, curses.COLOR_CYAN, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Title
|
||||
try:
|
||||
stdscr.addnstr(0, 0, title, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
stdscr.addnstr(1, 0, " ↑↓ navigate ⏎ select q quit", max_x - 1,
|
||||
curses.color_pair(3) if curses.has_colors() else curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, (label, desc) in enumerate(items):
|
||||
y = i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {label}"
|
||||
if desc:
|
||||
line += f" {desc}"
|
||||
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line[:max_x - 1], max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(items)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(items)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_menu)
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Fallback: numbered input
|
||||
print(f"\n {title}\n")
|
||||
for i, (label, desc) in enumerate(items):
|
||||
marker = "→" if i == default else " "
|
||||
d = f" {desc}" if desc else ""
|
||||
print(f" {marker} {i + 1}. {label}{d}")
|
||||
while True:
|
||||
try:
|
||||
val = input(f"\n Select [1-{len(items)}] ({default + 1}): ")
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(items):
|
||||
return idx
|
||||
except (ValueError, EOFError):
|
||||
return default
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
# Format (label, desc) tuples into display strings
|
||||
display_items = [
|
||||
f"{label} {desc}" if desc else label
|
||||
for label, desc in items
|
||||
]
|
||||
return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ _MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({
|
|||
"minimax-cn",
|
||||
"alibaba",
|
||||
"qwen-oauth",
|
||||
"xiaomi",
|
||||
"custom",
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -188,6 +188,11 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
],
|
||||
"xiaomi": [
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"mimo-v2-flash",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
|
|
@ -493,6 +498,7 @@ _PROVIDER_LABELS = {
|
|||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"qwen-oauth": "Qwen OAuth (Portal)",
|
||||
"huggingface": "Hugging Face",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
|
|
@ -535,6 +541,8 @@ _PROVIDER_ALIASES = {
|
|||
"hf": "huggingface",
|
||||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
"mimo": "xiaomi",
|
||||
"xiaomi-mimo": "xiaomi",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -819,7 +827,7 @@ def list_available_providers() -> list[dict[str, str]]:
|
|||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"gemini", "huggingface",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"qwen-oauth",
|
||||
"qwen-oauth", "xiaomi",
|
||||
"opencode-zen", "opencode-go",
|
||||
"ai-gateway", "deepseek", "custom",
|
||||
]
|
||||
|
|
|
|||
45
hermes_cli/platforms.py
Normal file
45
hermes_cli/platforms.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
Shared platform registry for Hermes Agent.
|
||||
|
||||
Single source of truth for platform metadata consumed by both
|
||||
skills_config (label display) and tools_config (default toolset
|
||||
resolution). Import ``PLATFORMS`` from here instead of maintaining
|
||||
duplicate dicts in each module.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class PlatformInfo(NamedTuple):
|
||||
"""Metadata for a single platform entry."""
|
||||
label: str
|
||||
default_toolset: str
|
||||
|
||||
|
||||
# Ordered so that TUI menus are deterministic.
|
||||
PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
|
||||
("cli", PlatformInfo(label="🖥️ CLI", default_toolset="hermes-cli")),
|
||||
("telegram", PlatformInfo(label="📱 Telegram", default_toolset="hermes-telegram")),
|
||||
("discord", PlatformInfo(label="💬 Discord", default_toolset="hermes-discord")),
|
||||
("slack", PlatformInfo(label="💼 Slack", default_toolset="hermes-slack")),
|
||||
("whatsapp", PlatformInfo(label="📱 WhatsApp", default_toolset="hermes-whatsapp")),
|
||||
("signal", PlatformInfo(label="📡 Signal", default_toolset="hermes-signal")),
|
||||
("bluebubbles", PlatformInfo(label="💙 BlueBubbles", default_toolset="hermes-bluebubbles")),
|
||||
("email", PlatformInfo(label="📧 Email", default_toolset="hermes-email")),
|
||||
("homeassistant", PlatformInfo(label="🏠 Home Assistant", default_toolset="hermes-homeassistant")),
|
||||
("mattermost", PlatformInfo(label="💬 Mattermost", default_toolset="hermes-mattermost")),
|
||||
("matrix", PlatformInfo(label="💬 Matrix", default_toolset="hermes-matrix")),
|
||||
("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")),
|
||||
("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")),
|
||||
("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")),
|
||||
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
|
||||
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
|
||||
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
|
||||
])
|
||||
|
||||
|
||||
def platform_label(key: str, default: str = "") -> str:
|
||||
"""Return the display label for a platform key, or *default*."""
|
||||
info = PLATFORMS.get(key)
|
||||
return info.label if info is not None else default
|
||||
|
|
@ -132,6 +132,10 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
|||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"xiaomi": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -222,6 +226,10 @@ ALIASES: Dict[str, str] = {
|
|||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
|
||||
# xiaomi
|
||||
"mimo": "xiaomi",
|
||||
"xiaomi-mimo": "xiaomi",
|
||||
|
||||
# Local server aliases → virtual "local" concept (resolved via user config)
|
||||
"lmstudio": "lmstudio",
|
||||
"lm-studio": "lmstudio",
|
||||
|
|
@ -242,6 +250,7 @@ _LABEL_OVERRIDES: Dict[str, str] = {
|
|||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"xiaomi": "Xiaomi MiMo",
|
||||
"local": "Local endpoint",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -304,6 +304,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
api_mode = _parse_api_mode(entry.get("api_mode"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
model_name = str(entry.get("model", "") or "").strip()
|
||||
if model_name:
|
||||
result["model"] = model_name
|
||||
return result
|
||||
|
||||
return None
|
||||
|
|
@ -329,6 +332,11 @@ def _resolve_named_custom_runtime(
|
|||
# Check if a credential pool exists for this custom endpoint
|
||||
pool_result = _try_resolve_from_custom_pool(base_url, "custom", custom_provider.get("api_mode"))
|
||||
if pool_result:
|
||||
# Propagate the model name even when using pooled credentials —
|
||||
# the pool doesn't know about the custom_providers model field.
|
||||
model_name = custom_provider.get("model")
|
||||
if model_name:
|
||||
pool_result["model"] = model_name
|
||||
return pool_result
|
||||
|
||||
api_key_candidates = [
|
||||
|
|
@ -339,7 +347,7 @@ def _resolve_named_custom_runtime(
|
|||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
result = {
|
||||
"provider": "custom",
|
||||
"api_mode": custom_provider.get("api_mode")
|
||||
or _detect_api_mode_for_url(base_url)
|
||||
|
|
@ -348,6 +356,11 @@ def _resolve_named_custom_runtime(
|
|||
"api_key": api_key or "no-key-required",
|
||||
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||
}
|
||||
# Propagate the model name so callers can override self.model when the
|
||||
# provider name differs from the actual model string the API expects.
|
||||
if custom_provider.get("model"):
|
||||
result["model"] = custom_provider["model"]
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_openrouter_runtime(
|
||||
|
|
|
|||
|
|
@ -197,24 +197,12 @@ def print_header(title: str):
|
|||
print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
|
||||
def print_info(text: str):
|
||||
"""Print info text."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
from hermes_cli.cli_output import ( # noqa: E402
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
|
||||
def is_interactive_stdin() -> bool:
|
||||
|
|
@ -269,80 +257,9 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
|
|||
|
||||
|
||||
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu using curses to avoid simple_term_menu rendering bugs."""
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
scroll_offset = 0
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
|
||||
# Rows available for list items: rows 2..(max_y-2) inclusive.
|
||||
visible = max(1, max_y - 3)
|
||||
|
||||
# Scroll the viewport so the cursor is always visible.
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible:
|
||||
scroll_offset = cursor - visible + 1
|
||||
scroll_offset = max(0, min(scroll_offset, max(0, len(choices) - visible)))
|
||||
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
question,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for row, i in enumerate(range(scroll_offset, min(scroll_offset + visible, len(choices)))):
|
||||
y = row + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {choices[i]}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
"""Single-select menu using curses. Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=-1)
|
||||
|
||||
|
||||
|
||||
|
|
@ -2052,6 +1969,42 @@ def _setup_weixin():
|
|||
_gateway_setup_weixin()
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Configure Signal via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_signal as _gateway_setup_signal
|
||||
_gateway_setup_signal()
|
||||
|
||||
|
||||
def _setup_email():
|
||||
"""Configure Email via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_email as _gateway_setup_email
|
||||
_gateway_setup_email()
|
||||
|
||||
|
||||
def _setup_sms():
|
||||
"""Configure SMS (Twilio) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_sms as _gateway_setup_sms
|
||||
_gateway_setup_sms()
|
||||
|
||||
|
||||
def _setup_dingtalk():
|
||||
"""Configure DingTalk via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk
|
||||
_gateway_setup_dingtalk()
|
||||
|
||||
|
||||
def _setup_feishu():
|
||||
"""Configure Feishu / Lark via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu
|
||||
_gateway_setup_feishu()
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom
|
||||
_gateway_setup_wecom()
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
"""Configure BlueBubbles iMessage gateway."""
|
||||
print_header("BlueBubbles (iMessage)")
|
||||
|
|
@ -2168,9 +2121,15 @@ _GATEWAY_PLATFORMS = [
|
|||
("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram),
|
||||
("Discord", "DISCORD_BOT_TOKEN", _setup_discord),
|
||||
("Slack", "SLACK_BOT_TOKEN", _setup_slack),
|
||||
("Signal", "SIGNAL_HTTP_URL", _setup_signal),
|
||||
("Email", "EMAIL_ADDRESS", _setup_email),
|
||||
("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms),
|
||||
("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix),
|
||||
("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost),
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk),
|
||||
("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu),
|
||||
("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles),
|
||||
("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks),
|
||||
|
|
@ -2212,10 +2171,17 @@ def setup_gateway(config: dict):
|
|||
get_env_value("TELEGRAM_BOT_TOKEN")
|
||||
or get_env_value("DISCORD_BOT_TOKEN")
|
||||
or get_env_value("SLACK_BOT_TOKEN")
|
||||
or get_env_value("SIGNAL_HTTP_URL")
|
||||
or get_env_value("EMAIL_ADDRESS")
|
||||
or get_env_value("TWILIO_ACCOUNT_SID")
|
||||
or get_env_value("MATTERMOST_TOKEN")
|
||||
or get_env_value("MATRIX_ACCESS_TOKEN")
|
||||
or get_env_value("MATRIX_PASSWORD")
|
||||
or get_env_value("WHATSAPP_ENABLED")
|
||||
or get_env_value("DINGTALK_CLIENT_ID")
|
||||
or get_env_value("FEISHU_APP_ID")
|
||||
or get_env_value("WECOM_BOT_ID")
|
||||
or get_env_value("WEIXIN_ACCOUNT_ID")
|
||||
or get_env_value("BLUEBUBBLES_SERVER_URL")
|
||||
or get_env_value("WEBHOOK_ENABLED")
|
||||
)
|
||||
|
|
@ -2404,12 +2370,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]
|
|||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if get_env_value("EMAIL_ADDRESS"):
|
||||
platforms.append("Email")
|
||||
if get_env_value("TWILIO_ACCOUNT_SID"):
|
||||
platforms.append("SMS")
|
||||
if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"):
|
||||
platforms.append("Matrix")
|
||||
if get_env_value("MATTERMOST_TOKEN"):
|
||||
platforms.append("Mattermost")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("DINGTALK_CLIENT_ID"):
|
||||
platforms.append("DingTalk")
|
||||
if get_env_value("FEISHU_APP_ID"):
|
||||
platforms.append("Feishu")
|
||||
if get_env_value("WECOM_BOT_ID"):
|
||||
platforms.append("WeCom")
|
||||
if get_env_value("WEIXIN_ACCOUNT_ID"):
|
||||
platforms.append("Weixin")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL"):
|
||||
platforms.append("BlueBubbles")
|
||||
if get_env_value("WEBHOOK_ENABLED"):
|
||||
platforms.append("Webhooks")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
|
|
|||
|
|
@ -15,25 +15,12 @@ from typing import List, Optional, Set
|
|||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS, platform_label
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": "🖥️ CLI",
|
||||
"telegram": "📱 Telegram",
|
||||
"discord": "💬 Discord",
|
||||
"slack": "💼 Slack",
|
||||
"whatsapp": "📱 WhatsApp",
|
||||
"signal": "📡 Signal",
|
||||
"bluebubbles": "💬 BlueBubbles",
|
||||
"email": "📧 Email",
|
||||
"homeassistant": "🏠 Home Assistant",
|
||||
"mattermost": "💬 Mattermost",
|
||||
"matrix": "💬 Matrix",
|
||||
"dingtalk": "💬 DingTalk",
|
||||
"feishu": "🪽 Feishu",
|
||||
"wecom": "💬 WeCom",
|
||||
"weixin": "💬 Weixin",
|
||||
"webhook": "🔗 Webhook",
|
||||
}
|
||||
# Backward-compatible view: {key: label_string} so existing code that
|
||||
# iterates ``PLATFORMS.items()`` or calls ``PLATFORMS.get(key)`` keeps
|
||||
# working without changes to every call site.
|
||||
PLATFORMS = {k: info.label for k, info in _PLATFORMS.items() if k != "api_server"}
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -141,11 +141,8 @@ def show_status(args):
|
|||
display = redact_key(value) if not show_all else value
|
||||
print(f" {name:<12} {check_mark(has_key)} {display}")
|
||||
|
||||
anthropic_value = (
|
||||
get_env_value("ANTHROPIC_TOKEN")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
or ""
|
||||
)
|
||||
from hermes_cli.auth import get_anthropic_key
|
||||
anthropic_value = get_anthropic_key()
|
||||
anthropic_display = redact_key(anthropic_value) if not show_all else anthropic_value
|
||||
print(f" {'Anthropic':<12} {check_mark(bool(anthropic_value))} {anthropic_display}")
|
||||
|
||||
|
|
|
|||
|
|
@ -33,33 +33,13 @@ PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
|||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
from hermes_cli.cli_output import ( # noqa: E402 — late import block
|
||||
print_error as _print_error,
|
||||
print_info as _print_info,
|
||||
print_success as _print_success,
|
||||
print_warning as _print_warning,
|
||||
prompt as _prompt,
|
||||
)
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -118,25 +98,14 @@ def _get_plugin_toolset_keys() -> set:
|
|||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
# Platform display config — derived from the canonical registry so every
|
||||
# module shares the same data. Kept as dict-of-dicts for backward
|
||||
# compatibility with existing ``PLATFORMS[key]["label"]`` access patterns.
|
||||
from hermes_cli.platforms import PLATFORMS as _PLATFORMS_REGISTRY
|
||||
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
"telegram": {"label": "📱 Telegram", "default_toolset": "hermes-telegram"},
|
||||
"discord": {"label": "💬 Discord", "default_toolset": "hermes-discord"},
|
||||
"slack": {"label": "💼 Slack", "default_toolset": "hermes-slack"},
|
||||
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"bluebubbles": {"label": "💙 BlueBubbles", "default_toolset": "hermes-bluebubbles"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"},
|
||||
"wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"},
|
||||
"weixin": {"label": "💬 Weixin", "default_toolset": "hermes-weixin"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
"mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"},
|
||||
"webhook": {"label": "🔗 Webhook", "default_toolset": "hermes-webhook"},
|
||||
k: {"label": info.label, "default_toolset": info.default_toolset}
|
||||
for k, info in _PLATFORMS_REGISTRY.items()
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -677,86 +646,9 @@ def _toolset_has_keys(ts_key: str, config: dict = None) -> bool:
|
|||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
from hermes_cli.curses_ui import flush_stdin
|
||||
flush_stdin()
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
"""Single-select menu (arrow keys). Delegates to curses_radiolist."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
return curses_radiolist(question, choices, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
# ─── Token Estimation ────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -189,6 +189,33 @@ def is_wsl() -> bool:
|
|||
return _wsl_detected
|
||||
|
||||
|
||||
# ─── Well-Known Paths ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Return the path to ``config.yaml`` under HERMES_HOME.
|
||||
|
||||
Replaces the ``get_hermes_home() / "config.yaml"`` pattern repeated
|
||||
in 7+ files (skill_utils.py, hermes_logging.py, hermes_time.py, etc.).
|
||||
"""
|
||||
return get_hermes_home() / "config.yaml"
|
||||
|
||||
|
||||
def get_skills_dir() -> Path:
|
||||
"""Return the path to the skills directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "skills"
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""Return the path to the logs directory under HERMES_HOME."""
|
||||
return get_hermes_home() / "logs"
|
||||
|
||||
|
||||
def get_env_path() -> Path:
|
||||
"""Return the path to the ``.env`` file under HERMES_HOME."""
|
||||
return get_hermes_home() / ".env"
|
||||
|
||||
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
OPENROUTER_MODELS_URL = f"{OPENROUTER_BASE_URL}/models"
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from logging.handlers import RotatingFileHandler
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path, get_hermes_home
|
||||
|
||||
# Sentinel to track whether setup_logging() has already run. The function
|
||||
# is idempotent — calling it twice is safe but the second call is a no-op
|
||||
|
|
@ -246,7 +246,7 @@ def _read_logging_config():
|
|||
"""
|
||||
try:
|
||||
import yaml
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ crashes due to a bad timezone string.
|
|||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_config_path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,8 +48,7 @@ def _resolve_timezone_name() -> str:
|
|||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = get_hermes_home()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
|
|
|||
12
nix/tui.nix
12
nix/tui.nix
|
|
@ -4,7 +4,7 @@ let
|
|||
src = ../ui-tui;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-QQixyLmsn5+Y1daHifzDaNQbaoZjm+ezGrGoLXcc95U=";
|
||||
hash = "sha256-+EhRRuvXi5hJupseHblF+MGxs84ijRMIH4qt5+2yYi8=";
|
||||
};
|
||||
|
||||
packageJson = builtins.fromJSON (builtins.readFile (src + "/package.json"));
|
||||
|
|
@ -28,6 +28,10 @@ pkgs.buildNpmPackage {
|
|||
# runtime node_modules
|
||||
cp -r node_modules $out/lib/hermes-tui/node_modules
|
||||
|
||||
# @hermes/ink is a file: dependency, we need to copy it in fr
|
||||
rm -f $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
cp -r packages/hermes-ink $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
|
||||
# package.json needed for "type": "module" resolution
|
||||
cp package.json $out/lib/hermes-tui/
|
||||
|
||||
|
|
@ -36,7 +40,7 @@ pkgs.buildNpmPackage {
|
|||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_tui_lockfile" ''
|
||||
set -euo pipefail
|
||||
set -euox pipefail
|
||||
|
||||
# get root of repo
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
|
@ -45,7 +49,7 @@ pkgs.buildNpmPackage {
|
|||
cd "$REPO_ROOT/ui-tui"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
npm install
|
||||
CI=true npm install # ci env var to suppress annoying unicode install banner lag
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/tui.nix"
|
||||
|
|
@ -65,7 +69,7 @@ pkgs.buildNpmPackage {
|
|||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-tui: installing npm dependencies..."
|
||||
cd ui-tui && npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
cd ui-tui && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -617,6 +617,19 @@ class Migrator:
|
|||
candidate = self.source_root / rel
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
# OpenClaw renamed workspace/ to workspace-main/ (and workspace-{agentId}
|
||||
# for multi-agent). Try the new path as a fallback.
|
||||
if rel.startswith("workspace/"):
|
||||
suffix = rel[len("workspace/"):]
|
||||
for variant in ("workspace-main", "workspace-assistant"):
|
||||
alt = self.source_root / variant / suffix
|
||||
if alt.exists():
|
||||
return alt
|
||||
elif rel.startswith("workspace.default/"):
|
||||
suffix = rel[len("workspace.default/"):]
|
||||
alt = self.source_root / "workspace-main" / suffix
|
||||
if alt.exists():
|
||||
return alt
|
||||
return None
|
||||
|
||||
def resolve_skill_destination(self, destination: Path) -> Path:
|
||||
|
|
@ -1033,11 +1046,8 @@ class Migrator:
|
|||
def migrate_secret_settings(self, config: Dict[str, Any]) -> None:
|
||||
secret_additions: Dict[str, str] = {}
|
||||
|
||||
telegram_token = (
|
||||
config.get("channels", {})
|
||||
.get("telegram", {})
|
||||
.get("botToken")
|
||||
)
|
||||
tg_cfg = config.get("channels", {}).get("telegram", {})
|
||||
telegram_token = self._get_channel_field(tg_cfg, "botToken") if isinstance(tg_cfg, dict) else None
|
||||
if isinstance(telegram_token, str) and telegram_token.strip():
|
||||
secret_additions["TELEGRAM_BOT_TOKEN"] = telegram_token.strip()
|
||||
|
||||
|
|
@ -1057,15 +1067,28 @@ class Migrator:
|
|||
"""Resolve a channel config value that may be a SecretRef."""
|
||||
return resolve_secret_input(value, self.load_openclaw_env())
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_field(ch_cfg: Dict[str, Any], field: str) -> Any:
|
||||
"""Get a field from channel config, checking both flat and accounts.default layout."""
|
||||
val = ch_cfg.get(field)
|
||||
if val is not None:
|
||||
return val
|
||||
accounts = ch_cfg.get("accounts")
|
||||
if isinstance(accounts, dict):
|
||||
default = accounts.get("default")
|
||||
if isinstance(default, dict):
|
||||
return default.get(field)
|
||||
return None
|
||||
|
||||
def migrate_discord_settings(self, config: Optional[Dict[str, Any]] = None) -> None:
|
||||
config = config or self.load_openclaw_config()
|
||||
additions: Dict[str, str] = {}
|
||||
discord = config.get("channels", {}).get("discord", {})
|
||||
if isinstance(discord, dict):
|
||||
token = discord.get("token")
|
||||
token = self._get_channel_field(discord, "token")
|
||||
if isinstance(token, str) and token.strip():
|
||||
additions["DISCORD_BOT_TOKEN"] = token.strip()
|
||||
allow_from = discord.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(discord, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
|
|
@ -1080,13 +1103,13 @@ class Migrator:
|
|||
additions: Dict[str, str] = {}
|
||||
slack = config.get("channels", {}).get("slack", {})
|
||||
if isinstance(slack, dict):
|
||||
bot_token = slack.get("botToken")
|
||||
bot_token = self._get_channel_field(slack, "botToken")
|
||||
if isinstance(bot_token, str) and bot_token.strip():
|
||||
additions["SLACK_BOT_TOKEN"] = bot_token.strip()
|
||||
app_token = slack.get("appToken")
|
||||
app_token = self._get_channel_field(slack, "appToken")
|
||||
if isinstance(app_token, str) and app_token.strip():
|
||||
additions["SLACK_APP_TOKEN"] = app_token.strip()
|
||||
allow_from = slack.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(slack, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
|
|
@ -1101,7 +1124,7 @@ class Migrator:
|
|||
additions: Dict[str, str] = {}
|
||||
whatsapp = config.get("channels", {}).get("whatsapp", {})
|
||||
if isinstance(whatsapp, dict):
|
||||
allow_from = whatsapp.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(whatsapp, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
|
|
@ -1116,13 +1139,13 @@ class Migrator:
|
|||
additions: Dict[str, str] = {}
|
||||
signal = config.get("channels", {}).get("signal", {})
|
||||
if isinstance(signal, dict):
|
||||
account = signal.get("account")
|
||||
account = self._get_channel_field(signal, "account")
|
||||
if isinstance(account, str) and account.strip():
|
||||
additions["SIGNAL_ACCOUNT"] = account.strip()
|
||||
http_url = signal.get("httpUrl")
|
||||
http_url = self._get_channel_field(signal, "httpUrl")
|
||||
if isinstance(http_url, str) and http_url.strip():
|
||||
additions["SIGNAL_HTTP_URL"] = http_url.strip()
|
||||
allow_from = signal.get("allowFrom", [])
|
||||
allow_from = self._get_channel_field(signal, "allowFrom") or []
|
||||
if isinstance(allow_from, list):
|
||||
users = [str(u).strip() for u in allow_from if str(u).strip()]
|
||||
if users:
|
||||
|
|
@ -1161,6 +1184,16 @@ class Migrator:
|
|||
raw_key = provider_cfg.get("apiKey")
|
||||
api_key = resolve_secret_input(raw_key, openclaw_env)
|
||||
if not api_key:
|
||||
# Warn if a SecretRef with file/exec source was silently unresolvable
|
||||
if isinstance(raw_key, dict) and raw_key.get("source") in ("file", "exec"):
|
||||
self.record(
|
||||
"provider-keys",
|
||||
self.source_root / "openclaw.json",
|
||||
None,
|
||||
"skipped",
|
||||
f"Provider '{provider_name}' uses a {raw_key['source']}-backed SecretRef "
|
||||
f"that cannot be auto-migrated. Add this key manually via: hermes config set",
|
||||
)
|
||||
continue
|
||||
|
||||
base_url = provider_cfg.get("baseUrl", "")
|
||||
|
|
@ -1224,6 +1257,21 @@ class Migrator:
|
|||
if val and hermes_key not in secret_additions:
|
||||
secret_additions[hermes_key] = val
|
||||
|
||||
# Check the openclaw.json "env" sub-object — some OpenClaw setups
|
||||
# store API keys here instead of in a separate .env file.
|
||||
# Keys can be at env.<KEY> or env.vars.<KEY>.
|
||||
json_env = config.get("env")
|
||||
if isinstance(json_env, dict):
|
||||
env_vars = json_env.get("vars")
|
||||
sources = [json_env]
|
||||
if isinstance(env_vars, dict):
|
||||
sources.append(env_vars)
|
||||
for src in sources:
|
||||
for oc_key, hermes_key in env_key_mapping.items():
|
||||
val = src.get(oc_key)
|
||||
if isinstance(val, str) and val.strip() and hermes_key not in secret_additions:
|
||||
secret_additions[hermes_key] = val.strip()
|
||||
|
||||
# Check per-agent auth-profiles.json for additional credentials
|
||||
auth_profiles_path = self.source_root / "agents" / "main" / "agent" / "auth-profiles.json"
|
||||
if auth_profiles_path.exists():
|
||||
|
|
@ -1324,8 +1372,9 @@ class Migrator:
|
|||
tts_data: Dict[str, Any] = {}
|
||||
|
||||
provider = tts.get("provider")
|
||||
if isinstance(provider, str) and provider in ("elevenlabs", "openai", "edge"):
|
||||
tts_data["provider"] = provider
|
||||
if isinstance(provider, str) and provider in ("elevenlabs", "openai", "edge", "microsoft"):
|
||||
# OpenClaw renamed "edge" to "microsoft"; Hermes still uses "edge"
|
||||
tts_data["provider"] = "edge" if provider == "microsoft" else provider
|
||||
|
||||
# TTS provider settings live under messages.tts.providers.{provider}
|
||||
# in OpenClaw (not messages.tts.elevenlabs directly)
|
||||
|
|
@ -1374,9 +1423,9 @@ class Migrator:
|
|||
tts_data["openai"] = oai_settings
|
||||
|
||||
edge_tts = (
|
||||
(providers.get("edge") or {})
|
||||
if isinstance(providers.get("edge"), dict) else
|
||||
(tts.get("edge") or {})
|
||||
(providers.get("edge") or providers.get("microsoft") or {})
|
||||
if isinstance(providers.get("edge"), dict) or isinstance(providers.get("microsoft"), dict) else
|
||||
(tts.get("edge") or tts.get("microsoft") or {})
|
||||
)
|
||||
if isinstance(edge_tts, dict):
|
||||
edge_voice = edge_tts.get("voice")
|
||||
|
|
@ -1890,11 +1939,11 @@ class Migrator:
|
|||
if defaults.get("thinkingDefault"):
|
||||
# Map OpenClaw thinking -> Hermes reasoning_effort
|
||||
thinking = defaults["thinkingDefault"]
|
||||
if thinking in ("always", "high"):
|
||||
if thinking in ("always", "high", "xhigh"):
|
||||
agent_cfg["reasoning_effort"] = "high"
|
||||
elif thinking in ("auto", "medium"):
|
||||
elif thinking in ("auto", "medium", "adaptive"):
|
||||
agent_cfg["reasoning_effort"] = "medium"
|
||||
elif thinking in ("off", "low", "none"):
|
||||
elif thinking in ("off", "low", "none", "minimal"):
|
||||
agent_cfg["reasoning_effort"] = "low"
|
||||
changes = True
|
||||
|
||||
|
|
@ -2099,10 +2148,14 @@ class Migrator:
|
|||
f"Provider '{prov_name}' already exists")
|
||||
continue
|
||||
|
||||
api_type = prov_cfg.get("apiType") or prov_cfg.get("type") or "openai"
|
||||
api_type = prov_cfg.get("apiType") or prov_cfg.get("api") or prov_cfg.get("type") or "openai"
|
||||
api_mode_map = {
|
||||
"openai": "chat_completions",
|
||||
"openai-completions": "chat_completions",
|
||||
"openai-responses": "chat_completions",
|
||||
"anthropic": "anthropic_messages",
|
||||
"anthropic-messages": "anthropic_messages",
|
||||
"google-generative-ai": "chat_completions",
|
||||
"cohere": "chat_completions",
|
||||
}
|
||||
entry = {
|
||||
|
|
@ -2142,7 +2195,7 @@ class Migrator:
|
|||
|
||||
# Extended channel token/allowlist mapping
|
||||
CHANNEL_ENV_MAP = {
|
||||
"matrix": {"token": "MATRIX_ACCESS_TOKEN", "allowFrom": "MATRIX_ALLOWED_USERS",
|
||||
"matrix": {"token": "MATRIX...OKEN", "tokenField": "accessToken", "allowFrom": "MATRIX_ALLOWED_USERS",
|
||||
"extras": {"homeserverUrl": "MATRIX_HOMESERVER_URL", "userId": "MATRIX_USER_ID"}},
|
||||
"mattermost": {"token": "MATTERMOST_BOT_TOKEN", "allowFrom": "MATTERMOST_ALLOWED_USERS",
|
||||
"extras": {"url": "MATTERMOST_URL", "teamId": "MATTERMOST_TEAM_ID"}},
|
||||
|
|
@ -2160,19 +2213,21 @@ class Migrator:
|
|||
if not ch_cfg:
|
||||
continue
|
||||
|
||||
# Extract tokens
|
||||
if ch_mapping.get("token") and ch_cfg.get("botToken") and self.migrate_secrets:
|
||||
self._set_env_var(ch_mapping["token"], ch_cfg["botToken"],
|
||||
f"channels.{ch_name}.botToken")
|
||||
if ch_mapping.get("allowFrom") and ch_cfg.get("allowFrom"):
|
||||
allow_val = ch_cfg["allowFrom"]
|
||||
# Extract tokens (check flat path, then accounts.default)
|
||||
token_field = ch_mapping.get("tokenField", "botToken")
|
||||
bot_token = self._get_channel_field(ch_cfg, token_field)
|
||||
if ch_mapping.get("token") and bot_token and self.migrate_secrets:
|
||||
self._set_env_var(ch_mapping["token"], str(bot_token),
|
||||
f"channels.{ch_name}.{token_field}")
|
||||
allow_val = self._get_channel_field(ch_cfg, "allowFrom")
|
||||
if ch_mapping.get("allowFrom") and allow_val:
|
||||
if isinstance(allow_val, list):
|
||||
allow_val = ",".join(str(x) for x in allow_val)
|
||||
self._set_env_var(ch_mapping["allowFrom"], str(allow_val),
|
||||
f"channels.{ch_name}.allowFrom")
|
||||
# Extra fields
|
||||
for oc_key, env_key in (ch_mapping.get("extras") or {}).items():
|
||||
val = ch_cfg.get(oc_key)
|
||||
val = self._get_channel_field(ch_cfg, oc_key)
|
||||
if val:
|
||||
if isinstance(val, list):
|
||||
val = ",".join(str(x) for x in val)
|
||||
|
|
@ -2495,6 +2550,33 @@ class Migrator:
|
|||
elif has_cron_store_archive:
|
||||
notes.append("- Run `hermes cron` to recreate scheduled tasks (see archived cron-store)")
|
||||
|
||||
# Check if skills were imported
|
||||
has_skills = any(i.kind == "skills" and i.status == "migrated" for i in self.items)
|
||||
if has_skills:
|
||||
notes.extend([
|
||||
"",
|
||||
"## Imported Skills",
|
||||
"",
|
||||
"Imported skills require a new session to take effect. After migration,",
|
||||
"restart your agent or start a new chat session, then run `/skills`",
|
||||
"to verify they loaded correctly.",
|
||||
"",
|
||||
])
|
||||
|
||||
# Check if WhatsApp was detected
|
||||
has_whatsapp = any(i.kind == "whatsapp-settings" and i.status == "migrated" for i in self.items)
|
||||
if has_whatsapp:
|
||||
notes.extend([
|
||||
"",
|
||||
"## WhatsApp Requires Re-Pairing",
|
||||
"",
|
||||
"WhatsApp uses QR-code pairing, not token-based auth. Your allowlist",
|
||||
"was migrated, but you must re-pair the device by running:",
|
||||
"",
|
||||
" hermes whatsapp",
|
||||
"",
|
||||
])
|
||||
|
||||
notes.extend([
|
||||
"- Run `hermes gateway install` if you need the gateway service",
|
||||
"- Review `~/.hermes/config.yaml` for any adjustments",
|
||||
|
|
|
|||
190
run_agent.py
190
run_agent.py
|
|
@ -739,6 +739,7 @@ class AIAgent:
|
|||
# Interrupt mechanism for breaking out of tool loops
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None # Optional message that triggered interrupt
|
||||
self._execution_thread_id: int | None = None # Set at run_conversation() start
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# Subagent delegation state
|
||||
|
|
@ -1406,6 +1407,12 @@ class AIAgent:
|
|||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
# Check immediately so CLI users see the warning at startup.
|
||||
# Gateway status_callback is not yet wired, so any warning is stored
|
||||
# in _compression_warning and replayed in the first run_conversation().
|
||||
self._compression_warning = None
|
||||
self._check_compression_model_feasibility()
|
||||
|
||||
# Snapshot primary runtime for per-turn restoration. When fallback
|
||||
# activates during a turn, the next turn restores these values so the
|
||||
# preferred model gets a fresh attempt each time. Uses a single dict
|
||||
|
|
@ -1697,6 +1704,104 @@ class AIAgent:
|
|||
except Exception:
|
||||
logger.debug("status_callback error in _emit_status", exc_info=True)
|
||||
|
||||
def _check_compression_model_feasibility(self) -> None:
|
||||
"""Warn at session start if the auxiliary compression model's context
|
||||
window is smaller than the main model's compression threshold.
|
||||
|
||||
When the auxiliary model cannot fit the content that needs summarising,
|
||||
compression will either fail outright (the LLM call errors) or produce
|
||||
a severely truncated summary.
|
||||
|
||||
Called during ``__init__`` so CLI users see the warning immediately
|
||||
(via ``_vprint``). The gateway sets ``status_callback`` *after*
|
||||
construction, so ``_replay_compression_warning()`` re-sends the
|
||||
stored warning through the callback on the first
|
||||
``run_conversation()`` call.
|
||||
"""
|
||||
if not self.compression_enabled:
|
||||
return
|
||||
try:
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import get_model_context_length
|
||||
|
||||
client, aux_model = get_text_auxiliary_client("compression")
|
||||
if client is None or not aux_model:
|
||||
msg = (
|
||||
"⚠ No auxiliary LLM provider configured — context "
|
||||
"compression will drop middle turns without a summary. "
|
||||
"Run `hermes setup` or set OPENROUTER_API_KEY."
|
||||
)
|
||||
self._compression_warning = msg
|
||||
self._emit_status(msg)
|
||||
logger.warning(
|
||||
"No auxiliary LLM provider for compression — "
|
||||
"summaries will be unavailable."
|
||||
)
|
||||
return
|
||||
|
||||
aux_base_url = str(getattr(client, "base_url", ""))
|
||||
aux_api_key = str(getattr(client, "api_key", ""))
|
||||
aux_context = get_model_context_length(
|
||||
aux_model,
|
||||
base_url=aux_base_url,
|
||||
api_key=aux_api_key,
|
||||
)
|
||||
|
||||
threshold = self.context_compressor.threshold_tokens
|
||||
if aux_context < threshold:
|
||||
# Suggest a threshold that would fit the aux model,
|
||||
# rounded down to a clean percentage.
|
||||
safe_pct = int((aux_context / self.context_compressor.context_length) * 100)
|
||||
msg = (
|
||||
f"⚠ Compression model ({aux_model}) context "
|
||||
f"is {aux_context:,} tokens, but the main model's "
|
||||
f"compression threshold is {threshold:,} tokens. "
|
||||
f"Context compression will not be possible — the "
|
||||
f"content to summarise will exceed the auxiliary "
|
||||
f"model's context window.\n"
|
||||
f" Fix options (config.yaml):\n"
|
||||
f" 1. Use a larger compression model:\n"
|
||||
f" auxiliary:\n"
|
||||
f" compression:\n"
|
||||
f" model: <model-with-{threshold:,}+-context>\n"
|
||||
f" 2. Lower the compression threshold to fit "
|
||||
f"the current model:\n"
|
||||
f" compression:\n"
|
||||
f" threshold: 0.{safe_pct:02d}"
|
||||
)
|
||||
self._compression_warning = msg
|
||||
self._emit_status(msg)
|
||||
logger.warning(
|
||||
"Auxiliary compression model %s has %d token context, "
|
||||
"below the main model's compression threshold of %d "
|
||||
"tokens — compression summaries will fail or be "
|
||||
"severely truncated.",
|
||||
aux_model,
|
||||
aux_context,
|
||||
threshold,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Compression feasibility check failed (non-fatal): %s", exc
|
||||
)
|
||||
|
||||
def _replay_compression_warning(self) -> None:
|
||||
"""Re-send the compression warning through ``status_callback``.
|
||||
|
||||
During ``__init__`` the gateway's ``status_callback`` is not yet
|
||||
wired, so ``_emit_status`` only reaches ``_vprint`` (CLI). This
|
||||
method is called once at the start of the first
|
||||
``run_conversation()`` — by then the gateway has set the callback,
|
||||
so every platform (Telegram, Discord, Slack, etc.) receives the
|
||||
warning.
|
||||
"""
|
||||
msg = getattr(self, "_compression_warning", None)
|
||||
if msg and self.status_callback:
|
||||
try:
|
||||
self.status_callback("lifecycle", msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _is_direct_openai_url(self, base_url: str = None) -> bool:
|
||||
"""Return True when a base URL targets OpenAI's native API."""
|
||||
url = (base_url or self._base_url_lower).lower()
|
||||
|
|
@ -2728,8 +2833,10 @@ class AIAgent:
|
|||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal all tools to abort any in-flight operations immediately
|
||||
_set_interrupt(True)
|
||||
# Signal all tools to abort any in-flight operations immediately.
|
||||
# Scope the interrupt to this agent's execution thread so other
|
||||
# agents running in the same process (gateway) are not affected.
|
||||
_set_interrupt(True, self._execution_thread_id)
|
||||
# Propagate interrupt to any running child agents (subagent delegation)
|
||||
with self._active_children_lock:
|
||||
children_copy = list(self._active_children)
|
||||
|
|
@ -2742,10 +2849,10 @@ class AIAgent:
|
|||
print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||
|
||||
def clear_interrupt(self) -> None:
|
||||
"""Clear any pending interrupt request and the global tool interrupt signal."""
|
||||
"""Clear any pending interrupt request and the per-thread tool interrupt signal."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_interrupt(False)
|
||||
_set_interrupt(False, self._execution_thread_id)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
|
|
@ -3339,6 +3446,7 @@ class AIAgent:
|
|||
def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert internal chat-style messages to Responses input items."""
|
||||
items: List[Dict[str, Any]] = []
|
||||
seen_item_ids: set = set()
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
|
|
@ -3359,7 +3467,12 @@ class AIAgent:
|
|||
if isinstance(codex_reasoning, list):
|
||||
for ri in codex_reasoning:
|
||||
if isinstance(ri, dict) and ri.get("encrypted_content"):
|
||||
item_id = ri.get("id")
|
||||
if item_id and item_id in seen_item_ids:
|
||||
continue
|
||||
items.append(ri)
|
||||
if item_id:
|
||||
seen_item_ids.add(item_id)
|
||||
has_codex_reasoning = True
|
||||
|
||||
if content_text.strip():
|
||||
|
|
@ -3439,6 +3552,7 @@ class AIAgent:
|
|||
raise ValueError("Codex Responses input must be a list of input items.")
|
||||
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
seen_ids: set = set()
|
||||
for idx, item in enumerate(raw_items):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Codex Responses input[{idx}] must be an object.")
|
||||
|
|
@ -3491,8 +3605,12 @@ class AIAgent:
|
|||
if item_type == "reasoning":
|
||||
encrypted = item.get("encrypted_content")
|
||||
if isinstance(encrypted, str) and encrypted:
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
item_id = item.get("id")
|
||||
if isinstance(item_id, str) and item_id:
|
||||
if item_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(item_id)
|
||||
reasoning_item = {"type": "reasoning", "encrypted_content": encrypted}
|
||||
if isinstance(item_id, str) and item_id:
|
||||
reasoning_item["id"] = item_id
|
||||
summary = item.get("summary")
|
||||
|
|
@ -7469,6 +7587,12 @@ class AIAgent:
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Replay compression warning through status_callback for gateway
|
||||
# platforms (the callback was not wired during __init__).
|
||||
if self._compression_warning:
|
||||
self._replay_compression_warning()
|
||||
self._compression_warning = None # send once
|
||||
|
||||
# NOTE: _turns_since_memory and _iters_since_skill are NOT reset here.
|
||||
# They are initialized in __init__ and must persist across run_conversation
|
||||
# calls so that nudge logic accumulates correctly in CLI mode.
|
||||
|
|
@ -7690,6 +7814,11 @@ class AIAgent:
|
|||
compression_attempts = 0
|
||||
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
|
||||
|
||||
# Record the execution thread so interrupt()/clear_interrupt() can
|
||||
# scope the tool-level interrupt signal to THIS agent's thread only.
|
||||
# Must be set before clear_interrupt() which uses it.
|
||||
self._execution_thread_id = threading.current_thread().ident
|
||||
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
|
|
@ -8168,8 +8297,24 @@ class AIAgent:
|
|||
_text_parts.append(getattr(_blk, "text", ""))
|
||||
_trunc_content = "\n".join(_text_parts) if _text_parts else None
|
||||
|
||||
# A response is "thinking exhausted" only when the model
|
||||
# actually produced reasoning blocks but no visible text after
|
||||
# them. Models that do not use <think> tags (e.g. GLM-4.7 on
|
||||
# NVIDIA Build, minimax) may return content=None or an empty
|
||||
# string for unrelated reasons — treat those as normal
|
||||
# truncations that deserve continuation retries, not as
|
||||
# thinking-budget exhaustion.
|
||||
_has_think_tags = bool(
|
||||
_trunc_content and re.search(
|
||||
r'<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)[^>]*>',
|
||||
_trunc_content,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
_thinking_exhausted = (
|
||||
not _trunc_has_tool_calls and (
|
||||
not _trunc_has_tool_calls
|
||||
and _has_think_tags
|
||||
and (
|
||||
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
|
||||
or _trunc_content is None
|
||||
)
|
||||
|
|
@ -9397,12 +9542,41 @@ class AIAgent:
|
|||
invalid_json_args.append((tc.function.name, str(e)))
|
||||
|
||||
if invalid_json_args:
|
||||
# Check if the invalid JSON is due to truncation rather
|
||||
# than a model formatting mistake. Routers sometimes
|
||||
# rewrite finish_reason from "length" to "tool_calls",
|
||||
# hiding the truncation from the length handler above.
|
||||
# Detect truncation: args that don't end with } or ]
|
||||
# (after stripping whitespace) are cut off mid-stream.
|
||||
_truncated = any(
|
||||
not (tc.function.arguments or "").rstrip().endswith(("}", "]"))
|
||||
for tc in assistant_message.tool_calls
|
||||
if tc.function.name in {n for n, _ in invalid_json_args}
|
||||
)
|
||||
if _truncated:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Truncated tool call arguments detected "
|
||||
f"(finish_reason={finish_reason!r}) — refusing to execute.",
|
||||
force=True,
|
||||
)
|
||||
self._invalid_json_retries = 0
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Response truncated due to output length limit",
|
||||
}
|
||||
|
||||
# Track retries for invalid JSON arguments
|
||||
self._invalid_json_retries += 1
|
||||
|
||||
|
||||
tool_name, error_msg = invalid_json_args[0]
|
||||
self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
|
||||
|
||||
if self._invalid_json_retries < 3:
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
# Don't add anything to messages, just retry the API call
|
||||
|
|
|
|||
15
scripts/whatsapp-bridge/package-lock.json
generated
15
scripts/whatsapp-bridge/package-lock.json
generated
|
|
@ -8,7 +8,7 @@
|
|||
"name": "hermes-whatsapp-bridge",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"pino": "^9.0.0",
|
||||
"qrcode-terminal": "^0.12.0"
|
||||
|
|
@ -730,21 +730,22 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@whiskeysockets/baileys": {
|
||||
"name": "baileys",
|
||||
"version": "7.0.0-rc.9",
|
||||
"resolved": "https://registry.npmjs.org/@whiskeysockets/baileys/-/baileys-7.0.0-rc.9.tgz",
|
||||
"integrity": "sha512-YFm5gKXfDP9byCXCW3OPHKXLzrAKzolzgVUlRosHHgwbnf2YOO3XknkMm6J7+F0ns8OA0uuSBhgkRHTDtqkacw==",
|
||||
"resolved": "git+ssh://git@github.com/WhiskeySockets/Baileys.git#01047debd81beb20da7b7779b08edcb06aa03770",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@cacheable/node-cache": "^1.4.0",
|
||||
"@hapi/boom": "^9.1.3",
|
||||
"async-mutex": "^0.5.0",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node.git",
|
||||
"libsignal": "git+https://github.com/whiskeysockets/libsignal-node",
|
||||
"lru-cache": "^11.1.0",
|
||||
"music-metadata": "^11.7.0",
|
||||
"p-queue": "^9.0.0",
|
||||
"pino": "^9.6",
|
||||
"protobufjs": "^7.2.4",
|
||||
"whatsapp-rust-bridge": "0.5.2",
|
||||
"ws": "^8.13.0"
|
||||
},
|
||||
"engines": {
|
||||
|
|
@ -2125,6 +2126,12 @@
|
|||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/whatsapp-rust-bridge": {
|
||||
"version": "0.5.2",
|
||||
"resolved": "https://registry.npmjs.org/whatsapp-rust-bridge/-/whatsapp-rust-bridge-0.5.2.tgz",
|
||||
"integrity": "sha512-6KBRNvxg6WMIwZ/euA8qVzj16qxMBzLllfmaJIP1JGAAfSvwn6nr8JDOMXeqpXPEOl71UfOG+79JwKEoT2b1Fw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/win-guid": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/win-guid/-/win-guid-0.2.1.tgz",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
"start": "node bridge.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"@whiskeysockets/baileys": "WhiskeySockets/Baileys#fix/abprops-abt-fetch",
|
||||
"express": "^4.21.0",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout:
|
|||
"http://0.0.0.0:5000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_local_endpoint_bumps_read_timeout(self, base_url):
|
||||
"""Local endpoint + default timeout -> bumps to base_timeout."""
|
||||
|
|
@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout:
|
|||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Direct unit tests for is_local_endpoint."""
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://localhost:11434",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:5000",
|
||||
"http://[::1]:11434",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://172.17.0.1:11434",
|
||||
])
|
||||
def test_classic_local_addresses(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.docker.internal:8080/v1",
|
||||
"http://gateway.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_container_dns_names(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://api.openai.com",
|
||||
"https://openrouter.ai/api",
|
||||
"https://api.anthropic.com",
|
||||
"https://evil.docker.internal.example.com",
|
||||
])
|
||||
def test_remote_endpoints(self, url):
|
||||
assert is_local_endpoint(url) is False
|
||||
|
|
|
|||
|
|
@ -211,7 +211,8 @@ def make_adapter(platform: Platform, runner=None):
|
|||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
|
||||
if platform == Platform.DISCORD:
|
||||
with patch.object(DiscordAdapter, "_load_participated_threads", return_value=set()):
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
with patch.object(ThreadParticipationTracker, "_load", return_value=set()):
|
||||
adapter = DiscordAdapter(config)
|
||||
platform_key = Platform.DISCORD
|
||||
elif platform == Platform.SLACK:
|
||||
|
|
|
|||
|
|
@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint:
|
|||
)
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
body = await resp.text()
|
||||
assert "data: " in body
|
||||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||
import asyncio
|
||||
import gateway.platforms.api_server as api_server_mod
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Working")
|
||||
await asyncio.sleep(0.65)
|
||||
cb("...done")
|
||||
return (
|
||||
{"final_response": "Working...done", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01),
|
||||
patch.object(adapter, "_run_agent", side_effect=_mock_run_agent),
|
||||
):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "do the thing"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert ": keepalive" in body
|
||||
assert "Working" in body
|
||||
assert "...done" in body
|
||||
assert "[DONE]" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
|
||||
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.
|
||||
|
|
|
|||
|
|
@ -119,28 +119,29 @@ class TestDeduplication:
|
|||
def test_first_message_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
assert adapter._is_duplicate("msg-1") is False
|
||||
assert adapter._dedup.is_duplicate("msg-1") is False
|
||||
|
||||
def test_second_same_message_is_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-1") is True
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-1") is True
|
||||
|
||||
def test_different_messages_not_duplicate(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-2") is False
|
||||
adapter._dedup.is_duplicate("msg-1")
|
||||
assert adapter._dedup.is_duplicate("msg-2") is False
|
||||
|
||||
def test_cache_cleanup_on_overflow(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
max_size = adapter._dedup._max_size
|
||||
# Fill beyond max
|
||||
for i in range(DEDUP_MAX_SIZE + 10):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
for i in range(max_size + 10):
|
||||
adapter._dedup.is_duplicate(f"msg-{i}")
|
||||
# Cache should have been pruned
|
||||
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
|
||||
assert len(adapter._dedup._seen) <= max_size + 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -253,13 +254,13 @@ class TestConnect:
|
|||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._session_webhooks["a"] = "http://x"
|
||||
adapter._seen_messages["b"] = 1.0
|
||||
adapter._dedup._seen["b"] = 1.0
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
|
||||
await adapter.disconnect()
|
||||
assert len(adapter._session_webhooks) == 0
|
||||
assert len(adapter._seen_messages) == 0
|
||||
assert len(adapter._dedup._seen) == 0
|
||||
assert adapter._http_client is None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -137,4 +137,4 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
|||
|
||||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._token_lock_identity is None
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ async def test_discord_bot_thread_skips_mention_requirement(adapter, monkeypatch
|
|||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
# Simulate bot having previously participated in thread 456
|
||||
adapter._bot_participated_threads.add("456")
|
||||
adapter._threads.mark("456")
|
||||
|
||||
thread = FakeThread(channel_id=456, name="existing thread")
|
||||
message = make_message(channel=thread, content="follow-up without mention")
|
||||
|
|
@ -344,7 +344,7 @@ async def test_discord_auto_thread_tracks_participation(adapter, monkeypatch):
|
|||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "555" in adapter._bot_participated_threads
|
||||
assert "555" in adapter._threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -358,4 +358,4 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
|
|||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._bot_participated_threads
|
||||
assert "777" in adapter._threads
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for Discord thread participation persistence.
|
||||
|
||||
Verifies that _bot_participated_threads survives adapter restarts by
|
||||
Verifies that _threads (ThreadParticipationTracker) survives adapter restarts by
|
||||
being persisted to ~/.hermes/discord_threads.json.
|
||||
"""
|
||||
|
||||
|
|
@ -25,13 +25,13 @@ class TestDiscordThreadPersistence:
|
|||
|
||||
def test_starts_empty_when_no_state_file(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists_to_disk(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("222")
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("222")
|
||||
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
assert state_file.exists()
|
||||
|
|
@ -42,42 +42,43 @@ class TestDiscordThreadPersistence:
|
|||
"""Threads tracked by one adapter instance are visible to the next."""
|
||||
adapter1 = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter1._track_thread("aaa")
|
||||
adapter1._track_thread("bbb")
|
||||
adapter1._threads.mark("aaa")
|
||||
adapter1._threads.mark("bbb")
|
||||
|
||||
adapter2 = self._make_adapter(tmp_path)
|
||||
assert "aaa" in adapter2._bot_participated_threads
|
||||
assert "bbb" in adapter2._bot_participated_threads
|
||||
assert "aaa" in adapter2._threads
|
||||
assert "bbb" in adapter2._threads
|
||||
|
||||
def test_duplicate_track_does_not_double_save(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
adapter._track_thread("111")
|
||||
adapter._track_thread("111") # no-op
|
||||
adapter._threads.mark("111")
|
||||
adapter._threads.mark("111") # no-op
|
||||
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert saved.count("111") == 1
|
||||
|
||||
def test_caps_at_max_tracked_threads(self, tmp_path):
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
for i in range(10):
|
||||
adapter._track_thread(str(i))
|
||||
adapter._threads.mark(str(i))
|
||||
|
||||
assert len(adapter._bot_participated_threads) == 5
|
||||
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||
assert len(saved) == 5
|
||||
|
||||
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
|
||||
state_file = tmp_path / "discord_threads.json"
|
||||
state_file.write_text("not valid json{{{")
|
||||
adapter = self._make_adapter(tmp_path)
|
||||
assert adapter._bot_participated_threads == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_missing_hermes_home_does_not_crash(self, tmp_path):
|
||||
"""Load/save tolerate missing directories."""
|
||||
fake_home = tmp_path / "nonexistent" / "deep"
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
# _load should return empty set, not crash
|
||||
threads = DiscordAdapter._load_participated_threads()
|
||||
assert threads == set()
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
# ThreadParticipationTracker should return empty set, not crash
|
||||
tracker = ThreadParticipationTracker("discord")
|
||||
assert "$test" not in tracker
|
||||
|
|
|
|||
|
|
@ -195,6 +195,105 @@ async def test_internal_event_does_not_trigger_pairing(monkeypatch, tmp_path):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_complete_preserves_user_identity(monkeypatch, tmp_path):
|
||||
"""Synthetic completion event should carry user_id and user_name from the watcher."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [
|
||||
SimpleNamespace(
|
||||
output_buffer="done\n", exited=True, exit_code=0, command="echo test"
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path)
|
||||
adapter = runner.adapters[Platform.DISCORD]
|
||||
|
||||
watcher = _watcher_dict_with_notify()
|
||||
watcher["user_id"] = "user-42"
|
||||
watcher["user_name"] = "alice"
|
||||
|
||||
await runner._run_process_watcher(watcher)
|
||||
|
||||
assert adapter.handle_message.await_count == 1
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.user_id == "user-42"
|
||||
assert event.source.user_name == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_skips_pairing(monkeypatch, tmp_path):
|
||||
"""A non-internal event with user_id=None should be silently dropped."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="service message",
|
||||
source=source,
|
||||
internal=False,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
# Should return None (dropped) and NOT send any pairing message
|
||||
assert result is None
|
||||
assert adapter.send.await_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_user_id_does_not_generate_pairing_code(monkeypatch, tmp_path):
|
||||
"""A message with user_id=None must never call generate_code."""
|
||||
import gateway.run as gateway_run
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
(tmp_path / "config.yaml").write_text("", encoding="utf-8")
|
||||
|
||||
runner = GatewayRunner(GatewayConfig())
|
||||
adapter = SimpleNamespace(send=AsyncMock())
|
||||
runner.adapters[Platform.DISCORD] = adapter
|
||||
|
||||
generate_called = False
|
||||
original_generate = runner.pairing_store.generate_code
|
||||
|
||||
def tracking_generate(*args, **kwargs):
|
||||
nonlocal generate_called
|
||||
generate_called = True
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
runner.pairing_store.generate_code = tracking_generate
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="456",
|
||||
chat_type="dm",
|
||||
user_id=None,
|
||||
)
|
||||
event = MessageEvent(text="anonymous", source=source, internal=False)
|
||||
|
||||
await runner._handle_message(event)
|
||||
|
||||
assert not generate_called, (
|
||||
"Pairing code should NOT be generated for messages with user_id=None"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_internal_event_without_user_triggers_pairing(monkeypatch, tmp_path):
|
||||
"""Verify the normal (non-internal) path still triggers pairing for unknown users."""
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ async def test_require_mention_bot_participated_thread(monkeypatch):
|
|||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread1")
|
||||
adapter._threads.mark("$thread1")
|
||||
|
||||
event = _make_event("hello without mention", thread_id="$thread1")
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ async def test_auto_thread_preserves_existing_thread(monkeypatch):
|
|||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._bot_participated_threads.add("$thread_root")
|
||||
adapter._threads.mark("$thread_root")
|
||||
event = _make_event("reply in thread", thread_id="$thread_root")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
|
|
@ -340,17 +340,17 @@ async def test_auto_thread_disabled(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_thread_tracks_participation(monkeypatch):
|
||||
"""Auto-created threads are tracked in _bot_participated_threads."""
|
||||
"""Auto-created threads are tracked in _threads."""
|
||||
monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("MATRIX_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter = _make_adapter()
|
||||
event = _make_event("hello", event_id="$msg1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$msg1" in adapter._bot_participated_threads
|
||||
assert "$msg1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -361,56 +361,54 @@ async def test_auto_thread_tracks_participation(monkeypatch):
|
|||
class TestThreadPersistence:
|
||||
def test_empty_state_file(self, tmp_path, monkeypatch):
|
||||
"""No state file → empty set."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: tmp_path / "matrix_threads.json"),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: tmp_path / "matrix_threads.json",
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
loaded = adapter._load_participated_threads()
|
||||
assert loaded == set()
|
||||
assert "$nonexistent" not in adapter._threads
|
||||
|
||||
def test_track_thread_persists(self, tmp_path, monkeypatch):
|
||||
"""_track_thread writes to disk."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""mark() writes to disk."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._track_thread("$thread_abc")
|
||||
adapter._threads.mark("$thread_abc")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert "$thread_abc" in data
|
||||
|
||||
def test_threads_survive_reload(self, tmp_path, monkeypatch):
|
||||
"""Persisted threads are loaded by a new adapter instance."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
state_path.write_text(json.dumps(["$t1", "$t2"]))
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
assert "$t1" in adapter._bot_participated_threads
|
||||
assert "$t2" in adapter._bot_participated_threads
|
||||
assert "$t1" in adapter._threads
|
||||
assert "$t2" in adapter._threads
|
||||
|
||||
def test_cap_max_tracked_threads(self, tmp_path, monkeypatch):
|
||||
"""Thread set is trimmed to _MAX_TRACKED_THREADS."""
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
"""Thread set is trimmed to max_tracked."""
|
||||
from gateway.platforms.helpers import ThreadParticipationTracker
|
||||
state_path = tmp_path / "matrix_threads.json"
|
||||
monkeypatch.setattr(
|
||||
MatrixAdapter, "_thread_state_path",
|
||||
staticmethod(lambda: state_path),
|
||||
ThreadParticipationTracker, "_state_path",
|
||||
lambda self: state_path,
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._MAX_TRACKED_THREADS = 5
|
||||
adapter._threads._max_tracked = 5
|
||||
|
||||
for i in range(10):
|
||||
adapter._bot_participated_threads.add(f"$t{i}")
|
||||
adapter._save_participated_threads()
|
||||
adapter._threads.mark(f"$t{i}")
|
||||
|
||||
data = json.loads(state_path.read_text())
|
||||
assert len(data) == 5
|
||||
|
|
@ -447,7 +445,7 @@ async def test_dm_mention_thread_creates_thread(monkeypatch):
|
|||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help me", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
|
@ -480,7 +478,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
|||
|
||||
adapter = _make_adapter()
|
||||
_set_dm(adapter)
|
||||
adapter._bot_participated_threads.add("$existing_thread")
|
||||
adapter._threads.mark("$existing_thread")
|
||||
event = _make_event("@hermes:example.org help me", thread_id="$existing_thread")
|
||||
|
||||
await adapter._on_room_message(event)
|
||||
|
|
@ -491,7 +489,7 @@ async def test_dm_mention_thread_preserves_existing_thread(monkeypatch):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
||||
"""DM mention-thread tracks the thread in _bot_participated_threads."""
|
||||
"""DM mention-thread tracks the thread in _threads."""
|
||||
monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", "true")
|
||||
monkeypatch.setenv("MATRIX_AUTO_THREAD", "false")
|
||||
|
||||
|
|
@ -499,10 +497,10 @@ async def test_dm_mention_thread_tracks_participation(monkeypatch):
|
|||
_set_dm(adapter)
|
||||
event = _make_event("@hermes:example.org help", event_id="$dm1")
|
||||
|
||||
with patch.object(adapter, "_save_participated_threads"):
|
||||
with patch.object(adapter._threads, "_save"):
|
||||
await adapter._on_room_message(event)
|
||||
|
||||
assert "$dm1" in adapter._bot_participated_threads
|
||||
assert "$dm1" in adapter._threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -614,25 +614,27 @@ class TestMattermostDedup:
|
|||
assert self.adapter.handle_message.call_count == 2
|
||||
|
||||
def test_prune_seen_clears_expired(self):
|
||||
"""_prune_seen should remove entries older than _SEEN_TTL."""
|
||||
"""Dedup cache should remove entries older than TTL on overflow."""
|
||||
now = time.time()
|
||||
dedup = self.adapter._dedup
|
||||
# Fill with enough expired entries to trigger pruning
|
||||
for i in range(self.adapter._SEEN_MAX + 10):
|
||||
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
|
||||
for i in range(dedup._max_size + 10):
|
||||
dedup._seen[f"old_{i}"] = now - 600 # 10 min ago (older than default TTL)
|
||||
|
||||
# Add a fresh one
|
||||
self.adapter._seen_posts["fresh"] = now
|
||||
dedup._seen["fresh"] = now
|
||||
|
||||
self.adapter._prune_seen()
|
||||
# Trigger pruning by calling is_duplicate with a new entry (over max_size)
|
||||
dedup.is_duplicate("trigger_prune")
|
||||
|
||||
# Old entries should be pruned, fresh one kept
|
||||
assert "fresh" in self.adapter._seen_posts
|
||||
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
|
||||
assert "fresh" in dedup._seen
|
||||
assert len(dedup._seen) < dedup._max_size + 10
|
||||
|
||||
def test_seen_cache_tracks_post_ids(self):
|
||||
"""Posts are tracked in _seen_posts dict."""
|
||||
self.adapter._seen_posts["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._seen_posts
|
||||
"""Posts are tracked in the dedup cache."""
|
||||
self.adapter._dedup._seen["test_post"] = time.time()
|
||||
assert "test_post" in self.adapter._dedup._seen
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.run import _dequeue_pending_event
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
|
|||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:voice"
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="voice-q1",
|
||||
media_urls=["/tmp/voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = _dequeue_pending_event(adapter, session_key)
|
||||
|
||||
assert retrieved is event
|
||||
assert retrieved.media_urls == ["/tmp/voice.ogg"]
|
||||
assert retrieved.media_types == ["audio/ogg"]
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
|
@ -25,6 +27,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
|
|
@ -33,6 +37,8 @@ def test_set_session_env_sets_contextvars(monkeypatch):
|
|||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == "-1001"
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == "Group"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == "alice"
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == "17585"
|
||||
|
||||
# os.environ should NOT be touched
|
||||
|
|
@ -50,6 +56,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_USER_NAME", raising=False)
|
||||
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||
|
||||
source = SessionSource(
|
||||
|
|
@ -57,12 +65,15 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
user_id="123456",
|
||||
user_name="alice",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(source=source, connected_platforms=[], home_channels={})
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_PLATFORM") == "telegram"
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == "123456"
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
|
@ -70,6 +81,8 @@ def test_clear_session_env_restores_previous_state(monkeypatch):
|
|||
assert get_session_env("HERMES_SESSION_PLATFORM") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_CHAT_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_ID") == ""
|
||||
assert get_session_env("HERMES_SESSION_USER_NAME") == ""
|
||||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -114,16 +114,16 @@ class TestSignalAdapterInit:
|
|||
|
||||
class TestSignalHelpers:
|
||||
def test_redact_phone_long(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+15551234567") == "+155****4567"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+155****4567") == "+155****4567"
|
||||
|
||||
def test_redact_phone_short(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("+12345") == "+1****45"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("+12345") == "+1****45"
|
||||
|
||||
def test_redact_phone_empty(self):
|
||||
from gateway.platforms.signal import _redact_phone
|
||||
assert _redact_phone("") == "<none>"
|
||||
from gateway.platforms.helpers import redact_phone
|
||||
assert redact_phone("") == "<none>"
|
||||
|
||||
def test_parse_comma_list(self):
|
||||
from gateway.platforms.signal import _parse_comma_list
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
"""Tests for SMS (Twilio) platform integration.
|
||||
|
||||
Covers config loading, format/truncate, echo prevention,
|
||||
requirements check, and toolset verification.
|
||||
requirements check, toolset verification, and Twilio signature validation.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -213,3 +216,335 @@ class TestSmsToolset:
|
|||
from tools.cronjob_tools import CRONJOB_SCHEMA
|
||||
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
|
||||
assert "sms" in deliver_desc.lower()
|
||||
|
||||
|
||||
# ── Webhook host configuration ─────────────────────────────────────
|
||||
|
||||
class TestWebhookHostConfig:
|
||||
"""Verify SMS_WEBHOOK_HOST env var and default."""
|
||||
|
||||
def test_default_host_is_all_interfaces(self):
|
||||
from gateway.platforms.sms import DEFAULT_WEBHOOK_HOST
|
||||
assert DEFAULT_WEBHOOK_HOST == "0.0.0.0"
|
||||
|
||||
def test_host_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_HOST": "127.0.0.1",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_host == "127.0.0.1"
|
||||
|
||||
def test_webhook_url_from_env(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
def test_webhook_url_stripped(self):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": " https://example.com/webhooks/twilio ",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
assert adapter._webhook_url == "https://example.com/webhooks/twilio"
|
||||
|
||||
|
||||
# ── Startup guard (fail-closed) ────────────────────────────────────
|
||||
|
||||
class TestStartupGuard:
|
||||
"""Adapter must refuse to start without SMS_WEBHOOK_URL."""
|
||||
|
||||
def _make_adapter(self, extra_env=None):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "tok",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
if extra_env:
|
||||
env.update(extra_env)
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
pc = PlatformConfig(enabled=True, api_key="tok")
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refuses_start_without_webhook_url(self):
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_allows_start_without_url(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch.dict(os.environ, {"SMS_INSECURE_NO_SIGNATURE": "true"}), \
|
||||
patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter()
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_url_allows_start(self):
|
||||
mock_session = AsyncMock()
|
||||
with patch("aiohttp.web.AppRunner") as mock_runner_cls, \
|
||||
patch("aiohttp.web.TCPSite") as mock_site_cls, \
|
||||
patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
mock_runner_cls.return_value.setup = AsyncMock()
|
||||
mock_runner_cls.return_value.cleanup = AsyncMock()
|
||||
mock_site_cls.return_value.start = AsyncMock()
|
||||
adapter = self._make_adapter(
|
||||
extra_env={"SMS_WEBHOOK_URL": "https://example.com/webhooks/twilio"}
|
||||
)
|
||||
result = await adapter.connect()
|
||||
assert result is True
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
# ── Twilio signature validation ────────────────────────────────────
|
||||
|
||||
def _compute_twilio_signature(auth_token, url, params):
|
||||
"""Reference implementation of Twilio's signature algorithm."""
|
||||
data_to_sign = url
|
||||
for key in sorted(params.keys()):
|
||||
data_to_sign += key + params[key]
|
||||
mac = hmac.new(
|
||||
auth_token.encode("utf-8"),
|
||||
data_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
)
|
||||
return base64.b64encode(mac.digest()).decode("utf-8")
|
||||
|
||||
|
||||
class TestTwilioSignatureValidation:
|
||||
"""Unit tests for SmsAdapter._validate_twilio_signature."""
|
||||
|
||||
def _make_adapter(self, auth_token="test_token_secret"):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": auth_token,
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key=auth_token)
|
||||
adapter = SmsAdapter(pc)
|
||||
return adapter
|
||||
|
||||
def test_valid_signature_accepted(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello", "To": "+15550001111"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_invalid_signature_rejected(self):
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
assert adapter._validate_twilio_signature(url, params, "badsig") is False
|
||||
|
||||
def test_wrong_token_rejected(self):
|
||||
adapter = self._make_adapter(auth_token="correct_token")
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature("wrong_token", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is False
|
||||
|
||||
def test_params_sorted_by_key(self):
|
||||
"""Signature must be computed with params sorted alphabetically."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"Zebra": "last", "Alpha": "first", "Middle": "mid"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_empty_param_values_included(self):
|
||||
"""Blank values must be included in signature computation."""
|
||||
adapter = self._make_adapter()
|
||||
url = "https://example.com/webhooks/twilio"
|
||||
params = {"From": "+15551234567", "Body": "", "SmsStatus": "received"}
|
||||
sig = _compute_twilio_signature("test_token_secret", url, params)
|
||||
assert adapter._validate_twilio_signature(url, params, sig) is True
|
||||
|
||||
def test_url_matters(self):
|
||||
"""Different URLs produce different signatures."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://a.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://b.com/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_443_matches_without_port(self):
|
||||
"""Signature for https URL with :443 validates against URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_port_variant_without_port_matches_443(self):
|
||||
"""Signature for https URL without port validates against URL with :443."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:443/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
def test_non_standard_port_no_variant(self):
|
||||
"""Non-standard port must NOT match URL without port."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"https://example.com:8080/webhooks/twilio", params, sig
|
||||
) is False
|
||||
|
||||
def test_port_variant_http_80(self):
|
||||
"""Port variant also works for http with port 80."""
|
||||
adapter = self._make_adapter()
|
||||
params = {"From": "+15551234567", "Body": "hello"}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "http://example.com:80/webhooks/twilio", params
|
||||
)
|
||||
assert adapter._validate_twilio_signature(
|
||||
"http://example.com/webhooks/twilio", params, sig
|
||||
) is True
|
||||
|
||||
|
||||
# ── Webhook signature enforcement (handler-level) ──────────────────
|
||||
|
||||
class TestWebhookSignatureEnforcement:
|
||||
"""Integration tests for signature validation in _handle_webhook."""
|
||||
|
||||
def _make_adapter(self, webhook_url=""):
|
||||
from gateway.platforms.sms import SmsAdapter
|
||||
|
||||
env = {
|
||||
"TWILIO_ACCOUNT_SID": "ACtest",
|
||||
"TWILIO_AUTH_TOKEN": "test_token_secret",
|
||||
"TWILIO_PHONE_NUMBER": "+15550001111",
|
||||
"SMS_WEBHOOK_URL": webhook_url,
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
pc = PlatformConfig(enabled=True, api_key="test_token_secret")
|
||||
adapter = SmsAdapter(pc)
|
||||
adapter._message_handler = AsyncMock()
|
||||
return adapter
|
||||
|
||||
def _mock_request(self, body, headers=None):
|
||||
request = MagicMock()
|
||||
request.read = AsyncMock(return_value=body)
|
||||
request.headers = headers or {}
|
||||
return request
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_skips_validation(self):
|
||||
"""With SMS_INSECURE_NO_SIGNATURE=true and no URL, requests are accepted."""
|
||||
env = {"SMS_INSECURE_NO_SIGNATURE": "true"}
|
||||
with patch.dict(os.environ, env):
|
||||
adapter = self._make_adapter(webhook_url="")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body)
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insecure_flag_with_url_still_validates(self):
|
||||
"""When both SMS_WEBHOOK_URL and SMS_INSECURE_NO_SIGNATURE are set,
|
||||
validation stays active (URL takes precedence)."""
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_signature_returns_403(self):
|
||||
adapter = self._make_adapter(webhook_url="https://example.com/webhooks/twilio")
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": "invalid"})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_signature_returns_200(self):
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature("test_token_secret", webhook_url, params)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_port_variant_signature_returns_200(self):
|
||||
"""Signature computed with :443 should pass when URL configured without port."""
|
||||
webhook_url = "https://example.com/webhooks/twilio"
|
||||
adapter = self._make_adapter(webhook_url=webhook_url)
|
||||
params = {
|
||||
"From": "+15551234567",
|
||||
"To": "+15550001111",
|
||||
"Body": "hello",
|
||||
"MessageSid": "SM123",
|
||||
}
|
||||
sig = _compute_twilio_signature(
|
||||
"test_token_secret", "https://example.com:443/webhooks/twilio", params
|
||||
)
|
||||
body = b"From=%2B15551234567&To=%2B15550001111&Body=hello&MessageSid=SM123"
|
||||
request = self._mock_request(body, headers={"X-Twilio-Signature": sig})
|
||||
resp = await adapter._handle_webhook(request)
|
||||
assert resp.status == 200
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
from gateway.config import GatewayConfig, Platform, load_gateway_config
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
|
|
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
|
|||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
runner.adapters = {}
|
||||
runner._model = "test-model"
|
||||
runner._base_url = ""
|
||||
runner._has_setup_skill = lambda: False
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=source,
|
||||
media_urls=["/tmp/queued-voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={
|
||||
"success": True,
|
||||
"transcript": "queued voice transcript",
|
||||
"provider": "local_command",
|
||||
},
|
||||
):
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "queued voice transcript" in result
|
||||
assert "voice message" in result.lower()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ def _no_auto_discovery(monkeypatch):
|
|||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
# Mock HTTPXRequest so the builder chain doesn't fail
|
||||
monkeypatch.setattr("gateway.platforms.telegram.HTTPXRequest", lambda **kwargs: MagicMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -57,9 +59,9 @@ async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
|||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_token_lock"
|
||||
assert adapter.fatal_error_code == "telegram-bot-token_lock"
|
||||
assert adapter.has_fatal_error is True
|
||||
assert "already using this Telegram bot token" in adapter.fatal_error_message
|
||||
assert "already in use" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -98,6 +100,8 @@ async def test_polling_conflict_retries_before_fatal(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
|
|
@ -172,6 +176,8 @@ async def test_polling_conflict_becomes_fatal_after_retries(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
|
|
@ -216,6 +222,8 @@ async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(m
|
|||
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
app = SimpleNamespace(
|
||||
bot=SimpleNamespace(delete_webhook=AsyncMock(), set_my_commands=AsyncMock()),
|
||||
updater=SimpleNamespace(),
|
||||
|
|
@ -265,6 +273,8 @@ async def test_connect_clears_webhook_before_polling(monkeypatch):
|
|||
)
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
builder.request.return_value = builder
|
||||
builder.get_updates_request.return_value = builder
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.telegram.Application",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
"""Tests for the Weixin platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.config import GatewayConfig, HomeChannel, Platform, _apply_env_overrides
|
||||
from gateway.platforms.weixin import WeixinAdapter
|
||||
from gateway.platforms import weixin
|
||||
from gateway.platforms.weixin import ContextTokenStore, WeixinAdapter
|
||||
from tools.send_message_tool import _parse_target_ref, _send_to_platform
|
||||
|
||||
|
||||
|
|
@ -62,15 +64,15 @@ class TestWeixinFormatting:
|
|||
|
||||
|
||||
class TestWeixinChunking:
|
||||
def test_split_text_sends_top_level_newlines_as_separate_messages(self):
|
||||
def test_split_text_keeps_short_multiline_message_in_single_chunk(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message("第一行\n第二行\n第三行")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["第一行", "第二行", "第三行"]
|
||||
assert chunks == ["第一行\n第二行\n第三行"]
|
||||
|
||||
def test_split_text_keeps_indented_followup_with_previous_line(self):
|
||||
def test_split_text_keeps_short_reformatted_table_in_single_chunk(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
content = adapter.format_message(
|
||||
|
|
@ -81,10 +83,7 @@ class TestWeixinChunking:
|
|||
)
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == [
|
||||
"- Setting: Timeout\n Value: 30s",
|
||||
"- Setting: Retries\n Value: 3",
|
||||
]
|
||||
assert chunks == [content]
|
||||
|
||||
def test_split_text_keeps_complete_code_block_together_when_possible(self):
|
||||
adapter = _make_adapter()
|
||||
|
|
@ -114,6 +113,23 @@ class TestWeixinChunking:
|
|||
assert all(len(chunk) <= adapter.MAX_MESSAGE_LENGTH for chunk in chunks)
|
||||
assert all(chunk.count("```") >= 2 for chunk in chunks)
|
||||
|
||||
def test_split_text_can_restore_legacy_multiline_splitting_via_config(self):
|
||||
adapter = WeixinAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"account_id": "acct",
|
||||
"token": "***",
|
||||
"split_multiline_messages": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
content = adapter.format_message("第一行\n第二行\n第三行")
|
||||
chunks = adapter._split_text(content)
|
||||
|
||||
assert chunks == ["第一行", "第二行", "第三行"]
|
||||
|
||||
|
||||
class TestWeixinConfig:
|
||||
def test_apply_env_overrides_configures_weixin(self):
|
||||
|
|
@ -127,6 +143,7 @@ class TestWeixinConfig:
|
|||
"WEIXIN_BASE_URL": "https://ilink.example.com/",
|
||||
"WEIXIN_CDN_BASE_URL": "https://cdn.example.com/c2c/",
|
||||
"WEIXIN_DM_POLICY": "allowlist",
|
||||
"WEIXIN_SPLIT_MULTILINE_MESSAGES": "true",
|
||||
"WEIXIN_ALLOWED_USERS": "wxid_1,wxid_2",
|
||||
"WEIXIN_HOME_CHANNEL": "wxid_1",
|
||||
"WEIXIN_HOME_CHANNEL_NAME": "Primary DM",
|
||||
|
|
@ -142,6 +159,7 @@ class TestWeixinConfig:
|
|||
assert platform_config.extra["base_url"] == "https://ilink.example.com"
|
||||
assert platform_config.extra["cdn_base_url"] == "https://cdn.example.com/c2c"
|
||||
assert platform_config.extra["dm_policy"] == "allowlist"
|
||||
assert platform_config.extra["split_multiline_messages"] == "true"
|
||||
assert platform_config.extra["allow_from"] == "wxid_1,wxid_2"
|
||||
assert platform_config.home_channel == HomeChannel(Platform.WEIXIN, "wxid_1", "Primary DM")
|
||||
|
||||
|
|
@ -171,6 +189,70 @@ class TestWeixinConfig:
|
|||
assert config.get_connected_platforms() == []
|
||||
|
||||
|
||||
class TestWeixinStatePersistence:
|
||||
def test_save_weixin_account_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
account_path = tmp_path / "weixin" / "accounts" / "acct.json"
|
||||
account_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
original = {"token": "old-token", "base_url": "https://old.example.com"}
|
||||
account_path.write_text(json.dumps(original), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
try:
|
||||
weixin.save_weixin_account(
|
||||
str(tmp_path),
|
||||
account_id="acct",
|
||||
token="new-token",
|
||||
base_url="https://new.example.com",
|
||||
user_id="wxid_new",
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("expected save_weixin_account to propagate replace failure")
|
||||
|
||||
assert json.loads(account_path.read_text(encoding="utf-8")) == original
|
||||
|
||||
def test_context_token_persist_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
token_path = tmp_path / "weixin" / "accounts" / "acct.context-tokens.json"
|
||||
token_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
token_path.write_text(json.dumps({"user-a": "old-token"}), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
store = ContextTokenStore(str(tmp_path))
|
||||
with patch.object(weixin.logger, "warning") as warning_mock:
|
||||
store.set("acct", "user-b", "new-token")
|
||||
|
||||
assert json.loads(token_path.read_text(encoding="utf-8")) == {"user-a": "old-token"}
|
||||
warning_mock.assert_called_once()
|
||||
|
||||
def test_save_sync_buf_preserves_existing_file_on_replace_failure(self, tmp_path, monkeypatch):
|
||||
sync_path = tmp_path / "weixin" / "accounts" / "acct.sync.json"
|
||||
sync_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sync_path.write_text(json.dumps({"get_updates_buf": "old-sync"}), encoding="utf-8")
|
||||
|
||||
def _boom(_src, _dst):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.os.replace", _boom)
|
||||
|
||||
try:
|
||||
weixin._save_sync_buf(str(tmp_path), "acct", "new-sync")
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("expected _save_sync_buf to propagate replace failure")
|
||||
|
||||
assert json.loads(sync_path.read_text(encoding="utf-8")) == {"get_updates_buf": "old-sync"}
|
||||
|
||||
|
||||
class TestWeixinSendMessageIntegration:
|
||||
def test_parse_target_ref_accepts_weixin_ids(self):
|
||||
assert _parse_target_ref("weixin", "wxid_test123") == ("wxid_test123", None, True)
|
||||
|
|
@ -201,6 +283,55 @@ class TestWeixinSendMessageIntegration:
|
|||
)
|
||||
|
||||
|
||||
class TestWeixinChunkDelivery:
|
||||
def _connected_adapter(self) -> WeixinAdapter:
|
||||
adapter = _make_adapter()
|
||||
adapter._session = object()
|
||||
adapter._token = "test-token"
|
||||
adapter._base_url = "https://weixin.example.com"
|
||||
adapter._token_store.get = lambda account_id, chat_id: "ctx-token"
|
||||
return adapter
|
||||
|
||||
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
|
||||
def test_send_waits_between_multiple_chunks(self, send_message_mock, sleep_mock):
|
||||
adapter = self._connected_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 12
|
||||
|
||||
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
|
||||
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
|
||||
|
||||
assert result.success is True
|
||||
assert send_message_mock.await_count == 3
|
||||
assert sleep_mock.await_count == 2
|
||||
|
||||
@patch("gateway.platforms.weixin.asyncio.sleep", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.weixin._send_message", new_callable=AsyncMock)
|
||||
def test_send_retries_failed_chunk_before_continuing(self, send_message_mock, sleep_mock):
|
||||
adapter = self._connected_adapter()
|
||||
adapter.MAX_MESSAGE_LENGTH = 12
|
||||
calls = {"count": 0}
|
||||
|
||||
async def flaky_send(*args, **kwargs):
|
||||
calls["count"] += 1
|
||||
if calls["count"] == 2:
|
||||
raise RuntimeError("temporary iLink failure")
|
||||
|
||||
send_message_mock.side_effect = flaky_send
|
||||
|
||||
# Use double newlines so _pack_markdown_blocks splits into 3 blocks
|
||||
result = asyncio.run(adapter.send("wxid_test123", "first\n\nsecond\n\nthird"))
|
||||
|
||||
assert result.success is True
|
||||
# 3 chunks, but chunk 2 fails once and retries → 4 _send_message calls total
|
||||
assert send_message_mock.await_count == 4
|
||||
# The retried chunk should reuse the same client_id for deduplication
|
||||
first_try = send_message_mock.await_args_list[1].kwargs
|
||||
retry = send_message_mock.await_args_list[2].kwargs
|
||||
assert first_try["text"] == retry["text"]
|
||||
assert first_try["client_id"] == retry["client_id"]
|
||||
|
||||
|
||||
class TestWeixinRemoteMediaSafety:
|
||||
def test_download_remote_media_blocks_unsafe_urls(self):
|
||||
adapter = _make_adapter()
|
||||
|
|
|
|||
|
|
@ -289,12 +289,16 @@ class TestCmdMigrate:
|
|||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(claw_mod, "_offer_source_archival"),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
|
|
@ -377,6 +381,16 @@ class TestCmdMigrate:
|
|||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
# Preview must succeed before the confirmation prompt is shown
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 1, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [{"kind": "soul", "status": "migrated", "source": "s", "destination": "d", "reason": ""}],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
|
|
@ -384,9 +398,15 @@ class TestCmdMigrate:
|
|||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
|
|
@ -448,7 +468,7 @@ class TestCmdMigrate:
|
|||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration failed" in captured.out
|
||||
assert "Could not load migration script" in captured.out
|
||||
|
||||
def test_full_preset_enables_secrets(self, tmp_path, capsys):
|
||||
"""The 'full' preset should set migrate_secrets=True automatically."""
|
||||
|
|
@ -511,7 +531,13 @@ class TestOfferSourceArchival:
|
|||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._offer_source_archival(source, auto_yes=False)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
|
|
@ -597,10 +623,14 @@ class TestCmdCleanup:
|
|||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
|
||||
mock_stdin = MagicMock()
|
||||
mock_stdin.isatty.return_value = True
|
||||
|
||||
args = Namespace(source=None, dry_run=False, yes=False)
|
||||
with (
|
||||
patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw]),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
patch("sys.stdin", mock_stdin),
|
||||
):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -106,6 +106,49 @@ class TestCmdUpdateBranchFallback:
|
|||
pull_cmds = [c for c in commands if "pull" in c]
|
||||
assert len(pull_cmds) == 0
|
||||
|
||||
@patch("shutil.which")
|
||||
@patch("subprocess.run")
|
||||
def test_update_refreshes_repo_and_tui_node_dependencies(
|
||||
self, mock_run, mock_which, mock_args
|
||||
):
|
||||
mock_which.side_effect = {"uv": "/usr/bin/uv", "npm": "/usr/bin/npm"}.get
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
branch="main", verify_ok=True, commit_count="1"
|
||||
)
|
||||
|
||||
cmd_update(mock_args)
|
||||
|
||||
npm_calls = [
|
||||
(call.args[0], call.kwargs.get("cwd"))
|
||||
for call in mock_run.call_args_list
|
||||
if call.args and call.args[0][0] == "/usr/bin/npm"
|
||||
]
|
||||
|
||||
assert npm_calls == [
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT,
|
||||
),
|
||||
(
|
||||
[
|
||||
"/usr/bin/npm",
|
||||
"install",
|
||||
"--silent",
|
||||
"--no-fund",
|
||||
"--no-audit",
|
||||
"--progress=false",
|
||||
],
|
||||
PROJECT_ROOT / "ui-tui",
|
||||
),
|
||||
]
|
||||
|
||||
def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys):
|
||||
"""When stdin/stdout aren't TTYs, config migration prompt is skipped."""
|
||||
with patch("shutil.which", return_value=None), patch(
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class TestWaitForGatewayExit:
|
|||
def test_kill_gateway_processes_force_uses_helper(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22])
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22])
|
||||
monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force)))
|
||||
|
||||
killed = gateway.kill_gateway_processes(force=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for gateway service management helpers."""
|
||||
|
||||
import os
|
||||
import pwd
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
|
@ -129,7 +130,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop"))
|
||||
|
|
@ -155,7 +156,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True}))
|
||||
|
|
@ -924,6 +925,23 @@ class TestProfileArg:
|
|||
assert "<string>--profile</string>" in plist
|
||||
assert "<string>mybot</string>" in plist
|
||||
|
||||
def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
machine_home = tmp_path / "machine-home"
|
||||
machine_home.mkdir()
|
||||
profile_home = profile_dir / "home"
|
||||
profile_home.mkdir()
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: profile_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home)))
|
||||
|
||||
plist_path = gateway_cli.get_launchd_plist_path()
|
||||
|
||||
assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist"
|
||||
|
||||
|
||||
class TestRemapPathForUser:
|
||||
"""Unit tests for _remap_path_for_user()."""
|
||||
|
|
|
|||
|
|
@ -1214,3 +1214,115 @@ def test_openrouter_provider_not_affected_by_custom_fix(monkeypatch):
|
|||
|
||||
resolved = rp.resolve_runtime_provider(requested="openrouter")
|
||||
assert resolved["provider"] == "openrouter"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# fix #7828 — custom_providers model field must propagate to runtime
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_named_custom_provider_includes_model(monkeypatch):
|
||||
"""_get_named_custom_provider should include the model field from config."""
|
||||
monkeypatch.setattr(rp, "load_config", lambda: {
|
||||
"custom_providers": [{
|
||||
"name": "my-dashscope",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"api_key": "test-key",
|
||||
"api_mode": "chat_completions",
|
||||
"model": "qwen3.6-plus",
|
||||
}],
|
||||
})
|
||||
|
||||
result = rp._get_named_custom_provider("my-dashscope")
|
||||
assert result is not None
|
||||
assert result["model"] == "qwen3.6-plus"
|
||||
|
||||
|
||||
def test_get_named_custom_provider_excludes_empty_model(monkeypatch):
|
||||
"""Empty or whitespace-only model field should not appear in result."""
|
||||
for model_val in ["", " ", None]:
|
||||
entry = {
|
||||
"name": "test-ep",
|
||||
"base_url": "https://example.com/v1",
|
||||
"api_key": "key",
|
||||
}
|
||||
if model_val is not None:
|
||||
entry["model"] = model_val
|
||||
|
||||
monkeypatch.setattr(rp, "load_config", lambda e=entry: {
|
||||
"custom_providers": [e],
|
||||
})
|
||||
|
||||
result = rp._get_named_custom_provider("test-ep")
|
||||
assert result is not None
|
||||
assert "model" not in result, (
|
||||
f"model field {model_val!r} should not be included in result"
|
||||
)
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_model_direct_path(monkeypatch):
|
||||
"""Model should propagate through the direct (non-pool) resolution path."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "qwen3.6-plus",
|
||||
},
|
||||
)
|
||||
# Ensure pool doesn't intercept
|
||||
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert resolved["model"] == "qwen3.6-plus"
|
||||
assert resolved["provider"] == "custom"
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_model_pool_path(monkeypatch):
|
||||
"""Model should propagate even when credential pool handles credentials."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "qwen3.6-plus",
|
||||
},
|
||||
)
|
||||
# Pool returns a result (intercepting the normal path)
|
||||
monkeypatch.setattr(
|
||||
rp, "_try_resolve_from_custom_pool",
|
||||
lambda *a, **k: {
|
||||
"provider": "custom",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "pool-key",
|
||||
"source": "pool:custom:my-server",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert resolved["model"] == "qwen3.6-plus", (
|
||||
"model must be injected into pool result"
|
||||
)
|
||||
assert resolved["api_key"] == "pool-key", "pool credentials should be used"
|
||||
|
||||
|
||||
def test_named_custom_runtime_no_model_when_absent(monkeypatch):
|
||||
"""When custom_providers entry has no model field, runtime should not either."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-server",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-server")
|
||||
assert "model" not in resolved
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def test_cmd_chat_tui_continue_uses_latest_tui_session(monkeypatch):
|
|||
calls.append(source)
|
||||
return "20260408_235959_a1b2c3" if source == "tui" else None
|
||||
|
||||
def fake_launch(resume_session_id=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch):
|
|||
return "20260408_235959_d4e5f6"
|
||||
return None
|
||||
|
||||
def fake_launch(resume_session_id=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch):
|
|||
|
||||
captured = {}
|
||||
|
||||
def fake_launch(resume_session_id=None):
|
||||
def fake_launch(resume_session_id=None, tui_dev=False):
|
||||
captured["resume"] = resume_session_id
|
||||
raise SystemExit(0)
|
||||
|
||||
|
|
|
|||
|
|
@ -191,6 +191,19 @@ class TestLaunchdPlistPath:
|
|||
raise AssertionError("PATH key not found in plist")
|
||||
|
||||
|
||||
class TestLaunchdPlistCurrentness:
|
||||
def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin")
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin")
|
||||
|
||||
assert gateway_cli.launchd_plist_is_current() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — macOS launchd detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -536,7 +549,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -579,7 +592,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -618,7 +631,7 @@ class TestServicePidExclusion:
|
|||
launchctl_loaded=True,
|
||||
)
|
||||
|
||||
def fake_find(exclude_pids=None):
|
||||
def fake_find(exclude_pids=None, all_profiles=False):
|
||||
_exclude = exclude_pids or set()
|
||||
return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude]
|
||||
|
||||
|
|
@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude:
|
|||
pids = gateway_cli.find_gateway_pids()
|
||||
assert 100 in pids
|
||||
assert 200 in pids
|
||||
|
||||
def test_filters_to_current_profile(self, monkeypatch, tmp_path):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(gateway_cli, "is_windows", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return subprocess.CompletedProcess(
|
||||
cmd, 0,
|
||||
stdout=(
|
||||
"100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n"
|
||||
"200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n"
|
||||
),
|
||||
stderr="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr("os.getpid", lambda: 999)
|
||||
monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set())
|
||||
monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha")
|
||||
|
||||
pids = gateway_cli.find_gateway_pids()
|
||||
|
||||
assert pids == [100]
|
||||
|
|
|
|||
327
tests/hermes_cli/test_xiaomi_provider.py
Normal file
327
tests/hermes_cli/test_xiaomi_provider.py
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
"""Tests for Xiaomi MiMo provider support."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure dotenv doesn't interfere
|
||||
if "dotenv" not in sys.modules:
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
sys.modules["dotenv"] = fake_dotenv
|
||||
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY,
|
||||
resolve_provider,
|
||||
get_api_key_provider_status,
|
||||
resolve_api_key_provider_credentials,
|
||||
AuthError,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Registry
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiProviderRegistry:
|
||||
"""Verify Xiaomi is registered correctly in the PROVIDER_REGISTRY."""
|
||||
|
||||
def test_registered(self):
|
||||
assert "xiaomi" in PROVIDER_REGISTRY
|
||||
|
||||
def test_name(self):
|
||||
assert PROVIDER_REGISTRY["xiaomi"].name == "Xiaomi MiMo"
|
||||
|
||||
def test_auth_type(self):
|
||||
assert PROVIDER_REGISTRY["xiaomi"].auth_type == "api_key"
|
||||
|
||||
def test_inference_base_url(self):
|
||||
assert PROVIDER_REGISTRY["xiaomi"].inference_base_url == "https://api.xiaomimimo.com/v1"
|
||||
|
||||
def test_api_key_env_vars(self):
|
||||
assert PROVIDER_REGISTRY["xiaomi"].api_key_env_vars == ("XIAOMI_API_KEY",)
|
||||
|
||||
def test_base_url_env_var(self):
|
||||
assert PROVIDER_REGISTRY["xiaomi"].base_url_env_var == "XIAOMI_BASE_URL"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Aliases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiAliases:
|
||||
"""All aliases should resolve to 'xiaomi'."""
|
||||
|
||||
@pytest.mark.parametrize("alias", [
|
||||
"xiaomi", "mimo", "xiaomi-mimo",
|
||||
])
|
||||
def test_alias_resolves(self, alias, monkeypatch):
|
||||
# Clear env to avoid auto-detection interfering
|
||||
for key in ("XIAOMI_API_KEY",):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
monkeypatch.setenv("XIAOMI_API_KEY", "sk-test-key-12345678")
|
||||
assert resolve_provider(alias) == "xiaomi"
|
||||
|
||||
def test_normalize_provider_models_py(self):
|
||||
from hermes_cli.models import normalize_provider
|
||||
assert normalize_provider("mimo") == "xiaomi"
|
||||
assert normalize_provider("xiaomi-mimo") == "xiaomi"
|
||||
|
||||
def test_normalize_provider_providers_py(self):
|
||||
from hermes_cli.providers import normalize_provider
|
||||
assert normalize_provider("mimo") == "xiaomi"
|
||||
assert normalize_provider("xiaomi-mimo") == "xiaomi"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auto-detection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiAutoDetection:
|
||||
"""Setting XIAOMI_API_KEY should auto-detect the provider."""
|
||||
|
||||
def test_auto_detect(self, monkeypatch):
|
||||
# Clear all other provider env vars
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"DEEPSEEK_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY",
|
||||
"DASHSCOPE_API_KEY", "XAI_API_KEY", "KIMI_API_KEY",
|
||||
"MINIMAX_API_KEY", "AI_GATEWAY_API_KEY", "KILOCODE_API_KEY",
|
||||
"HF_TOKEN", "GLM_API_KEY", "COPILOT_GITHUB_TOKEN",
|
||||
"GH_TOKEN", "GITHUB_TOKEN", "MINIMAX_CN_API_KEY"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
monkeypatch.setenv("XIAOMI_API_KEY", "sk-xiaomi-test-12345678")
|
||||
provider = resolve_provider("auto")
|
||||
assert provider == "xiaomi"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Credentials
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiCredentials:
|
||||
"""Test credential resolution for the xiaomi provider."""
|
||||
|
||||
def test_status_configured(self, monkeypatch):
|
||||
monkeypatch.setenv("XIAOMI_API_KEY", "sk-test-12345678")
|
||||
status = get_api_key_provider_status("xiaomi")
|
||||
assert status["configured"]
|
||||
|
||||
def test_status_not_configured(self, monkeypatch):
|
||||
monkeypatch.delenv("XIAOMI_API_KEY", raising=False)
|
||||
status = get_api_key_provider_status("xiaomi")
|
||||
assert not status["configured"]
|
||||
|
||||
def test_resolve_credentials(self, monkeypatch):
|
||||
monkeypatch.setenv("XIAOMI_API_KEY", "sk-test-12345678")
|
||||
monkeypatch.delenv("XIAOMI_BASE_URL", raising=False)
|
||||
creds = resolve_api_key_provider_credentials("xiaomi")
|
||||
assert creds["api_key"] == "sk-test-12345678"
|
||||
assert creds["base_url"] == "https://api.xiaomimimo.com/v1"
|
||||
|
||||
def test_custom_base_url_override(self, monkeypatch):
|
||||
monkeypatch.setenv("XIAOMI_API_KEY", "sk-test-12345678")
|
||||
monkeypatch.setenv("XIAOMI_BASE_URL", "https://custom.xiaomi.example/v1")
|
||||
creds = resolve_api_key_provider_credentials("xiaomi")
|
||||
assert creds["base_url"] == "https://custom.xiaomi.example/v1"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model catalog (dynamic — no static list)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiModelCatalog:
|
||||
"""Xiaomi uses dynamic model discovery via models.dev."""
|
||||
|
||||
def test_models_dev_mapping(self):
|
||||
from agent.models_dev import PROVIDER_TO_MODELS_DEV
|
||||
assert PROVIDER_TO_MODELS_DEV["xiaomi"] == "xiaomi"
|
||||
|
||||
def test_static_model_list_fallback(self):
|
||||
"""Static _PROVIDER_MODELS fallback must exist for model picker."""
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
assert "xiaomi" in _PROVIDER_MODELS
|
||||
models = _PROVIDER_MODELS["xiaomi"]
|
||||
assert "mimo-v2-pro" in models
|
||||
assert "mimo-v2-omni" in models
|
||||
assert "mimo-v2-flash" in models
|
||||
|
||||
def test_list_agentic_models_mock(self, monkeypatch):
|
||||
"""When models.dev returns Xiaomi data, list_agentic_models should return models."""
|
||||
from agent import models_dev as md
|
||||
|
||||
fake_data = {
|
||||
"xiaomi": {
|
||||
"name": "Xiaomi",
|
||||
"api": "https://api.xiaomimimo.com/v1",
|
||||
"env": ["XIAOMI_API_KEY"],
|
||||
"models": {
|
||||
"mimo-v2-pro": {
|
||||
"limit": {"context": 1000000},
|
||||
"tool_call": True,
|
||||
},
|
||||
"mimo-v2-omni": {
|
||||
"limit": {"context": 256000},
|
||||
"tool_call": True,
|
||||
},
|
||||
"mimo-v2-flash": {
|
||||
"limit": {"context": 256000},
|
||||
"tool_call": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
monkeypatch.setattr(md, "fetch_models_dev", lambda: fake_data)
|
||||
|
||||
result = md.list_agentic_models("xiaomi")
|
||||
assert "mimo-v2-pro" in result
|
||||
assert "mimo-v2-flash" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Normalization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiNormalization:
|
||||
"""Model name normalization — Xiaomi is a direct provider."""
|
||||
|
||||
def test_vendor_prefix_mapping(self):
|
||||
from hermes_cli.model_normalize import _VENDOR_PREFIXES
|
||||
assert _VENDOR_PREFIXES.get("mimo") == "xiaomi"
|
||||
|
||||
def test_matching_prefix_strip(self):
|
||||
"""xiaomi/mimo-v2-pro should normalize to mimo-v2-pro for direct API."""
|
||||
from hermes_cli.model_normalize import _MATCHING_PREFIX_STRIP_PROVIDERS
|
||||
assert "xiaomi" in _MATCHING_PREFIX_STRIP_PROVIDERS
|
||||
|
||||
def test_normalize_strips_provider_prefix(self):
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
result = normalize_model_for_provider("xiaomi/mimo-v2-pro", "xiaomi")
|
||||
assert result == "mimo-v2-pro"
|
||||
|
||||
def test_normalize_bare_name_unchanged(self):
|
||||
from hermes_cli.model_normalize import normalize_model_for_provider
|
||||
result = normalize_model_for_provider("mimo-v2-pro", "xiaomi")
|
||||
assert result == "mimo-v2-pro"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# URL mapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiURLMapping:
|
||||
"""Test URL → provider inference for Xiaomi endpoints."""
|
||||
|
||||
def test_url_to_provider(self):
|
||||
from agent.model_metadata import _URL_TO_PROVIDER
|
||||
assert _URL_TO_PROVIDER.get("api.xiaomimimo.com") == "xiaomi"
|
||||
|
||||
def test_provider_prefixes(self):
|
||||
from agent.model_metadata import _PROVIDER_PREFIXES
|
||||
assert "xiaomi" in _PROVIDER_PREFIXES
|
||||
assert "mimo" in _PROVIDER_PREFIXES
|
||||
assert "xiaomi-mimo" in _PROVIDER_PREFIXES
|
||||
|
||||
def test_infer_from_url(self):
|
||||
from agent.model_metadata import _infer_provider_from_url
|
||||
assert _infer_provider_from_url("https://api.xiaomimimo.com/v1") == "xiaomi"
|
||||
|
||||
def test_infer_from_regional_urls(self):
|
||||
"""Regional token-plan endpoints should also resolve to xiaomi."""
|
||||
from agent.model_metadata import _infer_provider_from_url
|
||||
assert _infer_provider_from_url("https://token-plan-ams.xiaomimimo.com/v1") == "xiaomi"
|
||||
assert _infer_provider_from_url("https://token-plan-cn.xiaomimimo.com/v1") == "xiaomi"
|
||||
assert _infer_provider_from_url("https://token-plan-sgp.xiaomimimo.com/v1") == "xiaomi"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# providers.py
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiProvidersModule:
|
||||
"""Test Xiaomi in the unified providers module."""
|
||||
|
||||
def test_overlay_exists(self):
|
||||
from hermes_cli.providers import HERMES_OVERLAYS
|
||||
assert "xiaomi" in HERMES_OVERLAYS
|
||||
overlay = HERMES_OVERLAYS["xiaomi"]
|
||||
assert overlay.transport == "openai_chat"
|
||||
assert overlay.base_url_env_var == "XIAOMI_BASE_URL"
|
||||
assert not overlay.is_aggregator
|
||||
|
||||
def test_alias_resolves(self):
|
||||
from hermes_cli.providers import normalize_provider
|
||||
assert normalize_provider("mimo") == "xiaomi"
|
||||
assert normalize_provider("xiaomi-mimo") == "xiaomi"
|
||||
|
||||
def test_label(self):
|
||||
from hermes_cli.providers import get_label
|
||||
assert get_label("xiaomi") == "Xiaomi MiMo"
|
||||
|
||||
def test_get_provider(self):
|
||||
pdef = None
|
||||
try:
|
||||
from hermes_cli.providers import get_provider
|
||||
pdef = get_provider("xiaomi")
|
||||
except Exception:
|
||||
pass
|
||||
if pdef is not None:
|
||||
assert pdef.id == "xiaomi"
|
||||
assert pdef.transport == "openai_chat"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auxiliary client
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiAuxiliary:
|
||||
"""Xiaomi auxiliary routing: vision → omni, non-vision → user's main model, never flash."""
|
||||
|
||||
def test_no_flash_in_aux_models(self):
|
||||
"""mimo-v2-flash must NEVER be used for automatic aux routing."""
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert "xiaomi" not in _API_KEY_PROVIDER_AUX_MODELS
|
||||
|
||||
def test_vision_model_override(self):
|
||||
"""Xiaomi vision tasks should use mimo-v2-omni (multimodal), not the main model."""
|
||||
from agent.auxiliary_client import _PROVIDER_VISION_MODELS
|
||||
assert "xiaomi" in _PROVIDER_VISION_MODELS
|
||||
assert _PROVIDER_VISION_MODELS["xiaomi"] == "mimo-v2-omni"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Agent init (no SyntaxError, correct api_mode)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestXiaomiDoctor:
|
||||
"""Verify hermes doctor recognizes Xiaomi env vars."""
|
||||
|
||||
def test_provider_env_hints(self):
|
||||
from hermes_cli.doctor import _PROVIDER_ENV_HINTS
|
||||
assert "XIAOMI_API_KEY" in _PROVIDER_ENV_HINTS
|
||||
|
||||
|
||||
class TestXiaomiAgentInit:
|
||||
"""Verify the agent can be constructed with xiaomi provider without errors."""
|
||||
|
||||
def test_no_syntax_errors(self):
|
||||
"""Importing run_agent with xiaomi should not raise."""
|
||||
import importlib
|
||||
importlib.import_module("run_agent")
|
||||
|
||||
def test_api_mode_is_chat_completions(self):
|
||||
from hermes_cli.providers import HERMES_OVERLAYS, TRANSPORT_TO_API_MODE
|
||||
overlay = HERMES_OVERLAYS["xiaomi"]
|
||||
api_mode = TRANSPORT_TO_API_MODE[overlay.transport]
|
||||
assert api_mode == "chat_completions"
|
||||
279
tests/run_agent/test_compression_feasibility.py
Normal file
279
tests/run_agent/test_compression_feasibility.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Tests for _check_compression_model_feasibility() — warns when the
|
||||
auxiliary compression model's context is smaller than the main model's
|
||||
compression threshold.
|
||||
|
||||
Two-phase design:
|
||||
1. __init__ → runs the check, prints via _vprint (CLI), stores warning
|
||||
2. run_conversation (first call) → replays stored warning through
|
||||
status_callback (gateway platforms)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent(
|
||||
*,
|
||||
compression_enabled: bool = True,
|
||||
threshold_percent: float = 0.50,
|
||||
main_context: int = 200_000,
|
||||
) -> AIAgent:
|
||||
"""Build a minimal AIAgent with a compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.model = "test-main-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-test"
|
||||
agent.quiet_mode = True
|
||||
agent.log_prefix = ""
|
||||
agent.compression_enabled = compression_enabled
|
||||
agent._print_fn = None
|
||||
agent.suppress_status_output = False
|
||||
agent._stream_consumers = []
|
||||
agent._executing_tools = False
|
||||
agent._mute_post_response = False
|
||||
agent.status_callback = None
|
||||
agent.tool_progress_callback = None
|
||||
agent._compression_warning = None
|
||||
|
||||
compressor = MagicMock(spec=ContextCompressor)
|
||||
compressor.context_length = main_context
|
||||
compressor.threshold_tokens = int(main_context * threshold_percent)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
# ── Core warning logic ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=32_768)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_warns_when_aux_context_below_threshold(mock_get_client, mock_ctx_len):
|
||||
"""Warning emitted when aux model context < main model threshold."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
# threshold = 100,000 — aux has only 32,768
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "google/gemini-3-flash-preview")
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "Compression model" in messages[0]
|
||||
assert "32,768" in messages[0]
|
||||
assert "100,000" in messages[0]
|
||||
assert "will not be possible" in messages[0]
|
||||
# Actionable fix guidance included
|
||||
assert "Fix options" in messages[0]
|
||||
assert "auxiliary:" in messages[0]
|
||||
assert "compression:" in messages[0]
|
||||
assert "threshold:" in messages[0]
|
||||
# Warning stored for gateway replay
|
||||
assert agent._compression_warning is not None
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=200_000)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_no_warning_when_aux_context_sufficient(mock_get_client, mock_ctx_len):
|
||||
"""No warning when aux model context >= main model threshold."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
# threshold = 100,000 — aux has 200,000 (sufficient)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "google/gemini-2.5-flash")
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 0
|
||||
assert agent._compression_warning is None
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_warns_when_no_auxiliary_provider(mock_get_client):
|
||||
"""Warning emitted when no auxiliary provider is configured."""
|
||||
agent = _make_agent()
|
||||
mock_get_client.return_value = (None, None)
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "No auxiliary LLM provider" in messages[0]
|
||||
assert agent._compression_warning is not None
|
||||
|
||||
|
||||
def test_skips_check_when_compression_disabled():
|
||||
"""No check performed when compression is disabled."""
|
||||
agent = _make_agent(compression_enabled=False)
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 0
|
||||
assert agent._compression_warning is None
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_exception_does_not_crash(mock_get_client):
|
||||
"""Exceptions in the check are caught — never blocks startup."""
|
||||
agent = _make_agent()
|
||||
mock_get_client.side_effect = RuntimeError("boom")
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
# Should not raise
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
# No user-facing message (error is debug-logged)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=100_000)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_exact_threshold_boundary_no_warning(mock_get_client, mock_ctx_len):
|
||||
"""No warning when aux context exactly equals the threshold."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "test-model")
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=99_999)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_just_below_threshold_warns(mock_get_client, mock_ctx_len):
|
||||
"""Warning fires when aux context is one token below the threshold."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "small-model")
|
||||
|
||||
messages = []
|
||||
agent._emit_status = lambda msg: messages.append(msg)
|
||||
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "small-model" in messages[0]
|
||||
|
||||
|
||||
# ── Two-phase: __init__ + run_conversation replay ───────────────────
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=32_768)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_warning_stored_for_gateway_replay(mock_get_client, mock_ctx_len):
|
||||
"""__init__ stores the warning; _replay sends it through status_callback."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "google/gemini-3-flash-preview")
|
||||
|
||||
# Phase 1: __init__ — _emit_status prints (CLI) but callback is None
|
||||
vprint_messages = []
|
||||
agent._emit_status = lambda msg: vprint_messages.append(msg)
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert len(vprint_messages) == 1 # CLI got it
|
||||
assert agent._compression_warning is not None # stored for replay
|
||||
|
||||
# Phase 2: gateway wires callback post-init, then run_conversation replays
|
||||
callback_events = []
|
||||
agent.status_callback = lambda ev, msg: callback_events.append((ev, msg))
|
||||
agent._replay_compression_warning()
|
||||
|
||||
assert any(
|
||||
ev == "lifecycle" and "will not be possible" in msg
|
||||
for ev, msg in callback_events
|
||||
)
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=200_000)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_no_replay_when_no_warning(mock_get_client, mock_ctx_len):
|
||||
"""_replay_compression_warning is a no-op when there's no stored warning."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "big-model")
|
||||
|
||||
agent._emit_status = lambda msg: None
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert agent._compression_warning is None
|
||||
|
||||
callback_events = []
|
||||
agent.status_callback = lambda ev, msg: callback_events.append((ev, msg))
|
||||
agent._replay_compression_warning()
|
||||
|
||||
assert len(callback_events) == 0
|
||||
|
||||
|
||||
def test_replay_without_callback_is_noop():
|
||||
"""_replay_compression_warning doesn't crash when status_callback is None."""
|
||||
agent = _make_agent()
|
||||
agent._compression_warning = "some warning"
|
||||
agent.status_callback = None
|
||||
|
||||
# Should not raise
|
||||
agent._replay_compression_warning()
|
||||
|
||||
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=32_768)
|
||||
@patch("agent.auxiliary_client.get_text_auxiliary_client")
|
||||
def test_run_conversation_clears_warning_after_replay(mock_get_client, mock_ctx_len):
|
||||
"""After replay in run_conversation, _compression_warning is cleared
|
||||
so the warning is not sent again on subsequent turns."""
|
||||
agent = _make_agent(main_context=200_000, threshold_percent=0.50)
|
||||
mock_client = MagicMock()
|
||||
mock_client.base_url = "https://openrouter.ai/api/v1"
|
||||
mock_client.api_key = "sk-aux"
|
||||
mock_get_client.return_value = (mock_client, "small-model")
|
||||
|
||||
agent._emit_status = lambda msg: None
|
||||
agent._check_compression_model_feasibility()
|
||||
|
||||
assert agent._compression_warning is not None
|
||||
|
||||
# Simulate what run_conversation does
|
||||
callback_events = []
|
||||
agent.status_callback = lambda ev, msg: callback_events.append((ev, msg))
|
||||
if agent._compression_warning:
|
||||
agent._replay_compression_warning()
|
||||
agent._compression_warning = None # as in run_conversation
|
||||
|
||||
assert len(callback_events) == 1
|
||||
|
||||
# Second turn — nothing replayed
|
||||
callback_events.clear()
|
||||
if agent._compression_warning:
|
||||
agent._replay_compression_warning()
|
||||
agent._compression_warning = None
|
||||
|
||||
assert len(callback_events) == 0
|
||||
|
|
@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def _make_bare_agent(self):
|
||||
"""Create a bare AIAgent via __new__ with all interrupt-related attrs."""
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._execution_thread_id = None # defaults to current thread in set_interrupt
|
||||
agent._active_children = []
|
||||
agent._active_children_lock = threading.Lock()
|
||||
agent.quiet_mode = True
|
||||
return agent
|
||||
|
||||
def test_parent_interrupt_sets_child_flag(self):
|
||||
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
parent = self._make_bare_agent()
|
||||
child = self._make_bare_agent()
|
||||
|
||||
parent._active_children.append(child)
|
||||
|
||||
|
|
@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_global(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
|
||||
|
||||
This is the intended behavior at startup, but verify it doesn't
|
||||
accidentally clear an interrupt intended for a running child.
|
||||
def test_child_clear_interrupt_at_start_clears_thread(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the
|
||||
per-thread interrupt flag for the current thread.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child = self._make_bare_agent()
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child.quiet_mode = True
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
|
||||
# Global is set
|
||||
# Interrupt for current thread is set
|
||||
set_interrupt(True)
|
||||
assert is_interrupted() is True
|
||||
|
||||
# child.clear_interrupt() clears both
|
||||
# child.clear_interrupt() clears both instance flag and thread flag
|
||||
child.clear_interrupt()
|
||||
assert child._interrupt_requested is False
|
||||
assert is_interrupted() is False
|
||||
|
||||
def test_interrupt_during_child_api_call_detected(self):
|
||||
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
child = self._make_bare_agent()
|
||||
child.api_mode = "chat_completions"
|
||||
child.log_prefix = ""
|
||||
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
|
||||
|
|
@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
|
||||
def test_concurrent_interrupt_propagation(self):
|
||||
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = threading.Lock()
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child._active_children_lock = threading.Lock()
|
||||
child.quiet_mode = True
|
||||
parent = self._make_bare_agent()
|
||||
child = self._make_bare_agent()
|
||||
|
||||
# Register child (simulating what _run_single_child does)
|
||||
parent._active_children.append(child)
|
||||
|
|
@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase):
|
|||
set_interrupt(False)
|
||||
|
||||
|
||||
class TestPerThreadInterruptIsolation(unittest.TestCase):
|
||||
"""Verify that interrupting one agent does NOT affect another agent's thread.
|
||||
|
||||
This is the core fix for the gateway cross-session interrupt leak:
|
||||
multiple agents run in separate threads within the same process, and
|
||||
interrupting agent A must not kill agent B's running tools.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_interrupt_only_affects_target_thread(self):
|
||||
"""set_interrupt(True, tid) only makes is_interrupted() True on that thread."""
|
||||
results = {}
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def thread_a():
|
||||
"""Agent A's execution thread — will be interrupted."""
|
||||
tid = threading.current_thread().ident
|
||||
results["a_tid"] = tid
|
||||
barrier.wait(timeout=5) # sync with thread B
|
||||
time.sleep(0.2) # let the interrupt arrive
|
||||
results["a_interrupted"] = is_interrupted()
|
||||
|
||||
def thread_b():
|
||||
"""Agent B's execution thread — should NOT be affected."""
|
||||
tid = threading.current_thread().ident
|
||||
results["b_tid"] = tid
|
||||
barrier.wait(timeout=5) # sync with thread A
|
||||
time.sleep(0.2)
|
||||
results["b_interrupted"] = is_interrupted()
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
tb.start()
|
||||
|
||||
# Wait for both threads to register their TIDs
|
||||
time.sleep(0.05)
|
||||
while "a_tid" not in results or "b_tid" not in results:
|
||||
time.sleep(0.01)
|
||||
|
||||
# Interrupt ONLY thread A (simulates gateway interrupting agent A)
|
||||
set_interrupt(True, results["a_tid"])
|
||||
|
||||
ta.join(timeout=3)
|
||||
tb.join(timeout=3)
|
||||
|
||||
assert results["a_interrupted"] is True, "Thread A should see the interrupt"
|
||||
assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt"
|
||||
|
||||
def test_clear_interrupt_only_clears_target_thread(self):
|
||||
"""Clearing one thread's interrupt doesn't clear another's."""
|
||||
tid_a = 99990001
|
||||
tid_b = 99990002
|
||||
set_interrupt(True, tid_a)
|
||||
set_interrupt(True, tid_b)
|
||||
|
||||
# Clear only A
|
||||
set_interrupt(False, tid_a)
|
||||
|
||||
# Simulate checking from thread B's perspective
|
||||
from tools.interrupt import _interrupted_threads, _lock
|
||||
with _lock:
|
||||
assert tid_a not in _interrupted_threads
|
||||
assert tid_b in _interrupted_threads
|
||||
|
||||
# Cleanup
|
||||
set_interrupt(False, tid_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -2087,8 +2087,9 @@ class TestRunConversation:
|
|||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
assert "/thinkon" in result["final_response"]
|
||||
|
||||
def test_length_empty_content_detected_as_thinking_exhausted(self, agent):
|
||||
"""When finish_reason='length' and content is None/empty, detect exhaustion."""
|
||||
def test_length_empty_content_without_think_tags_retries_normally(self, agent):
|
||||
"""When finish_reason='length' and content is None but no think tags,
|
||||
fall through to normal continuation retry (not thinking-exhaustion)."""
|
||||
self._setup_agent(agent)
|
||||
resp = _mock_response(content=None, finish_reason="length")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
|
@ -2100,12 +2101,10 @@ class TestRunConversation:
|
|||
):
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
# Without think tags, the agent should attempt continuation retries
|
||||
# (up to 3), not immediately fire thinking-exhaustion.
|
||||
assert result["api_calls"] == 3
|
||||
assert result["completed"] is False
|
||||
assert result["api_calls"] == 1
|
||||
assert "reasoning" in result["error"].lower()
|
||||
# User-friendly message is returned
|
||||
assert result["final_response"] is not None
|
||||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
|
||||
def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
|
@ -2169,6 +2168,35 @@ class TestRunConversation:
|
|||
mock_hfc.assert_called_once()
|
||||
assert result["final_response"] == "Done!"
|
||||
|
||||
def test_truncated_tool_args_detected_when_finish_reason_not_length(self, agent):
|
||||
"""When a router rewrites finish_reason from 'length' to 'tool_calls',
|
||||
truncated JSON arguments should still be detected and refused rather
|
||||
than wasting 3 retry attempts."""
|
||||
self._setup_agent(agent)
|
||||
agent.valid_tool_names.add("write_file")
|
||||
bad_tc = _mock_tool_call(
|
||||
name="write_file",
|
||||
arguments='{"path":"report.md","content":"partial',
|
||||
call_id="c1",
|
||||
)
|
||||
resp = _mock_response(
|
||||
content="", finish_reason="tool_calls", tool_calls=[bad_tc],
|
||||
)
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call") as mock_handle_function_call,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("write the report")
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result["partial"] is True
|
||||
assert "truncated due to output length limit" in result["error"]
|
||||
mock_handle_function_call.assert_not_called()
|
||||
|
||||
|
||||
class TestRetryExhaustion:
|
||||
"""Regression: retry_count > max_retries was dead code (off-by-one).
|
||||
|
|
|
|||
|
|
@ -1104,3 +1104,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch
|
|||
]
|
||||
assert "enc_first" in encrypted_contents
|
||||
assert "enc_second" in encrypted_contents
|
||||
|
||||
|
||||
def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""Duplicate reasoning item IDs across multi-turn incomplete responses
|
||||
must be deduplicated so the Responses API doesn't reject with HTTP 400."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
messages = [
|
||||
{"role": "user", "content": "think hard"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"codex_reasoning_items": [
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "partial answer",
|
||||
"codex_reasoning_items": [
|
||||
# rs_aaa is duplicated from the previous turn
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"},
|
||||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
|
||||
reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"]
|
||||
# rs_aaa should appear only once (first occurrence kept)
|
||||
assert reasoning_ids.count("rs_aaa") == 1
|
||||
# rs_bbb and rs_ccc should each appear once
|
||||
assert reasoning_ids.count("rs_bbb") == 1
|
||||
assert reasoning_ids.count("rs_ccc") == 1
|
||||
assert len(reasoning_ids) == 3
|
||||
|
||||
|
||||
def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""_preflight_codex_input_items should also deduplicate reasoning items by ID."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
raw_input = [
|
||||
{"role": "user", "content": [{"type": "input_text", "text": "hello"}]},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
|
||||
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
|
||||
reasoning_ids = [it["id"] for it in reasoning_items]
|
||||
assert reasoning_ids.count("rs_xyz") == 1
|
||||
assert reasoning_ids.count("rs_zzz") == 1
|
||||
assert len(reasoning_items) == 2
|
||||
|
|
|
|||
158
tests/tools/test_browser_orphan_reaper.py
Normal file
158
tests/tools/test_browser_orphan_reaper.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Tests for _reap_orphaned_browser_sessions() — kills orphaned agent-browser
|
||||
daemons whose Python parent exited without cleaning up."""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_tmpdir(tmp_path):
|
||||
"""Patch _socket_safe_tmpdir to return a temp dir we control."""
|
||||
with patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)):
|
||||
yield tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_sessions():
|
||||
"""Ensure _active_sessions is empty for each test."""
|
||||
import tools.browser_tool as bt
|
||||
orig = bt._active_sessions.copy()
|
||||
bt._active_sessions.clear()
|
||||
yield
|
||||
bt._active_sessions.clear()
|
||||
bt._active_sessions.update(orig)
|
||||
|
||||
|
||||
def _make_socket_dir(tmpdir, session_name, pid=None):
|
||||
"""Create a fake agent-browser socket directory with optional PID file."""
|
||||
d = tmpdir / f"agent-browser-{session_name}"
|
||||
d.mkdir()
|
||||
if pid is not None:
|
||||
(d / f"{session_name}.pid").write_text(str(pid))
|
||||
return d
|
||||
|
||||
|
||||
class TestReapOrphanedBrowserSessions:
|
||||
"""Tests for the orphan reaper function."""
|
||||
|
||||
def test_no_socket_dirs_is_noop(self, fake_tmpdir):
|
||||
"""No socket dirs => nothing happens, no errors."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
_reap_orphaned_browser_sessions() # should not raise
|
||||
|
||||
def test_stale_dir_without_pid_file_is_removed(self, fake_tmpdir):
|
||||
"""Socket dir with no PID file is cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
d = _make_socket_dir(fake_tmpdir, "h_abc1234567")
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
||||
def test_stale_dir_with_dead_pid_is_removed(self, fake_tmpdir):
|
||||
"""Socket dir whose daemon PID is dead gets cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
d = _make_socket_dir(fake_tmpdir, "h_dead123456", pid=999999999)
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
||||
def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir):
|
||||
"""Alive daemon not tracked by _active_sessions gets SIGTERM."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_orphan12345", pid=12345)
|
||||
|
||||
kill_calls = []
|
||||
original_kill = os.kill
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
if sig == 0:
|
||||
return # pretend process exists
|
||||
# Don't actually kill anything
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should have checked existence (sig 0) then killed (SIGTERM)
|
||||
assert (12345, 0) in kill_calls
|
||||
assert (12345, signal.SIGTERM) in kill_calls
|
||||
|
||||
def test_tracked_session_is_not_reaped(self, fake_tmpdir):
|
||||
"""Sessions tracked in _active_sessions are left alone."""
|
||||
import tools.browser_tool as bt
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
session_name = "h_tracked1234"
|
||||
d = _make_socket_dir(fake_tmpdir, session_name, pid=12345)
|
||||
|
||||
# Register the session as actively tracked
|
||||
bt._active_sessions["some_task"] = {"session_name": session_name}
|
||||
|
||||
kill_calls = []
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
kill_calls.append((pid, sig))
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should NOT have tried to kill anything
|
||||
assert len(kill_calls) == 0
|
||||
# Dir should still exist
|
||||
assert d.exists()
|
||||
|
||||
def test_permission_error_on_kill_check_skips(self, fake_tmpdir):
|
||||
"""If we can't check the PID (PermissionError), skip it."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345)
|
||||
|
||||
def mock_kill(pid, sig):
|
||||
if sig == 0:
|
||||
raise PermissionError("not our process")
|
||||
|
||||
with patch("os.kill", side_effect=mock_kill):
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Dir should still exist (we didn't touch someone else's process)
|
||||
assert d.exists()
|
||||
|
||||
def test_cdp_sessions_are_also_reaped(self, fake_tmpdir):
|
||||
"""CDP sessions (cdp_ prefix) are also scanned."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "cdp_abc1234567")
|
||||
assert d.exists()
|
||||
_reap_orphaned_browser_sessions()
|
||||
# No PID file → cleaned up
|
||||
assert not d.exists()
|
||||
|
||||
def test_non_hermes_dirs_are_ignored(self, fake_tmpdir):
|
||||
"""Socket dirs that don't match our naming pattern are left alone."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
# Create a dir that doesn't match h_* or cdp_* pattern
|
||||
d = fake_tmpdir / "agent-browser-other_session"
|
||||
d.mkdir()
|
||||
(d / "other_session.pid").write_text("12345")
|
||||
|
||||
_reap_orphaned_browser_sessions()
|
||||
|
||||
# Should NOT be touched
|
||||
assert d.exists()
|
||||
|
||||
def test_corrupt_pid_file_is_cleaned(self, fake_tmpdir):
|
||||
"""PID file with non-integer content is cleaned up."""
|
||||
from tools.browser_tool import _reap_orphaned_browser_sessions
|
||||
|
||||
d = _make_socket_dir(fake_tmpdir, "h_corrupt1234")
|
||||
(d / "h_corrupt1234.pid").write_text("not-a-number")
|
||||
|
||||
_reap_orphaned_browser_sessions()
|
||||
assert not d.exists()
|
||||
|
|
@ -1,9 +1,6 @@
|
|||
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
|
@ -42,6 +39,19 @@ def checkpoint_base(tmp_path):
|
|||
return tmp_path / "checkpoints"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_home(tmp_path, monkeypatch):
|
||||
"""Set a deterministic fake home for expanduser/path-home behavior."""
|
||||
home = tmp_path / "home"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HOME", str(home))
|
||||
monkeypatch.setenv("USERPROFILE", str(home))
|
||||
monkeypatch.delenv("HOMEDRIVE", raising=False)
|
||||
monkeypatch.delenv("HOMEPATH", raising=False)
|
||||
monkeypatch.setattr(Path, "home", classmethod(lambda cls: home))
|
||||
return home
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mgr(work_dir, checkpoint_base, monkeypatch):
|
||||
"""CheckpointManager with redirected checkpoint base."""
|
||||
|
|
@ -78,6 +88,16 @@ class TestShadowRepoPath:
|
|||
p = _shadow_repo_path(str(work_dir))
|
||||
assert str(p).startswith(str(checkpoint_base))
|
||||
|
||||
def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
expanded_path = str(project)
|
||||
|
||||
assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Shadow repo init
|
||||
|
|
@ -221,6 +241,20 @@ class TestListCheckpoints:
|
|||
assert result[0]["reason"] == "third"
|
||||
assert result[2]["reason"] == "first"
|
||||
|
||||
def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
(project / "main.py").write_text("v1\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
|
||||
listed = mgr.list_checkpoints(str(project))
|
||||
assert len(listed) == 1
|
||||
assert listed[0]["reason"] == "initial"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — restoring
|
||||
|
|
@ -271,6 +305,28 @@ class TestRestore:
|
|||
assert len(all_cps) >= 2
|
||||
assert "pre-rollback" in all_cps[0]["reason"]
|
||||
|
||||
def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
file_path = project / "main.py"
|
||||
file_path.write_text("original\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
mgr.new_turn()
|
||||
|
||||
file_path.write_text("changed\n")
|
||||
checkpoints = mgr.list_checkpoints(str(project))
|
||||
diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"])
|
||||
assert diff_result["success"] is True
|
||||
assert "main.py" in diff_result["diff"]
|
||||
|
||||
restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"])
|
||||
assert restore_result["success"] is True
|
||||
assert file_path.read_text() == "original\n"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — working dir resolution
|
||||
|
|
@ -310,6 +366,19 @@ class TestWorkingDirResolution:
|
|||
result = mgr.get_working_dir_for_path(str(filepath))
|
||||
assert result == str(filepath.parent)
|
||||
|
||||
def test_resolves_tilde_path_to_project_root(self, fake_home):
|
||||
mgr = CheckpointManager(enabled=True)
|
||||
project = fake_home / "myproject"
|
||||
project.mkdir()
|
||||
(project / "pyproject.toml").write_text("[project]\n")
|
||||
subdir = project / "src"
|
||||
subdir.mkdir()
|
||||
filepath = subdir / "main.py"
|
||||
filepath.write_text("x\n")
|
||||
|
||||
result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py")
|
||||
assert result == str(project)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Git env isolation
|
||||
|
|
@ -333,6 +402,14 @@ class TestGitEnvIsolation:
|
|||
env = _git_env(shadow, str(tmp_path))
|
||||
assert "GIT_INDEX_FILE" not in env
|
||||
|
||||
def test_expands_tilde_in_work_tree(self, fake_home, tmp_path):
|
||||
shadow = tmp_path / "shadow"
|
||||
work = fake_home / "work"
|
||||
work.mkdir()
|
||||
|
||||
env = _git_env(shadow, f"~/{work.name}")
|
||||
assert env["GIT_WORK_TREE"] == str(work.resolve())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_checkpoint_list
|
||||
|
|
@ -384,6 +461,8 @@ class TestErrorResilience:
|
|||
assert result is False
|
||||
|
||||
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
completed = subprocess.CompletedProcess(
|
||||
args=["git", "diff", "--cached", "--quiet"],
|
||||
returncode=1,
|
||||
|
|
@ -395,7 +474,7 @@ class TestErrorResilience:
|
|||
ok, stdout, stderr = _run_git(
|
||||
["diff", "--cached", "--quiet"],
|
||||
tmp_path / "shadow",
|
||||
str(tmp_path / "work"),
|
||||
str(work),
|
||||
allowed_returncodes={1},
|
||||
)
|
||||
assert ok is False
|
||||
|
|
@ -403,6 +482,38 @@ class TestErrorResilience:
|
|||
assert stderr == ""
|
||||
assert not caplog.records
|
||||
|
||||
def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog):
|
||||
missing = tmp_path / "missing"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(missing),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert "working directory not found" in stderr
|
||||
assert not any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
|
||||
def raise_missing_git(*args, **kwargs):
|
||||
raise FileNotFoundError(2, "No such file or directory", "git")
|
||||
|
||||
monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git)
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(work),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert stderr == "git not found"
|
||||
assert any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
||||
"""Checkpoint failures should never raise — they're silently logged."""
|
||||
def broken_run_git(*args, **kwargs):
|
||||
|
|
@ -411,3 +522,68 @@ class TestErrorResilience:
|
|||
# Should not raise
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
||||
assert result is False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Security / Input validation
|
||||
# =========================================================================
|
||||
|
||||
class TestSecurity:
|
||||
def test_restore_rejects_argument_injection(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Try to pass a git flag as a commit hash
|
||||
result = mgr.restore(str(work_dir), "--patch")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
assert "must not start with '-'" in result["error"]
|
||||
|
||||
result = mgr.restore(str(work_dir), "-p")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
|
||||
def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Git hashes should not contain characters like ;, &, |
|
||||
result = mgr.restore(str(work_dir), "abc; rm -rf /")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
result = mgr.diff(str(work_dir), "abc&def")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
def test_restore_rejects_path_traversal(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Real commit hash but malicious path
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Absolute path outside
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd")
|
||||
assert result["success"] is False
|
||||
assert "got absolute path" in result["error"]
|
||||
|
||||
# Relative traversal outside path
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt")
|
||||
assert result["success"] is False
|
||||
assert "escapes the working directory" in result["error"]
|
||||
|
||||
def test_restore_accepts_valid_file_path(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Valid path inside directory
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="main.py")
|
||||
assert result["success"] is True
|
||||
|
||||
# Another valid path with subdirectories
|
||||
(work_dir / "subdir").mkdir()
|
||||
(work_dir / "subdir" / "test.txt").write_text("hello")
|
||||
mgr.new_turn()
|
||||
mgr.ensure_checkpoint(str(work_dir), "second")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt")
|
||||
assert result["success"] is True
|
||||
|
|
|
|||
|
|
@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase):
|
|||
@unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
|
||||
class TestInterruptHandling(unittest.TestCase):
|
||||
def test_interrupt_event_stops_execution(self):
|
||||
"""When _interrupt_event is set, execute_code should stop the script."""
|
||||
"""When interrupt is set for the execution thread, execute_code should stop."""
|
||||
code = "import time; time.sleep(60); print('should not reach')"
|
||||
from tools.interrupt import set_interrupt
|
||||
|
||||
# Capture the main thread ID so we can target the interrupt correctly.
|
||||
# execute_code runs in the current thread; set_interrupt needs its ID.
|
||||
main_tid = threading.current_thread().ident
|
||||
|
||||
def set_interrupt_after_delay():
|
||||
import time as _t
|
||||
_t.sleep(1)
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
_interrupt_event.set()
|
||||
set_interrupt(True, main_tid)
|
||||
|
||||
t = threading.Thread(target=set_interrupt_after_delay, daemon=True)
|
||||
t.start()
|
||||
|
|
@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase):
|
|||
self.assertEqual(result["status"], "interrupted")
|
||||
self.assertIn("interrupted", result["output"])
|
||||
finally:
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
_interrupt_event.clear()
|
||||
set_interrupt(False, main_tid)
|
||||
t.join(timeout=3)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,8 @@ class TestCheckpointNotify:
|
|||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_user_id": "u123",
|
||||
"watcher_user_name": "alice",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 5,
|
||||
"notify_on_complete": True,
|
||||
|
|
@ -236,6 +238,8 @@ class TestCheckpointNotify:
|
|||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 1
|
||||
assert registry.pending_watchers[0]["notify_on_complete"] is True
|
||||
assert registry.pending_watchers[0]["user_id"] == "u123"
|
||||
assert registry.pending_watchers[0]["user_name"] == "alice"
|
||||
|
||||
def test_recover_defaults_false(self, registry, tmp_path):
|
||||
"""Old checkpoint entries without the field default to False."""
|
||||
|
|
|
|||
|
|
@ -438,6 +438,8 @@ class TestCheckpoint:
|
|||
s = _make_session()
|
||||
s.watcher_platform = "telegram"
|
||||
s.watcher_chat_id = "999"
|
||||
s.watcher_user_id = "u123"
|
||||
s.watcher_user_name = "alice"
|
||||
s.watcher_thread_id = "42"
|
||||
s.watcher_interval = 60
|
||||
registry._running[s.id] = s
|
||||
|
|
@ -447,6 +449,8 @@ class TestCheckpoint:
|
|||
assert len(data) == 1
|
||||
assert data[0]["watcher_platform"] == "telegram"
|
||||
assert data[0]["watcher_chat_id"] == "999"
|
||||
assert data[0]["watcher_user_id"] == "u123"
|
||||
assert data[0]["watcher_user_name"] == "alice"
|
||||
assert data[0]["watcher_thread_id"] == "42"
|
||||
assert data[0]["watcher_interval"] == 60
|
||||
|
||||
|
|
@ -460,6 +464,8 @@ class TestCheckpoint:
|
|||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_user_id": "u123",
|
||||
"watcher_user_name": "alice",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 60,
|
||||
}]))
|
||||
|
|
@ -471,6 +477,8 @@ class TestCheckpoint:
|
|||
assert w["session_id"] == "proc_live"
|
||||
assert w["platform"] == "telegram"
|
||||
assert w["chat_id"] == "123"
|
||||
assert w["user_id"] == "u123"
|
||||
assert w["user_name"] == "alice"
|
||||
assert w["thread_id"] == "42"
|
||||
assert w["check_interval"] == 60
|
||||
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ word word
|
|||
result = _patch_skill("my-skill", "old text", "new text", file_path="references/evil.md")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert outside_file.read_text() == "old text here"
|
||||
|
||||
|
||||
|
|
@ -412,7 +412,7 @@ class TestWriteFile:
|
|||
result = _write_file("my-skill", "references/escape/owned.md", "malicious")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert not (outside_dir / "owned.md").exists()
|
||||
|
||||
|
||||
|
|
@ -449,7 +449,7 @@ class TestRemoveFile:
|
|||
result = _remove_file("my-skill", "references/escape/keep.txt")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "boundary" in result["error"].lower()
|
||||
assert "escapes" in result["error"].lower()
|
||||
assert outside_file.exists()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,34 @@ class TestWriteToSandbox:
|
|||
cmd = env.execute.call_args[0][0]
|
||||
assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd
|
||||
|
||||
def test_path_with_spaces_is_quoted(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
remote_path = "/tmp/hermes results/abc file.txt"
|
||||
_write_to_sandbox("content", remote_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert "'/tmp/hermes results'" in cmd
|
||||
assert "'/tmp/hermes results/abc file.txt'" in cmd
|
||||
|
||||
def test_shell_metacharacters_neutralized(self):
|
||||
"""Paths with shell metacharacters must be quoted to prevent injection."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/hermes-results/$(whoami).txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The $() must not appear unquoted — shlex.quote wraps it
|
||||
assert "'/tmp/hermes-results/$(whoami).txt'" in cmd
|
||||
|
||||
def test_semicolon_injection_neutralized(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/x; rm -rf /; echo .txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The semicolons must be inside quotes, not acting as command separators
|
||||
assert "'/tmp/x; rm -rf /; echo .txt'" in cmd
|
||||
|
||||
|
||||
class TestResolveStorageDir:
|
||||
def test_defaults_to_storage_dir_without_env(self):
|
||||
|
|
|
|||
|
|
@ -769,6 +769,62 @@ class TestResizeImageForVision:
|
|||
assert _RESIZE_TARGET_BYTES == 5 * 1024 * 1024
|
||||
assert _MAX_BASE64_BYTES > _RESIZE_TARGET_BYTES
|
||||
|
||||
def test_extreme_aspect_ratio_preserved(self, tmp_path):
|
||||
"""Extreme aspect ratios should be preserved during resize."""
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
pytest.skip("Pillow not installed")
|
||||
# Very wide panorama: 8000x200
|
||||
img = Image.new("RGB", (8000, 200), (100, 150, 200))
|
||||
path = tmp_path / "panorama.png"
|
||||
img.save(path, "PNG")
|
||||
|
||||
result = _resize_image_for_vision(path, mime_type="image/png",
|
||||
max_base64_bytes=50_000)
|
||||
assert result.startswith("data:image/")
|
||||
# Decode and check aspect ratio is roughly preserved
|
||||
import base64
|
||||
header, b64data = result.split(",", 1)
|
||||
raw = base64.b64decode(b64data)
|
||||
from io import BytesIO
|
||||
resized = Image.open(BytesIO(raw))
|
||||
original_ratio = 8000 / 200 # 40:1
|
||||
resized_ratio = resized.width / resized.height if resized.height > 0 else 0
|
||||
# Allow some tolerance (floor clamping), but ratio should stay above 10:1
|
||||
# With independent halving, ratio would collapse to ~1:1. Proportional
|
||||
# scaling should keep it well above 10.
|
||||
assert resized_ratio > 10, (
|
||||
f"Aspect ratio collapsed: {resized.width}x{resized.height} "
|
||||
f"(ratio {resized_ratio:.1f}, expected >10)"
|
||||
)
|
||||
|
||||
def test_tall_narrow_image_preserved(self, tmp_path):
|
||||
"""Tall narrow images should also preserve aspect ratio."""
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
pytest.skip("Pillow not installed")
|
||||
# Very tall: 200x6000
|
||||
img = Image.new("RGB", (200, 6000), (200, 100, 50))
|
||||
path = tmp_path / "tall.png"
|
||||
img.save(path, "PNG")
|
||||
|
||||
result = _resize_image_for_vision(path, mime_type="image/png",
|
||||
max_base64_bytes=50_000)
|
||||
assert result.startswith("data:image/")
|
||||
import base64
|
||||
from io import BytesIO
|
||||
header, b64data = result.split(",", 1)
|
||||
raw = base64.b64decode(b64data)
|
||||
resized = Image.open(BytesIO(raw))
|
||||
original_ratio = 6000 / 200 # 30:1 (h/w)
|
||||
resized_ratio = resized.height / resized.width if resized.width > 0 else 0
|
||||
assert resized_ratio > 5, (
|
||||
f"Aspect ratio collapsed: {resized.width}x{resized.height} "
|
||||
f"(h/w ratio {resized_ratio:.1f}, expected >5)"
|
||||
)
|
||||
|
||||
def test_no_pillow_returns_original(self, tmp_path):
|
||||
"""Without Pillow, oversized images should be returned as-is."""
|
||||
# Create a dummy file
|
||||
|
|
|
|||
|
|
@ -473,13 +473,104 @@ def _cleanup_inactive_browser_sessions():
|
|||
logger.warning("Error cleaning up inactive session %s: %s", task_id, e)
|
||||
|
||||
|
||||
def _reap_orphaned_browser_sessions():
|
||||
"""Scan for orphaned agent-browser daemon processes from previous runs.
|
||||
|
||||
When the Python process that created a browser session exits uncleanly
|
||||
(SIGKILL, crash, gateway restart), the in-memory ``_active_sessions``
|
||||
tracking is lost but the node + Chromium processes keep running.
|
||||
|
||||
This function scans the tmp directory for ``agent-browser-*`` socket dirs
|
||||
left behind by previous runs, reads the daemon PID files, and kills any
|
||||
daemons that are still alive but not tracked by the current process.
|
||||
|
||||
Called once on cleanup-thread startup — not every 30 seconds — to avoid
|
||||
races with sessions being actively created.
|
||||
"""
|
||||
import glob
|
||||
|
||||
tmpdir = _socket_safe_tmpdir()
|
||||
pattern = os.path.join(tmpdir, "agent-browser-h_*")
|
||||
socket_dirs = glob.glob(pattern)
|
||||
# Also pick up CDP sessions
|
||||
socket_dirs += glob.glob(os.path.join(tmpdir, "agent-browser-cdp_*"))
|
||||
|
||||
if not socket_dirs:
|
||||
return
|
||||
|
||||
# Build set of session_names currently tracked by this process
|
||||
with _cleanup_lock:
|
||||
tracked_names = {
|
||||
info.get("session_name")
|
||||
for info in _active_sessions.values()
|
||||
if info.get("session_name")
|
||||
}
|
||||
|
||||
reaped = 0
|
||||
for socket_dir in socket_dirs:
|
||||
dir_name = os.path.basename(socket_dir)
|
||||
# dir_name is "agent-browser-{session_name}"
|
||||
session_name = dir_name.removeprefix("agent-browser-")
|
||||
if not session_name:
|
||||
continue
|
||||
|
||||
# Skip sessions that we are actively tracking
|
||||
if session_name in tracked_names:
|
||||
continue
|
||||
|
||||
pid_file = os.path.join(socket_dir, f"{session_name}.pid")
|
||||
if not os.path.isfile(pid_file):
|
||||
# No PID file — just a stale dir, remove it
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
|
||||
try:
|
||||
daemon_pid = int(Path(pid_file).read_text().strip())
|
||||
except (ValueError, OSError):
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
|
||||
# Check if the daemon is still alive
|
||||
try:
|
||||
os.kill(daemon_pid, 0) # signal 0 = existence check
|
||||
except ProcessLookupError:
|
||||
# Already dead, just clean up the dir
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
continue
|
||||
except PermissionError:
|
||||
# Alive but owned by someone else — leave it alone
|
||||
continue
|
||||
|
||||
# Daemon is alive and not tracked — orphan. Kill it.
|
||||
try:
|
||||
os.kill(daemon_pid, signal.SIGTERM)
|
||||
logger.info("Reaped orphaned browser daemon PID %d (session %s)",
|
||||
daemon_pid, session_name)
|
||||
reaped += 1
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# Clean up the socket directory
|
||||
shutil.rmtree(socket_dir, ignore_errors=True)
|
||||
|
||||
if reaped:
|
||||
logger.info("Reaped %d orphaned browser session(s) from previous run(s)", reaped)
|
||||
|
||||
|
||||
def _browser_cleanup_thread_worker():
|
||||
"""
|
||||
Background thread that periodically cleans up inactive browser sessions.
|
||||
|
||||
Runs every 30 seconds and checks for sessions that haven't been used
|
||||
within the BROWSER_SESSION_INACTIVITY_TIMEOUT period.
|
||||
On first run, also reaps orphaned sessions from previous process lifetimes.
|
||||
"""
|
||||
# One-time orphan reap on startup
|
||||
try:
|
||||
_reap_orphaned_browser_sessions()
|
||||
except Exception as e:
|
||||
logger.warning("Orphan reap error: %s", e)
|
||||
|
||||
while _cleanup_running:
|
||||
try:
|
||||
_cleanup_inactive_browser_sessions()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ into the user's project directory.
|
|||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
|
@ -64,23 +65,72 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", "
|
|||
# Max files to snapshot — skip huge directories to avoid slowdowns.
|
||||
_MAX_FILES = 50_000
|
||||
|
||||
# Valid git commit hash pattern: 4–40 hex chars (short or full SHA-1/SHA-256).
|
||||
_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _validate_commit_hash(commit_hash: str) -> Optional[str]:
|
||||
"""Validate a commit hash to prevent git argument injection.
|
||||
|
||||
Returns an error string if invalid, None if valid.
|
||||
Values starting with '-' would be interpreted as git flags
|
||||
(e.g., '--patch', '-p') instead of revision specifiers.
|
||||
"""
|
||||
if not commit_hash or not commit_hash.strip():
|
||||
return "Empty commit hash"
|
||||
if commit_hash.startswith("-"):
|
||||
return f"Invalid commit hash (must not start with '-'): {commit_hash!r}"
|
||||
if not _COMMIT_HASH_RE.match(commit_hash):
|
||||
return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]:
|
||||
"""Validate a file path to prevent path traversal outside the working directory.
|
||||
|
||||
Returns an error string if invalid, None if valid.
|
||||
"""
|
||||
if not file_path or not file_path.strip():
|
||||
return "Empty file path"
|
||||
# Reject absolute paths — restore targets must be relative to the workdir
|
||||
if os.path.isabs(file_path):
|
||||
return f"File path must be relative, got absolute path: {file_path!r}"
|
||||
# Resolve and check containment within working_dir
|
||||
abs_workdir = _normalize_path(working_dir)
|
||||
resolved = (abs_workdir / file_path).resolve()
|
||||
try:
|
||||
resolved.relative_to(abs_workdir)
|
||||
except ValueError:
|
||||
return f"File path escapes the working directory via traversal: {file_path!r}"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shadow repo helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _normalize_path(path_value: str) -> Path:
|
||||
"""Return a canonical absolute path for checkpoint operations."""
|
||||
return Path(path_value).expanduser().resolve()
|
||||
|
||||
|
||||
def _shadow_repo_path(working_dir: str) -> Path:
|
||||
"""Deterministic shadow repo path: sha256(abs_path)[:16]."""
|
||||
abs_path = str(Path(working_dir).resolve())
|
||||
abs_path = str(_normalize_path(working_dir))
|
||||
dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16]
|
||||
return CHECKPOINT_BASE / dir_hash
|
||||
|
||||
|
||||
def _git_env(shadow_repo: Path, working_dir: str) -> dict:
|
||||
"""Build env dict that redirects git to the shadow repo."""
|
||||
normalized_working_dir = _normalize_path(working_dir)
|
||||
env = os.environ.copy()
|
||||
env["GIT_DIR"] = str(shadow_repo)
|
||||
env["GIT_WORK_TREE"] = str(Path(working_dir).resolve())
|
||||
env["GIT_WORK_TREE"] = str(normalized_working_dir)
|
||||
env.pop("GIT_INDEX_FILE", None)
|
||||
env.pop("GIT_NAMESPACE", None)
|
||||
env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None)
|
||||
|
|
@ -100,7 +150,17 @@ def _run_git(
|
|||
exits while preserving the normal ``ok = (returncode == 0)`` contract.
|
||||
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
|
||||
"""
|
||||
env = _git_env(shadow_repo, working_dir)
|
||||
normalized_working_dir = _normalize_path(working_dir)
|
||||
if not normalized_working_dir.exists():
|
||||
msg = f"working directory not found: {normalized_working_dir}"
|
||||
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
|
||||
return False, "", msg
|
||||
if not normalized_working_dir.is_dir():
|
||||
msg = f"working directory is not a directory: {normalized_working_dir}"
|
||||
logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg)
|
||||
return False, "", msg
|
||||
|
||||
env = _git_env(shadow_repo, str(normalized_working_dir))
|
||||
cmd = ["git"] + list(args)
|
||||
allowed_returncodes = allowed_returncodes or set()
|
||||
try:
|
||||
|
|
@ -110,7 +170,7 @@ def _run_git(
|
|||
text=True,
|
||||
timeout=timeout,
|
||||
env=env,
|
||||
cwd=str(Path(working_dir).resolve()),
|
||||
cwd=str(normalized_working_dir),
|
||||
)
|
||||
ok = result.returncode == 0
|
||||
stdout = result.stdout.strip()
|
||||
|
|
@ -125,9 +185,14 @@ def _run_git(
|
|||
msg = f"git timed out after {timeout}s: {' '.join(cmd)}"
|
||||
logger.error(msg, exc_info=True)
|
||||
return False, "", msg
|
||||
except FileNotFoundError:
|
||||
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
|
||||
return False, "", "git not found"
|
||||
except FileNotFoundError as exc:
|
||||
missing_target = getattr(exc, "filename", None)
|
||||
if missing_target == "git":
|
||||
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
|
||||
return False, "", "git not found"
|
||||
msg = f"working directory not found: {normalized_working_dir}"
|
||||
logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True)
|
||||
return False, "", msg
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True)
|
||||
return False, "", str(exc)
|
||||
|
|
@ -154,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]:
|
|||
)
|
||||
|
||||
(shadow_repo / "HERMES_WORKDIR").write_text(
|
||||
str(Path(working_dir).resolve()) + "\n", encoding="utf-8"
|
||||
str(_normalize_path(working_dir)) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir)
|
||||
|
|
@ -229,7 +294,7 @@ class CheckpointManager:
|
|||
if not self._git_available:
|
||||
return False
|
||||
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
|
||||
# Skip root, home, and other overly broad directories
|
||||
if abs_dir in ("/", str(Path.home())):
|
||||
|
|
@ -254,7 +319,7 @@ class CheckpointManager:
|
|||
Returns a list of dicts with keys: hash, short_hash, timestamp, reason,
|
||||
files_changed, insertions, deletions. Most recent first.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -311,7 +376,12 @@ class CheckpointManager:
|
|||
|
||||
Returns dict with success, diff text, and stat summary.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
# Validate commit_hash to prevent git argument injection
|
||||
hash_err = _validate_commit_hash(commit_hash)
|
||||
if hash_err:
|
||||
return {"success": False, "error": hash_err}
|
||||
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -364,7 +434,19 @@ class CheckpointManager:
|
|||
|
||||
Returns dict with success/error info.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
# Validate commit_hash to prevent git argument injection
|
||||
hash_err = _validate_commit_hash(commit_hash)
|
||||
if hash_err:
|
||||
return {"success": False, "error": hash_err}
|
||||
|
||||
abs_dir = str(_normalize_path(working_dir))
|
||||
|
||||
# Validate file_path to prevent path traversal outside the working dir
|
||||
if file_path:
|
||||
path_err = _validate_file_path(file_path, abs_dir)
|
||||
if path_err:
|
||||
return {"success": False, "error": path_err}
|
||||
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
|
|
@ -413,7 +495,7 @@ class CheckpointManager:
|
|||
(directory containing .git, pyproject.toml, package.json, etc.).
|
||||
Falls back to the file's parent directory.
|
||||
"""
|
||||
path = Path(file_path).resolve()
|
||||
path = _normalize_path(file_path)
|
||||
if path.is_dir():
|
||||
candidate = path
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -924,8 +924,8 @@ def execute_code(
|
|||
|
||||
# --- Local execution path (UDS) --- below this line is unchanged ---
|
||||
|
||||
# Import interrupt event from terminal_tool (cooperative cancellation)
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
# Import per-thread interrupt check (cooperative cancellation)
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
|
||||
# Resolve config
|
||||
_cfg = _load_config()
|
||||
|
|
@ -1114,7 +1114,7 @@ def execute_code(
|
|||
|
||||
status = "success"
|
||||
while proc.poll() is None:
|
||||
if _interrupt_event.is_set():
|
||||
if _is_interrupted():
|
||||
_kill_process_group(proc)
|
||||
status = "interrupted"
|
||||
break
|
||||
|
|
|
|||
|
|
@ -80,20 +80,18 @@ def register_credential_file(
|
|||
|
||||
# Resolve symlinks and normalise ``..`` before the containment check so
|
||||
# that traversal like ``../. ssh/id_rsa`` cannot escape HERMES_HOME.
|
||||
try:
|
||||
resolved = host_path.resolve()
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
resolved.relative_to(hermes_home_resolved) # raises ValueError if outside
|
||||
except ValueError:
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
containment_error = validate_within_dir(host_path, hermes_home)
|
||||
if containment_error:
|
||||
logger.warning(
|
||||
"credential_files: rejected path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
"credential_files: rejected path traversal %r (%s)",
|
||||
relative_path,
|
||||
resolved,
|
||||
hermes_home_resolved,
|
||||
containment_error,
|
||||
)
|
||||
return False
|
||||
|
||||
resolved = host_path.resolve()
|
||||
if not resolved.is_file():
|
||||
logger.debug("credential_files: skipping %s (not found)", resolved)
|
||||
return False
|
||||
|
|
@ -142,7 +140,8 @@ def _load_config_files() -> List[Dict[str, str]]:
|
|||
cfg = read_raw_config()
|
||||
cred_files = cfg.get("terminal", {}).get("credential_files")
|
||||
if isinstance(cred_files, list):
|
||||
hermes_home_resolved = hermes_home.resolve()
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
for item in cred_files:
|
||||
if isinstance(item, str) and item.strip():
|
||||
rel = item.strip()
|
||||
|
|
@ -151,20 +150,19 @@ def _load_config_files() -> List[Dict[str, str]]:
|
|||
"credential_files: rejected absolute config path %r", rel,
|
||||
)
|
||||
continue
|
||||
host_path = (hermes_home / rel).resolve()
|
||||
try:
|
||||
host_path.relative_to(hermes_home_resolved)
|
||||
except ValueError:
|
||||
host_path = hermes_home / rel
|
||||
containment_error = validate_within_dir(host_path, hermes_home)
|
||||
if containment_error:
|
||||
logger.warning(
|
||||
"credential_files: rejected config path traversal %r "
|
||||
"(resolves to %s, outside HERMES_HOME %s)",
|
||||
rel, host_path, hermes_home_resolved,
|
||||
"credential_files: rejected config path traversal %r (%s)",
|
||||
rel, containment_error,
|
||||
)
|
||||
continue
|
||||
if host_path.is_file():
|
||||
resolved_path = host_path.resolve()
|
||||
if resolved_path.is_file():
|
||||
container_path = f"/root/.hermes/{rel}"
|
||||
result.append({
|
||||
"host_path": str(host_path),
|
||||
"host_path": str(resolved_path),
|
||||
"container_path": container_path,
|
||||
})
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -165,12 +165,12 @@ def _validate_cron_script_path(script: Optional[str]) -> Optional[str]:
|
|||
)
|
||||
|
||||
# Validate containment after resolution
|
||||
from tools.path_security import validate_within_dir
|
||||
|
||||
scripts_dir = get_hermes_home() / "scripts"
|
||||
scripts_dir.mkdir(parents=True, exist_ok=True)
|
||||
resolved = (scripts_dir / raw).resolve()
|
||||
try:
|
||||
resolved.relative_to(scripts_dir.resolve())
|
||||
except ValueError:
|
||||
containment_error = validate_within_dir(scripts_dir / raw, scripts_dir)
|
||||
if containment_error:
|
||||
return (
|
||||
f"Script path escapes the scripts directory via traversal: {raw!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
"""Shared interrupt signaling for all tools.
|
||||
"""Per-thread interrupt signaling for all tools.
|
||||
|
||||
Provides a global threading.Event that any tool can check to determine
|
||||
if the user has requested an interrupt. The agent's interrupt() method
|
||||
sets this event, and tools poll it during long-running operations.
|
||||
Provides thread-scoped interrupt tracking so that interrupting one agent
|
||||
session does not kill tools running in other sessions. This is critical
|
||||
in the gateway where multiple agents run concurrently in the same process.
|
||||
|
||||
The agent stores its execution thread ID at the start of run_conversation()
|
||||
and passes it to set_interrupt()/clear_interrupt(). Tools call
|
||||
is_interrupted() which checks the CURRENT thread — no argument needed.
|
||||
|
||||
Usage in tools:
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -12,17 +16,61 @@ Usage in tools:
|
|||
|
||||
import threading
|
||||
|
||||
_interrupt_event = threading.Event()
|
||||
# Set of thread idents that have been interrupted.
|
||||
_interrupted_threads: set[int] = set()
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def set_interrupt(active: bool) -> None:
|
||||
"""Called by the agent to signal or clear the interrupt."""
|
||||
if active:
|
||||
_interrupt_event.set()
|
||||
else:
|
||||
_interrupt_event.clear()
|
||||
def set_interrupt(active: bool, thread_id: int | None = None) -> None:
|
||||
"""Set or clear interrupt for a specific thread.
|
||||
|
||||
Args:
|
||||
active: True to signal interrupt, False to clear it.
|
||||
thread_id: Target thread ident. When None, targets the
|
||||
current thread (backward compat for CLI/tests).
|
||||
"""
|
||||
tid = thread_id if thread_id is not None else threading.current_thread().ident
|
||||
with _lock:
|
||||
if active:
|
||||
_interrupted_threads.add(tid)
|
||||
else:
|
||||
_interrupted_threads.discard(tid)
|
||||
|
||||
|
||||
def is_interrupted() -> bool:
|
||||
"""Check if an interrupt has been requested. Safe to call from any thread."""
|
||||
return _interrupt_event.is_set()
|
||||
"""Check if an interrupt has been requested for the current thread.
|
||||
|
||||
Safe to call from any thread — each thread only sees its own
|
||||
interrupt state.
|
||||
"""
|
||||
tid = threading.current_thread().ident
|
||||
with _lock:
|
||||
return tid in _interrupted_threads
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatible _interrupt_event proxy
|
||||
# ---------------------------------------------------------------------------
|
||||
# Some legacy call sites (code_execution_tool, process_registry, tests)
|
||||
# import _interrupt_event directly and call .is_set() / .set() / .clear().
|
||||
# This shim maps those calls to the per-thread functions above so existing
|
||||
# code keeps working while the underlying mechanism is thread-scoped.
|
||||
|
||||
class _ThreadAwareEventProxy:
|
||||
"""Drop-in proxy that maps threading.Event methods to per-thread state."""
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return is_interrupted()
|
||||
|
||||
def set(self) -> None: # noqa: A003
|
||||
set_interrupt(True)
|
||||
|
||||
def clear(self) -> None:
|
||||
set_interrupt(False)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Not truly supported — returns current state immediately."""
|
||||
return self.is_set()
|
||||
|
||||
|
||||
_interrupt_event = _ThreadAwareEventProxy()
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue