diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..a690443f7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +# Git +.git +.gitignore +.gitmodules + +# Dependencies +node_modules + +# CI/CD +.github + +# Environment files +.env \ No newline at end of file diff --git a/.env.example b/.env.example index d273a6966..bcb5708d6 100644 --- a/.env.example +++ b/.env.example @@ -59,12 +59,25 @@ OPENCODE_ZEN_API_KEY= # OpenCode Go provides access to open models (GLM-5, Kimi K2.5, MiniMax M2.5) # $10/month subscription. Get your key at: https://opencode.ai/auth OPENCODE_GO_API_KEY= + +# ============================================================================= +# LLM PROVIDER (Hugging Face Inference Providers) +# ============================================================================= +# Hugging Face routes to 20+ open models via unified OpenAI-compatible endpoint. +# Free tier included ($0.10/month), no markup on provider rates. +# Get your token at: https://huggingface.co/settings/tokens +# Required permission: "Make calls to Inference Providers" +HF_TOKEN= # OPENCODE_GO_BASE_URL=https://opencode.ai/zen/go/v1 # Override default base URL # ============================================================================= # TOOL API KEYS # ============================================================================= +# Exa API Key - AI-native web search and contents +# Get at: https://exa.ai +EXA_API_KEY= + # Parallel API Key - AI-native web search and extract # Get at: https://parallel.ai PARALLEL_API_KEY= @@ -85,7 +98,7 @@ FAL_KEY= HONCHO_API_KEY= # ============================================================================= -# TERMINAL TOOL CONFIGURATION (mini-swe-agent backend) +# TERMINAL TOOL CONFIGURATION # ============================================================================= # Backend type: "local", "singularity", "docker", "modal", or "ssh" # Terminal backend is configured in ~/.hermes/config.yaml (terminal.backend). diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 000000000..11b98c3a9 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,61 @@ +name: Docker Build and Publish + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: docker-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-push: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build image + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile + load: true + tags: nousresearch/hermes-agent:test + cache-from: type=gha + cache-to: type=gha,mode=max + + - name: Test image starts + run: | + docker run --rm \ + -v /tmp/hermes-test:/opt/data \ + --entrypoint /opt/hermes/docker/entrypoint.sh \ + nousresearch/hermes-agent:test --help + + - name: Log in to Docker Hub + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Push image + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: docker/build-push-action@v6 + with: + context: . + file: Dockerfile + push: true + tags: | + nousresearch/hermes-agent:latest + nousresearch/hermes-agent:${{ github.sha }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/AGENTS.md b/AGENTS.md index 19c6f2797..8045c3d21 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -210,6 +210,10 @@ registry.register( The registry handles schema collection, dispatch, availability checking, and error wrapping. All handlers MUST return a JSON string. +**Path references in tool schemas**: If the schema description mentions file paths (e.g. default output directories), use `display_hermes_home()` to make them profile-aware. The schema is generated at import time, which is after `_apply_profile_override()` sets `HERMES_HOME`. + +**State files**: If a tool stores persistent state (caches, logs, checkpoints), use `get_hermes_home()` for the base directory — never `Path.home() / ".hermes"`. This ensures each profile gets its own state. + **Agent-level tools** (todo, memory): intercepted by `run_agent.py` before `handle_function_call()`. See `todo_tool.py` for the pattern. --- @@ -358,8 +362,69 @@ in config.yaml (or `HERMES_BACKGROUND_NOTIFICATIONS` env var): --- +## Profiles: Multi-Instance Support + +Hermes supports **profiles** — multiple fully isolated instances, each with its own +`HERMES_HOME` directory (config, API keys, memory, sessions, skills, gateway, etc.). + +The core mechanism: `_apply_profile_override()` in `hermes_cli/main.py` sets +`HERMES_HOME` before any module imports. All 119+ references to `get_hermes_home()` +automatically scope to the active profile. + +### Rules for profile-safe code + +1. **Use `get_hermes_home()` for all HERMES_HOME paths.** Import from `hermes_constants`. + NEVER hardcode `~/.hermes` or `Path.home() / ".hermes"` in code that reads/writes state. + ```python + # GOOD + from hermes_constants import get_hermes_home + config_path = get_hermes_home() / "config.yaml" + + # BAD — breaks profiles + config_path = Path.home() / ".hermes" / "config.yaml" + ``` + +2. **Use `display_hermes_home()` for user-facing messages.** Import from `hermes_constants`. + This returns `~/.hermes` for default or `~/.hermes/profiles/` for profiles. + ```python + # GOOD + from hermes_constants import display_hermes_home + print(f"Config saved to {display_hermes_home()}/config.yaml") + + # BAD — shows wrong path for profiles + print("Config saved to ~/.hermes/config.yaml") + ``` + +3. **Module-level constants are fine** — they cache `get_hermes_home()` at import time, + which is AFTER `_apply_profile_override()` sets the env var. Just use `get_hermes_home()`, + not `Path.home() / ".hermes"`. + +4. **Tests that mock `Path.home()` must also set `HERMES_HOME`** — since code now uses + `get_hermes_home()` (reads env var), not `Path.home() / ".hermes"`: + ```python + with patch.object(Path, "home", return_value=tmp_path), \ + patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}): + ... + ``` + +5. **Gateway platform adapters should use token locks** — if the adapter connects with + a unique credential (bot token, API key), call `acquire_scoped_lock()` from + `gateway.status` in the `connect()`/`start()` method and `release_scoped_lock()` in + `disconnect()`/`stop()`. This prevents two profiles from using the same credential. + See `gateway/platforms/telegram.py` for the canonical pattern. + +6. **Profile operations are HOME-anchored, not HERMES_HOME-anchored** — `_get_profiles_root()` + returns `Path.home() / ".hermes" / "profiles"`, NOT `get_hermes_home() / "profiles"`. + This is intentional — it lets `hermes -p coder profile list` see all profiles regardless + of which one is active. + ## Known Pitfalls +### DO NOT hardcode `~/.hermes` paths +Use `get_hermes_home()` from `hermes_constants` for code paths. Use `display_hermes_home()` +for user-facing print/log messages. Hardcoding `~/.hermes` breaks profiles — each profile +has its own `HERMES_HOME` directory. This was the source of 5 bugs fixed in PR #3575. + ### DO NOT use `simple_term_menu` for interactive menus Rendering bugs in tmux/iTerm2 — ghosting on scroll. Use `curses` (stdlib) instead. See `hermes_cli/tools_config.py` for the pattern. @@ -375,6 +440,19 @@ Tool schema descriptions must not mention tools from other toolsets by name (e.g ### Tests must not write to `~/.hermes/` The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests. +**Profile tests**: When testing profile features, also mock `Path.home()` so that +`_get_profiles_root()` and `_get_default_hermes_home()` resolve within the temp dir. +Use the pattern from `tests/hermes_cli/test_profiles.py`: +```python +@pytest.fixture +def profile_env(tmp_path, monkeypatch): + home = tmp_path / ".hermes" + home.mkdir() + monkeypatch.setattr(Path, "home", lambda: tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + return home +``` + --- ## Testing diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..61b725d39 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM debian:13.4 + +RUN apt-get update +RUN apt-get install -y nodejs npm python3 python3-pip ripgrep ffmpeg gcc python3-dev libffi-dev + +COPY . /opt/hermes +WORKDIR /opt/hermes + +RUN pip install -e ".[all]" --break-system-packages +RUN npm install +RUN npx playwright install --with-deps chromium +WORKDIR /opt/hermes/scripts/whatsapp-bridge +RUN npm install + +WORKDIR /opt/hermes +RUN chmod +x /opt/hermes/docker/entrypoint.sh + +ENV HERMES_HOME=/opt/data +VOLUME [ "/opt/data" ] +ENTRYPOINT [ "/opt/hermes/docker/entrypoint.sh" ] \ No newline at end of file diff --git a/RELEASE_v0.5.0.md b/RELEASE_v0.5.0.md new file mode 100644 index 000000000..1f8ce9866 --- /dev/null +++ b/RELEASE_v0.5.0.md @@ -0,0 +1,348 @@ +# Hermes Agent v0.5.0 (v2026.3.28) + +**Release Date:** March 28, 2026 + +> The hardening release — Hugging Face provider, /model command overhaul, Telegram Private Chat Topics, native Modal SDK, plugin lifecycle hooks, tool-use enforcement for GPT models, Nix flake, 50+ security and reliability fixes, and a comprehensive supply chain audit. + +--- + +## ✨ Highlights + +- **Nous Portal now supports 400+ models** — The Nous Research inference portal has expanded dramatically, giving Hermes Agent users access to over 400 models through a single provider endpoint + +- **Hugging Face as a first-class inference provider** — Full integration with HF Inference API including curated agentic model picker that maps to OpenRouter analogues, live `/models` endpoint probe, and setup wizard flow ([#3419](https://github.com/NousResearch/hermes-agent/pull/3419), [#3440](https://github.com/NousResearch/hermes-agent/pull/3440)) + +- **Telegram Private Chat Topics** — Project-based conversations with functional skill binding per topic, enabling isolated workflows within a single Telegram chat ([#3163](https://github.com/NousResearch/hermes-agent/pull/3163)) + +- **Native Modal SDK backend** — Replaced swe-rex dependency with native Modal SDK (`Sandbox.create.aio` + `exec.aio`), eliminating tunnels and simplifying the Modal terminal backend ([#3538](https://github.com/NousResearch/hermes-agent/pull/3538)) + +- **Plugin lifecycle hooks activated** — `pre_llm_call`, `post_llm_call`, `on_session_start`, and `on_session_end` hooks now fire in the agent loop and CLI/gateway, completing the plugin hook system ([#3542](https://github.com/NousResearch/hermes-agent/pull/3542)) + +- **Improved OpenAI Model Reliability** — Added `GPT_TOOL_USE_GUIDANCE` to prevent GPT models from describing intended actions instead of making tool calls, plus automatic stripping of stale budget warnings from conversation history that caused models to avoid tools across turns ([#3528](https://github.com/NousResearch/hermes-agent/pull/3528)) + +- **Nix flake** — Full uv2nix build, NixOS module with persistent container mode, auto-generated config keys from Python source, and suffix PATHs for agent-friendliness ([#20](https://github.com/NousResearch/hermes-agent/pull/20), [#3274](https://github.com/NousResearch/hermes-agent/pull/3274), [#3061](https://github.com/NousResearch/hermes-agent/pull/3061)) by @alt-glitch + +- **Supply chain hardening** — Removed compromised `litellm` dependency, pinned all dependency version ranges, regenerated `uv.lock` with hashes, added CI workflow scanning PRs for supply chain attack patterns, and bumped deps to fix CVEs ([#2796](https://github.com/NousResearch/hermes-agent/pull/2796), [#2810](https://github.com/NousResearch/hermes-agent/pull/2810), [#2812](https://github.com/NousResearch/hermes-agent/pull/2812), [#2816](https://github.com/NousResearch/hermes-agent/pull/2816), [#3073](https://github.com/NousResearch/hermes-agent/pull/3073)) + +- **Anthropic output limits fix** — Replaced hardcoded 16K `max_tokens` with per-model native output limits (128K for Opus 4.6, 64K for Sonnet 4.6), fixing "Response truncated" and thinking-budget exhaustion on direct Anthropic API ([#3426](https://github.com/NousResearch/hermes-agent/pull/3426), [#3444](https://github.com/NousResearch/hermes-agent/pull/3444)) + +--- + +## 🏗️ Core Agent & Architecture + +### New Provider: Hugging Face +- First-class Hugging Face Inference API integration with auth, setup wizard, and model picker ([#3419](https://github.com/NousResearch/hermes-agent/pull/3419)) +- Curated model list mapping OpenRouter agentic defaults to HF equivalents — providers with 8+ curated models skip live `/models` probe for speed ([#3440](https://github.com/NousResearch/hermes-agent/pull/3440)) +- Added glm-5-turbo to Z.AI provider model list ([#3095](https://github.com/NousResearch/hermes-agent/pull/3095)) + +### Provider & Model Improvements +- `/model` command overhaul — extracted shared `switch_model()` pipeline for CLI and gateway, custom endpoint support, provider-aware routing ([#2795](https://github.com/NousResearch/hermes-agent/pull/2795), [#2799](https://github.com/NousResearch/hermes-agent/pull/2799)) +- Removed `/model` slash command from CLI and gateway in favor of `hermes model` subcommand ([#3080](https://github.com/NousResearch/hermes-agent/pull/3080)) +- Preserve `custom` provider instead of silently remapping to `openrouter` ([#2792](https://github.com/NousResearch/hermes-agent/pull/2792)) +- Read root-level `provider` and `base_url` from config.yaml into model config ([#3112](https://github.com/NousResearch/hermes-agent/pull/3112)) +- Align Nous Portal model slugs with OpenRouter naming ([#3253](https://github.com/NousResearch/hermes-agent/pull/3253)) +- Fix Alibaba provider default endpoint and model list ([#3484](https://github.com/NousResearch/hermes-agent/pull/3484)) +- Allow MiniMax users to override `/v1` → `/anthropic` auto-correction ([#3553](https://github.com/NousResearch/hermes-agent/pull/3553)) +- Migrate OAuth token refresh to `platform.claude.com` with fallback ([#3246](https://github.com/NousResearch/hermes-agent/pull/3246)) + +### Agent Loop & Conversation +- **Improved OpenAI model reliability** — `GPT_TOOL_USE_GUIDANCE` prevents GPT models from describing actions instead of calling tools + automatic budget warning stripping from history ([#3528](https://github.com/NousResearch/hermes-agent/pull/3528)) +- **Surface lifecycle events** — All retry, fallback, and compression events now surface to the user as formatted messages ([#3153](https://github.com/NousResearch/hermes-agent/pull/3153)) +- **Anthropic output limits** — Per-model native output limits instead of hardcoded 16K `max_tokens` ([#3426](https://github.com/NousResearch/hermes-agent/pull/3426)) +- **Thinking-budget exhaustion detection** — Skip useless continuation retries when model uses all output tokens on reasoning ([#3444](https://github.com/NousResearch/hermes-agent/pull/3444)) +- Always prefer streaming for API calls to prevent hung subagents ([#3120](https://github.com/NousResearch/hermes-agent/pull/3120)) +- Restore safe non-streaming fallback after stream failures ([#3020](https://github.com/NousResearch/hermes-agent/pull/3020)) +- Give subagents independent iteration budgets ([#3004](https://github.com/NousResearch/hermes-agent/pull/3004)) +- Update `api_key` in `_try_activate_fallback` for subagent auth ([#3103](https://github.com/NousResearch/hermes-agent/pull/3103)) +- Graceful return on max retries instead of crashing thread ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Count compression restarts toward retry limit ([#3070](https://github.com/NousResearch/hermes-agent/pull/3070)) +- Include tool tokens in preflight estimate, guard context probe persistence ([#3164](https://github.com/NousResearch/hermes-agent/pull/3164)) +- Update context compressor limits after fallback activation ([#3305](https://github.com/NousResearch/hermes-agent/pull/3305)) +- Validate empty user messages to prevent Anthropic API 400 errors ([#3322](https://github.com/NousResearch/hermes-agent/pull/3322)) +- GLM reasoning-only and max-length handling ([#3010](https://github.com/NousResearch/hermes-agent/pull/3010)) +- Increase API timeout default from 900s to 1800s for slow-thinking models ([#3431](https://github.com/NousResearch/hermes-agent/pull/3431)) +- Send `max_tokens` for Claude/OpenRouter + retry SSE connection errors ([#3497](https://github.com/NousResearch/hermes-agent/pull/3497)) +- Prevent AsyncOpenAI/httpx cross-loop deadlock in gateway mode ([#2701](https://github.com/NousResearch/hermes-agent/pull/2701)) by @ctlst + +### Streaming & Reasoning +- **Persist reasoning across gateway session turns** with new schema v6 columns (`reasoning`, `reasoning_details`, `codex_reasoning_items`) ([#2974](https://github.com/NousResearch/hermes-agent/pull/2974)) +- Detect and kill stale SSE connections ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Fix stale stream detector race causing spurious `RemoteProtocolError` ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Skip duplicate callback for ``-extracted reasoning during streaming ([#3116](https://github.com/NousResearch/hermes-agent/pull/3116)) +- Preserve reasoning fields in `rewrite_transcript` ([#3311](https://github.com/NousResearch/hermes-agent/pull/3311)) +- Preserve Gemini thought signatures in streamed tool calls ([#2997](https://github.com/NousResearch/hermes-agent/pull/2997)) +- Ensure first delta is fired during reasoning updates ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +### Session & Memory +- **Session search recent sessions mode** — Omit query to browse recent sessions with titles, previews, and timestamps ([#2533](https://github.com/NousResearch/hermes-agent/pull/2533)) +- **Session config surfacing** on `/new`, `/reset`, and auto-reset ([#3321](https://github.com/NousResearch/hermes-agent/pull/3321)) +- **Third-party session isolation** — `--source` flag for isolating sessions by origin ([#3255](https://github.com/NousResearch/hermes-agent/pull/3255)) +- Add `/resume` CLI handler, session log truncation guard, `reopen_session` API ([#3315](https://github.com/NousResearch/hermes-agent/pull/3315)) +- Clear compressor summary and turn counter on `/clear` and `/new` ([#3102](https://github.com/NousResearch/hermes-agent/pull/3102)) +- Surface silent SessionDB failures that cause session data loss ([#2999](https://github.com/NousResearch/hermes-agent/pull/2999)) +- Session search fallback preview on summarization failure ([#3478](https://github.com/NousResearch/hermes-agent/pull/3478)) +- Prevent stale memory overwrites by flush agent ([#2687](https://github.com/NousResearch/hermes-agent/pull/2687)) + +### Context Compression +- Replace dead `summary_target_tokens` with ratio-based scaling ([#2554](https://github.com/NousResearch/hermes-agent/pull/2554)) +- Expose `compression.target_ratio`, `protect_last_n`, and `threshold` in `DEFAULT_CONFIG` ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Restore sane defaults and cap summary at 12K tokens ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Preserve transcript on `/compress` and hygiene compression ([#3556](https://github.com/NousResearch/hermes-agent/pull/3556)) +- Update context pressure warnings and token estimates after compaction ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +### Architecture & Dependencies +- **Remove mini-swe-agent dependency** — Inline Docker and Modal backends directly ([#2804](https://github.com/NousResearch/hermes-agent/pull/2804)) +- **Replace swe-rex with native Modal SDK** for Modal backend ([#3538](https://github.com/NousResearch/hermes-agent/pull/3538)) +- **Plugin lifecycle hooks** — `pre_llm_call`, `post_llm_call`, `on_session_start`, `on_session_end` now fire in the agent loop ([#3542](https://github.com/NousResearch/hermes-agent/pull/3542)) +- Fix plugin toolsets invisible in `hermes tools` and standalone processes ([#3457](https://github.com/NousResearch/hermes-agent/pull/3457)) +- Consolidate `get_hermes_home()` and `parse_reasoning_effort()` ([#3062](https://github.com/NousResearch/hermes-agent/pull/3062)) +- Remove unused Hermes-native PKCE OAuth flow ([#3107](https://github.com/NousResearch/hermes-agent/pull/3107)) +- Remove ~100 unused imports across 55 files ([#3016](https://github.com/NousResearch/hermes-agent/pull/3016)) +- Fix 154 f-strings, simplify getattr/URL patterns, remove dead code ([#3119](https://github.com/NousResearch/hermes-agent/pull/3119)) + +--- + +## 📱 Messaging Platforms (Gateway) + +### Telegram +- **Private Chat Topics** — Project-based conversations with functional skill binding per topic, enabling isolated workflows within a single Telegram chat ([#3163](https://github.com/NousResearch/hermes-agent/pull/3163)) +- **Auto-discover fallback IPs via DNS-over-HTTPS** when `api.telegram.org` is unreachable ([#3376](https://github.com/NousResearch/hermes-agent/pull/3376)) +- **Configurable reply threading mode** ([#2907](https://github.com/NousResearch/hermes-agent/pull/2907)) +- Fall back to no `thread_id` on "Message thread not found" BadRequest ([#3390](https://github.com/NousResearch/hermes-agent/pull/3390)) +- Self-reschedule reconnect when `start_polling` fails after 502 ([#3268](https://github.com/NousResearch/hermes-agent/pull/3268)) + +### Discord +- Stop phantom typing indicator after agent turn completes ([#3003](https://github.com/NousResearch/hermes-agent/pull/3003)) + +### Slack +- Send tool call progress messages to correct Slack thread ([#3063](https://github.com/NousResearch/hermes-agent/pull/3063)) +- Scope progress thread fallback to Slack only ([#3488](https://github.com/NousResearch/hermes-agent/pull/3488)) + +### WhatsApp +- Download documents, audio, and video media from messages ([#2978](https://github.com/NousResearch/hermes-agent/pull/2978)) + +### Matrix +- Add missing Matrix entry in `PLATFORMS` dict ([#3473](https://github.com/NousResearch/hermes-agent/pull/3473)) +- Harden e2ee access-token handling ([#3562](https://github.com/NousResearch/hermes-agent/pull/3562)) +- Add backoff for `SyncError` in sync loop ([#3280](https://github.com/NousResearch/hermes-agent/pull/3280)) + +### Signal +- Track SSE keepalive comments as connection activity ([#3316](https://github.com/NousResearch/hermes-agent/pull/3316)) + +### Email +- Prevent unbounded growth of `_seen_uids` in EmailAdapter ([#3490](https://github.com/NousResearch/hermes-agent/pull/3490)) + +### Gateway Core +- **Config-gated `/verbose` command** for messaging platforms — toggle tool output verbosity from chat ([#3262](https://github.com/NousResearch/hermes-agent/pull/3262)) +- **Background review notifications** delivered to user chat ([#3293](https://github.com/NousResearch/hermes-agent/pull/3293)) +- **Retry transient send failures** and notify user on exhaustion ([#3288](https://github.com/NousResearch/hermes-agent/pull/3288)) +- Recover from hung agents — `/stop` hard-kills session lock ([#3104](https://github.com/NousResearch/hermes-agent/pull/3104)) +- Thread-safe `SessionStore` — protect `_entries` with `threading.Lock` ([#3052](https://github.com/NousResearch/hermes-agent/pull/3052)) +- Fix gateway token double-counting with cached agents — use absolute set instead of increment ([#3306](https://github.com/NousResearch/hermes-agent/pull/3306), [#3317](https://github.com/NousResearch/hermes-agent/pull/3317)) +- Fingerprint full auth token in agent cache signature ([#3247](https://github.com/NousResearch/hermes-agent/pull/3247)) +- Silence background agent terminal output ([#3297](https://github.com/NousResearch/hermes-agent/pull/3297)) +- Include per-platform `ALLOW_ALL` and `SIGNAL_GROUP` in startup allowlist check ([#3313](https://github.com/NousResearch/hermes-agent/pull/3313)) +- Include user-local bin paths in systemd unit PATH ([#3527](https://github.com/NousResearch/hermes-agent/pull/3527)) +- Track background task references in `GatewayRunner` ([#3254](https://github.com/NousResearch/hermes-agent/pull/3254)) +- Add request timeouts to HA, Email, Mattermost, SMS adapters ([#3258](https://github.com/NousResearch/hermes-agent/pull/3258)) +- Add media download retry to Mattermost, Slack, and base cache ([#3323](https://github.com/NousResearch/hermes-agent/pull/3323)) +- Detect virtualenv path instead of hardcoding `venv/` ([#2797](https://github.com/NousResearch/hermes-agent/pull/2797)) +- Use `TERMINAL_CWD` for context file discovery, not process cwd ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Stop loading hermes repo AGENTS.md into gateway sessions (~10k wasted tokens) ([#2891](https://github.com/NousResearch/hermes-agent/pull/2891)) + +--- + +## 🖥️ CLI & User Experience + +### Interactive CLI +- **Configurable busy input mode** + fix `/queue` always working ([#3298](https://github.com/NousResearch/hermes-agent/pull/3298)) +- **Preserve user input on multiline paste** ([#3065](https://github.com/NousResearch/hermes-agent/pull/3065)) +- **Tool generation callback** — streaming "preparing terminal…" updates during tool argument generation ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Show tool progress for substantive tools, not just "preparing" ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Buffer reasoning preview chunks and fix duplicate display ([#3013](https://github.com/NousResearch/hermes-agent/pull/3013)) +- Prevent reasoning box from rendering 3x during tool-calling loops ([#3405](https://github.com/NousResearch/hermes-agent/pull/3405)) +- Eliminate "Event loop is closed" / "Press ENTER to continue" during idle — three-layer fix with `neuter_async_httpx_del()`, custom exception handler, and stale client cleanup ([#3398](https://github.com/NousResearch/hermes-agent/pull/3398)) +- Fix status bar shows 26K instead of 260K for token counts with trailing zeros ([#3024](https://github.com/NousResearch/hermes-agent/pull/3024)) +- Fix status bar duplicates and degrades during long sessions ([#3291](https://github.com/NousResearch/hermes-agent/pull/3291)) +- Refresh TUI before background task output to prevent status bar overlap ([#3048](https://github.com/NousResearch/hermes-agent/pull/3048)) +- Suppress KawaiiSpinner animation under `patch_stdout` ([#2994](https://github.com/NousResearch/hermes-agent/pull/2994)) +- Skip KawaiiSpinner when TUI handles tool progress ([#2973](https://github.com/NousResearch/hermes-agent/pull/2973)) +- Guard `isatty()` against closed streams via `_is_tty` property ([#3056](https://github.com/NousResearch/hermes-agent/pull/3056)) +- Ensure single closure of streaming boxes during tool generation ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Cap context pressure percentage at 100% in display ([#3480](https://github.com/NousResearch/hermes-agent/pull/3480)) +- Clean up HTML error messages in CLI display ([#3069](https://github.com/NousResearch/hermes-agent/pull/3069)) +- Show HTTP status code and 400 body in API error output ([#3096](https://github.com/NousResearch/hermes-agent/pull/3096)) +- Extract useful info from HTML error pages, dump debug on max retries ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Prevent TypeError on startup when `base_url` is None ([#3068](https://github.com/NousResearch/hermes-agent/pull/3068)) +- Prevent update crash in non-TTY environments ([#3094](https://github.com/NousResearch/hermes-agent/pull/3094)) +- Handle EOFError in sessions delete/prune confirmation prompts ([#3101](https://github.com/NousResearch/hermes-agent/pull/3101)) +- Catch KeyboardInterrupt during `flush_memories` on exit and in exit cleanup handlers ([#3025](https://github.com/NousResearch/hermes-agent/pull/3025), [#3257](https://github.com/NousResearch/hermes-agent/pull/3257)) +- Guard `.strip()` against None values from YAML config ([#3552](https://github.com/NousResearch/hermes-agent/pull/3552)) +- Guard `config.get()` against YAML null values to prevent AttributeError ([#3377](https://github.com/NousResearch/hermes-agent/pull/3377)) +- Store asyncio task references to prevent GC mid-execution ([#3267](https://github.com/NousResearch/hermes-agent/pull/3267)) + +### Setup & Configuration +- Use explicit key mapping for returning-user menu dispatch instead of positional index ([#3083](https://github.com/NousResearch/hermes-agent/pull/3083)) +- Use `sys.executable` for pip in update commands to fix PEP 668 ([#3099](https://github.com/NousResearch/hermes-agent/pull/3099)) +- Harden `hermes update` against diverged history, non-main branches, and gateway edge cases ([#3492](https://github.com/NousResearch/hermes-agent/pull/3492)) +- OpenClaw migration overwrites defaults and setup wizard skips imported sections — fixed ([#3282](https://github.com/NousResearch/hermes-agent/pull/3282)) +- Stop recursive AGENTS.md walk, load top-level only ([#3110](https://github.com/NousResearch/hermes-agent/pull/3110)) +- Add macOS Homebrew paths to browser and terminal PATH resolution ([#2713](https://github.com/NousResearch/hermes-agent/pull/2713)) +- YAML boolean handling for `tool_progress` config ([#3300](https://github.com/NousResearch/hermes-agent/pull/3300)) +- Reset default SOUL.md to baseline identity text ([#3159](https://github.com/NousResearch/hermes-agent/pull/3159)) +- Reject relative cwd paths for container terminal backends ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Add explicit `hermes-api-server` toolset for API server platform ([#3304](https://github.com/NousResearch/hermes-agent/pull/3304)) +- Reorder setup wizard providers — OpenRouter first ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +--- + +## 🔧 Tool System + +### API Server +- **Idempotency-Key support**, body size limit, and OpenAI error envelope ([#2903](https://github.com/NousResearch/hermes-agent/pull/2903)) +- Allow Idempotency-Key in CORS headers ([#3530](https://github.com/NousResearch/hermes-agent/pull/3530)) +- Cancel orphaned agent + true interrupt on SSE disconnect ([#3427](https://github.com/NousResearch/hermes-agent/pull/3427)) +- Fix streaming breaks when agent makes tool calls ([#2985](https://github.com/NousResearch/hermes-agent/pull/2985)) + +### Terminal & File Operations +- Handle addition-only hunks in V4A patch parser ([#3325](https://github.com/NousResearch/hermes-agent/pull/3325)) +- Exponential backoff for persistent shell polling ([#2996](https://github.com/NousResearch/hermes-agent/pull/2996)) +- Add timeout to subprocess calls in `context_references` ([#3469](https://github.com/NousResearch/hermes-agent/pull/3469)) + +### Browser & Vision +- Handle 402 insufficient credits error in vision tool ([#2802](https://github.com/NousResearch/hermes-agent/pull/2802)) +- Fix `browser_vision` ignores `auxiliary.vision.timeout` config ([#2901](https://github.com/NousResearch/hermes-agent/pull/2901)) +- Make browser command timeout configurable via config.yaml ([#2801](https://github.com/NousResearch/hermes-agent/pull/2801)) + +### MCP +- MCP toolset resolution for runtime and config ([#3252](https://github.com/NousResearch/hermes-agent/pull/3252)) +- Add MCP tool name collision protection ([#3077](https://github.com/NousResearch/hermes-agent/pull/3077)) + +### Auxiliary LLM +- Guard aux LLM calls against None content + reasoning fallback + retry ([#3449](https://github.com/NousResearch/hermes-agent/pull/3449)) +- Catch ImportError from `build_anthropic_client` in vision auto-detection ([#3312](https://github.com/NousResearch/hermes-agent/pull/3312)) + +### Other Tools +- Add request timeouts to `send_message_tool` HTTP calls ([#3162](https://github.com/NousResearch/hermes-agent/pull/3162)) by @memosr +- Auto-repair `jobs.json` with invalid control characters ([#3537](https://github.com/NousResearch/hermes-agent/pull/3537)) +- Enable fine-grained tool streaming for Claude/OpenRouter ([#3497](https://github.com/NousResearch/hermes-agent/pull/3497)) + +--- + +## 🧩 Skills Ecosystem + +### Skills System +- **Env var passthrough** for skills and user config — skills can declare environment variables to pass through ([#2807](https://github.com/NousResearch/hermes-agent/pull/2807)) +- Cache skills prompt with shared `skill_utils` module for faster TTFT ([#3421](https://github.com/NousResearch/hermes-agent/pull/3421)) +- Avoid redundant file re-read for skill conditions ([#2992](https://github.com/NousResearch/hermes-agent/pull/2992)) +- Use Git Trees API to prevent silent subdirectory loss during install ([#2995](https://github.com/NousResearch/hermes-agent/pull/2995)) +- Fix skills-sh install for deeply nested repo structures ([#2980](https://github.com/NousResearch/hermes-agent/pull/2980)) +- Handle null metadata in skill frontmatter ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Preserve trust for skills-sh identifiers + reduce resolution churn ([#3251](https://github.com/NousResearch/hermes-agent/pull/3251)) +- Agent-created skills were incorrectly treated as untrusted community content — fixed ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +### New Skills +- **G0DM0D3 godmode jailbreaking skill** + docs ([#3157](https://github.com/NousResearch/hermes-agent/pull/3157)) +- **Docker management skill** added to optional-skills ([#3060](https://github.com/NousResearch/hermes-agent/pull/3060)) +- **OpenClaw migration v2** — 17 new modules, terminal recap for migrating from OpenClaw to Hermes ([#2906](https://github.com/NousResearch/hermes-agent/pull/2906)) + +--- + +## 🔒 Security & Reliability + +### Security Hardening +- **SSRF protection** added to `browser_navigate` ([#3058](https://github.com/NousResearch/hermes-agent/pull/3058)) +- **SSRF protection** added to `vision_tools` and `web_tools` (hardened) ([#2679](https://github.com/NousResearch/hermes-agent/pull/2679)) +- **Restrict subagent toolsets** to parent's enabled set ([#3269](https://github.com/NousResearch/hermes-agent/pull/3269)) +- **Prevent zip-slip path traversal** in self-update ([#3250](https://github.com/NousResearch/hermes-agent/pull/3250)) +- **Prevent shell injection** in `_expand_path` via `~user` path suffix ([#2685](https://github.com/NousResearch/hermes-agent/pull/2685)) +- **Normalize input** before dangerous command detection ([#3260](https://github.com/NousResearch/hermes-agent/pull/3260)) +- Make tirith block verdicts approvable instead of hard-blocking ([#3428](https://github.com/NousResearch/hermes-agent/pull/3428)) +- Remove compromised `litellm`/`typer`/`platformdirs` from deps ([#2796](https://github.com/NousResearch/hermes-agent/pull/2796)) +- Pin all dependency version ranges ([#2810](https://github.com/NousResearch/hermes-agent/pull/2810)) +- Regenerate `uv.lock` with hashes, use lockfile in setup ([#2812](https://github.com/NousResearch/hermes-agent/pull/2812)) +- Bump dependencies to fix CVEs + regenerate `uv.lock` ([#3073](https://github.com/NousResearch/hermes-agent/pull/3073)) +- Supply chain audit CI workflow for PR scanning ([#2816](https://github.com/NousResearch/hermes-agent/pull/2816)) + +### Reliability +- **SQLite WAL write-lock contention** causing 15-20s TUI freeze — fixed ([#3385](https://github.com/NousResearch/hermes-agent/pull/3385)) +- **SQLite concurrency hardening** + session transcript integrity ([#3249](https://github.com/NousResearch/hermes-agent/pull/3249)) +- Prevent recurring cron job re-fire on gateway crash/restart loop ([#3396](https://github.com/NousResearch/hermes-agent/pull/3396)) +- Mark cron session as ended after job completes ([#2998](https://github.com/NousResearch/hermes-agent/pull/2998)) + +--- + +## ⚡ Performance + +- **TTFT startup optimizations** — salvaged easy-win startup improvements ([#3395](https://github.com/NousResearch/hermes-agent/pull/3395)) +- Cache skills prompt with shared `skill_utils` module ([#3421](https://github.com/NousResearch/hermes-agent/pull/3421)) +- Avoid redundant file re-read for skill conditions in prompt builder ([#2992](https://github.com/NousResearch/hermes-agent/pull/2992)) + +--- + +## 🐛 Notable Bug Fixes + +- Fix gateway token double-counting with cached agents ([#3306](https://github.com/NousResearch/hermes-agent/pull/3306), [#3317](https://github.com/NousResearch/hermes-agent/pull/3317)) +- Fix "Event loop is closed" / "Press ENTER to continue" during idle sessions ([#3398](https://github.com/NousResearch/hermes-agent/pull/3398)) +- Fix reasoning box rendering 3x during tool-calling loops ([#3405](https://github.com/NousResearch/hermes-agent/pull/3405)) +- Fix status bar shows 26K instead of 260K for token counts ([#3024](https://github.com/NousResearch/hermes-agent/pull/3024)) +- Fix `/queue` always working regardless of config ([#3298](https://github.com/NousResearch/hermes-agent/pull/3298)) +- Fix phantom Discord typing indicator after agent turn ([#3003](https://github.com/NousResearch/hermes-agent/pull/3003)) +- Fix Slack progress messages appearing in wrong thread ([#3063](https://github.com/NousResearch/hermes-agent/pull/3063)) +- Fix WhatsApp media downloads (documents, audio, video) ([#2978](https://github.com/NousResearch/hermes-agent/pull/2978)) +- Fix Telegram "Message thread not found" killing progress messages ([#3390](https://github.com/NousResearch/hermes-agent/pull/3390)) +- Fix OpenClaw migration overwriting defaults ([#3282](https://github.com/NousResearch/hermes-agent/pull/3282)) +- Fix returning-user setup menu dispatching wrong section ([#3083](https://github.com/NousResearch/hermes-agent/pull/3083)) +- Fix `hermes update` PEP 668 "externally-managed-environment" error ([#3099](https://github.com/NousResearch/hermes-agent/pull/3099)) +- Fix subagents hitting `max_iterations` prematurely via shared budget ([#3004](https://github.com/NousResearch/hermes-agent/pull/3004)) +- Fix YAML boolean handling for `tool_progress` config ([#3300](https://github.com/NousResearch/hermes-agent/pull/3300)) +- Fix `config.get()` crashes on YAML null values ([#3377](https://github.com/NousResearch/hermes-agent/pull/3377)) +- Fix `.strip()` crash on None values from YAML config ([#3552](https://github.com/NousResearch/hermes-agent/pull/3552)) +- Fix hung agents on gateway — `/stop` now hard-kills session lock ([#3104](https://github.com/NousResearch/hermes-agent/pull/3104)) +- Fix `_custom` provider silently remapped to `openrouter` ([#2792](https://github.com/NousResearch/hermes-agent/pull/2792)) +- Fix Matrix missing from `PLATFORMS` dict ([#3473](https://github.com/NousResearch/hermes-agent/pull/3473)) +- Fix Email adapter unbounded `_seen_uids` growth ([#3490](https://github.com/NousResearch/hermes-agent/pull/3490)) + +--- + +## 🧪 Testing + +- Pin `agent-client-protocol` < 0.9 to handle breaking upstream release ([#3320](https://github.com/NousResearch/hermes-agent/pull/3320)) +- Catch anthropic ImportError in vision auto-detection tests ([#3312](https://github.com/NousResearch/hermes-agent/pull/3312)) +- Update retry-exhaust test for new graceful return behavior ([#3320](https://github.com/NousResearch/hermes-agent/pull/3320)) +- Add regression tests for null metadata frontmatter ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +--- + +## 📚 Documentation + +- Update all docs for `/model` command overhaul and custom provider support ([#2800](https://github.com/NousResearch/hermes-agent/pull/2800)) +- Fix stale and incorrect documentation across 18 files ([#2805](https://github.com/NousResearch/hermes-agent/pull/2805)) +- Document 9 previously undocumented features ([#2814](https://github.com/NousResearch/hermes-agent/pull/2814)) +- Add missing skills, CLI commands, and messaging env vars to docs ([#2809](https://github.com/NousResearch/hermes-agent/pull/2809)) +- Fix api-server response storage documentation — SQLite, not in-memory ([#2819](https://github.com/NousResearch/hermes-agent/pull/2819)) +- Quote pip install extras to fix zsh glob errors ([#2815](https://github.com/NousResearch/hermes-agent/pull/2815)) +- Unify hooks documentation — add plugin hooks to hooks page, add `session:end` event ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Clarify two-mode behavior in `session_search` schema description ([untagged commit](https://github.com/NousResearch/hermes-agent)) +- Fix Discord Public Bot setting for Discord-provided invite link ([#3519](https://github.com/NousResearch/hermes-agent/pull/3519)) by @mehmoodosman +- Revise v0.4.0 changelog — fix feature attribution, reorder sections ([untagged commit](https://github.com/NousResearch/hermes-agent)) + +--- + +## 👥 Contributors + +### Core +- **@teknium1** — 157 PRs covering the full scope of this release + +### Community Contributors +- **@alt-glitch** (Siddharth Balyan) — 2 PRs: Nix flake with uv2nix build, NixOS module, and persistent container mode ([#20](https://github.com/NousResearch/hermes-agent/pull/20)); auto-generated config keys and suffix PATHs for Nix builds ([#3061](https://github.com/NousResearch/hermes-agent/pull/3061), [#3274](https://github.com/NousResearch/hermes-agent/pull/3274)) +- **@ctlst** — 1 PR: Prevent AsyncOpenAI/httpx cross-loop deadlock in gateway mode ([#2701](https://github.com/NousResearch/hermes-agent/pull/2701)) +- **@memosr** (memosr.eth) — 1 PR: Add request timeouts to `send_message_tool` HTTP calls ([#3162](https://github.com/NousResearch/hermes-agent/pull/3162)) +- **@mehmoodosman** (Osman Mehmood) — 1 PR: Fix Discord docs for Public Bot setting ([#3519](https://github.com/NousResearch/hermes-agent/pull/3519)) + +### All Contributors +@alt-glitch, @ctlst, @mehmoodosman, @memosr, @teknium1 + +--- + +**Full Changelog**: [v2026.3.23...v2026.3.28](https://github.com/NousResearch/hermes-agent/compare/v2026.3.23...v2026.3.28) diff --git a/acp_adapter/entry.py b/acp_adapter/entry.py index fe13ce703..02e44c15e 100644 --- a/acp_adapter/entry.py +++ b/acp_adapter/entry.py @@ -74,7 +74,7 @@ def main() -> None: agent = HermesACPAgent() try: - asyncio.run(acp.run_agent(agent)) + asyncio.run(acp.run_agent(agent, use_unstable_protocol=True)) except KeyboardInterrupt: logger.info("Shutting down (KeyboardInterrupt)") except Exception: diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 64c1e5185..a5780fb69 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -25,6 +25,9 @@ from acp.schema import ( NewSessionResponse, PromptResponse, ResumeSessionResponse, + SetSessionConfigOptionResponse, + SetSessionModelResponse, + SetSessionModeResponse, ResourceContentBlock, SessionCapabilities, SessionForkCapabilities, @@ -94,11 +97,14 @@ class HermesACPAgent(acp.Agent): async def initialize( self, - protocol_version: int, + protocol_version: int | None = None, client_capabilities: ClientCapabilities | None = None, client_info: Implementation | None = None, **kwargs: Any, ) -> InitializeResponse: + resolved_protocol_version = ( + protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION + ) provider = detect_provider() auth_methods = None if provider: @@ -111,7 +117,11 @@ class HermesACPAgent(acp.Agent): ] client_name = client_info.name if client_info else "unknown" - logger.info("Initialize from %s (protocol v%s)", client_name, protocol_version) + logger.info( + "Initialize from %s (protocol v%s)", + client_name, + resolved_protocol_version, + ) return InitializeResponse( protocol_version=acp.PROTOCOL_VERSION, @@ -471,7 +481,7 @@ class HermesACPAgent(acp.Agent): async def set_session_model( self, model_id: str, session_id: str, **kwargs: Any - ): + ) -> SetSessionModelResponse | None: """Switch the model for a session (called by ACP protocol).""" state = self.session_manager.get_session(session_id) if state: @@ -489,4 +499,37 @@ class HermesACPAgent(acp.Agent): ) self.session_manager.save_session(session_id) logger.info("Session %s: model switched to %s", session_id, model_id) + return SetSessionModelResponse() + logger.warning("Session %s: model switch requested for missing session", session_id) return None + + async def set_session_mode( + self, mode_id: str, session_id: str, **kwargs: Any + ) -> SetSessionModeResponse | None: + """Persist the editor-requested mode so ACP clients do not fail on mode switches.""" + state = self.session_manager.get_session(session_id) + if state is None: + logger.warning("Session %s: mode switch requested for missing session", session_id) + return None + setattr(state, "mode", mode_id) + self.session_manager.save_session(session_id) + logger.info("Session %s: mode switched to %s", session_id, mode_id) + return SetSessionModeResponse() + + async def set_config_option( + self, config_id: str, session_id: str, value: str, **kwargs: Any + ) -> SetSessionConfigOptionResponse | None: + """Accept ACP config option updates even when Hermes has no typed ACP config surface yet.""" + state = self.session_manager.get_session(session_id) + if state is None: + logger.warning("Session %s: config update requested for missing session", session_id) + return None + + options = getattr(state, "config_options", None) + if not isinstance(options, dict): + options = {} + options[str(config_id)] = value + setattr(state, "config_options", options) + self.session_manager.save_session(session_id) + logger.info("Session %s: config option %s updated", session_id, config_id) + return SetSessionConfigOptionResponse(config_options=[]) diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index 4c41c823c..a2a052d0a 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -35,6 +35,54 @@ ADAPTIVE_EFFORT_MAP = { "minimal": "low", } +# ── Max output token limits per Anthropic model ─────────────────────── +# Source: Anthropic docs + Cline model catalog. Anthropic's API requires +# max_tokens as a mandatory field. Previously we hardcoded 16384, which +# starves thinking-enabled models (thinking tokens count toward the limit). +_ANTHROPIC_OUTPUT_LIMITS = { + # Claude 4.6 + "claude-opus-4-6": 128_000, + "claude-sonnet-4-6": 64_000, + # Claude 4.5 + "claude-opus-4-5": 64_000, + "claude-sonnet-4-5": 64_000, + "claude-haiku-4-5": 64_000, + # Claude 4 + "claude-opus-4": 32_000, + "claude-sonnet-4": 64_000, + # Claude 3.7 + "claude-3-7-sonnet": 128_000, + # Claude 3.5 + "claude-3-5-sonnet": 8_192, + "claude-3-5-haiku": 8_192, + # Claude 3 + "claude-3-opus": 4_096, + "claude-3-sonnet": 4_096, + "claude-3-haiku": 4_096, +} + +# For any model not in the table, assume the highest current limit. +# Future Anthropic models are unlikely to have *less* output capacity. +_ANTHROPIC_DEFAULT_OUTPUT_LIMIT = 128_000 + + +def _get_anthropic_max_output(model: str) -> int: + """Look up the max output token limit for an Anthropic model. + + Uses substring matching against _ANTHROPIC_OUTPUT_LIMITS so date-stamped + model IDs (claude-sonnet-4-5-20250929) and variant suffixes (:1m, :fast) + resolve correctly. Longest-prefix match wins to avoid e.g. "claude-3-5" + matching before "claude-3-5-sonnet". + """ + m = model.lower() + best_key = "" + best_val = _ANTHROPIC_DEFAULT_OUTPUT_LIMIT + for key, val in _ANTHROPIC_OUTPUT_LIMITS.items(): + if key in m and len(key) > len(best_key): + best_key = key + best_val = val + return best_val + def _supports_adaptive_thinking(model: str) -> bool: """Return True for Claude 4.6 models that support adaptive thinking.""" @@ -59,6 +107,7 @@ _OAUTH_ONLY_BETAS = [ # The version must stay reasonably current — Anthropic rejects OAuth requests # when the spoofed user-agent version is too far behind the actual release. _CLAUDE_CODE_VERSION_FALLBACK = "2.1.74" +_claude_code_version_cache: Optional[str] = None def _detect_claude_code_version() -> str: @@ -86,11 +135,18 @@ def _detect_claude_code_version() -> str: return _CLAUDE_CODE_VERSION_FALLBACK -_CLAUDE_CODE_VERSION = _detect_claude_code_version() _CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude." _MCP_TOOL_PREFIX = "mcp_" +def _get_claude_code_version() -> str: + """Lazily detect the installed Claude Code version when OAuth headers need it.""" + global _claude_code_version_cache + if _claude_code_version_cache is None: + _claude_code_version_cache = _detect_claude_code_version() + return _claude_code_version_cache + + def _is_oauth_token(key: str) -> bool: """Check if the key is an OAuth/setup token (not a regular Console API key). @@ -132,7 +188,7 @@ def build_anthropic_client(api_key: str, base_url: str = None): kwargs["auth_token"] = api_key kwargs["default_headers"] = { "anthropic-beta": ",".join(all_betas), - "user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)", + "user-agent": f"claude-cli/{_get_claude_code_version()} (external, cli)", "x-app": "cli", } else: @@ -241,7 +297,7 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]: headers = { "Content-Type": "application/json", - "User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)", + "User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)", } for endpoint in token_endpoints: @@ -706,14 +762,21 @@ def convert_messages_to_anthropic( result.append({"role": "user", "content": [tool_result]}) continue - # Regular user message + # Regular user message — validate non-empty content (Anthropic rejects empty) if isinstance(content, list): converted_blocks = _convert_content_to_anthropic(content) - result.append({ - "role": "user", - "content": converted_blocks or [{"type": "text", "text": ""}], - }) + # Check if all text blocks are empty + if not converted_blocks or all( + b.get("text", "").strip() == "" + for b in converted_blocks + if isinstance(b, dict) and b.get("type") == "text" + ): + converted_blocks = [{"type": "text", "text": "(empty message)"}] + result.append({"role": "user", "content": converted_blocks}) else: + # Validate string content is non-empty + if not content or (isinstance(content, str) and not content.strip()): + content = "(empty message)" result.append({"role": "user", "content": content}) # Strip orphaned tool_use blocks (no matching tool_result follows) @@ -803,9 +866,15 @@ def build_anthropic_kwargs( tool_choice: Optional[str] = None, is_oauth: bool = False, preserve_dots: bool = False, + context_length: Optional[int] = None, ) -> Dict[str, Any]: """Build kwargs for anthropic.messages.create(). + When *max_tokens* is None, the model's native output limit is used + (e.g. 128K for Opus 4.6, 64K for Sonnet 4.6). If *context_length* + is provided, the effective limit is clamped so it doesn't exceed + the context window. + When *is_oauth* is True, applies Claude Code compatibility transforms: system prompt prefix, tool name prefixing, and prompt sanitization. @@ -816,7 +885,12 @@ def build_anthropic_kwargs( anthropic_tools = convert_tools_to_anthropic(tools) if tools else [] model = normalize_model_name(model, preserve_dots=preserve_dots) - effective_max_tokens = max_tokens or 16384 + effective_max_tokens = max_tokens or _get_anthropic_max_output(model) + + # Clamp to context window if the user set a lower context_length + # (e.g. custom endpoint with limited capacity). + if context_length and effective_max_tokens > context_length: + effective_max_tokens = max(context_length - 1, 1) # ── OAuth: Claude Code identity ────────────────────────────────── if is_oauth: diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 6e01664ac..1528b3d7f 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -47,8 +47,7 @@ from typing import Any, Dict, List, Optional, Tuple from openai import OpenAI -from hermes_cli.config import get_hermes_home -from hermes_constants import OPENROUTER_BASE_URL +from hermes_constants import OPENROUTER_BASE_URL, get_hermes_home logger = logging.getLogger(__name__) @@ -627,8 +626,6 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]: custom_key = runtime.get("api_key") if not isinstance(custom_base, str) or not custom_base.strip(): return None, None - if not isinstance(custom_key, str) or not custom_key.strip(): - return None, None custom_base = custom_base.strip().rstrip("/") if "openrouter.ai" in custom_base.lower(): @@ -636,6 +633,13 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]: # configured. Treat that as "no custom endpoint" for auxiliary routing. return None, None + # Local servers (Ollama, llama.cpp, vLLM, LM Studio) don't require auth. + # Use a placeholder key — the OpenAI SDK requires a non-empty string but + # local servers ignore the Authorization header. Same fix as cli.py + # _ensure_runtime_credentials() (PR #2556). + if not isinstance(custom_key, str) or not custom_key.strip(): + custom_key = "no-key-required" + return custom_base, custom_key.strip() @@ -693,7 +697,13 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]: is_oauth = _is_oauth_token(token) model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001") logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth) - real_client = build_anthropic_client(token, base_url) + try: + real_client = build_anthropic_client(token, base_url) + except ImportError: + # The anthropic_adapter module imports fine but the SDK itself is + # missing — build_anthropic_client raises ImportError at call time + # when _anthropic_sdk is None. Treat as unavailable. + return None, None return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model @@ -731,16 +741,37 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st return None, None +_AUTO_PROVIDER_LABELS = { + "_try_openrouter": "openrouter", + "_try_nous": "nous", + "_try_custom_endpoint": "local/custom", + "_try_codex": "openai-codex", + "_resolve_api_key_provider": "api-key", +} + + def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: """Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None.""" global auxiliary_is_nous auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins + tried = [] for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider): + fn_name = getattr(try_fn, "__name__", "unknown") + label = _AUTO_PROVIDER_LABELS.get(fn_name, fn_name) client, model = try_fn() if client is not None: + if tried: + logger.info("Auxiliary auto-detect: using %s (%s) — skipped: %s", + label, model or "default", ", ".join(tried)) + else: + logger.info("Auxiliary auto-detect: using %s (%s)", label, model or "default") return client, model - logger.debug("Auxiliary client: none available") + tried.append(label) + logger.warning("Auxiliary auto-detect: no provider available (tried: %s). " + "Compression, summarization, and memory flush will not work. " + "Set OPENROUTER_API_KEY or configure a local model in config.yaml.", + ", ".join(tried)) return None, None @@ -891,11 +922,12 @@ def resolve_provider_client( custom_key = ( (explicit_api_key or "").strip() or os.getenv("OPENAI_API_KEY", "").strip() + or "no-key-required" # local servers don't need auth ) - if not custom_base or not custom_key: + if not custom_base: logger.warning( "resolve_provider_client: explicit custom endpoint requested " - "but no API key was found (set explicit_api_key or OPENAI_API_KEY)" + "but base_url is empty" ) return None, None final_model = model or _read_main_model() or "gpt-4o-mini" @@ -1131,7 +1163,13 @@ def resolve_vision_provider_client( return "custom", client, final_model if requested == "auto": - for candidate in get_available_vision_backends(): + ordered = list(_VISION_AUTO_PROVIDER_ORDER) + preferred = _preferred_main_vision_provider() + if preferred in ordered: + ordered.remove(preferred) + ordered.insert(0, preferred) + + for candidate in ordered: sync_client, default_model = _resolve_strict_vision_backend(candidate) if sync_client is not None: return _finalize(candidate, sync_client, default_model) @@ -1204,6 +1242,39 @@ _client_cache: Dict[tuple, tuple] = {} _client_cache_lock = threading.Lock() +def neuter_async_httpx_del() -> None: + """Monkey-patch ``AsyncHttpxClientWrapper.__del__`` to be a no-op. + + The OpenAI SDK's ``AsyncHttpxClientWrapper.__del__`` schedules + ``self.aclose()`` via ``asyncio.get_running_loop().create_task()``. + When an ``AsyncOpenAI`` client is garbage-collected while + prompt_toolkit's event loop is running (the common CLI idle state), + the ``aclose()`` task runs on prompt_toolkit's loop but the + underlying TCP transport is bound to a *different* loop (the worker + thread's loop that the client was originally created on). If that + loop is closed or its thread is dead, the transport's + ``self._loop.call_soon()`` raises ``RuntimeError("Event loop is + closed")``, which prompt_toolkit surfaces as "Unhandled exception + in event loop ... Press ENTER to continue...". + + Neutering ``__del__`` is safe because: + - Cached clients are explicitly cleaned via ``_force_close_async_httpx`` + on stale-loop detection and ``shutdown_cached_clients`` on exit. + - Uncached clients' TCP connections are cleaned up by the OS when the + process exits. + - The OpenAI SDK itself marks this as a TODO (``# TODO(someday): + support non asyncio runtimes here``). + + Call this once at CLI startup, before any ``AsyncOpenAI`` clients are + created. + """ + try: + from openai._base_client import AsyncHttpxClientWrapper + AsyncHttpxClientWrapper.__del__ = lambda self: None # type: ignore[assignment] + except (ImportError, AttributeError): + pass # Graceful degradation if the SDK changes its internals + + def _force_close_async_httpx(client: Any) -> None: """Mark the httpx AsyncClient inside an AsyncOpenAI client as closed. @@ -1251,6 +1322,25 @@ def shutdown_cached_clients() -> None: _client_cache.clear() +def cleanup_stale_async_clients() -> None: + """Force-close cached async clients whose event loop is closed. + + Call this after each agent turn to proactively clean up stale clients + before GC can trigger ``AsyncHttpxClientWrapper.__del__`` on them. + This is defense-in-depth — the primary fix is ``neuter_async_httpx_del`` + which disables ``__del__`` entirely. + """ + with _client_cache_lock: + stale_keys = [] + for key, entry in _client_cache.items(): + client, _default, cached_loop = entry + if cached_loop is not None and cached_loop.is_closed(): + _force_close_async_httpx(client) + stale_keys.append(key) + for key in stale_keys: + del _client_cache[key] + + def _get_cached_client( provider: str, model: str = None, @@ -1394,6 +1484,29 @@ def _resolve_task_provider_model( return "auto", resolved_model, None, None +_DEFAULT_AUX_TIMEOUT = 30.0 + + +def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float: + """Read timeout from auxiliary.{task}.timeout in config, falling back to *default*.""" + if not task: + return default + try: + from hermes_cli.config import load_config + config = load_config() + except ImportError: + return default + aux = config.get("auxiliary", {}) if isinstance(config, dict) else {} + task_config = aux.get(task, {}) if isinstance(aux, dict) else {} + raw = task_config.get("timeout") + if raw is not None: + try: + return float(raw) + except (ValueError, TypeError): + pass + return default + + def _build_call_kwargs( provider: str, model: str, @@ -1451,7 +1564,7 @@ def call_llm( temperature: float = None, max_tokens: int = None, tools: list = None, - timeout: float = 30.0, + timeout: float = None, extra_body: dict = None, ) -> Any: """Centralized synchronous LLM call. @@ -1469,7 +1582,7 @@ def call_llm( temperature: Sampling temperature (None = provider default). max_tokens: Max output tokens (handles max_tokens vs max_completion_tokens). tools: Tool definitions (for function calling). - timeout: Request timeout in seconds. + timeout: Request timeout in seconds (None = read from auxiliary.{task}.timeout config). extra_body: Additional request body fields. Returns: @@ -1525,8 +1638,8 @@ def call_llm( ) # For auto/custom, fall back to OpenRouter if not resolved_base_url: - logger.warning("Provider %s unavailable, falling back to openrouter", - resolved_provider) + logger.info("Auxiliary %s: provider %s unavailable, falling back to openrouter", + task or "call", resolved_provider) client, final_model = _get_cached_client( "openrouter", resolved_model or _OPENROUTER_MODEL) if client is None: @@ -1534,10 +1647,19 @@ def call_llm( f"No LLM provider configured for task={task} provider={resolved_provider}. " f"Run: hermes setup") + effective_timeout = timeout if timeout is not None else _get_task_timeout(task) + + # Log what we're about to do — makes auxiliary operations visible + _base_info = str(getattr(client, "base_url", resolved_base_url) or "") + if task: + logger.info("Auxiliary %s: using %s (%s)%s", + task, resolved_provider or "auto", final_model or "default", + f" at {_base_info}" if _base_info and "openrouter" not in _base_info else "") + kwargs = _build_call_kwargs( resolved_provider, final_model, messages, temperature=temperature, max_tokens=max_tokens, - tools=tools, timeout=timeout, extra_body=extra_body, + tools=tools, timeout=effective_timeout, extra_body=extra_body, base_url=resolved_base_url) # Handle max_tokens vs max_completion_tokens retry @@ -1552,6 +1674,62 @@ def call_llm( raise +def extract_content_or_reasoning(response) -> str: + """Extract content from an LLM response, falling back to reasoning fields. + + Mirrors the main agent loop's behavior when a reasoning model (DeepSeek-R1, + Qwen-QwQ, etc.) returns ``content=None`` with reasoning in structured fields. + + Resolution order: + 1. ``message.content`` — strip inline think/reasoning blocks, check for + remaining non-whitespace text. + 2. ``message.reasoning`` / ``message.reasoning_content`` — direct + structured reasoning fields (DeepSeek, Moonshot, Novita, etc.). + 3. ``message.reasoning_details`` — OpenRouter unified array format. + + Returns the best available text, or ``""`` if nothing found. + """ + import re + + msg = response.choices[0].message + content = (msg.content or "").strip() + + if content: + # Strip inline think/reasoning blocks (mirrors _strip_think_blocks) + cleaned = re.sub( + r"<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)>" + r".*?" + r"", + "", content, flags=re.DOTALL | re.IGNORECASE, + ).strip() + if cleaned: + return cleaned + + # Content is empty or reasoning-only — try structured reasoning fields + reasoning_parts: list[str] = [] + for field in ("reasoning", "reasoning_content"): + val = getattr(msg, field, None) + if val and isinstance(val, str) and val.strip() and val not in reasoning_parts: + reasoning_parts.append(val.strip()) + + details = getattr(msg, "reasoning_details", None) + if details and isinstance(details, list): + for detail in details: + if isinstance(detail, dict): + summary = ( + detail.get("summary") + or detail.get("content") + or detail.get("text") + ) + if summary and summary not in reasoning_parts: + reasoning_parts.append(summary.strip() if isinstance(summary, str) else str(summary)) + + if reasoning_parts: + return "\n\n".join(reasoning_parts) + + return "" + + async def async_call_llm( task: str = None, *, @@ -1563,7 +1741,7 @@ async def async_call_llm( temperature: float = None, max_tokens: int = None, tools: list = None, - timeout: float = 30.0, + timeout: float = None, extra_body: dict = None, ) -> Any: """Centralized asynchronous LLM call. @@ -1624,10 +1802,12 @@ async def async_call_llm( f"No LLM provider configured for task={task} provider={resolved_provider}. " f"Run: hermes setup") + effective_timeout = timeout if timeout is not None else _get_task_timeout(task) + kwargs = _build_call_kwargs( resolved_provider, final_model, messages, temperature=temperature, max_tokens=max_tokens, - tools=tools, timeout=timeout, extra_body=extra_body, + tools=tools, timeout=effective_timeout, extra_body=extra_body, base_url=resolved_base_url) try: diff --git a/agent/context_compressor.py b/agent/context_compressor.py index a39b19359..6fdb38b29 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -141,7 +141,7 @@ class ContextCompressor: "last_prompt_tokens": self.last_prompt_tokens, "threshold_tokens": self.threshold_tokens, "context_length": self.context_length, - "usage_percent": (self.last_prompt_tokens / self.context_length * 100) if self.context_length else 0, + "usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0, "compression_count": self.compression_count, } @@ -347,7 +347,7 @@ Write only the summary body. Do not include any preamble or prefix.""" "messages": [{"role": "user", "content": prompt}], "temperature": 0.3, "max_tokens": summary_budget * 2, - "timeout": 45.0, + # timeout resolved from auxiliary.compression.timeout config by call_llm } if self.summary_model: call_kwargs["model"] = self.summary_model diff --git a/agent/context_references.py b/agent/context_references.py index 795e37c64..09ba982df 100644 --- a/agent/context_references.py +++ b/agent/context_references.py @@ -286,12 +286,16 @@ def _expand_git_reference( args: list[str], label: str, ) -> tuple[str | None, str | None]: - result = subprocess.run( - ["git", *args], - cwd=cwd, - capture_output=True, - text=True, - ) + try: + result = subprocess.run( + ["git", *args], + cwd=cwd, + capture_output=True, + text=True, + timeout=30, + ) + except subprocess.TimeoutExpired: + return f"{ref.raw}: git command timed out (30s)", None if result.returncode != 0: stderr = (result.stderr or "").strip() or "git command failed" return f"{ref.raw}: {stderr}", None @@ -449,9 +453,12 @@ def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None: cwd=cwd, capture_output=True, text=True, + timeout=10, ) except FileNotFoundError: return None + except subprocess.TimeoutExpired: + return None if result.returncode != 0: return None files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()] diff --git a/agent/display.py b/agent/display.py index b574c485e..de47002d0 100644 --- a/agent/display.py +++ b/agent/display.py @@ -17,6 +17,23 @@ _RESET = "\033[0m" logger = logging.getLogger(__name__) +# ========================================================================= +# Configurable tool preview length (0 = no limit) +# Set once at startup by CLI or gateway from display.tool_preview_length config. +# ========================================================================= +_tool_preview_max_len: int = 0 # 0 = unlimited + + +def set_tool_preview_max_len(n: int) -> None: + """Set the global max length for tool call previews. 0 = no limit.""" + global _tool_preview_max_len + _tool_preview_max_len = max(int(n), 0) if n else 0 + + +def get_tool_preview_max_len() -> int: + """Return the configured max preview length (0 = unlimited).""" + return _tool_preview_max_len + # ========================================================================= # Skin-aware helpers (lazy import to avoid circular deps) @@ -94,8 +111,14 @@ def _oneline(text: str) -> str: return " ".join(text.split()) -def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str | None: - """Build a short preview of a tool call's primary argument for display.""" +def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -> str | None: + """Build a short preview of a tool call's primary argument for display. + + *max_len* controls truncation. ``None`` (default) defers to the global + ``_tool_preview_max_len`` set via config; ``0`` means unlimited. + """ + if max_len is None: + max_len = _tool_preview_max_len if not args: return None primary_args = { @@ -190,7 +213,7 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str | N preview = _oneline(str(value)) if not preview: return None - if len(preview) > max_len: + if max_len > 0 and len(preview) > max_len: preview = preview[:max_len - 3] + "..." return preview @@ -231,7 +254,7 @@ class KawaiiSpinner: "analyzing", "computing", "synthesizing", "formulating", "brainstorming", ] - def __init__(self, message: str = "", spinner_type: str = 'dots'): + def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None): self.message = message self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots']) self.running = False @@ -239,12 +262,26 @@ class KawaiiSpinner: self.frame_idx = 0 self.start_time = None self.last_line_len = 0 + # Optional callable to route all output through (e.g. a no-op for silent + # background agents). When set, bypasses self._out entirely so that + # agents with _print_fn overridden remain fully silent. + self._print_fn = print_fn # Capture stdout NOW, before any redirect_stdout(devnull) from # child agents can replace sys.stdout with a black hole. self._out = sys.stdout def _write(self, text: str, end: str = '\n', flush: bool = False): - """Write to the stdout captured at spinner creation time.""" + """Write to the stdout captured at spinner creation time. + + If a print_fn was supplied at construction, all output is routed through + it instead — allowing callers to silence the spinner with a no-op lambda. + """ + if self._print_fn is not None: + try: + self._print_fn(text) + except Exception: + pass + return try: self._out.write(text + end) if flush: @@ -270,11 +307,11 @@ class KawaiiSpinner: The CLI already drives a TUI widget (_spinner_text) for spinner display, so KawaiiSpinner's \\r-based animation is redundant under StdoutProxy. """ - out = self._out - # StdoutProxy has a 'raw' attribute (bool) that plain file objects lack. - if hasattr(out, 'raw') and type(out).__name__ == 'StdoutProxy': - return True - return False + try: + from prompt_toolkit.patch_stdout import StdoutProxy + return isinstance(self._out, StdoutProxy) + except ImportError: + return False def _animate(self): # When stdout is not a real terminal (e.g. Docker, systemd, pipe), @@ -470,10 +507,14 @@ def get_cute_tool_message( def _trunc(s, n=40): s = str(s) + if _tool_preview_max_len == 0: + return s # no limit return (s[:n-3] + "...") if len(s) > n else s def _path(p, n=35): p = str(p) + if _tool_preview_max_len == 0: + return p # no limit return ("..." + p[-(n-3):]) if len(p) > n else p def _wrap(line: str) -> str: @@ -685,7 +726,7 @@ def format_context_pressure( threshold_percent: Compaction threshold as a fraction of context window. compression_enabled: Whether auto-compression is active. """ - pct_int = int(compaction_progress * 100) + pct_int = min(int(compaction_progress * 100), 100) filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH) bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled) @@ -715,7 +756,7 @@ def format_context_pressure_gateway( No ANSI — just Unicode and plain text suitable for Telegram/Discord/etc. The percentage shows progress toward the compaction threshold. """ - pct_int = int(compaction_progress * 100) + pct_int = min(int(compaction_progress * 100), 100) filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH) bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled) diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2e87498e0..162295f81 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -113,6 +113,15 @@ DEFAULT_CONTEXT_LENGTHS = { "glm": 202752, # Kimi "kimi": 262144, + # Hugging Face Inference Providers — model IDs use org/name format + "Qwen/Qwen3.5-397B-A17B": 131072, + "Qwen/Qwen3.5-35B-A3B": 131072, + "deepseek-ai/DeepSeek-V3.2": 65536, + "moonshotai/Kimi-K2.5": 262144, + "moonshotai/Kimi-K2-Thinking": 262144, + "MiniMaxAI/MiniMax-M2.5": 204800, + "XiaomiMiMo/MiMo-V2-Flash": 32768, + "zai-org/GLM-5": 202752, } _CONTEXT_LENGTH_KEYS = ( diff --git a/agent/models_dev.py b/agent/models_dev.py index 0ef2b62cd..283e8018f 100644 --- a/agent/models_dev.py +++ b/agent/models_dev.py @@ -15,6 +15,8 @@ import time from pathlib import Path from typing import Any, Dict, Optional +from utils import atomic_json_write + import requests logger = logging.getLogger(__name__) @@ -64,12 +66,10 @@ def _load_disk_cache() -> Dict[str, Any]: def _save_disk_cache(data: Dict[str, Any]) -> None: - """Save models.dev data to disk cache.""" + """Save models.dev data to disk cache atomically.""" try: cache_path = _get_cache_path() - cache_path.parent.mkdir(parents=True, exist_ok=True) - with open(cache_path, "w", encoding="utf-8") as f: - json.dump(data, f, separators=(",", ":")) + atomic_json_write(cache_path, data, indent=None, separators=(",", ":")) except Exception as e: logger.debug("Failed to save models.dev disk cache: %s", e) diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 878c8658c..54a94d977 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -4,14 +4,28 @@ All functions are stateless. AIAgent._build_system_prompt() calls these to assemble pieces, then combines them with memory and ephemeral prompts. """ +import json import logging import os import re +import threading +from collections import OrderedDict from pathlib import Path from hermes_constants import get_hermes_home from typing import Optional +from agent.skill_utils import ( + extract_skill_conditions, + extract_skill_description, + get_all_skills_dirs, + get_disabled_skill_names, + iter_skill_index_files, + parse_frontmatter, + skill_matches_platform, +) +from utils import atomic_json_write + logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- @@ -156,6 +170,25 @@ SKILLS_GUIDANCE = ( "Skills that aren't maintained become liabilities." ) +TOOL_USE_ENFORCEMENT_GUIDANCE = ( + "# Tool-use enforcement\n" + "You MUST use your tools to take action — do not describe what you would do " + "or plan to do without actually doing it. When you say you will perform an " + "action (e.g. 'I will run the tests', 'Let me check the file', 'I will create " + "the project'), you MUST immediately make the corresponding tool call in the same " + "response. Never end your turn with a promise of future action — execute it now.\n" + "Keep working until the task is actually complete. Do not stop with a summary of " + "what you plan to do next time. If you have tools available that can accomplish " + "the task, use them instead of telling the user what you would do.\n" + "Every response should either (a) contain tool calls that make progress, or " + "(b) deliver a final result to the user. Responses that only describe intentions " + "without acting are not acceptable." +) + +# Model name substrings that trigger tool-use enforcement guidance. +# Add new patterns here when a model family needs explicit steering. +TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex") + PLATFORM_HINTS = { "whatsapp": ( "You are on a text messaging communication platform, WhatsApp. " @@ -230,6 +263,111 @@ CONTEXT_TRUNCATE_HEAD_RATIO = 0.7 CONTEXT_TRUNCATE_TAIL_RATIO = 0.2 +# ========================================================================= +# Skills prompt cache +# ========================================================================= + +_SKILLS_PROMPT_CACHE_MAX = 8 +_SKILLS_PROMPT_CACHE: OrderedDict[tuple, str] = OrderedDict() +_SKILLS_PROMPT_CACHE_LOCK = threading.Lock() +_SKILLS_SNAPSHOT_VERSION = 1 + + +def _skills_prompt_snapshot_path() -> Path: + return get_hermes_home() / ".skills_prompt_snapshot.json" + + +def clear_skills_system_prompt_cache(*, clear_snapshot: bool = False) -> None: + """Drop the in-process skills prompt cache (and optionally the disk snapshot).""" + with _SKILLS_PROMPT_CACHE_LOCK: + _SKILLS_PROMPT_CACHE.clear() + if clear_snapshot: + try: + _skills_prompt_snapshot_path().unlink(missing_ok=True) + except OSError as e: + logger.debug("Could not remove skills prompt snapshot: %s", e) + + +def _build_skills_manifest(skills_dir: Path) -> dict[str, list[int]]: + """Build an mtime/size manifest of all SKILL.md and DESCRIPTION.md files.""" + manifest: dict[str, list[int]] = {} + for filename in ("SKILL.md", "DESCRIPTION.md"): + for path in iter_skill_index_files(skills_dir, filename): + try: + st = path.stat() + except OSError: + continue + manifest[str(path.relative_to(skills_dir))] = [st.st_mtime_ns, st.st_size] + return manifest + + +def _load_skills_snapshot(skills_dir: Path) -> Optional[dict]: + """Load the disk snapshot if it exists and its manifest still matches.""" + snapshot_path = _skills_prompt_snapshot_path() + if not snapshot_path.exists(): + return None + try: + snapshot = json.loads(snapshot_path.read_text(encoding="utf-8")) + except Exception: + return None + if not isinstance(snapshot, dict): + return None + if snapshot.get("version") != _SKILLS_SNAPSHOT_VERSION: + return None + if snapshot.get("manifest") != _build_skills_manifest(skills_dir): + return None + return snapshot + + +def _write_skills_snapshot( + skills_dir: Path, + manifest: dict[str, list[int]], + skill_entries: list[dict], + category_descriptions: dict[str, str], +) -> None: + """Persist skill metadata to disk for fast cold-start reuse.""" + payload = { + "version": _SKILLS_SNAPSHOT_VERSION, + "manifest": manifest, + "skills": skill_entries, + "category_descriptions": category_descriptions, + } + try: + atomic_json_write(_skills_prompt_snapshot_path(), payload) + except Exception as e: + logger.debug("Could not write skills prompt snapshot: %s", e) + + +def _build_snapshot_entry( + skill_file: Path, + skills_dir: Path, + frontmatter: dict, + description: str, +) -> dict: + """Build a serialisable metadata dict for one skill.""" + rel_path = skill_file.relative_to(skills_dir) + parts = rel_path.parts + if len(parts) >= 2: + skill_name = parts[-2] + category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0] + else: + category = "general" + skill_name = skill_file.parent.name + + platforms = frontmatter.get("platforms") or [] + if isinstance(platforms, str): + platforms = [platforms] + + return { + "skill_name": skill_name, + "category": category, + "frontmatter_name": str(frontmatter.get("name", skill_name)), + "description": description, + "platforms": [str(p).strip() for p in platforms if str(p).strip()], + "conditions": extract_skill_conditions(frontmatter), + } + + # ========================================================================= # Skills index # ========================================================================= @@ -241,22 +379,13 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]: (True, {}, "") to err on the side of showing the skill. """ try: - from tools.skills_tool import _parse_frontmatter, skill_matches_platform - raw = skill_file.read_text(encoding="utf-8")[:2000] - frontmatter, _ = _parse_frontmatter(raw) + frontmatter, _ = parse_frontmatter(raw) if not skill_matches_platform(frontmatter): - return False, {}, "" + return False, frontmatter, "" - desc = "" - raw_desc = frontmatter.get("description", "") - if raw_desc: - desc = str(raw_desc).strip().strip("'\"") - if len(desc) > 60: - desc = desc[:57] + "..." - - return True, frontmatter, desc + return True, frontmatter, extract_skill_description(frontmatter) except Exception as e: logger.debug("Failed to parse skill file %s: %s", skill_file, e) return True, {}, "" @@ -265,16 +394,9 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]: def _read_skill_conditions(skill_file: Path) -> dict: """Extract conditional activation fields from SKILL.md frontmatter.""" try: - from tools.skills_tool import _parse_frontmatter raw = skill_file.read_text(encoding="utf-8")[:2000] - frontmatter, _ = _parse_frontmatter(raw) - hermes = frontmatter.get("metadata", {}).get("hermes", {}) - return { - "fallback_for_toolsets": hermes.get("fallback_for_toolsets", []), - "requires_toolsets": hermes.get("requires_toolsets", []), - "fallback_for_tools": hermes.get("fallback_for_tools", []), - "requires_tools": hermes.get("requires_tools", []), - } + frontmatter, _ = parse_frontmatter(raw) + return extract_skill_conditions(frontmatter) except Exception as e: logger.debug("Failed to read skill conditions from %s: %s", skill_file, e) return {} @@ -317,109 +439,210 @@ def build_skills_system_prompt( ) -> str: """Build a compact skill index for the system prompt. - Scans ~/.hermes/skills/ for SKILL.md files grouped by category. - Includes per-skill descriptions from frontmatter so the model can - match skills by meaning, not just name. - Filters out skills incompatible with the current OS platform. + Two-layer cache: + 1. In-process LRU dict keyed by (skills_dir, tools, toolsets) + 2. Disk snapshot (``.skills_prompt_snapshot.json``) validated by + mtime/size manifest — survives process restarts + + Falls back to a full filesystem scan when both layers miss. + + External skill directories (``skills.external_dirs`` in config.yaml) are + scanned alongside the local ``~/.hermes/skills/`` directory. External dirs + are read-only — they appear in the index but new skills are always created + in the local dir. Local skills take precedence when names collide. """ hermes_home = get_hermes_home() skills_dir = hermes_home / "skills" + external_dirs = get_all_skills_dirs()[1:] # skip local (index 0) - if not skills_dir.exists(): + if not skills_dir.exists() and not external_dirs: return "" - # Collect skills with descriptions, grouped by category. - # Each entry: (skill_name, description) - # Supports sub-categories: skills/mlops/training/axolotl/SKILL.md - # -> category "mlops/training", skill "axolotl" - # Load disabled skill names once for the entire scan - try: - from tools.skills_tool import _get_disabled_skill_names - disabled = _get_disabled_skill_names() - except Exception: - disabled = set() + # ── Layer 1: in-process LRU cache ───────────────────────────────── + cache_key = ( + str(skills_dir.resolve()), + tuple(str(d) for d in external_dirs), + tuple(sorted(str(t) for t in (available_tools or set()))), + tuple(sorted(str(ts) for ts in (available_toolsets or set()))), + ) + with _SKILLS_PROMPT_CACHE_LOCK: + cached = _SKILLS_PROMPT_CACHE.get(cache_key) + if cached is not None: + _SKILLS_PROMPT_CACHE.move_to_end(cache_key) + return cached + + disabled = get_disabled_skill_names() + + # ── Layer 2: disk snapshot ──────────────────────────────────────── + snapshot = _load_skills_snapshot(skills_dir) skills_by_category: dict[str, list[tuple[str, str]]] = {} - for skill_file in skills_dir.rglob("SKILL.md"): - is_compatible, frontmatter, desc = _parse_skill_file(skill_file) - if not is_compatible: - continue - rel_path = skill_file.relative_to(skills_dir) - parts = rel_path.parts - if len(parts) >= 2: - skill_name = parts[-2] - category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0] - else: - category = "general" - skill_name = skill_file.parent.name - # Respect user's disabled skills config - fm_name = frontmatter.get("name", skill_name) - if fm_name in disabled or skill_name in disabled: - continue - # Extract conditions inline from already-parsed frontmatter - # (avoids redundant file re-read that _read_skill_conditions would do) - hermes_meta = (frontmatter.get("metadata") or {}).get("hermes") or {} - conditions = { - "fallback_for_toolsets": hermes_meta.get("fallback_for_toolsets", []), - "requires_toolsets": hermes_meta.get("requires_toolsets", []), - "fallback_for_tools": hermes_meta.get("fallback_for_tools", []), - "requires_tools": hermes_meta.get("requires_tools", []), + category_descriptions: dict[str, str] = {} + + if snapshot is not None: + # Fast path: use pre-parsed metadata from disk + for entry in snapshot.get("skills", []): + if not isinstance(entry, dict): + continue + skill_name = entry.get("skill_name") or "" + category = entry.get("category") or "general" + frontmatter_name = entry.get("frontmatter_name") or skill_name + platforms = entry.get("platforms") or [] + if not skill_matches_platform({"platforms": platforms}): + continue + if frontmatter_name in disabled or skill_name in disabled: + continue + if not _skill_should_show( + entry.get("conditions") or {}, + available_tools, + available_toolsets, + ): + continue + skills_by_category.setdefault(category, []).append( + (skill_name, entry.get("description", "")) + ) + category_descriptions = { + str(k): str(v) + for k, v in (snapshot.get("category_descriptions") or {}).items() } - if not _skill_should_show(conditions, available_tools, available_toolsets): - continue - skills_by_category.setdefault(category, []).append((skill_name, desc)) + else: + # Cold path: full filesystem scan + write snapshot for next time + skill_entries: list[dict] = [] + for skill_file in iter_skill_index_files(skills_dir, "SKILL.md"): + is_compatible, frontmatter, desc = _parse_skill_file(skill_file) + entry = _build_snapshot_entry(skill_file, skills_dir, frontmatter, desc) + skill_entries.append(entry) + if not is_compatible: + continue + skill_name = entry["skill_name"] + if entry["frontmatter_name"] in disabled or skill_name in disabled: + continue + if not _skill_should_show( + extract_skill_conditions(frontmatter), + available_tools, + available_toolsets, + ): + continue + skills_by_category.setdefault(entry["category"], []).append( + (skill_name, entry["description"]) + ) - if not skills_by_category: - return "" - - # Read category-level descriptions from DESCRIPTION.md - # Checks both the exact category path and parent directories - category_descriptions = {} - for category in skills_by_category: - cat_path = Path(category) - desc_file = skills_dir / cat_path / "DESCRIPTION.md" - if desc_file.exists(): + # Read category-level DESCRIPTION.md files + for desc_file in iter_skill_index_files(skills_dir, "DESCRIPTION.md"): try: content = desc_file.read_text(encoding="utf-8") - match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL) - if match: - category_descriptions[category] = match.group(1).strip() + fm, _ = parse_frontmatter(content) + cat_desc = fm.get("description") + if not cat_desc: + continue + rel = desc_file.relative_to(skills_dir) + cat = "/".join(rel.parts[:-1]) if len(rel.parts) > 1 else "general" + category_descriptions[cat] = str(cat_desc).strip().strip("'\"") except Exception as e: logger.debug("Could not read skill description %s: %s", desc_file, e) - index_lines = [] - for category in sorted(skills_by_category.keys()): - cat_desc = category_descriptions.get(category, "") - if cat_desc: - index_lines.append(f" {category}: {cat_desc}") - else: - index_lines.append(f" {category}:") - # Deduplicate and sort skills within each category - seen = set() - for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]): - if name in seen: - continue - seen.add(name) - if desc: - index_lines.append(f" - {name}: {desc}") - else: - index_lines.append(f" - {name}") + _write_skills_snapshot( + skills_dir, + _build_skills_manifest(skills_dir), + skill_entries, + category_descriptions, + ) - return ( - "## Skills (mandatory)\n" - "Before replying, scan the skills below. If one clearly matches your task, " - "load it with skill_view(name) and follow its instructions. " - "If a skill has issues, fix it with skill_manage(action='patch').\n" - "After difficult/iterative tasks, offer to save as a skill. " - "If a skill you loaded was missing steps, had wrong commands, or needed " - "pitfalls you discovered, update it before finishing.\n" - "\n" - "\n" - + "\n".join(index_lines) + "\n" - "\n" - "\n" - "If none match, proceed normally without loading a skill." - ) + # ── External skill directories ───────────────────────────────────── + # Scan external dirs directly (no snapshot caching — they're read-only + # and typically small). Local skills already in skills_by_category take + # precedence: we track seen names and skip duplicates from external dirs. + seen_skill_names: set[str] = set() + for cat_skills in skills_by_category.values(): + for name, _desc in cat_skills: + seen_skill_names.add(name) + + for ext_dir in external_dirs: + if not ext_dir.exists(): + continue + for skill_file in iter_skill_index_files(ext_dir, "SKILL.md"): + try: + is_compatible, frontmatter, desc = _parse_skill_file(skill_file) + if not is_compatible: + continue + entry = _build_snapshot_entry(skill_file, ext_dir, frontmatter, desc) + skill_name = entry["skill_name"] + if skill_name in seen_skill_names: + continue + if entry["frontmatter_name"] in disabled or skill_name in disabled: + continue + if not _skill_should_show( + extract_skill_conditions(frontmatter), + available_tools, + available_toolsets, + ): + continue + seen_skill_names.add(skill_name) + skills_by_category.setdefault(entry["category"], []).append( + (skill_name, entry["description"]) + ) + except Exception as e: + logger.debug("Error reading external skill %s: %s", skill_file, e) + + # External category descriptions + for desc_file in iter_skill_index_files(ext_dir, "DESCRIPTION.md"): + try: + content = desc_file.read_text(encoding="utf-8") + fm, _ = parse_frontmatter(content) + cat_desc = fm.get("description") + if not cat_desc: + continue + rel = desc_file.relative_to(ext_dir) + cat = "/".join(rel.parts[:-1]) if len(rel.parts) > 1 else "general" + category_descriptions.setdefault(cat, str(cat_desc).strip().strip("'\"")) + except Exception as e: + logger.debug("Could not read external skill description %s: %s", desc_file, e) + + if not skills_by_category: + result = "" + else: + index_lines = [] + for category in sorted(skills_by_category.keys()): + cat_desc = category_descriptions.get(category, "") + if cat_desc: + index_lines.append(f" {category}: {cat_desc}") + else: + index_lines.append(f" {category}:") + # Deduplicate and sort skills within each category + seen = set() + for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]): + if name in seen: + continue + seen.add(name) + if desc: + index_lines.append(f" - {name}: {desc}") + else: + index_lines.append(f" - {name}") + + result = ( + "## Skills (mandatory)\n" + "Before replying, scan the skills below. If one clearly matches your task, " + "load it with skill_view(name) and follow its instructions. " + "If a skill has issues, fix it with skill_manage(action='patch').\n" + "After difficult/iterative tasks, offer to save as a skill. " + "If a skill you loaded was missing steps, had wrong commands, or needed " + "pitfalls you discovered, update it before finishing.\n" + "\n" + "\n" + + "\n".join(index_lines) + "\n" + "\n" + "\n" + "If none match, proceed normally without loading a skill." + ) + + # ── Store in LRU cache ──────────────────────────────────────────── + with _SKILLS_PROMPT_CACHE_LOCK: + _SKILLS_PROMPT_CACHE[cache_key] = result + _SKILLS_PROMPT_CACHE.move_to_end(cache_key) + while len(_SKILLS_PROMPT_CACHE) > _SKILLS_PROMPT_CACHE_MAX: + _SKILLS_PROMPT_CACHE.popitem(last=False) + + return result def build_nous_subscription_prompt(valid_tool_names: "set[str] | None" = None) -> str: diff --git a/agent/skill_commands.py b/agent/skill_commands.py index b266ad251..8a434ea79 100644 --- a/agent/skill_commands.py +++ b/agent/skill_commands.py @@ -128,7 +128,11 @@ def _build_skill_message( supporting.append(rel) if supporting and skill_dir: - skill_view_target = str(skill_dir.relative_to(SKILLS_DIR)) + try: + skill_view_target = str(skill_dir.relative_to(SKILLS_DIR)) + except ValueError: + # Skill is from an external dir — use the skill name instead + skill_view_target = skill_dir.name parts.append("") parts.append("[This skill has supporting files you can load with the skill_view tool:]") for sf in supporting: @@ -158,38 +162,49 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]: _skill_commands = {} try: from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform, _get_disabled_skill_names - if not SKILLS_DIR.exists(): - return _skill_commands + from agent.skill_utils import get_external_skills_dirs disabled = _get_disabled_skill_names() - for skill_md in SKILLS_DIR.rglob("SKILL.md"): - if any(part in ('.git', '.github', '.hub') for part in skill_md.parts): - continue - try: - content = skill_md.read_text(encoding='utf-8') - frontmatter, body = _parse_frontmatter(content) - # Skip skills incompatible with the current OS platform - if not skill_matches_platform(frontmatter): + seen_names: set = set() + + # Scan local dir first, then external dirs + dirs_to_scan = [] + if SKILLS_DIR.exists(): + dirs_to_scan.append(SKILLS_DIR) + dirs_to_scan.extend(get_external_skills_dirs()) + + for scan_dir in dirs_to_scan: + for skill_md in scan_dir.rglob("SKILL.md"): + if any(part in ('.git', '.github', '.hub') for part in skill_md.parts): continue - name = frontmatter.get('name', skill_md.parent.name) - # Respect user's disabled skills config - if name in disabled: + try: + content = skill_md.read_text(encoding='utf-8') + frontmatter, body = _parse_frontmatter(content) + # Skip skills incompatible with the current OS platform + if not skill_matches_platform(frontmatter): + continue + name = frontmatter.get('name', skill_md.parent.name) + if name in seen_names: + continue + # Respect user's disabled skills config + if name in disabled: + continue + description = frontmatter.get('description', '') + if not description: + for line in body.strip().split('\n'): + line = line.strip() + if line and not line.startswith('#'): + description = line[:80] + break + seen_names.add(name) + cmd_name = name.lower().replace(' ', '-').replace('_', '-') + _skill_commands[f"/{cmd_name}"] = { + "name": name, + "description": description or f"Invoke the {name} skill", + "skill_md_path": str(skill_md), + "skill_dir": str(skill_md.parent), + } + except Exception: continue - description = frontmatter.get('description', '') - if not description: - for line in body.strip().split('\n'): - line = line.strip() - if line and not line.startswith('#'): - description = line[:80] - break - cmd_name = name.lower().replace(' ', '-').replace('_', '-') - _skill_commands[f"/{cmd_name}"] = { - "name": name, - "description": description or f"Invoke the {name} skill", - "skill_md_path": str(skill_md), - "skill_dir": str(skill_md.parent), - } - except Exception: - continue except Exception: pass return _skill_commands diff --git a/agent/skill_utils.py b/agent/skill_utils.py new file mode 100644 index 000000000..c11bc5e2d --- /dev/null +++ b/agent/skill_utils.py @@ -0,0 +1,270 @@ +"""Lightweight skill metadata utilities shared by prompt_builder and skills_tool. + +This module intentionally avoids importing the tool registry, CLI config, or any +heavy dependency chain. It is safe to import at module level without triggering +tool registration or provider resolution. +""" + +import logging +import os +import re +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple + +from hermes_constants import get_hermes_home + +logger = logging.getLogger(__name__) + +# ── Platform mapping ────────────────────────────────────────────────────── + +PLATFORM_MAP = { + "macos": "darwin", + "linux": "linux", + "windows": "win32", +} + +EXCLUDED_SKILL_DIRS = frozenset((".git", ".github", ".hub")) + +# ── Lazy YAML loader ───────────────────────────────────────────────────── + +_yaml_load_fn = None + + +def yaml_load(content: str): + """Parse YAML with lazy import and CSafeLoader preference.""" + global _yaml_load_fn + if _yaml_load_fn is None: + import yaml + + loader = getattr(yaml, "CSafeLoader", None) or yaml.SafeLoader + + def _load(value: str): + return yaml.load(value, Loader=loader) + + _yaml_load_fn = _load + return _yaml_load_fn(content) + + +# ── Frontmatter parsing ────────────────────────────────────────────────── + + +def parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]: + """Parse YAML frontmatter from a markdown string. + + Uses yaml with CSafeLoader for full YAML support (nested metadata, lists) + with a fallback to simple key:value splitting for robustness. + + Returns: + (frontmatter_dict, remaining_body) + """ + frontmatter: Dict[str, Any] = {} + body = content + + if not content.startswith("---"): + return frontmatter, body + + end_match = re.search(r"\n---\s*\n", content[3:]) + if not end_match: + return frontmatter, body + + yaml_content = content[3 : end_match.start() + 3] + body = content[end_match.end() + 3 :] + + try: + parsed = yaml_load(yaml_content) + if isinstance(parsed, dict): + frontmatter = parsed + except Exception: + # Fallback: simple key:value parsing for malformed YAML + for line in yaml_content.strip().split("\n"): + if ":" not in line: + continue + key, value = line.split(":", 1) + frontmatter[key.strip()] = value.strip() + + return frontmatter, body + + +# ── Platform matching ───────────────────────────────────────────────────── + + +def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool: + """Return True when the skill is compatible with the current OS. + + Skills declare platform requirements via a top-level ``platforms`` list + in their YAML frontmatter:: + + platforms: [macos] # macOS only + platforms: [macos, linux] # macOS and Linux + + If the field is absent or empty the skill is compatible with **all** + platforms (backward-compatible default). + """ + platforms = frontmatter.get("platforms") + if not platforms: + return True + if not isinstance(platforms, list): + platforms = [platforms] + current = sys.platform + for platform in platforms: + normalized = str(platform).lower().strip() + mapped = PLATFORM_MAP.get(normalized, normalized) + if current.startswith(mapped): + return True + return False + + +# ── Disabled skills ─────────────────────────────────────────────────────── + + +def get_disabled_skill_names() -> Set[str]: + """Read disabled skill names from config.yaml. + + Resolves platform from ``HERMES_PLATFORM`` env var, falls back to + the global disabled list. Reads the config file directly (no CLI + config imports) to stay lightweight. + """ + config_path = get_hermes_home() / "config.yaml" + if not config_path.exists(): + return set() + try: + parsed = yaml_load(config_path.read_text(encoding="utf-8")) + except Exception as e: + logger.debug("Could not read skill config %s: %s", config_path, e) + return set() + if not isinstance(parsed, dict): + return set() + + skills_cfg = parsed.get("skills") + if not isinstance(skills_cfg, dict): + return set() + + resolved_platform = os.getenv("HERMES_PLATFORM") + if resolved_platform: + platform_disabled = (skills_cfg.get("platform_disabled") or {}).get( + resolved_platform + ) + if platform_disabled is not None: + return _normalize_string_set(platform_disabled) + return _normalize_string_set(skills_cfg.get("disabled")) + + +def _normalize_string_set(values) -> Set[str]: + if values is None: + return set() + if isinstance(values, str): + values = [values] + return {str(v).strip() for v in values if str(v).strip()} + + +# ── External skills directories ────────────────────────────────────────── + + +def get_external_skills_dirs() -> List[Path]: + """Read ``skills.external_dirs`` from config.yaml and return validated paths. + + Each entry is expanded (``~`` and ``${VAR}``) and resolved to an absolute + path. Only directories that actually exist are returned. Duplicates and + paths that resolve to the local ``~/.hermes/skills/`` are silently skipped. + """ + config_path = get_hermes_home() / "config.yaml" + if not config_path.exists(): + return [] + try: + parsed = yaml_load(config_path.read_text(encoding="utf-8")) + except Exception: + return [] + if not isinstance(parsed, dict): + return [] + + skills_cfg = parsed.get("skills") + if not isinstance(skills_cfg, dict): + return [] + + raw_dirs = skills_cfg.get("external_dirs") + if not raw_dirs: + return [] + if isinstance(raw_dirs, str): + raw_dirs = [raw_dirs] + if not isinstance(raw_dirs, list): + return [] + + local_skills = (get_hermes_home() / "skills").resolve() + seen: Set[Path] = set() + result: List[Path] = [] + + for entry in raw_dirs: + entry = str(entry).strip() + if not entry: + continue + # Expand ~ and environment variables + expanded = os.path.expanduser(os.path.expandvars(entry)) + p = Path(expanded).resolve() + if p == local_skills: + continue + if p in seen: + continue + if p.is_dir(): + seen.add(p) + result.append(p) + else: + logger.debug("External skills dir does not exist, skipping: %s", p) + + return result + + +def get_all_skills_dirs() -> List[Path]: + """Return all skill directories: local ``~/.hermes/skills/`` first, then external. + + The local dir is always first (and always included even if it doesn't exist + yet — callers handle that). External dirs follow in config order. + """ + dirs = [get_hermes_home() / "skills"] + dirs.extend(get_external_skills_dirs()) + return dirs + + +# ── Condition extraction ────────────────────────────────────────────────── + + +def extract_skill_conditions(frontmatter: Dict[str, Any]) -> Dict[str, List]: + """Extract conditional activation fields from parsed frontmatter.""" + hermes = (frontmatter.get("metadata") or {}).get("hermes") or {} + return { + "fallback_for_toolsets": hermes.get("fallback_for_toolsets", []), + "requires_toolsets": hermes.get("requires_toolsets", []), + "fallback_for_tools": hermes.get("fallback_for_tools", []), + "requires_tools": hermes.get("requires_tools", []), + } + + +# ── Description extraction ──────────────────────────────────────────────── + + +def extract_skill_description(frontmatter: Dict[str, Any]) -> str: + """Extract a truncated description from parsed frontmatter.""" + raw_desc = frontmatter.get("description", "") + if not raw_desc: + return "" + desc = str(raw_desc).strip().strip("'\"") + if len(desc) > 60: + return desc[:57] + "..." + return desc + + +# ── File iteration ──────────────────────────────────────────────────────── + + +def iter_skill_index_files(skills_dir: Path, filename: str): + """Walk skills_dir yielding sorted paths matching *filename*. + + Excludes ``.git``, ``.github``, ``.hub`` directories. + """ + matches = [] + for root, dirs, files in os.walk(skills_dir): + dirs[:] = [d for d in dirs if d not in EXCLUDED_SKILL_DIRS] + if filename in files: + matches.append(Path(root) / filename) + for path in sorted(matches, key=lambda p: str(p.relative_to(skills_dir))): + yield path diff --git a/agent/title_generator.py b/agent/title_generator.py index 9a18aab58..741fe8b09 100644 --- a/agent/title_generator.py +++ b/agent/title_generator.py @@ -19,7 +19,7 @@ _TITLE_PROMPT = ( ) -def generate_title(user_message: str, assistant_response: str, timeout: float = 15.0) -> Optional[str]: +def generate_title(user_message: str, assistant_response: str, timeout: float = 30.0) -> Optional[str]: """Generate a session title from the first exchange. Uses the auxiliary LLM client (cheapest/fastest available model). diff --git a/cli-config.yaml.example b/cli-config.yaml.example index dc1a33199..504b2178d 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -7,17 +7,33 @@ # ============================================================================= model: # Default model to use (can be overridden with --model flag) + # Both "default" and "model" work as the key name here. default: "anthropic/claude-opus-4.6" # Inference provider selection: - # "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default) - # "nous-api" - Use Nous Portal via API key (requires: NOUS_API_KEY) - # "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY - # "nous" - Always use Nous Portal (requires: hermes login) - # "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY) - # "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY) - # "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY) - # "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY) + # "auto" - Auto-detect from credentials (default) + # "openrouter" - OpenRouter (requires: OPENROUTER_API_KEY or OPENAI_API_KEY) + # "nous" - Nous Portal OAuth (requires: hermes login) + # "nous-api" - Nous Portal API key (requires: NOUS_API_KEY) + # "anthropic" - Direct Anthropic API (requires: ANTHROPIC_API_KEY) + # "openai-codex" - OpenAI Codex (requires: hermes login --provider openai-codex) + # "copilot" - GitHub Copilot / GitHub Models (requires: GITHUB_TOKEN) + # "zai" - z.ai / ZhipuAI GLM (requires: GLM_API_KEY) + # "kimi-coding" - Kimi / Moonshot AI (requires: KIMI_API_KEY) + # "minimax" - MiniMax global (requires: MINIMAX_API_KEY) + # "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY) + # "huggingface" - Hugging Face Inference (requires: HF_TOKEN) + # "kilocode" - KiloCode gateway (requires: KILOCODE_API_KEY) + # "ai-gateway" - Vercel AI Gateway (requires: AI_GATEWAY_API_KEY) + # + # Local servers (LM Studio, Ollama, vLLM, llama.cpp): + # "custom" - Any OpenAI-compatible endpoint. Set base_url below. + # Aliases: "lmstudio", "ollama", "vllm", "llamacpp" all map to "custom". + # Example for LM Studio: + # provider: "lmstudio" + # base_url: "http://localhost:1234/v1" + # No API key needed — local servers typically ignore auth. + # # Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var. provider: "auto" @@ -401,6 +417,15 @@ skills: # Set to 0 to disable. creation_nudge_interval: 15 + # External skill directories — share skills across tools/agents without + # copying them into ~/.hermes/skills/. Each path is expanded (~ and ${VAR}) + # and resolved to an absolute path. External dirs are read-only: skill + # creation always writes to ~/.hermes/skills/. Local skills take precedence + # when names collide. + # external_dirs: + # - ~/.agents/skills + # - /home/shared/team-skills + # ============================================================================= # Agent Behavior # ============================================================================= @@ -688,6 +713,12 @@ display: # Toggle at runtime with /verbose in the CLI tool_progress: all + # What Enter does when Hermes is already busy in the CLI. + # interrupt: Interrupt the current run and redirect Hermes (default) + # queue: Queue your message for the next turn + # Ctrl+C always interrupts regardless of this setting. + busy_input_mode: interrupt + # Background process notifications (gateway/messaging only). # Controls how chatty the process watcher is when you use # terminal(background=true, check_interval=...) from Telegram/Discord/etc. diff --git a/cli.py b/cli.py index e6ce2a95b..a601878f2 100644 --- a/cli.py +++ b/cli.py @@ -70,7 +70,7 @@ _COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧ # Load .env from ~/.hermes/.env first, then project root as dev fallback. # User-managed env files should override stale shell exports on restart. -from hermes_constants import get_hermes_home, OPENROUTER_BASE_URL +from hermes_constants import get_hermes_home, display_hermes_home, OPENROUTER_BASE_URL from hermes_cli.env_loader import load_hermes_dotenv _hermes_home = get_hermes_home() @@ -205,6 +205,7 @@ def load_cli_config() -> Dict[str, Any]: "resume_display": "full", "show_reasoning": False, "streaming": True, + "busy_input_mode": "interrupt", "skin": "default", }, @@ -448,6 +449,25 @@ try: except Exception: pass # Skin engine is optional — default skin used if unavailable +# Initialize tool preview length from config +try: + from agent.display import set_tool_preview_max_len + _tpl = CLI_CONFIG.get("display", {}).get("tool_preview_length", 0) + set_tool_preview_max_len(int(_tpl) if _tpl else 0) +except Exception: + pass + +# Neuter AsyncHttpxClientWrapper.__del__ before any AsyncOpenAI clients are +# created. The SDK's __del__ schedules aclose() on asyncio.get_running_loop() +# which, during CLI idle time, finds prompt_toolkit's event loop and tries to +# close TCP transports bound to dead worker loops — producing +# "Event loop is closed" / "Press ENTER to continue..." errors. +try: + from agent.auxiliary_client import neuter_async_httpx_del + neuter_async_httpx_del() +except Exception: + pass + from rich import box as rich_box from rich.console import Console from rich.markup import escape as _escape @@ -1035,13 +1055,18 @@ class HermesCLI: self.config = CLI_CONFIG self.compact = compact if compact is not None else CLI_CONFIG["display"].get("compact", False) # tool_progress: "off", "new", "all", "verbose" (from config.yaml display section) - self.tool_progress_mode = CLI_CONFIG["display"].get("tool_progress", "all") + # YAML 1.1 parses bare `off` as boolean False — normalise to string. + _raw_tp = CLI_CONFIG["display"].get("tool_progress", "all") + self.tool_progress_mode = "off" if _raw_tp is False else str(_raw_tp) # resume_display: "full" (show history) | "minimal" (one-liner only) self.resume_display = CLI_CONFIG["display"].get("resume_display", "full") # bell_on_complete: play terminal bell (\a) when agent finishes a response self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False) # show_reasoning: display model thinking/reasoning before the response self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False) + # busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn) + _bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt") + self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt" self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose") @@ -1061,12 +1086,12 @@ class HermesCLI: # authoritative. This avoids conflicts in multi-agent setups where # env vars would stomp each other. _model_config = CLI_CONFIG.get("model", {}) - _config_model = _model_config.get("default", "") if isinstance(_model_config, dict) else (_model_config or "") - _FALLBACK_MODEL = "anthropic/claude-opus-4.6" - self.model = model or _config_model or _FALLBACK_MODEL - # Auto-detect model from local server if still on fallback - if self.model == _FALLBACK_MODEL: - _base_url = _model_config.get("base_url", "") if isinstance(_model_config, dict) else "" + _config_model = (_model_config.get("default") or _model_config.get("model") or "") if isinstance(_model_config, dict) else (_model_config or "") + _DEFAULT_CONFIG_MODEL = "anthropic/claude-opus-4.6" + self.model = model or _config_model or _DEFAULT_CONFIG_MODEL + # Auto-detect model from local server if still on default + if self.model == _DEFAULT_CONFIG_MODEL: + _base_url = (_model_config.get("base_url") or "") if isinstance(_model_config, dict) else "" if "localhost" in _base_url or "127.0.0.1" in _base_url: from hermes_cli.runtime_provider import _auto_detect_local_model _detected = _auto_detect_local_model(_base_url) @@ -1079,7 +1104,7 @@ class HermesCLI: # explicit choice — the user just never changed it. But a config model # like "gpt-5.3-codex" IS explicit and must be preserved. self._model_is_default = not model and ( - not _config_model or _config_model == _FALLBACK_MODEL + not _config_model or _config_model == _DEFAULT_CONFIG_MODEL ) self._explicit_api_key = api_key @@ -1165,9 +1190,13 @@ class HermesCLI: self._provider_require_params = pr.get("require_parameters", False) self._provider_data_collection = pr.get("data_collection") - # Fallback model config — tried when primary provider fails after retries - fb = CLI_CONFIG.get("fallback_model") or {} - self._fallback_model = fb if fb.get("provider") and fb.get("model") else None + # Fallback provider chain — tried in order when primary fails after retries. + # Supports new list format (fallback_providers) and legacy single-dict (fallback_model). + fb = CLI_CONFIG.get("fallback_providers") or CLI_CONFIG.get("fallback_model") or [] + # Normalize legacy single-dict to a one-element list + if isinstance(fb, dict): + fb = [fb] if fb.get("provider") and fb.get("model") else [] + self._fallback_model = fb # Optional cheap-vs-strong routing for simple turns self._smart_model_routing = CLI_CONFIG.get("smart_model_routing", {}) or {} @@ -1329,7 +1358,12 @@ class HermesCLI: def _build_status_bar_text(self, width: Optional[int] = None) -> str: try: snapshot = self._get_status_bar_snapshot() - width = width or shutil.get_terminal_size((80, 24)).columns + if width is None: + try: + from prompt_toolkit.application import get_app + width = get_app().output.get_size().columns + except Exception: + width = shutil.get_terminal_size((80, 24)).columns percent = snapshot["context_percent"] percent_label = f"{percent}%" if percent is not None else "--" duration_label = snapshot["duration"] @@ -1359,7 +1393,16 @@ class HermesCLI: return [] try: snapshot = self._get_status_bar_snapshot() - width = shutil.get_terminal_size((80, 24)).columns + # Use prompt_toolkit's own terminal width when running inside the + # TUI — shutil.get_terminal_size() can return stale or fallback + # values (especially on SSH) that differ from what prompt_toolkit + # actually renders, causing the fragments to overflow to a second + # line and produce duplicated status bar rows over long sessions. + try: + from prompt_toolkit.application import get_app + width = get_app().output.get_size().columns + except Exception: + width = shutil.get_terminal_size((80, 24)).columns duration_label = snapshot["duration"] if width < 52: @@ -1594,6 +1637,7 @@ class HermesCLI: if not text: return self._reasoning_stream_started = True + self._reasoning_shown_this_turn = True if getattr(self, "_stream_box_opened", False): return @@ -2929,6 +2973,82 @@ class HermesCLI: if not silent: print("(^_^)v New session started!") + def _handle_resume_command(self, cmd_original: str) -> None: + """Handle /resume — switch to a previous session mid-conversation.""" + parts = cmd_original.split(None, 1) + target = parts[1].strip() if len(parts) > 1 else "" + + if not target: + _cprint(" Usage: /resume ") + _cprint(" Tip: Use /history or `hermes sessions list` to find sessions.") + return + + if not self._session_db: + _cprint(" Session database not available.") + return + + # Resolve title or ID + from hermes_cli.main import _resolve_session_by_name_or_id + resolved = _resolve_session_by_name_or_id(target) + target_id = resolved or target + + session_meta = self._session_db.get_session(target_id) + if not session_meta: + _cprint(f" Session not found: {target}") + _cprint(" Use /history or `hermes sessions list` to see available sessions.") + return + + if target_id == self.session_id: + _cprint(" Already on that session.") + return + + # End current session + try: + self._session_db.end_session(self.session_id, "resumed_other") + except Exception: + pass + + # Switch to the target session + self.session_id = target_id + self._resumed = True + self._pending_title = None + + # Load conversation history + restored = self._session_db.get_messages_as_conversation(target_id) + self.conversation_history = restored or [] + + # Re-open the target session so it's not marked as ended + try: + self._session_db.reopen_session(target_id) + except Exception: + pass + + # Sync the agent if already initialised + if self.agent: + self.agent.session_id = target_id + self.agent.reset_session_state() + if hasattr(self.agent, "_last_flushed_db_idx"): + self.agent._last_flushed_db_idx = len(self.conversation_history) + if hasattr(self.agent, "_todo_store"): + try: + from tools.todo_tool import TodoStore + self.agent._todo_store = TodoStore() + except Exception: + pass + if hasattr(self.agent, "_invalidate_system_prompt"): + self.agent._invalidate_system_prompt() + + title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else "" + msg_count = len([m for m in self.conversation_history if m.get("role") == "user"]) + if self.conversation_history: + _cprint( + f" ↻ Resumed session {target_id}{title_part}" + f" ({msg_count} user message{'s' if msg_count != 1 else ''}," + f" {len(self.conversation_history)} total)" + ) + else: + _cprint(f" ↻ Resumed session {target_id}{title_part} — no messages, starting fresh.") + def reset_conversation(self): """Reset the conversation by starting a new session.""" self.new_session() @@ -3486,7 +3606,7 @@ class HermesCLI: print(" To start the gateway:") print(" python cli.py --gateway") print() - print(" Configuration file: ~/.hermes/config.yaml") + print(f" Configuration file: {display_hermes_home()}/config.yaml") print() except Exception as e: @@ -3496,7 +3616,7 @@ class HermesCLI: print(" 1. Set environment variables:") print(" TELEGRAM_BOT_TOKEN=your_token") print(" DISCORD_BOT_TOKEN=your_token") - print(" 2. Or configure settings in ~/.hermes/config.yaml") + print(f" 2. Or configure settings in {display_hermes_home()}/config.yaml") print() def process_command(self, command: str) -> bool: @@ -3647,6 +3767,8 @@ class HermesCLI: _cprint(" Session database not available.") elif canonical == "new": self.new_session() + elif canonical == "resume": + self._handle_resume_command(cmd_original) elif canonical == "provider": self._show_model_and_providers() elif canonical == "prompt": @@ -3701,7 +3823,7 @@ class HermesCLI: plugins = mgr.list_plugins() if not plugins: print("No plugins installed.") - print("Drop plugin directories into ~/.hermes/plugins/ to get started.") + print(f"Drop plugin directories into {display_hermes_home()}/plugins/ to get started.") else: print(f"Plugins ({len(plugins)}):") for p in plugins: @@ -3722,17 +3844,17 @@ class HermesCLI: elif canonical == "background": self._handle_background_command(cmd_original) elif canonical == "queue": - if not self._agent_running: - _cprint(" /queue only works while Hermes is busy. Just type your message normally.") + # Extract prompt after "/queue " or "/q " + parts = cmd_original.split(None, 1) + payload = parts[1].strip() if len(parts) > 1 else "" + if not payload: + _cprint(" Usage: /queue ") else: - # Extract prompt after "/queue " or "/q " - parts = cmd_original.split(None, 1) - payload = parts[1].strip() if len(parts) > 1 else "" - if not payload: - _cprint(" Usage: /queue ") - else: - self._pending_input.put(payload) + self._pending_input.put(payload) + if self._agent_running: _cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}") + else: + _cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}") elif canonical == "skin": self._handle_skin_command(cmd_original) elif canonical == "voice": @@ -3924,6 +4046,17 @@ class HermesCLI: provider_data_collection=self._provider_data_collection, fallback_model=self._fallback_model, ) + # Silence raw spinner; route thinking through TUI widget when no foreground agent is active. + bg_agent._print_fn = lambda *_a, **_kw: None + + def _bg_thinking(text: str) -> None: + # Concurrent bg tasks may race on _spinner_text; acceptable for best-effort UI. + if not self._agent_running: + self._spinner_text = text + if self._app: + self._app.invalidate() + + bg_agent.thinking_callback = _bg_thinking result = bg_agent.run_conversation( user_message=prompt, @@ -3986,6 +4119,9 @@ class HermesCLI: _cprint(f" ❌ Background task #{task_num} failed: {e}") finally: self._background_tasks.pop(task_id, None) + # Clear spinner only if no foreground agent owns it + if not self._agent_running: + self._spinner_text = "" if self._app: self._invalidate(min_interval=0) @@ -4216,7 +4352,7 @@ class HermesCLI: source = f" ({s['source']})" if s["source"] == "user" else "" print(f" {marker} {s['name']}{source} — {s['description']}") print("\n Usage: /skin ") - print(" Custom skins: drop a YAML file in ~/.hermes/skins/\n") + print(f" Custom skins: drop a YAML file in {display_hermes_home()}/skins/\n") return new_skin = parts[1].strip().lower() @@ -4396,7 +4532,7 @@ class HermesCLI: compressor = agent.context_compressor last_prompt = compressor.last_prompt_tokens ctx_len = compressor.context_length - pct = (last_prompt / ctx_len * 100) if ctx_len else 0 + pct = min(100, (last_prompt / ctx_len * 100)) if ctx_len else 0 compressions = compressor.compression_count msg_count = len(self.conversation_history) @@ -4654,8 +4790,10 @@ class HermesCLI: from agent.display import get_tool_emoji emoji = get_tool_emoji(function_name) label = preview or function_name - if len(label) > 50: - label = label[:47] + "..." + from agent.display import get_tool_preview_max_len + _pl = get_tool_preview_max_len() + if _pl > 0 and len(label) > _pl: + label = label[:_pl - 3] + "..." self._spinner_text = f"{emoji} {label}" self._invalidate() @@ -5424,6 +5562,13 @@ class HermesCLI: except Exception as e: logging.debug("@ context reference expansion failed: %s", e) + # Sanitize surrogate characters that can arrive via clipboard paste from + # rich-text editors (Google Docs, Word, etc.). Lone surrogates are invalid + # UTF-8 and crash JSON serialization in the OpenAI SDK. + if isinstance(message, str): + from run_agent import _sanitize_surrogates + message = _sanitize_surrogates(message) + # Add user message to history self.conversation_history.append({"role": "user", "content": message}) @@ -5436,6 +5581,10 @@ class HermesCLI: # Reset streaming display state for this turn self._reset_stream_state() + # Separate from _reset_stream_state because this must persist + # across intermediate turn boundaries (tool-calling loops) — only + # reset at the start of each user turn. + self._reasoning_shown_this_turn = False # --- Streaming TTS setup --- # When ElevenLabs is the TTS provider and sounddevice is available, @@ -5580,6 +5729,16 @@ class HermesCLI: agent_thread.join() # Ensure agent thread completes + # Proactively clean up async clients whose event loop is dead. + # The agent thread may have created AsyncOpenAI clients bound + # to a per-thread event loop; if that loop is now closed, those + # clients' __del__ would crash prompt_toolkit's loop on GC. + try: + from agent.auxiliary_client import cleanup_stale_async_clients + cleanup_stale_async_clients() + except Exception: + pass + # Flush any remaining streamed text and close the box self._flush_stream() @@ -5640,8 +5799,13 @@ class HermesCLI: response_previewed = result.get("response_previewed", False) if result else False # Display reasoning (thinking) box if enabled and available. - # Skip when streaming already showed reasoning live. - if self.show_reasoning and result and not self._reasoning_stream_started: + # Skip when streaming already showed reasoning live. Use the + # turn-persistent flag (_reasoning_shown_this_turn) instead of + # _reasoning_stream_started — the latter gets reset during + # intermediate turn boundaries (tool-calling loops), which caused + # the reasoning box to re-render after the final response. + _reasoning_already_shown = getattr(self, '_reasoning_shown_this_turn', False) + if self.show_reasoning and result and not _reasoning_already_shown: reasoning = result.get("last_reasoning") if reasoning: w = shutil.get_terminal_size().columns @@ -5762,10 +5926,22 @@ class HermesCLI: else: duration_str = f"{seconds}s" + # Look up session title for resume-by-name hint + session_title = None + if self._session_db: + try: + session_title = self._session_db.get_session_title(self.session_id) + except Exception: + pass + print("Resume this session with:") print(f" hermes --resume {self.session_id}") + if session_title: + print(f" hermes -c \"{session_title}\"") print() print(f"Session: {self.session_id}") + if session_title: + print(f"Title: {session_title}") print(f"Duration: {duration_str}") print(f"Messages: {msg_count} ({user_msgs} user, {tool_calls} tool calls)") else: @@ -5782,6 +5958,9 @@ class HermesCLI: ``normal_prompt`` is the full ``branding.prompt_symbol``. ``state_suffix`` is what special states (sudo/secret/approval/agent) should render after their leading icon. + + When a profile is active (not "default"), the profile name is + prepended to the prompt symbol: ``coder ❯`` instead of ``❯``. """ try: from hermes_cli.skin_engine import get_active_prompt_symbol @@ -5790,6 +5969,15 @@ class HermesCLI: symbol = "❯ " symbol = (symbol or "❯ ").rstrip() + " " + + # Prepend profile name when not default + try: + from hermes_cli.profiles import get_active_profile_name + profile = get_active_profile_name() + if profile not in ("default", "custom"): + symbol = f"{profile} {symbol}" + except Exception: + pass stripped = symbol.rstrip() if not stripped: return "❯ ", "❯ " @@ -5941,7 +6129,7 @@ class HermesCLI: from honcho_integration.client import HonchoClientConfig from agent.display import honcho_session_line, write_tty hcfg = HonchoClientConfig.from_global_config() - if hcfg.enabled and hcfg.api_key and hcfg.explicitly_configured: + if hcfg.enabled and (hcfg.api_key or hcfg.base_url) and hcfg.explicitly_configured: sname = hcfg.resolve_session_name(session_id=self.session_id) if sname: write_tty(honcho_session_line(hcfg.workspace_id, sname) + "\n") @@ -6028,10 +6216,18 @@ class HermesCLI: set_approval_callback(self._approval_callback) set_secret_capture_callback(self._secret_capture_callback) - # Ensure tirith security scanner is available (downloads if needed) + # Ensure tirith security scanner is available (downloads if needed). + # Warn the user if tirith is enabled in config but not available, + # so they know command security scanning is degraded. try: from tools.tirith_security import ensure_installed - ensure_installed(log_failures=False) + tirith_path = ensure_installed(log_failures=False) + if tirith_path is None: + security_cfg = self.config.get("security", {}) or {} + tirith_enabled = security_cfg.get("tirith_enabled", True) + if tirith_enabled: + _cprint(f" {_DIM}⚠ tirith security scanner enabled but not available " + f"— command scanning will use pattern matching only{_RST}") except Exception: pass # Non-fatal — fail-open at scan time if unavailable @@ -6112,16 +6308,22 @@ class HermesCLI: # Bundle text + images as a tuple when images are present payload = (text, images) if images else text if self._agent_running and not (text and text.startswith("/")): - self._interrupt_queue.put(payload) - # Debug: log to file when message enters interrupt queue - try: - _dbg = _hermes_home / "interrupt_debug.log" - with open(_dbg, "a") as _f: - import time as _t - _f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, " - f"agent_running={self._agent_running}\n") - except Exception: - pass + if self.busy_input_mode == "queue": + # Queue for the next turn instead of interrupting + self._pending_input.put(payload) + preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]" + _cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}") + else: + self._interrupt_queue.put(payload) + # Debug: log to file when message enters interrupt queue + try: + _dbg = _hermes_home / "interrupt_debug.log" + with open(_dbg, "a") as _f: + import time as _t + _f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, " + f"agent_running={self._agent_running}\n") + except Exception: + pass else: self._pending_input.put(payload) event.app.current_buffer.reset(append_to_history=True) @@ -6312,6 +6514,24 @@ class HermesCLI: self._should_exit = True event.app.exit() + @kb.add('c-z') + def handle_ctrl_z(event): + """Handle Ctrl+Z - suspend process to background (Unix only).""" + import sys + if sys.platform == 'win32': + _cprint(f"\n{_DIM}Suspend (Ctrl+Z) is not supported on Windows.{_RST}") + event.app.invalidate() + return + import os, signal as _sig + from prompt_toolkit.application import run_in_terminal + from hermes_cli.skin_engine import get_active_skin + agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent") + msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back." + def _suspend(): + os.write(1, msg.encode()) + os.kill(0, _sig.SIGTSTP) + run_in_terminal(_suspend) + # Voice push-to-talk key: configurable via config.yaml (voice.record_key) # Default: Ctrl+B (avoids conflict with Ctrl+R readline reverse-search) # Config uses "ctrl+b" format; prompt_toolkit expects "c-b" format. @@ -6501,6 +6721,7 @@ class HermesCLI: # Paste collapsing: detect large pastes and save to temp file _paste_counter = [0] _prev_text_len = [0] + _prev_newline_count = [0] _paste_just_collapsed = [False] def _on_text_changed(buf): @@ -6509,18 +6730,27 @@ class HermesCLI: When bracketed paste is available, handle_paste collapses large pastes directly. This handler is a fallback for terminals without bracketed paste support. + + Two heuristics (either triggers collapse): + 1. Many characters added at once (chars_added > 1) — works + when the terminal delivers the paste in one event-loop tick. + 2. Newline count jumped by 4+ in a single text-change event — + catches terminals that feed characters individually but + still batch newlines. Alt+Enter only adds 1 newline per + event so it never triggers this. """ text = buf.text chars_added = len(text) - _prev_text_len[0] _prev_text_len[0] = len(text) if _paste_just_collapsed[0]: _paste_just_collapsed[0] = False + _prev_newline_count[0] = text.count('\n') return line_count = text.count('\n') - # Heuristic: a real paste adds many characters at once (not just a - # single newline from Alt+Enter) AND the result has 5+ lines. - # Fallback for terminals without bracketed paste support. - if line_count >= 5 and chars_added > 1 and not text.startswith('/'): + newlines_added = line_count - _prev_newline_count[0] + _prev_newline_count[0] = line_count + is_paste = chars_added > 1 or newlines_added >= 4 + if line_count >= 5 and is_paste and not text.startswith('/'): _paste_counter[0] += 1 # Save to temp file paste_dir = _hermes_home / "pastes" @@ -6528,6 +6758,7 @@ class HermesCLI: paste_file = paste_dir / f"paste_{_paste_counter[0]}_{datetime.now().strftime('%H%M%S')}.txt" paste_file.write_text(text, encoding="utf-8") # Replace buffer with compact reference + _paste_just_collapsed[0] = True buf.text = f"[Pasted text #{_paste_counter[0]}: {line_count + 1} lines \u2192 {paste_file}]" buf.cursor_position = len(buf.text) @@ -6894,6 +7125,15 @@ class HermesCLI: Window( content=FormattedTextControl(lambda: cli_ref._get_status_bar_fragments()), height=1, + # Prevent fragments that overflow the terminal width from + # wrapping onto a second line, which causes the status bar to + # appear duplicated (one full + one partial row) during long + # sessions, especially on SSH where shutil.get_terminal_size + # may return stale values. _get_status_bar_fragments now reads + # width from prompt_toolkit's own output object, so fragments + # will always fit; wrap_lines=False is the belt-and-suspenders + # guard against any future width mismatch. + wrap_lines=False, ), filter=Condition(lambda: cli_ref._status_bar_visible), ) @@ -7128,9 +7368,28 @@ class HermesCLI: # Register atexit cleanup so resources are freed even on unexpected exit atexit.register(_run_cleanup) + # Install a custom asyncio exception handler that suppresses the + # "Event loop is closed" RuntimeError from httpx transport cleanup. + # This is defense-in-depth — the primary fix is neuter_async_httpx_del + # which disables __del__ entirely, but older clients or SDK upgrades + # could bypass it. + def _suppress_closed_loop_errors(loop, context): + exc = context.get("exception") + if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc): + return # silently suppress + # Fall back to default handler for everything else + loop.default_exception_handler(context) + # Run the application with patch_stdout for proper output handling try: with patch_stdout(): + # Set the custom handler on prompt_toolkit's event loop + try: + import asyncio as _aio + _loop = _aio.get_event_loop() + _loop.set_exception_handler(_suppress_closed_loop_errors) + except Exception: + pass app.run() except (EOFError, KeyboardInterrupt): pass diff --git a/cron/jobs.py b/cron/jobs.py index 841f56335..22c04d0c6 100644 --- a/cron/jobs.py +++ b/cron/jobs.py @@ -327,7 +327,20 @@ def load_jobs() -> List[Dict[str, Any]]: with open(JOBS_FILE, 'r', encoding='utf-8') as f: data = json.load(f) return data.get("jobs", []) - except (json.JSONDecodeError, IOError): + except json.JSONDecodeError: + # Retry with strict=False to handle bare control chars in string values + try: + with open(JOBS_FILE, 'r', encoding='utf-8') as f: + data = json.loads(f.read(), strict=False) + jobs = data.get("jobs", []) + if jobs: + # Auto-repair: rewrite with proper escaping + save_jobs(jobs) + logger.warning("Auto-repaired jobs.json (had invalid control characters)") + return jobs + except Exception: + return [] + except IOError: return [] @@ -598,6 +611,34 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None): save_jobs(jobs) +def advance_next_run(job_id: str) -> bool: + """Preemptively advance next_run_at for a recurring job before execution. + + Call this BEFORE run_job() so that if the process crashes mid-execution, + the job won't re-fire on the next gateway restart. This converts the + scheduler from at-least-once to at-most-once for recurring jobs — missing + one run is far better than firing dozens of times in a crash loop. + + One-shot jobs are left unchanged so they can still retry on restart. + + Returns True if next_run_at was advanced, False otherwise. + """ + jobs = load_jobs() + for job in jobs: + if job["id"] == job_id: + kind = job.get("schedule", {}).get("kind") + if kind not in ("cron", "interval"): + return False + now = _hermes_now().isoformat() + new_next = compute_next_run(job["schedule"], now) + if new_next and new_next != job.get("next_run_at"): + job["next_run_at"] = new_next + save_jobs(jobs) + return True + return False + return False + + def get_due_jobs() -> List[Dict[str, Any]]: """Get all jobs that are due to run now. diff --git a/cron/scheduler.py b/cron/scheduler.py index e6313cd7b..e4299836f 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -26,6 +26,7 @@ except ImportError: msvcrt = None from pathlib import Path from hermes_constants import get_hermes_home +from hermes_cli.config import load_config from typing import Optional from hermes_time import now as _hermes_now @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) -from cron.jobs import get_due_jobs, mark_job_run, save_job_output +from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run # Sentinel: when a cron agent has nothing new to report, it can start its # response with this marker to suppress delivery. Output is still saved @@ -86,6 +87,22 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]: chat_id, thread_id = rest.split(":", 1) else: chat_id, thread_id = rest, None + + # Resolve human-friendly labels like "Alice (dm)" to real IDs. + # send_message(action="list") shows labels with display suffixes + # that aren't valid platform IDs (e.g. WhatsApp JIDs). + try: + from gateway.channel_directory import resolve_channel_name + target = chat_id + # Strip display suffix like " (dm)" or " (group)" + if target.endswith(")") and " (" in target: + target = target.rsplit(" (", 1)[0].strip() + resolved = resolve_channel_name(platform_name.lower(), target) + if resolved: + chat_id = resolved + except Exception: + pass + return { "platform": platform_name, "chat_id": chat_id, @@ -145,6 +162,8 @@ def _deliver_result(job: dict, content: str) -> None: "mattermost": Platform.MATTERMOST, "homeassistant": Platform.HOMEASSISTANT, "dingtalk": Platform.DINGTALK, + "feishu": Platform.FEISHU, + "wecom": Platform.WECOM, "email": Platform.EMAIL, "sms": Platform.SMS, } @@ -164,18 +183,29 @@ def _deliver_result(job: dict, content: str) -> None: logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name) return - # Wrap the content so the user knows this is a cron delivery and that - # the interactive agent has no visibility into it. - task_name = job.get("name", job["id"]) - wrapped = ( - f"Cronjob Response: {task_name}\n" - f"-------------\n\n" - f"{content}\n\n" - f"Note: The agent cannot see this message, and therefore cannot respond to it." - ) + # Optionally wrap the content with a header/footer so the user knows this + # is a cron delivery. Wrapping is on by default; set cron.wrap_response: false + # in config.yaml for clean output. + wrap_response = True + try: + user_cfg = load_config() + wrap_response = user_cfg.get("cron", {}).get("wrap_response", True) + except Exception: + pass + + if wrap_response: + task_name = job.get("name", job["id"]) + delivery_content = ( + f"Cronjob Response: {task_name}\n" + f"-------------\n\n" + f"{content}\n\n" + f"Note: The agent cannot see this message, and therefore cannot respond to it." + ) + else: + delivery_content = content # Run the async send in a fresh event loop (safe from any thread) - coro = _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id) + coro = _send_to_platform(platform, pconfig, chat_id, delivery_content, thread_id=thread_id) try: result = asyncio.run(coro) except RuntimeError: @@ -186,7 +216,7 @@ def _deliver_result(job: dict, content: str) -> None: coro.close() import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, wrapped, thread_id=thread_id)) + future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, delivery_content, thread_id=thread_id)) result = future.result(timeout=30) except Exception as e: logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e) @@ -308,7 +338,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: if delivery_target.get("thread_id") is not None: os.environ["HERMES_CRON_AUTO_DELIVER_THREAD_ID"] = str(delivery_target["thread_id"]) - model = job.get("model") or os.getenv("HERMES_MODEL") or "anthropic/claude-opus-4.6" + model = job.get("model") or os.getenv("HERMES_MODEL") or "" # Load config.yaml for model, reasoning, prefill, toolsets, provider routing _cfg = {} @@ -524,6 +554,12 @@ def tick(verbose: bool = True) -> int: executed = 0 for job in due_jobs: try: + # For recurring jobs (cron/interval), advance next_run_at to the + # next future occurrence BEFORE execution. This way, if the + # process crashes mid-run, the job won't re-fire on restart. + # One-shot jobs are left alone so they can retry on restart. + advance_next_run(job["id"]) + success, output, final_response, error = run_job(job) output_file = save_job_output(job["id"], output) diff --git a/docker/SOUL.md b/docker/SOUL.md new file mode 100644 index 000000000..9103a6122 --- /dev/null +++ b/docker/SOUL.md @@ -0,0 +1,15 @@ +# Hermes Agent Persona + + \ No newline at end of file diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 000000000..4c6366cbe --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Docker entrypoint: bootstrap config files into the mounted volume, then run hermes. +set -e + +HERMES_HOME="/opt/data" +INSTALL_DIR="/opt/hermes" + +# Create essential directory structure. Cache and platform directories +# (cache/images, cache/audio, platforms/whatsapp, etc.) are created on +# demand by the application — don't pre-create them here so new installs +# get the consolidated layout from get_hermes_dir(). +mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills} + +# .env +if [ ! -f "$HERMES_HOME/.env" ]; then + cp "$INSTALL_DIR/.env.example" "$HERMES_HOME/.env" +fi + +# config.yaml +if [ ! -f "$HERMES_HOME/config.yaml" ]; then + cp "$INSTALL_DIR/cli-config.yaml.example" "$HERMES_HOME/config.yaml" +fi + +# SOUL.md +if [ ! -f "$HERMES_HOME/SOUL.md" ]; then + cp "$INSTALL_DIR/docker/SOUL.md" "$HERMES_HOME/SOUL.md" +fi + +# Sync bundled skills (manifest-based so user edits are preserved) +if [ -d "$INSTALL_DIR/skills" ]; then + python3 "$INSTALL_DIR/tools/skills_sync.py" +fi + +exec hermes "$@" diff --git a/environments/README.md b/environments/README.md index f2d1a7956..9677fdb70 100644 --- a/environments/README.md +++ b/environments/README.md @@ -101,21 +101,11 @@ Available methods: ### Patches (`patches.py`) -**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested. +**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested. -**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop. +**Solution**: `ModalEnvironment` uses a dedicated `_AsyncWorker` background thread with its own event loop. The calling code sees a sync interface, but internally all async Modal SDK calls happen on the worker thread so they don't conflict with Atropos's loop. This is built directly into `tools/environments/modal.py` — no monkey-patching required. -What gets patched: -- `SwerexModalEnvironment.__init__` -- creates Modal deployment on a background thread -- `SwerexModalEnvironment.execute` -- runs commands on the same background thread -- `SwerexModalEnvironment.stop` -- stops deployment on the background thread - -The patches are: -- **Idempotent** -- calling `apply_patches()` multiple times is safe -- **Transparent** -- same interface and behavior, only the internal async execution changes -- **Universal** -- works identically in normal CLI use (no running event loop) - -Applied automatically at import time by `hermes_base_env.py`. +`patches.py` is now a no-op (kept for backward compatibility with imports). ### Tool Call Parsers (`tool_call_parsers/`) diff --git a/environments/benchmarks/terminalbench_2/terminalbench2_env.py b/environments/benchmarks/terminalbench_2/terminalbench2_env.py index 1b52c15f8..3f95d4029 100644 --- a/environments/benchmarks/terminalbench_2/terminalbench2_env.py +++ b/environments/benchmarks/terminalbench_2/terminalbench2_env.py @@ -209,7 +209,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv): # Agent settings -- TB2 tasks are complex, need many turns max_agent_turns=60, - max_token_length=*** + max_token_length=16000, agent_temperature=0.6, system_prompt=None, @@ -233,7 +233,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv): steps_per_eval=1, total_steps=1, - tokenizer_name="NousRe...1-8B", + tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B", use_wandb=True, wandb_name="terminal-bench-2", ensure_scores_are_not_same=False, # Binary rewards may all be 0 or 1 @@ -245,7 +245,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv): base_url="https://openrouter.ai/api/v1", model_name="anthropic/claude-sonnet-4", server_type="openai", - api_key=os.get...EY", ""), + api_key=os.getenv("OPENROUTER_API_KEY", ""), health_check=False, ) ] @@ -513,3 +513,446 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv): reward = 0.0 else: # Run tests in a thread so the blocking ctx.terminal() calls + # don't freeze the entire event loop (which would stall all + # other tasks, tqdm updates, and timeout timers). + ctx = ToolContext(task_id) + try: + loop = asyncio.get_event_loop() + reward = await loop.run_in_executor( + None, # default thread pool + self._run_tests, eval_item, ctx, task_name, + ) + except Exception as e: + logger.error("Task %s: test verification failed: %s", task_name, e) + reward = 0.0 + finally: + ctx.cleanup() + + passed = reward == 1.0 + status = "PASS" if passed else "FAIL" + elapsed = time.time() - task_start + tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)") + logger.info( + "Task %s: reward=%.1f, turns=%d, finished=%s", + task_name, reward, result.turns_used, result.finished_naturally, + ) + + out = { + "passed": passed, + "reward": reward, + "task_name": task_name, + "category": category, + "turns_used": result.turns_used, + "finished_naturally": result.finished_naturally, + "messages": result.messages, + } + self._save_result(out) + return out + + except Exception as e: + elapsed = time.time() - task_start + logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True) + tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)") + out = { + "passed": False, "reward": 0.0, + "task_name": task_name, "category": category, + "error": str(e), + } + self._save_result(out) + return out + + finally: + # --- Cleanup: clear overrides, sandbox, and temp files --- + clear_task_env_overrides(task_id) + try: + cleanup_vm(task_id) + except Exception as e: + logger.debug("VM cleanup for %s: %s", task_id[:8], e) + if task_dir and task_dir.exists(): + shutil.rmtree(task_dir, ignore_errors=True) + + def _run_tests( + self, item: Dict[str, Any], ctx: ToolContext, task_name: str + ) -> float: + """ + Upload and execute the test suite in the agent's sandbox, then + download the verifier output locally to read the reward. + + Follows Harbor's verification pattern: + 1. Upload tests/ directory into the sandbox + 2. Execute test.sh inside the sandbox + 3. Download /logs/verifier/ directory to a local temp dir + 4. Read reward.txt locally with native Python I/O + + Downloading locally avoids issues with the file_read tool on + the Modal VM and matches how Harbor handles verification. + + TB2 test scripts (test.sh) typically: + 1. Install pytest via uv/pip + 2. Run pytest against the test files in /tests/ + 3. Write results to /logs/verifier/reward.txt + + Args: + item: The TB2 task dict (contains tests_tar, test_sh) + ctx: ToolContext scoped to this task's sandbox + task_name: For logging + + Returns: + 1.0 if tests pass, 0.0 otherwise + """ + tests_tar = item.get("tests_tar", "") + test_sh = item.get("test_sh", "") + + if not test_sh: + logger.warning("Task %s: no test_sh content, reward=0", task_name) + return 0.0 + + # Create required directories in the sandbox + ctx.terminal("mkdir -p /tests /logs/verifier") + + # Upload test files into the sandbox (binary-safe via base64) + if tests_tar: + tests_temp = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{task_name}-")) + try: + _extract_base64_tar(tests_tar, tests_temp) + ctx.upload_dir(str(tests_temp), "/tests") + except Exception as e: + logger.warning("Task %s: failed to upload test files: %s", task_name, e) + finally: + shutil.rmtree(tests_temp, ignore_errors=True) + + # Write the test runner script (test.sh) + ctx.write_file("/tests/test.sh", test_sh) + ctx.terminal("chmod +x /tests/test.sh") + + # Execute the test suite + logger.info( + "Task %s: running test suite (timeout=%ds)", + task_name, self.config.test_timeout, + ) + test_result = ctx.terminal( + "bash /tests/test.sh", + timeout=self.config.test_timeout, + ) + + exit_code = test_result.get("exit_code", -1) + output = test_result.get("output", "") + + # Download the verifier output directory locally, then read reward.txt + # with native Python I/O. This avoids issues with file_read on the + # Modal VM and matches Harbor's verification pattern. + reward = 0.0 + local_verifier_dir = Path(tempfile.mkdtemp(prefix=f"tb2-verifier-{task_name}-")) + try: + ctx.download_dir("/logs/verifier", str(local_verifier_dir)) + + reward_file = local_verifier_dir / "reward.txt" + if reward_file.exists() and reward_file.stat().st_size > 0: + content = reward_file.read_text().strip() + if content == "1": + reward = 1.0 + elif content == "0": + reward = 0.0 + else: + # Unexpected content -- try parsing as float + try: + reward = float(content) + except (ValueError, TypeError): + logger.warning( + "Task %s: reward.txt content unexpected (%r), " + "falling back to exit_code=%d", + task_name, content, exit_code, + ) + reward = 1.0 if exit_code == 0 else 0.0 + else: + # reward.txt not written -- fall back to exit code + logger.warning( + "Task %s: reward.txt not found after download, " + "falling back to exit_code=%d", + task_name, exit_code, + ) + reward = 1.0 if exit_code == 0 else 0.0 + except Exception as e: + logger.warning( + "Task %s: failed to download verifier dir: %s, " + "falling back to exit_code=%d", + task_name, e, exit_code, + ) + reward = 1.0 if exit_code == 0 else 0.0 + finally: + shutil.rmtree(local_verifier_dir, ignore_errors=True) + + # Log test output for debugging failures + if reward == 0.0: + output_preview = output[-500:] if output else "(no output)" + logger.info( + "Task %s: FAIL (exit_code=%d)\n%s", + task_name, exit_code, output_preview, + ) + + return reward + + # ========================================================================= + # Evaluate -- main entry point for the eval subcommand + # ========================================================================= + + async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict: + """ + Wrap rollout_and_score_eval with a per-task wall-clock timeout. + + If the task exceeds task_timeout seconds, it's automatically scored + as FAIL. This prevents any single task from hanging indefinitely. + """ + task_name = item.get("task_name", "unknown") + category = item.get("category", "unknown") + try: + return await asyncio.wait_for( + self.rollout_and_score_eval(item), + timeout=self.config.task_timeout, + ) + except asyncio.TimeoutError: + from tqdm import tqdm + elapsed = self.config.task_timeout + tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)") + logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed) + out = { + "passed": False, "reward": 0.0, + "task_name": task_name, "category": category, + "error": f"timeout ({elapsed}s)", + } + self._save_result(out) + return out + + async def evaluate(self, *args, **kwargs) -> None: + """ + Run Terminal-Bench 2.0 evaluation over all tasks. + + This is the main entry point when invoked via: + python environments/terminalbench2_env.py evaluate + + Runs all tasks through rollout_and_score_eval() via asyncio.gather() + (same pattern as GPQA and other Atropos eval envs). Each task is + wrapped with a wall-clock timeout so hung tasks auto-fail. + + Suppresses noisy Modal/terminal output (HERMES_QUIET) so the tqdm + bar stays visible. + """ + start_time = time.time() + + # Route all logging through tqdm.write() so the progress bar stays + # pinned at the bottom while log lines scroll above it. + from tqdm import tqdm + + class _TqdmHandler(logging.Handler): + def emit(self, record): + try: + tqdm.write(self.format(record)) + except Exception: + self.handleError(record) + + handler = _TqdmHandler() + handler.setFormatter(logging.Formatter( + "%(asctime)s [%(name)s] %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + )) + root = logging.getLogger() + root.handlers = [handler] # Replace any existing handlers + root.setLevel(logging.INFO) + + # Silence noisy third-party loggers that flood the output + logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request + logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries + logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment + logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds + + print(f"\n{'='*60}") + print("Starting Terminal-Bench 2.0 Evaluation") + print(f"{'='*60}") + print(f" Dataset: {self.config.dataset_name}") + print(f" Total tasks: {len(self.all_eval_items)}") + print(f" Max agent turns: {self.config.max_agent_turns}") + print(f" Task timeout: {self.config.task_timeout}s") + print(f" Terminal backend: {self.config.terminal_backend}") + print(f" Tool thread pool: {self.config.tool_pool_size}") + print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd") + print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)") + print(f" Max concurrent tasks: {self.config.max_concurrent_tasks}") + print(f"{'='*60}\n") + + # Semaphore to limit concurrent Modal sandbox creations. + # Without this, all 86 tasks fire simultaneously, each creating a Modal + # sandbox via asyncio.run() inside a thread pool worker. Modal's blocking + # calls (App.lookup, etc.) deadlock when too many are created at once. + semaphore = asyncio.Semaphore(self.config.max_concurrent_tasks) + + async def _eval_with_semaphore(item): + async with semaphore: + return await self._eval_with_timeout(item) + + # Fire all tasks with wall-clock timeout, track live accuracy on the bar + total_tasks = len(self.all_eval_items) + eval_tasks = [ + asyncio.ensure_future(_eval_with_semaphore(item)) + for item in self.all_eval_items + ] + + results = [] + passed_count = 0 + pbar = tqdm(total=total_tasks, desc="Evaluating TB2", dynamic_ncols=True) + try: + for coro in asyncio.as_completed(eval_tasks): + result = await coro + results.append(result) + if result and result.get("passed"): + passed_count += 1 + done = len(results) + pct = (passed_count / done * 100) if done else 0 + pbar.set_postfix_str(f"pass={passed_count}/{done} ({pct:.1f}%)") + pbar.update(1) + except (KeyboardInterrupt, asyncio.CancelledError): + pbar.close() + print(f"\n\nInterrupted! Cleaning up {len(eval_tasks)} tasks...") + # Cancel all pending tasks + for task in eval_tasks: + task.cancel() + # Let cancellations propagate (finally blocks run cleanup_vm) + await asyncio.gather(*eval_tasks, return_exceptions=True) + # Belt-and-suspenders: clean up any remaining sandboxes + from tools.terminal_tool import cleanup_all_environments + cleanup_all_environments() + print("All sandboxes cleaned up.") + return + finally: + pbar.close() + + end_time = time.time() + + # Filter out None results (shouldn't happen, but be safe) + valid_results = [r for r in results if r is not None] + + if not valid_results: + print("Warning: No valid evaluation results obtained") + return + + # ---- Compute metrics ---- + total = len(valid_results) + passed = sum(1 for r in valid_results if r.get("passed")) + overall_pass_rate = passed / total if total > 0 else 0.0 + + # Per-category breakdown + cat_results: Dict[str, List[Dict]] = defaultdict(list) + for r in valid_results: + cat_results[r.get("category", "unknown")].append(r) + + # Build metrics dict + eval_metrics = { + "eval/pass_rate": overall_pass_rate, + "eval/total_tasks": total, + "eval/passed_tasks": passed, + "eval/evaluation_time_seconds": end_time - start_time, + } + + # Per-category metrics + for category, cat_items in sorted(cat_results.items()): + cat_passed = sum(1 for r in cat_items if r.get("passed")) + cat_total = len(cat_items) + cat_pass_rate = cat_passed / cat_total if cat_total > 0 else 0.0 + cat_key = category.replace(" ", "_").replace("-", "_").lower() + eval_metrics[f"eval/pass_rate_{cat_key}"] = cat_pass_rate + + # Store metrics for wandb_log + self.eval_metrics = [(k, v) for k, v in eval_metrics.items()] + + # ---- Print summary ---- + print(f"\n{'='*60}") + print("Terminal-Bench 2.0 Evaluation Results") + print(f"{'='*60}") + print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})") + print(f"Evaluation Time: {end_time - start_time:.1f} seconds") + + print("\nCategory Breakdown:") + for category, cat_items in sorted(cat_results.items()): + cat_passed = sum(1 for r in cat_items if r.get("passed")) + cat_total = len(cat_items) + cat_rate = cat_passed / cat_total if cat_total > 0 else 0.0 + print(f" {category}: {cat_rate:.1%} ({cat_passed}/{cat_total})") + + # Print individual task results + print("\nTask Results:") + for r in sorted(valid_results, key=lambda x: x.get("task_name", "")): + status = "PASS" if r.get("passed") else "FAIL" + turns = r.get("turns_used", "?") + error = r.get("error", "") + extra = f" (error: {error})" if error else "" + print(f" [{status}] {r['task_name']} (turns={turns}){extra}") + + print(f"{'='*60}\n") + + # Build sample records for evaluate_log (includes full conversations) + samples = [ + { + "task_name": r.get("task_name"), + "category": r.get("category"), + "passed": r.get("passed"), + "reward": r.get("reward"), + "turns_used": r.get("turns_used"), + "error": r.get("error"), + "messages": r.get("messages"), + } + for r in valid_results + ] + + # Log evaluation results + try: + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": self.config.agent_temperature, + "max_tokens": self.config.max_token_length, + "max_agent_turns": self.config.max_agent_turns, + "terminal_backend": self.config.terminal_backend, + }, + ) + except Exception as e: + print(f"Error logging evaluation results: {e}") + + # Close streaming file + if hasattr(self, "_streaming_file") and not self._streaming_file.closed: + self._streaming_file.close() + print(f" Live results saved to: {self._streaming_path}") + + # Kill all remaining sandboxes. Timed-out tasks leave orphaned thread + # pool workers still executing commands -- cleanup_all stops them. + from tools.terminal_tool import cleanup_all_environments + print("\nCleaning up all sandboxes...") + cleanup_all_environments() + + # Shut down the tool thread pool so orphaned workers from timed-out + # tasks are killed immediately instead of retrying against dead + # sandboxes and spamming the console with TimeoutError warnings. + from environments.agent_loop import _tool_executor + _tool_executor.shutdown(wait=False, cancel_futures=True) + print("Done.") + + # ========================================================================= + # Wandb logging + # ========================================================================= + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log TB2-specific metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + + # Add stored eval metrics + for metric_name, metric_value in self.eval_metrics: + wandb_metrics[metric_name] = metric_value + self.eval_metrics = [] + + await super().wandb_log(wandb_metrics) + + +if __name__ == "__main__": + TerminalBench2EvalEnv.cli() diff --git a/gateway/builtin_hooks/__init__.py b/gateway/builtin_hooks/__init__.py new file mode 100644 index 000000000..37da09db9 --- /dev/null +++ b/gateway/builtin_hooks/__init__.py @@ -0,0 +1 @@ +"""Built-in gateway hooks that are always registered.""" diff --git a/gateway/builtin_hooks/boot_md.py b/gateway/builtin_hooks/boot_md.py new file mode 100644 index 000000000..fced0b5e1 --- /dev/null +++ b/gateway/builtin_hooks/boot_md.py @@ -0,0 +1,86 @@ +"""Built-in boot-md hook — run ~/.hermes/BOOT.md on gateway startup. + +This hook is always registered. It silently skips if no BOOT.md exists. +To activate, create ``~/.hermes/BOOT.md`` with instructions for the +agent to execute on every gateway restart. + +Example BOOT.md:: + + # Startup Checklist + + 1. Check if any cron jobs failed overnight + 2. Send a status update to Discord #general + 3. If there are errors in /opt/app/deploy.log, summarize them + +The agent runs in a background thread so it doesn't block gateway +startup. If nothing needs attention, it replies with [SILENT] to +suppress delivery. +""" + +import logging +import os +import threading +from pathlib import Path + +logger = logging.getLogger("hooks.boot-md") + +HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) +BOOT_FILE = HERMES_HOME / "BOOT.md" + + +def _build_boot_prompt(content: str) -> str: + """Wrap BOOT.md content in a system-level instruction.""" + return ( + "You are running a startup boot checklist. Follow the BOOT.md " + "instructions below exactly.\n\n" + "---\n" + f"{content}\n" + "---\n\n" + "Execute each instruction. If you need to send a message to a " + "platform, use the send_message tool.\n" + "If nothing needs attention and there is nothing to report, " + "reply with ONLY: [SILENT]" + ) + + +def _run_boot_agent(content: str) -> None: + """Spawn a one-shot agent session to execute the boot instructions.""" + try: + from run_agent import AIAgent + + prompt = _build_boot_prompt(content) + agent = AIAgent( + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + max_iterations=20, + ) + result = agent.run_conversation(prompt) + response = result.get("final_response", "") + if response and "[SILENT]" not in response: + logger.info("boot-md completed: %s", response[:200]) + else: + logger.info("boot-md completed (nothing to report)") + except Exception as e: + logger.error("boot-md agent failed: %s", e) + + +async def handle(event_type: str, context: dict) -> None: + """Gateway startup handler — run BOOT.md if it exists.""" + if not BOOT_FILE.exists(): + return + + content = BOOT_FILE.read_text(encoding="utf-8").strip() + if not content: + return + + logger.info("Running BOOT.md (%d chars)", len(content)) + + # Run in a background thread so we don't block gateway startup. + thread = threading.Thread( + target=_run_boot_agent, + args=(content,), + name="boot-md", + daemon=True, + ) + thread.start() diff --git a/gateway/config.py b/gateway/config.py index 1f84c7689..d7383100e 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -52,6 +52,8 @@ class Platform(Enum): DINGTALK = "dingtalk" API_SERVER = "api_server" WEBHOOK = "webhook" + FEISHU = "feishu" + WECOM = "wecom" @dataclass @@ -269,6 +271,12 @@ class GatewayConfig: # Webhook uses enabled flag only (secrets are per-route) elif platform == Platform.WEBHOOK: connected.append(platform) + # Feishu uses extra dict for app credentials + elif platform == Platform.FEISHU and config.extra.get("app_id"): + connected.append(platform) + # WeCom uses extra dict for bot credentials + elif platform == Platform.WECOM and config.extra.get("bot_id"): + connected.append(platform) return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: @@ -596,6 +604,14 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.TELEGRAM] = PlatformConfig() config.platforms[Platform.TELEGRAM].reply_to_mode = telegram_reply_mode + telegram_fallback_ips = os.getenv("TELEGRAM_FALLBACK_IPS", "") + if telegram_fallback_ips: + if Platform.TELEGRAM not in config.platforms: + config.platforms[Platform.TELEGRAM] = PlatformConfig() + config.platforms[Platform.TELEGRAM].extra["fallback_ips"] = [ + ip.strip() for ip in telegram_fallback_ips.split(",") if ip.strip() + ] + telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL") if telegram_home and Platform.TELEGRAM in config.platforms: config.platforms[Platform.TELEGRAM].home_channel = HomeChannel( @@ -634,14 +650,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.SLACK] = PlatformConfig() config.platforms[Platform.SLACK].enabled = True config.platforms[Platform.SLACK].token = slack_token - # Home channel - slack_home = os.getenv("SLACK_HOME_CHANNEL") - if slack_home: - config.platforms[Platform.SLACK].home_channel = HomeChannel( - platform=Platform.SLACK, - chat_id=slack_home, - name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""), - ) + slack_home = os.getenv("SLACK_HOME_CHANNEL") + if slack_home and Platform.SLACK in config.platforms: + config.platforms[Platform.SLACK].home_channel = HomeChannel( + platform=Platform.SLACK, + chat_id=slack_home, + name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""), + ) # Signal signal_url = os.getenv("SIGNAL_HTTP_URL") @@ -655,13 +670,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: "account": signal_account, "ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"), }) - signal_home = os.getenv("SIGNAL_HOME_CHANNEL") - if signal_home: - config.platforms[Platform.SIGNAL].home_channel = HomeChannel( - platform=Platform.SIGNAL, - chat_id=signal_home, - name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"), - ) + signal_home = os.getenv("SIGNAL_HOME_CHANNEL") + if signal_home and Platform.SIGNAL in config.platforms: + config.platforms[Platform.SIGNAL].home_channel = HomeChannel( + platform=Platform.SIGNAL, + chat_id=signal_home, + name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"), + ) # Mattermost mattermost_token = os.getenv("MATTERMOST_TOKEN") @@ -674,13 +689,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.MATTERMOST].enabled = True config.platforms[Platform.MATTERMOST].token = mattermost_token config.platforms[Platform.MATTERMOST].extra["url"] = mattermost_url - mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL") - if mattermost_home: - config.platforms[Platform.MATTERMOST].home_channel = HomeChannel( - platform=Platform.MATTERMOST, - chat_id=mattermost_home, - name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"), - ) + mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL") + if mattermost_home and Platform.MATTERMOST in config.platforms: + config.platforms[Platform.MATTERMOST].home_channel = HomeChannel( + platform=Platform.MATTERMOST, + chat_id=mattermost_home, + name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"), + ) # Matrix matrix_token = os.getenv("MATRIX_ACCESS_TOKEN") @@ -702,13 +717,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.MATRIX].extra["password"] = matrix_password matrix_e2ee = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes") config.platforms[Platform.MATRIX].extra["encryption"] = matrix_e2ee - matrix_home = os.getenv("MATRIX_HOME_ROOM") - if matrix_home: - config.platforms[Platform.MATRIX].home_channel = HomeChannel( - platform=Platform.MATRIX, - chat_id=matrix_home, - name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"), - ) + matrix_home = os.getenv("MATRIX_HOME_ROOM") + if matrix_home and Platform.MATRIX in config.platforms: + config.platforms[Platform.MATRIX].home_channel = HomeChannel( + platform=Platform.MATRIX, + chat_id=matrix_home, + name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"), + ) # Home Assistant hass_token = os.getenv("HASS_TOKEN") @@ -735,13 +750,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: "imap_host": email_imap, "smtp_host": email_smtp, }) - email_home = os.getenv("EMAIL_HOME_ADDRESS") - if email_home: - config.platforms[Platform.EMAIL].home_channel = HomeChannel( - platform=Platform.EMAIL, - chat_id=email_home, - name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"), - ) + email_home = os.getenv("EMAIL_HOME_ADDRESS") + if email_home and Platform.EMAIL in config.platforms: + config.platforms[Platform.EMAIL].home_channel = HomeChannel( + platform=Platform.EMAIL, + chat_id=email_home, + name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"), + ) # SMS (Twilio) twilio_sid = os.getenv("TWILIO_ACCOUNT_SID") @@ -750,13 +765,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.SMS] = PlatformConfig() config.platforms[Platform.SMS].enabled = True config.platforms[Platform.SMS].api_key = os.getenv("TWILIO_AUTH_TOKEN", "") - sms_home = os.getenv("SMS_HOME_CHANNEL") - if sms_home: - config.platforms[Platform.SMS].home_channel = HomeChannel( - platform=Platform.SMS, - chat_id=sms_home, - name=os.getenv("SMS_HOME_CHANNEL_NAME", "Home"), - ) + sms_home = os.getenv("SMS_HOME_CHANNEL") + if sms_home and Platform.SMS in config.platforms: + config.platforms[Platform.SMS].home_channel = HomeChannel( + platform=Platform.SMS, + chat_id=sms_home, + name=os.getenv("SMS_HOME_CHANNEL_NAME", "Home"), + ) # API Server api_server_enabled = os.getenv("API_SERVER_ENABLED", "").lower() in ("true", "1", "yes") @@ -798,6 +813,55 @@ def _apply_env_overrides(config: GatewayConfig) -> None: if webhook_secret: config.platforms[Platform.WEBHOOK].extra["secret"] = webhook_secret + # Feishu / Lark + feishu_app_id = os.getenv("FEISHU_APP_ID") + feishu_app_secret = os.getenv("FEISHU_APP_SECRET") + if feishu_app_id and feishu_app_secret: + if Platform.FEISHU not in config.platforms: + config.platforms[Platform.FEISHU] = PlatformConfig() + config.platforms[Platform.FEISHU].enabled = True + config.platforms[Platform.FEISHU].extra.update({ + "app_id": feishu_app_id, + "app_secret": feishu_app_secret, + "domain": os.getenv("FEISHU_DOMAIN", "feishu"), + "connection_mode": os.getenv("FEISHU_CONNECTION_MODE", "websocket"), + }) + feishu_encrypt_key = os.getenv("FEISHU_ENCRYPT_KEY", "") + if feishu_encrypt_key: + config.platforms[Platform.FEISHU].extra["encrypt_key"] = feishu_encrypt_key + feishu_verification_token = os.getenv("FEISHU_VERIFICATION_TOKEN", "") + if feishu_verification_token: + config.platforms[Platform.FEISHU].extra["verification_token"] = feishu_verification_token + feishu_home = os.getenv("FEISHU_HOME_CHANNEL") + if feishu_home: + config.platforms[Platform.FEISHU].home_channel = HomeChannel( + platform=Platform.FEISHU, + chat_id=feishu_home, + name=os.getenv("FEISHU_HOME_CHANNEL_NAME", "Home"), + ) + + # WeCom (Enterprise WeChat) + wecom_bot_id = os.getenv("WECOM_BOT_ID") + wecom_secret = os.getenv("WECOM_SECRET") + if wecom_bot_id and wecom_secret: + if Platform.WECOM not in config.platforms: + config.platforms[Platform.WECOM] = PlatformConfig() + config.platforms[Platform.WECOM].enabled = True + config.platforms[Platform.WECOM].extra.update({ + "bot_id": wecom_bot_id, + "secret": wecom_secret, + }) + wecom_ws_url = os.getenv("WECOM_WEBSOCKET_URL", "") + if wecom_ws_url: + config.platforms[Platform.WECOM].extra["websocket_url"] = wecom_ws_url + wecom_home = os.getenv("WECOM_HOME_CHANNEL") + if wecom_home: + config.platforms[Platform.WECOM].home_channel = HomeChannel( + platform=Platform.WECOM, + chat_id=wecom_home, + name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"), + ) + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/hooks.py b/gateway/hooks.py index 15ecd3fee..c50394b20 100644 --- a/gateway/hooks.py +++ b/gateway/hooks.py @@ -51,14 +51,33 @@ class HookRegistry: """Return metadata about all loaded hooks.""" return list(self._loaded_hooks) + def _register_builtin_hooks(self) -> None: + """Register built-in hooks that are always active.""" + try: + from gateway.builtin_hooks.boot_md import handle as boot_md_handle + + self._handlers.setdefault("gateway:startup", []).append(boot_md_handle) + self._loaded_hooks.append({ + "name": "boot-md", + "description": "Run ~/.hermes/BOOT.md on gateway startup", + "events": ["gateway:startup"], + "path": "(builtin)", + }) + except Exception as e: + print(f"[hooks] Could not load built-in boot-md hook: {e}", flush=True) + def discover_and_load(self) -> None: """ Scan the hooks directory for hook directories and load their handlers. + Also registers built-in hooks that are always active. + Each hook directory must contain: - HOOK.yaml with at least 'name' and 'events' keys - handler.py with a top-level 'handle' function (sync or async) """ + self._register_builtin_hooks() + if not HOOKS_DIR.exists(): return diff --git a/gateway/pairing.py b/gateway/pairing.py index 20b64b013..34b3d9023 100644 --- a/gateway/pairing.py +++ b/gateway/pairing.py @@ -25,7 +25,7 @@ import time from pathlib import Path from typing import Optional -from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_dir # Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion @@ -41,7 +41,7 @@ LOCKOUT_SECONDS = 3600 # Lockout duration after too many failures MAX_PENDING_PER_PLATFORM = 3 # Max pending codes per platform MAX_FAILED_ATTEMPTS = 5 # Failed approvals before lockout -PAIRING_DIR = get_hermes_home() / "pairing" +PAIRING_DIR = get_hermes_dir("platforms/pairing", "pairing") def _secure_write(path: Path, data: str) -> None: diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 425613cb5..19fa5f60d 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -166,7 +166,7 @@ class ResponseStore: _CORS_HEADERS = { "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Authorization, Content-Type", + "Access-Control-Allow-Headers": "Authorization, Content-Type, Idempotency-Key", } @@ -223,6 +223,23 @@ if AIOHTTP_AVAILABLE: else: body_limit_middleware = None # type: ignore[assignment] +_SECURITY_HEADERS = { + "X-Content-Type-Options": "nosniff", + "Referrer-Policy": "no-referrer", +} + + +if AIOHTTP_AVAILABLE: + @web.middleware + async def security_headers_middleware(request, handler): + """Add security headers to all responses (including errors).""" + response = await handler(request) + for k, v in _SECURITY_HEADERS.items(): + response.headers.setdefault(k, v) + return response +else: + security_headers_middleware = None # type: ignore[assignment] + class _IdempotencyCache: """In-memory idempotency cache with TTL and basic LRU semantics.""" @@ -307,6 +324,7 @@ class APIServerAdapter(BasePlatformAdapter): if "*" in self._cors_origins: headers = dict(_CORS_HEADERS) headers["Access-Control-Allow-Origin"] = "*" + headers["Access-Control-Max-Age"] = "600" return headers if origin not in self._cors_origins: @@ -315,6 +333,7 @@ class APIServerAdapter(BasePlatformAdapter): headers = dict(_CORS_HEADERS) headers["Access-Control-Allow-Origin"] = origin headers["Vary"] = "Origin" + headers["Access-Control-Max-Age"] = "600" return headers def _origin_allowed(self, origin: str) -> bool: @@ -366,14 +385,20 @@ class APIServerAdapter(BasePlatformAdapter): Create an AIAgent instance using the gateway's runtime config. Uses _resolve_runtime_agent_kwargs() to pick up model, api_key, - base_url, etc. from config.yaml / env vars. + base_url, etc. from config.yaml / env vars. Toolsets are resolved + from config.yaml platform_toolsets.api_server (same as all other + gateway platforms), falling back to the hermes-api-server default. """ from run_agent import AIAgent - from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model + from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config + from hermes_cli.tools_config import _get_platform_tools runtime_kwargs = _resolve_runtime_agent_kwargs() model = _resolve_gateway_model() + user_config = _load_gateway_config() + enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server")) + max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90")) agent = AIAgent( @@ -383,6 +408,7 @@ class APIServerAdapter(BasePlatformAdapter): quiet_mode=True, verbose_logging=False, ephemeral_system_prompt=ephemeral_system_prompt or None, + enabled_toolsets=enabled_toolsets, session_id=session_id, platform="api_server", stream_delta_callback=stream_delta_callback, @@ -488,17 +514,21 @@ class APIServerAdapter(BasePlatformAdapter): if delta is not None: _stream_q.put(delta) - # Start agent in background + # Start agent in background. agent_ref is a mutable container + # so the SSE writer can interrupt the agent on client disconnect. + agent_ref = [None] agent_task = asyncio.ensure_future(self._run_agent( user_message=user_message, conversation_history=history, ephemeral_system_prompt=system_prompt, session_id=session_id, stream_delta_callback=_on_delta, + agent_ref=agent_ref, )) return await self._write_sse_chat_completion( - request, completion_id, model_name, created, _stream_q, agent_task + request, completion_id, model_name, created, _stream_q, + agent_task, agent_ref, ) # Non-streaming: run the agent (with optional Idempotency-Key) @@ -561,80 +591,107 @@ class APIServerAdapter(BasePlatformAdapter): async def _write_sse_chat_completion( self, request: "web.Request", completion_id: str, model: str, - created: int, stream_q, agent_task, + created: int, stream_q, agent_task, agent_ref=None, ) -> "web.StreamResponse": - """Write real streaming SSE from agent's stream_delta_callback queue.""" + """Write real streaming SSE from agent's stream_delta_callback queue. + + If the client disconnects mid-stream (network drop, browser tab close), + the agent is interrupted via ``agent.interrupt()`` so it stops making + LLM API calls, and the asyncio task wrapper is cancelled. + """ import queue as _q - response = web.StreamResponse( - status=200, - headers={"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}, - ) + sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"} + # CORS middleware can't inject headers into StreamResponse after + # prepare() flushes them, so resolve CORS headers up front. + origin = request.headers.get("Origin", "") + cors = self._cors_headers_for_origin(origin) if origin else None + if cors: + sse_headers.update(cors) + response = web.StreamResponse(status=200, headers=sse_headers) await response.prepare(request) - # Role chunk - role_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) - - # Stream content chunks as they arrive from the agent - loop = asyncio.get_event_loop() - while True: - try: - delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5)) - except _q.Empty: - if agent_task.done(): - # Drain any remaining items - while True: - try: - delta = stream_q.get_nowait() - if delta is None: - break - content_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) - except _q.Empty: - break - break - continue - - if delta is None: # End of stream sentinel - break - - content_chunk = { + try: + # Role chunk + role_chunk = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) - # Get usage from completed agent - usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} - try: - result, agent_usage = await agent_task - usage = agent_usage or usage - except Exception: - pass + # Stream content chunks as they arrive from the agent + loop = asyncio.get_event_loop() + while True: + try: + delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5)) + except _q.Empty: + if agent_task.done(): + # Drain any remaining items + while True: + try: + delta = stream_q.get_nowait() + if delta is None: + break + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + except _q.Empty: + break + break + continue - # Finish chunk - finish_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], - "usage": { - "prompt_tokens": usage.get("input_tokens", 0), - "completion_tokens": usage.get("output_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - }, - } - await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode()) - await response.write(b"data: [DONE]\n\n") + if delta is None: # End of stream sentinel + break + + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + + # Get usage from completed agent + usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + try: + result, agent_usage = await agent_task + usage = agent_usage or usage + except Exception: + pass + + # Finish chunk + finish_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + } + await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode()) + await response.write(b"data: [DONE]\n\n") + except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError): + # Client disconnected mid-stream. Interrupt the agent so it + # stops making LLM API calls at the next loop iteration, then + # cancel the asyncio task wrapper. + agent = agent_ref[0] if agent_ref else None + if agent is not None: + try: + agent.interrupt("SSE client disconnected") + except Exception: + pass + if not agent_task.done(): + agent_task.cancel() + try: + await agent_task + except (asyncio.CancelledError, Exception): + pass + logger.info("SSE client disconnected; interrupted agent task %s", completion_id) return response @@ -1137,12 +1194,18 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, stream_delta_callback=None, + agent_ref: Optional[list] = None, ) -> tuple: """ Create an agent and run a conversation in a thread executor. Returns ``(result_dict, usage_dict)`` where *usage_dict* contains ``input_tokens``, ``output_tokens`` and ``total_tokens``. + + If *agent_ref* is a one-element list, the AIAgent instance is stored + at ``agent_ref[0]`` before ``run_conversation`` begins. This allows + callers (e.g. the SSE writer) to call ``agent.interrupt()`` from + another thread to stop in-progress LLM calls. """ loop = asyncio.get_event_loop() @@ -1152,6 +1215,8 @@ class APIServerAdapter(BasePlatformAdapter): session_id=session_id, stream_delta_callback=stream_delta_callback, ) + if agent_ref is not None: + agent_ref[0] = agent result = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, @@ -1176,10 +1241,11 @@ class APIServerAdapter(BasePlatformAdapter): return False try: - mws = [mw for mw in (cors_middleware, body_limit_middleware) if mw is not None] + mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None] self._app = web.Application(middlewares=mws) self._app["api_server_adapter"] = self self._app.router.add_get("/health", self._handle_health) + self._app.router.add_get("/v1/health", self._handle_health) self._app.router.add_get("/v1/models", self._handle_models) self._app.router.add_post("/v1/chat/completions", self._handle_chat_completions) self._app.router.add_post("/v1/responses", self._handle_responses) @@ -1195,6 +1261,17 @@ class APIServerAdapter(BasePlatformAdapter): self._app.router.add_post("/api/jobs/{job_id}/resume", self._handle_resume_job) self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job) + # Port conflict detection — fail fast if port is already in use + import socket as _socket + try: + with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s: + _s.settimeout(1) + _s.connect(('127.0.0.1', self._port)) + logger.error('[%s] Port %d already in use. Set a different port in config.yaml: platforms.api_server.port', self.name, self._port) + return False + except (ConnectionRefusedError, OSError): + pass # port is free + self._runner = web.AppRunner(self._app) await self._runner.setup() self._site = web.TCPSite(self._runner, self._host, self._port) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 338859ed2..efa5ed318 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -8,6 +8,7 @@ and implement the required methods. import asyncio import logging import os +import random import re import uuid from abc import ABC, abstractmethod @@ -26,6 +27,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig from gateway.session import SessionSource, build_session_key from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_dir GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = ( @@ -43,8 +45,8 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = ( # (e.g. Telegram file URLs expire after ~1 hour). # --------------------------------------------------------------------------- -# Default location: {HERMES_HOME}/image_cache/ -IMAGE_CACHE_DIR = get_hermes_home() / "image_cache" +# Default location: {HERMES_HOME}/cache/images/ (legacy: image_cache/) +IMAGE_CACHE_DIR = get_hermes_dir("cache/images", "image_cache") def get_image_cache_dir() -> Path: @@ -71,31 +73,51 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str: return str(filepath) -async def cache_image_from_url(url: str, ext: str = ".jpg") -> str: +async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str: """ Download an image from a URL and save it to the local cache. - Uses httpx for async download with a reasonable timeout. + Retries on transient failures (timeouts, 429, 5xx) with exponential + backoff so a single slow CDN response doesn't lose the media. Args: url: The HTTP/HTTPS URL to download from. ext: File extension including the dot (e.g. ".jpg", ".png"). + retries: Number of retry attempts on transient failures. Returns: Absolute path to the cached image file as a string. """ + import asyncio import httpx + import logging as _logging + _log = _logging.getLogger(__name__) + last_exc = None async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get( - url, - headers={ - "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", - "Accept": "image/*,*/*;q=0.8", - }, - ) - response.raise_for_status() - return cache_image_from_bytes(response.content, ext) + for attempt in range(retries + 1): + try: + response = await client.get( + url, + headers={ + "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", + "Accept": "image/*,*/*;q=0.8", + }, + ) + response.raise_for_status() + return cache_image_from_bytes(response.content, ext) + except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: + last_exc = exc + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: + raise + if attempt < retries: + wait = 1.5 * (attempt + 1) + _log.debug("Media cache retry %d/%d for %s (%.1fs): %s", + attempt + 1, retries, url[:80], wait, exc) + await asyncio.sleep(wait) + continue + raise + raise last_exc def cleanup_image_cache(max_age_hours: int = 24) -> int: @@ -126,7 +148,7 @@ def cleanup_image_cache(max_age_hours: int = 24) -> int: # here so the STT tool (OpenAI Whisper) can transcribe them from local files. # --------------------------------------------------------------------------- -AUDIO_CACHE_DIR = get_hermes_home() / "audio_cache" +AUDIO_CACHE_DIR = get_hermes_dir("cache/audio", "audio_cache") def get_audio_cache_dir() -> Path: @@ -153,29 +175,51 @@ def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str: return str(filepath) -async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str: +async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> str: """ Download an audio file from a URL and save it to the local cache. + Retries on transient failures (timeouts, 429, 5xx) with exponential + backoff so a single slow CDN response doesn't lose the media. + Args: url: The HTTP/HTTPS URL to download from. ext: File extension including the dot (e.g. ".ogg", ".mp3"). + retries: Number of retry attempts on transient failures. Returns: Absolute path to the cached audio file as a string. """ + import asyncio import httpx + import logging as _logging + _log = _logging.getLogger(__name__) + last_exc = None async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get( - url, - headers={ - "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", - "Accept": "audio/*,*/*;q=0.8", - }, - ) - response.raise_for_status() - return cache_audio_from_bytes(response.content, ext) + for attempt in range(retries + 1): + try: + response = await client.get( + url, + headers={ + "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", + "Accept": "audio/*,*/*;q=0.8", + }, + ) + response.raise_for_status() + return cache_audio_from_bytes(response.content, ext) + except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: + last_exc = exc + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: + raise + if attempt < retries: + wait = 1.5 * (attempt + 1) + _log.debug("Audio cache retry %d/%d for %s (%.1fs): %s", + attempt + 1, retries, url[:80], wait, exc) + await asyncio.sleep(wait) + continue + raise + raise last_exc # --------------------------------------------------------------------------- @@ -185,7 +229,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg") -> str: # here so the agent can reference them by local file path. # --------------------------------------------------------------------------- -DOCUMENT_CACHE_DIR = get_hermes_home() / "document_cache" +DOCUMENT_CACHE_DIR = get_hermes_dir("cache/documents", "document_cache") SUPPORTED_DOCUMENT_TYPES = { ".pdf": "application/pdf", @@ -312,7 +356,10 @@ class MessageEvent: return None # Split on space and get first word, strip the / parts = self.text.split(maxsplit=1) - return parts[0][1:].lower() if parts else None + raw = parts[0][1:].lower() if parts else None + if raw and "@" in raw: + raw = raw.split("@", 1)[0] + return raw def get_command_args(self) -> str: """Get the arguments after a command.""" @@ -329,6 +376,24 @@ class SendResult: message_id: Optional[str] = None error: Optional[str] = None raw_response: Any = None + retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically + + +# Error substrings that indicate a transient network failure worth retrying +_RETRYABLE_ERROR_PATTERNS = ( + "connecterror", + "connectionerror", + "connectionreset", + "connectionrefused", + "timeout", + "timed out", + "network", + "broken pipe", + "remotedisconnected", + "eoferror", + "readtimeout", + "writetimeout", +) # Type for message handlers @@ -833,6 +898,91 @@ class BasePlatformAdapter(ABC): except Exception: pass + @staticmethod + def _is_retryable_error(error: Optional[str]) -> bool: + """Return True if the error string looks like a transient network failure.""" + if not error: + return False + lowered = error.lower() + return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS) + + async def _send_with_retry( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Any = None, + max_retries: int = 2, + base_delay: float = 2.0, + ) -> "SendResult": + """ + Send a message with automatic retry for transient network errors. + + On permanent failures (e.g. formatting / permission errors) falls back + to a plain-text version before giving up. If all attempts fail due to + network errors, sends the user a brief delivery-failure notice so they + know to retry rather than waiting indefinitely. + """ + + result = await self.send( + chat_id=chat_id, + content=content, + reply_to=reply_to, + metadata=metadata, + ) + + if result.success: + return result + + error_str = result.error or "" + is_network = result.retryable or self._is_retryable_error(error_str) + + if is_network: + # Retry with exponential backoff for transient errors + for attempt in range(1, max_retries + 1): + delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1) + logger.warning( + "[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s", + self.name, attempt, max_retries, delay, error_str, + ) + await asyncio.sleep(delay) + result = await self.send( + chat_id=chat_id, + content=content, + reply_to=reply_to, + metadata=metadata, + ) + if result.success: + logger.info("[%s] Send succeeded on retry %d", self.name, attempt) + return result + error_str = result.error or "" + if not (result.retryable or self._is_retryable_error(error_str)): + break # error switched to non-transient — fall through to plain-text fallback + else: + # All retries exhausted (loop completed without break) — notify user + logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str) + notice = ( + "\u26a0\ufe0f Message delivery failed after multiple attempts. " + "Please try again \u2014 your request was processed but the response could not be sent." + ) + try: + await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata) + except Exception as notify_err: + logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err) + return result + + # Non-network / post-retry formatting failure: try plain text as fallback + logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str) + fallback_result = await self.send( + chat_id=chat_id, + content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}", + reply_to=reply_to, + metadata=metadata, + ) + if not fallback_result.success: + logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error) + return fallback_result + async def handle_message(self, event: MessageEvent) -> None: """ Process an incoming message. @@ -855,7 +1005,7 @@ class BasePlatformAdapter(ABC): # simultaneous messages. Queue them without interrupting the active run, # then process them immediately after the current task finishes. if event.message_type == MessageType.PHOTO: - print(f"[{self.name}] 🖼️ Queuing photo follow-up for session {session_key} without interrupt") + logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key) existing = self._pending_messages.get(session_key) if existing and existing.message_type == MessageType.PHOTO: existing.media_urls.extend(event.media_urls) @@ -870,7 +1020,7 @@ class BasePlatformAdapter(ABC): return # Don't interrupt now - will run after current task completes # Default behavior for non-photo follow-ups: interrupt the running agent - print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt") + logger.debug("[%s] New message while session %s is active — triggering interrupt", self.name, session_key) self._pending_messages[session_key] = event # Signal the interrupt (the processing task checks this) self._active_sessions[session_key].set() @@ -982,26 +1132,13 @@ class BasePlatformAdapter(ABC): # Send the text portion if text_content: logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id) - result = await self.send( + result = await self._send_with_retry( chat_id=event.source.chat_id, content=text_content, reply_to=event.message_id, metadata=_thread_metadata, ) - # Log send failures (don't raise - user already saw tool progress) - if not result.success: - print(f"[{self.name}] Failed to send response: {result.error}") - # Try sending without markdown as fallback - fallback_result = await self.send( - chat_id=event.source.chat_id, - content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}", - reply_to=event.message_id, - metadata=_thread_metadata, - ) - if not fallback_result.success: - print(f"[{self.name}] Fallback send also failed: {fallback_result.error}") - # Human-like pacing delay between text and media human_delay = self._get_human_delay() @@ -1069,9 +1206,9 @@ class BasePlatformAdapter(ABC): ) if not media_result.success: - print(f"[{self.name}] Failed to send media ({ext}): {media_result.error}") + logger.warning("[%s] Failed to send media (%s): %s", self.name, ext, media_result.error) except Exception as media_err: - print(f"[{self.name}] Error sending media: {media_err}") + logger.warning("[%s] Error sending media: %s", self.name, media_err) # Send auto-detected local files as native attachments for file_path in local_files: @@ -1103,7 +1240,7 @@ class BasePlatformAdapter(ABC): # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: pending_event = self._pending_messages.pop(session_key) - print(f"[{self.name}] 📨 Processing queued message from interrupt") + logger.debug("[%s] Processing queued message from interrupt", self.name) # Clean up current session before processing pending if session_key in self._active_sessions: del self._active_sessions[session_key] @@ -1117,9 +1254,7 @@ class BasePlatformAdapter(ABC): return # Already cleaned up except Exception as e: - print(f"[{self.name}] Error handling message: {e}") - import traceback - traceback.print_exc() + logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True) # Send the error to the user so they aren't left with radio silence try: error_type = type(e).__name__ diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index cb5bab1fa..1da9925cd 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -486,6 +486,17 @@ class DiscordAdapter(BasePlatformAdapter): return False try: + # Acquire scoped lock to prevent duplicate bot token usage + from gateway.status import acquire_scoped_lock + self._token_lock_identity = self.config.token + acquired, existing = acquire_scoped_lock('discord-bot-token', self._token_lock_identity, metadata={'platform': 'discord'}) + if not acquired: + owner_pid = existing.get('pid') if isinstance(existing, dict) else None + message = f'Discord bot token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' + logger.error('[%s] %s', self.name, message) + self._set_fatal_error('discord_token_lock', message, retryable=False) + return False + # Set up intents -- members intent needed for username-to-ID resolution intents = Intents.default() intents.message_content = True @@ -550,6 +561,22 @@ class DiscordAdapter(BasePlatformAdapter): return # "all" falls through to handle_message + # If the message @mentions other users but NOT the bot, the + # sender is talking to someone else — stay silent. Only + # applies in server channels; in DMs the user is always + # talking to the bot (mentions are just references). + # Controlled by DISCORD_IGNORE_NO_MENTION (default: true). + _ignore_no_mention = os.getenv( + "DISCORD_IGNORE_NO_MENTION", "true" + ).lower() in ("true", "1", "yes") + if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel): + _bot_mentioned = ( + self._client.user is not None + and self._client.user in message.mentions + ) + if not _bot_mentioned: + return # Talking to someone else, don't interrupt + await self._handle_message(message) @self._client.event @@ -622,6 +649,16 @@ class DiscordAdapter(BasePlatformAdapter): self._running = False self._client = None self._ready_event.clear() + + # Release the token lock + try: + from gateway.status import release_scoped_lock + if getattr(self, '_token_lock_identity', None): + release_scoped_lock('discord-bot-token', self._token_lock_identity) + self._token_lock_identity = None + except Exception: + pass + logger.info("[%s] Disconnected", self.name) async def send( @@ -1413,15 +1450,23 @@ class DiscordAdapter(BasePlatformAdapter): command_text: str, followup_msg: str | None = None, ) -> None: - """Common handler for simple slash commands that dispatch a command string.""" + """Common handler for simple slash commands that dispatch a command string. + + Defers the interaction (shows "thinking..."), dispatches the command, + then cleans up the deferred response. If *followup_msg* is provided + the "thinking..." indicator is replaced with that text; otherwise it + is deleted so the channel isn't cluttered. + """ await interaction.response.defer(ephemeral=True) event = self._build_slash_event(interaction, command_text) await self.handle_message(event) - if followup_msg: - try: - await interaction.followup.send(followup_msg, ephemeral=True) - except Exception as e: - logger.debug("Discord followup failed: %s", e) + try: + if followup_msg: + await interaction.edit_original_response(content=followup_msg) + else: + await interaction.delete_original_response() + except Exception as e: + logger.debug("Discord interaction cleanup failed: %s", e) def _register_slash_commands(self) -> None: """Register Discord slash commands on the command tree.""" @@ -1446,9 +1491,7 @@ class DiscordAdapter(BasePlatformAdapter): @tree.command(name="reasoning", description="Show or change reasoning effort") @discord.app_commands.describe(effort="Reasoning effort: xhigh, high, medium, low, minimal, or none.") async def slash_reasoning(interaction: discord.Interaction, effort: str = ""): - await interaction.response.defer(ephemeral=True) - event = self._build_slash_event(interaction, f"/reasoning {effort}".strip()) - await self.handle_message(event) + await self._run_simple_slash(interaction, f"/reasoning {effort}".strip()) @tree.command(name="personality", description="Set a personality") @discord.app_commands.describe(name="Personality name. Leave empty to list available.") @@ -1521,9 +1564,7 @@ class DiscordAdapter(BasePlatformAdapter): discord.app_commands.Choice(name="status — show current mode", value="status"), ]) async def slash_voice(interaction: discord.Interaction, mode: str = ""): - await interaction.response.defer(ephemeral=True) - event = self._build_slash_event(interaction, f"/voice {mode}".strip()) - await self.handle_message(event) + await self._run_simple_slash(interaction, f"/voice {mode}".strip()) @tree.command(name="update", description="Update Hermes Agent to the latest version") async def slash_update(interaction: discord.Interaction): @@ -2096,6 +2137,11 @@ class DiscordAdapter(BasePlatformAdapter): if pending_text_injection: event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection + # Defense-in-depth: prevent empty user messages from entering session + # (can happen when user sends @mention-only with no other text) + if not event_text or not event_text.strip(): + event_text = "(The user sent a message with no text content)" + event = MessageEvent( text=event_text, message_type=msg_type, diff --git a/gateway/platforms/email.py b/gateway/platforms/email.py index f3e153c4e..a54bd94bb 100644 --- a/gateway/platforms/email.py +++ b/gateway/platforms/email.py @@ -43,6 +43,20 @@ from gateway.platforms.base import ( from gateway.config import Platform, PlatformConfig logger = logging.getLogger(__name__) +# Automated sender patterns — emails from these are silently ignored +_NOREPLY_PATTERNS = ( + "noreply", "no-reply", "no_reply", "donotreply", "do-not-reply", + "mailer-daemon", "postmaster", "bounce", "notifications@", + "automated@", "auto-confirm", "auto-reply", "automailer", +) + +# RFC headers that indicate bulk/automated mail +_AUTOMATED_HEADERS = { + "Auto-Submitted": lambda v: v.lower() != "no", + "Precedence": lambda v: v.lower() in ("bulk", "list", "junk"), + "X-Auto-Response-Suppress": lambda v: bool(v), + "List-Unsubscribe": lambda v: bool(v), +} # Gmail-safe max length per email body MAX_MESSAGE_LENGTH = 50_000 @@ -50,7 +64,17 @@ MAX_MESSAGE_LENGTH = 50_000 # Supported image extensions for inline detection _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp"} - +def _is_automated_sender(address: str, headers: dict) -> bool: + """Return True if this email is from an automated/noreply source.""" + addr = address.lower() + if any(pattern in addr for pattern in _NOREPLY_PATTERNS): + return True + for header, check in _AUTOMATED_HEADERS.items(): + value = headers.get(header, "") + if value and check(value): + return True + return False + def check_email_requirements() -> bool: """Check if email platform dependencies are available.""" addr = os.getenv("EMAIL_ADDRESS") @@ -213,6 +237,7 @@ class EmailAdapter(BasePlatformAdapter): # Track message IDs we've already processed to avoid duplicates self._seen_uids: set = set() + self._seen_uids_max: int = 2000 # cap to prevent unbounded memory growth self._poll_task: Optional[asyncio.Task] = None # Map chat_id (sender email) -> last subject + message-id for threading @@ -220,6 +245,26 @@ class EmailAdapter(BasePlatformAdapter): logger.info("[Email] Adapter initialized for %s", self._address) + def _trim_seen_uids(self) -> None: + """Keep only the most recent UIDs to prevent unbounded memory growth. + + IMAP UIDs are monotonically increasing integers. When the set grows + beyond the cap, we keep only the highest half — old UIDs are safe to + drop because new messages always have higher UIDs and IMAP's UNSEEN + flag prevents re-delivery regardless. + """ + if len(self._seen_uids) <= self._seen_uids_max: + return + try: + # UIDs are bytes like b'1234' — sort numerically and keep top half + sorted_uids = sorted(self._seen_uids, key=lambda u: int(u)) + keep = self._seen_uids_max // 2 + self._seen_uids = set(sorted_uids[-keep:]) + logger.debug("[Email] Trimmed seen UIDs to %d entries", len(self._seen_uids)) + except (ValueError, TypeError): + # Fallback: just clear old entries if sort fails + self._seen_uids = set(list(self._seen_uids)[-self._seen_uids_max // 2:]) + async def connect(self) -> bool: """Connect to the IMAP server and start polling for new messages.""" try: @@ -232,6 +277,8 @@ class EmailAdapter(BasePlatformAdapter): if status == "OK" and data and data[0]: for uid in data[0].split(): self._seen_uids.add(uid) + # Keep only the most recent UIDs to prevent unbounded growth + self._trim_seen_uids() imap.logout() logger.info("[Email] IMAP connection test passed. %d existing messages skipped.", len(self._seen_uids)) except Exception as e: @@ -290,52 +337,63 @@ class EmailAdapter(BasePlatformAdapter): results = [] try: imap = imaplib.IMAP4_SSL(self._imap_host, self._imap_port, timeout=30) - imap.login(self._address, self._password) - imap.select("INBOX") + try: + imap.login(self._address, self._password) + imap.select("INBOX") - status, data = imap.uid("search", None, "UNSEEN") - if status != "OK" or not data or not data[0]: - imap.logout() - return results + status, data = imap.uid("search", None, "UNSEEN") + if status != "OK" or not data or not data[0]: + return results - for uid in data[0].split(): - if uid in self._seen_uids: - continue - self._seen_uids.add(uid) + for uid in data[0].split(): + if uid in self._seen_uids: + continue + self._seen_uids.add(uid) + # Trim periodically to prevent unbounded memory growth + if len(self._seen_uids) > self._seen_uids_max: + self._trim_seen_uids() - status, msg_data = imap.uid("fetch", uid, "(RFC822)") - if status != "OK": - continue + status, msg_data = imap.uid("fetch", uid, "(RFC822)") + if status != "OK": + continue - raw_email = msg_data[0][1] - msg = email_lib.message_from_bytes(raw_email) + raw_email = msg_data[0][1] + msg = email_lib.message_from_bytes(raw_email) - sender_raw = msg.get("From", "") - sender_addr = _extract_email_address(sender_raw) - sender_name = _decode_header_value(sender_raw) - # Remove email from name if present - if "<" in sender_name: - sender_name = sender_name.split("<")[0].strip().strip('"') + sender_raw = msg.get("From", "") + sender_addr = _extract_email_address(sender_raw) + sender_name = _decode_header_value(sender_raw) + # Remove email from name if present + if "<" in sender_name: + sender_name = sender_name.split("<")[0].strip().strip('"') - subject = _decode_header_value(msg.get("Subject", "(no subject)")) - message_id = msg.get("Message-ID", "") - in_reply_to = msg.get("In-Reply-To", "") - body = _extract_text_body(msg) - attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments) + subject = _decode_header_value(msg.get("Subject", "(no subject)")) + message_id = msg.get("Message-ID", "") + in_reply_to = msg.get("In-Reply-To", "") + # Skip automated/noreply senders before any processing + msg_headers = dict(msg.items()) + if _is_automated_sender(sender_addr, msg_headers): + logger.debug("[Email] Skipping automated sender: %s", sender_addr) + continue + body = _extract_text_body(msg) + attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments) - results.append({ - "uid": uid, - "sender_addr": sender_addr, - "sender_name": sender_name, - "subject": subject, - "message_id": message_id, - "in_reply_to": in_reply_to, - "body": body, - "attachments": attachments, - "date": msg.get("Date", ""), - }) - - imap.logout() + results.append({ + "uid": uid, + "sender_addr": sender_addr, + "sender_name": sender_name, + "subject": subject, + "message_id": message_id, + "in_reply_to": in_reply_to, + "body": body, + "attachments": attachments, + "date": msg.get("Date", ""), + }) + finally: + try: + imap.logout() + except Exception: + pass except Exception as e: logger.error("[Email] IMAP fetch error: %s", e) return results @@ -348,6 +406,11 @@ class EmailAdapter(BasePlatformAdapter): if sender_addr == self._address.lower(): return + # Never reply to automated senders + if _is_automated_sender(sender_addr, {}): + logger.debug("[Email] Dropping automated sender at dispatch: %s", sender_addr) + return + subject = msg_data["subject"] body = msg_data["body"].strip() attachments = msg_data["attachments"] @@ -443,10 +506,15 @@ class EmailAdapter(BasePlatformAdapter): msg.attach(MIMEText(body, "plain", "utf-8")) smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30) - smtp.starttls(context=ssl.create_default_context()) - smtp.login(self._address, self._password) - smtp.send_message(msg) - smtp.quit() + try: + smtp.starttls(context=ssl.create_default_context()) + smtp.login(self._address, self._password) + smtp.send_message(msg) + finally: + try: + smtp.quit() + except Exception: + smtp.close() logger.info("[Email] Sent reply to %s (subject: %s)", to_addr, subject) return msg_id @@ -530,10 +598,15 @@ class EmailAdapter(BasePlatformAdapter): msg.attach(part) smtp = smtplib.SMTP(self._smtp_host, self._smtp_port, timeout=30) - smtp.starttls(context=ssl.create_default_context()) - smtp.login(self._address, self._password) - smtp.send_message(msg) - smtp.quit() + try: + smtp.starttls(context=ssl.create_default_context()) + smtp.login(self._address, self._password) + smtp.send_message(msg) + finally: + try: + smtp.quit() + except Exception: + smtp.close() return msg_id diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py new file mode 100644 index 000000000..d9aaae9a7 --- /dev/null +++ b/gateway/platforms/feishu.py @@ -0,0 +1,3255 @@ +""" +Feishu/Lark platform adapter. + +Supports: +- WebSocket long connection and Webhook transport +- Direct-message and group @mention-gated text receive/send +- Inbound image/file/audio/media caching +- Gateway allowlist integration via FEISHU_ALLOWED_USERS +- Persistent dedup state across restarts +- Per-chat serial message processing (matches openclaw createChatQueue) +- Persistent ACK emoji reaction on inbound messages +- Reaction events routed as synthetic text events (matches openclaw) +- Interactive card button-click events routed as synthetic COMMAND events +- Webhook anomaly tracking (matches openclaw createWebhookAnomalyTracker) +- Verification token validation as second auth layer (matches openclaw) +""" + +from __future__ import annotations + +import asyncio +import hashlib +import hmac +import json +import logging +import mimetypes +import os +import re +import threading +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List, Optional + +# aiohttp/websockets are independent optional deps — import outside lark_oapi +# so they remain available for tests and webhook mode even if lark_oapi is missing. +try: + import aiohttp + from aiohttp import web +except ImportError: + aiohttp = None # type: ignore[assignment] + web = None # type: ignore[assignment] + +try: + import websockets +except ImportError: + websockets = None # type: ignore[assignment] + +try: + import lark_oapi as lark + from lark_oapi.api.application.v6 import GetApplicationRequest + from lark_oapi.api.im.v1 import ( + CreateFileRequest, + CreateFileRequestBody, + CreateImageRequest, + CreateImageRequestBody, + CreateMessageRequest, + CreateMessageRequestBody, + GetChatRequest, + GetMessageRequest, + GetImageRequest, + GetMessageResourceRequest, + P2ImMessageMessageReadV1, + ReplyMessageRequest, + ReplyMessageRequestBody, + UpdateMessageRequest, + UpdateMessageRequestBody, + ) + from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN + from lark_oapi.event.callback.model.p2_card_action_trigger import P2CardActionTriggerResponse + from lark_oapi.event.dispatcher_handler import EventDispatcherHandler + from lark_oapi.ws import Client as FeishuWSClient + + FEISHU_AVAILABLE = True +except ImportError: + FEISHU_AVAILABLE = False + lark = None # type: ignore[assignment] + P2CardActionTriggerResponse = None # type: ignore[assignment] + EventDispatcherHandler = None # type: ignore[assignment] + FeishuWSClient = None # type: ignore[assignment] + FEISHU_DOMAIN = None # type: ignore[assignment] + LARK_DOMAIN = None # type: ignore[assignment] + +FEISHU_WEBSOCKET_AVAILABLE = websockets is not None +FEISHU_WEBHOOK_AVAILABLE = aiohttp is not None + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + SUPPORTED_DOCUMENT_TYPES, + cache_document_from_bytes, + cache_image_from_url, + cache_audio_from_bytes, + cache_image_from_bytes, +) +from gateway.status import acquire_scoped_lock, release_scoped_lock +from hermes_constants import get_hermes_home + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Regex patterns +# --------------------------------------------------------------------------- + +_MARKDOWN_HINT_RE = re.compile( + r"(^#{1,6}\s)|(^\s*[-*]\s)|(^\s*\d+\.\s)|(^\s*---+\s*$)|(```)|(`[^`\n]+`)|(\*\*[^*\n].+?\*\*)|(~~[^~\n].+?~~)|(.+?)|(\*[^*\n]+\*)|(\[[^\]]+\]\([^)]+\))|(^>\s)", + re.MULTILINE, +) +_MARKDOWN_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") +_MENTION_RE = re.compile(r"@_user_\d+") +_MULTISPACE_RE = re.compile(r"[ \t]{2,}") +_POST_CONTENT_INVALID_RE = re.compile(r"content format of the post type is incorrect", re.IGNORECASE) +# --------------------------------------------------------------------------- +# Media type sets and upload constants +# --------------------------------------------------------------------------- + +_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} +_AUDIO_EXTENSIONS = {".ogg", ".mp3", ".wav", ".m4a", ".aac", ".flac", ".opus", ".webm"} +_VIDEO_EXTENSIONS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".3gp"} +_DOCUMENT_MIME_TO_EXT = {mime: ext for ext, mime in SUPPORTED_DOCUMENT_TYPES.items()} +_FEISHU_IMAGE_UPLOAD_TYPE = "message" +_FEISHU_FILE_UPLOAD_TYPE = "stream" +_FEISHU_OPUS_UPLOAD_EXTENSIONS = {".ogg", ".opus"} +_FEISHU_MEDIA_UPLOAD_EXTENSIONS = {".mp4", ".mov", ".avi", ".m4v"} +_FEISHU_DOC_UPLOAD_TYPES = { + ".pdf": "pdf", + ".doc": "doc", + ".docx": "doc", + ".xls": "xls", + ".xlsx": "xls", + ".ppt": "ppt", + ".pptx": "ppt", +} +# --------------------------------------------------------------------------- +# Connection, retry and batching tuning +# --------------------------------------------------------------------------- + +_MAX_TEXT_INJECT_BYTES = 100 * 1024 +_FEISHU_CONNECT_ATTEMPTS = 3 +_FEISHU_SEND_ATTEMPTS = 3 +_FEISHU_APP_LOCK_SCOPE = "feishu-app-id" +_DEFAULT_TEXT_BATCH_DELAY_SECONDS = 0.6 +_DEFAULT_TEXT_BATCH_MAX_MESSAGES = 8 +_DEFAULT_TEXT_BATCH_MAX_CHARS = 4000 +_DEFAULT_MEDIA_BATCH_DELAY_SECONDS = 0.8 +_DEFAULT_DEDUP_CACHE_SIZE = 2048 +_DEFAULT_WEBHOOK_HOST = "127.0.0.1" +_DEFAULT_WEBHOOK_PORT = 8765 +_DEFAULT_WEBHOOK_PATH = "/feishu/webhook" +# --------------------------------------------------------------------------- +# TTL, rate-limit and webhook security constants +# --------------------------------------------------------------------------- + +_FEISHU_DEDUP_TTL_SECONDS = 24 * 60 * 60 # 24 hours — matches openclaw +_FEISHU_SENDER_NAME_TTL_SECONDS = 10 * 60 # 10 minutes sender-name cache +_FEISHU_WEBHOOK_MAX_BODY_BYTES = 1 * 1024 * 1024 # 1 MB body limit +_FEISHU_WEBHOOK_RATE_WINDOW_SECONDS = 60 # sliding window for rate limiter +_FEISHU_WEBHOOK_RATE_LIMIT_MAX = 120 # max requests per window per IP — matches openclaw +_FEISHU_WEBHOOK_RATE_MAX_KEYS = 4096 # max tracked keys (prevents unbounded growth) +_FEISHU_WEBHOOK_BODY_TIMEOUT_SECONDS = 30 # max seconds to read request body +_FEISHU_WEBHOOK_ANOMALY_THRESHOLD = 25 # consecutive error responses before WARNING log +_FEISHU_WEBHOOK_ANOMALY_TTL_SECONDS = 6 * 60 * 60 # anomaly tracker TTL (6 hours) — matches openclaw +_FEISHU_CARD_ACTION_DEDUP_TTL_SECONDS = 15 * 60 # card action token dedup window (15 min) +_FEISHU_BOT_MSG_TRACK_SIZE = 512 # LRU size for tracking sent message IDs +_FEISHU_REPLY_FALLBACK_CODES = frozenset({230011, 231003}) # reply target withdrawn/missing → create fallback +_FEISHU_ACK_EMOJI = "OK" +# --------------------------------------------------------------------------- +# Fallback display strings +# --------------------------------------------------------------------------- + +FALLBACK_POST_TEXT = "[Rich text message]" +FALLBACK_FORWARD_TEXT = "[Merged forward message]" +FALLBACK_SHARE_CHAT_TEXT = "[Shared chat]" +FALLBACK_INTERACTIVE_TEXT = "[Interactive message]" +FALLBACK_IMAGE_TEXT = "[Image]" +FALLBACK_ATTACHMENT_TEXT = "[Attachment]" +# --------------------------------------------------------------------------- +# Post/card parsing helpers +# --------------------------------------------------------------------------- + +_PREFERRED_LOCALES = ("zh_cn", "en_us") +_MARKDOWN_SPECIAL_CHARS_RE = re.compile(r"([\\`*_{}\[\]()#+\-!|>~])") +_MENTION_PLACEHOLDER_RE = re.compile(r"@_user_\d+") +_WHITESPACE_RE = re.compile(r"\s+") +_SUPPORTED_CARD_TEXT_KEYS = ( + "title", + "text", + "content", + "label", + "value", + "name", + "summary", + "subtitle", + "description", + "placeholder", + "hint", +) +_SKIP_TEXT_KEYS = { + "tag", + "type", + "msg_type", + "message_type", + "chat_id", + "open_chat_id", + "share_chat_id", + "file_key", + "image_key", + "user_id", + "open_id", + "union_id", + "url", + "href", + "link", + "token", + "template", + "locale", +} + + +@dataclass(frozen=True) +class FeishuPostMediaRef: + file_key: str + file_name: str = "" + resource_type: str = "file" + + +@dataclass(frozen=True) +class FeishuPostParseResult: + text_content: str + image_keys: List[str] = field(default_factory=list) + media_refs: List[FeishuPostMediaRef] = field(default_factory=list) + mentioned_ids: List[str] = field(default_factory=list) + + +@dataclass(frozen=True) +class FeishuNormalizedMessage: + raw_type: str + text_content: str + preferred_message_type: str = "text" + image_keys: List[str] = field(default_factory=list) + media_refs: List[FeishuPostMediaRef] = field(default_factory=list) + mentioned_ids: List[str] = field(default_factory=list) + relation_kind: str = "plain" + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class FeishuAdapterSettings: + app_id: str + app_secret: str + domain_name: str + connection_mode: str + encrypt_key: str + verification_token: str + group_policy: str + allowed_group_users: frozenset[str] + bot_open_id: str + bot_user_id: str + bot_name: str + dedup_cache_size: int + text_batch_delay_seconds: float + text_batch_max_messages: int + text_batch_max_chars: int + media_batch_delay_seconds: float + webhook_host: str + webhook_port: int + webhook_path: str + + +@dataclass +class FeishuBatchState: + events: Dict[str, MessageEvent] = field(default_factory=dict) + tasks: Dict[str, asyncio.Task] = field(default_factory=dict) + counts: Dict[str, int] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Markdown rendering helpers +# --------------------------------------------------------------------------- + + +def _escape_markdown_text(text: str) -> str: + return _MARKDOWN_SPECIAL_CHARS_RE.sub(r"\\\1", text) + + +def _to_boolean(value: Any) -> bool: + return value is True or value == 1 or value == "true" + + +def _is_style_enabled(style: Dict[str, Any] | None, key: str) -> bool: + if not style: + return False + return _to_boolean(style.get(key)) + + +def _wrap_inline_code(text: str) -> str: + max_run = max([0, *[len(run) for run in re.findall(r"`+", text)]]) + fence = "`" * (max_run + 1) + body = f" {text} " if text.startswith("`") or text.endswith("`") else text + return f"{fence}{body}{fence}" + + +def _sanitize_fence_language(language: str) -> str: + return language.strip().replace("\n", " ").replace("\r", " ") + + +def _render_text_element(element: Dict[str, Any]) -> str: + text = str(element.get("text", "") or "") + style = element.get("style") + style_dict = style if isinstance(style, dict) else None + + if _is_style_enabled(style_dict, "code"): + return _wrap_inline_code(text) + + rendered = _escape_markdown_text(text) + if not rendered: + return "" + if _is_style_enabled(style_dict, "bold"): + rendered = f"**{rendered}**" + if _is_style_enabled(style_dict, "italic"): + rendered = f"*{rendered}*" + if _is_style_enabled(style_dict, "underline"): + rendered = f"{rendered}" + if _is_style_enabled(style_dict, "strikethrough"): + rendered = f"~~{rendered}~~" + return rendered + + +def _render_code_block_element(element: Dict[str, Any]) -> str: + language = _sanitize_fence_language( + str(element.get("language", "") or "") or str(element.get("lang", "") or "") + ) + code = ( + str(element.get("text", "") or "") or str(element.get("content", "") or "") + ).replace("\r\n", "\n") + trailing_newline = "" if code.endswith("\n") else "\n" + return f"```{language}\n{code}{trailing_newline}```" + + +def _strip_markdown_to_plain_text(text: str) -> str: + plain = text.replace("\r\n", "\n") + plain = _MARKDOWN_LINK_RE.sub(lambda m: f"{m.group(1)} ({m.group(2).strip()})", plain) + plain = re.sub(r"^#{1,6}\s+", "", plain, flags=re.MULTILINE) + plain = re.sub(r"^>\s?", "", plain, flags=re.MULTILINE) + plain = re.sub(r"^\s*---+\s*$", "---", plain, flags=re.MULTILINE) + plain = re.sub(r"```(?:[^\n]*\n)?([\s\S]*?)```", lambda m: m.group(1).strip("\n"), plain) + plain = re.sub(r"`([^`\n]+)`", r"\1", plain) + plain = re.sub(r"\*\*([^*\n]+)\*\*", r"\1", plain) + plain = re.sub(r"\*([^*\n]+)\*", r"\1", plain) + plain = re.sub(r"~~([^~\n]+)~~", r"\1", plain) + plain = re.sub(r"([\s\S]*?)", r"\1", plain) + plain = re.sub(r"\n{3,}", "\n\n", plain) + return plain.strip() + + +# --------------------------------------------------------------------------- +# Post payload builders and parsers +# --------------------------------------------------------------------------- + + +def _build_markdown_post_payload(content: str) -> str: + return json.dumps( + { + "zh_cn": { + "content": [ + [ + { + "tag": "md", + "text": content, + } + ] + ], + } + }, + ensure_ascii=False, + ) + + +def parse_feishu_post_content(raw_content: str) -> FeishuPostParseResult: + try: + parsed = json.loads(raw_content) if raw_content else {} + except json.JSONDecodeError: + return FeishuPostParseResult(text_content=FALLBACK_POST_TEXT) + return parse_feishu_post_payload(parsed) + + +def parse_feishu_post_payload(payload: Any) -> FeishuPostParseResult: + resolved = _resolve_post_payload(payload) + if not resolved: + return FeishuPostParseResult(text_content=FALLBACK_POST_TEXT) + + image_keys: List[str] = [] + media_refs: List[FeishuPostMediaRef] = [] + mentioned_ids: List[str] = [] + parts: List[str] = [] + + title = _normalize_feishu_text(str(resolved.get("title", "")).strip()) + if title: + parts.append(title) + + for row in resolved.get("content", []) or []: + if not isinstance(row, list): + continue + row_text = _normalize_feishu_text( + "".join(_render_post_element(item, image_keys, media_refs, mentioned_ids) for item in row) + ) + if row_text: + parts.append(row_text) + + return FeishuPostParseResult( + text_content="\n".join(parts).strip() or FALLBACK_POST_TEXT, + image_keys=image_keys, + media_refs=media_refs, + mentioned_ids=mentioned_ids, + ) + + +def _resolve_post_payload(payload: Any) -> Dict[str, Any]: + direct = _to_post_payload(payload) + if direct: + return direct + if not isinstance(payload, dict): + return {} + + wrapped = payload.get("post") + wrapped_direct = _resolve_locale_payload(wrapped) + if wrapped_direct: + return wrapped_direct + return _resolve_locale_payload(payload) + + +def _resolve_locale_payload(payload: Any) -> Dict[str, Any]: + direct = _to_post_payload(payload) + if direct: + return direct + if not isinstance(payload, dict): + return {} + + for key in _PREFERRED_LOCALES: + candidate = _to_post_payload(payload.get(key)) + if candidate: + return candidate + for value in payload.values(): + candidate = _to_post_payload(value) + if candidate: + return candidate + return {} + + +def _to_post_payload(candidate: Any) -> Dict[str, Any]: + if not isinstance(candidate, dict): + return {} + content = candidate.get("content") + if not isinstance(content, list): + return {} + return { + "title": str(candidate.get("title", "") or ""), + "content": content, + } + + +def _render_post_element( + element: Any, + image_keys: List[str], + media_refs: List[FeishuPostMediaRef], + mentioned_ids: List[str], +) -> str: + if isinstance(element, str): + return element + if not isinstance(element, dict): + return "" + + tag = str(element.get("tag", "")).strip().lower() + if tag == "text": + return _render_text_element(element) + if tag == "a": + href = str(element.get("href", "")).strip() + label = str(element.get("text", href) or "").strip() + if not label: + return "" + escaped_label = _escape_markdown_text(label) + return f"[{escaped_label}]({href})" if href else escaped_label + if tag == "at": + mentioned_id = ( + str(element.get("open_id", "")).strip() + or str(element.get("user_id", "")).strip() + ) + if mentioned_id and mentioned_id not in mentioned_ids: + mentioned_ids.append(mentioned_id) + display_name = ( + str(element.get("user_name", "")).strip() + or str(element.get("name", "")).strip() + or str(element.get("text", "")).strip() + or mentioned_id + ) + return f"@{_escape_markdown_text(display_name)}" if display_name else "@" + if tag in {"img", "image"}: + image_key = str(element.get("image_key", "")).strip() + if image_key and image_key not in image_keys: + image_keys.append(image_key) + alt = str(element.get("text", "")).strip() or str(element.get("alt", "")).strip() + return f"[Image: {alt}]" if alt else "[Image]" + if tag in {"media", "file", "audio", "video"}: + file_key = str(element.get("file_key", "")).strip() + file_name = ( + str(element.get("file_name", "")).strip() + or str(element.get("title", "")).strip() + or str(element.get("text", "")).strip() + ) + if file_key: + media_refs.append( + FeishuPostMediaRef( + file_key=file_key, + file_name=file_name, + resource_type=tag if tag in {"audio", "video"} else "file", + ) + ) + return f"[Attachment: {file_name}]" if file_name else "[Attachment]" + if tag in {"emotion", "emoji"}: + label = str(element.get("text", "")).strip() or str(element.get("emoji_type", "")).strip() + return f":{_escape_markdown_text(label)}:" if label else "[Emoji]" + if tag == "br": + return "\n" + if tag in {"hr", "divider"}: + return "\n\n---\n\n" + if tag == "code": + code = str(element.get("text", "") or "") or str(element.get("content", "") or "") + return _wrap_inline_code(code) if code else "" + if tag in {"code_block", "pre"}: + return _render_code_block_element(element) + + nested_parts: List[str] = [] + for key in ("text", "title", "content", "children", "elements"): + value = element.get(key) + extracted = _render_nested_post(value, image_keys, media_refs, mentioned_ids) + if extracted: + nested_parts.append(extracted) + return " ".join(part for part in nested_parts if part) + + +def _render_nested_post( + value: Any, + image_keys: List[str], + media_refs: List[FeishuPostMediaRef], + mentioned_ids: List[str], +) -> str: + if isinstance(value, str): + return _escape_markdown_text(value) + if isinstance(value, list): + return " ".join( + part + for item in value + for part in [_render_nested_post(item, image_keys, media_refs, mentioned_ids)] + if part + ) + if isinstance(value, dict): + direct = _render_post_element(value, image_keys, media_refs, mentioned_ids) + if direct: + return direct + return " ".join( + part + for item in value.values() + for part in [_render_nested_post(item, image_keys, media_refs, mentioned_ids)] + if part + ) + return "" + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def normalize_feishu_message(*, message_type: str, raw_content: str) -> FeishuNormalizedMessage: + normalized_type = str(message_type or "").strip().lower() + payload = _load_feishu_payload(raw_content) + + if normalized_type == "text": + return FeishuNormalizedMessage( + raw_type=normalized_type, + text_content=_normalize_feishu_text(str(payload.get("text", "") or "")), + ) + if normalized_type == "post": + parsed_post = parse_feishu_post_payload(payload) + return FeishuNormalizedMessage( + raw_type=normalized_type, + text_content=parsed_post.text_content, + image_keys=list(parsed_post.image_keys), + media_refs=list(parsed_post.media_refs), + mentioned_ids=list(parsed_post.mentioned_ids), + relation_kind="post", + ) + if normalized_type == "image": + image_key = str(payload.get("image_key", "") or "").strip() + alt_text = _normalize_feishu_text( + str(payload.get("text", "") or "") + or str(payload.get("alt", "") or "") + or FALLBACK_IMAGE_TEXT + ) + return FeishuNormalizedMessage( + raw_type=normalized_type, + text_content=alt_text if alt_text != FALLBACK_IMAGE_TEXT else "", + preferred_message_type="photo", + image_keys=[image_key] if image_key else [], + relation_kind="image", + ) + if normalized_type in {"file", "audio", "media"}: + media_ref = _build_media_ref_from_payload(payload, resource_type=normalized_type) + placeholder = _attachment_placeholder(media_ref.file_name) + return FeishuNormalizedMessage( + raw_type=normalized_type, + text_content="", + preferred_message_type="audio" if normalized_type == "audio" else "document", + media_refs=[media_ref] if media_ref.file_key else [], + relation_kind=normalized_type, + metadata={"placeholder_text": placeholder}, + ) + if normalized_type == "merge_forward": + return _normalize_merge_forward_message(payload) + if normalized_type == "share_chat": + return _normalize_share_chat_message(payload) + if normalized_type in {"interactive", "card"}: + return _normalize_interactive_message(normalized_type, payload) + + return FeishuNormalizedMessage(raw_type=normalized_type, text_content="") + + +def _load_feishu_payload(raw_content: str) -> Dict[str, Any]: + try: + parsed = json.loads(raw_content) if raw_content else {} + except json.JSONDecodeError: + return {"text": raw_content} + return parsed if isinstance(parsed, dict) else {"content": parsed} + + +def _normalize_merge_forward_message(payload: Dict[str, Any]) -> FeishuNormalizedMessage: + title = _first_non_empty_text( + payload.get("title"), + payload.get("summary"), + payload.get("preview"), + _find_first_text(payload, keys=("title", "summary", "preview", "description")), + ) + entries = _collect_forward_entries(payload) + lines: List[str] = [] + if title: + lines.append(title) + lines.extend(entries[:8]) + text_content = "\n".join(lines).strip() or FALLBACK_FORWARD_TEXT + return FeishuNormalizedMessage( + raw_type="merge_forward", + text_content=text_content, + relation_kind="merge_forward", + metadata={"entry_count": len(entries), "title": title}, + ) + + +def _normalize_share_chat_message(payload: Dict[str, Any]) -> FeishuNormalizedMessage: + chat_name = _first_non_empty_text( + payload.get("chat_name"), + payload.get("name"), + payload.get("title"), + _find_first_text(payload, keys=("chat_name", "name", "title")), + ) + share_id = _first_non_empty_text( + payload.get("chat_id"), + payload.get("open_chat_id"), + payload.get("share_chat_id"), + ) + lines = [] + if chat_name: + lines.append(f"Shared chat: {chat_name}") + else: + lines.append(FALLBACK_SHARE_CHAT_TEXT) + if share_id: + lines.append(f"Chat ID: {share_id}") + text_content = "\n".join(lines) + return FeishuNormalizedMessage( + raw_type="share_chat", + text_content=text_content, + relation_kind="share_chat", + metadata={"chat_id": share_id, "chat_name": chat_name}, + ) + + +def _normalize_interactive_message(message_type: str, payload: Dict[str, Any]) -> FeishuNormalizedMessage: + card_payload = payload.get("card") if isinstance(payload.get("card"), dict) else payload + title = _first_non_empty_text( + _find_header_title(card_payload), + payload.get("title"), + _find_first_text(card_payload, keys=("title", "summary", "subtitle")), + ) + body_lines = _collect_card_lines(card_payload) + actions = _collect_action_labels(card_payload) + + lines: List[str] = [] + if title: + lines.append(title) + for line in body_lines: + if line != title: + lines.append(line) + if actions: + lines.append(f"Actions: {', '.join(actions)}") + + text_content = "\n".join(lines[:12]).strip() or FALLBACK_INTERACTIVE_TEXT + return FeishuNormalizedMessage( + raw_type=message_type, + text_content=text_content, + relation_kind="interactive", + metadata={"title": title, "actions": actions}, + ) + + +# --------------------------------------------------------------------------- +# Content extraction utilities (card / forward / text walking) +# --------------------------------------------------------------------------- + + +def _collect_forward_entries(payload: Dict[str, Any]) -> List[str]: + candidates: List[Any] = [] + for key in ("messages", "items", "message_list", "records", "content"): + value = payload.get(key) + if isinstance(value, list): + candidates.extend(value) + entries: List[str] = [] + for item in candidates: + if not isinstance(item, dict): + text = _normalize_feishu_text(str(item or "")) + if text: + entries.append(f"- {text}") + continue + sender = _first_non_empty_text( + item.get("sender_name"), + item.get("user_name"), + item.get("sender"), + item.get("name"), + ) + nested_type = str(item.get("message_type", "") or item.get("msg_type", "")).strip().lower() + if nested_type == "post": + body = parse_feishu_post_payload(item.get("content") or item).text_content + else: + body = _first_non_empty_text( + item.get("text"), + item.get("summary"), + item.get("preview"), + item.get("content"), + _find_first_text(item, keys=("text", "content", "summary", "preview", "title")), + ) + body = _normalize_feishu_text(body) + if sender and body: + entries.append(f"- {sender}: {body}") + elif body: + entries.append(f"- {body}") + return _unique_lines(entries) + + +def _collect_card_lines(payload: Any) -> List[str]: + lines = _collect_text_segments(payload, in_rich_block=False) + normalized = [_normalize_feishu_text(line) for line in lines] + return _unique_lines([line for line in normalized if line]) + + +def _collect_action_labels(payload: Any) -> List[str]: + labels: List[str] = [] + for item in _walk_nodes(payload): + if not isinstance(item, dict): + continue + tag = str(item.get("tag", "") or item.get("type", "")).strip().lower() + if tag not in {"button", "select_static", "overflow", "date_picker", "picker"}: + continue + label = _first_non_empty_text( + item.get("text"), + item.get("name"), + item.get("value"), + _find_first_text(item, keys=("text", "content", "name", "value")), + ) + if label: + labels.append(label) + return _unique_lines(labels) + + +def _collect_text_segments(value: Any, *, in_rich_block: bool) -> List[str]: + if isinstance(value, str): + return [_normalize_feishu_text(value)] if in_rich_block else [] + if isinstance(value, list): + segments: List[str] = [] + for item in value: + segments.extend(_collect_text_segments(item, in_rich_block=in_rich_block)) + return segments + if not isinstance(value, dict): + return [] + + tag = str(value.get("tag", "") or value.get("type", "")).strip().lower() + next_in_rich_block = in_rich_block or tag in { + "plain_text", + "lark_md", + "markdown", + "note", + "div", + "column_set", + "column", + "action", + "button", + "select_static", + "date_picker", + } + + segments: List[str] = [] + for key in _SUPPORTED_CARD_TEXT_KEYS: + item = value.get(key) + if isinstance(item, str) and next_in_rich_block: + normalized = _normalize_feishu_text(item) + if normalized: + segments.append(normalized) + + for key, item in value.items(): + if key in _SKIP_TEXT_KEYS: + continue + segments.extend(_collect_text_segments(item, in_rich_block=next_in_rich_block)) + return segments + + +def _build_media_ref_from_payload(payload: Dict[str, Any], *, resource_type: str) -> FeishuPostMediaRef: + file_key = str(payload.get("file_key", "") or "").strip() + file_name = _first_non_empty_text( + payload.get("file_name"), + payload.get("title"), + payload.get("text"), + ) + effective_type = resource_type if resource_type in {"audio", "video"} else "file" + return FeishuPostMediaRef(file_key=file_key, file_name=file_name, resource_type=effective_type) + + +def _attachment_placeholder(file_name: str) -> str: + normalized_name = _normalize_feishu_text(file_name) + return f"[Attachment: {normalized_name}]" if normalized_name else FALLBACK_ATTACHMENT_TEXT + + +def _find_header_title(payload: Any) -> str: + if not isinstance(payload, dict): + return "" + header = payload.get("header") + if not isinstance(header, dict): + return "" + title = header.get("title") + if isinstance(title, dict): + return _first_non_empty_text(title.get("content"), title.get("text"), title.get("name")) + return _normalize_feishu_text(str(title or "")) + + +def _find_first_text(payload: Any, *, keys: tuple[str, ...]) -> str: + for node in _walk_nodes(payload): + if not isinstance(node, dict): + continue + for key in keys: + value = node.get(key) + if isinstance(value, str): + normalized = _normalize_feishu_text(value) + if normalized: + return normalized + return "" + + +def _walk_nodes(value: Any): + if isinstance(value, dict): + yield value + for item in value.values(): + yield from _walk_nodes(item) + elif isinstance(value, list): + for item in value: + yield from _walk_nodes(item) + + +def _first_non_empty_text(*values: Any) -> str: + for value in values: + if isinstance(value, str): + normalized = _normalize_feishu_text(value) + if normalized: + return normalized + elif value is not None and not isinstance(value, (dict, list)): + normalized = _normalize_feishu_text(str(value)) + if normalized: + return normalized + return "" + + +# --------------------------------------------------------------------------- +# General text utilities +# --------------------------------------------------------------------------- + + +def _normalize_feishu_text(text: str) -> str: + cleaned = _MENTION_PLACEHOLDER_RE.sub(" ", text or "") + cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n") + cleaned = "\n".join(_WHITESPACE_RE.sub(" ", line).strip() for line in cleaned.split("\n")) + cleaned = "\n".join(line for line in cleaned.split("\n") if line) + cleaned = _MULTISPACE_RE.sub(" ", cleaned) + return cleaned.strip() + + +def _unique_lines(lines: List[str]) -> List[str]: + seen: set[str] = set() + unique: List[str] = [] + for line in lines: + if not line or line in seen: + continue + seen.add(line) + unique.append(line) + return unique + + +def _run_official_feishu_ws_client(ws_client: Any) -> None: + """Run the official Lark WS client in its own thread-local event loop.""" + import lark_oapi.ws.client as ws_client_module + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + ws_client_module.loop = loop + ws_client.start() + + +def check_feishu_requirements() -> bool: + """Check if Feishu/Lark dependencies are available.""" + return FEISHU_AVAILABLE + + +class FeishuAdapter(BasePlatformAdapter): + """Feishu/Lark bot adapter.""" + + MAX_MESSAGE_LENGTH = 8000 + + # ========================================================================= + # Lifecycle — init / settings / connect / disconnect + # ========================================================================= + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.FEISHU) + + self._settings = self._load_settings(config.extra or {}) + self._apply_settings(self._settings) + self._client: Optional[Any] = None + self._ws_client: Optional[Any] = None + self._ws_future: Optional[asyncio.Future] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._webhook_runner: Optional[Any] = None + self._webhook_site: Optional[Any] = None + self._event_handler = self._build_event_handler() + self._seen_message_ids: Dict[str, float] = {} # message_id → seen_at (time.time()) + self._seen_message_order: List[str] = [] + self._dedup_state_path = get_hermes_home() / "feishu_seen_message_ids.json" + self._dedup_lock = threading.Lock() + self._sender_name_cache: Dict[str, tuple[str, float]] = {} # sender_id → (name, expire_at) + self._webhook_rate_counts: Dict[str, tuple[int, float]] = {} # rate_key → (count, window_start) + self._webhook_anomaly_counts: Dict[str, tuple[int, str, float]] = {} # ip → (count, last_status, first_seen) + self._card_action_tokens: Dict[str, float] = {} # token → first_seen_time + self._chat_locks: Dict[str, asyncio.Lock] = {} # chat_id → lock (per-chat serial processing) + self._sent_message_ids_to_chat: Dict[str, str] = {} # message_id → chat_id (for reaction routing) + self._sent_message_id_order: List[str] = [] # LRU order for _sent_message_ids_to_chat + self._chat_info_cache: Dict[str, Dict[str, Any]] = {} + self._message_text_cache: Dict[str, Optional[str]] = {} + self._app_lock_identity: Optional[str] = None + self._text_batch_state = FeishuBatchState() + self._pending_text_batches = self._text_batch_state.events + self._pending_text_batch_tasks = self._text_batch_state.tasks + self._pending_text_batch_counts = self._text_batch_state.counts + self._media_batch_state = FeishuBatchState() + self._pending_media_batches = self._media_batch_state.events + self._pending_media_batch_tasks = self._media_batch_state.tasks + self._load_seen_message_ids() + + @staticmethod + def _load_settings(extra: Dict[str, Any]) -> FeishuAdapterSettings: + return FeishuAdapterSettings( + app_id=str(extra.get("app_id") or os.getenv("FEISHU_APP_ID", "")).strip(), + app_secret=str(extra.get("app_secret") or os.getenv("FEISHU_APP_SECRET", "")).strip(), + domain_name=str(extra.get("domain") or os.getenv("FEISHU_DOMAIN", "feishu")).strip().lower(), + connection_mode=str( + extra.get("connection_mode") or os.getenv("FEISHU_CONNECTION_MODE", "websocket") + ).strip().lower(), + encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", "").strip(), + verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", "").strip(), + group_policy=os.getenv("FEISHU_GROUP_POLICY", "allowlist").strip().lower(), + allowed_group_users=frozenset( + item.strip() + for item in os.getenv("FEISHU_ALLOWED_USERS", "").split(",") + if item.strip() + ), + bot_open_id=os.getenv("FEISHU_BOT_OPEN_ID", "").strip(), + bot_user_id=os.getenv("FEISHU_BOT_USER_ID", "").strip(), + bot_name=os.getenv("FEISHU_BOT_NAME", "").strip(), + dedup_cache_size=max( + 32, + int(os.getenv("HERMES_FEISHU_DEDUP_CACHE_SIZE", str(_DEFAULT_DEDUP_CACHE_SIZE))), + ), + text_batch_delay_seconds=float( + os.getenv("HERMES_FEISHU_TEXT_BATCH_DELAY_SECONDS", str(_DEFAULT_TEXT_BATCH_DELAY_SECONDS)) + ), + text_batch_max_messages=max( + 1, + int(os.getenv("HERMES_FEISHU_TEXT_BATCH_MAX_MESSAGES", str(_DEFAULT_TEXT_BATCH_MAX_MESSAGES))), + ), + text_batch_max_chars=max( + 1, + int(os.getenv("HERMES_FEISHU_TEXT_BATCH_MAX_CHARS", str(_DEFAULT_TEXT_BATCH_MAX_CHARS))), + ), + media_batch_delay_seconds=float( + os.getenv("HERMES_FEISHU_MEDIA_BATCH_DELAY_SECONDS", str(_DEFAULT_MEDIA_BATCH_DELAY_SECONDS)) + ), + webhook_host=str( + extra.get("webhook_host") or os.getenv("FEISHU_WEBHOOK_HOST", _DEFAULT_WEBHOOK_HOST) + ).strip(), + webhook_port=int( + extra.get("webhook_port") or os.getenv("FEISHU_WEBHOOK_PORT", str(_DEFAULT_WEBHOOK_PORT)) + ), + webhook_path=( + str(extra.get("webhook_path") or os.getenv("FEISHU_WEBHOOK_PATH", _DEFAULT_WEBHOOK_PATH)).strip() + or _DEFAULT_WEBHOOK_PATH + ), + ) + + def _apply_settings(self, settings: FeishuAdapterSettings) -> None: + self._app_id = settings.app_id + self._app_secret = settings.app_secret + self._domain_name = settings.domain_name + self._connection_mode = settings.connection_mode + self._encrypt_key = settings.encrypt_key + self._verification_token = settings.verification_token + self._group_policy = settings.group_policy + self._allowed_group_users = set(settings.allowed_group_users) + self._bot_open_id = settings.bot_open_id + self._bot_user_id = settings.bot_user_id + self._bot_name = settings.bot_name + self._dedup_cache_size = settings.dedup_cache_size + self._text_batch_delay_seconds = settings.text_batch_delay_seconds + self._text_batch_max_messages = settings.text_batch_max_messages + self._text_batch_max_chars = settings.text_batch_max_chars + self._media_batch_delay_seconds = settings.media_batch_delay_seconds + self._webhook_host = settings.webhook_host + self._webhook_port = settings.webhook_port + self._webhook_path = settings.webhook_path + + def _build_event_handler(self) -> Any: + if EventDispatcherHandler is None: + return None + return ( + EventDispatcherHandler.builder( + self._encrypt_key, + self._verification_token, + ) + .register_p2_im_message_message_read_v1(self._on_message_read_event) + .register_p2_im_message_receive_v1(self._on_message_event) + .register_p2_im_message_reaction_created_v1( + lambda data: self._on_reaction_event("im.message.reaction.created_v1", data) + ) + .register_p2_im_message_reaction_deleted_v1( + lambda data: self._on_reaction_event("im.message.reaction.deleted_v1", data) + ) + .register_p2_card_action_trigger(self._on_card_action_trigger) + .build() + ) + + async def connect(self) -> bool: + """Connect to Feishu/Lark.""" + if not FEISHU_AVAILABLE: + logger.error("[Feishu] lark-oapi not installed") + return False + if not self._app_id or not self._app_secret: + logger.error("[Feishu] FEISHU_APP_ID or FEISHU_APP_SECRET not set") + return False + if self._connection_mode not in {"websocket", "webhook"}: + logger.error( + "[Feishu] Unsupported FEISHU_CONNECTION_MODE=%s. Supported modes: websocket, webhook.", + self._connection_mode, + ) + return False + + try: + self._app_lock_identity = self._app_id + acquired, existing = acquire_scoped_lock( + _FEISHU_APP_LOCK_SCOPE, + self._app_lock_identity, + metadata={"platform": self.platform.value}, + ) + if not acquired: + owner_pid = existing.get("pid") if isinstance(existing, dict) else None + message = ( + "Another local Hermes gateway is already using this Feishu app_id" + + (f" (PID {owner_pid})." if owner_pid else ".") + + " Stop the other gateway before starting a second Feishu websocket client." + ) + logger.error("[Feishu] %s", message) + self._set_fatal_error("feishu_app_lock", message, retryable=False) + return False + + self._loop = asyncio.get_running_loop() + await self._connect_with_retry() + self._mark_connected() + logger.info("[Feishu] Connected in %s mode (%s)", self._connection_mode, self._domain_name) + return True + except Exception as exc: + await self._release_app_lock() + message = f"Feishu startup failed: {exc}" + self._set_fatal_error("feishu_connect_error", message, retryable=True) + logger.error("[Feishu] Failed to connect: %s", exc, exc_info=True) + return False + + async def disconnect(self) -> None: + """Disconnect from Feishu/Lark.""" + self._running = False + await self._cancel_pending_tasks(self._pending_text_batch_tasks) + await self._cancel_pending_tasks(self._pending_media_batch_tasks) + self._reset_batch_buffers() + self._disable_websocket_auto_reconnect() + await self._stop_webhook_server() + self._ws_future = None + self._loop = None + self._persist_seen_message_ids() + await self._release_app_lock() + + self._mark_disconnected() + logger.info("[Feishu] Disconnected") + + async def _cancel_pending_tasks(self, tasks: Dict[str, asyncio.Task]) -> None: + pending = [task for task in tasks.values() if task and not task.done()] + for task in pending: + task.cancel() + if pending: + await asyncio.gather(*pending, return_exceptions=True) + tasks.clear() + + def _reset_batch_buffers(self) -> None: + self._pending_text_batches.clear() + self._pending_text_batch_counts.clear() + self._pending_media_batches.clear() + + def _disable_websocket_auto_reconnect(self) -> None: + if self._ws_client is None: + return + try: + setattr(self._ws_client, "_auto_reconnect", False) + except Exception: + pass + finally: + self._ws_client = None + + async def _stop_webhook_server(self) -> None: + if self._webhook_runner is None: + return + try: + await self._webhook_runner.cleanup() + finally: + self._webhook_runner = None + self._webhook_site = None + + # ========================================================================= + # Outbound — send / edit / send_image / send_voice / … + # ========================================================================= + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a Feishu message.""" + if not self._client: + return SendResult(success=False, error="Not connected") + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) + last_response = None + + try: + for chunk in chunks: + msg_type, payload = self._build_outbound_payload(chunk) + try: + response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type=msg_type, + payload=payload, + reply_to=reply_to, + metadata=metadata, + ) + except Exception as exc: + if msg_type != "post" or not _POST_CONTENT_INVALID_RE.search(str(exc)): + raise + logger.warning("[Feishu] Invalid post payload rejected by API; falling back to plain text") + response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type="text", + payload=json.dumps({"text": _strip_markdown_to_plain_text(chunk)}, ensure_ascii=False), + reply_to=reply_to, + metadata=metadata, + ) + if ( + msg_type == "post" + and not self._response_succeeded(response) + and _POST_CONTENT_INVALID_RE.search(str(getattr(response, "msg", "") or "")) + ): + logger.warning("[Feishu] Post payload rejected by API response; falling back to plain text") + response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type="text", + payload=json.dumps({"text": _strip_markdown_to_plain_text(chunk)}, ensure_ascii=False), + reply_to=reply_to, + metadata=metadata, + ) + last_response = response + + return self._finalize_send_result(last_response, "send failed") + except Exception as exc: + logger.error("[Feishu] Send error: %s", exc, exc_info=True) + return SendResult(success=False, error=str(exc)) + + async def edit_message( + self, + chat_id: str, + message_id: str, + content: str, + ) -> SendResult: + """Edit a previously sent Feishu text/post message.""" + if not self._client: + return SendResult(success=False, error="Not connected") + + try: + msg_type, payload = self._build_outbound_payload(content) + body = self._build_update_message_body(msg_type=msg_type, content=payload) + request = self._build_update_message_request(message_id=message_id, request_body=body) + response = await asyncio.to_thread(self._client.im.v1.message.update, request) + result = self._finalize_send_result(response, "update failed") + if not result.success and msg_type == "post" and _POST_CONTENT_INVALID_RE.search(result.error or ""): + logger.warning("[Feishu] Invalid post update payload rejected by API; falling back to plain text") + fallback_body = self._build_update_message_body( + msg_type="text", + content=json.dumps({"text": _strip_markdown_to_plain_text(content)}, ensure_ascii=False), + ) + fallback_request = self._build_update_message_request(message_id=message_id, request_body=fallback_body) + fallback_response = await asyncio.to_thread(self._client.im.v1.message.update, fallback_request) + result = self._finalize_send_result(fallback_response, "update failed") + if result.success: + result.message_id = message_id + return result + except Exception as exc: + logger.error("[Feishu] Failed to edit message %s: %s", message_id, exc, exc_info=True) + return SendResult(success=False, error=str(exc)) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> SendResult: + """Send audio to Feishu as a file attachment plus optional caption.""" + return await self._send_uploaded_file_message( + chat_id=chat_id, + file_path=audio_path, + reply_to=reply_to, + metadata=metadata, + caption=caption, + outbound_message_type="audio", + ) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> SendResult: + """Send a document/file attachment to Feishu.""" + return await self._send_uploaded_file_message( + chat_id=chat_id, + file_path=file_path, + reply_to=reply_to, + metadata=metadata, + caption=caption, + file_name=file_name, + ) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> SendResult: + """Send a video file to Feishu.""" + return await self._send_uploaded_file_message( + chat_id=chat_id, + file_path=video_path, + reply_to=reply_to, + metadata=metadata, + caption=caption, + outbound_message_type="media", + ) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> SendResult: + """Send a local image file to Feishu.""" + if not self._client: + return SendResult(success=False, error="Not connected") + if not os.path.exists(image_path): + return SendResult(success=False, error=f"Image file not found: {image_path}") + + try: + with open(image_path, "rb") as image_file: + body = self._build_image_upload_body( + image_type=_FEISHU_IMAGE_UPLOAD_TYPE, + image=image_file, + ) + request = self._build_image_upload_request(body) + upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request) + image_key = self._extract_response_field(upload_response, "image_key") + if not image_key: + return self._response_error_result( + upload_response, + default_message="image upload failed", + override_error="Feishu image upload missing image_key", + ) + + if caption: + post_payload = self._build_media_post_payload( + caption=caption, + media_tag={"tag": "img", "image_key": image_key}, + ) + message_response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type="post", + payload=post_payload, + reply_to=reply_to, + metadata=metadata, + ) + else: + message_response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type="image", + payload=json.dumps({"image_key": image_key}, ensure_ascii=False), + reply_to=reply_to, + metadata=metadata, + ) + return self._finalize_send_result(message_response, "image send failed") + except Exception as exc: + logger.error("[Feishu] Failed to send image %s: %s", image_path, exc, exc_info=True) + return SendResult(success=False, error=str(exc)) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """Feishu bot API does not expose a typing indicator.""" + return None + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Download a remote image then send it through the native Feishu image flow.""" + try: + image_path = await self._download_remote_image(image_url) + except Exception as exc: + logger.error("[Feishu] Failed to download image %s: %s", image_url, exc, exc_info=True) + return await super().send_image( + chat_id=chat_id, + image_url=image_url, + caption=caption, + reply_to=reply_to, + metadata=metadata, + ) + return await self.send_image_file( + chat_id=chat_id, + image_path=image_path, + caption=caption, + reply_to=reply_to, + metadata=metadata, + ) + + async def send_animation( + self, + chat_id: str, + animation_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Feishu has no native GIF bubble; degrade to a downloadable file.""" + try: + file_path, file_name = await self._download_remote_document( + animation_url, + default_ext=".gif", + preferred_name="animation.gif", + ) + except Exception as exc: + logger.error("[Feishu] Failed to download animation %s: %s", animation_url, exc, exc_info=True) + return await super().send_animation( + chat_id=chat_id, + animation_url=animation_url, + caption=caption, + reply_to=reply_to, + metadata=metadata, + ) + degraded_caption = f"[GIF downgraded to file]\n{caption}" if caption else "[GIF downgraded to file]" + return await self.send_document( + chat_id=chat_id, + file_path=file_path, + file_name=file_name, + caption=degraded_caption, + reply_to=reply_to, + metadata=metadata, + ) + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return real chat metadata from Feishu when available.""" + fallback = { + "chat_id": chat_id, + "name": chat_id, + "type": "dm", + } + if not self._client: + return fallback + + cached = self._chat_info_cache.get(chat_id) + if cached is not None: + return dict(cached) + + try: + request = self._build_get_chat_request(chat_id) + response = await asyncio.to_thread(self._client.im.v1.chat.get, request) + if not response or getattr(response, "success", lambda: False)() is False: + code = getattr(response, "code", "unknown") + msg = getattr(response, "msg", "chat lookup failed") + logger.warning("[Feishu] Failed to get chat info for %s: [%s] %s", chat_id, code, msg) + return fallback + + data = getattr(response, "data", None) + raw_chat_type = str(getattr(data, "chat_type", "") or "").strip().lower() + info = { + "chat_id": chat_id, + "name": str(getattr(data, "name", None) or chat_id), + "type": self._map_chat_type(raw_chat_type), + "raw_type": raw_chat_type or None, + } + self._chat_info_cache[chat_id] = info + return dict(info) + except Exception: + logger.warning("[Feishu] Failed to get chat info for %s", chat_id, exc_info=True) + return fallback + + def format_message(self, content: str) -> str: + """Feishu text messages are plain text by default.""" + return content.strip() + + # ========================================================================= + # Inbound event handlers + # ========================================================================= + + def _on_message_event(self, data: Any) -> None: + """Normalize Feishu inbound events into MessageEvent.""" + if self._loop is None: + logger.warning("[Feishu] Dropping inbound message before adapter loop is ready") + return + future = asyncio.run_coroutine_threadsafe( + self._handle_message_event_data(data), + self._loop, + ) + future.add_done_callback(self._log_background_failure) + + async def _handle_message_event_data(self, data: Any) -> None: + """Shared inbound message handling for websocket and webhook transports.""" + event = getattr(data, "event", None) + message = getattr(event, "message", None) + sender = getattr(event, "sender", None) + sender_id = getattr(sender, "sender_id", None) + if not message or not sender_id: + logger.debug("[Feishu] Dropping malformed inbound event: missing message or sender_id") + return + + message_id = getattr(message, "message_id", None) + if not message_id or self._is_duplicate(message_id): + logger.debug("[Feishu] Dropping duplicate/missing message_id: %s", message_id) + return + if getattr(sender, "sender_type", "") == "bot": + logger.debug("[Feishu] Dropping bot-originated event: %s", message_id) + return + + chat_type = getattr(message, "chat_type", "p2p") + if chat_type != "p2p" and not self._should_accept_group_message(message, sender_id): + logger.debug("[Feishu] Dropping group message that failed mention/policy gate: %s", message_id) + return + await self._process_inbound_message( + data=data, + message=message, + sender_id=sender_id, + chat_type=chat_type, + message_id=message_id, + ) + + def _on_message_read_event(self, data: P2ImMessageMessageReadV1) -> None: + """Ignore read-receipt events that Hermes does not act on.""" + event = getattr(data, "event", None) + message = getattr(event, "message", None) + message_id = getattr(message, "message_id", None) or "" + logger.debug("[Feishu] Ignoring message_read event: %s", message_id) + + def _on_bot_added_to_chat(self, data: Any) -> None: + """Handle bot being added to a group chat.""" + event = getattr(data, "event", None) + chat_id = str(getattr(event, "chat_id", "") or "") + logger.info("[Feishu] Bot added to chat: %s", chat_id) + self._chat_info_cache.pop(chat_id, None) + + def _on_bot_removed_from_chat(self, data: Any) -> None: + """Handle bot being removed from a group chat.""" + event = getattr(data, "event", None) + chat_id = str(getattr(event, "chat_id", "") or "") + logger.info("[Feishu] Bot removed from chat: %s", chat_id) + self._chat_info_cache.pop(chat_id, None) + + def _on_reaction_event(self, event_type: str, data: Any) -> None: + """Route user reactions on bot messages as synthetic text events.""" + event = getattr(data, "event", None) + message_id = str(getattr(event, "message_id", "") or "") + operator_type = str(getattr(event, "operator_type", "") or "") + reaction_type_obj = getattr(event, "reaction_type", None) + emoji_type = str(getattr(reaction_type_obj, "emoji_type", "") or "") + action = "added" if "created" in event_type else "removed" + logger.debug( + "[Feishu] Reaction %s on message %s (operator_type=%s, emoji=%s)", + action, + message_id, + operator_type, + emoji_type, + ) + # Only process reactions from real users. Ignore app/bot-generated reactions + # and Hermes' own ACK emoji to avoid feedback loops. + if ( + operator_type in {"bot", "app"} + or emoji_type == _FEISHU_ACK_EMOJI + or not message_id + or self._loop is None + ): + return + future = asyncio.run_coroutine_threadsafe( + self._handle_reaction_event(event_type, data), + self._loop, + ) + future.add_done_callback(self._log_background_failure) + + def _on_card_action_trigger(self, data: Any) -> Any: + """Schedule Feishu card actions on the adapter loop and acknowledge immediately.""" + if self._loop is None: + logger.warning("[Feishu] Dropping card action before adapter loop is ready") + else: + future = asyncio.run_coroutine_threadsafe( + self._handle_card_action_event(data), + self._loop, + ) + future.add_done_callback(self._log_background_failure) + if P2CardActionTriggerResponse is None: + return None + return P2CardActionTriggerResponse() + + async def _handle_reaction_event(self, event_type: str, data: Any) -> None: + """Fetch the reacted-to message; if it was sent by this bot, emit a synthetic text event.""" + if not self._client: + return + event = getattr(data, "event", None) + message_id = str(getattr(event, "message_id", "") or "") + if not message_id: + return + + # Fetch the target message to verify it was sent by us and to obtain chat context. + try: + request = self._build_get_message_request(message_id) + response = await asyncio.to_thread(self._client.im.v1.message.get, request) + if not response or not getattr(response, "success", lambda: False)(): + return + items = getattr(getattr(response, "data", None), "items", None) or [] + msg = items[0] if items else None + if not msg: + return + sender = getattr(msg, "sender", None) + sender_type = str(getattr(sender, "sender_type", "") or "").lower() + if sender_type != "app": + return # only route reactions on our own bot messages + chat_id = str(getattr(msg, "chat_id", "") or "") + chat_type_raw = str(getattr(msg, "chat_type", "p2p") or "p2p") + if not chat_id: + return + except Exception: + logger.debug("[Feishu] Failed to fetch message for reaction routing", exc_info=True) + return + + user_id_obj = getattr(event, "user_id", None) + reaction_type_obj = getattr(event, "reaction_type", None) + emoji_type = str(getattr(reaction_type_obj, "emoji_type", "") or "UNKNOWN") + action = "added" if "created" in event_type else "removed" + synthetic_text = f"reaction:{action}:{emoji_type}" + + sender_profile = await self._resolve_sender_profile(user_id_obj) + chat_info = await self.get_chat_info(chat_id) + source = self.build_source( + chat_id=chat_id, + chat_name=chat_info.get("name") or chat_id or "Feishu Chat", + chat_type=self._resolve_source_chat_type(chat_info=chat_info, event_chat_type=chat_type_raw), + user_id=sender_profile["user_id"], + user_name=sender_profile["user_name"], + thread_id=None, + user_id_alt=sender_profile["user_id_alt"], + ) + synthetic_event = MessageEvent( + text=synthetic_text, + message_type=MessageType.TEXT, + source=source, + raw_message=data, + message_id=message_id, + timestamp=datetime.now(), + ) + logger.info("[Feishu] Routing reaction %s:%s on bot message %s as synthetic event", action, emoji_type, message_id) + await self._handle_message_with_guards(synthetic_event) + + def _is_card_action_duplicate(self, token: str) -> bool: + """Return True if this card action token was already processed within the dedup window.""" + now = time.time() + # Prune expired tokens lazily each call. + expired = [t for t, ts in self._card_action_tokens.items() if now - ts > _FEISHU_CARD_ACTION_DEDUP_TTL_SECONDS] + for t in expired: + del self._card_action_tokens[t] + if token in self._card_action_tokens: + return True + self._card_action_tokens[token] = now + return False + + async def _handle_card_action_event(self, data: Any) -> None: + """Route Feishu interactive card button clicks as synthetic COMMAND events.""" + event = getattr(data, "event", None) + token = str(getattr(event, "token", "") or "") + if token and self._is_card_action_duplicate(token): + logger.debug("[Feishu] Dropping duplicate card action token: %s", token) + return + + context = getattr(event, "context", None) + chat_id = str(getattr(context, "open_chat_id", "") or "") + operator = getattr(event, "operator", None) + open_id = str(getattr(operator, "open_id", "") or "") + if not chat_id or not open_id: + logger.debug("[Feishu] Card action missing chat_id or operator open_id, dropping") + return + + action = getattr(event, "action", None) + action_tag = str(getattr(action, "tag", "") or "button") + action_value = getattr(action, "value", {}) or {} + synthetic_text = f"/card {action_tag}" + if action_value: + try: + synthetic_text += f" {json.dumps(action_value, ensure_ascii=False)}" + except Exception: + pass + + sender_id = SimpleNamespace(open_id=open_id, user_id=None, union_id=None) + sender_profile = await self._resolve_sender_profile(sender_id) + chat_info = await self.get_chat_info(chat_id) + source = self.build_source( + chat_id=chat_id, + chat_name=chat_info.get("name") or chat_id or "Feishu Chat", + chat_type=self._resolve_source_chat_type(chat_info=chat_info, event_chat_type="group"), + user_id=sender_profile["user_id"], + user_name=sender_profile["user_name"], + thread_id=None, + user_id_alt=sender_profile["user_id_alt"], + ) + synthetic_event = MessageEvent( + text=synthetic_text, + message_type=MessageType.COMMAND, + source=source, + raw_message=data, + message_id=token or str(uuid.uuid4()), + timestamp=datetime.now(), + ) + logger.info("[Feishu] Routing card action %r from %s in %s as synthetic command", action_tag, open_id, chat_id) + await self._handle_message_with_guards(synthetic_event) + + # ========================================================================= + # Per-chat serialization and typing indicator + # ========================================================================= + + def _get_chat_lock(self, chat_id: str) -> asyncio.Lock: + """Return (creating if needed) the per-chat asyncio.Lock for serial message processing.""" + lock = self._chat_locks.get(chat_id) + if lock is None: + lock = asyncio.Lock() + self._chat_locks[chat_id] = lock + return lock + + async def _handle_message_with_guards(self, event: MessageEvent) -> None: + """Dispatch a single event through the agent pipeline with per-chat serialization + and a persistent ACK emoji reaction before processing starts. + + - Per-chat lock: ensures messages in the same chat are processed one at a time + (matches openclaw's createChatQueue serial queue behaviour). + - ACK indicator: adds a CHECK reaction to the triggering message before handing + off to the agent and leaves it in place as a receipt marker. + """ + chat_id = getattr(event.source, "chat_id", "") or "" if event.source else "" + chat_lock = self._get_chat_lock(chat_id) + async with chat_lock: + message_id = event.message_id + if message_id: + await self._add_ack_reaction(message_id) + await self.handle_message(event) + + async def _add_ack_reaction(self, message_id: str) -> Optional[str]: + """Add a persistent ACK emoji reaction to signal the message was received.""" + if not self._client or not message_id: + return None + try: + from lark_oapi.api.im.v1 import ( # lazy import — keeps optional dep optional + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + ) + body = ( + CreateMessageReactionRequestBody.builder() + .reaction_type({"emoji_type": _FEISHU_ACK_EMOJI}) + .build() + ) + request = ( + CreateMessageReactionRequest.builder() + .message_id(message_id) + .request_body(body) + .build() + ) + response = await asyncio.to_thread(self._client.im.v1.message_reaction.create, request) + if response and getattr(response, "success", lambda: False)(): + data = getattr(response, "data", None) + return getattr(data, "reaction_id", None) + logger.warning( + "[Feishu] Failed to add ack reaction to %s: code=%s msg=%s", + message_id, + getattr(response, "code", None), + getattr(response, "msg", None), + ) + except Exception: + logger.warning("[Feishu] Failed to add ack reaction to %s", message_id, exc_info=True) + return None + + # ========================================================================= + # Webhook server and security + # ========================================================================= + + def _record_webhook_anomaly(self, remote_ip: str, status: str) -> None: + """Increment the anomaly counter for remote_ip and emit a WARNING every threshold hits. + + Mirrors openclaw's createWebhookAnomalyTracker: TTL 6 hours, log every 25 consecutive + error responses from the same IP. + """ + now = time.time() + entry = self._webhook_anomaly_counts.get(remote_ip) + if entry is not None: + count, _last_status, first_seen = entry + if now - first_seen < _FEISHU_WEBHOOK_ANOMALY_TTL_SECONDS: + count += 1 + if count % _FEISHU_WEBHOOK_ANOMALY_THRESHOLD == 0: + logger.warning( + "[Feishu] Webhook anomaly: %d consecutive error responses (%s) from %s " + "over the last %.0fs", + count, + status, + remote_ip, + now - first_seen, + ) + self._webhook_anomaly_counts[remote_ip] = (count, status, first_seen) + return + # Either first occurrence or TTL expired — start fresh. + self._webhook_anomaly_counts[remote_ip] = (1, status, now) + + def _clear_webhook_anomaly(self, remote_ip: str) -> None: + """Reset the anomaly counter for remote_ip after a successful request.""" + self._webhook_anomaly_counts.pop(remote_ip, None) + + # ========================================================================= + # Inbound processing pipeline + # ========================================================================= + + async def _process_inbound_message( + self, + *, + data: Any, + message: Any, + sender_id: Any, + chat_type: str, + message_id: str, + ) -> None: + text, inbound_type, media_urls, media_types = await self._extract_message_content(message) + if inbound_type == MessageType.TEXT and not text and not media_urls: + logger.debug("[Feishu] Ignoring unsupported or empty message type: %s", getattr(message, "message_type", "")) + return + + if inbound_type == MessageType.TEXT and text.startswith("/"): + inbound_type = MessageType.COMMAND + + reply_to_message_id = ( + getattr(message, "parent_id", None) + or getattr(message, "upper_message_id", None) + or None + ) + reply_to_text = await self._fetch_message_text(reply_to_message_id) if reply_to_message_id else None + + logger.info( + "[Feishu] Inbound %s message received: id=%s type=%s chat_id=%s text=%r media=%d", + "dm" if chat_type == "p2p" else "group", + message_id, + inbound_type.value, + getattr(message, "chat_id", "") or "", + text[:120], + len(media_urls), + ) + + chat_id = getattr(message, "chat_id", "") or "" + chat_info = await self.get_chat_info(chat_id) + sender_profile = await self._resolve_sender_profile(sender_id) + source = self.build_source( + chat_id=chat_id, + chat_name=chat_info.get("name") or chat_id or "Feishu Chat", + chat_type=self._resolve_source_chat_type(chat_info=chat_info, event_chat_type=chat_type), + user_id=sender_profile["user_id"], + user_name=sender_profile["user_name"], + thread_id=getattr(message, "thread_id", None) or None, + user_id_alt=sender_profile["user_id_alt"], + ) + normalized = MessageEvent( + text=text, + message_type=inbound_type, + source=source, + raw_message=data, + message_id=message_id, + media_urls=media_urls, + media_types=media_types, + reply_to_message_id=reply_to_message_id, + reply_to_text=reply_to_text, + timestamp=datetime.now(), + ) + await self._dispatch_inbound_event(normalized) + + async def _dispatch_inbound_event(self, event: MessageEvent) -> None: + """Apply Feishu-specific burst protection before entering the base adapter.""" + if event.message_type == MessageType.TEXT and not event.is_command(): + await self._enqueue_text_event(event) + return + if self._should_batch_media_event(event): + await self._enqueue_media_event(event) + return + await self._handle_message_with_guards(event) + + # ========================================================================= + # Media batching + # ========================================================================= + + def _should_batch_media_event(self, event: MessageEvent) -> bool: + return bool( + event.media_urls + and event.message_type in {MessageType.PHOTO, MessageType.VIDEO, MessageType.DOCUMENT, MessageType.AUDIO} + ) + + def _media_batch_key(self, event: MessageEvent) -> str: + from gateway.session import build_session_key + + session_key = build_session_key( + event.source, + group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), + ) + return f"{session_key}:media:{event.message_type.value}" + + @staticmethod + def _media_batch_is_compatible(existing: MessageEvent, incoming: MessageEvent) -> bool: + return ( + existing.message_type == incoming.message_type + and existing.reply_to_message_id == incoming.reply_to_message_id + and existing.reply_to_text == incoming.reply_to_text + and existing.source.thread_id == incoming.source.thread_id + ) + + async def _enqueue_media_event(self, event: MessageEvent) -> None: + key = self._media_batch_key(event) + existing = self._pending_media_batches.get(key) + if existing is None: + self._pending_media_batches[key] = event + self._schedule_media_batch_flush(key) + return + if not self._media_batch_is_compatible(existing, event): + await self._flush_media_batch_now(key) + self._pending_media_batches[key] = event + self._schedule_media_batch_flush(key) + return + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) + if event.text: + if not existing.text: + existing.text = event.text + elif event.text not in existing.text.split("\n\n"): + existing.text = f"{existing.text}\n\n{event.text}" + existing.timestamp = event.timestamp + if event.message_id: + existing.message_id = event.message_id + self._schedule_media_batch_flush(key) + + def _schedule_media_batch_flush(self, key: str) -> None: + self._reschedule_batch_task( + self._pending_media_batch_tasks, + key, + self._flush_media_batch, + ) + + async def _flush_media_batch(self, key: str) -> None: + current_task = asyncio.current_task() + try: + await asyncio.sleep(self._media_batch_delay_seconds) + await self._flush_media_batch_now(key) + finally: + if self._pending_media_batch_tasks.get(key) is current_task: + self._pending_media_batch_tasks.pop(key, None) + + async def _flush_media_batch_now(self, key: str) -> None: + event = self._pending_media_batches.pop(key, None) + if not event: + return + logger.info( + "[Feishu] Flushing media batch %s with %d attachment(s)", + key, + len(event.media_urls), + ) + await self._handle_message_with_guards(event) + + async def _download_remote_image(self, image_url: str) -> str: + ext = self._guess_remote_extension(image_url, default=".jpg") + return await cache_image_from_url(image_url, ext=ext) + + async def _download_remote_document( + self, + file_url: str, + *, + default_ext: str, + preferred_name: str, + ) -> tuple[str, str]: + import httpx + + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + response = await client.get( + file_url, + headers={ + "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", + "Accept": "*/*", + }, + ) + response.raise_for_status() + filename = self._derive_remote_filename( + file_url, + content_type=str(response.headers.get("Content-Type", "")), + default_name=preferred_name, + default_ext=default_ext, + ) + cached_path = cache_document_from_bytes(response.content, filename) + return cached_path, filename + + @staticmethod + def _guess_remote_extension(url: str, *, default: str) -> str: + ext = Path((url or "").split("?", 1)[0]).suffix.lower() + return ext if ext in (_IMAGE_EXTENSIONS | _AUDIO_EXTENSIONS | _VIDEO_EXTENSIONS | set(SUPPORTED_DOCUMENT_TYPES)) else default + + @staticmethod + def _derive_remote_filename(file_url: str, *, content_type: str, default_name: str, default_ext: str) -> str: + candidate = Path((file_url or "").split("?", 1)[0]).name or default_name + ext = Path(candidate).suffix.lower() + if not ext: + guessed = mimetypes.guess_extension((content_type or "").split(";", 1)[0].strip().lower() or "") or default_ext + candidate = f"{candidate}{guessed}" + return candidate + + @staticmethod + def _namespace_from_mapping(value: Any) -> Any: + if isinstance(value, dict): + return SimpleNamespace(**{key: FeishuAdapter._namespace_from_mapping(item) for key, item in value.items()}) + if isinstance(value, list): + return [FeishuAdapter._namespace_from_mapping(item) for item in value] + return value + + async def _handle_webhook_request(self, request: Any) -> Any: + remote_ip = (getattr(request, "remote", None) or "unknown") + + # Rate limiting — composite key: app_id:path:remote_ip (matches openclaw key structure). + rate_key = f"{self._app_id}:{self._webhook_path}:{remote_ip}" + if not self._check_webhook_rate_limit(rate_key): + logger.warning("[Feishu] Webhook rate limit exceeded for %s", remote_ip) + self._record_webhook_anomaly(remote_ip, "429") + return web.Response(status=429, text="Too Many Requests") + + # Content-Type guard — Feishu always sends application/json. + headers = getattr(request, "headers", {}) or {} + content_type = str(headers.get("Content-Type", "") or "").split(";")[0].strip().lower() + if content_type and content_type != "application/json": + logger.warning("[Feishu] Webhook rejected: unexpected Content-Type %r from %s", content_type, remote_ip) + self._record_webhook_anomaly(remote_ip, "415") + return web.Response(status=415, text="Unsupported Media Type") + + # Body size guard — reject early via Content-Length when present. + content_length = getattr(request, "content_length", None) + if content_length is not None and content_length > _FEISHU_WEBHOOK_MAX_BODY_BYTES: + logger.warning("[Feishu] Webhook body too large (%d bytes) from %s", content_length, remote_ip) + self._record_webhook_anomaly(remote_ip, "413") + return web.Response(status=413, text="Request body too large") + + try: + body_bytes: bytes = await asyncio.wait_for( + request.read(), + timeout=_FEISHU_WEBHOOK_BODY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning("[Feishu] Webhook body read timed out after %ds from %s", _FEISHU_WEBHOOK_BODY_TIMEOUT_SECONDS, remote_ip) + self._record_webhook_anomaly(remote_ip, "408") + return web.Response(status=408, text="Request Timeout") + except Exception: + self._record_webhook_anomaly(remote_ip, "400") + return web.json_response({"code": 400, "msg": "failed to read body"}, status=400) + + if len(body_bytes) > _FEISHU_WEBHOOK_MAX_BODY_BYTES: + logger.warning("[Feishu] Webhook body exceeds limit (%d bytes) from %s", len(body_bytes), remote_ip) + self._record_webhook_anomaly(remote_ip, "413") + return web.Response(status=413, text="Request body too large") + + try: + payload = json.loads(body_bytes.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + self._record_webhook_anomaly(remote_ip, "400") + return web.json_response({"code": 400, "msg": "invalid json"}, status=400) + + # URL verification challenge — respond before other checks so that Feishu's + # subscription setup works even before encrypt_key is wired. + if payload.get("type") == "url_verification": + return web.json_response({"challenge": payload.get("challenge", "")}) + + # Verification token check — second layer of defence beyond signature (matches openclaw). + if self._verification_token: + header = payload.get("header") or {} + incoming_token = str(header.get("token") or payload.get("token") or "") + if not incoming_token or not hmac.compare_digest(incoming_token, self._verification_token): + logger.warning("[Feishu] Webhook rejected: invalid verification token from %s", remote_ip) + self._record_webhook_anomaly(remote_ip, "401-token") + return web.Response(status=401, text="Invalid verification token") + + # Timing-safe signature verification (only enforced when encrypt_key is set). + if self._encrypt_key and not self._is_webhook_signature_valid(request.headers, body_bytes): + logger.warning("[Feishu] Webhook rejected: invalid signature from %s", remote_ip) + self._record_webhook_anomaly(remote_ip, "401-sig") + return web.Response(status=401, text="Invalid signature") + + if payload.get("encrypt"): + logger.error("[Feishu] Encrypted webhook payloads are not supported by Hermes webhook mode") + self._record_webhook_anomaly(remote_ip, "400-encrypted") + return web.json_response({"code": 400, "msg": "encrypted webhook payloads are not supported"}, status=400) + + self._clear_webhook_anomaly(remote_ip) + + event_type = str((payload.get("header") or {}).get("event_type") or "") + data = self._namespace_from_mapping(payload) + if event_type == "im.message.receive_v1": + await self._handle_message_event_data(data) + elif event_type == "im.message.message_read_v1": + self._on_message_read_event(data) + elif event_type == "im.chat.member.bot.added_v1": + self._on_bot_added_to_chat(data) + elif event_type == "im.chat.member.bot.deleted_v1": + self._on_bot_removed_from_chat(data) + elif event_type in ("im.message.reaction.created_v1", "im.message.reaction.deleted_v1"): + self._on_reaction_event(event_type, data) + elif event_type == "card.action.trigger": + asyncio.ensure_future(self._handle_card_action_event(data)) + else: + logger.debug("[Feishu] Ignoring webhook event type: %s", event_type or "unknown") + return web.json_response({"code": 0, "msg": "ok"}) + + def _is_webhook_signature_valid(self, headers: Any, body_bytes: bytes) -> bool: + """Verify Feishu webhook signature using timing-safe comparison. + + Feishu signature algorithm: + SHA256(timestamp + nonce + encrypt_key + body_string) + Headers checked: x-lark-request-timestamp, x-lark-request-nonce, x-lark-signature. + """ + timestamp = str(headers.get("x-lark-request-timestamp", "") or "") + nonce = str(headers.get("x-lark-request-nonce", "") or "") + signature = str(headers.get("x-lark-signature", "") or "") + if not timestamp or not nonce or not signature: + return False + try: + body_str = body_bytes.decode("utf-8", errors="replace") + content = f"{timestamp}{nonce}{self._encrypt_key}{body_str}" + computed = hashlib.sha256(content.encode("utf-8")).hexdigest() + return hmac.compare_digest(computed, signature) + except Exception: + logger.debug("[Feishu] Signature verification raised an exception", exc_info=True) + return False + + def _check_webhook_rate_limit(self, rate_key: str) -> bool: + """Return False when the composite rate_key has exceeded _FEISHU_WEBHOOK_RATE_LIMIT_MAX. + + The rate_key is composed as "{app_id}:{path}:{remote_ip}" — matching openclaw's key + structure so the limit is scoped to a specific (account, endpoint, IP) triple rather + than a bare IP, which causes fewer false-positive denials in multi-tenant setups. + + The tracking dict is capped at _FEISHU_WEBHOOK_RATE_MAX_KEYS entries to prevent unbounded + memory growth. Stale (expired) entries are pruned when the cap is reached. + """ + now = time.time() + # Fast path: existing entry within the current window. + entry = self._webhook_rate_counts.get(rate_key) + if entry is not None: + count, window_start = entry + if now - window_start < _FEISHU_WEBHOOK_RATE_WINDOW_SECONDS: + if count >= _FEISHU_WEBHOOK_RATE_LIMIT_MAX: + return False + self._webhook_rate_counts[rate_key] = (count + 1, window_start) + return True + # New window for an existing key, or a brand-new key — prune stale entries first. + if len(self._webhook_rate_counts) >= _FEISHU_WEBHOOK_RATE_MAX_KEYS: + stale_keys = [ + k for k, (_, ws) in self._webhook_rate_counts.items() + if now - ws >= _FEISHU_WEBHOOK_RATE_WINDOW_SECONDS + ] + for k in stale_keys: + del self._webhook_rate_counts[k] + # If still at capacity after pruning, allow through without tracking. + if rate_key not in self._webhook_rate_counts and len(self._webhook_rate_counts) >= _FEISHU_WEBHOOK_RATE_MAX_KEYS: + return True + self._webhook_rate_counts[rate_key] = (1, now) + return True + + # ========================================================================= + # Text batching + # ========================================================================= + + def _text_batch_key(self, event: MessageEvent) -> str: + """Return the session-scoped key used for Feishu text aggregation.""" + from gateway.session import build_session_key + + return build_session_key( + event.source, + group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), + ) + + @staticmethod + def _text_batch_is_compatible(existing: MessageEvent, incoming: MessageEvent) -> bool: + """Only merge text events when reply/thread context is identical.""" + return ( + existing.reply_to_message_id == incoming.reply_to_message_id + and existing.reply_to_text == incoming.reply_to_text + and existing.source.thread_id == incoming.source.thread_id + ) + + async def _enqueue_text_event(self, event: MessageEvent) -> None: + """Debounce rapid Feishu text bursts into a single MessageEvent.""" + key = self._text_batch_key(event) + existing = self._pending_text_batches.get(key) + if existing is None: + self._pending_text_batches[key] = event + self._pending_text_batch_counts[key] = 1 + self._schedule_text_batch_flush(key) + return + + if not self._text_batch_is_compatible(existing, event): + await self._flush_text_batch_now(key) + self._pending_text_batches[key] = event + self._pending_text_batch_counts[key] = 1 + self._schedule_text_batch_flush(key) + return + + existing_count = self._pending_text_batch_counts.get(key, 1) + next_count = existing_count + 1 + appended_text = event.text or "" + next_text = f"{existing.text}\n{appended_text}" if existing.text and appended_text else (existing.text or appended_text) + if next_count > self._text_batch_max_messages or len(next_text) > self._text_batch_max_chars: + await self._flush_text_batch_now(key) + self._pending_text_batches[key] = event + self._pending_text_batch_counts[key] = 1 + self._schedule_text_batch_flush(key) + return + + existing.text = next_text + existing.timestamp = event.timestamp + if event.message_id: + existing.message_id = event.message_id + self._pending_text_batch_counts[key] = next_count + self._schedule_text_batch_flush(key) + + def _schedule_text_batch_flush(self, key: str) -> None: + """Reset the debounce timer for a pending Feishu text batch.""" + self._reschedule_batch_task( + self._pending_text_batch_tasks, + key, + self._flush_text_batch, + ) + + @staticmethod + def _reschedule_batch_task( + task_map: Dict[str, asyncio.Task], + key: str, + flush_fn: Any, + ) -> None: + prior_task = task_map.get(key) + if prior_task and not prior_task.done(): + prior_task.cancel() + task_map[key] = asyncio.create_task(flush_fn(key)) + + async def _flush_text_batch(self, key: str) -> None: + """Flush a pending text batch after the quiet period.""" + current_task = asyncio.current_task() + try: + await asyncio.sleep(self._text_batch_delay_seconds) + await self._flush_text_batch_now(key) + finally: + if self._pending_text_batch_tasks.get(key) is current_task: + self._pending_text_batch_tasks.pop(key, None) + + async def _flush_text_batch_now(self, key: str) -> None: + """Dispatch the current text batch immediately.""" + event = self._pending_text_batches.pop(key, None) + self._pending_text_batch_counts.pop(key, None) + if not event: + return + logger.info( + "[Feishu] Flushing text batch %s (%d chars)", + key, + len(event.text or ""), + ) + await self._handle_message_with_guards(event) + + # ========================================================================= + # Message content extraction and resource download + # ========================================================================= + + async def _extract_message_content(self, message: Any) -> tuple[str, MessageType, List[str], List[str]]: + """Extract text and cached media from a normalized Feishu message.""" + raw_content = getattr(message, "content", "") or "" + raw_type = getattr(message, "message_type", "") or "" + message_id = str(getattr(message, "message_id", "") or "") + logger.info("[Feishu] Received raw message type=%s message_id=%s", raw_type, message_id) + + normalized = normalize_feishu_message(message_type=raw_type, raw_content=raw_content) + media_urls, media_types = await self._download_feishu_message_resources( + message_id=message_id, + normalized=normalized, + ) + inbound_type = self._resolve_normalized_message_type(normalized, media_types) + text = normalized.text_content + + if ( + inbound_type in {MessageType.DOCUMENT, MessageType.AUDIO, MessageType.VIDEO, MessageType.PHOTO} + and len(media_urls) == 1 + and normalized.preferred_message_type in {"document", "audio"} + ): + injected = await self._maybe_extract_text_document(media_urls[0], media_types[0]) + if injected: + text = injected + + return text, inbound_type, media_urls, media_types + + async def _download_feishu_message_resources( + self, + *, + message_id: str, + normalized: FeishuNormalizedMessage, + ) -> tuple[List[str], List[str]]: + media_urls: List[str] = [] + media_types: List[str] = [] + + for image_key in normalized.image_keys: + cached_path, media_type = await self._download_feishu_image( + message_id=message_id, + image_key=image_key, + ) + if cached_path: + media_urls.append(cached_path) + media_types.append(media_type) + + for media_ref in normalized.media_refs: + cached_path, media_type = await self._download_feishu_message_resource( + message_id=message_id, + file_key=media_ref.file_key, + resource_type=media_ref.resource_type, + fallback_filename=media_ref.file_name, + ) + if cached_path: + media_urls.append(cached_path) + media_types.append(media_type) + + return media_urls, media_types + + @staticmethod + def _resolve_media_message_type(media_type: str, *, default: MessageType) -> MessageType: + normalized = (media_type or "").lower() + if normalized.startswith("image/"): + return MessageType.PHOTO + if normalized.startswith("audio/"): + return MessageType.AUDIO + if normalized.startswith("video/"): + return MessageType.VIDEO + return default + + def _resolve_normalized_message_type( + self, + normalized: FeishuNormalizedMessage, + media_types: List[str], + ) -> MessageType: + preferred = normalized.preferred_message_type + if preferred == "photo": + return self._resolve_media_message_type(media_types[0] if media_types else "", default=MessageType.PHOTO) + if preferred == "audio": + return self._resolve_media_message_type(media_types[0] if media_types else "", default=MessageType.AUDIO) + if preferred == "document": + return self._resolve_media_message_type(media_types[0] if media_types else "", default=MessageType.DOCUMENT) + return MessageType.TEXT + + def _normalize_inbound_text(self, text: str) -> str: + """Strip Feishu mention placeholders from inbound text.""" + text = _MENTION_RE.sub(" ", text or "") + text = _MULTISPACE_RE.sub(" ", text) + return text.strip() + + async def _maybe_extract_text_document(self, cached_path: str, media_type: str) -> str: + if not cached_path or not media_type.startswith("text/"): + return "" + try: + if os.path.getsize(cached_path) > _MAX_TEXT_INJECT_BYTES: + return "" + ext = Path(cached_path).suffix.lower() + if ext not in {".txt", ".md"} and media_type not in {"text/plain", "text/markdown"}: + return "" + content = Path(cached_path).read_text(encoding="utf-8") + display_name = self._display_name_from_cached_path(cached_path) + return f"[Content of {display_name}]:\n{content}" + except (OSError, UnicodeDecodeError): + logger.warning("[Feishu] Failed to inject text document content from %s", cached_path, exc_info=True) + return "" + + async def _download_feishu_image(self, *, message_id: str, image_key: str) -> tuple[str, str]: + if not self._client or not message_id: + return "", "" + try: + request = self._build_message_resource_request( + message_id=message_id, + file_key=image_key, + resource_type="image", + ) + response = await asyncio.to_thread(self._client.im.v1.message_resource.get, request) + if not response or not response.success(): + logger.warning( + "[Feishu] Failed to download image %s: %s %s", + image_key, + getattr(response, "code", "unknown"), + getattr(response, "msg", "request failed"), + ) + return "", "" + raw_bytes = self._read_binary_response(response) + if not raw_bytes: + return "", "" + content_type = self._get_response_header(response, "Content-Type") + filename = getattr(response, "file_name", None) or f"{image_key}.jpg" + ext = self._guess_extension(filename, content_type, ".jpg", allowed=_IMAGE_EXTENSIONS) + cached_path = cache_image_from_bytes(raw_bytes, ext=ext) + media_type = self._normalize_media_type(content_type, default=self._default_image_media_type(ext)) + return cached_path, media_type + except Exception: + logger.warning("[Feishu] Failed to cache image resource %s", image_key, exc_info=True) + return "", "" + + async def _download_feishu_message_resource( + self, + *, + message_id: str, + file_key: str, + resource_type: str, + fallback_filename: str, + ) -> tuple[str, str]: + if not self._client or not message_id: + return "", "" + + request_types = [resource_type] + if resource_type in {"audio", "media"}: + request_types.append("file") + + for request_type in request_types: + try: + request = self._build_message_resource_request( + message_id=message_id, + file_key=file_key, + resource_type=request_type, + ) + response = await asyncio.to_thread(self._client.im.v1.message_resource.get, request) + if not response or not response.success(): + logger.debug( + "[Feishu] Resource download failed for %s/%s via type=%s: %s %s", + message_id, + file_key, + request_type, + getattr(response, "code", "unknown"), + getattr(response, "msg", "request failed"), + ) + continue + + raw_bytes = self._read_binary_response(response) + if not raw_bytes: + continue + content_type = self._get_response_header(response, "Content-Type") + response_filename = getattr(response, "file_name", None) or "" + filename = response_filename or fallback_filename or f"{request_type}_{file_key}" + media_type = self._normalize_media_type( + content_type, + default=self._guess_media_type_from_filename(filename), + ) + + if media_type.startswith("image/"): + ext = self._guess_extension(filename, content_type, ".jpg", allowed=_IMAGE_EXTENSIONS) + cached_path = cache_image_from_bytes(raw_bytes, ext=ext) + logger.info("[Feishu] Cached message image resource at %s", cached_path) + return cached_path, media_type or self._default_image_media_type(ext) + + if request_type == "audio" or media_type.startswith("audio/"): + ext = self._guess_extension(filename, content_type, ".ogg", allowed=_AUDIO_EXTENSIONS) + cached_path = cache_audio_from_bytes(raw_bytes, ext=ext) + logger.info("[Feishu] Cached message audio resource at %s", cached_path) + return cached_path, (media_type or f"audio/{ext.lstrip('.') or 'ogg'}") + + if media_type.startswith("video/"): + if not Path(filename).suffix: + filename = f"{filename}.mp4" + cached_path = cache_document_from_bytes(raw_bytes, filename) + logger.info("[Feishu] Cached message video resource at %s", cached_path) + return cached_path, media_type + + if not Path(filename).suffix and media_type in _DOCUMENT_MIME_TO_EXT: + filename = f"{filename}{_DOCUMENT_MIME_TO_EXT[media_type]}" + cached_path = cache_document_from_bytes(raw_bytes, filename) + logger.info("[Feishu] Cached message document resource at %s", cached_path) + return cached_path, (media_type or self._guess_document_media_type(filename)) + except Exception: + logger.warning( + "[Feishu] Failed to cache message resource %s/%s", + message_id, + file_key, + exc_info=True, + ) + return "", "" + + # ========================================================================= + # Static helpers — extension / media-type guessing + # ========================================================================= + + @staticmethod + def _read_binary_response(response: Any) -> bytes: + file_obj = getattr(response, "file", None) + if file_obj is None: + return b"" + if hasattr(file_obj, "getvalue"): + return bytes(file_obj.getvalue()) + return bytes(file_obj.read()) + + @staticmethod + def _get_response_header(response: Any, name: str) -> str: + raw = getattr(response, "raw", None) + headers = getattr(raw, "headers", {}) or {} + return str(headers.get(name, headers.get(name.lower(), "")) or "").split(";", 1)[0].strip().lower() + + @staticmethod + def _guess_extension(filename: str, content_type: str, default: str, *, allowed: set[str]) -> str: + ext = Path(filename or "").suffix.lower() + if ext in allowed: + return ext + guessed = mimetypes.guess_extension((content_type or "").split(";", 1)[0].strip().lower() or "") + if guessed in allowed: + return guessed + return default + + @staticmethod + def _normalize_media_type(content_type: str, *, default: str) -> str: + normalized = (content_type or "").split(";", 1)[0].strip().lower() + return normalized or default + + @staticmethod + def _guess_document_media_type(filename: str) -> str: + ext = Path(filename or "").suffix.lower() + return SUPPORTED_DOCUMENT_TYPES.get(ext, mimetypes.guess_type(filename or "")[0] or "application/octet-stream") + + @staticmethod + def _display_name_from_cached_path(path: str) -> str: + basename = os.path.basename(path) + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + return re.sub(r"[^\w.\- ]", "_", display_name) + + @staticmethod + def _guess_media_type_from_filename(filename: str) -> str: + guessed = (mimetypes.guess_type(filename or "")[0] or "").lower() + if guessed: + return guessed + ext = Path(filename or "").suffix.lower() + if ext in _VIDEO_EXTENSIONS: + return f"video/{ext.lstrip('.')}" + if ext in _AUDIO_EXTENSIONS: + return f"audio/{ext.lstrip('.')}" + if ext in _IMAGE_EXTENSIONS: + return FeishuAdapter._default_image_media_type(ext) + return "" + + @staticmethod + def _map_chat_type(raw_chat_type: str) -> str: + normalized = (raw_chat_type or "").strip().lower() + if normalized == "p2p": + return "dm" + if "topic" in normalized or "thread" in normalized or "forum" in normalized: + return "forum" + if normalized == "group": + return "group" + return "dm" + + @staticmethod + def _resolve_source_chat_type(*, chat_info: Dict[str, Any], event_chat_type: str) -> str: + resolved = str(chat_info.get("type") or "").strip().lower() + if resolved in {"group", "forum"}: + return resolved + if event_chat_type == "p2p": + return "dm" + return "group" + + async def _resolve_sender_profile(self, sender_id: Any) -> Dict[str, Optional[str]]: + open_id = getattr(sender_id, "open_id", None) or None + user_id = getattr(sender_id, "user_id", None) or None + union_id = getattr(sender_id, "union_id", None) or None + primary_id = open_id or user_id + display_name = await self._resolve_sender_name_from_api(primary_id or union_id) + return { + "user_id": primary_id, + "user_name": display_name, + "user_id_alt": union_id, + } + + async def _resolve_sender_name_from_api(self, sender_id: Optional[str]) -> Optional[str]: + """Fetch the sender's display name from the Feishu contact API with a 10-minute cache. + + ID-type detection mirrors openclaw: ou_ → open_id, on_ → union_id, else user_id. + Failures are silently suppressed; the message pipeline must not block on name resolution. + """ + if not sender_id or not self._client: + return None + trimmed = sender_id.strip() + if not trimmed: + return None + now = time.time() + cached = self._sender_name_cache.get(trimmed) + if cached is not None: + name, expire_at = cached + if now < expire_at: + return name + try: + from lark_oapi.api.contact.v3 import GetUserRequest # lazy import + if trimmed.startswith("ou_"): + id_type = "open_id" + elif trimmed.startswith("on_"): + id_type = "union_id" + else: + id_type = "user_id" + request = GetUserRequest.builder().user_id(trimmed).user_id_type(id_type).build() + response = await asyncio.to_thread(self._client.contact.v3.user.get, request) + if not response or not response.success(): + return None + user = getattr(getattr(response, "data", None), "user", None) + name = ( + getattr(user, "name", None) + or getattr(user, "display_name", None) + or getattr(user, "nickname", None) + or getattr(user, "en_name", None) + ) + if name and isinstance(name, str): + name = name.strip() + if name: + self._sender_name_cache[trimmed] = (name, now + _FEISHU_SENDER_NAME_TTL_SECONDS) + return name + except Exception: + logger.debug("[Feishu] Failed to resolve sender name for %s", sender_id, exc_info=True) + return None + + async def _fetch_message_text(self, message_id: str) -> Optional[str]: + if not self._client or not message_id: + return None + if message_id in self._message_text_cache: + return self._message_text_cache[message_id] + try: + request = self._build_get_message_request(message_id) + response = await asyncio.to_thread(self._client.im.v1.message.get, request) + if not response or getattr(response, "success", lambda: False)() is False: + code = getattr(response, "code", "unknown") + msg = getattr(response, "msg", "message lookup failed") + logger.warning("[Feishu] Failed to fetch parent message %s: [%s] %s", message_id, code, msg) + return None + items = getattr(getattr(response, "data", None), "items", None) or [] + parent = items[0] if items else None + body = getattr(parent, "body", None) + msg_type = getattr(parent, "msg_type", "") or "" + raw_content = getattr(body, "content", "") or "" + text = self._extract_text_from_raw_content(msg_type=msg_type, raw_content=raw_content) + self._message_text_cache[message_id] = text + return text + except Exception: + logger.warning("[Feishu] Failed to fetch parent message %s", message_id, exc_info=True) + return None + + def _extract_text_from_raw_content(self, *, msg_type: str, raw_content: str) -> Optional[str]: + normalized = normalize_feishu_message(message_type=msg_type, raw_content=raw_content) + if normalized.text_content: + return normalized.text_content + placeholder = normalized.metadata.get("placeholder_text") if isinstance(normalized.metadata, dict) else None + return str(placeholder).strip() or None + + @staticmethod + def _default_image_media_type(ext: str) -> str: + normalized_ext = (ext or "").lower() + if normalized_ext in {".jpg", ".jpeg"}: + return "image/jpeg" + return f"image/{normalized_ext.lstrip('.') or 'jpeg'}" + + @staticmethod + def _log_background_failure(future: Any) -> None: + try: + future.result() + except Exception: + logger.exception("[Feishu] Background inbound processing failed") + + # ========================================================================= + # Group policy and mention gating + # ========================================================================= + + def _allow_group_message(self, sender_id: Any) -> bool: + """Current group policy gate for non-DM traffic.""" + if self._group_policy == "disabled": + return False + sender_open_id = getattr(sender_id, "open_id", None) or getattr(sender_id, "user_id", None) + if self._group_policy == "open": + return True + return bool(sender_open_id and sender_open_id in self._allowed_group_users) + + def _should_accept_group_message(self, message: Any, sender_id: Any) -> bool: + """Require an explicit @mention before group messages enter the agent.""" + if not self._allow_group_message(sender_id): + return False + # @_all is Feishu's @everyone placeholder — always route to the bot. + raw_content = getattr(message, "content", "") or "" + if "@_all" in raw_content: + return True + mentions = getattr(message, "mentions", None) or [] + if mentions: + return self._message_mentions_bot(mentions) + normalized = normalize_feishu_message( + message_type=getattr(message, "message_type", "") or "", + raw_content=raw_content, + ) + if normalized.mentioned_ids: + return self._post_mentions_bot(normalized.mentioned_ids) + return False + + def _message_mentions_bot(self, mentions: List[Any]) -> bool: + """Check whether any mention targets the configured or inferred bot identity.""" + for mention in mentions: + mention_id = getattr(mention, "id", None) + mention_open_id = getattr(mention_id, "open_id", None) + mention_user_id = getattr(mention_id, "user_id", None) + mention_name = (getattr(mention, "name", None) or "").strip() + + if self._bot_open_id and mention_open_id == self._bot_open_id: + return True + if self._bot_user_id and mention_user_id == self._bot_user_id: + return True + if self._bot_name and mention_name == self._bot_name: + return True + + return False + + def _post_mentions_bot(self, mentioned_ids: List[str]) -> bool: + if not mentioned_ids: + return False + if self._bot_open_id and self._bot_open_id in mentioned_ids: + return True + if self._bot_user_id and self._bot_user_id in mentioned_ids: + return True + return False + + async def _hydrate_bot_identity(self) -> None: + """Best-effort discovery of bot identity for precise group mention gating.""" + if not self._client: + return + if any((self._bot_open_id, self._bot_user_id, self._bot_name)): + return + try: + request = self._build_get_application_request(app_id=self._app_id, lang="en_us") + response = await asyncio.to_thread(self._client.application.v6.application.get, request) + if not response or not response.success(): + code = getattr(response, "code", None) + if code == 99991672: + logger.warning( + "[Feishu] Unable to hydrate bot identity from application info. " + "Grant admin:app.info:readonly or application:application:self_manage " + "so group @mention gating can resolve the bot name precisely." + ) + return + app = getattr(getattr(response, "data", None), "app", None) + app_name = (getattr(app, "app_name", None) or "").strip() + if app_name: + self._bot_name = app_name + except Exception: + logger.debug("[Feishu] Failed to hydrate bot identity", exc_info=True) + + # ========================================================================= + # Deduplication — seen message ID cache (persistent) + # ========================================================================= + + def _load_seen_message_ids(self) -> None: + try: + payload = json.loads(self._dedup_state_path.read_text(encoding="utf-8")) + except FileNotFoundError: + return + except (OSError, json.JSONDecodeError): + logger.warning("[Feishu] Failed to load persisted dedup state from %s", self._dedup_state_path, exc_info=True) + return + seen_data = payload.get("message_ids", {}) if isinstance(payload, dict) else {} + now = time.time() + ttl = _FEISHU_DEDUP_TTL_SECONDS + # Backward-compat: old format stored a plain list of IDs (no timestamps). + if isinstance(seen_data, list): + entries: Dict[str, float] = {str(item).strip(): 0.0 for item in seen_data if str(item).strip()} + elif isinstance(seen_data, dict): + entries = {k: float(v) for k, v in seen_data.items() if isinstance(k, str) and k.strip()} + else: + return + # Filter out TTL-expired entries (entries saved with ts=0.0 are treated as immortal + # for one migration cycle to avoid nuking old data on first upgrade). + valid: Dict[str, float] = { + msg_id: ts for msg_id, ts in entries.items() + if ts == 0.0 or ttl <= 0 or now - ts < ttl + } + # Apply size cap; keep the most recently seen IDs. + sorted_ids = sorted(valid, key=lambda k: valid[k], reverse=True)[:self._dedup_cache_size] + self._seen_message_order = list(reversed(sorted_ids)) + self._seen_message_ids = {k: valid[k] for k in sorted_ids} + + def _persist_seen_message_ids(self) -> None: + try: + self._dedup_state_path.parent.mkdir(parents=True, exist_ok=True) + recent = self._seen_message_order[-self._dedup_cache_size:] + # Save as {msg_id: timestamp} so TTL filtering works across restarts. + payload = {"message_ids": {k: self._seen_message_ids[k] for k in recent if k in self._seen_message_ids}} + self._dedup_state_path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") + except OSError: + logger.warning("[Feishu] Failed to persist dedup state to %s", self._dedup_state_path, exc_info=True) + + def _is_duplicate(self, message_id: str) -> bool: + now = time.time() + ttl = _FEISHU_DEDUP_TTL_SECONDS + with self._dedup_lock: + seen_at = self._seen_message_ids.get(message_id) + if seen_at is not None and (ttl <= 0 or now - seen_at < ttl): + return True + # Record with current wall-clock timestamp so TTL works across restarts. + self._seen_message_ids[message_id] = now + self._seen_message_order.append(message_id) + while len(self._seen_message_order) > self._dedup_cache_size: + stale = self._seen_message_order.pop(0) + self._seen_message_ids.pop(stale, None) + self._persist_seen_message_ids() + return False + + # ========================================================================= + # Outbound payload construction and send pipeline + # ========================================================================= + + def _build_outbound_payload(self, content: str) -> tuple[str, str]: + if _MARKDOWN_HINT_RE.search(content): + return "post", _build_markdown_post_payload(content) + text_payload = {"text": content} + return "text", json.dumps(text_payload, ensure_ascii=False) + + async def _send_uploaded_file_message( + self, + *, + chat_id: str, + file_path: str, + reply_to: Optional[str], + metadata: Optional[Dict[str, Any]], + caption: Optional[str] = None, + file_name: Optional[str] = None, + outbound_message_type: str = "file", + ) -> SendResult: + if not self._client: + return SendResult(success=False, error="Not connected") + if not os.path.exists(file_path): + return SendResult(success=False, error=f"File not found: {file_path}") + + display_name = file_name or os.path.basename(file_path) + upload_file_type, resolved_message_type = self._resolve_outbound_file_routing( + file_path=display_name, + requested_message_type=outbound_message_type, + ) + try: + with open(file_path, "rb") as file_obj: + body = self._build_file_upload_body( + file_type=upload_file_type, + file_name=display_name, + file=file_obj, + ) + request = self._build_file_upload_request(body) + upload_response = await asyncio.to_thread(self._client.im.v1.file.create, request) + file_key = self._extract_response_field(upload_response, "file_key") + if not file_key: + return self._response_error_result( + upload_response, + default_message="file upload failed", + override_error="Feishu file upload missing file_key", + ) + + if caption: + media_tag = { + "tag": "media", + "file_key": file_key, + "file_name": display_name, + } + message_response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type="post", + payload=self._build_media_post_payload(caption=caption, media_tag=media_tag), + reply_to=reply_to, + metadata=metadata, + ) + else: + message_response = await self._feishu_send_with_retry( + chat_id=chat_id, + msg_type=resolved_message_type, + payload=json.dumps({"file_key": file_key}, ensure_ascii=False), + reply_to=reply_to, + metadata=metadata, + ) + return self._finalize_send_result(message_response, "file send failed") + except Exception as exc: + logger.error("[Feishu] Failed to send file %s: %s", file_path, exc, exc_info=True) + return SendResult(success=False, error=str(exc)) + + async def _send_raw_message( + self, + *, + chat_id: str, + msg_type: str, + payload: str, + reply_to: Optional[str], + metadata: Optional[Dict[str, Any]], + ) -> Any: + reply_in_thread = bool((metadata or {}).get("thread_id")) + if reply_to: + body = self._build_reply_message_body( + content=payload, + msg_type=msg_type, + reply_in_thread=reply_in_thread, + uuid_value=str(uuid.uuid4()), + ) + request = self._build_reply_message_request(reply_to, body) + return await asyncio.to_thread(self._client.im.v1.message.reply, request) + + body = self._build_create_message_body( + receive_id=chat_id, + msg_type=msg_type, + content=payload, + uuid_value=str(uuid.uuid4()), + ) + request = self._build_create_message_request("chat_id", body) + return await asyncio.to_thread(self._client.im.v1.message.create, request) + + @staticmethod + def _response_succeeded(response: Any) -> bool: + return bool(response and getattr(response, "success", lambda: False)()) + + @staticmethod + def _extract_response_field(response: Any, field_name: str) -> Any: + if not FeishuAdapter._response_succeeded(response): + return None + data = getattr(response, "data", None) + return getattr(data, field_name, None) if data else None + + def _response_error_result( + self, + response: Any, + *, + default_message: str, + override_error: Optional[str] = None, + ) -> SendResult: + if override_error: + return SendResult(success=False, error=override_error, raw_response=response) + code = getattr(response, "code", "unknown") + msg = getattr(response, "msg", default_message) + return SendResult(success=False, error=f"[{code}] {msg}", raw_response=response) + + def _finalize_send_result(self, response: Any, default_message: str) -> SendResult: + if not self._response_succeeded(response): + return self._response_error_result(response, default_message=default_message) + return SendResult( + success=True, + message_id=self._extract_response_field(response, "message_id"), + raw_response=response, + ) + + # ========================================================================= + # Connection internals — websocket / webhook setup + # ========================================================================= + + async def _connect_with_retry(self) -> None: + for attempt in range(_FEISHU_CONNECT_ATTEMPTS): + try: + if self._connection_mode == "websocket": + await self._connect_websocket() + else: + await self._connect_webhook() + return + except Exception as exc: + self._running = False + self._disable_websocket_auto_reconnect() + self._ws_future = None + await self._stop_webhook_server() + if attempt >= _FEISHU_CONNECT_ATTEMPTS - 1: + raise + wait_seconds = 2 ** attempt + logger.warning( + "[Feishu] Connect attempt %d/%d failed; retrying in %ds: %s", + attempt + 1, + _FEISHU_CONNECT_ATTEMPTS, + wait_seconds, + exc, + ) + await asyncio.sleep(wait_seconds) + + async def _connect_websocket(self) -> None: + if not FEISHU_WEBSOCKET_AVAILABLE: + raise RuntimeError("websockets not installed; websocket mode unavailable") + domain = FEISHU_DOMAIN if self._domain_name != "lark" else LARK_DOMAIN + self._client = self._build_lark_client(domain) + await self._hydrate_bot_identity() + self._ws_client = FeishuWSClient( + app_id=self._app_id, + app_secret=self._app_secret, + log_level=lark.LogLevel.INFO, + event_handler=self._event_handler, + domain=domain, + ) + self._ws_future = self._loop.run_in_executor( + None, + _run_official_feishu_ws_client, + self._ws_client, + ) + + async def _connect_webhook(self) -> None: + if not FEISHU_WEBHOOK_AVAILABLE: + raise RuntimeError("aiohttp not installed; webhook mode unavailable") + domain = FEISHU_DOMAIN if self._domain_name != "lark" else LARK_DOMAIN + self._client = self._build_lark_client(domain) + await self._hydrate_bot_identity() + app = web.Application() + app.router.add_post(self._webhook_path, self._handle_webhook_request) + self._webhook_runner = web.AppRunner(app) + await self._webhook_runner.setup() + self._webhook_site = web.TCPSite(self._webhook_runner, self._webhook_host, self._webhook_port) + await self._webhook_site.start() + + def _build_lark_client(self, domain: Any) -> Any: + return ( + lark.Client.builder() + .app_id(self._app_id) + .app_secret(self._app_secret) + .domain(domain) + .log_level(lark.LogLevel.WARNING) + .build() + ) + + async def _feishu_send_with_retry( + self, + *, + chat_id: str, + msg_type: str, + payload: str, + reply_to: Optional[str], + metadata: Optional[Dict[str, Any]], + ) -> Any: + last_error: Optional[Exception] = None + active_reply_to = reply_to + for attempt in range(_FEISHU_SEND_ATTEMPTS): + try: + response = await self._send_raw_message( + chat_id=chat_id, + msg_type=msg_type, + payload=payload, + reply_to=active_reply_to, + metadata=metadata, + ) + # If replying to a message failed because it was withdrawn or not found, + # fall back to posting a new message directly to the chat. + if active_reply_to and not self._response_succeeded(response): + code = getattr(response, "code", None) + if code in _FEISHU_REPLY_FALLBACK_CODES: + logger.warning( + "[Feishu] Reply to %s failed (code %s — message withdrawn/missing); " + "falling back to new message in chat %s", + active_reply_to, + code, + chat_id, + ) + active_reply_to = None + response = await self._send_raw_message( + chat_id=chat_id, + msg_type=msg_type, + payload=payload, + reply_to=None, + metadata=metadata, + ) + return response + except Exception as exc: + last_error = exc + if msg_type == "post" and _POST_CONTENT_INVALID_RE.search(str(exc)): + raise + if attempt >= _FEISHU_SEND_ATTEMPTS - 1: + raise + wait_seconds = 2 ** attempt + logger.warning( + "[Feishu] Send attempt %d/%d failed for chat %s; retrying in %ds: %s", + attempt + 1, + _FEISHU_SEND_ATTEMPTS, + chat_id, + wait_seconds, + exc, + ) + await asyncio.sleep(wait_seconds) + raise last_error or RuntimeError("Feishu send failed") + + async def _release_app_lock(self) -> None: + if not self._app_lock_identity: + return + try: + release_scoped_lock(_FEISHU_APP_LOCK_SCOPE, self._app_lock_identity) + except Exception as exc: + logger.warning("[Feishu] Failed to release app lock: %s", exc, exc_info=True) + finally: + self._app_lock_identity = None + + # ========================================================================= + # Lark API request builders + # ========================================================================= + + @staticmethod + def _build_get_chat_request(chat_id: str) -> Any: + if "GetChatRequest" in globals(): + return GetChatRequest.builder().chat_id(chat_id).build() + return SimpleNamespace(chat_id=chat_id) + + @staticmethod + def _build_get_message_request(message_id: str) -> Any: + if "GetMessageRequest" in globals(): + return GetMessageRequest.builder().message_id(message_id).build() + return SimpleNamespace(message_id=message_id) + + @staticmethod + def _build_message_resource_request(*, message_id: str, file_key: str, resource_type: str) -> Any: + if "GetMessageResourceRequest" in globals(): + return ( + GetMessageResourceRequest.builder() + .message_id(message_id) + .file_key(file_key) + .type(resource_type) + .build() + ) + return SimpleNamespace(message_id=message_id, file_key=file_key, type=resource_type) + + @staticmethod + def _build_get_application_request(*, app_id: str, lang: str) -> Any: + if "GetApplicationRequest" in globals(): + return ( + GetApplicationRequest.builder() + .app_id(app_id) + .lang(lang) + .build() + ) + return SimpleNamespace(app_id=app_id, lang=lang) + + @staticmethod + def _build_reply_message_body(*, content: str, msg_type: str, reply_in_thread: bool, uuid_value: str) -> Any: + if "ReplyMessageRequestBody" in globals(): + return ( + ReplyMessageRequestBody.builder() + .content(content) + .msg_type(msg_type) + .reply_in_thread(reply_in_thread) + .uuid(uuid_value) + .build() + ) + return SimpleNamespace( + content=content, + msg_type=msg_type, + reply_in_thread=reply_in_thread, + uuid=uuid_value, + ) + + @staticmethod + def _build_reply_message_request(message_id: str, request_body: Any) -> Any: + if "ReplyMessageRequest" in globals(): + return ( + ReplyMessageRequest.builder() + .message_id(message_id) + .request_body(request_body) + .build() + ) + return SimpleNamespace(message_id=message_id, request_body=request_body) + + @staticmethod + def _build_update_message_body(*, msg_type: str, content: str) -> Any: + if "UpdateMessageRequestBody" in globals(): + return ( + UpdateMessageRequestBody.builder() + .msg_type(msg_type) + .content(content) + .build() + ) + return SimpleNamespace(msg_type=msg_type, content=content) + + @staticmethod + def _build_update_message_request(message_id: str, request_body: Any) -> Any: + if "UpdateMessageRequest" in globals(): + return ( + UpdateMessageRequest.builder() + .message_id(message_id) + .request_body(request_body) + .build() + ) + return SimpleNamespace(message_id=message_id, request_body=request_body) + + @staticmethod + def _build_create_message_body(*, receive_id: str, msg_type: str, content: str, uuid_value: str) -> Any: + if "CreateMessageRequestBody" in globals(): + return ( + CreateMessageRequestBody.builder() + .receive_id(receive_id) + .msg_type(msg_type) + .content(content) + .uuid(uuid_value) + .build() + ) + return SimpleNamespace( + receive_id=receive_id, + msg_type=msg_type, + content=content, + uuid=uuid_value, + ) + + @staticmethod + def _build_create_message_request(receive_id_type: str, request_body: Any) -> Any: + if "CreateMessageRequest" in globals(): + return ( + CreateMessageRequest.builder() + .receive_id_type(receive_id_type) + .request_body(request_body) + .build() + ) + return SimpleNamespace(receive_id_type=receive_id_type, request_body=request_body) + + @staticmethod + def _build_image_upload_body(*, image_type: str, image: Any) -> Any: + if "CreateImageRequestBody" in globals(): + return ( + CreateImageRequestBody.builder() + .image_type(image_type) + .image(image) + .build() + ) + return SimpleNamespace(image_type=image_type, image=image) + + @staticmethod + def _build_image_upload_request(request_body: Any) -> Any: + if "CreateImageRequest" in globals(): + return CreateImageRequest.builder().request_body(request_body).build() + return SimpleNamespace(request_body=request_body) + + @staticmethod + def _build_file_upload_body(*, file_type: str, file_name: str, file: Any) -> Any: + if "CreateFileRequestBody" in globals(): + return ( + CreateFileRequestBody.builder() + .file_type(file_type) + .file_name(file_name) + .file(file) + .build() + ) + return SimpleNamespace(file_type=file_type, file_name=file_name, file=file) + + @staticmethod + def _build_file_upload_request(request_body: Any) -> Any: + if "CreateFileRequest" in globals(): + return CreateFileRequest.builder().request_body(request_body).build() + return SimpleNamespace(request_body=request_body) + + def _build_post_payload(self, content: str) -> str: + return _build_markdown_post_payload(content) + + def _build_media_post_payload(self, *, caption: str, media_tag: Dict[str, str]) -> str: + payload = json.loads(self._build_post_payload(caption)) + content = payload.setdefault("zh_cn", {}).setdefault("content", []) + content.append([media_tag]) + return json.dumps(payload, ensure_ascii=False) + + @staticmethod + def _resolve_outbound_file_routing( + *, + file_path: str, + requested_message_type: str, + ) -> tuple[str, str]: + ext = Path(file_path).suffix.lower() + + if ext in _FEISHU_OPUS_UPLOAD_EXTENSIONS: + return "opus", "audio" + + if ext in _FEISHU_MEDIA_UPLOAD_EXTENSIONS: + return "mp4", "media" + + if ext in _FEISHU_DOC_UPLOAD_TYPES: + return _FEISHU_DOC_UPLOAD_TYPES[ext], "file" + + if requested_message_type == "file": + return _FEISHU_FILE_UPLOAD_TYPE, "file" + + return _FEISHU_FILE_UPLOAD_TYPE, "file" diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index d353b8294..3d6a90502 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -40,7 +40,9 @@ logger = logging.getLogger(__name__) MAX_MESSAGE_LENGTH = 4000 # Store directory for E2EE keys and sync state. -_STORE_DIR = Path.home() / ".hermes" / "matrix" / "store" +# Uses get_hermes_home() so each profile gets its own Matrix store. +from hermes_constants import get_hermes_dir as _get_hermes_dir +_STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 @@ -161,22 +163,49 @@ class MatrixAdapter(BasePlatformAdapter): # Authenticate. if self._access_token: client.access_token = self._access_token - # Resolve user_id if not set. - if not self._user_id: - resp = await client.whoami() - if isinstance(resp, nio.WhoamiResponse): - self._user_id = resp.user_id - client.user_id = resp.user_id - logger.info("Matrix: authenticated as %s", self._user_id) - else: - logger.error( - "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" + + # With access-token auth, always resolve whoami so we validate the + # token and learn the device_id. The device_id matters for E2EE: + # without it, matrix-nio can send plain messages but may fail to + # decrypt inbound encrypted events or encrypt outbound room sends. + resp = await client.whoami() + if isinstance(resp, nio.WhoamiResponse): + resolved_user_id = getattr(resp, "user_id", "") or self._user_id + resolved_device_id = getattr(resp, "device_id", "") + if resolved_user_id: + self._user_id = resolved_user_id + + # restore_login() is the matrix-nio path that binds the access + # token to a specific device and loads the crypto store. + if resolved_device_id and hasattr(client, "restore_login"): + client.restore_login( + self._user_id or resolved_user_id, + resolved_device_id, + self._access_token, ) - await client.close() - return False + else: + if self._user_id: + client.user_id = self._user_id + if resolved_device_id: + client.device_id = resolved_device_id + client.access_token = self._access_token + if self._encryption: + logger.warning( + "Matrix: access-token login did not restore E2EE state; " + "encrypted rooms may fail until a device_id is available" + ) + + logger.info( + "Matrix: using access token for %s%s", + self._user_id or "(unknown user)", + f" (device {resolved_device_id})" if resolved_device_id else "", + ) else: - client.user_id = self._user_id - logger.info("Matrix: using access token for %s", self._user_id) + logger.error( + "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" + ) + await client.close() + return False elif self._password and self._user_id: resp = await client.login( self._password, @@ -194,13 +223,18 @@ class MatrixAdapter(BasePlatformAdapter): return False # If E2EE is enabled, load the crypto store. - if self._encryption and hasattr(client, "olm"): + if self._encryption and getattr(client, "olm", None): try: if client.should_upload_keys: await client.keys_upload() logger.info("Matrix: E2EE crypto initialized") except Exception as exc: logger.warning("Matrix: crypto init issue: %s", exc) + elif self._encryption: + logger.warning( + "Matrix: E2EE requested but crypto store is not loaded; " + "encrypted rooms may fail" + ) # Register event callbacks. client.add_event_callback(self._on_room_message, nio.RoomMessageText) @@ -230,6 +264,7 @@ class MatrixAdapter(BasePlatformAdapter): ) # Build DM room cache from m.direct account data. await self._refresh_dm_cache() + await self._run_e2ee_maintenance() else: logger.warning("Matrix: initial sync returned %s", type(resp).__name__) @@ -301,13 +336,48 @@ class MatrixAdapter(BasePlatformAdapter): relates_to["m.in_reply_to"] = {"event_id": reply_to} msg_content["m.relates_to"] = relates_to - resp = await self._client.room_send( - chat_id, - "m.room.message", - msg_content, - ) + async def _room_send_once(*, ignore_unverified_devices: bool = False): + return await asyncio.wait_for( + self._client.room_send( + chat_id, + "m.room.message", + msg_content, + ignore_unverified_devices=ignore_unverified_devices, + ), + timeout=45, + ) + + try: + resp = await _room_send_once(ignore_unverified_devices=False) + except Exception as exc: + retryable = isinstance(exc, asyncio.TimeoutError) + olm_unverified = getattr(nio, "OlmUnverifiedDeviceError", None) + send_retry = getattr(nio, "SendRetryError", None) + if isinstance(olm_unverified, type) and isinstance(exc, olm_unverified): + retryable = True + if isinstance(send_retry, type) and isinstance(exc, send_retry): + retryable = True + + if not retryable: + logger.error("Matrix: failed to send to %s: %s", chat_id, exc) + return SendResult(success=False, error=str(exc)) + + logger.warning( + "Matrix: initial encrypted send to %s failed (%s); " + "retrying after E2EE maintenance with ignored unverified devices", + chat_id, + exc, + ) + await self._run_e2ee_maintenance() + try: + resp = await _room_send_once(ignore_unverified_devices=True) + except Exception as retry_exc: + logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc) + return SendResult(success=False, error=str(retry_exc)) + if isinstance(resp, nio.RoomSendResponse): last_event_id = resp.event_id + logger.info("Matrix: sent event %s to %s", last_event_id, chat_id) else: err = getattr(resp, "message", str(resp)) logger.error("Matrix: failed to send to %s: %s", chat_id, err) @@ -551,9 +621,23 @@ class MatrixAdapter(BasePlatformAdapter): async def _sync_loop(self) -> None: """Continuously sync with the homeserver.""" + import nio + while not self._closing: try: - await self._client.sync(timeout=30000) + resp = await self._client.sync(timeout=30000) + if isinstance(resp, nio.SyncError): + if self._closing: + return + logger.warning( + "Matrix: sync returned %s: %s — retrying in 5s", + type(resp).__name__, + getattr(resp, "message", resp), + ) + await asyncio.sleep(5) + continue + + await self._run_e2ee_maintenance() except asyncio.CancelledError: return except Exception as exc: @@ -562,6 +646,38 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: sync error: %s — retrying in 5s", exc) await asyncio.sleep(5) + async def _run_e2ee_maintenance(self) -> None: + """Run matrix-nio E2EE housekeeping between syncs. + + Hermes uses a custom sync loop instead of matrix-nio's sync_forever(), + so we need to explicitly drive the key management work that sync_forever() + normally handles for encrypted rooms. + """ + client = self._client + if not client or not self._encryption or not getattr(client, "olm", None): + return + + tasks = [asyncio.create_task(client.send_to_device_messages())] + + if client.should_upload_keys: + tasks.append(asyncio.create_task(client.keys_upload())) + + if client.should_query_keys: + tasks.append(asyncio.create_task(client.keys_query())) + + if client.should_claim_keys: + users = client.get_users_for_key_claiming() + if users: + tasks.append(asyncio.create_task(client.keys_claim(users))) + + for task in asyncio.as_completed(tasks): + try: + await task + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Matrix: E2EE maintenance task failed: %s", exc) + # ------------------------------------------------------------------ # Event callbacks # ------------------------------------------------------------------ diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py index 0f66577ff..c134bb35d 100644 --- a/gateway/platforms/mattermost.py +++ b/gateway/platforms/mattermost.py @@ -407,18 +407,38 @@ class MattermostAdapter(BasePlatformAdapter): kind: str = "file", ) -> SendResult: """Download a URL and upload it as a file attachment.""" + import asyncio import aiohttp - try: - async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: - if resp.status >= 400: - # Fall back to sending the URL as text. - return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) - file_data = await resp.read() - ct = resp.content_type or "application/octet-stream" - # Derive filename from URL. - fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png" - except Exception as exc: - logger.warning("Mattermost: failed to download %s: %s", url, exc) + + last_exc = None + file_data = None + ct = "application/octet-stream" + fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png" + + for attempt in range(3): + try: + async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + if resp.status >= 500 or resp.status == 429: + if attempt < 2: + logger.debug("Mattermost download retry %d/2 for %s (status %d)", + attempt + 1, url[:80], resp.status) + await asyncio.sleep(1.5 * (attempt + 1)) + continue + if resp.status >= 400: + return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) + file_data = await resp.read() + ct = resp.content_type or "application/octet-stream" + break + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + last_exc = exc + if attempt < 2: + await asyncio.sleep(1.5 * (attempt + 1)) + continue + logger.warning("Mattermost: failed to download %s after %d attempts: %s", url, attempt + 1, exc) + return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) + + if file_data is None: + logger.warning("Mattermost: download returned no data for %s", url) return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) file_id = await self._upload_file(chat_id, file_data, fname, ct) @@ -583,9 +603,19 @@ class MattermostAdapter(BasePlatformAdapter): # For DMs, user_id is sufficient. For channels, check for @mention. message_text = post.get("message", "") - # Mention-only mode: skip channel messages that don't @mention the bot. - # DMs (type "D") are always processed. + # Mention-gating for non-DM channels. + # Config (env vars): + # MATTERMOST_REQUIRE_MENTION: Require @mention in channels (default: true) + # MATTERMOST_FREE_RESPONSE_CHANNELS: Channel IDs where bot responds without mention if channel_type_raw != "D": + require_mention = os.getenv( + "MATTERMOST_REQUIRE_MENTION", "true" + ).lower() not in ("false", "0", "no") + + free_channels_raw = os.getenv("MATTERMOST_FREE_RESPONSE_CHANNELS", "") + free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()} + is_free_channel = channel_id in free_channels + mention_patterns = [ f"@{self._bot_username}", f"@{self._bot_user_id}", @@ -594,13 +624,21 @@ class MattermostAdapter(BasePlatformAdapter): pattern.lower() in message_text.lower() for pattern in mention_patterns ) - if not has_mention: + + if require_mention and not is_free_channel and not has_mention: logger.debug( "Mattermost: skipping non-DM message without @mention (channel=%s)", channel_id, ) return + # Strip @mention from the message text so the agent sees clean input. + if has_mention: + for pattern in mention_patterns: + message_text = re.sub( + re.escape(pattern), "", message_text, flags=re.IGNORECASE + ).strip() + # Resolve sender info. sender_id = post.get("user_id", "") sender_name = data.get("sender_name", "").lstrip("@") or sender_id diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 39c3814fb..1629e0863 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -22,7 +22,7 @@ import time from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Any -from urllib.parse import unquote +from urllib.parse import quote, unquote import httpx @@ -184,6 +184,8 @@ class SignalAdapter(BasePlatformAdapter): self._recent_sent_timestamps: set = set() self._max_recent_timestamps = 50 + self._phone_lock_identity: Optional[str] = None + logger.info("Signal adapter initialized: url=%s account=%s groups=%s", self.http_url, _redact_phone(self.account), "enabled" if self.group_allow_from else "disabled") @@ -198,6 +200,29 @@ class SignalAdapter(BasePlatformAdapter): logger.error("Signal: SIGNAL_HTTP_URL and SIGNAL_ACCOUNT are required") return False + # Acquire scoped lock to prevent duplicate Signal listeners for the same phone + try: + from gateway.status import acquire_scoped_lock + + self._phone_lock_identity = self.account + acquired, existing = acquire_scoped_lock( + "signal-phone", + self._phone_lock_identity, + metadata={"platform": self.platform.value}, + ) + if not acquired: + owner_pid = existing.get("pid") if isinstance(existing, dict) else None + message = ( + "Another local Hermes gateway is already using this Signal account" + + (f" (PID {owner_pid})." if owner_pid else ".") + + " Stop the other gateway before starting a second Signal listener." + ) + logger.error("Signal: %s", message) + self._set_fatal_error("signal_phone_lock", message, retryable=False) + return False + except Exception as e: + logger.warning("Signal: Could not acquire phone lock (non-fatal): %s", e) + self.client = httpx.AsyncClient(timeout=30.0) # Health check — verify signal-cli daemon is reachable @@ -245,6 +270,14 @@ class SignalAdapter(BasePlatformAdapter): await self.client.aclose() self.client = None + if self._phone_lock_identity: + try: + from gateway.status import release_scoped_lock + release_scoped_lock("signal-phone", self._phone_lock_identity) + except Exception as e: + logger.warning("Signal: Error releasing phone lock: %s", e, exc_info=True) + self._phone_lock_identity = None + logger.info("Signal: disconnected") # ------------------------------------------------------------------ @@ -253,7 +286,7 @@ class SignalAdapter(BasePlatformAdapter): async def _sse_listener(self) -> None: """Listen for SSE events from signal-cli daemon.""" - url = f"{self.http_url}/api/v1/events?account={self.account}" + url = f"{self.http_url}/api/v1/events?account={quote(self.account, safe='')}" backoff = SSE_RETRY_DELAY_INITIAL while self._running: @@ -279,6 +312,12 @@ class SignalAdapter(BasePlatformAdapter): line = line.strip() if not line: continue + # SSE keepalive comments (":") prove the connection + # is alive — update activity so the health monitor + # doesn't report false idle warnings. + if line.startswith(":"): + self._last_sse_activity = time.time() + continue # Parse SSE data lines if line.startswith("data:"): data_str = line[5:].strip() @@ -515,7 +554,7 @@ class SignalAdapter(BasePlatformAdapter): """Fetch an attachment via JSON-RPC and cache it. Returns (path, ext).""" result = await self._rpc("getAttachment", { "account": self.account, - "attachmentId": attachment_id, + "id": attachment_id, }) if not result: diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index e8163e26e..2a7e046f8 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -93,6 +93,17 @@ class SlackAdapter(BasePlatformAdapter): return False try: + # Acquire scoped lock to prevent duplicate app token usage + from gateway.status import acquire_scoped_lock + self._token_lock_identity = app_token + acquired, existing = acquire_scoped_lock('slack-app-token', app_token, metadata={'platform': 'slack'}) + if not acquired: + owner_pid = existing.get('pid') if isinstance(existing, dict) else None + message = f'Slack app token already in use' + (f' (PID {owner_pid})' if owner_pid else '') + '. Stop the other gateway first.' + logger.error('[%s] %s', self.name, message) + self._set_fatal_error('slack_token_lock', message, retryable=False) + return False + self._app = AsyncApp(token=bot_token) # Get our own bot user ID for mention detection @@ -138,6 +149,16 @@ class SlackAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.warning("[Slack] Error while closing Socket Mode handler: %s", e, exc_info=True) self._running = False + + # Release the token lock (use stored identity, not re-read env) + try: + from gateway.status import release_scoped_lock + if getattr(self, '_token_lock_identity', None): + release_scoped_lock('slack-app-token', self._token_lock_identity) + self._token_lock_identity = None + except Exception: + pass + logger.info("[Slack] Disconnected") async def send( @@ -819,33 +840,65 @@ class SlackAdapter(BasePlatformAdapter): await self.handle_message(event) async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str: - """Download a Slack file using the bot token for auth.""" + """Download a Slack file using the bot token for auth, with retry.""" + import asyncio import httpx bot_token = self.config.token - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get( - url, - headers={"Authorization": f"Bearer {bot_token}"}, - ) - response.raise_for_status() + last_exc = None - if audio: - from gateway.platforms.base import cache_audio_from_bytes - return cache_audio_from_bytes(response.content, ext) - else: - from gateway.platforms.base import cache_image_from_bytes - return cache_image_from_bytes(response.content, ext) + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + for attempt in range(3): + try: + response = await client.get( + url, + headers={"Authorization": f"Bearer {bot_token}"}, + ) + response.raise_for_status() + + if audio: + from gateway.platforms.base import cache_audio_from_bytes + return cache_audio_from_bytes(response.content, ext) + else: + from gateway.platforms.base import cache_image_from_bytes + return cache_image_from_bytes(response.content, ext) + except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: + last_exc = exc + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: + raise + if attempt < 2: + logger.debug("Slack file download retry %d/2 for %s: %s", + attempt + 1, url[:80], exc) + await asyncio.sleep(1.5 * (attempt + 1)) + continue + raise + raise last_exc async def _download_slack_file_bytes(self, url: str) -> bytes: - """Download a Slack file and return raw bytes.""" + """Download a Slack file and return raw bytes, with retry.""" + import asyncio import httpx bot_token = self.config.token + last_exc = None + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get( - url, - headers={"Authorization": f"Bearer {bot_token}"}, - ) - response.raise_for_status() - return response.content + for attempt in range(3): + try: + response = await client.get( + url, + headers={"Authorization": f"Bearer {bot_token}"}, + ) + response.raise_for_status() + return response.content + except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: + last_exc = exc + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: + raise + if attempt < 2: + logger.debug("Slack file download retry %d/2 for %s: %s", + attempt + 1, url[:80], exc) + await asyncio.sleep(1.5 * (attempt + 1)) + continue + raise + raise last_exc diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 549d09ce3..5f497221f 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -11,7 +11,7 @@ import asyncio import logging import os import re -from typing import Dict, Optional, Any +from typing import Dict, List, Optional, Any logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ try: filters, ) from telegram.constants import ParseMode, ChatType + from telegram.request import HTTPXRequest TELEGRAM_AVAILABLE = True except ImportError: TELEGRAM_AVAILABLE = False @@ -34,6 +35,7 @@ except ImportError: Application = Any CommandHandler = Any TelegramMessageHandler = Any + HTTPXRequest = Any filters = None ParseMode = None ChatType = None @@ -59,6 +61,11 @@ from gateway.platforms.base import ( cache_document_from_bytes, SUPPORTED_DOCUMENT_TYPES, ) +from gateway.platforms.telegram_network import ( + TelegramFallbackTransport, + discover_fallback_ips, + parse_fallback_ip_env, +) def check_telegram_requirements() -> bool: @@ -138,6 +145,13 @@ class TelegramAdapter(BasePlatformAdapter): # DM Topics config from extra.dm_topics self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", []) + def _fallback_ips(self) -> list[str]: + """Return validated fallback IPs from config (populated by _apply_env_overrides).""" + configured = self.config.extra.get("fallback_ips", []) if getattr(self.config, "extra", None) else [] + if isinstance(configured, str): + configured = configured.split(",") + return parse_fallback_ip_env(",".join(str(v) for v in configured) if configured else None) + @staticmethod def _looks_like_polling_conflict(error: Exception) -> bool: text = str(error).lower() @@ -331,7 +345,8 @@ class TelegramAdapter(BasePlatformAdapter): def _persist_dm_topic_thread_id(self, chat_id: int, topic_name: str, thread_id: int) -> None: """Save a newly created thread_id back into config.yaml so it persists across restarts.""" try: - config_path = _Path.home() / ".hermes" / "config.yaml" + from hermes_constants import get_hermes_home + config_path = get_hermes_home() / "config.yaml" if not config_path.exists(): logger.warning("[%s] Config file not found at %s, cannot persist thread_id", self.name, config_path) return @@ -474,7 +489,26 @@ class TelegramAdapter(BasePlatformAdapter): return False # Build the application - self._app = Application.builder().token(self.config.token).build() + builder = Application.builder().token(self.config.token) + fallback_ips = self._fallback_ips() + if not fallback_ips: + fallback_ips = await discover_fallback_ips() + logger.info( + "[%s] Auto-discovered Telegram fallback IPs: %s", + self.name, + ", ".join(fallback_ips), + ) + if fallback_ips: + logger.warning( + "[%s] Telegram fallback IPs active: %s", + self.name, + ", ".join(fallback_ips), + ) + transport = TelegramFallbackTransport(fallback_ips) + request = HTTPXRequest(httpx_kwargs={"transport": transport}) + get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport}) + builder = builder.request(request).get_updates_request(get_updates_request) + self._app = builder.build() self._bot = self._app.bot # Register handlers @@ -674,9 +708,15 @@ class TelegramAdapter(BasePlatformAdapter): except ImportError: _NetErr = OSError # type: ignore[misc,assignment] + try: + from telegram.error import BadRequest as _BadReq + except ImportError: + _BadReq = None # type: ignore[assignment,misc] + for i, chunk in enumerate(chunks): should_thread = self._should_thread_reply(reply_to, i) reply_to_id = int(reply_to) if should_thread else None + effective_thread_id = int(thread_id) if thread_id else None msg = None for _send_attempt in range(3): @@ -688,7 +728,7 @@ class TelegramAdapter(BasePlatformAdapter): text=chunk, parse_mode=ParseMode.MARKDOWN_V2, reply_to_message_id=reply_to_id, - message_thread_id=int(thread_id) if thread_id else None, + message_thread_id=effective_thread_id, ) except Exception as md_error: # Markdown parsing failed, try plain text @@ -700,12 +740,40 @@ class TelegramAdapter(BasePlatformAdapter): text=plain_chunk, parse_mode=None, reply_to_message_id=reply_to_id, - message_thread_id=int(thread_id) if thread_id else None, + message_thread_id=effective_thread_id, ) else: raise break # success except _NetErr as send_err: + # BadRequest is a subclass of NetworkError in + # python-telegram-bot but represents permanent errors + # (not transient network issues). Detect and handle + # specific cases instead of blindly retrying. + if _BadReq and isinstance(send_err, _BadReq): + err_lower = str(send_err).lower() + if "thread not found" in err_lower and effective_thread_id is not None: + # Thread doesn't exist — retry without + # message_thread_id so the message still + # reaches the chat. + logger.warning( + "[%s] Thread %s not found, retrying without message_thread_id", + self.name, effective_thread_id, + ) + effective_thread_id = None + continue + if "message to be replied not found" in err_lower and reply_to_id is not None: + # Original message was deleted before we + # could reply — clear reply target and retry + # so the response is still delivered. + logger.warning( + "[%s] Reply target deleted, retrying without reply_to: %s", + self.name, send_err, + ) + reply_to_id = None + continue + # Other BadRequest errors are permanent — don't retry + raise if _send_attempt < 2: wait = 2 ** _send_attempt logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s", @@ -1700,7 +1768,8 @@ class TelegramAdapter(BasePlatformAdapter): recognized without a gateway restart. """ try: - config_path = _Path.home() / ".hermes" / "config.yaml" + from hermes_constants import get_hermes_home + config_path = get_hermes_home() / "config.yaml" if not config_path.exists(): return diff --git a/gateway/platforms/telegram_network.py b/gateway/platforms/telegram_network.py new file mode 100644 index 000000000..93f1f0fb5 --- /dev/null +++ b/gateway/platforms/telegram_network.py @@ -0,0 +1,245 @@ +"""Telegram-specific network helpers. + +Provides a hostname-preserving fallback transport for networks where +api.telegram.org resolves to an endpoint that is unreachable from the current +host. The transport keeps the logical request host and TLS SNI as +api.telegram.org while retrying the TCP connection against one or more fallback +IPv4 addresses. +""" + +from __future__ import annotations + +import asyncio +import ipaddress +import logging +import os +import socket +from typing import Iterable, Optional + +import httpx + +logger = logging.getLogger(__name__) + +_TELEGRAM_API_HOST = "api.telegram.org" + +# DNS-over-HTTPS providers used to discover Telegram API IPs that may differ +# from the (potentially unreachable) IP returned by the local system resolver. +_DOH_TIMEOUT = 4.0 # seconds — bounded so connect() isn't noticeably delayed + +_DOH_PROVIDERS: list[dict] = [ + { + "url": "https://dns.google/resolve", + "params": {"name": _TELEGRAM_API_HOST, "type": "A"}, + "headers": {}, + }, + { + "url": "https://cloudflare-dns.com/dns-query", + "params": {"name": _TELEGRAM_API_HOST, "type": "A"}, + "headers": {"Accept": "application/dns-json"}, + }, +] + +# Last-resort IPs when DoH is also blocked. These are stable Telegram Bot API +# endpoints in the 149.154.160.0/20 block (same seed used by OpenClaw). +_SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"] + + +def _resolve_proxy_url() -> str | None: + for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy"): + value = (os.environ.get(key) or "").strip() + if value: + return value + return None + + +class TelegramFallbackTransport(httpx.AsyncBaseTransport): + """Retry Telegram Bot API requests via fallback IPs while preserving TLS/SNI. + + Requests continue to target https://api.telegram.org/... logically, but on + connect failures the underlying TCP connection is retried against a known + reachable IP. This is effectively the programmatic equivalent of + ``curl --resolve api.telegram.org:443:``. + """ + + def __init__(self, fallback_ips: Iterable[str], **transport_kwargs): + self._fallback_ips = [ip for ip in dict.fromkeys(_normalize_fallback_ips(fallback_ips))] + proxy_url = _resolve_proxy_url() + if proxy_url and "proxy" not in transport_kwargs: + transport_kwargs["proxy"] = proxy_url + self._primary = httpx.AsyncHTTPTransport(**transport_kwargs) + self._fallbacks = { + ip: httpx.AsyncHTTPTransport(**transport_kwargs) for ip in self._fallback_ips + } + self._sticky_ip: Optional[str] = None + self._sticky_lock = asyncio.Lock() + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if request.url.host != _TELEGRAM_API_HOST or not self._fallback_ips: + return await self._primary.handle_async_request(request) + + sticky_ip = self._sticky_ip + attempt_order: list[Optional[str]] = [sticky_ip] if sticky_ip else [None] + for ip in self._fallback_ips: + if ip != sticky_ip: + attempt_order.append(ip) + + last_error: Exception | None = None + for ip in attempt_order: + candidate = request if ip is None else _rewrite_request_for_ip(request, ip) + transport = self._primary if ip is None else self._fallbacks[ip] + try: + response = await transport.handle_async_request(candidate) + if ip is not None and self._sticky_ip != ip: + async with self._sticky_lock: + if self._sticky_ip != ip: + self._sticky_ip = ip + logger.warning( + "[Telegram] Primary api.telegram.org path unreachable; using sticky fallback IP %s", + ip, + ) + return response + except Exception as exc: + last_error = exc + if not _is_retryable_connect_error(exc): + raise + if ip is None: + logger.warning( + "[Telegram] Primary api.telegram.org connection failed (%s); trying fallback IPs %s", + exc, + ", ".join(self._fallback_ips), + ) + continue + logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc) + continue + + assert last_error is not None + raise last_error + + async def aclose(self) -> None: + await self._primary.aclose() + for transport in self._fallbacks.values(): + await transport.aclose() + + +def _normalize_fallback_ips(values: Iterable[str]) -> list[str]: + normalized: list[str] = [] + for value in values: + raw = str(value).strip() + if not raw: + continue + try: + addr = ipaddress.ip_address(raw) + except ValueError: + logger.warning("Ignoring invalid Telegram fallback IP: %r", raw) + continue + if addr.version != 4: + logger.warning("Ignoring non-IPv4 Telegram fallback IP: %s", raw) + continue + normalized.append(str(addr)) + return normalized + + +def parse_fallback_ip_env(value: str | None) -> list[str]: + if not value: + return [] + parts = [part.strip() for part in value.split(",")] + return _normalize_fallback_ips(parts) + + +def _resolve_system_dns() -> set[str]: + """Return the IPv4 addresses that the OS resolver gives for api.telegram.org.""" + try: + results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET) + return {addr[4][0] for addr in results} + except Exception: + return set() + + +async def _query_doh_provider( + client: httpx.AsyncClient, provider: dict +) -> list[str]: + """Query one DoH provider and return A-record IPs.""" + try: + resp = await client.get( + provider["url"], params=provider["params"], headers=provider["headers"] + ) + resp.raise_for_status() + data = resp.json() + ips: list[str] = [] + for answer in data.get("Answer", []): + if answer.get("type") != 1: # A record + continue + raw = answer.get("data", "").strip() + try: + ipaddress.ip_address(raw) + ips.append(raw) + except ValueError: + continue + return ips + except Exception as exc: + logger.debug("DoH query to %s failed: %s", provider["url"], exc) + return [] + + +async def discover_fallback_ips() -> list[str]: + """Auto-discover Telegram API IPs via DNS-over-HTTPS. + + Resolves api.telegram.org through Google and Cloudflare DoH, collects all + unique IPs, and excludes the system-DNS-resolved IP (which is presumably + unreachable on this network). Falls back to a hardcoded seed list when DoH + is also unavailable. + """ + async with httpx.AsyncClient(timeout=httpx.Timeout(_DOH_TIMEOUT)) as client: + doh_tasks = [_query_doh_provider(client, p) for p in _DOH_PROVIDERS] + system_dns_task = asyncio.to_thread(_resolve_system_dns) + results = await asyncio.gather(system_dns_task, *doh_tasks, return_exceptions=True) + + # results[0] = system DNS IPs (set), results[1:] = DoH IP lists + system_ips: set[str] = results[0] if isinstance(results[0], set) else set() + + doh_ips: list[str] = [] + for r in results[1:]: + if isinstance(r, list): + doh_ips.extend(r) + + # Deduplicate preserving order, exclude system-DNS IPs + seen: set[str] = set() + candidates: list[str] = [] + for ip in doh_ips: + if ip not in seen and ip not in system_ips: + seen.add(ip) + candidates.append(ip) + + # Validate through existing normalization + validated = _normalize_fallback_ips(candidates) + + if validated: + logger.debug("Discovered Telegram fallback IPs via DoH: %s", ", ".join(validated)) + return validated + + logger.info( + "DoH discovery yielded no new IPs (system DNS: %s); using seed fallback IPs %s", + ", ".join(system_ips) or "unknown", + ", ".join(_SEED_FALLBACK_IPS), + ) + return list(_SEED_FALLBACK_IPS) + + +def _rewrite_request_for_ip(request: httpx.Request, ip: str) -> httpx.Request: + original_host = request.url.host or _TELEGRAM_API_HOST + url = request.url.copy_with(host=ip) + headers = request.headers.copy() + headers["host"] = original_host + extensions = dict(request.extensions) + extensions["sni_hostname"] = original_host + return httpx.Request( + method=request.method, + url=url, + headers=headers, + stream=request.stream, + extensions=extensions, + ) + + +def _is_retryable_connect_error(exc: Exception) -> bool: + return isinstance(exc, (httpx.ConnectTimeout, httpx.ConnectError)) diff --git a/gateway/platforms/webhook.py b/gateway/platforms/webhook.py index 2d75879b5..5f7c78cfa 100644 --- a/gateway/platforms/webhook.py +++ b/gateway/platforms/webhook.py @@ -27,6 +27,7 @@ import hashlib import hmac import json import logging +import os import re import subprocess import time @@ -53,6 +54,7 @@ logger = logging.getLogger(__name__) DEFAULT_HOST = "0.0.0.0" DEFAULT_PORT = 8644 _INSECURE_NO_AUTH = "INSECURE_NO_AUTH" +_DYNAMIC_ROUTES_FILENAME = "webhook_subscriptions.json" def check_webhook_requirements() -> bool: @@ -68,7 +70,10 @@ class WebhookAdapter(BasePlatformAdapter): self._host: str = config.extra.get("host", DEFAULT_HOST) self._port: int = int(config.extra.get("port", DEFAULT_PORT)) self._global_secret: str = config.extra.get("secret", "") - self._routes: Dict[str, dict] = config.extra.get("routes", {}) + self._static_routes: Dict[str, dict] = config.extra.get("routes", {}) + self._dynamic_routes: Dict[str, dict] = {} + self._dynamic_routes_mtime: float = 0.0 + self._routes: Dict[str, dict] = dict(self._static_routes) self._runner = None # Delivery info keyed by session chat_id — consumed by send() @@ -96,6 +101,9 @@ class WebhookAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def connect(self) -> bool: + # Load agent-created subscriptions before validating + self._reload_dynamic_routes() + # Validate routes at startup — secret is required per route for name, route in self._routes.items(): secret = route.get("secret", self._global_secret) @@ -110,6 +118,17 @@ class WebhookAdapter(BasePlatformAdapter): app.router.add_get("/health", self._handle_health) app.router.add_post("/webhooks/{route_name}", self._handle_webhook) + # Port conflict detection — fail fast if port is already in use + import socket as _socket + try: + with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s: + _s.settimeout(1) + _s.connect(('127.0.0.1', self._port)) + logger.error('[webhook] Port %d already in use. Set a different port in config.yaml: platforms.webhook.port', self._port) + return False + except (ConnectionRefusedError, OSError): + pass # port is free + self._runner = web.AppRunner(app) await self._runner.setup() site = web.TCPSite(self._runner, self._host, self._port) @@ -182,8 +201,46 @@ class WebhookAdapter(BasePlatformAdapter): """GET /health — simple health check.""" return web.json_response({"status": "ok", "platform": "webhook"}) + def _reload_dynamic_routes(self) -> None: + """Reload agent-created subscriptions from disk if the file changed.""" + from pathlib import Path as _Path + hermes_home = _Path( + os.getenv("HERMES_HOME", str(_Path.home() / ".hermes")) + ).expanduser() + subs_path = hermes_home / _DYNAMIC_ROUTES_FILENAME + if not subs_path.exists(): + if self._dynamic_routes: + self._dynamic_routes = {} + self._routes = dict(self._static_routes) + logger.debug("[webhook] Dynamic subscriptions file removed, cleared dynamic routes") + return + try: + mtime = subs_path.stat().st_mtime + if mtime <= self._dynamic_routes_mtime: + return # No change + data = json.loads(subs_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + return + # Merge: static routes take precedence over dynamic ones + self._dynamic_routes = { + k: v for k, v in data.items() + if k not in self._static_routes + } + self._routes = {**self._dynamic_routes, **self._static_routes} + self._dynamic_routes_mtime = mtime + logger.info( + "[webhook] Reloaded %d dynamic route(s): %s", + len(self._dynamic_routes), + ", ".join(self._dynamic_routes.keys()) or "(none)", + ) + except Exception as e: + logger.warning("[webhook] Failed to reload dynamic routes: %s", e) + async def _handle_webhook(self, request: "web.Request") -> "web.Response": """POST /webhooks/{route_name} — receive and process a webhook event.""" + # Hot-reload dynamic subscriptions on each request (mtime-gated, cheap) + self._reload_dynamic_routes() + route_name = request.match_info.get("route_name", "") route_config = self._routes.get(route_name) diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py new file mode 100644 index 000000000..d40b651c5 --- /dev/null +++ b/gateway/platforms/wecom.py @@ -0,0 +1,1338 @@ +""" +WeCom (Enterprise WeChat) platform adapter. + +Uses the WeCom AI Bot WebSocket gateway for inbound and outbound messages. +The adapter focuses on the core gateway path: + +- authenticate via ``aibot_subscribe`` +- receive inbound ``aibot_msg_callback`` events +- send outbound markdown messages via ``aibot_send_msg`` +- upload outbound media via ``aibot_upload_media_*`` and send native attachments +- best-effort download of inbound image/file attachments for agent context + +Configuration in config.yaml: + platforms: + wecom: + enabled: true + extra: + bot_id: "your-bot-id" # or WECOM_BOT_ID env var + secret: "your-secret" # or WECOM_SECRET env var + websocket_url: "wss://openws.work.weixin.qq.com" + dm_policy: "open" # open | allowlist | disabled | pairing + allow_from: ["user_id_1"] + group_policy: "open" # open | allowlist | disabled + group_allow_from: ["group_id_1"] + groups: + group_id_1: + allow_from: ["user_id_1"] +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import json +import logging +import mimetypes +import os +import re +import time +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import unquote, urlparse + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + aiohttp = None # type: ignore[assignment] + +try: + import httpx + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False + httpx = None # type: ignore[assignment] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_document_from_bytes, + cache_image_from_bytes, +) + +logger = logging.getLogger(__name__) + +DEFAULT_WS_URL = "wss://openws.work.weixin.qq.com" + +APP_CMD_SUBSCRIBE = "aibot_subscribe" +APP_CMD_CALLBACK = "aibot_msg_callback" +APP_CMD_LEGACY_CALLBACK = "aibot_callback" +APP_CMD_EVENT_CALLBACK = "aibot_event_callback" +APP_CMD_SEND = "aibot_send_msg" +APP_CMD_RESPONSE = "aibot_respond_msg" +APP_CMD_PING = "ping" +APP_CMD_UPLOAD_MEDIA_INIT = "aibot_upload_media_init" +APP_CMD_UPLOAD_MEDIA_CHUNK = "aibot_upload_media_chunk" +APP_CMD_UPLOAD_MEDIA_FINISH = "aibot_upload_media_finish" + +CALLBACK_COMMANDS = {APP_CMD_CALLBACK, APP_CMD_LEGACY_CALLBACK} +NON_RESPONSE_COMMANDS = CALLBACK_COMMANDS | {APP_CMD_EVENT_CALLBACK} + +MAX_MESSAGE_LENGTH = 4000 +CONNECT_TIMEOUT_SECONDS = 20.0 +REQUEST_TIMEOUT_SECONDS = 15.0 +HEARTBEAT_INTERVAL_SECONDS = 30.0 +RECONNECT_BACKOFF = [2, 5, 10, 30, 60] + +DEDUP_WINDOW_SECONDS = 300 +DEDUP_MAX_SIZE = 1000 + +IMAGE_MAX_BYTES = 10 * 1024 * 1024 +VIDEO_MAX_BYTES = 10 * 1024 * 1024 +VOICE_MAX_BYTES = 2 * 1024 * 1024 +FILE_MAX_BYTES = 20 * 1024 * 1024 +ABSOLUTE_MAX_BYTES = FILE_MAX_BYTES +UPLOAD_CHUNK_SIZE = 512 * 1024 +MAX_UPLOAD_CHUNKS = 100 +VOICE_SUPPORTED_MIMES = {"audio/amr"} + + +def check_wecom_requirements() -> bool: + """Check if WeCom runtime dependencies are available.""" + return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE + + +def _coerce_list(value: Any) -> List[str]: + """Coerce config values into a trimmed string list.""" + if value is None: + return [] + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + if isinstance(value, (list, tuple, set)): + return [str(item).strip() for item in value if str(item).strip()] + return [str(value).strip()] if str(value).strip() else [] + + +def _normalize_entry(raw: str) -> str: + """Normalize allowlist entries such as ``wecom:user:foo``.""" + value = str(raw).strip() + value = re.sub(r"^wecom:", "", value, flags=re.IGNORECASE) + value = re.sub(r"^(user|group):", "", value, flags=re.IGNORECASE) + return value.strip() + + +def _entry_matches(entries: List[str], target: str) -> bool: + """Case-insensitive allowlist match with ``*`` support.""" + normalized_target = str(target).strip().lower() + for entry in entries: + normalized = _normalize_entry(entry).lower() + if normalized == "*" or normalized == normalized_target: + return True + return False + + +class WeComAdapter(BasePlatformAdapter): + """WeCom AI Bot adapter backed by a persistent WebSocket connection.""" + + MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.WECOM) + + extra = config.extra or {} + self._bot_id = str(extra.get("bot_id") or os.getenv("WECOM_BOT_ID", "")).strip() + self._secret = str(extra.get("secret") or os.getenv("WECOM_SECRET", "")).strip() + self._ws_url = str( + extra.get("websocket_url") + or extra.get("websocketUrl") + or os.getenv("WECOM_WEBSOCKET_URL", DEFAULT_WS_URL) + ).strip() or DEFAULT_WS_URL + + self._dm_policy = str(extra.get("dm_policy") or os.getenv("WECOM_DM_POLICY", "open")).strip().lower() + self._allow_from = _coerce_list(extra.get("allow_from") or extra.get("allowFrom")) + + self._group_policy = str(extra.get("group_policy") or os.getenv("WECOM_GROUP_POLICY", "open")).strip().lower() + self._group_allow_from = _coerce_list(extra.get("group_allow_from") or extra.get("groupAllowFrom")) + self._groups = extra.get("groups") if isinstance(extra.get("groups"), dict) else {} + + self._session: Optional["aiohttp.ClientSession"] = None + self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None + self._http_client: Optional["httpx.AsyncClient"] = None + self._listen_task: Optional[asyncio.Task] = None + self._heartbeat_task: Optional[asyncio.Task] = None + self._pending_responses: Dict[str, asyncio.Future] = {} + self._seen_messages: Dict[str, float] = {} + self._reply_req_ids: Dict[str, str] = {} + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to the WeCom AI Bot gateway.""" + if not AIOHTTP_AVAILABLE: + message = "WeCom startup failed: aiohttp not installed" + self._set_fatal_error("wecom_missing_dependency", message, retryable=True) + logger.warning("[%s] %s. Run: pip install aiohttp", self.name, message) + return False + if not HTTPX_AVAILABLE: + message = "WeCom startup failed: httpx not installed" + self._set_fatal_error("wecom_missing_dependency", message, retryable=True) + logger.warning("[%s] %s. Run: pip install httpx", self.name, message) + return False + if not self._bot_id or not self._secret: + message = "WeCom startup failed: WECOM_BOT_ID and WECOM_SECRET are required" + self._set_fatal_error("wecom_missing_credentials", message, retryable=True) + logger.warning("[%s] %s", self.name, message) + return False + + try: + self._http_client = httpx.AsyncClient(timeout=30.0, follow_redirects=True) + await self._open_connection() + self._mark_connected() + self._listen_task = asyncio.create_task(self._listen_loop()) + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + logger.info("[%s] Connected to %s", self.name, self._ws_url) + return True + except Exception as exc: + message = f"WeCom startup failed: {exc}" + self._set_fatal_error("wecom_connect_error", message, retryable=True) + logger.error("[%s] Failed to connect: %s", self.name, exc, exc_info=True) + await self._cleanup_ws() + if self._http_client: + await self._http_client.aclose() + self._http_client = None + return False + + async def disconnect(self) -> None: + """Disconnect from WeCom.""" + self._running = False + self._mark_disconnected() + + if self._listen_task: + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + self._listen_task = None + + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + self._fail_pending_responses(RuntimeError("WeCom adapter disconnected")) + await self._cleanup_ws() + + if self._http_client: + await self._http_client.aclose() + self._http_client = None + + self._seen_messages.clear() + logger.info("[%s] Disconnected", self.name) + + async def _cleanup_ws(self) -> None: + """Close the live websocket/session, if any.""" + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def _open_connection(self) -> None: + """Open and authenticate a websocket connection.""" + await self._cleanup_ws() + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect( + self._ws_url, + heartbeat=HEARTBEAT_INTERVAL_SECONDS * 2, + timeout=CONNECT_TIMEOUT_SECONDS, + ) + + req_id = self._new_req_id("subscribe") + await self._send_json( + { + "cmd": APP_CMD_SUBSCRIBE, + "headers": {"req_id": req_id}, + "body": {"bot_id": self._bot_id, "secret": self._secret}, + } + ) + + auth_payload = await self._wait_for_handshake(req_id) + errcode = auth_payload.get("errcode", 0) + if errcode not in (0, None): + errmsg = auth_payload.get("errmsg", "authentication failed") + raise RuntimeError(f"{errmsg} (errcode={errcode})") + + async def _wait_for_handshake(self, req_id: str) -> Dict[str, Any]: + """Wait for the subscribe acknowledgement.""" + if not self._ws: + raise RuntimeError("WebSocket not initialized") + + deadline = asyncio.get_running_loop().time() + CONNECT_TIMEOUT_SECONDS + while True: + remaining = deadline - asyncio.get_running_loop().time() + if remaining <= 0: + raise TimeoutError("Timed out waiting for WeCom subscribe acknowledgement") + + msg = await asyncio.wait_for(self._ws.receive(), timeout=remaining) + if msg.type == aiohttp.WSMsgType.TEXT: + payload = self._parse_json(msg.data) + if not payload: + continue + if payload.get("cmd") == APP_CMD_PING: + continue + if self._payload_req_id(payload) == req_id: + return payload + logger.debug("[%s] Ignoring pre-auth payload: %s", self.name, payload.get("cmd")) + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.ERROR): + raise RuntimeError("WeCom websocket closed during authentication") + + async def _listen_loop(self) -> None: + """Read websocket events forever, reconnecting on errors.""" + backoff_idx = 0 + while self._running: + try: + await self._read_events() + backoff_idx = 0 + except asyncio.CancelledError: + return + except Exception as exc: + if not self._running: + return + logger.warning("[%s] WebSocket error: %s", self.name, exc) + self._fail_pending_responses(RuntimeError("WeCom connection interrupted")) + + delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)] + backoff_idx += 1 + await asyncio.sleep(delay) + + try: + await self._open_connection() + backoff_idx = 0 + logger.info("[%s] Reconnected", self.name) + except Exception as reconnect_exc: + logger.warning("[%s] Reconnect failed: %s", self.name, reconnect_exc) + + async def _read_events(self) -> None: + """Read websocket frames until the connection closes.""" + if not self._ws: + raise RuntimeError("WebSocket not connected") + + while self._running and self._ws and not self._ws.closed: + msg = await self._ws.receive() + if msg.type == aiohttp.WSMsgType.TEXT: + payload = self._parse_json(msg.data) + if payload: + await self._dispatch_payload(payload) + elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): + raise RuntimeError("WeCom websocket closed") + + async def _heartbeat_loop(self) -> None: + """Send lightweight application-level pings.""" + try: + while self._running: + await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS) + if not self._ws or self._ws.closed: + continue + try: + await self._send_json( + { + "cmd": APP_CMD_PING, + "headers": {"req_id": self._new_req_id("ping")}, + "body": {}, + } + ) + except Exception as exc: + logger.debug("[%s] Heartbeat send failed: %s", self.name, exc) + except asyncio.CancelledError: + pass + + async def _dispatch_payload(self, payload: Dict[str, Any]) -> None: + """Route inbound websocket payloads.""" + req_id = self._payload_req_id(payload) + cmd = str(payload.get("cmd") or "") + + if req_id and req_id in self._pending_responses and cmd not in NON_RESPONSE_COMMANDS: + future = self._pending_responses.get(req_id) + if future and not future.done(): + future.set_result(payload) + return + + if cmd in CALLBACK_COMMANDS: + await self._on_message(payload) + return + if cmd in {APP_CMD_PING, APP_CMD_EVENT_CALLBACK}: + return + + logger.debug("[%s] Ignoring websocket payload: %s", self.name, cmd or payload) + + def _fail_pending_responses(self, exc: Exception) -> None: + """Fail all outstanding request futures.""" + for req_id, future in list(self._pending_responses.items()): + if not future.done(): + future.set_exception(exc) + self._pending_responses.pop(req_id, None) + + async def _send_json(self, payload: Dict[str, Any]) -> None: + """Send a raw JSON frame over the active websocket.""" + if not self._ws or self._ws.closed: + raise RuntimeError("WeCom websocket is not connected") + await self._ws.send_json(payload) + + async def _send_request(self, cmd: str, body: Dict[str, Any], timeout: float = REQUEST_TIMEOUT_SECONDS) -> Dict[str, Any]: + """Send a JSON request and await the correlated response.""" + if not self._ws or self._ws.closed: + raise RuntimeError("WeCom websocket is not connected") + + req_id = self._new_req_id(cmd) + future = asyncio.get_running_loop().create_future() + self._pending_responses[req_id] = future + try: + await self._send_json({"cmd": cmd, "headers": {"req_id": req_id}, "body": body}) + response = await asyncio.wait_for(future, timeout=timeout) + return response + finally: + self._pending_responses.pop(req_id, None) + + async def _send_reply_request( + self, + reply_req_id: str, + body: Dict[str, Any], + cmd: str = APP_CMD_RESPONSE, + timeout: float = REQUEST_TIMEOUT_SECONDS, + ) -> Dict[str, Any]: + """Send a reply frame correlated to an inbound callback req_id.""" + if not self._ws or self._ws.closed: + raise RuntimeError("WeCom websocket is not connected") + + normalized_req_id = str(reply_req_id or "").strip() + if not normalized_req_id: + raise ValueError("reply_req_id is required") + + future = asyncio.get_running_loop().create_future() + self._pending_responses[normalized_req_id] = future + try: + await self._send_json( + {"cmd": cmd, "headers": {"req_id": normalized_req_id}, "body": body} + ) + response = await asyncio.wait_for(future, timeout=timeout) + return response + finally: + self._pending_responses.pop(normalized_req_id, None) + + @staticmethod + def _new_req_id(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex}" + + @staticmethod + def _payload_req_id(payload: Dict[str, Any]) -> str: + headers = payload.get("headers") + if isinstance(headers, dict): + return str(headers.get("req_id") or "") + return "" + + @staticmethod + def _parse_json(raw: Any) -> Optional[Dict[str, Any]]: + try: + payload = json.loads(raw) + except Exception: + logger.debug("Failed to parse WeCom payload: %r", raw) + return None + return payload if isinstance(payload, dict) else None + + # ------------------------------------------------------------------ + # Inbound message parsing + # ------------------------------------------------------------------ + + async def _on_message(self, payload: Dict[str, Any]) -> None: + """Process an inbound WeCom message callback event.""" + body = payload.get("body") + if not isinstance(body, dict): + return + + msg_id = str(body.get("msgid") or self._payload_req_id(payload) or uuid.uuid4().hex) + if self._is_duplicate(msg_id): + logger.debug("[%s] Duplicate message %s ignored", self.name, msg_id) + return + self._remember_reply_req_id(msg_id, self._payload_req_id(payload)) + + sender = body.get("from") if isinstance(body.get("from"), dict) else {} + sender_id = str(sender.get("userid") or "").strip() + chat_id = str(body.get("chatid") or sender_id).strip() + if not chat_id: + logger.debug("[%s] Missing chat id, skipping message", self.name) + return + + is_group = str(body.get("chattype") or "").lower() == "group" + if is_group: + if not self._is_group_allowed(chat_id, sender_id): + logger.debug("[%s] Group %s / sender %s blocked by policy", self.name, chat_id, sender_id) + return + elif not self._is_dm_allowed(sender_id): + logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id) + return + + text, reply_text = self._extract_text(body) + media_urls, media_types = await self._extract_media(body) + message_type = self._derive_message_type(body, text, media_types) + has_reply_context = bool(reply_text and (text or media_urls)) + + if not text and reply_text and not media_urls: + text = reply_text + + if not text and not media_urls: + logger.debug("[%s] Empty WeCom message skipped", self.name) + return + + source = self.build_source( + chat_id=chat_id, + chat_type="group" if is_group else "dm", + user_id=sender_id or None, + user_name=sender_id or None, + ) + + event = MessageEvent( + text=text, + message_type=message_type, + source=source, + raw_message=payload, + message_id=msg_id, + media_urls=media_urls, + media_types=media_types, + reply_to_message_id=f"quote:{msg_id}" if has_reply_context else None, + reply_to_text=reply_text if has_reply_context else None, + timestamp=datetime.now(tz=timezone.utc), + ) + + await self.handle_message(event) + + @staticmethod + def _extract_text(body: Dict[str, Any]) -> Tuple[str, Optional[str]]: + """Extract plain text and quoted text from a callback payload.""" + text_parts: List[str] = [] + reply_text: Optional[str] = None + msgtype = str(body.get("msgtype") or "").lower() + + if msgtype == "mixed": + mixed = body.get("mixed") if isinstance(body.get("mixed"), dict) else {} + items = mixed.get("msg_item") if isinstance(mixed.get("msg_item"), list) else [] + for item in items: + if not isinstance(item, dict): + continue + if str(item.get("msgtype") or "").lower() == "text": + text_block = item.get("text") if isinstance(item.get("text"), dict) else {} + content = str(text_block.get("content") or "").strip() + if content: + text_parts.append(content) + else: + text_block = body.get("text") if isinstance(body.get("text"), dict) else {} + content = str(text_block.get("content") or "").strip() + if content: + text_parts.append(content) + + if msgtype == "voice": + voice_block = body.get("voice") if isinstance(body.get("voice"), dict) else {} + voice_text = str(voice_block.get("content") or "").strip() + if voice_text: + text_parts.append(voice_text) + + quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} + quote_type = str(quote.get("msgtype") or "").lower() + if quote_type == "text": + quote_text = quote.get("text") if isinstance(quote.get("text"), dict) else {} + reply_text = str(quote_text.get("content") or "").strip() or None + elif quote_type == "voice": + quote_voice = quote.get("voice") if isinstance(quote.get("voice"), dict) else {} + reply_text = str(quote_voice.get("content") or "").strip() or None + + return "\n".join(part for part in text_parts if part).strip(), reply_text + + async def _extract_media(self, body: Dict[str, Any]) -> Tuple[List[str], List[str]]: + """Best-effort extraction of inbound media to local cache paths.""" + media_paths: List[str] = [] + media_types: List[str] = [] + refs: List[Tuple[str, Dict[str, Any]]] = [] + msgtype = str(body.get("msgtype") or "").lower() + + if msgtype == "mixed": + mixed = body.get("mixed") if isinstance(body.get("mixed"), dict) else {} + items = mixed.get("msg_item") if isinstance(mixed.get("msg_item"), list) else [] + for item in items: + if not isinstance(item, dict): + continue + item_type = str(item.get("msgtype") or "").lower() + if item_type == "image" and isinstance(item.get("image"), dict): + refs.append(("image", item["image"])) + else: + if isinstance(body.get("image"), dict): + refs.append(("image", body["image"])) + if msgtype == "file" and isinstance(body.get("file"), dict): + refs.append(("file", body["file"])) + + quote = body.get("quote") if isinstance(body.get("quote"), dict) else {} + quote_type = str(quote.get("msgtype") or "").lower() + if quote_type == "image" and isinstance(quote.get("image"), dict): + refs.append(("image", quote["image"])) + elif quote_type == "file" and isinstance(quote.get("file"), dict): + refs.append(("file", quote["file"])) + + for kind, ref in refs: + cached = await self._cache_media(kind, ref) + if cached: + path, content_type = cached + media_paths.append(path) + media_types.append(content_type) + + return media_paths, media_types + + async def _cache_media(self, kind: str, media: Dict[str, Any]) -> Optional[Tuple[str, str]]: + """Cache an inbound image/file/media reference to local storage.""" + if "base64" in media and media.get("base64"): + try: + raw = self._decode_base64(media["base64"]) + except Exception as exc: + logger.debug("[%s] Failed to decode %s base64 media: %s", self.name, kind, exc) + return None + + if kind == "image": + ext = self._detect_image_ext(raw) + return cache_image_from_bytes(raw, ext), self._mime_for_ext(ext, fallback="image/jpeg") + + filename = str(media.get("filename") or media.get("name") or "wecom_file") + return cache_document_from_bytes(raw, filename), mimetypes.guess_type(filename)[0] or "application/octet-stream" + + url = str(media.get("url") or "").strip() + if not url: + return None + + try: + raw, headers = await self._download_remote_bytes(url, max_bytes=ABSOLUTE_MAX_BYTES) + except Exception as exc: + logger.debug("[%s] Failed to download %s from %s: %s", self.name, kind, url, exc) + return None + + aes_key = str(media.get("aeskey") or "").strip() + if aes_key: + try: + raw = self._decrypt_file_bytes(raw, aes_key) + except Exception as exc: + logger.debug("[%s] Failed to decrypt %s from %s: %s", self.name, kind, url, exc) + return None + + content_type = str(headers.get("content-type") or "").split(";", 1)[0].strip() or "application/octet-stream" + if kind == "image": + ext = self._guess_extension(url, content_type, fallback=self._detect_image_ext(raw)) + return cache_image_from_bytes(raw, ext), content_type or self._mime_for_ext(ext, fallback="image/jpeg") + + filename = self._guess_filename(url, headers.get("content-disposition"), content_type) + return cache_document_from_bytes(raw, filename), content_type + + @staticmethod + def _decode_base64(data: str) -> bytes: + payload = data.split(",", 1)[-1].strip() + return base64.b64decode(payload) + + @staticmethod + def _detect_image_ext(data: bytes) -> str: + if data.startswith(b"\x89PNG\r\n\x1a\n"): + return ".png" + if data.startswith(b"\xff\xd8\xff"): + return ".jpg" + if data.startswith(b"GIF87a") or data.startswith(b"GIF89a"): + return ".gif" + if data.startswith(b"RIFF") and data[8:12] == b"WEBP": + return ".webp" + return ".jpg" + + @staticmethod + def _mime_for_ext(ext: str, fallback: str = "application/octet-stream") -> str: + return mimetypes.types_map.get(ext.lower(), fallback) + + @staticmethod + def _guess_extension(url: str, content_type: str, fallback: str) -> str: + ext = mimetypes.guess_extension(content_type) if content_type else None + if ext: + return ext + path_ext = Path(urlparse(url).path).suffix + if path_ext: + return path_ext + return fallback + + @staticmethod + def _guess_filename(url: str, content_disposition: Optional[str], content_type: str) -> str: + if content_disposition: + match = re.search(r'filename="?([^";]+)"?', content_disposition) + if match: + return match.group(1) + + name = Path(urlparse(url).path).name or "document" + if "." not in name: + ext = mimetypes.guess_extension(content_type) or ".bin" + name = f"{name}{ext}" + return name + + @staticmethod + def _derive_message_type(body: Dict[str, Any], text: str, media_types: List[str]) -> MessageType: + """Choose the normalized inbound message type.""" + if any(mtype.startswith("application/") or mtype.startswith("text/") for mtype in media_types): + return MessageType.DOCUMENT + if any(mtype.startswith("image/") for mtype in media_types): + return MessageType.TEXT if text else MessageType.PHOTO + if str(body.get("msgtype") or "").lower() == "voice": + return MessageType.VOICE + return MessageType.TEXT + + # ------------------------------------------------------------------ + # Policy helpers + # ------------------------------------------------------------------ + + def _is_dm_allowed(self, sender_id: str) -> bool: + if self._dm_policy == "disabled": + return False + if self._dm_policy == "allowlist": + return _entry_matches(self._allow_from, sender_id) + return True + + def _is_group_allowed(self, chat_id: str, sender_id: str) -> bool: + if self._group_policy == "disabled": + return False + if self._group_policy == "allowlist" and not _entry_matches(self._group_allow_from, chat_id): + return False + + group_cfg = self._resolve_group_cfg(chat_id) + sender_allow = _coerce_list(group_cfg.get("allow_from") or group_cfg.get("allowFrom")) + if sender_allow: + return _entry_matches(sender_allow, sender_id) + return True + + def _resolve_group_cfg(self, chat_id: str) -> Dict[str, Any]: + if not isinstance(self._groups, dict): + return {} + if chat_id in self._groups and isinstance(self._groups[chat_id], dict): + return self._groups[chat_id] + lowered = chat_id.lower() + for key, value in self._groups.items(): + if isinstance(key, str) and key.lower() == lowered and isinstance(value, dict): + return value + wildcard = self._groups.get("*") + return wildcard if isinstance(wildcard, dict) else {} + + def _is_duplicate(self, msg_id: str) -> bool: + now = time.time() + if len(self._seen_messages) > DEDUP_MAX_SIZE: + cutoff = now - DEDUP_WINDOW_SECONDS + self._seen_messages = { + key: ts for key, ts in self._seen_messages.items() if ts > cutoff + } + if self._reply_req_ids: + self._reply_req_ids = { + key: value for key, value in self._reply_req_ids.items() if key in self._seen_messages + } + + if msg_id in self._seen_messages: + return True + + self._seen_messages[msg_id] = now + return False + + def _remember_reply_req_id(self, message_id: str, req_id: str) -> None: + normalized_message_id = str(message_id or "").strip() + normalized_req_id = str(req_id or "").strip() + if not normalized_message_id or not normalized_req_id: + return + self._reply_req_ids[normalized_message_id] = normalized_req_id + while len(self._reply_req_ids) > DEDUP_MAX_SIZE: + self._reply_req_ids.pop(next(iter(self._reply_req_ids))) + + def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]: + normalized = str(reply_to or "").strip() + if not normalized or normalized.startswith("quote:"): + return None + return self._reply_req_ids.get(normalized) + + # ------------------------------------------------------------------ + # Outbound messaging + # ------------------------------------------------------------------ + + @staticmethod + def _guess_mime_type(filename: str) -> str: + mime_type = mimetypes.guess_type(filename)[0] + if mime_type: + return mime_type + if Path(filename).suffix.lower() == ".amr": + return "audio/amr" + return "application/octet-stream" + + @staticmethod + def _normalize_content_type(content_type: str, filename: str) -> str: + normalized = str(content_type or "").split(";", 1)[0].strip().lower() + guessed = WeComAdapter._guess_mime_type(filename) + if not normalized: + return guessed + if normalized in {"application/octet-stream", "text/plain"}: + return guessed + return normalized + + @staticmethod + def _detect_wecom_media_type(content_type: str) -> str: + mime_type = str(content_type or "").strip().lower() + if mime_type.startswith("image/"): + return "image" + if mime_type.startswith("video/"): + return "video" + if mime_type.startswith("audio/") or mime_type == "application/ogg": + return "voice" + return "file" + + @staticmethod + def _apply_file_size_limits(file_size: int, detected_type: str, content_type: Optional[str] = None) -> Dict[str, Any]: + file_size_mb = file_size / (1024 * 1024) + normalized_type = str(detected_type or "file").lower() + normalized_content_type = str(content_type or "").strip().lower() + + if file_size > ABSOLUTE_MAX_BYTES: + return { + "final_type": normalized_type, + "rejected": True, + "reject_reason": ( + f"文件大小 {file_size_mb:.2f}MB 超过了企业微信允许的最大限制 20MB,无法发送。" + "请尝试压缩文件或减小文件大小。" + ), + "downgraded": False, + "downgrade_note": None, + } + + if normalized_type == "image" and file_size > IMAGE_MAX_BYTES: + return { + "final_type": "file", + "rejected": False, + "reject_reason": None, + "downgraded": True, + "downgrade_note": f"图片大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送", + } + + if normalized_type == "video" and file_size > VIDEO_MAX_BYTES: + return { + "final_type": "file", + "rejected": False, + "reject_reason": None, + "downgraded": True, + "downgrade_note": f"视频大小 {file_size_mb:.2f}MB 超过 10MB 限制,已转为文件格式发送", + } + + if normalized_type == "voice": + if normalized_content_type and normalized_content_type not in VOICE_SUPPORTED_MIMES: + return { + "final_type": "file", + "rejected": False, + "reject_reason": None, + "downgraded": True, + "downgrade_note": ( + f"语音格式 {normalized_content_type} 不支持,企微仅支持 AMR 格式,已转为文件格式发送" + ), + } + if file_size > VOICE_MAX_BYTES: + return { + "final_type": "file", + "rejected": False, + "reject_reason": None, + "downgraded": True, + "downgrade_note": f"语音大小 {file_size_mb:.2f}MB 超过 2MB 限制,已转为文件格式发送", + } + + return { + "final_type": normalized_type, + "rejected": False, + "reject_reason": None, + "downgraded": False, + "downgrade_note": None, + } + + @staticmethod + def _response_error(response: Dict[str, Any]) -> Optional[str]: + errcode = response.get("errcode", 0) + if errcode in (0, None): + return None + errmsg = str(response.get("errmsg") or "unknown error") + return f"WeCom errcode {errcode}: {errmsg}" + + @classmethod + def _raise_for_wecom_error(cls, response: Dict[str, Any], operation: str) -> None: + error = cls._response_error(response) + if error: + raise RuntimeError(f"{operation} failed: {error}") + + @staticmethod + def _decrypt_file_bytes(encrypted_data: bytes, aes_key: str) -> bytes: + if not encrypted_data: + raise ValueError("encrypted_data is empty") + if not aes_key: + raise ValueError("aes_key is required") + + key = base64.b64decode(aes_key) + if len(key) != 32: + raise ValueError(f"Invalid WeCom AES key length: expected 32 bytes, got {len(key)}") + + try: + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + except ImportError as exc: # pragma: no cover - dependency is environment-specific + raise RuntimeError("cryptography is required for WeCom media decryption") from exc + + cipher = Cipher(algorithms.AES(key), modes.CBC(key[:16])) + decryptor = cipher.decryptor() + decrypted = decryptor.update(encrypted_data) + decryptor.finalize() + + pad_len = decrypted[-1] + if pad_len < 1 or pad_len > 32 or pad_len > len(decrypted): + raise ValueError(f"Invalid PKCS#7 padding value: {pad_len}") + if any(byte != pad_len for byte in decrypted[-pad_len:]): + raise ValueError("Invalid PKCS#7 padding: padding bytes mismatch") + + return decrypted[:-pad_len] + + async def _download_remote_bytes( + self, + url: str, + max_bytes: int, + ) -> Tuple[bytes, Dict[str, str]]: + if not HTTPX_AVAILABLE: + raise RuntimeError("httpx is required for WeCom media download") + + client = self._http_client or httpx.AsyncClient(timeout=30.0, follow_redirects=True) + created_client = client is not self._http_client + try: + async with client.stream( + "GET", + url, + headers={ + "User-Agent": "HermesAgent/1.0", + "Accept": "*/*", + }, + ) as response: + response.raise_for_status() + headers = {key.lower(): value for key, value in response.headers.items()} + content_length = headers.get("content-length") + if content_length and content_length.isdigit() and int(content_length) > max_bytes: + raise ValueError( + f"Remote media exceeds WeCom limit: {int(content_length)} bytes > {max_bytes} bytes" + ) + + data = bytearray() + async for chunk in response.aiter_bytes(): + data.extend(chunk) + if len(data) > max_bytes: + raise ValueError( + f"Remote media exceeds WeCom limit while downloading: {len(data)} bytes > {max_bytes} bytes" + ) + + return bytes(data), headers + finally: + if created_client: + await client.aclose() + + @staticmethod + def _looks_like_url(media_source: str) -> bool: + parsed = urlparse(str(media_source or "")) + return parsed.scheme in {"http", "https"} + + async def _load_outbound_media( + self, + media_source: str, + file_name: Optional[str] = None, + ) -> Tuple[bytes, str, str]: + source = str(media_source or "").strip() + if not source: + raise ValueError("media source is required") + if re.fullmatch(r"<[^>\n]+>", source): + raise ValueError(f"Media placeholder was not replaced with a real file path: {source}") + + parsed = urlparse(source) + if parsed.scheme in {"http", "https"}: + data, headers = await self._download_remote_bytes(source, max_bytes=ABSOLUTE_MAX_BYTES) + content_disposition = headers.get("content-disposition") + resolved_name = file_name or self._guess_filename(source, content_disposition, headers.get("content-type", "")) + content_type = self._normalize_content_type(headers.get("content-type", ""), resolved_name) + return data, content_type, resolved_name + + if parsed.scheme == "file": + local_path = Path(unquote(parsed.path)).expanduser() + else: + local_path = Path(source).expanduser() + + if not local_path.is_absolute(): + local_path = (Path.cwd() / local_path).resolve() + + if not local_path.exists() or not local_path.is_file(): + raise FileNotFoundError(f"Media file not found: {local_path}") + + data = local_path.read_bytes() + resolved_name = file_name or local_path.name + content_type = self._normalize_content_type("", resolved_name) + return data, content_type, resolved_name + + async def _prepare_outbound_media( + self, + media_source: str, + file_name: Optional[str] = None, + ) -> Dict[str, Any]: + data, content_type, resolved_name = await self._load_outbound_media(media_source, file_name=file_name) + detected_type = self._detect_wecom_media_type(content_type) + size_check = self._apply_file_size_limits(len(data), detected_type, content_type) + return { + "data": data, + "content_type": content_type, + "file_name": resolved_name, + "detected_type": detected_type, + **size_check, + } + + async def _upload_media_bytes(self, data: bytes, media_type: str, filename: str) -> Dict[str, Any]: + if not data: + raise ValueError("Cannot upload empty media") + + total_size = len(data) + total_chunks = (total_size + UPLOAD_CHUNK_SIZE - 1) // UPLOAD_CHUNK_SIZE + if total_chunks > MAX_UPLOAD_CHUNKS: + raise ValueError( + f"File too large: {total_chunks} chunks exceeds maximum of {MAX_UPLOAD_CHUNKS} chunks" + ) + + init_response = await self._send_request( + APP_CMD_UPLOAD_MEDIA_INIT, + { + "type": media_type, + "filename": filename, + "total_size": total_size, + "total_chunks": total_chunks, + "md5": hashlib.md5(data).hexdigest(), + }, + ) + self._raise_for_wecom_error(init_response, "media upload init") + + init_body = init_response.get("body") if isinstance(init_response.get("body"), dict) else {} + upload_id = str(init_body.get("upload_id") or "").strip() + if not upload_id: + raise RuntimeError(f"media upload init failed: missing upload_id in response {init_response}") + + for chunk_index, start in enumerate(range(0, total_size, UPLOAD_CHUNK_SIZE)): + chunk = data[start : start + UPLOAD_CHUNK_SIZE] + chunk_response = await self._send_request( + APP_CMD_UPLOAD_MEDIA_CHUNK, + { + "upload_id": upload_id, + # Match the official SDK implementation, which currently uses 0-based chunk indexes. + "chunk_index": chunk_index, + "base64_data": base64.b64encode(chunk).decode("ascii"), + }, + ) + self._raise_for_wecom_error(chunk_response, f"media upload chunk {chunk_index}") + + finish_response = await self._send_request( + APP_CMD_UPLOAD_MEDIA_FINISH, + {"upload_id": upload_id}, + ) + self._raise_for_wecom_error(finish_response, "media upload finish") + + finish_body = finish_response.get("body") if isinstance(finish_response.get("body"), dict) else {} + media_id = str(finish_body.get("media_id") or "").strip() + if not media_id: + raise RuntimeError(f"media upload finish failed: missing media_id in response {finish_response}") + + return { + "type": str(finish_body.get("type") or media_type), + "media_id": media_id, + "created_at": finish_body.get("created_at"), + } + + async def _send_media_message(self, chat_id: str, media_type: str, media_id: str) -> Dict[str, Any]: + response = await self._send_request( + APP_CMD_SEND, + { + "chatid": chat_id, + "msgtype": media_type, + media_type: {"media_id": media_id}, + }, + ) + self._raise_for_wecom_error(response, "send media message") + return response + + async def _send_reply_stream(self, reply_req_id: str, content: str) -> Dict[str, Any]: + response = await self._send_reply_request( + reply_req_id, + { + "msgtype": "stream", + "stream": { + "id": self._new_req_id("stream"), + "finish": True, + "content": content[:self.MAX_MESSAGE_LENGTH], + }, + }, + ) + self._raise_for_wecom_error(response, "send reply stream") + return response + + async def _send_reply_media_message( + self, + reply_req_id: str, + media_type: str, + media_id: str, + ) -> Dict[str, Any]: + response = await self._send_reply_request( + reply_req_id, + { + "msgtype": media_type, + media_type: {"media_id": media_id}, + }, + ) + self._raise_for_wecom_error(response, "send reply media message") + return response + + async def _send_followup_markdown( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + ) -> Optional[SendResult]: + if not content: + return None + result = await self.send(chat_id=chat_id, content=content, reply_to=reply_to) + if not result.success: + logger.warning("[%s] Follow-up markdown send failed: %s", self.name, result.error) + return result + + async def _send_media_source( + self, + chat_id: str, + media_source: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + ) -> SendResult: + if not chat_id: + return SendResult(success=False, error="chat_id is required") + + try: + prepared = await self._prepare_outbound_media(media_source, file_name=file_name) + except FileNotFoundError as exc: + return SendResult(success=False, error=str(exc)) + except Exception as exc: + logger.error("[%s] Failed to prepare outbound media %s: %s", self.name, media_source, exc) + return SendResult(success=False, error=str(exc)) + + if prepared["rejected"]: + await self._send_followup_markdown( + chat_id, + f"⚠️ {prepared['reject_reason']}", + reply_to=reply_to, + ) + return SendResult(success=False, error=prepared["reject_reason"]) + + reply_req_id = self._reply_req_id_for_message(reply_to) + try: + upload_result = await self._upload_media_bytes( + prepared["data"], + prepared["final_type"], + prepared["file_name"], + ) + if reply_req_id: + media_response = await self._send_reply_media_message( + reply_req_id, + prepared["final_type"], + upload_result["media_id"], + ) + else: + media_response = await self._send_media_message( + chat_id, + prepared["final_type"], + upload_result["media_id"], + ) + except asyncio.TimeoutError: + return SendResult(success=False, error="Timeout sending media to WeCom") + except Exception as exc: + logger.error("[%s] Failed to send media %s: %s", self.name, media_source, exc) + return SendResult(success=False, error=str(exc)) + + caption_result = None + downgrade_result = None + if caption: + caption_result = await self._send_followup_markdown( + chat_id, + caption, + reply_to=reply_to, + ) + if prepared["downgraded"] and prepared["downgrade_note"]: + downgrade_result = await self._send_followup_markdown( + chat_id, + f"ℹ️ {prepared['downgrade_note']}", + reply_to=reply_to, + ) + + return SendResult( + success=True, + message_id=self._payload_req_id(media_response) or uuid.uuid4().hex[:12], + raw_response={ + "upload": upload_result, + "media": media_response, + "caption": caption_result.raw_response if caption_result else None, + "caption_error": caption_result.error if caption_result and not caption_result.success else None, + "downgrade": downgrade_result.raw_response if downgrade_result else None, + "downgrade_error": downgrade_result.error if downgrade_result and not downgrade_result.success else None, + }, + ) + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send markdown to a WeCom chat via proactive ``aibot_send_msg``.""" + del metadata + + if not chat_id: + return SendResult(success=False, error="chat_id is required") + + try: + reply_req_id = self._reply_req_id_for_message(reply_to) + if reply_req_id: + response = await self._send_reply_stream(reply_req_id, content) + else: + response = await self._send_request( + APP_CMD_SEND, + { + "chatid": chat_id, + "msgtype": "markdown", + "markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]}, + }, + ) + except asyncio.TimeoutError: + return SendResult(success=False, error="Timeout sending message to WeCom") + except Exception as exc: + logger.error("[%s] Send failed: %s", self.name, exc) + return SendResult(success=False, error=str(exc)) + + error = self._response_error(response) + if error: + return SendResult(success=False, error=error) + + return SendResult( + success=True, + message_id=self._payload_req_id(response) or uuid.uuid4().hex[:12], + raw_response=response, + ) + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + del metadata + + result = await self._send_media_source( + chat_id=chat_id, + media_source=image_url, + caption=caption, + reply_to=reply_to, + ) + if result.success or not self._looks_like_url(image_url): + return result + + logger.warning("[%s] Falling back to text send for image URL %s: %s", self.name, image_url, result.error) + fallback_text = f"{caption}\n{image_url}" if caption else image_url + return await self.send(chat_id=chat_id, content=fallback_text, reply_to=reply_to) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + del kwargs + return await self._send_media_source( + chat_id=chat_id, + media_source=image_path, + caption=caption, + reply_to=reply_to, + ) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + del kwargs + return await self._send_media_source( + chat_id=chat_id, + media_source=file_path, + caption=caption, + file_name=file_name, + reply_to=reply_to, + ) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + del kwargs + return await self._send_media_source( + chat_id=chat_id, + media_source=audio_path, + caption=caption, + reply_to=reply_to, + ) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + del kwargs + return await self._send_media_source( + chat_id=chat_id, + media_source=video_path, + caption=caption, + reply_to=reply_to, + ) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """WeCom does not expose typing indicators in this adapter.""" + del chat_id, metadata + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return minimal chat info.""" + return { + "name": chat_id, + "type": "group" if chat_id and chat_id.lower().startswith("group") else "dm", + } diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index b83657401..02448a6dd 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -26,6 +26,7 @@ from pathlib import Path from typing import Dict, Optional, Any from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_dir logger = logging.getLogger(__name__) @@ -134,13 +135,15 @@ class WhatsAppAdapter(BasePlatformAdapter): ) self._session_path: Path = Path(config.extra.get( "session_path", - get_hermes_home() / "whatsapp" / "session" + get_hermes_dir("platforms/whatsapp/session", "whatsapp/session") )) self._reply_prefix: Optional[str] = config.extra.get("reply_prefix") self._message_queue: asyncio.Queue = asyncio.Queue() self._bridge_log_fh = None self._bridge_log: Optional[Path] = None self._poll_task: Optional[asyncio.Task] = None + self._http_session: Optional["aiohttp.ClientSession"] = None + self._session_lock_identity: Optional[str] = None async def connect(self) -> bool: """ @@ -159,6 +162,29 @@ class WhatsAppAdapter(BasePlatformAdapter): logger.info("[%s] Bridge found at %s", self.name, bridge_path) + # Acquire scoped lock to prevent duplicate sessions + try: + from gateway.status import acquire_scoped_lock + + self._session_lock_identity = str(self._session_path) + acquired, existing = acquire_scoped_lock( + "whatsapp-session", + self._session_lock_identity, + metadata={"platform": self.platform.value}, + ) + if not acquired: + owner_pid = existing.get("pid") if isinstance(existing, dict) else None + message = ( + "Another local Hermes gateway is already using this WhatsApp session" + + (f" (PID {owner_pid})." if owner_pid else ".") + + " Stop the other gateway before starting a second WhatsApp bridge." + ) + logger.error("[%s] %s", self.name, message) + self._set_fatal_error("whatsapp_session_lock", message, retryable=False) + return False + except Exception as e: + logger.warning("[%s] Could not acquire session lock (non-fatal): %s", self.name, e) + # Auto-install npm dependencies if node_modules doesn't exist bridge_dir = bridge_path.parent if not (bridge_dir / "node_modules").exists(): @@ -199,6 +225,7 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] Using existing bridge (status: {bridge_status})") self._mark_connected() self._bridge_process = None # Not managed by us + self._http_session = aiohttp.ClientSession() self._poll_task = asyncio.create_task(self._poll_messages()) return True else: @@ -304,6 +331,9 @@ class WhatsAppAdapter(BasePlatformAdapter): print(f"[{self.name}] Bridge log: {self._bridge_log}") print(f"[{self.name}] If session expired, re-pair: hermes whatsapp") + # Create a persistent HTTP session for all bridge communication + self._http_session = aiohttp.ClientSession() + # Start message polling task self._poll_task = asyncio.create_task(self._poll_messages()) @@ -312,6 +342,12 @@ class WhatsAppAdapter(BasePlatformAdapter): return True except Exception as e: + if self._session_lock_identity: + try: + from gateway.status import release_scoped_lock + release_scoped_lock("whatsapp-session", self._session_lock_identity) + except Exception: + pass logger.error("[%s] Failed to start bridge: %s", self.name, e, exc_info=True) self._close_bridge_log() return False @@ -369,10 +405,32 @@ class WhatsAppAdapter(BasePlatformAdapter): else: # Bridge was not started by us, don't kill it print(f"[{self.name}] Disconnecting (external bridge left running)") - + + # Cancel the poll task explicitly + if self._poll_task and not self._poll_task.done(): + self._poll_task.cancel() + try: + await self._poll_task + except (asyncio.CancelledError, Exception): + pass + self._poll_task = None + + # Close the persistent HTTP session + if self._http_session and not self._http_session.closed: + await self._http_session.close() + self._http_session = None + + if self._session_lock_identity: + try: + from gateway.status import release_scoped_lock + release_scoped_lock("whatsapp-session", self._session_lock_identity) + except Exception as e: + logger.warning("[%s] Error releasing WhatsApp session lock: %s", self.name, e, exc_info=True) + self._mark_disconnected() self._bridge_process = None self._close_bridge_log() + self._session_lock_identity = None print(f"[{self.name}] Disconnected") async def send( @@ -383,7 +441,7 @@ class WhatsAppAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None ) -> SendResult: """Send a message via the WhatsApp bridge.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: @@ -391,36 +449,29 @@ class WhatsAppAdapter(BasePlatformAdapter): try: import aiohttp + + payload = { + "chatId": chat_id, + "message": content, + } + if reply_to: + payload["replyTo"] = reply_to - async with aiohttp.ClientSession() as session: - payload = { - "chatId": chat_id, - "message": content, - } - if reply_to: - payload["replyTo"] = reply_to - - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/send", - json=payload, - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - if resp.status == 200: - data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data - ) - else: - error = await resp.text() - return SendResult(success=False, error=error) - - except ImportError: - return SendResult( - success=False, - error="aiohttp not installed. Run: pip install aiohttp" - ) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/send", + json=payload, + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status == 200: + data = await resp.json() + return SendResult( + success=True, + message_id=data.get("messageId"), + raw_response=data + ) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -431,28 +482,27 @@ class WhatsAppAdapter(BasePlatformAdapter): content: str, ) -> SendResult: """Edit a previously sent message via the WhatsApp bridge.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: return SendResult(success=False, error=bridge_exit) try: import aiohttp - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/edit", - json={ - "chatId": chat_id, - "messageId": message_id, - "message": content, - }, - timeout=aiohttp.ClientTimeout(total=15) - ) as resp: - if resp.status == 200: - return SendResult(success=True, message_id=message_id) - else: - error = await resp.text() - return SendResult(success=False, error=error) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/edit", + json={ + "chatId": chat_id, + "messageId": message_id, + "message": content, + }, + timeout=aiohttp.ClientTimeout(total=15) + ) as resp: + if resp.status == 200: + return SendResult(success=True, message_id=message_id) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -465,7 +515,7 @@ class WhatsAppAdapter(BasePlatformAdapter): file_name: Optional[str] = None, ) -> SendResult: """Send any media file via bridge /send-media endpoint.""" - if not self._running: + if not self._running or not self._http_session: return SendResult(success=False, error="Not connected") bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: @@ -486,22 +536,21 @@ class WhatsAppAdapter(BasePlatformAdapter): if file_name: payload["fileName"] = file_name - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://127.0.0.1:{self._bridge_port}/send-media", - json=payload, - timeout=aiohttp.ClientTimeout(total=120), - ) as resp: - if resp.status == 200: - data = await resp.json() - return SendResult( - success=True, - message_id=data.get("messageId"), - raw_response=data, - ) - else: - error = await resp.text() - return SendResult(success=False, error=error) + async with self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/send-media", + json=payload, + timeout=aiohttp.ClientTimeout(total=120), + ) as resp: + if resp.status == 200: + data = await resp.json() + return SendResult( + success=True, + message_id=data.get("messageId"), + raw_response=data, + ) + else: + error = await resp.text() + return SendResult(success=False, error=error) except Exception as e: return SendResult(success=False, error=str(e)) @@ -526,6 +575,7 @@ class WhatsAppAdapter(BasePlatformAdapter): image_path: str, caption: Optional[str] = None, reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a local image file natively via bridge.""" return await self._send_media_to_bridge(chat_id, image_path, "image", caption) @@ -536,6 +586,7 @@ class WhatsAppAdapter(BasePlatformAdapter): video_path: str, caption: Optional[str] = None, reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a video natively via bridge — plays inline in WhatsApp.""" return await self._send_media_to_bridge(chat_id, video_path, "video", caption) @@ -547,6 +598,7 @@ class WhatsAppAdapter(BasePlatformAdapter): caption: Optional[str] = None, file_name: Optional[str] = None, reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a document/file as a downloadable attachment via bridge.""" return await self._send_media_to_bridge( @@ -556,45 +608,43 @@ class WhatsAppAdapter(BasePlatformAdapter): async def send_typing(self, chat_id: str, metadata=None) -> None: """Send typing indicator via bridge.""" - if not self._running: + if not self._running or not self._http_session: return if await self._check_managed_bridge_exit(): return try: import aiohttp - - async with aiohttp.ClientSession() as session: - await session.post( - f"http://127.0.0.1:{self._bridge_port}/typing", - json={"chatId": chat_id}, - timeout=aiohttp.ClientTimeout(total=5) - ) + + await self._http_session.post( + f"http://127.0.0.1:{self._bridge_port}/typing", + json={"chatId": chat_id}, + timeout=aiohttp.ClientTimeout(total=5) + ) except Exception: pass # Ignore typing indicator failures async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Get information about a WhatsApp chat.""" - if not self._running: + if not self._running or not self._http_session: return {"name": "Unknown", "type": "dm"} if await self._check_managed_bridge_exit(): return {"name": chat_id, "type": "dm"} try: import aiohttp - - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}", - timeout=aiohttp.ClientTimeout(total=10) - ) as resp: - if resp.status == 200: - data = await resp.json() - return { - "name": data.get("name", chat_id), - "type": "group" if data.get("isGroup") else "dm", - "participants": data.get("participants", []), - } + + async with self._http_session.get( + f"http://127.0.0.1:{self._bridge_port}/chat/{chat_id}", + timeout=aiohttp.ClientTimeout(total=10) + ) as resp: + if resp.status == 200: + data = await resp.json() + return { + "name": data.get("name", chat_id), + "type": "group" if data.get("isGroup") else "dm", + "participants": data.get("participants", []), + } except Exception as e: logger.debug("Could not get WhatsApp chat info for %s: %s", chat_id, e) @@ -602,29 +652,26 @@ class WhatsAppAdapter(BasePlatformAdapter): async def _poll_messages(self) -> None: """Poll the bridge for incoming messages.""" - try: - import aiohttp - except ImportError: - print(f"[{self.name}] aiohttp not installed, message polling disabled") - return - + import aiohttp + while self._running: + if not self._http_session: + break bridge_exit = await self._check_managed_bridge_exit() if bridge_exit: print(f"[{self.name}] {bridge_exit}") break try: - async with aiohttp.ClientSession() as session: - async with session.get( - f"http://127.0.0.1:{self._bridge_port}/messages", - timeout=aiohttp.ClientTimeout(total=30) - ) as resp: - if resp.status == 200: - messages = await resp.json() - for msg_data in messages: - event = await self._build_message_event(msg_data) - if event: - await self.handle_message(event) + async with self._http_session.get( + f"http://127.0.0.1:{self._bridge_port}/messages", + timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + if resp.status == 200: + messages = await resp.json() + for msg_data in messages: + event = await self._build_message_event(msg_data) + if event: + await self.handle_message(event) except asyncio.CancelledError: break except Exception as e: diff --git a/gateway/run.py b/gateway/run.py index fd0d60042..3c0ca1818 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -77,6 +77,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) # Resolve Hermes home directory (respects HERMES_HOME override) from hermes_constants import get_hermes_home +from utils import atomic_yaml_write _hermes_home = get_hermes_home() # Load environment variables from ~/.hermes/.env first. @@ -224,6 +225,49 @@ from gateway.session import ( from gateway.delivery import DeliveryRouter from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType + +def _normalize_whatsapp_identifier(value: str) -> str: + """Strip WhatsApp JID/LID syntax down to its stable numeric identifier.""" + return ( + str(value or "") + .strip() + .replace("+", "", 1) + .split(":", 1)[0] + .split("@", 1)[0] + ) + + +def _expand_whatsapp_auth_aliases(identifier: str) -> set: + """Resolve WhatsApp phone/LID aliases using bridge session mapping files.""" + normalized = _normalize_whatsapp_identifier(identifier) + if not normalized: + return set() + + session_dir = _hermes_home / "whatsapp" / "session" + resolved = set() + queue = [normalized] + + while queue: + current = queue.pop(0) + if not current or current in resolved: + continue + + resolved.add(current) + for suffix in ("", "_reverse"): + mapping_path = session_dir / f"lid-mapping-{current}{suffix}.json" + if not mapping_path.exists(): + continue + try: + mapped = _normalize_whatsapp_identifier( + json.loads(mapping_path.read_text(encoding="utf-8")) + ) + except Exception: + continue + if mapped and mapped not in resolved: + queue.append(mapped) + + return resolved + logger = logging.getLogger(__name__) # Sentinel placed into _running_agents immediately when a session starts @@ -279,16 +323,16 @@ def _resolve_gateway_model(config: dict | None = None) -> str: """Read model from env/config — mirrors the resolution in _run_agent_sync. Without this, temporary AIAgent instances (memory flush, /compress) fall - back to the hardcoded default ("anthropic/claude-opus-4.6") which fails - when the active provider is openai-codex. + back to the hardcoded default which fails when the active provider is + openai-codex. """ - model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6" + model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "" cfg = config if config is not None else _load_gateway_config() model_cfg = cfg.get("model", {}) if isinstance(model_cfg, str): model = model_cfg elif isinstance(model_cfg, dict): - model = model_cfg.get("default", model) + model = model_cfg.get("default") or model_cfg.get("model") or model return model @@ -432,7 +476,7 @@ class GatewayRunner: from honcho_integration.session import HonchoSessionManager hcfg = HonchoClientConfig.from_global_config() - if not hcfg.enabled or not hcfg.api_key: + if not hcfg.enabled or not (hcfg.api_key or hcfg.base_url): return None, hcfg client = get_honcho_client(hcfg) @@ -573,6 +617,10 @@ class GatewayRunner: session_id=old_session_id, honcho_session_key=honcho_session_key, ) + # Fully silence the flush agent — quiet_mode only suppresses init + # messages; tool call output still leaks to the terminal through + # _safe_print → _print_fn. Set a no-op to prevent that. + tmp_agent._print_fn = lambda *a, **kw: None # Build conversation history from transcript msgs = [ @@ -741,10 +789,22 @@ class GatewayRunner: logger.error("No connected messaging platforms remain. Shutting down gateway cleanly.") await self.stop() elif not self.adapters and self._failed_platforms: - logger.warning( - "No connected messaging platforms remain, but %d platform(s) queued for reconnection", - len(self._failed_platforms), - ) + # All platforms are down and queued for background reconnection. + # If the error is retryable, exit with failure so systemd Restart=on-failure + # can restart the process. Otherwise stay alive and keep retrying in background. + if adapter.fatal_error_retryable: + self._exit_reason = adapter.fatal_error_message or "All messaging platforms failed with retryable errors" + self._exit_with_failure = True + logger.error( + "All messaging platforms failed with retryable errors. " + "Shutting down gateway for service restart (systemd will retry)." + ) + await self.stop() + else: + logger.warning( + "No connected messaging platforms remain, but %d platform(s) queued for reconnection", + len(self._failed_platforms), + ) def _request_clean_exit(self, reason: str) -> None: self._exit_cleanly = True @@ -902,11 +962,12 @@ class GatewayRunner: return {} @staticmethod - def _load_fallback_model() -> dict | None: - """Load fallback model config from config.yaml. + def _load_fallback_model() -> list | dict | None: + """Load fallback provider chain from config.yaml. - Returns a dict with 'provider' and 'model' keys, or None if - not configured / both fields empty. + Returns a list of provider dicts (``fallback_providers``), a single + dict (legacy ``fallback_model``), or None if not configured. + AIAgent.__init__ normalizes both formats into a chain. """ try: import yaml as _y @@ -914,8 +975,8 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - fb = cfg.get("fallback_model", {}) or {} - if fb.get("provider") and fb.get("model"): + fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or None + if fb: return fb except Exception: pass @@ -943,6 +1004,13 @@ class GatewayRunner: """ logger.info("Starting Hermes Gateway...") logger.info("Session storage: %s", self.config.sessions_dir) + try: + from hermes_cli.profiles import get_active_profile_name + _profile = get_active_profile_name() + if _profile and _profile != "default": + logger.info("Active profile: %s", _profile) + except Exception: + pass try: from gateway.status import write_runtime_status write_runtime_status(gateway_state="starting", exit_reason=None) @@ -954,12 +1022,24 @@ class GatewayRunner: os.getenv(v) for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", - "SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS", + "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", + "EMAIL_ALLOWED_USERS", "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", + "FEISHU_ALLOWED_USERS", + "WECOM_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") ) - _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") + _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( + os.getenv(v, "").lower() in ("true", "1", "yes") + for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS", + "WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS", + "SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS", + "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", + "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", + "FEISHU_ALLOW_ALL_USERS", + "WECOM_ALLOW_ALL_USERS") + ) if not _any_allowlist and not _allow_all: logger.warning( "No user allowlists configured. All unauthorized users will be denied. " @@ -1401,6 +1481,20 @@ class GatewayRunner: return None return DingTalkAdapter(config) + elif platform == Platform.FEISHU: + from gateway.platforms.feishu import FeishuAdapter, check_feishu_requirements + if not check_feishu_requirements(): + logger.warning("Feishu: lark-oapi not installed or FEISHU_APP_ID/SECRET not set") + return None + return FeishuAdapter(config) + + elif platform == Platform.WECOM: + from gateway.platforms.wecom import WeComAdapter, check_wecom_requirements + if not check_wecom_requirements(): + logger.warning("WeCom: aiohttp not installed or WECOM_BOT_ID/SECRET not set") + return None + return WeComAdapter(config) + elif platform == Platform.MATTERMOST: from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements if not check_mattermost_requirements(): @@ -1467,6 +1561,8 @@ class GatewayRunner: Platform.MATTERMOST: "MATTERMOST_ALLOWED_USERS", Platform.MATRIX: "MATRIX_ALLOWED_USERS", Platform.DINGTALK: "DINGTALK_ALLOWED_USERS", + Platform.FEISHU: "FEISHU_ALLOWED_USERS", + Platform.WECOM: "WECOM_ALLOWED_USERS", } platform_allow_all_map = { Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS", @@ -1479,6 +1575,8 @@ class GatewayRunner: Platform.MATTERMOST: "MATTERMOST_ALLOW_ALL_USERS", Platform.MATRIX: "MATRIX_ALLOW_ALL_USERS", Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS", + Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS", + Platform.WECOM: "WECOM_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) @@ -1506,10 +1604,23 @@ class GatewayRunner: if global_allowlist: allowed_ids.update(uid.strip() for uid in global_allowlist.split(",") if uid.strip()) - # WhatsApp JIDs have @s.whatsapp.net suffix — strip it for comparison check_ids = {user_id} if "@" in user_id: check_ids.add(user_id.split("@")[0]) + + # WhatsApp: resolve phone↔LID aliases from bridge session mapping files + if source.platform == Platform.WHATSAPP: + normalized_allowed_ids = set() + for allowed_id in allowed_ids: + normalized_allowed_ids.update(_expand_whatsapp_auth_aliases(allowed_id)) + if normalized_allowed_ids: + allowed_ids = normalized_allowed_ids + + check_ids.update(_expand_whatsapp_auth_aliases(user_id)) + normalized_user_id = _normalize_whatsapp_identifier(user_id) + if normalized_user_id: + check_ids.add(normalized_user_id) + return bool(check_ids & allowed_ids) def _get_unauthorized_dm_behavior(self, platform: Optional[Platform]) -> str: @@ -1970,6 +2081,12 @@ class GatewayRunner: f"Use /resume to browse and restore a previous session.\n" f"Adjust reset timing in config.yaml under session_reset." ) + try: + session_info = self._format_session_info() + if session_info: + notice = f"{notice}\n\n{session_info}" + except Exception: + pass await adapter.send( source.chat_id, notice, metadata=getattr(event, 'metadata', None), @@ -2063,7 +2180,7 @@ class GatewayRunner: if isinstance(_model_cfg, str): _hyg_model = _model_cfg elif isinstance(_model_cfg, dict): - _hyg_model = _model_cfg.get("default", _hyg_model) + _hyg_model = _model_cfg.get("default") or _model_cfg.get("model") or _hyg_model # Read explicit context_length override from model config # (same as run_agent.py lines 995-1005) _raw_ctx = _model_cfg.get("context_length") @@ -2175,6 +2292,7 @@ class GatewayRunner: enabled_toolsets=["memory"], session_id=session_entry.session_id, ) + _hyg_agent._print_fn = lambda *a, **kw: None loop = asyncio.get_event_loop() _compressed, _ = await loop.run_in_executor( @@ -2185,6 +2303,15 @@ class GatewayRunner: ), ) + # _compress_context ends the old session and creates + # a new session_id. Write compressed messages into + # the NEW session so the old transcript stays intact + # and searchable via session_search. + _hyg_new_sid = _hyg_agent.session_id + if _hyg_new_sid != session_entry.session_id: + session_entry.session_id = _hyg_new_sid + self.session_store._save() + self.session_store.rewrite_transcript( session_entry.session_id, _compressed ) @@ -2736,6 +2863,85 @@ class GatewayRunner: # Clear session env self._clear_session_env() + def _format_session_info(self) -> str: + """Resolve current model config and return a formatted info block. + + Surfaces model, provider, context length, and endpoint so gateway + users can immediately see if context detection went wrong (e.g. + local models falling to the 128K default). + """ + from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT + + model = _resolve_gateway_model() + config_context_length = None + provider = None + base_url = None + api_key = None + + try: + cfg_path = _hermes_home / "config.yaml" + if cfg_path.exists(): + import yaml as _info_yaml + with open(cfg_path, encoding="utf-8") as f: + data = _info_yaml.safe_load(f) or {} + model_cfg = data.get("model", {}) + if isinstance(model_cfg, dict): + raw_ctx = model_cfg.get("context_length") + if raw_ctx is not None: + try: + config_context_length = int(raw_ctx) + except (TypeError, ValueError): + pass + provider = model_cfg.get("provider") or None + base_url = model_cfg.get("base_url") or None + except Exception: + pass + + # Resolve runtime credentials for probing + try: + runtime = _resolve_runtime_agent_kwargs() + provider = provider or runtime.get("provider") + base_url = base_url or runtime.get("base_url") + api_key = runtime.get("api_key") + except Exception: + pass + + context_length = get_model_context_length( + model, + base_url=base_url or "", + api_key=api_key or "", + config_context_length=config_context_length, + provider=provider or "", + ) + + # Format context source hint + if config_context_length is not None: + ctx_source = "config" + elif context_length == DEFAULT_FALLBACK_CONTEXT: + ctx_source = "default — set model.context_length in config to override" + else: + ctx_source = "detected" + + # Format context length for display + if context_length >= 1_000_000: + ctx_display = f"{context_length / 1_000_000:.1f}M" + elif context_length >= 1_000: + ctx_display = f"{context_length // 1_000}K" + else: + ctx_display = str(context_length) + + lines = [ + f"◆ Model: `{model}`", + f"◆ Provider: {provider or 'openrouter'}", + f"◆ Context: {ctx_display} tokens ({ctx_source})", + ] + + # Show endpoint for local/custom setups + if base_url and ("localhost" in base_url or "127.0.0.1" in base_url or "0.0.0.0" in base_url): + lines.append(f"◆ Endpoint: {base_url}") + + return "\n".join(lines) + async def _handle_reset_command(self, event: MessageEvent) -> str: """Handle /new or /reset command.""" source = event.source @@ -2776,12 +2982,22 @@ class GatewayRunner: "session_key": session_key, }) + # Resolve session config info to surface to the user + try: + session_info = self._format_session_info() + except Exception: + session_info = "" + if new_entry: - return "✨ Session reset! I've started fresh with no memory of our previous conversation." + header = "✨ Session reset! Starting fresh." else: # No existing session, just create one self.session_store.get_or_create_session(source, force_new=True) - return "✨ New session started!" + header = "✨ New session started!" + + if session_info: + return f"{header}\n\n{session_info}" + return header async def _handle_status_command(self, event: MessageEvent) -> str: """Handle /status command.""" @@ -2959,8 +3175,7 @@ class GatewayRunner: if "agent" not in config or not isinstance(config.get("agent"), dict): config["agent"] = {} config["agent"]["system_prompt"] = "" - with open(config_path, "w") as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) + atomic_yaml_write(config_path, config) except Exception as e: return f"⚠️ Failed to save personality change: {e}" self._ephemeral_system_prompt = "" @@ -2973,8 +3188,7 @@ class GatewayRunner: if "agent" not in config or not isinstance(config.get("agent"), dict): config["agent"] = {} config["agent"]["system_prompt"] = new_prompt - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(config, f, default_flow_style=False, sort_keys=False) + atomic_yaml_write(config_path, config) except Exception as e: return f"⚠️ Failed to save personality change: {e}" @@ -3064,8 +3278,7 @@ class GatewayRunner: with open(config_path, encoding="utf-8") as f: user_config = yaml.safe_load(f) or {} user_config[env_key] = chat_id - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False) + atomic_yaml_write(config_path, user_config) # Also set in the current environment so it takes effect immediately os.environ[env_key] = str(chat_id) except Exception as e: @@ -3733,8 +3946,7 @@ class GatewayRunner: current[k] = {} current = current[k] current[keys[-1]] = value - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + atomic_yaml_write(config_path, user_config) return True except Exception as e: logger.error("Failed to save config key %s: %s", key_path, e) @@ -3842,8 +4054,7 @@ class GatewayRunner: if "display" not in user_config or not isinstance(user_config.get("display"), dict): user_config["display"] = {} user_config["display"]["tool_progress"] = new_mode - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + atomic_yaml_write(config_path, user_config) return f"{descriptions[new_mode]}\n_(saved to config — takes effect on next message)_" except Exception as e: logger.warning("Failed to save tool_progress mode: %s", e) @@ -3885,17 +4096,27 @@ class GatewayRunner: enabled_toolsets=["memory"], session_id=session_entry.session_id, ) + tmp_agent._print_fn = lambda *a, **kw: None loop = asyncio.get_event_loop() compressed, _ = await loop.run_in_executor( None, - lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens), + lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens) ) - self.session_store.rewrite_transcript(session_entry.session_id, compressed) + # _compress_context already calls end_session() on the old session + # (preserving its full transcript in SQLite) and creates a new + # session_id for the continuation. Write the compressed messages + # into the NEW session so the original history stays searchable. + new_session_id = tmp_agent.session_id + if new_session_id != session_entry.session_id: + session_entry.session_id = new_session_id + self.session_store._save() + + self.session_store.rewrite_transcript(new_session_id, compressed) # Reset stored token count — transcript changed, old value is stale self.session_store.update_session( - session_entry.session_key, last_prompt_tokens=0, + session_entry.session_key, last_prompt_tokens=0 ) new_count = len(compressed) new_tokens = estimate_messages_tokens_rough(compressed) @@ -4051,7 +4272,7 @@ class GatewayRunner: ] ctx = agent.context_compressor if ctx.last_prompt_tokens: - pct = ctx.last_prompt_tokens / ctx.context_length * 100 if ctx.context_length else 0 + pct = min(100, ctx.last_prompt_tokens / ctx.context_length * 100) if ctx.context_length else 0 lines.append(f"Context: {ctx.last_prompt_tokens:,} / {ctx.context_length:,} ({pct:.0f}%)") if ctx.compression_count: lines.append(f"Compressions: {ctx.compression_count}") @@ -4798,10 +5019,23 @@ class GatewayRunner: from hermes_cli.tools_config import _get_platform_tools enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key)) + # Apply tool preview length config (0 = no limit) + try: + from agent.display import set_tool_preview_max_len + _tpl = user_config.get("display", {}).get("tool_preview_length", 0) + set_tool_preview_max_len(int(_tpl) if _tpl else 0) + except Exception: + pass + # Tool progress mode from config.yaml: "all", "new", "verbose", "off" - # Falls back to env vars for backward compatibility + # Falls back to env vars for backward compatibility. + # YAML 1.1 parses bare `off` as boolean False — normalise before + # the `or` chain so it doesn't silently fall through to "all". + _raw_tp = user_config.get("display", {}).get("tool_progress") + if _raw_tp is False: + _raw_tp = "off" progress_mode = ( - user_config.get("display", {}).get("tool_progress") + _raw_tp or os.getenv("HERMES_TOOL_PROGRESS_MODE") or "all" ) @@ -4838,9 +5072,11 @@ class GatewayRunner: return if preview: - # Truncate preview to keep messages clean - if len(preview) > 80: - preview = preview[:77] + "..." + # Truncate preview unless config says unlimited + from agent.display import get_tool_preview_max_len + _pl = get_tool_preview_max_len() + if _pl > 0 and len(preview) > _pl: + preview = preview[:_pl - 3] + "..." msg = f"{emoji} {tool_name}: \"{preview}\"" else: msg = f"{emoji} {tool_name}..." @@ -4860,12 +5096,17 @@ class GatewayRunner: progress_queue.put(msg) # Background task to send progress messages - # Accumulates tool lines into a single message that gets edited - # For DM top-level Slack messages, source.thread_id is None but the - # final reply will be threaded under the original message via reply_to. - # Use event_message_id as fallback so progress messages land in the - # same thread as the final response instead of going to the DM root. - _progress_thread_id = source.thread_id or event_message_id + # Accumulates tool lines into a single message that gets edited. + # + # Threading metadata is platform-specific: + # - Slack DM threading needs event_message_id fallback (reply thread) + # - Telegram uses message_thread_id only for forum topics; passing a + # normal DM/group message id as thread_id causes send failures + # - Other platforms should use explicit source.thread_id only + if source.platform == Platform.SLACK: + _progress_thread_id = source.thread_id or event_message_id + else: + _progress_thread_id = source.thread_id _progress_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None async def send_progress_messages(): @@ -5128,7 +5369,25 @@ class GatewayRunner: agent.stream_delta_callback = _stream_delta_cb agent.status_callback = _status_callback_sync agent.reasoning_config = reasoning_config - + + # Background review delivery — send "💾 Memory updated" etc. to user + def _bg_review_send(message: str) -> None: + if not _status_adapter: + return + try: + asyncio.run_coroutine_threadsafe( + _status_adapter.send( + _status_chat_id, + message, + metadata=_status_thread_metadata, + ), + _loop_for_step, + ) + except Exception as _e: + logger.debug("background_review_callback error: %s", _e) + + agent.background_review_callback = _bg_review_send + # Store agent reference for interrupt support agent_holder[0] = agent # Capture the full tool definitions for transcript logging diff --git a/gateway/session.py b/gateway/session.py index b85ac3e3a..5aefb6c01 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -762,14 +762,16 @@ class SessionStore: if session_key in self._entries: entry = self._entries[session_key] entry.updated_at = _now() - entry.input_tokens += input_tokens - entry.output_tokens += output_tokens - entry.cache_read_tokens += cache_read_tokens - entry.cache_write_tokens += cache_write_tokens + # Direct assignment — the gateway receives cumulative totals + # from the cached agent, not per-call deltas. + entry.input_tokens = input_tokens + entry.output_tokens = output_tokens + entry.cache_read_tokens = cache_read_tokens + entry.cache_write_tokens = cache_write_tokens if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens if estimated_cost_usd is not None: - entry.estimated_cost_usd += estimated_cost_usd + entry.estimated_cost_usd = estimated_cost_usd if cost_status: entry.cost_status = cost_status entry.total_tokens = ( @@ -783,7 +785,7 @@ class SessionStore: if self._db and db_session_id: try: - self._db.update_token_counts( + self._db.set_token_counts( db_session_id, input_tokens=input_tokens, output_tokens=output_tokens, @@ -795,6 +797,7 @@ class SessionStore: billing_provider=provider, billing_base_url=base_url, model=model, + absolute=True, ) except Exception as e: logger.debug("Session DB operation failed: %s", e) @@ -955,13 +958,17 @@ class SessionStore: try: self._db.clear_messages(session_id) for msg in messages: + role = msg.get("role", "unknown") self._db.append_message( session_id=session_id, - role=msg.get("role", "unknown"), + role=role, content=msg.get("content"), tool_name=msg.get("tool_name"), tool_calls=msg.get("tool_calls"), tool_call_id=msg.get("tool_call_id"), + reasoning=msg.get("reasoning") if role == "assistant" else None, + reasoning_details=msg.get("reasoning_details") if role == "assistant" else None, + codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None, ) except Exception as e: logger.debug("Failed to rewrite transcript in DB: %s", e) diff --git a/hermes_cli/__init__.py b/hermes_cli/__init__.py index 047783202..797c7e8d6 100644 --- a/hermes_cli/__init__.py +++ b/hermes_cli/__init__.py @@ -11,5 +11,5 @@ Provides subcommands for: - hermes cron - Manage cron jobs """ -__version__ = "0.4.0" -__release_date__ = "2026.3.23" +__version__ = "0.5.0" +__release_date__ = "2026.3.28" diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 9eb867352..faebe9443 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -160,7 +160,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { id="alibaba", name="Alibaba Cloud (DashScope)", auth_type="api_key", - inference_base_url="https://dashscope-intl.aliyuncs.com/apps/anthropic", + inference_base_url="https://coding-intl.dashscope.aliyuncs.com/v1", api_key_env_vars=("DASHSCOPE_API_KEY",), base_url_env_var="DASHSCOPE_BASE_URL", ), @@ -212,6 +212,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { api_key_env_vars=("KILOCODE_API_KEY",), base_url_env_var="KILOCODE_BASE_URL", ), + "huggingface": ProviderConfig( + id="huggingface", + name="Hugging Face", + auth_type="api_key", + inference_base_url="https://router.huggingface.co/v1", + api_key_env_vars=("HF_TOKEN",), + base_url_env_var="HF_BASE_URL", + ), } @@ -685,8 +693,13 @@ def resolve_provider( "github-copilot-acp": "copilot-acp", "copilot-acp-agent": "copilot-acp", "aigateway": "ai-gateway", "vercel": "ai-gateway", "vercel-ai-gateway": "ai-gateway", "opencode": "opencode-zen", "zen": "opencode-zen", + "hf": "huggingface", "hugging-face": "huggingface", "huggingface-hub": "huggingface", "go": "opencode-go", "opencode-go-sub": "opencode-go", "kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode", + # Local server aliases — route through the generic custom provider + "lmstudio": "custom", "lm-studio": "custom", "lm_studio": "custom", + "ollama": "custom", "vllm": "custom", "llamacpp": "custom", + "llama.cpp": "custom", "llama-cpp": "custom", } normalized = _PROVIDER_ALIASES.get(normalized, normalized) @@ -733,7 +746,12 @@ def resolve_provider( if has_usable_secret(os.getenv(env_var, "")): return pid - return "openrouter" + raise AuthError( + "No inference provider configured. Run 'hermes model' to choose a " + "provider and model, or set an API key (OPENROUTER_API_KEY, " + "OPENAI_API_KEY, etc.) in ~/.hermes/.env.", + code="no_provider_configured", + ) # ============================================================================= @@ -2095,7 +2113,8 @@ def _login_openai_codex(args, pconfig: ProviderConfig) -> None: config_path = _update_config_for_provider("openai-codex", creds.get("base_url", DEFAULT_CODEX_BASE_URL)) print() print("Login successful!") - print(" Auth state: ~/.hermes/auth.json") + from hermes_constants import display_hermes_home as _dhh + print(f" Auth state: {_dhh()}/auth.json") print(f" Config updated: {config_path} (model.provider=openai-codex)") diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index c4eb827e2..5ecc94acf 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -258,7 +258,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, get_toolset_for_tool: Callable to map tool name -> toolset name. context_length: Model's context window size in tokens. """ - from model_tools import check_tool_availability + from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS if get_toolset_for_tool is None: from model_tools import get_toolset_for_tool @@ -267,8 +267,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str, _, unavailable_toolsets = check_tool_availability(quiet=True) disabled_tools = set() + # Tools whose toolset has a check_fn are lazy-initialized (e.g. honcho, + # homeassistant) — they show as unavailable at banner time because the + # check hasn't run yet, but they aren't misconfigured. + lazy_tools = set() for item in unavailable_toolsets: - disabled_tools.update(item.get("tools", [])) + toolset_name = item.get("name", "") + ts_req = TOOLSET_REQUIREMENTS.get(toolset_name, {}) + tools_in_ts = item.get("tools", []) + if ts_req.get("check_fn"): + lazy_tools.update(tools_in_ts) + else: + disabled_tools.update(tools_in_ts) layout_table = Table.grid(padding=(0, 2)) layout_table.add_column("left", justify="center") @@ -328,6 +338,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, for name in sorted(tool_names): if name in disabled_tools: colored_names.append(f"[red]{name}[/]") + elif name in lazy_tools: + colored_names.append(f"[yellow]{name}[/]") else: colored_names.append(f"[{text}]{name}[/]") @@ -347,6 +359,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, colored_names.append("[dim]...[/]") elif name in disabled_tools: colored_names.append(f"[red]{name}[/]") + elif name in lazy_tools: + colored_names.append(f"[yellow]{name}[/]") else: colored_names.append(f"[{text}]{name}[/]") tools_str = ", ".join(colored_names) @@ -403,6 +417,15 @@ def build_welcome_banner(console: Console, model: str, cwd: str, if mcp_connected: summary_parts.append(f"{mcp_connected} MCP servers") summary_parts.append("/help for commands") + # Show active profile name when not 'default' + try: + from hermes_cli.profiles import get_active_profile_name + _profile_name = get_active_profile_name() + if _profile_name and _profile_name != "default": + right_lines.append(f"[bold {accent}]Profile:[/] [{text}]{_profile_name}[/]") + except Exception: + pass # Never break the banner over a profiles.py bug + right_lines.append(f"[dim {dim}]{' · '.join(summary_parts)}[/]") # Update check — use prefetched result if available diff --git a/hermes_cli/callbacks.py b/hermes_cli/callbacks.py index 88a97511c..fa51ee157 100644 --- a/hermes_cli/callbacks.py +++ b/hermes_cli/callbacks.py @@ -12,6 +12,7 @@ import getpass from hermes_cli.banner import cprint, _DIM, _RST from hermes_cli.config import save_env_value_secure +from hermes_constants import display_hermes_home def clarify_callback(cli, question, choices): @@ -131,7 +132,8 @@ def prompt_for_secret(cli, var_name: str, prompt: str, metadata=None) -> dict: } stored = save_env_value_secure(var_name, value) - cprint(f"\n{_DIM} ✓ Stored secret in ~/.hermes/.env as {var_name}{_RST}") + _dhh = display_hermes_home() + cprint(f"\n{_DIM} ✓ Stored secret in {_dhh}/.env as {var_name}{_RST}") return { **stored, "skipped": False, @@ -183,7 +185,8 @@ def prompt_for_secret(cli, var_name: str, prompt: str, metadata=None) -> dict: } stored = save_env_value_secure(var_name, value) - cprint(f"\n{_DIM} ✓ Stored secret in ~/.hermes/.env as {var_name}{_RST}") + _dhh = display_hermes_home() + cprint(f"\n{_DIM} ✓ Stored secret in {_dhh}/.env as {var_name}{_RST}") return { **stored, "skipped": False, diff --git a/hermes_cli/codex_models.py b/hermes_cli/codex_models.py index 169c63e8a..f5616b68d 100644 --- a/hermes_cli/codex_models.py +++ b/hermes_cli/codex_models.py @@ -12,6 +12,8 @@ import os logger = logging.getLogger(__name__) DEFAULT_CODEX_MODELS: List[str] = [ + "gpt-5.4-mini", + "gpt-5.4", "gpt-5.3-codex", "gpt-5.2-codex", "gpt-5.1-codex-max", @@ -19,8 +21,9 @@ DEFAULT_CODEX_MODELS: List[str] = [ ] _FORWARD_COMPAT_TEMPLATE_MODELS: List[tuple[str, tuple[str, ...]]] = [ - ("gpt-5.3-codex", ("gpt-5.2-codex",)), + ("gpt-5.4-mini", ("gpt-5.3-codex", "gpt-5.2-codex")), ("gpt-5.4", ("gpt-5.3-codex", "gpt-5.2-codex")), + ("gpt-5.3-codex", ("gpt-5.2-codex",)), ("gpt-5.3-codex-spark", ("gpt-5.3-codex", "gpt-5.2-codex")), ] diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 211e264e4..0041cf665 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -36,6 +36,8 @@ _EXTRA_ENV_KEYS = frozenset({ "SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL", "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", "DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET", + "FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN", + "WECOM_BOT_ID", "WECOM_SECRET", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", @@ -136,9 +138,16 @@ def ensure_hermes_home(): DEFAULT_CONFIG = { "model": "anthropic/claude-opus-4.6", + "fallback_providers": [], "toolsets": ["hermes-cli"], "agent": { "max_turns": 90, + # Tool-use enforcement: injects system prompt guidance that tells the + # model to actually call tools instead of describing intended actions. + # Values: "auto" (default — applies to gpt/codex models), true/false + # (force on/off for all models), or a list of model-name substrings + # to match (e.g. ["gpt", "codex", "gemini", "qwen"]). + "tool_use_enforcement": "auto", }, "terminal": { @@ -223,42 +232,49 @@ DEFAULT_CONFIG = { "model": "", "base_url": "", "api_key": "", + "timeout": 30, # seconds — increase for slow local models }, "compression": { "provider": "auto", "model": "", "base_url": "", "api_key": "", + "timeout": 120, # seconds — compression summarises large contexts; increase for local models }, "session_search": { "provider": "auto", "model": "", "base_url": "", "api_key": "", + "timeout": 30, }, "skills_hub": { "provider": "auto", "model": "", "base_url": "", "api_key": "", + "timeout": 30, }, "approval": { "provider": "auto", "model": "", # fast/cheap model recommended (e.g. gemini-flash, haiku) "base_url": "", "api_key": "", + "timeout": 30, }, "mcp": { "provider": "auto", "model": "", "base_url": "", "api_key": "", + "timeout": 30, }, "flush_memories": { "provider": "auto", "model": "", "base_url": "", "api_key": "", + "timeout": 30, }, }, @@ -266,12 +282,14 @@ DEFAULT_CONFIG = { "compact": False, "personality": "kawaii", "resume_display": "full", + "busy_input_mode": "interrupt", "bell_on_complete": False, "show_reasoning": False, "streaming": False, "show_cost": False, # Show $ cost in the status bar (off by default) "skin": "default", "tool_progress_command": False, # Enable /verbose command in messaging gateway + "tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands) }, # Privacy settings @@ -354,6 +372,13 @@ DEFAULT_CONFIG = { # Never saved to sessions, logs, or trajectories. "prefill_messages_file": "", + # Skills — external skill directories for sharing skills across tools/agents. + # Each path is expanded (~, ${VAR}) and resolved. Read-only — skill creation + # always goes to ~/.hermes/skills/. + "skills": { + "external_dirs": [], # e.g. ["~/.agents/skills", "/shared/team-skills"] + }, + # Honcho AI-native memory -- reads ~/.honcho/config.json as single source of truth. # This section is only needed for hermes-specific overrides; everything else # (apiKey, workspace, peerName, sessions, enabled) comes from the global config. @@ -409,6 +434,12 @@ DEFAULT_CONFIG = { }, }, + "cron": { + # Wrap delivered cron responses with a header (task name) and footer + # ("The agent cannot see this message"). Set to false for clean output. + "wrap_response": True, + }, + # Config schema version - bump this when adding new required fields "_config_version": 11, } @@ -549,14 +580,14 @@ OPTIONAL_ENV_VARS = { "category": "provider", }, "DASHSCOPE_API_KEY": { - "description": "Alibaba Cloud DashScope API key for Qwen models", + "description": "Alibaba Cloud DashScope API key (Qwen + multi-provider models)", "prompt": "DashScope API Key", "url": "https://modelstudio.console.alibabacloud.com/", "password": True, "category": "provider", }, "DASHSCOPE_BASE_URL": { - "description": "Custom DashScope base URL (default: international endpoint)", + "description": "Custom DashScope base URL (default: coding-intl OpenAI-compat endpoint)", "prompt": "DashScope Base URL", "url": "", "password": False, @@ -595,8 +626,31 @@ OPTIONAL_ENV_VARS = { "category": "provider", "advanced": True, }, + "HF_TOKEN": { + "description": "Hugging Face token for Inference Providers (20+ open models via router.huggingface.co)", + "prompt": "Hugging Face Token", + "url": "https://huggingface.co/settings/tokens", + "password": True, + "category": "provider", + }, + "HF_BASE_URL": { + "description": "Hugging Face Inference Providers base URL override", + "prompt": "HF base URL (leave empty for default)", + "url": None, + "password": False, + "category": "provider", + "advanced": True, + }, # ── Tool API keys ── + "EXA_API_KEY": { + "description": "Exa API key for AI-native web search and contents", + "prompt": "Exa API key", + "url": "https://exa.ai/", + "tools": ["web_search", "web_extract"], + "password": True, + "category": "tool", + }, "PARALLEL_API_KEY": { "description": "Parallel API key for AI-native web search and extract", "prompt": "Parallel API key", @@ -815,6 +869,20 @@ OPTIONAL_ENV_VARS = { "password": False, "category": "messaging", }, + "MATTERMOST_REQUIRE_MENTION": { + "description": "Require @mention in Mattermost channels (default: true). Set to false to respond to all messages.", + "prompt": "Require @mention in channels", + "url": None, + "password": False, + "category": "messaging", + }, + "MATTERMOST_FREE_RESPONSE_CHANNELS": { + "description": "Comma-separated Mattermost channel IDs where bot responds without @mention", + "prompt": "Free-response channel IDs (comma-separated)", + "url": None, + "password": False, + "category": "messaging", + }, "MATRIX_HOMESERVER": { "description": "Matrix homeserver URL (e.g. https://matrix.example.org)", "prompt": "Matrix homeserver URL", @@ -1694,6 +1762,7 @@ def show_config(): keys = [ ("OPENROUTER_API_KEY", "OpenRouter"), ("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"), + ("EXA_API_KEY", "Exa"), ("PARALLEL_API_KEY", "Parallel"), ("FIRECRAWL_API_KEY", "Firecrawl"), ("TAVILY_API_KEY", "Tavily"), @@ -1853,7 +1922,7 @@ def set_config_value(key: str, value: str): # Check if it's an API key (goes to .env) api_keys = [ 'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY', - 'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', + 'EXA_API_KEY', 'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'FIRECRAWL_GATEWAY_URL', 'TOOL_GATEWAY_DOMAIN', 'TOOL_GATEWAY_SCHEME', 'TOOL_GATEWAY_USER_TOKEN', 'TAVILY_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY', diff --git a/hermes_cli/curses_ui.py b/hermes_cli/curses_ui.py index f819b1ffd..dce620b8c 100644 --- a/hermes_cli/curses_ui.py +++ b/hermes_cli/curses_ui.py @@ -4,7 +4,7 @@ Used by `hermes tools` and `hermes skills` for interactive checklists. Provides a curses multi-select with keyboard navigation, plus a text-based numbered fallback for terminals without curses support. """ -from typing import List, Set +from typing import Callable, List, Optional, Set from hermes_cli.colors import Colors, color @@ -15,6 +15,7 @@ def curses_checklist( selected: Set[int], *, cancel_returns: Set[int] | None = None, + status_fn: Optional[Callable[[Set[int]], str]] = None, ) -> Set[int]: """Curses multi-select checklist. Returns set of selected indices. @@ -23,6 +24,9 @@ def curses_checklist( items: Display labels for each row. selected: Indices that start checked (pre-selected). cancel_returns: Returned on ESC/q. Defaults to the original *selected*. + status_fn: Optional callback ``f(chosen_indices) -> str`` whose return + value is rendered on the bottom row of the terminal. Use this for + live aggregate info (e.g. estimated token counts). """ if cancel_returns is None: cancel_returns = set(selected) @@ -47,6 +51,9 @@ def curses_checklist( stdscr.clear() max_y, max_x = stdscr.getmaxyx() + # Reserve bottom row for status bar when status_fn provided + footer_rows = 1 if status_fn else 0 + # Header try: hattr = curses.A_BOLD @@ -62,7 +69,7 @@ def curses_checklist( pass # Scrollable item list - visible_rows = max_y - 3 + visible_rows = max_y - 3 - footer_rows if cursor < scroll_offset: scroll_offset = cursor elif cursor >= scroll_offset + visible_rows: @@ -72,7 +79,7 @@ def curses_checklist( range(scroll_offset, min(len(items), scroll_offset + visible_rows)) ): y = draw_i + 3 - if y >= max_y - 1: + if y >= max_y - 1 - footer_rows: break check = "✓" if i in chosen else " " arrow = "→" if i == cursor else " " @@ -87,6 +94,20 @@ def curses_checklist( except curses.error: pass + # Status bar (bottom row, right-aligned) + if status_fn: + try: + status_text = status_fn(chosen) + if status_text: + # Right-align on the bottom row + sx = max(0, max_x - len(status_text) - 1) + sattr = curses.A_DIM + if curses.has_colors(): + sattr |= curses.color_pair(3) + stdscr.addnstr(max_y - 1, sx, status_text, max_x - sx - 1, sattr) + except curses.error: + pass + stdscr.refresh() key = stdscr.getch() @@ -107,7 +128,7 @@ def curses_checklist( return result_holder[0] if result_holder[0] is not None else cancel_returns except Exception: - return _numbered_fallback(title, items, selected, cancel_returns) + return _numbered_fallback(title, items, selected, cancel_returns, status_fn) def _numbered_fallback( @@ -115,6 +136,7 @@ def _numbered_fallback( items: List[str], selected: Set[int], cancel_returns: Set[int], + status_fn: Optional[Callable[[Set[int]], str]] = None, ) -> Set[int]: """Text-based toggle fallback for terminals without curses.""" chosen = set(selected) @@ -125,6 +147,10 @@ def _numbered_fallback( for i, label in enumerate(items): marker = color("[✓]", Colors.GREEN) if i in chosen else "[ ]" print(f" {marker} {i + 1:>2}. {label}") + if status_fn: + status_text = status_fn(chosen) + if status_text: + print(color(f"\n {status_text}", Colors.DIM)) print() try: val = input(color(" Toggle # (or Enter to confirm): ", Colors.DIM)).strip() diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index 053f92a27..a0a841905 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -10,9 +10,11 @@ import subprocess import shutil from hermes_cli.config import get_project_root, get_hermes_home, get_env_path +from hermes_constants import display_hermes_home PROJECT_ROOT = get_project_root() HERMES_HOME = get_hermes_home() +_DHH = display_hermes_home() # user-facing display path (e.g. ~/.hermes or ~/.hermes/profiles/coder) # Load environment variables from ~/.hermes/.env so API key checks work from dotenv import load_dotenv @@ -56,7 +58,7 @@ def _honcho_is_configured_for_doctor() -> bool: from honcho_integration.client import HonchoClientConfig cfg = HonchoClientConfig.from_global_config() - return bool(cfg.enabled and cfg.api_key) + return bool(cfg.enabled and (cfg.api_key or cfg.base_url)) except Exception: return False @@ -209,14 +211,14 @@ def run_doctor(args): # Check ~/.hermes/.env (primary location for user config) env_path = HERMES_HOME / '.env' if env_path.exists(): - check_ok("~/.hermes/.env file exists") + check_ok(f"{_DHH}/.env file exists") # Check for common issues content = env_path.read_text() if _has_provider_env_config(content): check_ok("API key or custom endpoint configured") else: - check_warn("No API key found in ~/.hermes/.env") + check_warn(f"No API key found in {_DHH}/.env") issues.append("Run 'hermes setup' to configure API keys") else: # Also check project root as fallback @@ -224,11 +226,11 @@ def run_doctor(args): if fallback_env.exists(): check_ok(".env file exists (in project directory)") else: - check_fail("~/.hermes/.env file missing") + check_fail(f"{_DHH}/.env file missing") if should_fix: env_path.parent.mkdir(parents=True, exist_ok=True) env_path.touch() - check_ok("Created empty ~/.hermes/.env") + check_ok(f"Created empty {_DHH}/.env") check_info("Run 'hermes setup' to configure API keys") fixed_count += 1 else: @@ -238,7 +240,7 @@ def run_doctor(args): # Check ~/.hermes/config.yaml (primary) or project cli-config.yaml (fallback) config_path = HERMES_HOME / 'config.yaml' if config_path.exists(): - check_ok("~/.hermes/config.yaml exists") + check_ok(f"{_DHH}/config.yaml exists") else: fallback_config = PROJECT_ROOT / 'cli-config.yaml' if fallback_config.exists(): @@ -248,11 +250,11 @@ def run_doctor(args): if should_fix and example_config.exists(): config_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(str(example_config), str(config_path)) - check_ok("Created ~/.hermes/config.yaml from cli-config.yaml.example") + check_ok(f"Created {_DHH}/config.yaml from cli-config.yaml.example") fixed_count += 1 elif should_fix: check_warn("config.yaml not found and no example to copy from") - manual_issues.append("Create ~/.hermes/config.yaml manually") + manual_issues.append(f"Create {_DHH}/config.yaml manually") else: check_warn("config.yaml not found", "(using defaults)") @@ -294,28 +296,28 @@ def run_doctor(args): hermes_home = HERMES_HOME if hermes_home.exists(): - check_ok("~/.hermes directory exists") + check_ok(f"{_DHH} directory exists") else: if should_fix: hermes_home.mkdir(parents=True, exist_ok=True) - check_ok("Created ~/.hermes directory") + check_ok(f"Created {_DHH} directory") fixed_count += 1 else: - check_warn("~/.hermes not found", "(will be created on first use)") + check_warn(f"{_DHH} not found", "(will be created on first use)") # Check expected subdirectories expected_subdirs = ["cron", "sessions", "logs", "skills", "memories"] for subdir_name in expected_subdirs: subdir_path = hermes_home / subdir_name if subdir_path.exists(): - check_ok(f"~/.hermes/{subdir_name}/ exists") + check_ok(f"{_DHH}/{subdir_name}/ exists") else: if should_fix: subdir_path.mkdir(parents=True, exist_ok=True) - check_ok(f"Created ~/.hermes/{subdir_name}/") + check_ok(f"Created {_DHH}/{subdir_name}/") fixed_count += 1 else: - check_warn(f"~/.hermes/{subdir_name}/ not found", "(will be created on first use)") + check_warn(f"{_DHH}/{subdir_name}/ not found", "(will be created on first use)") # Check for SOUL.md persona file soul_path = hermes_home / "SOUL.md" @@ -324,11 +326,11 @@ def run_doctor(args): # Check if it's just the template comments (no real content) lines = [l for l in content.splitlines() if l.strip() and not l.strip().startswith(("", "#"))] if lines: - check_ok("~/.hermes/SOUL.md exists (persona configured)") + check_ok(f"{_DHH}/SOUL.md exists (persona configured)") else: - check_info("~/.hermes/SOUL.md exists but is empty — edit it to customize personality") + check_info(f"{_DHH}/SOUL.md exists but is empty — edit it to customize personality") else: - check_warn("~/.hermes/SOUL.md not found", "(create it to give Hermes a custom personality)") + check_warn(f"{_DHH}/SOUL.md not found", "(create it to give Hermes a custom personality)") if should_fix: soul_path.parent.mkdir(parents=True, exist_ok=True) soul_path.write_text( @@ -337,13 +339,13 @@ def run_doctor(args): "You are Hermes, a helpful AI assistant.\n", encoding="utf-8", ) - check_ok("Created ~/.hermes/SOUL.md with basic template") + check_ok(f"Created {_DHH}/SOUL.md with basic template") fixed_count += 1 # Check memory directory memories_dir = hermes_home / "memories" if memories_dir.exists(): - check_ok("~/.hermes/memories/ directory exists") + check_ok(f"{_DHH}/memories/ directory exists") memory_file = memories_dir / "MEMORY.md" user_file = memories_dir / "USER.md" if memory_file.exists(): @@ -357,10 +359,10 @@ def run_doctor(args): else: check_info("USER.md not created yet (will be created when the agent first writes a memory)") else: - check_warn("~/.hermes/memories/ not found", "(will be created on first use)") + check_warn(f"{_DHH}/memories/ not found", "(will be created on first use)") if should_fix: memories_dir.mkdir(parents=True, exist_ok=True) - check_ok("Created ~/.hermes/memories/") + check_ok(f"Created {_DHH}/memories/") fixed_count += 1 # Check SQLite session store @@ -372,11 +374,11 @@ def run_doctor(args): cursor = conn.execute("SELECT COUNT(*) FROM sessions") count = cursor.fetchone()[0] conn.close() - check_ok(f"~/.hermes/state.db exists ({count} sessions)") + check_ok(f"{_DHH}/state.db exists ({count} sessions)") except Exception as e: - check_warn(f"~/.hermes/state.db exists but has issues: {e}") + check_warn(f"{_DHH}/state.db exists but has issues: {e}") else: - check_info("~/.hermes/state.db not created yet (will be created on first session)") + check_info(f"{_DHH}/state.db not created yet (will be created on first session)") _check_gateway_service_linger(issues) @@ -691,7 +693,7 @@ def run_doctor(args): if github_token: check_ok("GitHub token configured (authenticated API access)") else: - check_warn("No GITHUB_TOKEN", "(60 req/hr rate limit — set in ~/.hermes/.env for better rates)") + check_warn("No GITHUB_TOKEN", f"(60 req/hr rate limit — set in {_DHH}/.env for better rates)") # ========================================================================= # Honcho memory @@ -708,8 +710,8 @@ def run_doctor(args): check_warn("Honcho config not found", "run: hermes honcho setup") elif not hcfg.enabled: check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)") - elif not hcfg.api_key: - check_fail("Honcho API key not set", "run: hermes honcho setup") + elif not (hcfg.api_key or hcfg.base_url): + check_fail("Honcho API key or base URL not set", "run: hermes honcho setup") issues.append("No Honcho API key — run 'hermes honcho setup'") else: from honcho_integration.client import get_honcho_client, reset_honcho_client @@ -728,6 +730,53 @@ def run_doctor(args): except Exception as _e: check_warn("Honcho check failed", str(_e)) + # ========================================================================= + # Profiles + # ========================================================================= + try: + from hermes_cli.profiles import list_profiles, _get_wrapper_dir, profile_exists + import re as _re + + named_profiles = [p for p in list_profiles() if not p.is_default] + if named_profiles: + print() + print(color("◆ Profiles", Colors.CYAN, Colors.BOLD)) + check_ok(f"{len(named_profiles)} profile(s) found") + wrapper_dir = _get_wrapper_dir() + for p in named_profiles: + parts = [] + if p.gateway_running: + parts.append("gateway running") + if p.model: + parts.append(p.model[:30]) + if not (p.path / "config.yaml").exists(): + parts.append("⚠ missing config") + if not (p.path / ".env").exists(): + parts.append("no .env") + wrapper = wrapper_dir / p.name + if not wrapper.exists(): + parts.append("no alias") + status = ", ".join(parts) if parts else "configured" + check_ok(f" {p.name}: {status}") + + # Check for orphan wrappers + if wrapper_dir.is_dir(): + for wrapper in wrapper_dir.iterdir(): + if not wrapper.is_file(): + continue + try: + content = wrapper.read_text() + if "hermes -p" in content: + _m = _re.search(r"hermes -p (\S+)", content) + if _m and not profile_exists(_m.group(1)): + check_warn(f"Orphan alias: {wrapper.name} → profile '{_m.group(1)}' no longer exists") + except Exception: + pass + except ImportError: + pass + except Exception as _e: + logger.debug("Profile health check failed: %s", _e) + # ========================================================================= # Summary # ========================================================================= diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index ba13cb4e8..ba2922771 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -15,6 +15,8 @@ from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() from hermes_cli.config import get_env_value, get_hermes_home, save_env_value, is_managed, managed_error +# display_hermes_home is imported lazily at call sites to avoid ImportError +# when hermes_constants is cached from a pre-update version during `hermes update`. from hermes_cli.setup import ( print_header, print_info, print_success, print_warning, print_error, prompt, prompt_choice, prompt_yes_no, @@ -125,20 +127,43 @@ _SERVICE_BASE = "hermes-gateway" SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration" +def _profile_suffix() -> str: + """Derive a service-name suffix from the current HERMES_HOME. + + Returns ``""`` for the default ``~/.hermes``, the profile name for + ``~/.hermes/profiles/``, or a short hash for any other custom + HERMES_HOME path. + """ + import hashlib + import re + from pathlib import Path as _Path + home = get_hermes_home().resolve() + default = (_Path.home() / ".hermes").resolve() + if home == default: + return "" + # Detect ~/.hermes/profiles/ pattern → use the profile name + profiles_root = (default / "profiles").resolve() + try: + rel = home.relative_to(profiles_root) + parts = rel.parts + if len(parts) == 1 and re.match(r"^[a-z0-9][a-z0-9_-]{0,63}$", parts[0]): + return parts[0] + except ValueError: + pass + # Fallback: short hash for arbitrary HERMES_HOME paths + return hashlib.sha256(str(home).encode()).hexdigest()[:8] + + def get_service_name() -> str: """Derive a systemd service name scoped to this HERMES_HOME. Default ``~/.hermes`` returns ``hermes-gateway`` (backward compatible). - Any other HERMES_HOME appends a short hash so multiple installations - can each have their own systemd service without conflicting. + Profile ``~/.hermes/profiles/coder`` returns ``hermes-gateway-coder``. + Any other HERMES_HOME appends a short hash for uniqueness. """ - import hashlib - from pathlib import Path as _Path # local import to avoid monkeypatch interference - home = get_hermes_home().resolve() - default = (_Path.home() / ".hermes").resolve() - if home == default: + suffix = _profile_suffix() + if not suffix: return _SERVICE_BASE - suffix = hashlib.sha256(str(home).encode()).hexdigest()[:8] return f"{_SERVICE_BASE}-{suffix}" @@ -369,7 +394,14 @@ def print_systemd_linger_guidance() -> None: print(" sudo loginctl enable-linger $USER") def get_launchd_plist_path() -> Path: - return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist" + """Return the launchd plist path, scoped per profile. + + Default ``~/.hermes`` → ``ai.hermes.gateway.plist`` (backward compatible). + Profile ``~/.hermes/profiles/coder`` → ``ai.hermes.gateway-coder.plist``. + """ + suffix = _profile_suffix() + name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway" + return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist" def _detect_venv_dir() -> Path | None: """Detect the active virtualenv directory. @@ -420,6 +452,17 @@ def get_hermes_cli_path() -> str: # Systemd (Linux) # ============================================================================= +def _build_user_local_paths(home: Path, path_entries: list[str]) -> list[str]: + """Return user-local bin dirs that exist and aren't already in *path_entries*.""" + candidates = [ + str(home / ".local" / "bin"), # uv, uvx, pip-installed CLIs + str(home / ".cargo" / "bin"), # Rust/cargo tools + str(home / "go" / "bin"), # Go tools + str(home / ".npm-global" / "bin"), # npm global packages + ] + return [p for p in candidates if p not in path_entries and Path(p).exists()] + + def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str: python_path = get_python_path() working_dir = str(PROJECT_ROOT) @@ -434,13 +477,16 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) resolved_node_dir = str(Path(resolved_node).resolve().parent) if resolved_node_dir not in path_entries: path_entries.append(resolved_node_dir) - path_entries.extend(["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]) - sane_path = ":".join(path_entries) hermes_home = str(get_hermes_home().resolve()) + common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"] + if system: username, group_name, home_dir = _system_service_identity(run_as_user) + path_entries.extend(_build_user_local_paths(Path(home_dir), path_entries)) + path_entries.extend(common_bin_paths) + sane_path = ":".join(path_entries) return f"""[Unit] Description={SERVICE_DESCRIPTION} After=network-online.target @@ -472,6 +518,9 @@ StandardError=journal WantedBy=multi-user.target """ + path_entries.extend(_build_user_local_paths(Path.home(), path_entries)) + path_entries.extend(common_bin_paths) + sane_path = ":".join(path_entries) return f"""[Unit] Description={SERVICE_DESCRIPTION} After=network.target @@ -752,18 +801,46 @@ def systemd_status(deep: bool = False, system: bool = False): # Launchd (macOS) # ============================================================================= +def get_launchd_label() -> str: + """Return the launchd service label, scoped per profile.""" + suffix = _profile_suffix() + return f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway" + + def generate_launchd_plist() -> str: python_path = get_python_path() working_dir = str(PROJECT_ROOT) + hermes_home = str(get_hermes_home().resolve()) log_dir = get_hermes_home() / "logs" log_dir.mkdir(parents=True, exist_ok=True) - + label = get_launchd_label() + # Build a sane PATH for the launchd plist. launchd provides only a + # minimal default (/usr/bin:/bin:/usr/sbin:/sbin) which misses Homebrew, + # nvm, cargo, etc. We prepend venv/bin and node_modules/.bin (matching + # the systemd unit), then capture the user's full shell PATH so every + # user-installed tool (node, ffmpeg, …) is reachable. + detected_venv = _detect_venv_dir() + venv_bin = str(detected_venv / "bin") if detected_venv else str(PROJECT_ROOT / "venv" / "bin") + venv_dir = str(detected_venv) if detected_venv else str(PROJECT_ROOT / "venv") + node_bin = str(PROJECT_ROOT / "node_modules" / ".bin") + # Resolve the directory containing the node binary (e.g. Homebrew, nvm) + # so it's explicitly in PATH even if the user's shell PATH changes later. + priority_dirs = [venv_bin, node_bin] + resolved_node = shutil.which("node") + if resolved_node: + resolved_node_dir = str(Path(resolved_node).resolve().parent) + if resolved_node_dir not in priority_dirs: + priority_dirs.append(resolved_node_dir) + sane_path = ":".join( + dict.fromkeys(priority_dirs + [p for p in os.environ.get("PATH", "").split(":") if p]) + ) + return f""" Label - ai.hermes.gateway + {label} ProgramArguments @@ -778,6 +855,16 @@ def generate_launchd_plist() -> str: WorkingDirectory {working_dir} + EnvironmentVariables + + PATH + {sane_path} + VIRTUAL_ENV + {venv_dir} + HERMES_HOME + {hermes_home} + + RunAtLoad @@ -850,7 +937,8 @@ def launchd_install(force: bool = False): print() print("Next steps:") print(" hermes gateway status # Check status") - print(" tail -f ~/.hermes/logs/gateway.log # View logs") + from hermes_constants import display_hermes_home as _dhh + print(f" tail -f {_dhh()}/logs/gateway.log # View logs") def launchd_uninstall(): plist_path = get_launchd_plist_path() @@ -863,20 +951,33 @@ def launchd_uninstall(): print("✓ Service uninstalled") def launchd_start(): - refresh_launchd_plist_if_needed() plist_path = get_launchd_plist_path() + label = get_launchd_label() + + # Self-heal if the plist is missing entirely (e.g., manual cleanup, failed upgrade) + if not plist_path.exists(): + print("↻ launchd plist missing; regenerating service definition") + plist_path.parent.mkdir(parents=True, exist_ok=True) + plist_path.write_text(generate_launchd_plist(), encoding="utf-8") + subprocess.run(["launchctl", "load", str(plist_path)], check=True) + subprocess.run(["launchctl", "start", label], check=True) + print("✓ Service started") + return + + refresh_launchd_plist_if_needed() try: - subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True) + subprocess.run(["launchctl", "start", label], check=True) except subprocess.CalledProcessError as e: - if e.returncode != 3 or not plist_path.exists(): + if e.returncode != 3: raise print("↻ launchd job was unloaded; reloading service definition") subprocess.run(["launchctl", "load", str(plist_path)], check=True) - subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True) + subprocess.run(["launchctl", "start", label], check=True) print("✓ Service started") def launchd_stop(): - subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True) + label = get_launchd_label() + subprocess.run(["launchctl", "stop", label], check=True) print("✓ Service stopped") def _wait_for_gateway_exit(timeout: float = 10.0, force_after: float = 5.0): @@ -931,8 +1032,9 @@ def launchd_restart(): def launchd_status(deep: bool = False): plist_path = get_launchd_plist_path() + label = get_launchd_label() result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], + ["launchctl", "list", label], capture_output=True, text=True ) @@ -1220,6 +1322,59 @@ _PLATFORMS = [ "help": "The AppSecret from your DingTalk application credentials."}, ], }, + { + "key": "feishu", + "label": "Feishu / Lark", + "emoji": "🪽", + "token_var": "FEISHU_APP_ID", + "setup_instructions": [ + "1. Go to https://open.feishu.cn/ (or https://open.larksuite.com/ for Lark)", + "2. Create an app and copy the App ID and App Secret", + "3. Enable the Bot capability for the app", + "4. Choose WebSocket (recommended) or Webhook connection mode", + "5. Add the bot to a group chat or message it directly", + "6. Restrict access with FEISHU_ALLOWED_USERS for production use", + ], + "vars": [ + {"name": "FEISHU_APP_ID", "prompt": "App ID", "password": False, + "help": "The App ID from your Feishu/Lark application."}, + {"name": "FEISHU_APP_SECRET", "prompt": "App Secret", "password": True, + "help": "The App Secret from your Feishu/Lark application."}, + {"name": "FEISHU_DOMAIN", "prompt": "Domain — feishu or lark (default: feishu)", "password": False, + "help": "Use 'feishu' for Feishu China, or 'lark' for Lark international."}, + {"name": "FEISHU_CONNECTION_MODE", "prompt": "Connection mode — websocket or webhook (default: websocket)", "password": False, + "help": "websocket is recommended unless you specifically need webhook mode."}, + {"name": "FEISHU_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, or empty)", "password": False, + "is_allowlist": True, + "help": "Restrict which Feishu/Lark users can interact with the bot."}, + {"name": "FEISHU_HOME_CHANNEL", "prompt": "Home chat ID (optional, for cron/notifications)", "password": False, + "help": "Chat ID for scheduled results and notifications."}, + ], + }, + { + "key": "wecom", + "label": "WeCom (Enterprise WeChat)", + "emoji": "💬", + "token_var": "WECOM_BOT_ID", + "setup_instructions": [ + "1. Go to WeCom Admin Console → Applications → Create AI Bot", + "2. Copy the Bot ID and Secret from the bot's credentials page", + "3. The bot connects via WebSocket — no public endpoint needed", + "4. Add the bot to a group chat or message it directly in WeCom", + "5. Restrict access with WECOM_ALLOWED_USERS for production use", + ], + "vars": [ + {"name": "WECOM_BOT_ID", "prompt": "Bot ID", "password": False, + "help": "The Bot ID from your WeCom AI Bot."}, + {"name": "WECOM_SECRET", "prompt": "Secret", "password": True, + "help": "The secret from your WeCom AI Bot."}, + {"name": "WECOM_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, or empty)", "password": False, + "is_allowlist": True, + "help": "Restrict which WeCom users can interact with the bot."}, + {"name": "WECOM_HOME_CHANNEL", "prompt": "Home chat ID (optional, for cron/notifications)", "password": False, + "help": "Chat ID for scheduled results and notifications."}, + ], + }, ] @@ -1437,7 +1592,7 @@ def _is_service_running() -> bool: return False elif is_macos() and get_launchd_plist_path().exists(): result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], + ["launchctl", "list", get_launchd_label()], capture_output=True, text=True ) return result.returncode == 0 diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 1e2eec481..95932013f 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -54,6 +54,71 @@ from typing import Optional PROJECT_ROOT = Path(__file__).parent.parent.resolve() sys.path.insert(0, str(PROJECT_ROOT)) +# --------------------------------------------------------------------------- +# Profile override — MUST happen before any hermes module import. +# +# Many modules cache HERMES_HOME at import time (module-level constants). +# We intercept --profile/-p from sys.argv here and set the env var so that +# every subsequent ``os.getenv("HERMES_HOME", ...)`` resolves correctly. +# The flag is stripped from sys.argv so argparse never sees it. +# Falls back to ~/.hermes/active_profile for sticky default. +# --------------------------------------------------------------------------- +def _apply_profile_override() -> None: + """Pre-parse --profile/-p and set HERMES_HOME before module imports.""" + argv = sys.argv[1:] + profile_name = None + consume = 0 + + # 1. Check for explicit -p / --profile flag + for i, arg in enumerate(argv): + if arg in ("--profile", "-p") and i + 1 < len(argv): + profile_name = argv[i + 1] + consume = 2 + break + elif arg.startswith("--profile="): + profile_name = arg.split("=", 1)[1] + consume = 1 + break + + # 2. If no flag, check ~/.hermes/active_profile + if profile_name is None: + try: + active_path = Path.home() / ".hermes" / "active_profile" + if active_path.exists(): + name = active_path.read_text().strip() + if name and name != "default": + profile_name = name + consume = 0 # don't strip anything from argv + except (UnicodeDecodeError, OSError): + pass # corrupted file, skip + + # 3. If we found a profile, resolve and set HERMES_HOME + if profile_name is not None: + try: + from hermes_cli.profiles import resolve_profile_env + hermes_home = resolve_profile_env(profile_name) + except (ValueError, FileNotFoundError) as exc: + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) + except Exception as exc: + # A bug in profiles.py must NEVER prevent hermes from starting + print(f"Warning: profile override failed ({exc}), using default", file=sys.stderr) + return + os.environ["HERMES_HOME"] = hermes_home + # Strip the flag from argv so argparse doesn't choke + if consume > 0: + for i, arg in enumerate(argv): + if arg in ("--profile", "-p"): + start = i + 1 # +1 because argv is sys.argv[1:] + sys.argv = sys.argv[:start] + sys.argv[start + consume:] + break + elif arg.startswith("--profile="): + start = i + 1 + sys.argv = sys.argv[:start] + sys.argv[start + 1:] + break + +_apply_profile_override() + # Load .env from ~/.hermes/.env first, then project root as dev fallback. # User-managed env files should override stale shell exports on restart. from hermes_cli.config import get_hermes_home @@ -795,6 +860,7 @@ def cmd_model(args): "ai-gateway": "AI Gateway", "kilocode": "Kilo Code", "alibaba": "Alibaba Cloud (DashScope)", + "huggingface": "Hugging Face", "custom": "Custom endpoint", } active_label = provider_labels.get(active, active) @@ -820,7 +886,8 @@ def cmd_model(args): ("opencode-zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"), ("opencode-go", "OpenCode Go (open models, $10/month subscription)"), ("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"), - ("alibaba", "Alibaba Cloud / DashScope (Qwen models, Anthropic-compatible)"), + ("alibaba", "Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"), + ("huggingface", "Hugging Face Inference Providers (20+ open models)"), ] # Add user-defined custom providers from config.yaml @@ -830,8 +897,8 @@ def cmd_model(args): for entry in custom_providers_cfg: if not isinstance(entry, dict): continue - name = entry.get("name", "").strip() - base_url = entry.get("base_url", "").strip() + name = (entry.get("name") or "").strip() + base_url = (entry.get("base_url") or "").strip() if not name or not base_url: continue # Generate a stable key from the name @@ -893,7 +960,7 @@ def cmd_model(args): _model_flow_anthropic(config, current_model) elif selected_provider == "kimi-coding": _model_flow_kimi(config, current_model) - elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba"): + elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface"): _model_flow_api_key_provider(config, selected_provider, current_model) @@ -978,6 +1045,7 @@ def _model_flow_openrouter(config, current_model=""): cfg["model"] = model model["provider"] = "openrouter" model["base_url"] = OPENROUTER_BASE_URL + model["api_mode"] = "chat_completions" save_config(cfg) deactivate_provider() print(f"Default model set to: {selected} (via OpenRouter)") @@ -1238,6 +1306,7 @@ def _model_flow_custom(config): cfg["model"] = model model["provider"] = "custom" model["base_url"] = effective_url + model.pop("api_mode", None) # let runtime auto-detect from URL save_config(cfg) deactivate_provider() @@ -1539,6 +1608,18 @@ _PROVIDER_MODELS = { "google/gemini-3-pro-preview", "google/gemini-3-flash-preview", ], + # Curated HF model list — only agentic models that map to OpenRouter defaults. + # Format: HF model ID → OpenRouter equivalent noted in comment + "huggingface": [ + "Qwen/Qwen3.5-397B-A17B", # ↔ qwen/qwen3.5-plus + "Qwen/Qwen3.5-35B-A3B", # ↔ qwen/qwen3.5-35b-a3b + "deepseek-ai/DeepSeek-V3.2", # ↔ deepseek/deepseek-chat + "moonshotai/Kimi-K2.5", # ↔ moonshotai/kimi-k2.5 + "MiniMaxAI/MiniMax-M2.5", # ↔ minimax/minimax-m2.5 + "zai-org/GLM-5", # ↔ z-ai/glm-5 + "XiaomiMiMo/MiMo-V2-Flash", # ↔ xiaomi/mimo-v2-pro + "moonshotai/Kimi-K2-Thinking", # ↔ moonshotai/kimi-k2-thinking + ], } @@ -2007,6 +2088,7 @@ def _model_flow_kimi(config, current_model=""): cfg["model"] = model model["provider"] = provider_id model["base_url"] = effective_base + model.pop("api_mode", None) # let runtime auto-detect from URL save_config(cfg) deactivate_provider() @@ -2068,19 +2150,25 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): save_env_value(base_url_env, override) effective_base = override - # Model selection — try live /models endpoint first, fall back to defaults - from hermes_cli.models import fetch_api_models - api_key_for_probe = existing_key or (get_env_value(key_env) if key_env else "") - live_models = fetch_api_models(api_key_for_probe, effective_base) + # Model selection — try live /models endpoint first, fall back to defaults. + # Providers with large live catalogs (100+ models) use a curated list instead + # so users see familiar model names rather than an overwhelming dump. + curated = _PROVIDER_MODELS.get(provider_id, []) + if curated and len(curated) >= 8: + # Curated list is substantial — use it directly, skip live probe + live_models = None + else: + from hermes_cli.models import fetch_api_models + api_key_for_probe = existing_key or (get_env_value(key_env) if key_env else "") + live_models = fetch_api_models(api_key_for_probe, effective_base) - if live_models: + if live_models and len(live_models) >= len(curated): model_list = live_models print(f" Found {len(model_list)} model(s) from {pconfig.name} API") else: - model_list = _PROVIDER_MODELS.get(provider_id, []) + model_list = curated if model_list: - print(" ⚠ Could not auto-detect models from API — showing defaults.") - print(" Use \"Enter custom model name\" if you don't see your model.") + print(f" Showing {len(model_list)} curated models — use \"Enter custom model name\" for others.") # else: no defaults either, will fall through to raw input if model_list: @@ -2107,6 +2195,7 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): cfg["model"] = model model["provider"] = provider_id model["base_url"] = effective_base + model.pop("api_mode", None) # let runtime auto-detect from URL save_config(cfg) deactivate_provider() @@ -2138,7 +2227,8 @@ def _run_anthropic_oauth_flow(save_env_value): ): use_anthropic_claude_code_credentials(save_fn=save_env_value) print(" ✓ Claude Code credentials linked.") - print(" Hermes will use Claude's credential store directly instead of copying a setup-token into ~/.hermes/.env.") + from hermes_constants import display_hermes_home as _dhh_fn + print(f" Hermes will use Claude's credential store directly instead of copying a setup-token into {_dhh_fn()}/.env.") return True return False @@ -2356,6 +2446,12 @@ def cmd_cron(args): cron_command(args) +def cmd_webhook(args): + """Webhook subscription management.""" + from hermes_cli.webhook import webhook_command + webhook_command(args) + + def cmd_doctor(args): """Check configuration and dependencies.""" from hermes_cli.doctor import run_doctor @@ -2402,6 +2498,34 @@ def cmd_uninstall(args): run_uninstall(args) +def _clear_bytecode_cache(root: Path) -> int: + """Remove all __pycache__ directories under *root*. + + Stale .pyc files can cause ImportError after code updates when Python + loads a cached bytecode file that references names that no longer exist + (or don't yet exist) in the updated source. Clearing them forces Python + to recompile from the .py source on next import. + + Returns the number of directories removed. + """ + removed = 0 + for dirpath, dirnames, _ in os.walk(root): + # Skip venv / node_modules / .git entirely + dirnames[:] = [ + d for d in dirnames + if d not in ("venv", ".venv", "node_modules", ".git", ".worktrees") + ] + if os.path.basename(dirpath) == "__pycache__": + try: + import shutil as _shutil + _shutil.rmtree(dirpath) + removed += 1 + except OSError: + pass + dirnames.clear() # nothing left to recurse into + return removed + + def _update_via_zip(args): """Update Hermes Agent by downloading a ZIP archive. @@ -2443,7 +2567,7 @@ def _update_via_zip(args): break # Copy updated files over existing installation, preserving venv/node_modules/.git - preserve = {'venv', 'node_modules', '.git', '__pycache__', '.env'} + preserve = {'venv', 'node_modules', '.git', '.env'} update_count = 0 for item in os.listdir(extracted): if item in preserve: @@ -2466,6 +2590,11 @@ def _update_via_zip(args): except Exception as e: print(f"✗ ZIP update failed: {e}") sys.exit(1) + + # Clear stale bytecode after ZIP extraction + removed = _clear_bytecode_cache(PROJECT_ROOT) + if removed: + print(f" ✓ Cleared {removed} stale __pycache__ director{'y' if removed == 1 else 'ies'}") # Reinstall Python dependencies (try .[all] first for optional extras, # fall back to . if extras fail — mirrors the install script behavior) @@ -2487,8 +2616,18 @@ def _update_via_zip(args): ) else: # Use sys.executable to explicitly call the venv's pip module, - # avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu + # avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu. + # Some environments lose pip inside the venv; bootstrap it back with + # ensurepip before trying the editable install. pip_cmd = [sys.executable, "-m", "pip"] + try: + subprocess.run(pip_cmd + ["--version"], cwd=PROJECT_ROOT, check=True, capture_output=True) + except subprocess.CalledProcessError: + subprocess.run( + [sys.executable, "-m", "ensurepip", "--upgrade", "--default-pip"], + cwd=PROJECT_ROOT, + check=True, + ) try: subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True) except subprocess.CalledProcessError: @@ -2649,7 +2788,12 @@ def _restore_stashed_changes( print("Resolve conflicts manually, then run: git stash drop") print(f"Restore your changes with: git stash apply {stash_ref}") - sys.exit(1) + # In non-interactive mode (gateway /update), don't abort — the code + # update itself succeeded, only the stash restore had conflicts. + # Aborting would report the entire update as failed. + if prompt_user: + sys.exit(1) + return False stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref) if stash_selector is None: @@ -2723,30 +2867,60 @@ def cmd_update(args): # Fetch and pull try: - print("→ Fetching updates...") git_cmd = ["git"] if sys.platform == "win32": git_cmd = ["git", "-c", "windows.appendAtomically=false"] - - subprocess.run(git_cmd + ["fetch", "origin"], cwd=PROJECT_ROOT, check=True) - - # Get current branch + + print("→ Fetching updates...") + fetch_result = subprocess.run( + git_cmd + ["fetch", "origin"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + if fetch_result.returncode != 0: + stderr = fetch_result.stderr.strip() + if "Could not resolve host" in stderr or "unable to access" in stderr: + print("✗ Network error — cannot reach the remote repository.") + print(f" {stderr.splitlines()[0]}" if stderr else "") + elif "Authentication failed" in stderr or "could not read Username" in stderr: + print("✗ Authentication failed — check your git credentials or SSH key.") + else: + print(f"✗ Failed to fetch updates from origin.") + if stderr: + print(f" {stderr.splitlines()[0]}") + sys.exit(1) + + # Get current branch (returns literal "HEAD" when detached) result = subprocess.run( git_cmd + ["rev-parse", "--abbrev-ref", "HEAD"], cwd=PROJECT_ROOT, capture_output=True, text=True, - check=True + check=True, ) - branch = result.stdout.strip() + current_branch = result.stdout.strip() - # Fall back to main if the current branch doesn't exist on the remote - verify = subprocess.run( - git_cmd + ["rev-parse", "--verify", f"origin/{branch}"], - cwd=PROJECT_ROOT, capture_output=True, text=True, - ) - if verify.returncode != 0: - branch = "main" + # Always update against main + branch = "main" + + # If user is on a non-main branch or detached HEAD, switch to main + if current_branch != "main": + label = "detached HEAD" if current_branch == "HEAD" else f"branch '{current_branch}'" + print(f" ⚠ Currently on {label} — switching to main for update...") + # Stash before checkout so uncommitted work isn't lost + auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT) + subprocess.run( + git_cmd + ["checkout", "main"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + check=True, + ) + else: + auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT) + + prompt_for_restore = auto_stash_ref is not None and sys.stdin.isatty() and sys.stdout.isatty() # Check if there are updates result = subprocess.run( @@ -2754,33 +2928,78 @@ def cmd_update(args): cwd=PROJECT_ROOT, capture_output=True, text=True, - check=True + check=True, ) commit_count = int(result.stdout.strip()) - + if commit_count == 0: _invalidate_update_cache() - print("✓ Already up to date!") - return - - print(f"→ Found {commit_count} new commit(s)") - - auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT) - prompt_for_restore = auto_stash_ref is not None and sys.stdin.isatty() and sys.stdout.isatty() - - print("→ Pulling updates...") - try: - subprocess.run(git_cmd + ["pull", "--ff-only", "origin", branch], cwd=PROJECT_ROOT, check=True) - finally: + # Restore stash and switch back to original branch if we moved if auto_stash_ref is not None: _restore_stashed_changes( - git_cmd, - PROJECT_ROOT, - auto_stash_ref, + git_cmd, PROJECT_ROOT, auto_stash_ref, prompt_user=prompt_for_restore, ) + if current_branch not in ("main", "HEAD"): + subprocess.run( + git_cmd + ["checkout", current_branch], + cwd=PROJECT_ROOT, capture_output=True, text=True, check=False, + ) + print("✓ Already up to date!") + return + + print(f"→ Found {commit_count} new commit(s)") + + print("→ Pulling updates...") + update_succeeded = False + try: + pull_result = subprocess.run( + git_cmd + ["pull", "--ff-only", "origin", branch], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + if pull_result.returncode != 0: + # ff-only failed — local and remote have diverged (e.g. upstream + # force-pushed or rebase). Since local changes are already + # stashed, reset to match the remote exactly. + print(" ⚠ Fast-forward not possible (history diverged), resetting to match remote...") + reset_result = subprocess.run( + git_cmd + ["reset", "--hard", f"origin/{branch}"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + if reset_result.returncode != 0: + print(f"✗ Failed to reset to origin/{branch}.") + if reset_result.stderr.strip(): + print(f" {reset_result.stderr.strip()}") + print(" Try manually: git fetch origin && git reset --hard origin/main") + sys.exit(1) + update_succeeded = True + finally: + if auto_stash_ref is not None: + # Don't attempt stash restore if the code update itself failed — + # working tree is in an unknown state. + if not update_succeeded: + print(f" ℹ️ Local changes preserved in stash (ref: {auto_stash_ref})") + print(f" Restore manually with: git stash apply") + else: + _restore_stashed_changes( + git_cmd, + PROJECT_ROOT, + auto_stash_ref, + prompt_user=prompt_for_restore, + ) _invalidate_update_cache() + + # Clear stale .pyc bytecode cache — prevents ImportError on gateway + # restart when updated source references names that didn't exist in + # the old bytecode (e.g. get_hermes_home added to hermes_constants). + removed = _clear_bytecode_cache(PROJECT_ROOT) + if removed: + print(f" ✓ Cleared {removed} stale __pycache__ director{'y' if removed == 1 else 'ies'}") # Reinstall Python dependencies (try .[all] first for optional extras, # fall back to . if extras fail — mirrors the install script behavior) @@ -2801,8 +3020,18 @@ def cmd_update(args): ) else: # Use sys.executable to explicitly call the venv's pip module, - # avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu + # avoiding PEP 668 'externally-managed-environment' errors on Debian/Ubuntu. + # Some environments lose pip inside the venv; bootstrap it back with + # ensurepip before trying the editable install. pip_cmd = [sys.executable, "-m", "pip"] + try: + subprocess.run(pip_cmd + ["--version"], cwd=PROJECT_ROOT, check=True, capture_output=True) + except subprocess.CalledProcessError: + subprocess.run( + [sys.executable, "-m", "ensurepip", "--upgrade", "--default-pip"], + cwd=PROJECT_ROOT, + check=True, + ) try: subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True) except subprocess.CalledProcessError: @@ -2819,6 +3048,17 @@ def cmd_update(args): print() print("✓ Code updated!") + # After git pull, source files on disk are newer than cached Python + # modules in this process. Reload hermes_constants so that any lazy + # import executed below (skills sync, gateway restart) sees new + # attributes like display_hermes_home() added since the last release. + try: + import importlib + import hermes_constants as _hc + importlib.reload(_hc) + except Exception: + pass # non-fatal — worst case a lazy import fails gracefully + # Sync bundled skills (copies new, updates changed, respects user deletions) try: from tools.skills_sync import sync_skills @@ -2837,7 +3077,35 @@ def cmd_update(args): print(" ✓ Skills are up to date") except Exception as e: logger.debug("Skills sync during update failed: %s", e) - + + # Sync bundled skills to all other profiles + try: + from hermes_cli.profiles import list_profiles, get_active_profile_name, seed_profile_skills + active = get_active_profile_name() + other_profiles = [p for p in list_profiles() if not p.is_default and p.name != active] + if other_profiles: + print() + print("→ Syncing bundled skills to other profiles...") + for p in other_profiles: + try: + r = seed_profile_skills(p.path, quiet=True) + if r: + copied = len(r.get("copied", [])) + updated = len(r.get("updated", [])) + modified = len(r.get("user_modified", [])) + parts = [] + if copied: parts.append(f"+{copied} new") + if updated: parts.append(f"↑{updated} updated") + if modified: parts.append(f"~{modified} user-modified") + status = ", ".join(parts) if parts else "up to date" + else: + status = "sync failed" + print(f" {p.name}: {status}") + except Exception as pe: + print(f" {p.name}: error ({pe})") + except Exception: + pass # profiles module not available or no profiles + # Check for config migrations print() print("→ Checking configuration for new options...") @@ -2861,10 +3129,15 @@ def cmd_update(args): print(f" ℹ️ {len(missing_config)} new config option(s) available") print() - if sys.stdin.isatty(): - response = input("Would you like to configure them now? [Y/n]: ").strip().lower() - else: + if not (sys.stdin.isatty() and sys.stdout.isatty()): + print(" ℹ Non-interactive session — skipping config migration prompt.") + print(" Run 'hermes config migrate' later to apply any new config/env options.") response = "n" + else: + try: + response = input("Would you like to configure them now? [Y/n]: ").strip().lower() + except EOFError: + response = "n" if response in ('', 'y', 'yes'): print() @@ -2912,10 +3185,11 @@ def cmd_update(args): # Check for macOS launchd service if is_macos(): try: + from hermes_cli.gateway import get_launchd_label plist_path = get_launchd_plist_path() if plist_path.exists(): check = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], + ["launchctl", "list", get_launchd_label()], capture_output=True, text=True, timeout=5, ) has_launchd_service = check.returncode == 0 @@ -2971,12 +3245,13 @@ def cmd_update(args): # after a manual SIGTERM, which would race with the # PID file cleanup. print("→ Restarting gateway service...") + _launchd_label = get_launchd_label() stop = subprocess.run( - ["launchctl", "stop", "ai.hermes.gateway"], + ["launchctl", "stop", _launchd_label], capture_output=True, text=True, timeout=10, ) start = subprocess.run( - ["launchctl", "start", "ai.hermes.gateway"], + ["launchctl", "start", _launchd_label], capture_output=True, text=True, timeout=10, ) if start.returncode == 0: @@ -3028,6 +3303,7 @@ def _coalesce_session_name_args(argv: list) -> list: "chat", "model", "gateway", "setup", "whatsapp", "login", "logout", "status", "cron", "doctor", "config", "pairing", "skills", "tools", "mcp", "sessions", "insights", "version", "update", "uninstall", + "profile", } _SESSION_FLAGS = {"-c", "--continue", "-r", "--resume"} @@ -3051,6 +3327,253 @@ def _coalesce_session_name_args(argv: list) -> list: return result +def cmd_profile(args): + """Profile management — create, delete, list, switch, alias.""" + from hermes_cli.profiles import ( + list_profiles, create_profile, delete_profile, seed_profile_skills, + get_active_profile, set_active_profile, get_active_profile_name, + check_alias_collision, create_wrapper_script, remove_wrapper_script, + _is_wrapper_dir_in_path, _get_wrapper_dir, + ) + from hermes_constants import display_hermes_home + + action = getattr(args, "profile_action", None) + + if action is None: + # Bare `hermes profile` — show current profile status + profile_name = get_active_profile_name() + dhh = display_hermes_home() + print(f"\nActive profile: {profile_name}") + print(f"Path: {dhh}") + + profiles = list_profiles() + for p in profiles: + if p.name == profile_name or (profile_name == "default" and p.is_default): + if p.model: + print(f"Model: {p.model}" + (f" ({p.provider})" if p.provider else "")) + print(f"Gateway: {'running' if p.gateway_running else 'stopped'}") + print(f"Skills: {p.skill_count} installed") + if p.alias_path: + print(f"Alias: {p.name} → hermes -p {p.name}") + break + print() + return + + if action == "list": + profiles = list_profiles() + active = get_active_profile_name() + + if not profiles: + print("No profiles found.") + return + + # Header + print(f"\n {'Profile':<16} {'Model':<28} {'Gateway':<12} {'Alias'}") + print(f" {'─' * 15} {'─' * 27} {'─' * 11} {'─' * 12}") + + for p in profiles: + marker = " ◆" if (p.name == active or (active == "default" and p.is_default)) else " " + name = p.name + model = (p.model or "—")[:26] + gw = "running" if p.gateway_running else "stopped" + alias = p.name if p.alias_path else "—" + if p.is_default: + alias = "—" + print(f"{marker}{name:<15} {model:<28} {gw:<12} {alias}") + print() + + elif action == "use": + name = args.profile_name + try: + set_active_profile(name) + if name == "default": + print(f"Switched to: default (~/.hermes)") + else: + print(f"Switched to: {name}") + except (ValueError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + elif action == "create": + name = args.profile_name + clone = getattr(args, "clone", False) + clone_all = getattr(args, "clone_all", False) + no_alias = getattr(args, "no_alias", False) + + try: + clone_from = getattr(args, "clone_from", None) + + profile_dir = create_profile( + name=name, + clone_from=clone_from, + clone_all=clone_all, + clone_config=clone, + no_alias=no_alias, + ) + print(f"\nProfile '{name}' created at {profile_dir}") + + if clone or clone_all: + source_label = getattr(args, "clone_from", None) or get_active_profile_name() + if clone_all: + print(f"Full copy from {source_label}.") + else: + print(f"Cloned config, .env, SOUL.md from {source_label}.") + + # Seed bundled skills (skip if --clone-all already copied them) + if not clone_all: + result = seed_profile_skills(profile_dir) + if result: + copied = len(result.get("copied", [])) + print(f"{copied} bundled skills synced.") + else: + print("⚠ Skills could not be seeded. Run `{} update` to retry.".format(name)) + + # Create wrapper alias + if not no_alias: + collision = check_alias_collision(name) + if collision: + print(f"\n⚠ Cannot create alias '{name}' — {collision}") + print(f" Choose a custom alias: hermes profile alias {name} --name ") + print(f" Or access via flag: hermes -p {name} chat") + else: + wrapper_path = create_wrapper_script(name) + if wrapper_path: + print(f"Wrapper created: {wrapper_path}") + if not _is_wrapper_dir_in_path(): + print(f"\n⚠ {_get_wrapper_dir()} is not in your PATH.") + print(f' Add to your shell config (~/.bashrc or ~/.zshrc):') + print(f' export PATH="$HOME/.local/bin:$PATH"') + + # Next steps + print(f"\nNext steps:") + print(f" {name} setup Configure API keys and model") + print(f" {name} chat Start chatting") + print(f" {name} gateway start Start the messaging gateway") + if clone or clone_all: + from hermes_constants import get_hermes_home + profile_dir_display = f"~/.hermes/profiles/{name}" + print(f"\n Edit {profile_dir_display}/.env for different API keys") + print(f" Edit {profile_dir_display}/SOUL.md for different personality") + print() + + except (ValueError, FileExistsError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + elif action == "delete": + name = args.profile_name + yes = getattr(args, "yes", False) + try: + delete_profile(name, yes=yes) + except (ValueError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + elif action == "show": + name = args.profile_name + from hermes_cli.profiles import get_profile_dir, profile_exists, _read_config_model, _check_gateway_running, _count_skills + if not profile_exists(name): + print(f"Error: Profile '{name}' does not exist.") + sys.exit(1) + profile_dir = get_profile_dir(name) + model, provider = _read_config_model(profile_dir) + gw = _check_gateway_running(profile_dir) + skills = _count_skills(profile_dir) + wrapper = _get_wrapper_dir() / name + + print(f"\nProfile: {name}") + print(f"Path: {profile_dir}") + if model: + print(f"Model: {model}" + (f" ({provider})" if provider else "")) + print(f"Gateway: {'running' if gw else 'stopped'}") + print(f"Skills: {skills}") + print(f".env: {'exists' if (profile_dir / '.env').exists() else 'not configured'}") + print(f"SOUL.md: {'exists' if (profile_dir / 'SOUL.md').exists() else 'not configured'}") + if wrapper.exists(): + print(f"Alias: {wrapper}") + print() + + elif action == "alias": + name = args.profile_name + remove = getattr(args, "remove", False) + custom_name = getattr(args, "alias_name", None) + + from hermes_cli.profiles import profile_exists + if not profile_exists(name): + print(f"Error: Profile '{name}' does not exist.") + sys.exit(1) + + alias_name = custom_name or name + + if remove: + if remove_wrapper_script(alias_name): + print(f"✓ Removed alias '{alias_name}'") + else: + print(f"No alias '{alias_name}' found to remove.") + else: + collision = check_alias_collision(alias_name) + if collision: + print(f"Error: {collision}") + sys.exit(1) + wrapper_path = create_wrapper_script(alias_name) + if wrapper_path: + # If custom name, write the profile name into the wrapper + if custom_name: + wrapper_path.write_text(f'#!/bin/sh\nexec hermes -p {name} "$@"\n') + print(f"✓ Alias created: {wrapper_path}") + if not _is_wrapper_dir_in_path(): + print(f"⚠ {_get_wrapper_dir()} is not in your PATH.") + + elif action == "rename": + from hermes_cli.profiles import rename_profile + try: + new_dir = rename_profile(args.old_name, args.new_name) + print(f"\nProfile renamed: {args.old_name} → {args.new_name}") + print(f"Path: {new_dir}\n") + except (ValueError, FileExistsError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + elif action == "export": + from hermes_cli.profiles import export_profile + name = args.profile_name + output = args.output or f"{name}.tar.gz" + try: + result_path = export_profile(name, output) + print(f"✓ Exported '{name}' to {result_path}") + except (ValueError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + elif action == "import": + from hermes_cli.profiles import import_profile + try: + profile_dir = import_profile(args.archive, name=getattr(args, "import_name", None)) + name = profile_dir.name + print(f"✓ Imported profile '{name}' at {profile_dir}") + + # Offer to create alias + collision = check_alias_collision(name) + if not collision: + wrapper_path = create_wrapper_script(name) + if wrapper_path: + print(f" Wrapper created: {wrapper_path}") + print() + except (ValueError, FileExistsError, FileNotFoundError) as e: + print(f"Error: {e}") + sys.exit(1) + + +def cmd_completion(args): + """Print shell completion script.""" + from hermes_cli.profiles import generate_bash_completion, generate_zsh_completion + shell = getattr(args, "shell", "bash") + if shell == "zsh": + print(generate_zsh_completion()) + else: + print(generate_bash_completion()) + + def main(): """Main entry point for hermes CLI.""" parser = argparse.ArgumentParser( @@ -3159,7 +3682,7 @@ For more help on a command: ) chat_parser.add_argument( "--provider", - choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"], + choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"], default=None, help="Inference provider (default: auto)" ) @@ -3498,7 +4021,38 @@ For more help on a command: cron_subparsers.add_parser("tick", help="Run due jobs once and exit") cron_parser.set_defaults(func=cmd_cron) - + + # ========================================================================= + # webhook command + # ========================================================================= + webhook_parser = subparsers.add_parser( + "webhook", + help="Manage dynamic webhook subscriptions", + description="Create, list, and remove webhook subscriptions for event-driven agent activation", + ) + webhook_subparsers = webhook_parser.add_subparsers(dest="webhook_action") + + wh_sub = webhook_subparsers.add_parser("subscribe", aliases=["add"], help="Create a webhook subscription") + wh_sub.add_argument("name", help="Route name (used in URL: /webhooks/)") + wh_sub.add_argument("--prompt", default="", help="Prompt template with {dot.notation} payload refs") + wh_sub.add_argument("--events", default="", help="Comma-separated event types to accept") + wh_sub.add_argument("--description", default="", help="What this subscription does") + wh_sub.add_argument("--skills", default="", help="Comma-separated skill names to load") + wh_sub.add_argument("--deliver", default="log", help="Delivery target: log, telegram, discord, slack, etc.") + wh_sub.add_argument("--deliver-chat-id", default="", help="Target chat ID for cross-platform delivery") + wh_sub.add_argument("--secret", default="", help="HMAC secret (auto-generated if omitted)") + + webhook_subparsers.add_parser("list", aliases=["ls"], help="List all dynamic subscriptions") + + wh_rm = webhook_subparsers.add_parser("remove", aliases=["rm"], help="Remove a subscription") + wh_rm.add_argument("name", help="Subscription name to remove") + + wh_test = webhook_subparsers.add_parser("test", help="Send a test POST to a webhook route") + wh_test.add_argument("name", help="Subscription name to test") + wh_test.add_argument("--payload", default="", help="JSON payload to send (default: test payload)") + + webhook_parser.set_defaults(func=cmd_webhook) + # ========================================================================= # doctor command # ========================================================================= @@ -3631,7 +4185,7 @@ For more help on a command: skills_snapshot = skills_subparsers.add_parser("snapshot", help="Export/import skill configurations") snapshot_subparsers = skills_snapshot.add_subparsers(dest="snapshot_action") snap_export = snapshot_subparsers.add_parser("export", help="Export installed skills to a file") - snap_export.add_argument("output", help="Output JSON file path") + snap_export.add_argument("output", help="Output JSON file path (use - for stdout)") snap_import = snapshot_subparsers.add_parser("import", help="Import and install skills from a file") snap_import.add_argument("input", help="Input JSON file path") snap_import.add_argument("--force", action="store_true", help="Force install despite caution verdict") @@ -3692,6 +4246,16 @@ For more help on a command: plugins_subparsers.add_parser("list", aliases=["ls"], help="List installed plugins") + plugins_enable = plugins_subparsers.add_parser( + "enable", help="Enable a disabled plugin" + ) + plugins_enable.add_argument("name", help="Plugin name to enable") + + plugins_disable = plugins_subparsers.add_parser( + "disable", help="Disable a plugin without removing it" + ) + plugins_disable.add_argument("name", help="Plugin name to disable") + def cmd_plugins(args): from hermes_cli.plugins_cmd import plugins_command plugins_command(args) @@ -3859,16 +4423,25 @@ For more help on a command: # ========================================================================= mcp_parser = subparsers.add_parser( "mcp", - help="Manage MCP server connections", + help="Manage MCP servers and run Hermes as an MCP server", description=( - "Add, remove, list, test, and configure MCP server connections.\n\n" + "Manage MCP server connections and run Hermes as an MCP server.\n\n" "MCP servers provide additional tools via the Model Context Protocol.\n" - "Use 'hermes mcp add' to connect to a new server with interactive\n" - "tool discovery. Run 'hermes mcp' with no subcommand to list servers." + "Use 'hermes mcp add' to connect to a new server, or\n" + "'hermes mcp serve' to expose Hermes conversations over MCP." ), ) mcp_sub = mcp_parser.add_subparsers(dest="mcp_action") + mcp_serve_p = mcp_sub.add_parser( + "serve", + help="Run Hermes as an MCP server (expose conversations to other agents)", + ) + mcp_serve_p.add_argument( + "-v", "--verbose", action="store_true", + help="Enable verbose logging on stderr", + ) + mcp_add_p = mcp_sub.add_parser("add", help="Add an MCP server (discovery-first install)") mcp_add_p.add_argument("name", help="Server name (used as config key)") mcp_add_p.add_argument("--url", help="HTTP/SSE endpoint URL") @@ -3908,7 +4481,7 @@ For more help on a command: sessions_list.add_argument("--limit", type=int, default=20, help="Max sessions to show") sessions_export = sessions_subparsers.add_parser("export", help="Export sessions to a JSONL file") - sessions_export.add_argument("output", help="Output JSONL file path") + sessions_export.add_argument("output", help="Output JSONL file path (use - for stdout)") sessions_export.add_argument("--source", help="Filter by source") sessions_export.add_argument("--session-id", help="Export a specific session") @@ -3989,15 +4562,25 @@ For more help on a command: if not data: print(f"Session '{args.session_id}' not found.") return - with open(args.output, "w", encoding="utf-8") as f: - f.write(_json.dumps(data, ensure_ascii=False) + "\n") - print(f"Exported 1 session to {args.output}") + line = _json.dumps(data, ensure_ascii=False) + "\n" + if args.output == "-": + import sys + sys.stdout.write(line) + else: + with open(args.output, "w", encoding="utf-8") as f: + f.write(line) + print(f"Exported 1 session to {args.output}") else: sessions = db.export_all(source=args.source) - with open(args.output, "w", encoding="utf-8") as f: + if args.output == "-": + import sys for s in sessions: - f.write(_json.dumps(s, ensure_ascii=False) + "\n") - print(f"Exported {len(sessions)} sessions to {args.output}") + sys.stdout.write(_json.dumps(s, ensure_ascii=False) + "\n") + else: + with open(args.output, "w", encoding="utf-8") as f: + for s in sessions: + f.write(_json.dumps(s, ensure_ascii=False) + "\n") + print(f"Exported {len(sessions)} sessions to {args.output}") elif action == "delete": resolved_session_id = db.resolve_session_id(args.session_id) @@ -4235,7 +4818,75 @@ For more help on a command: sys.exit(1) acp_parser.set_defaults(func=cmd_acp) - + + # ========================================================================= + # profile command + # ========================================================================= + profile_parser = subparsers.add_parser( + "profile", + help="Manage profiles — multiple isolated Hermes instances", + ) + profile_subparsers = profile_parser.add_subparsers(dest="profile_action") + + profile_list = profile_subparsers.add_parser("list", help="List all profiles") + profile_use = profile_subparsers.add_parser("use", help="Set sticky default profile") + profile_use.add_argument("profile_name", help="Profile name (or 'default')") + + profile_create = profile_subparsers.add_parser("create", help="Create a new profile") + profile_create.add_argument("profile_name", help="Profile name (lowercase, alphanumeric)") + profile_create.add_argument("--clone", action="store_true", + help="Copy config.yaml, .env, SOUL.md from active profile") + profile_create.add_argument("--clone-all", action="store_true", + help="Full copy of active profile (all state)") + profile_create.add_argument("--clone-from", metavar="SOURCE", + help="Source profile to clone from (default: active)") + profile_create.add_argument("--no-alias", action="store_true", + help="Skip wrapper script creation") + + profile_delete = profile_subparsers.add_parser("delete", help="Delete a profile") + profile_delete.add_argument("profile_name", help="Profile to delete") + profile_delete.add_argument("-y", "--yes", action="store_true", + help="Skip confirmation prompt") + + profile_show = profile_subparsers.add_parser("show", help="Show profile details") + profile_show.add_argument("profile_name", help="Profile to show") + + profile_alias = profile_subparsers.add_parser("alias", help="Manage wrapper scripts") + profile_alias.add_argument("profile_name", help="Profile name") + profile_alias.add_argument("--remove", action="store_true", + help="Remove the wrapper script") + profile_alias.add_argument("--name", dest="alias_name", metavar="NAME", + help="Custom alias name (default: profile name)") + + profile_rename = profile_subparsers.add_parser("rename", help="Rename a profile") + profile_rename.add_argument("old_name", help="Current profile name") + profile_rename.add_argument("new_name", help="New profile name") + + profile_export = profile_subparsers.add_parser("export", help="Export a profile to archive") + profile_export.add_argument("profile_name", help="Profile to export") + profile_export.add_argument("-o", "--output", default=None, + help="Output file (default: .tar.gz)") + + profile_import = profile_subparsers.add_parser("import", help="Import a profile from archive") + profile_import.add_argument("archive", help="Path to .tar.gz archive") + profile_import.add_argument("--name", dest="import_name", metavar="NAME", + help="Profile name (default: inferred from archive)") + + profile_parser.set_defaults(func=cmd_profile) + + # ========================================================================= + # completion command + # ========================================================================= + completion_parser = subparsers.add_parser( + "completion", + help="Print shell completion script (bash or zsh)", + ) + completion_parser.add_argument( + "shell", nargs="?", default="bash", choices=["bash", "zsh"], + help="Shell type (default: bash)", + ) + completion_parser.set_defaults(func=cmd_completion) + # ========================================================================= # Parse and execute # ========================================================================= diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index 025bfd627..0f08e4673 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -24,6 +24,7 @@ from hermes_cli.config import ( get_hermes_home, # noqa: F401 — used by test mocks ) from hermes_cli.colors import Colors, color +from hermes_constants import display_hermes_home logger = logging.getLogger(__name__) @@ -244,7 +245,7 @@ def cmd_mcp_add(args): api_key = _prompt("API key / Bearer token", password=True) if api_key: save_env_value(env_key, api_key) - _success(f"Saved to ~/.hermes/.env as {env_key}") + _success(f"Saved to {display_hermes_home()}/.env as {env_key}") # Set header with env var interpolation if api_key or existing_key: @@ -332,7 +333,7 @@ def cmd_mcp_add(args): _save_mcp_server(name, server_config) print() - _success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)") + _success(f"Saved '{name}' to {display_hermes_home()}/config.yaml ({tool_count}/{total} tools enabled)") _info("Start a new session to use these tools.") @@ -607,6 +608,11 @@ def mcp_command(args): """Main dispatcher for ``hermes mcp`` subcommands.""" action = getattr(args, "mcp_action", None) + if action == "serve": + from mcp_serve import run_mcp_server + run_mcp_server(verbose=getattr(args, "verbose", False)) + return + handlers = { "add": cmd_mcp_add, "remove": cmd_mcp_remove, @@ -625,6 +631,7 @@ def mcp_command(args): # No subcommand — show list cmd_mcp_list() print(color(" Commands:", Colors.CYAN)) + _info("hermes mcp serve Run as MCP server") _info("hermes mcp add --url Add an MCP server") _info("hermes mcp add --command Add a stdio server") _info("hermes mcp remove Remove a server") diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 4aa2a3d21..ef2b3deb4 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -35,6 +35,8 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [ ("openai/gpt-5.3-codex", ""), ("google/gemini-3-pro-preview", ""), ("google/gemini-3-flash-preview", ""), + ("google/gemini-3.1-pro-preview", ""), + ("google/gemini-3.1-flash-lite-preview", ""), ("qwen/qwen3.5-plus-02-15", ""), ("qwen/qwen3.5-35b-a3b", ""), ("stepfun/step-3.5-flash", ""), @@ -62,6 +64,8 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "openai/gpt-5.3-codex", "google/gemini-3-pro-preview", "google/gemini-3-flash-preview", + "google/gemini-3.1-pro-preview", + "google/gemini-3.1-flash-lite-preview", "qwen/qwen3.5-plus-02-15", "qwen/qwen3.5-35b-a3b", "stepfun/step-3.5-flash", @@ -208,14 +212,31 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "google/gemini-3-pro-preview", "google/gemini-3-flash-preview", ], + # Alibaba DashScope Coding platform (coding-intl) — default endpoint. + # Supports Qwen models + third-party providers (GLM, Kimi, MiniMax). + # Users with classic DashScope keys should override DASHSCOPE_BASE_URL + # to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat) + # or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat). "alibaba": [ "qwen3.5-plus", - "qwen3-max", "qwen3-coder-plus", "qwen3-coder-next", - "qwen-plus-latest", - "qwen3.5-flash", - "qwen-vl-max", + # Third-party models available on coding-intl + "glm-5", + "glm-4.7", + "kimi-k2.5", + "MiniMax-M2.5", + ], + # Curated HF model list — only agentic models that map to OpenRouter defaults. + "huggingface": [ + "Qwen/Qwen3.5-397B-A17B", + "Qwen/Qwen3.5-35B-A3B", + "deepseek-ai/DeepSeek-V3.2", + "moonshotai/Kimi-K2.5", + "MiniMaxAI/MiniMax-M2.5", + "zai-org/GLM-5", + "XiaomiMiMo/MiMo-V2-Flash", + "moonshotai/Kimi-K2-Thinking", ], } @@ -236,6 +257,7 @@ _PROVIDER_LABELS = { "ai-gateway": "AI Gateway", "kilocode": "Kilo Code", "alibaba": "Alibaba Cloud (DashScope)", + "huggingface": "Hugging Face", "custom": "Custom endpoint", } @@ -271,6 +293,9 @@ _PROVIDER_ALIASES = { "aliyun": "alibaba", "qwen": "alibaba", "alibaba-cloud": "alibaba", + "hf": "huggingface", + "hugging-face": "huggingface", + "huggingface-hub": "huggingface", } @@ -304,7 +329,7 @@ def list_available_providers() -> list[dict[str, str]]: # Canonical providers in display order _PROVIDER_ORDER = [ "openrouter", "nous", "openai-codex", "copilot", "copilot-acp", - "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba", + "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba", "opencode-zen", "opencode-go", "ai-gateway", "deepseek", "custom", ] diff --git a/hermes_cli/nous_subscription.py b/hermes_cli/nous_subscription.py index 063732235..5e4b36c30 100644 --- a/hermes_cli/nous_subscription.py +++ b/hermes_cli/nous_subscription.py @@ -14,6 +14,7 @@ from tools.tool_backend_helpers import ( managed_nous_tools_enabled, normalize_browser_cloud_provider, normalize_modal_mode, + resolve_modal_backend_state, resolve_openai_audio_api_key, ) @@ -185,6 +186,7 @@ def get_nous_subscription_features( else None ) + direct_exa = bool(get_env_value("EXA_API_KEY")) direct_firecrawl = bool(get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL")) direct_parallel = bool(get_env_value("PARALLEL_API_KEY")) direct_tavily = bool(get_env_value("TAVILY_API_KEY")) @@ -200,19 +202,25 @@ def get_nous_subscription_features( managed_tts_available = managed_tools_flag and nous_auth_present and is_managed_tool_gateway_ready("openai-audio") managed_browser_available = managed_tools_flag and nous_auth_present and is_managed_tool_gateway_ready("browserbase") managed_modal_available = managed_tools_flag and nous_auth_present and is_managed_tool_gateway_ready("modal") + modal_state = resolve_modal_backend_state( + modal_mode, + has_direct=direct_modal, + managed_ready=managed_modal_available, + ) web_managed = web_backend == "firecrawl" and managed_web_available and not direct_firecrawl web_active = bool( web_tool_enabled and ( web_managed + or (web_backend == "exa" and direct_exa) or (web_backend == "firecrawl" and direct_firecrawl) or (web_backend == "parallel" and direct_parallel) or (web_backend == "tavily" and direct_tavily) ) ) web_available = bool( - managed_web_available or direct_firecrawl or direct_parallel or direct_tavily + managed_web_available or direct_exa or direct_firecrawl or direct_parallel or direct_tavily ) image_managed = image_tool_enabled and managed_image_available and not direct_fal @@ -260,25 +268,31 @@ def get_nous_subscription_features( modal_available = True modal_active = bool(modal_tool_enabled) modal_direct_override = False + elif modal_state["selected_backend"] == "managed": + modal_managed = bool(modal_tool_enabled) + modal_available = True + modal_active = bool(modal_tool_enabled) + modal_direct_override = False + elif modal_state["selected_backend"] == "direct": + modal_managed = False + modal_available = True + modal_active = bool(modal_tool_enabled) + modal_direct_override = bool(modal_tool_enabled) elif modal_mode == "managed": - modal_managed = bool(modal_tool_enabled and managed_modal_available) + modal_managed = False modal_available = bool(managed_modal_available) - modal_active = bool(modal_tool_enabled and managed_modal_available) + modal_active = False modal_direct_override = False elif modal_mode == "direct": modal_managed = False modal_available = bool(direct_modal) - modal_active = bool(modal_tool_enabled and direct_modal) - modal_direct_override = bool(direct_modal) + modal_active = False + modal_direct_override = False else: - modal_managed = bool( - modal_tool_enabled - and managed_modal_available - and not direct_modal - ) + modal_managed = False modal_available = bool(managed_modal_available or direct_modal) - modal_active = bool(modal_tool_enabled and (direct_modal or managed_modal_available)) - modal_direct_override = bool(direct_modal) + modal_active = False + modal_direct_override = False tts_explicit_configured = False raw_tts_cfg = config.get("tts") diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index c5195ffa7..7779ab964 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -70,6 +70,17 @@ def _env_enabled(name: str) -> bool: return env_var_enabled(name) +def _get_disabled_plugins() -> set: + """Read the disabled plugins list from config.yaml.""" + try: + from hermes_cli.config import load_config + config = load_config() + disabled = config.get("plugins", {}).get("disabled", []) + return set(disabled) if isinstance(disabled, list) else set() + except Exception: + return set() + + # --------------------------------------------------------------------------- # Data classes # --------------------------------------------------------------------------- @@ -201,8 +212,15 @@ class PluginManager: # 3. Pip / entry-point plugins manifests.extend(self._scan_entry_points()) - # Load each manifest + # Load each manifest (skip user-disabled plugins) + disabled = _get_disabled_plugins() for manifest in manifests: + if manifest.name in disabled: + loaded = LoadedPlugin(manifest=manifest, enabled=False) + loaded.error = "disabled via config" + self._plugins[manifest.name] = loaded + logger.debug("Skipping disabled plugin '%s'", manifest.name) + continue self._load_plugin(manifest) if manifests: @@ -387,16 +405,23 @@ class PluginManager: # Hook invocation # ----------------------------------------------------------------------- - def invoke_hook(self, hook_name: str, **kwargs: Any) -> None: + def invoke_hook(self, hook_name: str, **kwargs: Any) -> List[Any]: """Call all registered callbacks for *hook_name*. Each callback is wrapped in its own try/except so a misbehaving plugin cannot break the core agent loop. + + Returns a list of non-``None`` return values from callbacks. + This allows hooks like ``pre_llm_call`` to contribute context + that the agent core can collect and inject. """ callbacks = self._hooks.get(hook_name, []) + results: List[Any] = [] for cb in callbacks: try: - cb(**kwargs) + ret = cb(**kwargs) + if ret is not None: + results.append(ret) except Exception as exc: logger.warning( "Hook '%s' callback %s raised: %s", @@ -404,6 +429,7 @@ class PluginManager: getattr(cb, "__name__", repr(cb)), exc, ) + return results # ----------------------------------------------------------------------- # Introspection @@ -448,9 +474,12 @@ def discover_plugins() -> None: get_plugin_manager().discover_and_load() -def invoke_hook(hook_name: str, **kwargs: Any) -> None: - """Invoke a lifecycle hook on all loaded plugins.""" - get_plugin_manager().invoke_hook(hook_name, **kwargs) +def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]: + """Invoke a lifecycle hook on all loaded plugins. + + Returns a list of non-``None`` return values from plugin callbacks. + """ + return get_plugin_manager().invoke_hook(hook_name, **kwargs) def get_plugin_tool_names() -> Set[str]: diff --git a/hermes_cli/plugins_cmd.py b/hermes_cli/plugins_cmd.py index e20c1e1b0..e53f5c94b 100644 --- a/hermes_cli/plugins_cmd.py +++ b/hermes_cli/plugins_cmd.py @@ -374,6 +374,73 @@ def cmd_remove(name: str) -> None: _display_removed(name, plugins_dir) +def _get_disabled_set() -> set: + """Read the disabled plugins set from config.yaml.""" + try: + from hermes_cli.config import load_config + config = load_config() + disabled = config.get("plugins", {}).get("disabled", []) + return set(disabled) if isinstance(disabled, list) else set() + except Exception: + return set() + + +def _save_disabled_set(disabled: set) -> None: + """Write the disabled plugins list to config.yaml.""" + from hermes_cli.config import load_config, save_config + config = load_config() + if "plugins" not in config: + config["plugins"] = {} + config["plugins"]["disabled"] = sorted(disabled) + save_config(config) + + +def cmd_enable(name: str) -> None: + """Enable a previously disabled plugin.""" + from rich.console import Console + + console = Console() + plugins_dir = _plugins_dir() + + # Verify the plugin exists + target = plugins_dir / name + if not target.is_dir(): + console.print(f"[red]Plugin '{name}' is not installed.[/red]") + sys.exit(1) + + disabled = _get_disabled_set() + if name not in disabled: + console.print(f"[dim]Plugin '{name}' is already enabled.[/dim]") + return + + disabled.discard(name) + _save_disabled_set(disabled) + console.print(f"[green]✓[/green] Plugin [bold]{name}[/bold] enabled. Takes effect on next session.") + + +def cmd_disable(name: str) -> None: + """Disable a plugin without removing it.""" + from rich.console import Console + + console = Console() + plugins_dir = _plugins_dir() + + # Verify the plugin exists + target = plugins_dir / name + if not target.is_dir(): + console.print(f"[red]Plugin '{name}' is not installed.[/red]") + sys.exit(1) + + disabled = _get_disabled_set() + if name in disabled: + console.print(f"[dim]Plugin '{name}' is already disabled.[/dim]") + return + + disabled.add(name) + _save_disabled_set(disabled) + console.print(f"[yellow]⊘[/yellow] Plugin [bold]{name}[/bold] disabled. Takes effect on next session.") + + def cmd_list() -> None: """List installed plugins.""" from rich.console import Console @@ -393,8 +460,11 @@ def cmd_list() -> None: console.print("[dim]Install with:[/dim] hermes plugins install owner/repo") return + disabled = _get_disabled_set() + table = Table(title="Installed Plugins", show_lines=False) table.add_column("Name", style="bold") + table.add_column("Status") table.add_column("Version", style="dim") table.add_column("Description") table.add_column("Source", style="dim") @@ -420,11 +490,86 @@ def cmd_list() -> None: if (d / ".git").exists(): source = "git" - table.add_row(name, str(version), description, source) + is_disabled = name in disabled or d.name in disabled + status = "[red]disabled[/red]" if is_disabled else "[green]enabled[/green]" + table.add_row(name, status, str(version), description, source) console.print() console.print(table) console.print() + console.print("[dim]Interactive toggle:[/dim] hermes plugins") + console.print("[dim]Enable/disable:[/dim] hermes plugins enable/disable ") + + +def cmd_toggle() -> None: + """Interactive curses checklist to enable/disable installed plugins.""" + from rich.console import Console + + try: + import yaml + except ImportError: + yaml = None + + console = Console() + plugins_dir = _plugins_dir() + + dirs = sorted(d for d in plugins_dir.iterdir() if d.is_dir()) + if not dirs: + console.print("[dim]No plugins installed.[/dim]") + console.print("[dim]Install with:[/dim] hermes plugins install owner/repo") + return + + disabled = _get_disabled_set() + + # Build items list: "name — description" for display + names = [] + labels = [] + selected = set() + + for i, d in enumerate(dirs): + manifest_file = d / "plugin.yaml" + name = d.name + description = "" + + if manifest_file.exists() and yaml: + try: + with open(manifest_file) as f: + manifest = yaml.safe_load(f) or {} + name = manifest.get("name", d.name) + description = manifest.get("description", "") + except Exception: + pass + + names.append(name) + label = f"{name} — {description}" if description else name + labels.append(label) + + if name not in disabled and d.name not in disabled: + selected.add(i) + + from hermes_cli.curses_ui import curses_checklist + + result = curses_checklist( + title="Plugins — toggle enabled/disabled", + items=labels, + selected=selected, + ) + + # Compute new disabled set from deselected items + new_disabled = set() + for i, name in enumerate(names): + if i not in result: + new_disabled.add(name) + + if new_disabled != disabled: + _save_disabled_set(new_disabled) + enabled_count = len(names) - len(new_disabled) + console.print( + f"\n[green]✓[/green] {enabled_count} enabled, {len(new_disabled)} disabled. " + f"Takes effect on next session." + ) + else: + console.print("\n[dim]No changes.[/dim]") def plugins_command(args) -> None: @@ -437,8 +582,14 @@ def plugins_command(args) -> None: cmd_update(args.name) elif action in ("remove", "rm", "uninstall"): cmd_remove(args.name) - elif action in ("list", "ls") or action is None: + elif action == "enable": + cmd_enable(args.name) + elif action == "disable": + cmd_disable(args.name) + elif action in ("list", "ls"): cmd_list() + elif action is None: + cmd_toggle() else: from rich.console import Console diff --git a/hermes_cli/profiles.py b/hermes_cli/profiles.py new file mode 100644 index 000000000..7ef39d105 --- /dev/null +++ b/hermes_cli/profiles.py @@ -0,0 +1,906 @@ +""" +Profile management for multiple isolated Hermes instances. + +Each profile is a fully independent HERMES_HOME directory with its own +config.yaml, .env, memory, sessions, skills, gateway, cron, and logs. +Profiles live under ``~/.hermes/profiles//`` by default. + +The "default" profile is ``~/.hermes`` itself — backward compatible, +zero migration needed. + +Usage:: + + hermes profile create coder # fresh profile + bundled skills + hermes profile create coder --clone # also copy config, .env, SOUL.md + hermes profile create coder --clone-all # full copy of source profile + coder chat # use via wrapper alias + hermes -p coder chat # or via flag + hermes profile use coder # set as sticky default + hermes profile delete coder # remove profile + alias + service +""" + +import json +import os +import re +import shutil +import stat +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +_PROFILE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$") + +# Directories bootstrapped inside every new profile +_PROFILE_DIRS = [ + "memories", + "sessions", + "skills", + "skins", + "logs", + "plans", + "workspace", + "cron", +] + +# Files copied during --clone (if they exist in the source) +_CLONE_CONFIG_FILES = [ + "config.yaml", + ".env", + "SOUL.md", +] + +# Runtime files stripped after --clone-all (shouldn't carry over) +_CLONE_ALL_STRIP = [ + "gateway.pid", + "gateway_state.json", + "processes.json", +] + +# Names that cannot be used as profile aliases +_RESERVED_NAMES = frozenset({ + "hermes", "default", "test", "tmp", "root", "sudo", +}) + +# Hermes subcommands that cannot be used as profile names/aliases +_HERMES_SUBCOMMANDS = frozenset({ + "chat", "model", "gateway", "setup", "whatsapp", "login", "logout", + "status", "cron", "doctor", "config", "pairing", "skills", "tools", + "mcp", "sessions", "insights", "version", "update", "uninstall", + "profile", "plugins", "honcho", "acp", +}) + + +# --------------------------------------------------------------------------- +# Path helpers +# --------------------------------------------------------------------------- + +def _get_profiles_root() -> Path: + """Return the directory where named profiles are stored. + + Always ``~/.hermes/profiles/`` — anchored to the user's home, + NOT to the current HERMES_HOME (which may itself be a profile). + This ensures ``coder profile list`` can see all profiles. + """ + return Path.home() / ".hermes" / "profiles" + + +def _get_default_hermes_home() -> Path: + """Return the default (pre-profile) HERMES_HOME path.""" + return Path.home() / ".hermes" + + +def _get_active_profile_path() -> Path: + """Return the path to the sticky active_profile file.""" + return _get_default_hermes_home() / "active_profile" + + +def _get_wrapper_dir() -> Path: + """Return the directory for wrapper scripts.""" + return Path.home() / ".local" / "bin" + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + +def validate_profile_name(name: str) -> None: + """Raise ``ValueError`` if *name* is not a valid profile identifier.""" + if name == "default": + return # special alias for ~/.hermes + if not _PROFILE_ID_RE.match(name): + raise ValueError( + f"Invalid profile name {name!r}. Must match " + f"[a-z0-9][a-z0-9_-]{{0,63}}" + ) + + +def get_profile_dir(name: str) -> Path: + """Resolve a profile name to its HERMES_HOME directory.""" + if name == "default": + return _get_default_hermes_home() + return _get_profiles_root() / name + + +def profile_exists(name: str) -> bool: + """Check whether a profile directory exists.""" + if name == "default": + return True + return get_profile_dir(name).is_dir() + + +# --------------------------------------------------------------------------- +# Alias / wrapper script management +# --------------------------------------------------------------------------- + +def check_alias_collision(name: str) -> Optional[str]: + """Return a human-readable collision message, or None if the name is safe. + + Checks: reserved names, hermes subcommands, existing binaries in PATH. + """ + if name in _RESERVED_NAMES: + return f"'{name}' is a reserved name" + if name in _HERMES_SUBCOMMANDS: + return f"'{name}' conflicts with a hermes subcommand" + + # Check existing commands in PATH + wrapper_dir = _get_wrapper_dir() + try: + result = subprocess.run( + ["which", name], capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + existing_path = result.stdout.strip() + # Allow overwriting our own wrappers + if existing_path == str(wrapper_dir / name): + try: + content = (wrapper_dir / name).read_text() + if "hermes -p" in content: + return None # it's our wrapper, safe to overwrite + except Exception: + pass + return f"'{name}' conflicts with an existing command ({existing_path})" + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + + return None # safe + + +def _is_wrapper_dir_in_path() -> bool: + """Check if ~/.local/bin is in PATH.""" + wrapper_dir = str(_get_wrapper_dir()) + return wrapper_dir in os.environ.get("PATH", "").split(os.pathsep) + + +def create_wrapper_script(name: str) -> Optional[Path]: + """Create a shell wrapper script at ~/.local/bin/. + + Returns the path to the created wrapper, or None if creation failed. + """ + wrapper_dir = _get_wrapper_dir() + try: + wrapper_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + print(f"⚠ Could not create {wrapper_dir}: {e}") + return None + + wrapper_path = wrapper_dir / name + try: + wrapper_path.write_text(f'#!/bin/sh\nexec hermes -p {name} "$@"\n') + wrapper_path.chmod(wrapper_path.stat().st_mode | stat.S_IEXEC | stat.S_IXGRP | stat.S_IXOTH) + return wrapper_path + except OSError as e: + print(f"⚠ Could not create wrapper at {wrapper_path}: {e}") + return None + + +def remove_wrapper_script(name: str) -> bool: + """Remove the wrapper script for a profile. Returns True if removed.""" + wrapper_path = _get_wrapper_dir() / name + if wrapper_path.exists(): + try: + # Verify it's our wrapper before removing + content = wrapper_path.read_text() + if "hermes -p" in content: + wrapper_path.unlink() + return True + except Exception: + pass + return False + + +# --------------------------------------------------------------------------- +# ProfileInfo +# --------------------------------------------------------------------------- + +@dataclass +class ProfileInfo: + """Summary information about a profile.""" + name: str + path: Path + is_default: bool + gateway_running: bool + model: Optional[str] = None + provider: Optional[str] = None + has_env: bool = False + skill_count: int = 0 + alias_path: Optional[Path] = None + + +def _read_config_model(profile_dir: Path) -> tuple: + """Read model/provider from a profile's config.yaml. Returns (model, provider).""" + config_path = profile_dir / "config.yaml" + if not config_path.exists(): + return None, None + try: + import yaml + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) or {} + model_cfg = cfg.get("model", {}) + if isinstance(model_cfg, str): + return model_cfg, None + if isinstance(model_cfg, dict): + return model_cfg.get("model"), model_cfg.get("provider") + return None, None + except Exception: + return None, None + + +def _check_gateway_running(profile_dir: Path) -> bool: + """Check if a gateway is running for a given profile directory.""" + pid_file = profile_dir / "gateway.pid" + if not pid_file.exists(): + return False + try: + raw = pid_file.read_text().strip() + if not raw: + return False + data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)} + pid = int(data["pid"]) + os.kill(pid, 0) # existence check + return True + except (json.JSONDecodeError, KeyError, ValueError, TypeError, + ProcessLookupError, PermissionError, OSError): + return False + + +def _count_skills(profile_dir: Path) -> int: + """Count installed skills in a profile.""" + skills_dir = profile_dir / "skills" + if not skills_dir.is_dir(): + return 0 + count = 0 + for md in skills_dir.rglob("SKILL.md"): + if "/.hub/" not in str(md) and "/.git/" not in str(md): + count += 1 + return count + + +# --------------------------------------------------------------------------- +# CRUD operations +# --------------------------------------------------------------------------- + +def list_profiles() -> List[ProfileInfo]: + """Return info for all profiles, including the default.""" + profiles = [] + wrapper_dir = _get_wrapper_dir() + + # Default profile + default_home = _get_default_hermes_home() + if default_home.is_dir(): + model, provider = _read_config_model(default_home) + profiles.append(ProfileInfo( + name="default", + path=default_home, + is_default=True, + gateway_running=_check_gateway_running(default_home), + model=model, + provider=provider, + has_env=(default_home / ".env").exists(), + skill_count=_count_skills(default_home), + )) + + # Named profiles + profiles_root = _get_profiles_root() + if profiles_root.is_dir(): + for entry in sorted(profiles_root.iterdir()): + if not entry.is_dir(): + continue + name = entry.name + if not _PROFILE_ID_RE.match(name): + continue + model, provider = _read_config_model(entry) + alias_path = wrapper_dir / name + profiles.append(ProfileInfo( + name=name, + path=entry, + is_default=False, + gateway_running=_check_gateway_running(entry), + model=model, + provider=provider, + has_env=(entry / ".env").exists(), + skill_count=_count_skills(entry), + alias_path=alias_path if alias_path.exists() else None, + )) + + return profiles + + +def create_profile( + name: str, + clone_from: Optional[str] = None, + clone_all: bool = False, + clone_config: bool = False, + no_alias: bool = False, +) -> Path: + """Create a new profile directory. + + Parameters + ---------- + name: + Profile identifier (lowercase, alphanumeric, hyphens, underscores). + clone_from: + Source profile to clone from. If ``None`` and clone_config/clone_all + is True, defaults to the currently active profile. + clone_all: + If True, do a full copytree of the source (all state). + clone_config: + If True, copy only config files (config.yaml, .env, SOUL.md). + no_alias: + If True, skip wrapper script creation. + + Returns + ------- + Path + The newly created profile directory. + """ + validate_profile_name(name) + + if name == "default": + raise ValueError( + "Cannot create a profile named 'default' — it is the built-in profile (~/.hermes)." + ) + + profile_dir = get_profile_dir(name) + if profile_dir.exists(): + raise FileExistsError(f"Profile '{name}' already exists at {profile_dir}") + + # Resolve clone source + source_dir = None + if clone_from is not None or clone_all or clone_config: + if clone_from is None: + # Default: clone from active profile + from hermes_constants import get_hermes_home + source_dir = get_hermes_home() + else: + validate_profile_name(clone_from) + source_dir = get_profile_dir(clone_from) + if not source_dir.is_dir(): + raise FileNotFoundError( + f"Source profile '{clone_from or 'active'}' does not exist at {source_dir}" + ) + + if clone_all and source_dir: + # Full copy of source profile + shutil.copytree(source_dir, profile_dir) + # Strip runtime files + for stale in _CLONE_ALL_STRIP: + (profile_dir / stale).unlink(missing_ok=True) + else: + # Bootstrap directory structure + profile_dir.mkdir(parents=True, exist_ok=True) + for subdir in _PROFILE_DIRS: + (profile_dir / subdir).mkdir(parents=True, exist_ok=True) + + # Clone config files from source + if source_dir is not None: + for filename in _CLONE_CONFIG_FILES: + src = source_dir / filename + if src.exists(): + shutil.copy2(src, profile_dir / filename) + + return profile_dir + + +def seed_profile_skills(profile_dir: Path, quiet: bool = False) -> Optional[dict]: + """Seed bundled skills into a profile via subprocess. + + Uses subprocess because sync_skills() caches HERMES_HOME at module level. + Returns the sync result dict, or None on failure. + """ + project_root = Path(__file__).parent.parent.resolve() + try: + result = subprocess.run( + [sys.executable, "-c", + "import json; from tools.skills_sync import sync_skills; " + "r = sync_skills(quiet=True); print(json.dumps(r))"], + env={**os.environ, "HERMES_HOME": str(profile_dir)}, + cwd=str(project_root), + capture_output=True, text=True, timeout=60, + ) + if result.returncode == 0 and result.stdout.strip(): + return json.loads(result.stdout.strip()) + if not quiet: + print(f"⚠ Skill seeding returned exit code {result.returncode}") + if result.stderr.strip(): + print(f" {result.stderr.strip()[:200]}") + return None + except subprocess.TimeoutExpired: + if not quiet: + print("⚠ Skill seeding timed out (60s)") + return None + except Exception as e: + if not quiet: + print(f"⚠ Skill seeding failed: {e}") + return None + + +def delete_profile(name: str, yes: bool = False) -> Path: + """Delete a profile, its wrapper script, and its gateway service. + + Stops the gateway if running. Disables systemd/launchd service first + to prevent auto-restart. + + Returns the path that was removed. + """ + validate_profile_name(name) + + if name == "default": + raise ValueError( + "Cannot delete the default profile (~/.hermes).\n" + "To remove everything, use: hermes uninstall" + ) + + profile_dir = get_profile_dir(name) + if not profile_dir.is_dir(): + raise FileNotFoundError(f"Profile '{name}' does not exist.") + + # Show what will be deleted + model, provider = _read_config_model(profile_dir) + gw_running = _check_gateway_running(profile_dir) + skill_count = _count_skills(profile_dir) + + print(f"\nProfile: {name}") + print(f"Path: {profile_dir}") + if model: + print(f"Model: {model}" + (f" ({provider})" if provider else "")) + if skill_count: + print(f"Skills: {skill_count}") + + items = [ + "All config, API keys, memories, sessions, skills, cron jobs", + ] + + # Check for service + from hermes_cli.gateway import _profile_suffix, get_service_name + wrapper_path = _get_wrapper_dir() / name + has_wrapper = wrapper_path.exists() + if has_wrapper: + items.append(f"Command alias ({wrapper_path})") + + print(f"\nThis will permanently delete:") + for item in items: + print(f" • {item}") + if gw_running: + print(f" ⚠ Gateway is running — it will be stopped.") + + # Confirmation + if not yes: + print() + try: + confirm = input(f"Type '{name}' to confirm: ").strip() + except (KeyboardInterrupt, EOFError): + print("\nCancelled.") + return profile_dir + if confirm != name: + print("Cancelled.") + return profile_dir + + # 1. Disable service (prevents auto-restart) + _cleanup_gateway_service(name, profile_dir) + + # 2. Stop running gateway + if gw_running: + _stop_gateway_process(profile_dir) + + # 3. Remove wrapper script + if has_wrapper: + if remove_wrapper_script(name): + print(f"✓ Removed {wrapper_path}") + + # 4. Remove profile directory + try: + shutil.rmtree(profile_dir) + print(f"✓ Removed {profile_dir}") + except Exception as e: + print(f"⚠ Could not remove {profile_dir}: {e}") + + # 5. Clear active_profile if it pointed to this profile + try: + active = get_active_profile() + if active == name: + set_active_profile("default") + print("✓ Active profile reset to default") + except Exception: + pass + + print(f"\nProfile '{name}' deleted.") + return profile_dir + + +def _cleanup_gateway_service(name: str, profile_dir: Path) -> None: + """Disable and remove systemd/launchd service for a profile.""" + import platform as _platform + + # Derive service name for this profile + # Temporarily set HERMES_HOME so _profile_suffix resolves correctly + old_home = os.environ.get("HERMES_HOME") + try: + os.environ["HERMES_HOME"] = str(profile_dir) + from hermes_cli.gateway import get_service_name, get_launchd_plist_path + + if _platform.system() == "Linux": + svc_name = get_service_name() + svc_file = Path.home() / ".config" / "systemd" / "user" / f"{svc_name}.service" + if svc_file.exists(): + subprocess.run( + ["systemctl", "--user", "disable", svc_name], + capture_output=True, check=False, timeout=10, + ) + subprocess.run( + ["systemctl", "--user", "stop", svc_name], + capture_output=True, check=False, timeout=10, + ) + svc_file.unlink(missing_ok=True) + subprocess.run( + ["systemctl", "--user", "daemon-reload"], + capture_output=True, check=False, timeout=10, + ) + print(f"✓ Service {svc_name} removed") + + elif _platform.system() == "Darwin": + plist_path = get_launchd_plist_path() + if plist_path.exists(): + subprocess.run( + ["launchctl", "unload", str(plist_path)], + capture_output=True, check=False, timeout=10, + ) + plist_path.unlink(missing_ok=True) + print(f"✓ Launchd service removed") + except Exception as e: + print(f"⚠ Service cleanup: {e}") + finally: + if old_home is not None: + os.environ["HERMES_HOME"] = old_home + elif "HERMES_HOME" in os.environ: + del os.environ["HERMES_HOME"] + + +def _stop_gateway_process(profile_dir: Path) -> None: + """Stop a running gateway process via its PID file.""" + import signal as _signal + import time as _time + + pid_file = profile_dir / "gateway.pid" + if not pid_file.exists(): + return + + try: + raw = pid_file.read_text().strip() + data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)} + pid = int(data["pid"]) + os.kill(pid, _signal.SIGTERM) + # Wait up to 10s for graceful shutdown + for _ in range(20): + _time.sleep(0.5) + try: + os.kill(pid, 0) + except ProcessLookupError: + print(f"✓ Gateway stopped (PID {pid})") + return + # Force kill + try: + os.kill(pid, _signal.SIGKILL) + except ProcessLookupError: + pass + print(f"✓ Gateway force-stopped (PID {pid})") + except (ProcessLookupError, PermissionError): + print("✓ Gateway already stopped") + except Exception as e: + print(f"⚠ Could not stop gateway: {e}") + + +# --------------------------------------------------------------------------- +# Active profile (sticky default) +# --------------------------------------------------------------------------- + +def get_active_profile() -> str: + """Read the sticky active profile name. + + Returns ``"default"`` if no active_profile file exists or it's empty. + """ + path = _get_active_profile_path() + try: + name = path.read_text().strip() + if not name: + return "default" + return name + except (FileNotFoundError, UnicodeDecodeError, OSError): + return "default" + + +def set_active_profile(name: str) -> None: + """Set the sticky active profile. + + Writes to ``~/.hermes/active_profile``. Use ``"default"`` to clear. + """ + validate_profile_name(name) + if name != "default" and not profile_exists(name): + raise FileNotFoundError( + f"Profile '{name}' does not exist. " + f"Create it with: hermes profile create {name}" + ) + + path = _get_active_profile_path() + path.parent.mkdir(parents=True, exist_ok=True) + if name == "default": + # Remove the file to indicate default + path.unlink(missing_ok=True) + else: + # Atomic write + tmp = path.with_suffix(".tmp") + tmp.write_text(name + "\n") + tmp.replace(path) + + +def get_active_profile_name() -> str: + """Infer the current profile name from HERMES_HOME. + + Returns ``"default"`` if HERMES_HOME is not set or points to ``~/.hermes``. + Returns the profile name if HERMES_HOME points into ``~/.hermes/profiles/``. + Returns ``"custom"`` if HERMES_HOME is set to an unrecognized path. + """ + from hermes_constants import get_hermes_home + hermes_home = get_hermes_home() + resolved = hermes_home.resolve() + + default_resolved = _get_default_hermes_home().resolve() + if resolved == default_resolved: + return "default" + + profiles_root = _get_profiles_root().resolve() + try: + rel = resolved.relative_to(profiles_root) + parts = rel.parts + if len(parts) == 1 and _PROFILE_ID_RE.match(parts[0]): + return parts[0] + except ValueError: + pass + + return "custom" + + +# --------------------------------------------------------------------------- +# Export / Import +# --------------------------------------------------------------------------- + +def export_profile(name: str, output_path: str) -> Path: + """Export a profile to a tar.gz archive. + + Returns the output file path. + """ + validate_profile_name(name) + profile_dir = get_profile_dir(name) + if not profile_dir.is_dir(): + raise FileNotFoundError(f"Profile '{name}' does not exist.") + + output = Path(output_path) + # shutil.make_archive wants the base name without extension + base = str(output).removesuffix(".tar.gz").removesuffix(".tgz") + result = shutil.make_archive(base, "gztar", str(profile_dir.parent), name) + return Path(result) + + +def import_profile(archive_path: str, name: Optional[str] = None) -> Path: + """Import a profile from a tar.gz archive. + + If *name* is not given, infers it from the archive's top-level directory. + Returns the imported profile directory. + """ + import tarfile + + archive = Path(archive_path) + if not archive.exists(): + raise FileNotFoundError(f"Archive not found: {archive}") + + # Peek at the archive to find the top-level directory name + with tarfile.open(archive, "r:gz") as tf: + top_dirs = {m.name.split("/")[0] for m in tf.getmembers() if "/" in m.name} + if not top_dirs: + top_dirs = {m.name for m in tf.getmembers() if m.isdir()} + + inferred_name = name or (top_dirs.pop() if len(top_dirs) == 1 else None) + if not inferred_name: + raise ValueError( + "Cannot determine profile name from archive. " + "Specify it explicitly: hermes profile import --name " + ) + + validate_profile_name(inferred_name) + profile_dir = get_profile_dir(inferred_name) + if profile_dir.exists(): + raise FileExistsError(f"Profile '{inferred_name}' already exists at {profile_dir}") + + profiles_root = _get_profiles_root() + profiles_root.mkdir(parents=True, exist_ok=True) + + shutil.unpack_archive(str(archive), str(profiles_root)) + + # If the archive extracted under a different name, rename + extracted = profiles_root / (top_dirs.pop() if top_dirs else inferred_name) + if extracted != profile_dir and extracted.exists(): + extracted.rename(profile_dir) + + return profile_dir + + +# --------------------------------------------------------------------------- +# Rename +# --------------------------------------------------------------------------- + +def rename_profile(old_name: str, new_name: str) -> Path: + """Rename a profile: directory, wrapper script, service, active_profile. + + Returns the new profile directory. + """ + validate_profile_name(old_name) + validate_profile_name(new_name) + + if old_name == "default": + raise ValueError("Cannot rename the default profile.") + if new_name == "default": + raise ValueError("Cannot rename to 'default' — it is reserved.") + + old_dir = get_profile_dir(old_name) + new_dir = get_profile_dir(new_name) + + if not old_dir.is_dir(): + raise FileNotFoundError(f"Profile '{old_name}' does not exist.") + if new_dir.exists(): + raise FileExistsError(f"Profile '{new_name}' already exists.") + + # 1. Stop gateway if running + if _check_gateway_running(old_dir): + _cleanup_gateway_service(old_name, old_dir) + _stop_gateway_process(old_dir) + + # 2. Rename directory + old_dir.rename(new_dir) + print(f"✓ Renamed {old_dir.name} → {new_dir.name}") + + # 3. Update wrapper script + remove_wrapper_script(old_name) + collision = check_alias_collision(new_name) + if not collision: + create_wrapper_script(new_name) + print(f"✓ Alias updated: {new_name}") + else: + print(f"⚠ Cannot create alias '{new_name}' — {collision}") + + # 4. Update active_profile if it pointed to old name + try: + if get_active_profile() == old_name: + set_active_profile(new_name) + print(f"✓ Active profile updated: {new_name}") + except Exception: + pass + + return new_dir + + +# --------------------------------------------------------------------------- +# Tab completion +# --------------------------------------------------------------------------- + +def generate_bash_completion() -> str: + """Generate a bash completion script for hermes profile names.""" + return '''# Hermes Agent profile completion +# Add to ~/.bashrc: eval "$(hermes completion bash)" + +_hermes_profiles() { + local profiles_dir="$HOME/.hermes/profiles" + local profiles="default" + if [ -d "$profiles_dir" ]; then + profiles="$profiles $(ls "$profiles_dir" 2>/dev/null)" + fi + echo "$profiles" +} + +_hermes_completion() { + local cur prev + cur="${COMP_WORDS[COMP_CWORD]}" + prev="${COMP_WORDS[COMP_CWORD-1]}" + + # Complete profile names after -p / --profile + if [[ "$prev" == "-p" || "$prev" == "--profile" ]]; then + COMPREPLY=($(compgen -W "$(_hermes_profiles)" -- "$cur")) + return + fi + + # Complete profile subcommands + if [[ "${COMP_WORDS[1]}" == "profile" ]]; then + case "$prev" in + profile) + COMPREPLY=($(compgen -W "list use create delete show alias rename export import" -- "$cur")) + return + ;; + use|delete|show|alias|rename|export) + COMPREPLY=($(compgen -W "$(_hermes_profiles)" -- "$cur")) + return + ;; + esac + fi + + # Top-level subcommands + if [[ "$COMP_CWORD" == 1 ]]; then + local commands="chat model gateway setup status cron doctor config skills tools mcp sessions profile update version" + COMPREPLY=($(compgen -W "$commands" -- "$cur")) + fi +} + +complete -F _hermes_completion hermes +''' + + +def generate_zsh_completion() -> str: + """Generate a zsh completion script for hermes profile names.""" + return '''#compdef hermes +# Hermes Agent profile completion +# Add to ~/.zshrc: eval "$(hermes completion zsh)" + +_hermes() { + local -a profiles + profiles=(default) + if [[ -d "$HOME/.hermes/profiles" ]]; then + profiles+=("${(@f)$(ls $HOME/.hermes/profiles 2>/dev/null)}") + fi + + _arguments \\ + '-p[Profile name]:profile:($profiles)' \\ + '--profile[Profile name]:profile:($profiles)' \\ + '1:command:(chat model gateway setup status cron doctor config skills tools mcp sessions profile update version)' \\ + '*::arg:->args' + + case $words[1] in + profile) + _arguments '1:action:(list use create delete show alias rename export import)' \\ + '2:profile:($profiles)' + ;; + esac +} + +_hermes "$@" +''' + + +# --------------------------------------------------------------------------- +# Profile env resolution (called from _apply_profile_override) +# --------------------------------------------------------------------------- + +def resolve_profile_env(profile_name: str) -> str: + """Resolve a profile name to a HERMES_HOME path string. + + Called early in the CLI entry point, before any hermes modules + are imported, to set the HERMES_HOME environment variable. + """ + validate_profile_name(profile_name) + profile_dir = get_profile_dir(profile_name) + + if profile_name != "default" and not profile_dir.is_dir(): + raise FileNotFoundError( + f"Profile '{profile_name}' does not exist. " + f"Create it with: hermes profile create {profile_name}" + ) + + return str(profile_dir) diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 760775c4c..0c82805d5 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -63,8 +63,11 @@ def _get_model_config() -> Dict[str, Any]: model_cfg = config.get("model") if isinstance(model_cfg, dict): cfg = dict(model_cfg) - default = cfg.get("default", "").strip() - base_url = cfg.get("base_url", "").strip() + # Accept "model" as alias for "default" (users intuitively write model.model) + if not cfg.get("default") and cfg.get("model"): + cfg["default"] = cfg["model"] + default = (cfg.get("default") or "").strip() + base_url = (cfg.get("base_url") or "").strip() is_local = "localhost" in base_url or "127.0.0.1" in base_url is_fallback = not default or default == "anthropic/claude-opus-4.6" if is_local and is_fallback and base_url: @@ -203,7 +206,7 @@ def _resolve_named_custom_runtime( or _detect_api_mode_for_url(base_url) or "chat_completions", "base_url": base_url, - "api_key": api_key, + "api_key": api_key or "no-key-required", "source": f"custom_provider:{custom_provider.get('name', requested_provider)}", } @@ -407,12 +410,6 @@ def resolve_runtime_provider( # (e.g. https://api.minimax.io/anthropic, https://dashscope.../anthropic) elif base_url.rstrip("/").endswith("/anthropic"): api_mode = "anthropic_messages" - # MiniMax providers always use Anthropic Messages API. - # Auto-correct stale /v1 URLs (from old .env or config) to /anthropic. - elif provider in ("minimax", "minimax-cn"): - api_mode = "anthropic_messages" - if base_url.rstrip("/").endswith("/v1"): - base_url = base_url.rstrip("/")[:-3] + "/anthropic" return { "provider": provider, "api_mode": api_mode, diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 1abf37610..fc34c0b7b 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -98,6 +98,11 @@ _DEFAULT_PROVIDER_MODELS = { "minimax-cn": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"], "ai-gateway": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5", "google/gemini-3-flash"], "kilocode": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5.4", "google/gemini-3-pro-preview", "google/gemini-3-flash-preview"], + "huggingface": [ + "Qwen/Qwen3.5-397B-A17B", "Qwen/Qwen3-235B-A22B-Thinking-2507", + "Qwen/Qwen3-Coder-480B-A35B-Instruct", "deepseek-ai/DeepSeek-R1-0528", + "deepseek-ai/DeepSeek-V3.2", "moonshotai/Kimi-K2.5", + ], } @@ -302,6 +307,7 @@ from hermes_cli.config import ( get_env_value, ensure_hermes_home, ) +# display_hermes_home imported lazily at call sites (stale-module safety during hermes update) from hermes_cli.colors import Colors, color @@ -599,7 +605,7 @@ def _print_setup_summary(config: dict, hermes_home): else: tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY")) - # Web tools (Parallel, Firecrawl, or Tavily) + # Web tools (Exa, Parallel, Firecrawl, or Tavily) if subscription_features.web.managed_by_nous: tool_status.append(("Web Search & Extract (Nous subscription)", True, None)) elif subscription_features.web.available: @@ -608,7 +614,7 @@ def _print_setup_summary(config: dict, hermes_home): label = f"Web Search & Extract ({subscription_features.web.current_provider})" tool_status.append((label, True, None)) else: - tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY, FIRECRAWL_API_KEY, or TAVILY_API_KEY")) + tool_status.append(("Web Search & Extract", False, "EXA_API_KEY, PARALLEL_API_KEY, FIRECRAWL_API_KEY/FIRECRAWL_API_URL, or TAVILY_API_KEY")) # Browser tools (local Chromium or Browserbase cloud) import shutil @@ -720,7 +726,8 @@ def _print_setup_summary(config: dict, hermes_home): print_warning( "Some tools are disabled. Run 'hermes setup tools' to configure them," ) - print_warning("or edit ~/.hermes/.env directly to add the missing API keys.") + from hermes_constants import display_hermes_home as _dhh + print_warning(f"or edit {_dhh()}/.env directly to add the missing API keys.") print() # Done banner @@ -743,7 +750,8 @@ def _print_setup_summary(config: dict, hermes_home): print() # Show file locations prominently - print(color("📁 All your files are in ~/.hermes/:", Colors.CYAN, Colors.BOLD)) + from hermes_constants import display_hermes_home as _dhh + print(color(f"📁 All your files are in {_dhh()}/:", Colors.CYAN, Colors.BOLD)) print() print(f" {color('Settings:', Colors.YELLOW)} {get_config_path()}") print(f" {color('API Keys:', Colors.YELLOW)} {get_env_path()}") @@ -926,6 +934,7 @@ def setup_model_provider(config: dict): "OpenCode Go (open models, $10/month subscription)", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)", + "Hugging Face Inference Providers (20+ open models)", ] if keep_label: provider_choices.append(keep_label) @@ -1574,7 +1583,26 @@ def setup_model_provider(config: dict): _set_model_provider(config, "copilot-acp", pconfig.inference_base_url) selected_base_url = pconfig.inference_base_url - # else: provider_idx == 16 (Keep current) — only shown when a provider already exists + elif provider_idx == 16: # Hugging Face Inference Providers + selected_provider = "huggingface" + print() + print_header("Hugging Face API Token") + pconfig = PROVIDER_REGISTRY["huggingface"] + print_info(f"Provider: {pconfig.name}") + print_info("Get your token at: https://huggingface.co/settings/tokens") + print_info("Required permission: 'Make calls to Inference Providers'") + print() + + api_key = prompt(" HF Token", password=True) + if api_key: + save_env_value("HF_TOKEN", api_key) + # Clear OpenRouter env vars to prevent routing confusion + save_env_value("OPENAI_BASE_URL", "") + save_env_value("OPENAI_API_KEY", "") + _set_model_provider(config, "huggingface", pconfig.inference_base_url) + selected_base_url = pconfig.inference_base_url + + # else: provider_idx == 17 (Keep current) — only shown when a provider already exists # Normalize "keep current" to an explicit provider so downstream logic # doesn't fall back to the generic OpenRouter/static-model path. if selected_provider is None: @@ -2178,11 +2206,11 @@ def setup_terminal_backend(config: dict): config["terminal"]["modal_mode"] = "direct" print_info("Requires a Modal account: https://modal.com") - # Check if swe-rex[modal] is installed + # Check if modal SDK is installed try: - __import__("swe_rex") + __import__("modal") except ImportError: - print_info("Installing swe-rex[modal]...") + print_info("Installing modal SDK...") import subprocess uv_bin = shutil.which("uv") @@ -2194,23 +2222,21 @@ def setup_terminal_backend(config: dict): "install", "--python", sys.executable, - "swe-rex[modal]", + "modal", ], capture_output=True, text=True, ) else: result = subprocess.run( - [sys.executable, "-m", "pip", "install", "swe-rex[modal]"], + [sys.executable, "-m", "pip", "install", "modal"], capture_output=True, text=True, ) if result.returncode == 0: - print_success("swe-rex[modal] installed") + print_success("modal SDK installed") else: - print_warning( - "Install failed — run manually: pip install 'swe-rex[modal]'" - ) + print_warning("Install failed — run manually: pip install modal") # Modal token print() @@ -2925,7 +2951,8 @@ def setup_gateway(config: dict): save_env_value("WEBHOOK_ENABLED", "true") print() print_success("Webhooks enabled! Next steps:") - print_info(" 1. Define webhook routes in ~/.hermes/config.yaml") + from hermes_constants import display_hermes_home as _dhh + print_info(f" 1. Define webhook routes in {_dhh()}/config.yaml") print_info(" 2. Point your service (GitHub, GitLab, etc.) at:") print_info(" http://your-server:8644/webhooks/") print() @@ -3081,6 +3108,95 @@ def setup_tools(config: dict, first_install: bool = False): tools_command(first_install=first_install, config=config) +# ============================================================================= +# Post-Migration Section Skip Logic +# ============================================================================= + + +def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]: + """Return a short summary if a setup section is already configured, else None. + + Used after OpenClaw migration to detect which sections can be skipped. + ``get_env_value`` is the module-level import from hermes_cli.config + so that test patches on ``setup_mod.get_env_value`` take effect. + """ + if section_key == "model": + has_key = bool( + get_env_value("OPENROUTER_API_KEY") + or get_env_value("OPENAI_API_KEY") + or get_env_value("ANTHROPIC_API_KEY") + ) + if not has_key: + # Check for OAuth providers + try: + from hermes_cli.auth import get_active_provider + if get_active_provider(): + has_key = True + except Exception: + pass + if not has_key: + return None + model = config.get("model") + if isinstance(model, str) and model.strip(): + return model.strip() + if isinstance(model, dict): + return str(model.get("default") or model.get("model") or "configured") + return "configured" + + elif section_key == "terminal": + backend = config.get("terminal", {}).get("backend", "local") + return f"backend: {backend}" + + elif section_key == "agent": + max_turns = config.get("agent", {}).get("max_turns", 90) + return f"max turns: {max_turns}" + + elif section_key == "gateway": + platforms = [] + if get_env_value("TELEGRAM_BOT_TOKEN"): + platforms.append("Telegram") + if get_env_value("DISCORD_BOT_TOKEN"): + platforms.append("Discord") + if get_env_value("SLACK_BOT_TOKEN"): + platforms.append("Slack") + if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): + platforms.append("WhatsApp") + if get_env_value("SIGNAL_ACCOUNT"): + platforms.append("Signal") + if platforms: + return ", ".join(platforms) + return None # No platforms configured — section must run + + elif section_key == "tools": + tools = [] + if get_env_value("ELEVENLABS_API_KEY"): + tools.append("TTS/ElevenLabs") + if get_env_value("BROWSERBASE_API_KEY"): + tools.append("Browser") + if get_env_value("FIRECRAWL_API_KEY"): + tools.append("Firecrawl") + if tools: + return ", ".join(tools) + return None + + return None + + +def _skip_configured_section( + config: dict, section_key: str, label: str +) -> bool: + """Show an already-configured section summary and offer to skip. + + Returns True if the user chose to skip, False if the section should run. + """ + summary = _get_section_config_summary(config, section_key) + if not summary: + return False + print() + print_success(f" {label}: {summary}") + return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False) + + # ============================================================================= # OpenClaw Migration # ============================================================================= @@ -3152,7 +3268,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool: target_root=hermes_home.resolve(), execute=True, workspace_target=None, - overwrite=False, + overwrite=True, migrate_secrets=True, output_dir=None, selected_options=selected, @@ -3319,6 +3435,8 @@ def run_setup_wizard(args): ) ) + migration_ran = False + if is_existing: # ── Returning User Menu ── print() @@ -3387,7 +3505,8 @@ def run_setup_wizard(args): return # Offer OpenClaw migration before configuration begins - if _offer_openclaw_migration(hermes_home): + migration_ran = _offer_openclaw_migration(hermes_home) + if migration_ran: # Reload config in case migration wrote to it config = load_config() @@ -3400,20 +3519,31 @@ def run_setup_wizard(args): print() print_info("You can edit these files directly or use 'hermes config edit'") + if migration_ran: + print() + print_info("Settings were imported from OpenClaw.") + print_info("Each section below will show what was imported — press Enter to keep,") + print_info("or choose to reconfigure if needed.") + # Section 1: Model & Provider - setup_model_provider(config) + if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")): + setup_model_provider(config) # Section 2: Terminal Backend - setup_terminal_backend(config) + if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")): + setup_terminal_backend(config) # Section 3: Agent Settings - setup_agent_settings(config) + if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")): + setup_agent_settings(config) # Section 4: Messaging Platforms - setup_gateway(config) + if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")): + setup_gateway(config) # Section 5: Tools - setup_tools(config, first_install=not is_existing) + if not (migration_ran and _skip_configured_section(config, "tools", "Tools")): + setup_tools(config, first_install=not is_existing) # Save and show summary save_config(config) diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index d1d8d50a3..07ccd0af9 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -24,6 +24,12 @@ PLATFORMS = { "whatsapp": "📱 WhatsApp", "signal": "📡 Signal", "email": "📧 Email", + "homeassistant": "🏠 Home Assistant", + "mattermost": "💬 Mattermost", + "matrix": "💬 Matrix", + "dingtalk": "💬 DingTalk", + "feishu": "🪽 Feishu", + "wecom": "💬 WeCom", } # ─── Config Helpers ─────────────────────────────────────────────────────────── diff --git a/hermes_cli/skills_hub.py b/hermes_cli/skills_hub.py index a36ee78ce..359e8b912 100644 --- a/hermes_cli/skills_hub.py +++ b/hermes_cli/skills_hub.py @@ -21,6 +21,7 @@ from rich.table import Table # Lazy imports to avoid circular dependencies and slow startup. # tools.skills_hub and tools.skills_guard are imported inside functions. +from hermes_constants import display_hermes_home _console = Console() @@ -304,7 +305,8 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all", def do_install(identifier: str, category: str = "", force: bool = False, - console: Optional[Console] = None, skip_confirm: bool = False) -> None: + console: Optional[Console] = None, skip_confirm: bool = False, + invalidate_cache: bool = True) -> None: """Fetch, quarantine, scan, confirm, and install a skill.""" from tools.skills_hub import ( GitHubAuth, create_source_router, ensure_hub_dirs, @@ -387,7 +389,7 @@ def do_install(identifier: str, category: str = "", force: bool = False, "[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n" "It ships with hermes-agent but is not activated by default.\n" "Installing will copy it to your skills directory where the agent can use it.\n\n" - f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", + f"Files will be at: [cyan]{display_hermes_home()}/skills/{category + '/' if category else ''}{bundle.name}/[/]", title="Official Skill", border_style="bright_cyan", )) @@ -397,7 +399,7 @@ def do_install(identifier: str, category: str = "", force: bool = False, "External skills can contain instructions that influence agent behavior,\n" "shell commands, and scripts. Even after automated scanning, you should\n" "review the installed files before use.\n\n" - f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]", + f"Files will be at: [cyan]{display_hermes_home()}/skills/{category + '/' if category else ''}{bundle.name}/[/]", title="Disclaimer", border_style="yellow", )) @@ -417,6 +419,17 @@ def do_install(identifier: str, category: str = "", force: bool = False, c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}") c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n") + if invalidate_cache: + # Invalidate the skills prompt cache so the new skill appears immediately + try: + from agent.prompt_builder import clear_skills_system_prompt_cache + clear_skills_system_prompt_cache(clear_snapshot=True) + except Exception: + pass + else: + c.print("[dim]Skill will be available in your next session.[/]") + c.print("[dim]Use /reset to start a new session now, or --now to activate immediately (invalidates prompt cache).[/]\n") + def do_inspect(identifier: str, console: Optional[Console] = None) -> None: """Preview a skill's SKILL.md content without installing.""" @@ -603,7 +616,8 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N def do_uninstall(name: str, console: Optional[Console] = None, - skip_confirm: bool = False) -> None: + skip_confirm: bool = False, + invalidate_cache: bool = True) -> None: """Remove a hub-installed skill with confirmation.""" from tools.skills_hub import uninstall_skill @@ -623,6 +637,15 @@ def do_uninstall(name: str, console: Optional[Console] = None, success, msg = uninstall_skill(name) if success: c.print(f"[bold green]{msg}[/]\n") + if invalidate_cache: + try: + from agent.prompt_builder import clear_skills_system_prompt_cache + clear_skills_system_prompt_cache(clear_snapshot=True) + except Exception: + pass + else: + c.print("[dim]Change will take effect in your next session.[/]") + c.print("[dim]Use /reset to start a new session now, or --now to apply immediately (invalidates prompt cache).[/]\n") else: c.print(f"[bold red]Error:[/] {msg}\n") @@ -722,7 +745,7 @@ def do_publish(skill_path: str, target: str = "github", repo: str = "", auth = GitHubAuth() if not auth.is_authenticated(): c.print("[bold red]Error:[/] GitHub authentication required.\n" - "Set GITHUB_TOKEN in ~/.hermes/.env or run 'gh auth login'.\n") + f"Set GITHUB_TOKEN in {display_hermes_home()}/.env or run 'gh auth login'.\n") return c.print(f"[bold]Publishing '{name}' to {repo}...[/]") @@ -865,10 +888,15 @@ def do_snapshot_export(output_path: str, console: Optional[Console] = None) -> N "taps": tap_list, } - out = Path(output_path) - out.write_text(json.dumps(snapshot, indent=2, ensure_ascii=False) + "\n") - c.print(f"[bold green]Snapshot exported:[/] {out}") - c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n") + payload = json.dumps(snapshot, indent=2, ensure_ascii=False) + "\n" + if output_path == "-": + import sys + sys.stdout.write(payload) + else: + out = Path(output_path) + out.write_text(payload) + c.print(f"[bold green]Snapshot exported:[/] {out}") + c.print(f"[dim]{len(installed)} skill(s), {len(tap_list)} tap(s)[/]\n") def do_snapshot_import(input_path: str, force: bool = False, @@ -1059,19 +1087,23 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: elif action == "install": if not args: - c.print("[bold red]Usage:[/] /skills install [--category ] [--force|--yes]\n") + c.print("[bold red]Usage:[/] /skills install [--category ] [--force] [--now]\n") return identifier = args[0] category = "" - # --yes / -y bypasses confirmation prompt (needed in TUI mode) - # --force handles reinstall override - skip_confirm = any(flag in args for flag in ("--yes", "-y")) + # Slash commands run inside prompt_toolkit where input() hangs. + # Always skip confirmation — the user typing the command is implicit consent. + skip_confirm = True force = "--force" in args + # --now invalidates prompt cache immediately (costs more money). + # Default: defer to next session to preserve cache. + invalidate_cache = "--now" in args for i, a in enumerate(args): if a == "--category" and i + 1 < len(args): category = args[i + 1] do_install(identifier, category=category, force=force, - skip_confirm=skip_confirm, console=c) + skip_confirm=skip_confirm, invalidate_cache=invalidate_cache, + console=c) elif action == "inspect": if not args: @@ -1101,10 +1133,13 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: elif action == "uninstall": if not args: - c.print("[bold red]Usage:[/] /skills uninstall [--yes]\n") + c.print("[bold red]Usage:[/] /skills uninstall [--now]\n") return - skip_confirm = any(flag in args for flag in ("--yes", "-y")) - do_uninstall(args[0], console=c, skip_confirm=skip_confirm) + # Slash commands run inside prompt_toolkit where input() hangs. + skip_confirm = True + invalidate_cache = "--now" in args + do_uninstall(args[0], console=c, skip_confirm=skip_confirm, + invalidate_cache=invalidate_cache) elif action == "publish": if not args: diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 4b68c084b..14d0a9c08 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -281,6 +281,9 @@ def show_status(args): "Slack": ("SLACK_BOT_TOKEN", None), "Email": ("EMAIL_ADDRESS", "EMAIL_HOME_ADDRESS"), "SMS": ("TWILIO_ACCOUNT_SID", "SMS_HOME_CHANNEL"), + "DingTalk": ("DINGTALK_CLIENT_ID", None), + "Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"), + "WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): @@ -319,8 +322,9 @@ def show_status(args): print(" Manager: systemd (user)") elif sys.platform == 'darwin': + from hermes_cli.gateway import get_launchd_label result = subprocess.run( - ["launchctl", "list", "ai.hermes.gateway"], + ["launchctl", "list", get_launchd_label()], capture_output=True, text=True ) diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 4046f40ac..f9e6d28ee 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -9,6 +9,8 @@ Saves per-platform tool configuration to ~/.hermes/config.yaml under the `platform_toolsets` key. """ +import json as _json +import logging import sys from pathlib import Path from typing import Dict, List, Optional, Set @@ -24,6 +26,8 @@ from hermes_cli.nous_subscription import ( ) from tools.tool_backend_helpers import managed_nous_tools_enabled +logger = logging.getLogger(__name__) + PROJECT_ROOT = Path(__file__).parent.parent.resolve() @@ -113,7 +117,8 @@ def _get_effective_configurable_toolsets(): """ result = list(CONFIGURABLE_TOOLSETS) try: - from hermes_cli.plugins import get_plugin_toolsets + from hermes_cli.plugins import discover_plugins, get_plugin_toolsets + discover_plugins() # idempotent — ensures plugins are loaded result.extend(get_plugin_toolsets()) except Exception: pass @@ -123,7 +128,8 @@ def _get_effective_configurable_toolsets(): def _get_plugin_toolset_keys() -> set: """Return the set of toolset keys provided by plugins.""" try: - from hermes_cli.plugins import get_plugin_toolsets + from hermes_cli.plugins import discover_plugins, get_plugin_toolsets + discover_plugins() # idempotent — ensures plugins are loaded return {ts_key for ts_key, _, _ in get_plugin_toolsets()} except Exception: return set() @@ -138,7 +144,12 @@ PLATFORMS = { "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, "homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"}, "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, - "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, + "matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"}, + "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, + "feishu": {"label": "🪽 Feishu", "default_toolset": "hermes-feishu"}, + "wecom": {"label": "💬 WeCom", "default_toolset": "hermes-wecom"}, + "api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"}, + "mattermost": {"label": "💬 Mattermost", "default_toolset": "hermes-mattermost"}, } @@ -208,6 +219,14 @@ TOOL_CATEGORIES = { {"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"}, ], }, + { + "name": "Exa", + "tag": "AI-native search and contents", + "web_backend": "exa", + "env_vars": [ + {"key": "EXA_API_KEY", "prompt": "Exa API key", "url": "https://exa.ai"}, + ], + }, { "name": "Parallel", "tag": "AI-native search and extract", @@ -354,7 +373,8 @@ def _run_post_setup(post_setup_key: str): if result.returncode == 0: _print_success(" Node.js dependencies installed") else: - _print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install") + from hermes_constants import display_hermes_home + _print_warning(f" npm install failed - run manually: cd {display_hermes_home()}/hermes-agent && npm install") elif not node_modules.exists(): _print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)") @@ -689,9 +709,61 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int: return default +# ─── Token Estimation ──────────────────────────────────────────────────────── + +# Module-level cache so discovery + tokenization runs at most once per process. +_tool_token_cache: Optional[Dict[str, int]] = None + + +def _estimate_tool_tokens() -> Dict[str, int]: + """Return estimated token counts per individual tool name. + + Uses tiktoken (cl100k_base) to count tokens in the JSON-serialised + OpenAI-format tool schema. Triggers tool discovery on first call, + then caches the result for the rest of the process. + + Returns an empty dict when tiktoken or the registry is unavailable. + """ + global _tool_token_cache + if _tool_token_cache is not None: + return _tool_token_cache + + try: + import tiktoken + enc = tiktoken.get_encoding("cl100k_base") + except Exception: + logger.debug("tiktoken unavailable; skipping tool token estimation") + _tool_token_cache = {} + return _tool_token_cache + + try: + # Trigger full tool discovery (imports all tool modules). + import model_tools # noqa: F401 + from tools.registry import registry + except Exception: + logger.debug("Tool registry unavailable; skipping token estimation") + _tool_token_cache = {} + return _tool_token_cache + + counts: Dict[str, int] = {} + for name in registry.get_all_tool_names(): + schema = registry.get_schema(name) + if schema: + # Mirror what gets sent to the API: + # {"type": "function", "function": } + text = _json.dumps({"type": "function", "function": schema}) + counts[name] = len(enc.encode(text)) + _tool_token_cache = counts + return _tool_token_cache + + def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]: """Multi-select checklist of toolsets. Returns set of selected toolset keys.""" from hermes_cli.curses_ui import curses_checklist + from toolsets import resolve_toolset + + # Pre-compute per-tool token counts (cached after first call). + tool_tokens = _estimate_tool_tokens() effective = _get_effective_configurable_toolsets() @@ -707,11 +779,27 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str if ts_key in enabled } + # Build a live status function that shows deduplicated total token cost. + status_fn = None + if tool_tokens: + ts_keys = [ts_key for ts_key, _, _ in effective] + + def status_fn(chosen: set) -> str: + # Collect unique tool names across all selected toolsets + all_tools: set = set() + for idx in chosen: + all_tools.update(resolve_toolset(ts_keys[idx])) + total = sum(tool_tokens.get(name, 0) for name in all_tools) + if total >= 1000: + return f"Est. tool context: ~{total / 1000:.1f}k tokens" + return f"Est. tool context: ~{total} tokens" + chosen = curses_checklist( f"Tools for {platform_label}", labels, pre_selected, cancel_returns=pre_selected, + status_fn=status_fn, ) return {effective[i][0] for i in chosen} @@ -1399,7 +1487,8 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): platform_choices[idx] = f"Configure {pinfo['label']} ({new_count}/{total} enabled)" print() - print(color(" Tool configuration saved to ~/.hermes/config.yaml", Colors.DIM)) + from hermes_constants import display_hermes_home + print(color(f" Tool configuration saved to {display_hermes_home()}/config.yaml", Colors.DIM)) print(color(" Changes take effect on next 'hermes' or gateway restart.", Colors.DIM)) print() diff --git a/hermes_cli/webhook.py b/hermes_cli/webhook.py new file mode 100644 index 000000000..264e7f842 --- /dev/null +++ b/hermes_cli/webhook.py @@ -0,0 +1,260 @@ +"""hermes webhook — manage dynamic webhook subscriptions from the CLI. + +Usage: + hermes webhook subscribe [options] + hermes webhook list + hermes webhook remove + hermes webhook test [--payload '{"key": "value"}'] + +Subscriptions persist to ~/.hermes/webhook_subscriptions.json and are +hot-reloaded by the webhook adapter without a gateway restart. +""" + +import json +import os +import re +import secrets +import time +from pathlib import Path +from typing import Dict, Optional + +from hermes_constants import display_hermes_home + + +_SUBSCRIPTIONS_FILENAME = "webhook_subscriptions.json" + + +def _hermes_home() -> Path: + return Path( + os.getenv("HERMES_HOME", str(Path.home() / ".hermes")) + ).expanduser() + + +def _subscriptions_path() -> Path: + return _hermes_home() / _SUBSCRIPTIONS_FILENAME + + +def _load_subscriptions() -> Dict[str, dict]: + path = _subscriptions_path() + if not path.exists(): + return {} + try: + data = json.loads(path.read_text(encoding="utf-8")) + return data if isinstance(data, dict) else {} + except Exception: + return {} + + +def _save_subscriptions(subs: Dict[str, dict]) -> None: + path = _subscriptions_path() + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(".tmp") + tmp_path.write_text( + json.dumps(subs, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + os.replace(str(tmp_path), str(path)) + + +def _get_webhook_config() -> dict: + """Load webhook platform config. Returns {} if not configured.""" + try: + from hermes_cli.config import load_config + cfg = load_config() + return cfg.get("platforms", {}).get("webhook", {}) + except Exception: + return {} + + +def _is_webhook_enabled() -> bool: + return bool(_get_webhook_config().get("enabled")) + + +def _get_webhook_base_url() -> str: + wh = _get_webhook_config().get("extra", {}) + host = wh.get("host", "0.0.0.0") + port = wh.get("port", 8644) + display_host = "localhost" if host == "0.0.0.0" else host + return f"http://{display_host}:{port}" + + +def _setup_hint() -> str: + _dhh = display_hermes_home() + return f""" + Webhook platform is not enabled. To set it up: + + 1. Run the gateway setup wizard: + hermes gateway setup + + 2. Or manually add to {_dhh}/config.yaml: + platforms: + webhook: + enabled: true + extra: + host: "0.0.0.0" + port: 8644 + secret: "your-global-hmac-secret" + + 3. Or set environment variables in {_dhh}/.env: + WEBHOOK_ENABLED=true + WEBHOOK_PORT=8644 + WEBHOOK_SECRET=your-global-secret + + Then start the gateway: hermes gateway run +""" + + +def _require_webhook_enabled() -> bool: + """Check webhook is enabled. Print setup guide and return False if not.""" + if _is_webhook_enabled(): + return True + print(_setup_hint()) + return False + + +def webhook_command(args): + """Entry point for 'hermes webhook' subcommand.""" + sub = getattr(args, "webhook_action", None) + + if not sub: + print("Usage: hermes webhook {subscribe|list|remove|test}") + print("Run 'hermes webhook --help' for details.") + return + + if not _require_webhook_enabled(): + return + + if sub in ("subscribe", "add"): + _cmd_subscribe(args) + elif sub in ("list", "ls"): + _cmd_list(args) + elif sub in ("remove", "rm"): + _cmd_remove(args) + elif sub == "test": + _cmd_test(args) + + +def _cmd_subscribe(args): + name = args.name.strip().lower().replace(" ", "-") + if not re.match(r'^[a-z0-9][a-z0-9_-]*$', name): + print(f"Error: Invalid name '{name}'. Use lowercase alphanumeric with hyphens/underscores.") + return + + subs = _load_subscriptions() + is_update = name in subs + + secret = args.secret or secrets.token_urlsafe(32) + events = [e.strip() for e in args.events.split(",")] if args.events else [] + + route = { + "description": args.description or f"Agent-created subscription: {name}", + "events": events, + "secret": secret, + "prompt": args.prompt or "", + "skills": [s.strip() for s in args.skills.split(",")] if args.skills else [], + "deliver": args.deliver or "log", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + + if args.deliver_chat_id: + route["deliver_extra"] = {"chat_id": args.deliver_chat_id} + + subs[name] = route + _save_subscriptions(subs) + + base_url = _get_webhook_base_url() + status = "Updated" if is_update else "Created" + + print(f"\n {status} webhook subscription: {name}") + print(f" URL: {base_url}/webhooks/{name}") + print(f" Secret: {secret}") + if events: + print(f" Events: {', '.join(events)}") + else: + print(" Events: (all)") + print(f" Deliver: {route['deliver']}") + if route.get("prompt"): + prompt_preview = route["prompt"][:80] + ("..." if len(route["prompt"]) > 80 else "") + print(f" Prompt: {prompt_preview}") + print(f"\n Configure your service to POST to the URL above.") + print(f" Use the secret for HMAC-SHA256 signature validation.") + print(f" The gateway must be running to receive events (hermes gateway run).\n") + + +def _cmd_list(args): + subs = _load_subscriptions() + if not subs: + print(" No dynamic webhook subscriptions.") + print(" Create one with: hermes webhook subscribe ") + return + + base_url = _get_webhook_base_url() + print(f"\n {len(subs)} webhook subscription(s):\n") + for name, route in subs.items(): + events = ", ".join(route.get("events", [])) or "(all)" + deliver = route.get("deliver", "log") + desc = route.get("description", "") + print(f" ◆ {name}") + if desc: + print(f" {desc}") + print(f" URL: {base_url}/webhooks/{name}") + print(f" Events: {events}") + print(f" Deliver: {deliver}") + print() + + +def _cmd_remove(args): + name = args.name.strip().lower() + subs = _load_subscriptions() + + if name not in subs: + print(f" No subscription named '{name}'.") + print(" Note: Static routes from config.yaml cannot be removed here.") + return + + del subs[name] + _save_subscriptions(subs) + print(f" Removed webhook subscription: {name}") + + +def _cmd_test(args): + """Send a test POST to a webhook route.""" + name = args.name.strip().lower() + subs = _load_subscriptions() + + if name not in subs: + print(f" No subscription named '{name}'.") + return + + route = subs[name] + secret = route.get("secret", "") + base_url = _get_webhook_base_url() + url = f"{base_url}/webhooks/{name}" + + payload = args.payload or '{"test": true, "event_type": "test", "message": "Hello from hermes webhook test"}' + + import hmac + import hashlib + sig = "sha256=" + hmac.new( + secret.encode(), payload.encode(), hashlib.sha256 + ).hexdigest() + + print(f" Sending test POST to {url}") + try: + import urllib.request + req = urllib.request.Request( + url, + data=payload.encode(), + headers={ + "Content-Type": "application/json", + "X-Hub-Signature-256": sig, + "X-GitHub-Event": "test", + }, + method="POST", + ) + with urllib.request.urlopen(req, timeout=10) as resp: + body = resp.read().decode() + print(f" Response ({resp.status}): {body}") + except Exception as e: + print(f" Error: {e}") + print(" Is the gateway running? (hermes gateway run)") diff --git a/hermes_constants.py b/hermes_constants.py index 518472023..2bfc0a8c7 100644 --- a/hermes_constants.py +++ b/hermes_constants.py @@ -17,6 +17,47 @@ def get_hermes_home() -> Path: return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) +def get_hermes_dir(new_subpath: str, old_name: str) -> Path: + """Resolve a Hermes subdirectory with backward compatibility. + + New installs get the consolidated layout (e.g. ``cache/images``). + Existing installs that already have the old path (e.g. ``image_cache``) + keep using it — no migration required. + + Args: + new_subpath: Preferred path relative to HERMES_HOME (e.g. ``"cache/images"``). + old_name: Legacy path relative to HERMES_HOME (e.g. ``"image_cache"``). + + Returns: + Absolute ``Path`` — old location if it exists on disk, otherwise the new one. + """ + home = get_hermes_home() + old_path = home / old_name + if old_path.exists(): + return old_path + return home / new_subpath + + +def display_hermes_home() -> str: + """Return a user-friendly display string for the current HERMES_HOME. + + Uses ``~/`` shorthand for readability:: + + default: ``~/.hermes`` + profile: ``~/.hermes/profiles/coder`` + custom: ``/opt/hermes-custom`` + + Use this in **user-facing** print/log messages instead of hardcoding + ``~/.hermes``. For code that needs a real ``Path``, use + :func:`get_hermes_home` instead. + """ + home = get_hermes_home() + try: + return "~/" + str(home.relative_to(Path.home())) + except ValueError: + return str(home) + + VALID_REASONING_EFFORTS = ("xhigh", "high", "medium", "low", "minimal") diff --git a/hermes_state.py b/hermes_state.py index d3088fce6..af74ed6ff 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -15,15 +15,20 @@ Key design decisions: """ import json +import logging import os +import random import re import sqlite3 import threading import time from pathlib import Path from hermes_constants import get_hermes_home -from typing import Dict, Any, List, Optional +from typing import Any, Callable, Dict, List, Optional, TypeVar +logger = logging.getLogger(__name__) + +T = TypeVar("T") DEFAULT_DB_PATH = get_hermes_home() / "state.db" @@ -116,18 +121,38 @@ class SessionDB: single writer via WAL mode). Each method opens its own cursor. """ + # ── Write-contention tuning ── + # With multiple hermes processes (gateway + CLI sessions + worktree agents) + # all sharing one state.db, WAL write-lock contention causes visible TUI + # freezes. SQLite's built-in busy handler uses a deterministic sleep + # schedule that causes convoy effects under high concurrency. + # + # Instead, we keep the SQLite timeout short (1s) and handle retries at the + # application level with random jitter, which naturally staggers competing + # writers and avoids the convoy. + _WRITE_MAX_RETRIES = 15 + _WRITE_RETRY_MIN_S = 0.020 # 20ms + _WRITE_RETRY_MAX_S = 0.150 # 150ms + # Attempt a PASSIVE WAL checkpoint every N successful writes. + _CHECKPOINT_EVERY_N_WRITES = 50 + def __init__(self, db_path: Path = None): self.db_path = db_path or DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) self._lock = threading.Lock() + self._write_count = 0 self._conn = sqlite3.connect( str(self.db_path), check_same_thread=False, - # 30s gives the WAL writer (CLI or gateway) time to finish a batch - # flush before the concurrent reader/writer gives up. 10s was too - # short when the CLI is doing frequent memory flushes. - timeout=30.0, + # Short timeout — application-level retry with random jitter + # handles contention instead of sitting in SQLite's internal + # busy handler for up to 30s. + timeout=1.0, + # Autocommit mode: Python's default isolation_level="" auto-starts + # transactions on DML, which conflicts with our explicit + # BEGIN IMMEDIATE. None = we manage transactions ourselves. + isolation_level=None, ) self._conn.row_factory = sqlite3.Row self._conn.execute("PRAGMA journal_mode=WAL") @@ -135,6 +160,96 @@ class SessionDB: self._init_schema() + # ── Core write helper ── + + def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T: + """Execute a write transaction with BEGIN IMMEDIATE and jitter retry. + + *fn* receives the connection and should perform INSERT/UPDATE/DELETE + statements. The caller must NOT call ``commit()`` — that's handled + here after *fn* returns. + + BEGIN IMMEDIATE acquires the WAL write lock at transaction start + (not at commit time), so lock contention surfaces immediately. + On ``database is locked``, we release the Python lock, sleep a + random 20-150ms, and retry — breaking the convoy pattern that + SQLite's built-in deterministic backoff creates. + + Returns whatever *fn* returns. + """ + last_err: Optional[Exception] = None + for attempt in range(self._WRITE_MAX_RETRIES): + try: + with self._lock: + self._conn.execute("BEGIN IMMEDIATE") + try: + result = fn(self._conn) + self._conn.commit() + except BaseException: + try: + self._conn.rollback() + except Exception: + pass + raise + # Success — periodic best-effort checkpoint. + self._write_count += 1 + if self._write_count % self._CHECKPOINT_EVERY_N_WRITES == 0: + self._try_wal_checkpoint() + return result + except sqlite3.OperationalError as exc: + err_msg = str(exc).lower() + if "locked" in err_msg or "busy" in err_msg: + last_err = exc + if attempt < self._WRITE_MAX_RETRIES - 1: + jitter = random.uniform( + self._WRITE_RETRY_MIN_S, + self._WRITE_RETRY_MAX_S, + ) + time.sleep(jitter) + continue + # Non-lock error or retries exhausted — propagate. + raise + # Retries exhausted (shouldn't normally reach here). + raise last_err or sqlite3.OperationalError( + "database is locked after max retries" + ) + + def _try_wal_checkpoint(self) -> None: + """Best-effort PASSIVE WAL checkpoint. Never blocks, never raises. + + Flushes committed WAL frames back into the main DB file for any + frames that no other connection currently needs. Keeps the WAL + from growing unbounded when many processes hold persistent + connections. + """ + try: + with self._lock: + result = self._conn.execute( + "PRAGMA wal_checkpoint(PASSIVE)" + ).fetchone() + if result and result[1] > 0: + logger.debug( + "WAL checkpoint: %d/%d pages checkpointed", + result[2], result[1], + ) + except Exception: + pass # Best effort — never fatal. + + def close(self): + """Close the database connection. + + Attempts a PASSIVE WAL checkpoint first so that exiting processes + help keep the WAL file from growing unbounded. + """ + with self._lock: + if self._conn: + try: + self._conn.execute("PRAGMA wal_checkpoint(PASSIVE)") + except Exception: + pass + self._conn.close() + self._conn = None + def _init_schema(self): """Create tables and FTS if they don't exist, run migrations.""" cursor = self._conn.cursor() @@ -256,8 +371,8 @@ class SessionDB: parent_session_id: str = None, ) -> str: """Create a new session record. Returns the session_id.""" - with self._lock: - self._conn.execute( + def _do(conn): + conn.execute( """INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config, system_prompt, parent_session_id, started_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", @@ -272,26 +387,35 @@ class SessionDB: time.time(), ), ) - self._conn.commit() + self._execute_write(_do) return session_id def end_session(self, session_id: str, end_reason: str) -> None: """Mark a session as ended.""" - with self._lock: - self._conn.execute( + def _do(conn): + conn.execute( "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", (time.time(), end_reason, session_id), ) - self._conn.commit() + self._execute_write(_do) + + def reopen_session(self, session_id: str) -> None: + """Clear ended_at/end_reason so a session can be resumed.""" + def _do(conn): + conn.execute( + "UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?", + (session_id,), + ) + self._execute_write(_do) def update_system_prompt(self, session_id: str, system_prompt: str) -> None: """Store the full assembled system prompt snapshot.""" - with self._lock: - self._conn.execute( + def _do(conn): + conn.execute( "UPDATE sessions SET system_prompt = ? WHERE id = ?", (system_prompt, session_id), ) - self._conn.commit() + self._execute_write(_do) def update_token_counts( self, @@ -310,11 +434,39 @@ class SessionDB: billing_provider: Optional[str] = None, billing_base_url: Optional[str] = None, billing_mode: Optional[str] = None, + absolute: bool = False, ) -> None: - """Increment token counters and backfill model if not already set.""" - with self._lock: - self._conn.execute( - """UPDATE sessions SET + """Update token counters and backfill model if not already set. + + When *absolute* is False (default), values are **incremented** — use + this for per-API-call deltas (CLI path). + + When *absolute* is True, values are **set directly** — use this when + the caller already holds cumulative totals (gateway path, where the + cached agent accumulates across messages). + """ + if absolute: + sql = """UPDATE sessions SET + input_tokens = ?, + output_tokens = ?, + cache_read_tokens = ?, + cache_write_tokens = ?, + reasoning_tokens = ?, + estimated_cost_usd = COALESCE(?, 0), + actual_cost_usd = CASE + WHEN ? IS NULL THEN actual_cost_usd + ELSE ? + END, + cost_status = COALESCE(?, cost_status), + cost_source = COALESCE(?, cost_source), + pricing_version = COALESCE(?, pricing_version), + billing_provider = COALESCE(billing_provider, ?), + billing_base_url = COALESCE(billing_base_url, ?), + billing_mode = COALESCE(billing_mode, ?), + model = COALESCE(model, ?) + WHERE id = ?""" + else: + sql = """UPDATE sessions SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, cache_read_tokens = cache_read_tokens + ?, @@ -332,6 +484,94 @@ class SessionDB: billing_base_url = COALESCE(billing_base_url, ?), billing_mode = COALESCE(billing_mode, ?), model = COALESCE(model, ?) + WHERE id = ?""" + params = ( + input_tokens, + output_tokens, + cache_read_tokens, + cache_write_tokens, + reasoning_tokens, + estimated_cost_usd, + actual_cost_usd, + actual_cost_usd, + cost_status, + cost_source, + pricing_version, + billing_provider, + billing_base_url, + billing_mode, + model, + session_id, + ) + def _do(conn): + conn.execute(sql, params) + self._execute_write(_do) + + def ensure_session( + self, + session_id: str, + source: str = "unknown", + model: str = None, + ) -> None: + """Ensure a session row exists, creating it with minimal metadata if absent. + + Used by _flush_messages_to_session_db to recover from a failed + create_session() call (e.g. transient SQLite lock at agent startup). + INSERT OR IGNORE is safe to call even when the row already exists. + """ + def _do(conn): + conn.execute( + """INSERT OR IGNORE INTO sessions + (id, source, model, started_at) + VALUES (?, ?, ?, ?)""", + (session_id, source, model, time.time()), + ) + self._execute_write(_do) + + def set_token_counts( + self, + session_id: str, + input_tokens: int = 0, + output_tokens: int = 0, + model: str = None, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + reasoning_tokens: int = 0, + estimated_cost_usd: Optional[float] = None, + actual_cost_usd: Optional[float] = None, + cost_status: Optional[str] = None, + cost_source: Optional[str] = None, + pricing_version: Optional[str] = None, + billing_provider: Optional[str] = None, + billing_base_url: Optional[str] = None, + billing_mode: Optional[str] = None, + ) -> None: + """Set token counters to absolute values (not increment). + + Use this when the caller provides cumulative totals from a completed + conversation run (e.g. the gateway, where the cached agent's + session_prompt_tokens already reflects the running total). + """ + def _do(conn): + conn.execute( + """UPDATE sessions SET + input_tokens = ?, + output_tokens = ?, + cache_read_tokens = ?, + cache_write_tokens = ?, + reasoning_tokens = ?, + estimated_cost_usd = ?, + actual_cost_usd = CASE + WHEN ? IS NULL THEN actual_cost_usd + ELSE ? + END, + cost_status = COALESCE(?, cost_status), + cost_source = COALESCE(?, cost_source), + pricing_version = COALESCE(?, pricing_version), + billing_provider = COALESCE(billing_provider, ?), + billing_base_url = COALESCE(billing_base_url, ?), + billing_mode = COALESCE(billing_mode, ?), + model = COALESCE(model, ?) WHERE id = ?""", ( input_tokens, @@ -352,28 +592,7 @@ class SessionDB: session_id, ), ) - self._conn.commit() - - def ensure_session( - self, - session_id: str, - source: str = "unknown", - model: str = None, - ) -> None: - """Ensure a session row exists, creating it with minimal metadata if absent. - - Used by _flush_messages_to_session_db to recover from a failed - create_session() call (e.g. transient SQLite lock at agent startup). - INSERT OR IGNORE is safe to call even when the row already exists. - """ - with self._lock: - self._conn.execute( - """INSERT OR IGNORE INTO sessions - (id, source, model, started_at) - VALUES (?, ?, ?, ?)""", - (session_id, source, model, time.time()), - ) - self._conn.commit() + self._execute_write(_do) def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" @@ -467,10 +686,10 @@ class SessionDB: Empty/whitespace-only strings are normalized to None (clearing the title). """ title = self.sanitize_title(title) - with self._lock: + def _do(conn): if title: # Check uniqueness (allow the same session to keep its own title) - cursor = self._conn.execute( + cursor = conn.execute( "SELECT id FROM sessions WHERE title = ? AND id != ?", (title, session_id), ) @@ -479,12 +698,12 @@ class SessionDB: raise ValueError( f"Title '{title}' is already in use by session {conflict['id']}" ) - cursor = self._conn.execute( + cursor = conn.execute( "UPDATE sessions SET title = ? WHERE id = ?", (title, session_id), ) - self._conn.commit() - rowcount = cursor.rowcount + return cursor.rowcount + rowcount = self._execute_write(_do) return rowcount > 0 def get_session_title(self, session_id: str) -> Optional[str]: @@ -656,17 +875,24 @@ class SessionDB: Also increments the session's message_count (and tool_call_count if role is 'tool' or tool_calls is present). """ - with self._lock: - # Serialize structured fields to JSON for storage - reasoning_details_json = ( - json.dumps(reasoning_details) - if reasoning_details else None - ) - codex_items_json = ( - json.dumps(codex_reasoning_items) - if codex_reasoning_items else None - ) - cursor = self._conn.execute( + # Serialize structured fields to JSON before entering the write txn + reasoning_details_json = ( + json.dumps(reasoning_details) + if reasoning_details else None + ) + codex_items_json = ( + json.dumps(codex_reasoning_items) + if codex_reasoning_items else None + ) + tool_calls_json = json.dumps(tool_calls) if tool_calls else None + + # Pre-compute tool call count + num_tool_calls = 0 + if tool_calls is not None: + num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 + + def _do(conn): + cursor = conn.execute( """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_details, codex_reasoning_items) @@ -676,7 +902,7 @@ class SessionDB: role, content, tool_call_id, - json.dumps(tool_calls) if tool_calls else None, + tool_calls_json, tool_name, time.time(), token_count, @@ -689,25 +915,20 @@ class SessionDB: msg_id = cursor.lastrowid # Update counters - # Count actual tool calls from the tool_calls list (not from tool responses). - # A single assistant message can contain multiple parallel tool calls. - num_tool_calls = 0 - if tool_calls is not None: - num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 if num_tool_calls > 0: - self._conn.execute( + conn.execute( """UPDATE sessions SET message_count = message_count + 1, tool_call_count = tool_call_count + ? WHERE id = ?""", (num_tool_calls, session_id), ) else: - self._conn.execute( + conn.execute( "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", (session_id,), ) + return msg_id - self._conn.commit() - return msg_id + return self._execute_write(_do) def get_messages(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" @@ -1001,54 +1222,53 @@ class SessionDB: def clear_messages(self, session_id: str) -> None: """Delete all messages for a session and reset its counters.""" - with self._lock: - self._conn.execute( + def _do(conn): + conn.execute( "DELETE FROM messages WHERE session_id = ?", (session_id,) ) - self._conn.execute( + conn.execute( "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", (session_id,), ) - self._conn.commit() + self._execute_write(_do) def delete_session(self, session_id: str) -> bool: """Delete a session and all its messages. Returns True if found.""" - with self._lock: - cursor = self._conn.execute( + def _do(conn): + cursor = conn.execute( "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) ) if cursor.fetchone()[0] == 0: return False - self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) - self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) - self._conn.commit() + conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) + conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) return True + return self._execute_write(_do) def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: """ Delete sessions older than N days. Returns count of deleted sessions. Only prunes ended sessions (not active ones). """ - import time as _time - cutoff = _time.time() - (older_than_days * 86400) + cutoff = time.time() - (older_than_days * 86400) - with self._lock: + def _do(conn): if source: - cursor = self._conn.execute( + cursor = conn.execute( """SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", (cutoff, source), ) else: - cursor = self._conn.execute( + cursor = conn.execute( "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", (cutoff,), ) session_ids = [row["id"] for row in cursor.fetchall()] for sid in session_ids: - self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) - self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) + conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) + conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) + return len(session_ids) - self._conn.commit() - return len(session_ids) + return self._execute_write(_do) diff --git a/honcho_integration/cli.py b/honcho_integration/cli.py index 78a0d4b78..ae09c3713 100644 --- a/honcho_integration/cli.py +++ b/honcho_integration/cli.py @@ -270,7 +270,7 @@ def cmd_status(args) -> None: print(f" {peer}: {mode}") print(f" Write freq: {hcfg.write_frequency}") - if hcfg.enabled and hcfg.api_key: + if hcfg.enabled and (hcfg.api_key or hcfg.base_url): print("\n Connection... ", end="", flush=True) try: get_honcho_client(hcfg) @@ -278,7 +278,7 @@ def cmd_status(args) -> None: except Exception as e: print(f"FAILED ({e})\n") else: - reason = "disabled" if not hcfg.enabled else "no API key" + reason = "disabled" if not hcfg.enabled else "no API key or base URL" print(f"\n Not connected ({reason})\n") diff --git a/honcho_integration/client.py b/honcho_integration/client.py index 385974d12..50f7af30a 100644 --- a/honcho_integration/client.py +++ b/honcho_integration/client.py @@ -417,9 +417,18 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho: else: logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) + # Local Honcho instances don't require an API key, but the SDK + # expects a non-empty string. Use a placeholder for local URLs. + _is_local = resolved_base_url and ( + "localhost" in resolved_base_url + or "127.0.0.1" in resolved_base_url + or "::1" in resolved_base_url + ) + effective_api_key = config.api_key or ("local" if _is_local else None) + kwargs: dict = { "workspace_id": config.workspace_id, - "api_key": config.api_key, + "api_key": effective_api_key, "environment": config.environment, } if resolved_base_url: diff --git a/mcp_serve.py b/mcp_serve.py new file mode 100644 index 000000000..93c439795 --- /dev/null +++ b/mcp_serve.py @@ -0,0 +1,868 @@ +""" +Hermes MCP Server — expose messaging conversations as MCP tools. + +Starts a stdio MCP server that lets any MCP client (Claude Code, Cursor, Codex, +etc.) list conversations, read message history, send messages, poll for live +events, and manage approval requests across all connected platforms. + +Matches OpenClaw's 9-tool MCP channel bridge surface: + conversations_list, conversation_get, messages_read, attachments_fetch, + events_poll, events_wait, messages_send, permissions_list_open, + permissions_respond + +Plus: channels_list (Hermes-specific extra) + +Usage: + hermes mcp serve + hermes mcp serve --verbose + +MCP client config (e.g. claude_desktop_config.json): + { + "mcpServers": { + "hermes": { + "command": "hermes", + "args": ["mcp", "serve"] + } + } + } +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import sys +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger("hermes.mcp_serve") + +# --------------------------------------------------------------------------- +# Lazy MCP SDK import +# --------------------------------------------------------------------------- + +_MCP_SERVER_AVAILABLE = False +try: + from mcp.server.fastmcp import FastMCP + + _MCP_SERVER_AVAILABLE = True +except ImportError: + FastMCP = None # type: ignore[assignment,misc] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_sessions_dir() -> Path: + """Return the sessions directory using HERMES_HOME.""" + try: + from hermes_constants import get_hermes_home + return get_hermes_home() / "sessions" + except ImportError: + return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "sessions" + + +def _get_session_db(): + """Get a SessionDB instance for reading message transcripts.""" + try: + from hermes_state import SessionDB + return SessionDB() + except Exception as e: + logger.debug("SessionDB unavailable: %s", e) + return None + + +def _load_sessions_index() -> dict: + """Load the gateway sessions.json index directly. + + Returns a dict of session_key -> entry_dict with platform routing info. + This avoids importing the full SessionStore which needs GatewayConfig. + """ + sessions_file = _get_sessions_dir() / "sessions.json" + if not sessions_file.exists(): + return {} + try: + with open(sessions_file, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.debug("Failed to load sessions.json: %s", e) + return {} + + +def _load_channel_directory() -> dict: + """Load the cached channel directory for available targets.""" + try: + from hermes_constants import get_hermes_home + directory_file = get_hermes_home() / "channel_directory.json" + except ImportError: + directory_file = Path( + os.environ.get("HERMES_HOME", Path.home() / ".hermes") + ) / "channel_directory.json" + + if not directory_file.exists(): + return {} + try: + with open(directory_file, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.debug("Failed to load channel_directory.json: %s", e) + return {} + + +def _extract_message_content(msg: dict) -> str: + """Extract text content from a message, handling multi-part content.""" + content = msg.get("content", "") + if isinstance(content, list): + text_parts = [ + p.get("text", "") for p in content + if isinstance(p, dict) and p.get("type") == "text" + ] + return "\n".join(text_parts) + return str(content) if content else "" + + +def _extract_attachments(msg: dict) -> List[dict]: + """Extract non-text attachments from a message. + + Finds: multi-part image/file content blocks, MEDIA: tags in text, + image URLs, and file references. + """ + attachments = [] + content = msg.get("content", "") + + # Multi-part content blocks (image_url, file, etc.) + if isinstance(content, list): + for part in content: + if not isinstance(part, dict): + continue + ptype = part.get("type", "") + if ptype == "image_url": + url = part.get("image_url", {}).get("url", "") if isinstance(part.get("image_url"), dict) else "" + if url: + attachments.append({"type": "image", "url": url}) + elif ptype == "image": + url = part.get("url", part.get("source", {}).get("url", "")) + if url: + attachments.append({"type": "image", "url": url}) + elif ptype not in ("text",): + # Unknown non-text content type + attachments.append({"type": ptype, "data": part}) + + # MEDIA: tags in text content + text = _extract_message_content(msg) + if text: + media_pattern = re.compile(r'MEDIA:\s*(\S+)') + for match in media_pattern.finditer(text): + path = match.group(1) + attachments.append({"type": "media", "path": path}) + + return attachments + + +# --------------------------------------------------------------------------- +# Event Bridge — polls SessionDB for new messages, maintains event queue +# --------------------------------------------------------------------------- + +QUEUE_LIMIT = 1000 +POLL_INTERVAL = 0.2 # seconds between DB polls (200ms) + + +@dataclass +class QueueEvent: + """An event in the bridge's in-memory queue.""" + cursor: int + type: str # "message", "approval_requested", "approval_resolved" + session_key: str = "" + data: dict = field(default_factory=dict) + + +class EventBridge: + """Background poller that watches SessionDB for new messages and + maintains an in-memory event queue with waiter support. + + This is the Hermes equivalent of OpenClaw's WebSocket gateway bridge. + Instead of WebSocket events, we poll the SQLite database for changes. + """ + + def __init__(self): + self._queue: List[QueueEvent] = [] + self._cursor = 0 + self._lock = threading.Lock() + self._new_event = threading.Event() + self._running = False + self._thread: Optional[threading.Thread] = None + self._last_poll_timestamps: Dict[str, float] = {} # session_key -> unix timestamp + # In-memory approval tracking (populated from events) + self._pending_approvals: Dict[str, dict] = {} + # mtime cache — skip expensive work when files haven't changed + self._sessions_json_mtime: float = 0.0 + self._state_db_mtime: float = 0.0 + self._cached_sessions_index: dict = {} + + def start(self): + """Start the background polling thread.""" + if self._running: + return + self._running = True + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + logger.debug("EventBridge started") + + def stop(self): + """Stop the background polling thread.""" + self._running = False + self._new_event.set() # Wake any waiters + if self._thread: + self._thread.join(timeout=5) + logger.debug("EventBridge stopped") + + def poll_events( + self, + after_cursor: int = 0, + session_key: Optional[str] = None, + limit: int = 20, + ) -> dict: + """Return events since after_cursor, optionally filtered by session_key.""" + with self._lock: + events = [ + e for e in self._queue + if e.cursor > after_cursor + and (not session_key or e.session_key == session_key) + ][:limit] + + next_cursor = events[-1].cursor if events else after_cursor + return { + "events": [ + {"cursor": e.cursor, "type": e.type, + "session_key": e.session_key, **e.data} + for e in events + ], + "next_cursor": next_cursor, + } + + def wait_for_event( + self, + after_cursor: int = 0, + session_key: Optional[str] = None, + timeout_ms: int = 30000, + ) -> Optional[dict]: + """Block until a matching event arrives or timeout expires.""" + deadline = time.monotonic() + (timeout_ms / 1000.0) + + while time.monotonic() < deadline: + with self._lock: + for e in self._queue: + if e.cursor > after_cursor and ( + not session_key or e.session_key == session_key + ): + return { + "cursor": e.cursor, "type": e.type, + "session_key": e.session_key, **e.data, + } + + remaining = deadline - time.monotonic() + if remaining <= 0: + break + self._new_event.clear() + self._new_event.wait(timeout=min(remaining, POLL_INTERVAL)) + + return None + + def list_pending_approvals(self) -> List[dict]: + """List approval requests observed during this bridge session.""" + with self._lock: + return sorted( + self._pending_approvals.values(), + key=lambda a: a.get("created_at", ""), + ) + + def respond_to_approval(self, approval_id: str, decision: str) -> dict: + """Resolve a pending approval (best-effort without gateway IPC).""" + with self._lock: + approval = self._pending_approvals.pop(approval_id, None) + + if not approval: + return {"error": f"Approval not found: {approval_id}"} + + self._enqueue(QueueEvent( + cursor=0, # Will be set by _enqueue + type="approval_resolved", + session_key=approval.get("session_key", ""), + data={"approval_id": approval_id, "decision": decision}, + )) + + return {"resolved": True, "approval_id": approval_id, "decision": decision} + + def _enqueue(self, event: QueueEvent) -> None: + """Add an event to the queue and wake any waiters.""" + with self._lock: + self._cursor += 1 + event.cursor = self._cursor + self._queue.append(event) + # Trim queue to limit + while len(self._queue) > QUEUE_LIMIT: + self._queue.pop(0) + self._new_event.set() + + def _poll_loop(self): + """Background loop: poll SessionDB for new messages.""" + db = _get_session_db() + if not db: + logger.warning("EventBridge: SessionDB unavailable, event polling disabled") + return + + while self._running: + try: + self._poll_once(db) + except Exception as e: + logger.debug("EventBridge poll error: %s", e) + time.sleep(POLL_INTERVAL) + + def _poll_once(self, db): + """Check for new messages across all sessions. + + Uses mtime checks on sessions.json and state.db to skip work + when nothing has changed — makes 200ms polling essentially free. + """ + # Check if sessions.json has changed (mtime check is ~1μs) + sessions_file = _get_sessions_dir() / "sessions.json" + try: + sj_mtime = sessions_file.stat().st_mtime if sessions_file.exists() else 0.0 + except OSError: + sj_mtime = 0.0 + + if sj_mtime != self._sessions_json_mtime: + self._sessions_json_mtime = sj_mtime + self._cached_sessions_index = _load_sessions_index() + + # Check if state.db has changed + try: + from hermes_constants import get_hermes_home + db_file = get_hermes_home() / "state.db" + except ImportError: + db_file = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) / "state.db" + + try: + db_mtime = db_file.stat().st_mtime if db_file.exists() else 0.0 + except OSError: + db_mtime = 0.0 + + if db_mtime == self._state_db_mtime and sj_mtime == self._sessions_json_mtime: + return # Nothing changed since last poll — skip entirely + + self._state_db_mtime = db_mtime + entries = self._cached_sessions_index + + for session_key, entry in entries.items(): + session_id = entry.get("session_id", "") + if not session_id: + continue + + last_seen = self._last_poll_timestamps.get(session_key, 0.0) + + try: + messages = db.get_messages(session_id) + except Exception: + continue + + if not messages: + continue + + # Normalize timestamps to float for comparison + def _ts_float(ts) -> float: + if isinstance(ts, (int, float)): + return float(ts) + if isinstance(ts, str) and ts: + try: + return float(ts) + except ValueError: + # ISO string — parse to epoch + try: + from datetime import datetime + return datetime.fromisoformat(ts).timestamp() + except Exception: + return 0.0 + return 0.0 + + # Find messages newer than our last seen timestamp + new_messages = [] + for msg in messages: + ts = _ts_float(msg.get("timestamp", 0)) + role = msg.get("role", "") + if role not in ("user", "assistant"): + continue + if ts > last_seen: + new_messages.append(msg) + + for msg in new_messages: + content = _extract_message_content(msg) + if not content: + continue + self._enqueue(QueueEvent( + cursor=0, + type="message", + session_key=session_key, + data={ + "role": msg.get("role", ""), + "content": content[:500], + "timestamp": str(msg.get("timestamp", "")), + "message_id": str(msg.get("id", "")), + }, + )) + + # Update last seen to the most recent message timestamp + all_ts = [_ts_float(m.get("timestamp", 0)) for m in messages] + if all_ts: + latest = max(all_ts) + if latest > last_seen: + self._last_poll_timestamps[session_key] = latest + + +# --------------------------------------------------------------------------- +# MCP Server +# --------------------------------------------------------------------------- + +def create_mcp_server(event_bridge: Optional[EventBridge] = None) -> "FastMCP": + """Create and return the Hermes MCP server with all tools registered.""" + if not _MCP_SERVER_AVAILABLE: + raise ImportError( + "MCP server requires the 'mcp' package. " + "Install with: pip install 'hermes-agent[mcp]'" + ) + + mcp = FastMCP( + "hermes", + instructions=( + "Hermes Agent messaging bridge. Use these tools to interact with " + "conversations across Telegram, Discord, Slack, WhatsApp, Signal, " + "Matrix, and other connected platforms." + ), + ) + + bridge = event_bridge or EventBridge() + + # -- conversations_list ------------------------------------------------ + + @mcp.tool() + def conversations_list( + platform: Optional[str] = None, + limit: int = 50, + search: Optional[str] = None, + ) -> str: + """List active messaging conversations across connected platforms. + + Returns conversations with their session keys (needed for messages_read), + platform, chat type, display name, and last activity time. + + Args: + platform: Filter by platform name (telegram, discord, slack, etc.) + limit: Maximum number of conversations to return (default 50) + search: Optional text to filter conversations by name + """ + entries = _load_sessions_index() + conversations = [] + + for key, entry in entries.items(): + origin = entry.get("origin", {}) + entry_platform = entry.get("platform") or origin.get("platform", "") + + if platform and entry_platform.lower() != platform.lower(): + continue + + display_name = entry.get("display_name", "") + chat_name = origin.get("chat_name", "") + if search: + search_lower = search.lower() + if (search_lower not in display_name.lower() + and search_lower not in chat_name.lower() + and search_lower not in key.lower()): + continue + + conversations.append({ + "session_key": key, + "session_id": entry.get("session_id", ""), + "platform": entry_platform, + "chat_type": entry.get("chat_type", origin.get("chat_type", "")), + "display_name": display_name, + "chat_name": chat_name, + "user_name": origin.get("user_name", ""), + "updated_at": entry.get("updated_at", ""), + }) + + conversations.sort(key=lambda c: c.get("updated_at", ""), reverse=True) + conversations = conversations[:limit] + + return json.dumps({ + "count": len(conversations), + "conversations": conversations, + }, indent=2) + + # -- conversation_get -------------------------------------------------- + + @mcp.tool() + def conversation_get(session_key: str) -> str: + """Get detailed info about one conversation by its session key. + + Args: + session_key: The session key from conversations_list + """ + entries = _load_sessions_index() + entry = entries.get(session_key) + + if not entry: + return json.dumps({"error": f"Conversation not found: {session_key}"}) + + origin = entry.get("origin", {}) + return json.dumps({ + "session_key": session_key, + "session_id": entry.get("session_id", ""), + "platform": entry.get("platform") or origin.get("platform", ""), + "chat_type": entry.get("chat_type", origin.get("chat_type", "")), + "display_name": entry.get("display_name", ""), + "user_name": origin.get("user_name", ""), + "chat_name": origin.get("chat_name", ""), + "chat_id": origin.get("chat_id", ""), + "thread_id": origin.get("thread_id"), + "updated_at": entry.get("updated_at", ""), + "created_at": entry.get("created_at", ""), + "input_tokens": entry.get("input_tokens", 0), + "output_tokens": entry.get("output_tokens", 0), + "total_tokens": entry.get("total_tokens", 0), + }, indent=2) + + # -- messages_read ----------------------------------------------------- + + @mcp.tool() + def messages_read( + session_key: str, + limit: int = 50, + ) -> str: + """Read recent messages from a conversation. + + Returns the message history in chronological order with role, content, + and timestamp for each message. + + Args: + session_key: The session key from conversations_list + limit: Maximum number of messages to return (default 50, most recent) + """ + entries = _load_sessions_index() + entry = entries.get(session_key) + if not entry: + return json.dumps({"error": f"Conversation not found: {session_key}"}) + + session_id = entry.get("session_id", "") + if not session_id: + return json.dumps({"error": "No session ID for this conversation"}) + + db = _get_session_db() + if not db: + return json.dumps({"error": "Session database unavailable"}) + + try: + all_messages = db.get_messages(session_id) + except Exception as e: + return json.dumps({"error": f"Failed to read messages: {e}"}) + + filtered = [] + for msg in all_messages: + role = msg.get("role", "") + if role in ("user", "assistant"): + content = _extract_message_content(msg) + if content: + filtered.append({ + "id": str(msg.get("id", "")), + "role": role, + "content": content[:2000], + "timestamp": msg.get("timestamp", ""), + }) + + messages = filtered[-limit:] + + return json.dumps({ + "session_key": session_key, + "count": len(messages), + "total_in_session": len(filtered), + "messages": messages, + }, indent=2) + + # -- attachments_fetch ------------------------------------------------- + + @mcp.tool() + def attachments_fetch( + session_key: str, + message_id: str, + ) -> str: + """List non-text attachments for a message in a conversation. + + Extracts images, media files, and other non-text content blocks + from the specified message. + + Args: + session_key: The session key from conversations_list + message_id: The message ID from messages_read + """ + entries = _load_sessions_index() + entry = entries.get(session_key) + if not entry: + return json.dumps({"error": f"Conversation not found: {session_key}"}) + + session_id = entry.get("session_id", "") + if not session_id: + return json.dumps({"error": "No session ID for this conversation"}) + + db = _get_session_db() + if not db: + return json.dumps({"error": "Session database unavailable"}) + + try: + all_messages = db.get_messages(session_id) + except Exception as e: + return json.dumps({"error": f"Failed to read messages: {e}"}) + + # Find the target message + target_msg = None + for msg in all_messages: + if str(msg.get("id", "")) == message_id: + target_msg = msg + break + + if not target_msg: + return json.dumps({"error": f"Message not found: {message_id}"}) + + attachments = _extract_attachments(target_msg) + + return json.dumps({ + "message_id": message_id, + "count": len(attachments), + "attachments": attachments, + }, indent=2) + + # -- events_poll ------------------------------------------------------- + + @mcp.tool() + def events_poll( + after_cursor: int = 0, + session_key: Optional[str] = None, + limit: int = 20, + ) -> str: + """Poll for new conversation events since a cursor position. + + Returns events that have occurred since the given cursor. Use the + returned next_cursor value for subsequent polls. + + Event types: message, approval_requested, approval_resolved + + Args: + after_cursor: Return events after this cursor (0 for all) + session_key: Optional filter to one conversation + limit: Maximum events to return (default 20) + """ + result = bridge.poll_events( + after_cursor=after_cursor, + session_key=session_key, + limit=limit, + ) + return json.dumps(result, indent=2) + + # -- events_wait ------------------------------------------------------- + + @mcp.tool() + def events_wait( + after_cursor: int = 0, + session_key: Optional[str] = None, + timeout_ms: int = 30000, + ) -> str: + """Wait for the next conversation event (long-poll). + + Blocks until a matching event arrives or the timeout expires. + Use this for near-real-time event delivery without polling. + + Args: + after_cursor: Wait for events after this cursor + session_key: Optional filter to one conversation + timeout_ms: Maximum wait time in milliseconds (default 30000) + """ + event = bridge.wait_for_event( + after_cursor=after_cursor, + session_key=session_key, + timeout_ms=min(timeout_ms, 300000), # Cap at 5 minutes + ) + if event: + return json.dumps({"event": event}, indent=2) + return json.dumps({"event": None, "reason": "timeout"}, indent=2) + + # -- messages_send ----------------------------------------------------- + + @mcp.tool() + def messages_send( + target: str, + message: str, + ) -> str: + """Send a message to a platform conversation. + + The target format is "platform:chat_id" — same format used by the + channels_list tool. You can also use human-friendly channel names + that will be resolved automatically. + + Examples: + target="telegram:6308981865" + target="discord:#general" + target="slack:#engineering" + + Args: + target: Platform target in "platform:identifier" format + message: The message text to send + """ + if not target or not message: + return json.dumps({"error": "Both target and message are required"}) + + try: + from tools.send_message_tool import send_message_tool + result_str = send_message_tool( + {"action": "send", "target": target, "message": message} + ) + return result_str + except ImportError: + return json.dumps({"error": "Send message tool not available"}) + except Exception as e: + return json.dumps({"error": f"Send failed: {e}"}) + + # -- channels_list ----------------------------------------------------- + + @mcp.tool() + def channels_list(platform: Optional[str] = None) -> str: + """List available messaging channels and targets across platforms. + + Returns channels that you can send messages to. The target strings + returned here can be used directly with the messages_send tool. + + Args: + platform: Filter by platform name (telegram, discord, slack, etc.) + """ + directory = _load_channel_directory() + if not directory: + entries = _load_sessions_index() + targets = [] + seen = set() + for key, entry in entries.items(): + origin = entry.get("origin", {}) + p = entry.get("platform") or origin.get("platform", "") + chat_id = origin.get("chat_id", "") + if not p or not chat_id: + continue + if platform and p.lower() != platform.lower(): + continue + target_str = f"{p}:{chat_id}" + if target_str in seen: + continue + seen.add(target_str) + targets.append({ + "target": target_str, + "platform": p, + "name": entry.get("display_name") or origin.get("chat_name", ""), + "chat_type": entry.get("chat_type", origin.get("chat_type", "")), + }) + return json.dumps({"count": len(targets), "channels": targets}, indent=2) + + channels = [] + for plat, entries_list in directory.items(): + if platform and plat.lower() != platform.lower(): + continue + if isinstance(entries_list, list): + for ch in entries_list: + if isinstance(ch, dict): + chat_id = ch.get("id", ch.get("chat_id", "")) + channels.append({ + "target": f"{plat}:{chat_id}" if chat_id else plat, + "platform": plat, + "name": ch.get("name", ch.get("display_name", "")), + "chat_type": ch.get("type", ""), + }) + + return json.dumps({"count": len(channels), "channels": channels}, indent=2) + + # -- permissions_list_open --------------------------------------------- + + @mcp.tool() + def permissions_list_open() -> str: + """List pending approval requests observed during this bridge session. + + Returns exec and plugin approval requests that the bridge has seen + since it started. Approvals are live-session only — older approvals + from before the bridge connected are not included. + """ + approvals = bridge.list_pending_approvals() + return json.dumps({ + "count": len(approvals), + "approvals": approvals, + }, indent=2) + + # -- permissions_respond ----------------------------------------------- + + @mcp.tool() + def permissions_respond( + id: str, + decision: str, + ) -> str: + """Respond to a pending approval request. + + Args: + id: The approval ID from permissions_list_open + decision: One of "allow-once", "allow-always", or "deny" + """ + if decision not in ("allow-once", "allow-always", "deny"): + return json.dumps({ + "error": f"Invalid decision: {decision}. " + f"Must be allow-once, allow-always, or deny" + }) + + result = bridge.respond_to_approval(id, decision) + return json.dumps(result, indent=2) + + return mcp + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def run_mcp_server(verbose: bool = False) -> None: + """Start the Hermes MCP server on stdio.""" + if not _MCP_SERVER_AVAILABLE: + print( + "Error: MCP server requires the 'mcp' package.\n" + "Install with: pip install 'hermes-agent[mcp]'", + file=sys.stderr, + ) + sys.exit(1) + + if verbose: + logging.basicConfig(level=logging.DEBUG, stream=sys.stderr) + else: + logging.basicConfig(level=logging.WARNING, stream=sys.stderr) + + bridge = EventBridge() + bridge.start() + + server = create_mcp_server(event_bridge=bridge) + + import asyncio + + async def _run(): + try: + await server.run_stdio_async() + finally: + bridge.stop() + + try: + asyncio.run(_run()) + except KeyboardInterrupt: + bridge.stop() diff --git a/nix/nixosModules.nix b/nix/nixosModules.nix index 178305a28..0e15c6f53 100644 --- a/nix/nixosModules.nix +++ b/nix/nixosModules.nix @@ -10,6 +10,12 @@ # container recreation. Environment variables are written to $HERMES_HOME/.env # and read by hermes at startup — no container recreation needed for env changes. # +# Tool resolution: the hermes wrapper uses --suffix PATH for nix store tools, +# so apt/uv-installed versions take priority. The container entrypoint provisions +# extensible tools on first boot: nodejs/npm via apt, uv via curl, and a Python +# 3.11 venv (bootstrapped entirely by uv) at ~/.venv with pip seeded. Agents get +# writable tool prefixes for npm i -g, pip install, uv tool install, etc. +# # Usage: # services.hermes-agent = { # enable = true; @@ -105,22 +111,52 @@ fi mkdir -p "$TARGET_HOME" chown "$HERMES_UID:$HERMES_GID" "$TARGET_HOME" + chmod 0750 "$TARGET_HOME" # Ensure HERMES_HOME is owned by the target user if [ -n "''${HERMES_HOME:-}" ] && [ -d "$HERMES_HOME" ]; then chown -R "$HERMES_UID:$HERMES_GID" "$HERMES_HOME" fi - # Install sudo on Debian/Ubuntu if missing (first boot only, cached in writable layer) - if command -v apt-get >/dev/null 2>&1 && ! command -v sudo >/dev/null 2>&1; then - apt-get update -qq >/dev/null 2>&1 && apt-get install -y -qq sudo >/dev/null 2>&1 || true + # ── Provision apt packages (first boot only, cached in writable layer) ── + # sudo: agent self-modification + # nodejs/npm: writable node so npm i -g works (nix store copies are read-only) + # curl: needed for uv installer + if [ ! -f /var/lib/hermes-tools-provisioned ] && command -v apt-get >/dev/null 2>&1; then + echo "First boot: provisioning agent tools..." + apt-get update -qq + apt-get install -y -qq sudo nodejs npm curl + touch /var/lib/hermes-tools-provisioned fi + if command -v sudo >/dev/null 2>&1 && [ ! -f /etc/sudoers.d/hermes ]; then mkdir -p /etc/sudoers.d echo "$TARGET_USER ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/hermes chmod 0440 /etc/sudoers.d/hermes fi + # uv (Python manager) — not in Ubuntu repos, retry-safe outside the sentinel + if ! command -v uv >/dev/null 2>&1 && [ ! -x "$TARGET_HOME/.local/bin/uv" ] && command -v curl >/dev/null 2>&1; then + su -s /bin/sh "$TARGET_USER" -c 'curl -LsSf https://astral.sh/uv/install.sh | sh' || true + fi + + # Python 3.11 venv — gives the agent a writable Python with pip. + # Uses uv to install Python 3.11 (Ubuntu 24.04 ships 3.12). + # --seed includes pip/setuptools so bare `pip install` works. + _UV_BIN="$TARGET_HOME/.local/bin/uv" + if [ ! -d "$TARGET_HOME/.venv" ] && [ -x "$_UV_BIN" ]; then + su -s /bin/sh "$TARGET_USER" -c " + export PATH=\"\$HOME/.local/bin:\$PATH\" + uv python install 3.11 + uv venv --python 3.11 --seed \"\$HOME/.venv\" + " || true + fi + + # Put the agent venv first on PATH so python/pip resolve to writable copies + if [ -d "$TARGET_HOME/.venv/bin" ]; then + export PATH="$TARGET_HOME/.venv/bin:$PATH" + fi + if command -v setpriv >/dev/null 2>&1; then exec setpriv --reuid="$HERMES_UID" --regid="$HERMES_GID" --init-groups "$@" elif command -v su >/dev/null 2>&1; then @@ -516,8 +552,8 @@ # ── Directories ─────────────────────────────────────────────────── { systemd.tmpfiles.rules = [ - "d ${cfg.stateDir} 0755 ${cfg.user} ${cfg.group} - -" - "d ${cfg.stateDir}/.hermes 0755 ${cfg.user} ${cfg.group} - -" + "d ${cfg.stateDir} 0750 ${cfg.user} ${cfg.group} - -" + "d ${cfg.stateDir}/.hermes 0750 ${cfg.user} ${cfg.group} - -" "d ${cfg.stateDir}/home 0750 ${cfg.user} ${cfg.group} - -" "d ${cfg.workingDirectory} 0750 ${cfg.user} ${cfg.group} - -" ]; @@ -531,21 +567,23 @@ mkdir -p ${cfg.stateDir}/home mkdir -p ${cfg.workingDirectory} chown ${cfg.user}:${cfg.group} ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory} + chmod 0750 ${cfg.stateDir} ${cfg.stateDir}/.hermes ${cfg.stateDir}/home ${cfg.workingDirectory} # Merge Nix settings into existing config.yaml. # Preserves user-added keys (skills, streaming, etc.); Nix keys win. # If configFile is user-provided (not generated), overwrite instead of merge. ${if cfg.configFile != null then '' - install -o ${cfg.user} -g ${cfg.group} -m 0644 -D ${configFile} ${cfg.stateDir}/.hermes/config.yaml + install -o ${cfg.user} -g ${cfg.group} -m 0640 -D ${configFile} ${cfg.stateDir}/.hermes/config.yaml '' else '' ${configMergeScript} ${generatedConfigFile} ${cfg.stateDir}/.hermes/config.yaml chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/config.yaml - chmod 0644 ${cfg.stateDir}/.hermes/config.yaml + chmod 0640 ${cfg.stateDir}/.hermes/config.yaml ''} # Managed mode marker (so interactive shells also detect NixOS management) touch ${cfg.stateDir}/.hermes/.managed chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.managed + chmod 0644 ${cfg.stateDir}/.hermes/.managed # Seed auth file if provided ${lib.optionalString (cfg.authFile != null) '' @@ -577,7 +615,7 @@ HERMES_NIX_ENV_EOF # Link documents into workspace ${lib.concatStringsSep "\n" (lib.mapAttrsToList (name: _value: '' - install -o ${cfg.user} -g ${cfg.group} -m 0644 ${documentDerivation}/${name} ${cfg.workingDirectory}/${name} + install -o ${cfg.user} -g ${cfg.group} -m 0640 ${documentDerivation}/${name} ${cfg.workingDirectory}/${name} '') cfg.documents)} ''; } diff --git a/nix/packages.nix b/nix/packages.nix index 8c2b7cbd9..805f76605 100644 --- a/nix/packages.nix +++ b/nix/packages.nix @@ -35,7 +35,7 @@ ${pkgs.lib.concatMapStringsSep "\n" (name: '' makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \ - --prefix PATH : "${runtimePath}" \ + --suffix PATH : "${runtimePath}" \ --set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills '') [ "hermes" "hermes-agent" "hermes-acp" ]} diff --git a/optional-skills/communication/DESCRIPTION.md b/optional-skills/communication/DESCRIPTION.md new file mode 100644 index 000000000..2f44b53b8 --- /dev/null +++ b/optional-skills/communication/DESCRIPTION.md @@ -0,0 +1 @@ +Communication and decision-making frameworks — structured response formats for proposals, trade-off analysis, and stakeholder-ready recommendations. diff --git a/optional-skills/communication/one-three-one-rule/SKILL.md b/optional-skills/communication/one-three-one-rule/SKILL.md new file mode 100644 index 000000000..ca0ccd449 --- /dev/null +++ b/optional-skills/communication/one-three-one-rule/SKILL.md @@ -0,0 +1,103 @@ +--- +name: one-three-one-rule +description: > + Structured decision-making framework for technical proposals and trade-off analysis. + When the user faces a choice between multiple approaches (architecture decisions, + tool selection, refactoring strategies, migration paths), this skill produces a + 1-3-1 format: one clear problem statement, three distinct options with pros/cons, + and one concrete recommendation with definition of done and implementation plan. + Use when the user asks for a "1-3-1", says "give me options", or needs help + choosing between competing approaches. +version: 1.0.0 +author: Willard Moore +license: MIT +category: communication +metadata: + hermes: + tags: [communication, decision-making, proposals, trade-offs] +--- + +# 1-3-1 Communication Rule + +Structured decision-making format for when a task has multiple viable approaches and the user needs a clear recommendation. Produces a concise problem framing, three options with trade-offs, and an actionable plan for the recommended path. + +## When to Use + +- The user explicitly asks for a "1-3-1" response. +- The user says "give me options" or "what are my choices" for a technical decision. +- A task has multiple viable approaches with meaningful trade-offs (architecture, tooling, migration strategy). +- The user needs a proposal they can forward to a team or stakeholder. + +Do NOT use for simple questions with one obvious answer, debugging sessions, or tasks where the user has already decided on an approach. + +## Procedure + +1. **Problem** (one sentence) + - State the core decision or desired outcome in a single concise sentence. + - Focus on the *what*, not the *how* — no implementation details, tool names, or specific technologies. + - Keep it tight. If you need "and", you're describing two problems. + +2. **Options** (exactly three) + - Present three distinct, viable approaches labeled A, B, C. + - Each option gets a brief description, pros, and cons. + - Options should represent genuinely different strategies, not minor variations of the same approach. + +3. **Recommendation** (one option) + - State which option you recommend and why, based on the user's context and priorities. + - Be direct — this is your professional judgment, not a hedge. + +4. **Definition of Done** + - List the specific success criteria for the recommended option. + - These are concrete, verifiable outcomes — not vague aspirations. + - If the user picks a different option, revise this section to match. + +5. **Implementation Plan** + - Concrete steps to execute the recommended option. + - Include specific commands, tools, or actions where applicable. + - If the user picks a different option, revise this section to match. + +## Verification + +- Response has exactly one Problem sentence. +- Response has exactly three Options (A, B, C) with pros and cons for each. +- Response has a single Recommendation that picks one option with reasoning. +- Definition of Done and Implementation Plan align with the recommended option. +- If the user selects a different option, Recommendation, DoD, and Implementation Plan update accordingly. + +## Example + +User: "Give me a 1-3-1 for adding retry logic to our API client." + +Agent: + +**Problem:** API calls to external services fail intermittently under load and need graceful retry handling. + +**Options:** + +- **Option A: Exponential backoff with jitter in a shared utility.** + Pros: Simple, well-understood pattern; single implementation shared across all callers. + Cons: No circuit-breaking; callers still block during retries. + +- **Option B: Retry library (tenacity) as a decorator on each endpoint call.** + Pros: Configurable per-endpoint; built-in support for circuit breakers and rate limits. + Cons: Adds a dependency; decorator pattern can obscure control flow. + +- **Option C: Queue-based async retry with dead letter handling.** + Pros: Non-blocking; failed calls are tracked and replayable; scales well. + Cons: Significant complexity increase; requires a queue backend; overkill for low-volume services. + +**Recommendation:** Option A. The call volume doesn't justify queue infrastructure, and a shared utility with exponential backoff + jitter covers the failure mode cleanly without new dependencies. + +**Definition of Done:** +- `api_retry()` utility handles transient HTTP errors (429, 502, 503, 504). +- Retries up to 3 times with exponential backoff and jitter. +- All external API calls route through the retry utility. +- Logs each retry attempt with endpoint and attempt number. +- Tests cover retry exhaustion, immediate success, and transient-then-success scenarios. + +**Implementation Plan:** +1. Create `utils/api_retry.py` with configurable max retries, base delay, and retryable status codes. +2. Add jitter using `random.uniform(0, base_delay)` to prevent thundering herd. +3. Wrap existing API calls in `api_client.py` with the retry utility. +4. Add unit tests mocking HTTP responses for each retry scenario. +5. Verify under load with a simple stress test against a flaky endpoint mock. diff --git a/optional-skills/productivity/canvas/SKILL.md b/optional-skills/productivity/canvas/SKILL.md new file mode 100644 index 000000000..88299d0ab --- /dev/null +++ b/optional-skills/productivity/canvas/SKILL.md @@ -0,0 +1,97 @@ +--- +name: canvas +description: Canvas LMS integration — fetch enrolled courses and assignments using API token authentication. +version: 1.0.0 +author: community +license: MIT +prerequisites: + env_vars: [CANVAS_API_TOKEN, CANVAS_BASE_URL] +metadata: + hermes: + tags: [Canvas, LMS, Education, Courses, Assignments] +--- + +# Canvas LMS — Course & Assignment Access + +Read-only access to Canvas LMS for listing courses and assignments. + +## Scripts + +- `scripts/canvas_api.py` — Python CLI for Canvas API calls + +## Setup + +1. Log in to your Canvas instance in a browser +2. Go to **Account → Settings** (click your profile icon, then Settings) +3. Scroll to **Approved Integrations** and click **+ New Access Token** +4. Name the token (e.g., "Hermes Agent"), set an optional expiry, and click **Generate Token** +5. Copy the token and add to `~/.hermes/.env`: + +``` +CANVAS_API_TOKEN=your_token_here +CANVAS_BASE_URL=https://yourschool.instructure.com +``` + +The base URL is whatever appears in your browser when you're logged into Canvas (no trailing slash). + +## Usage + +```bash +CANVAS="python $HERMES_HOME/skills/productivity/canvas/scripts/canvas_api.py" + +# List all active courses +$CANVAS list_courses --enrollment-state active + +# List all courses (any state) +$CANVAS list_courses + +# List assignments for a specific course +$CANVAS list_assignments 12345 + +# List assignments ordered by due date +$CANVAS list_assignments 12345 --order-by due_at +``` + +## Output Format + +**list_courses** returns: +```json +[{"id": 12345, "name": "Intro to CS", "course_code": "CS101", "workflow_state": "available", "start_at": "...", "end_at": "..."}] +``` + +**list_assignments** returns: +```json +[{"id": 67890, "name": "Homework 1", "due_at": "2025-02-15T23:59:00Z", "points_possible": 100, "submission_types": ["online_upload"], "html_url": "...", "description": "...", "course_id": 12345}] +``` + +Note: Assignment descriptions are truncated to 500 characters. The `html_url` field links to the full assignment page in Canvas. + +## API Reference (curl) + +```bash +# List courses +curl -s -H "Authorization: Bearer $CANVAS_API_TOKEN" \ + "$CANVAS_BASE_URL/api/v1/courses?enrollment_state=active&per_page=10" + +# List assignments for a course +curl -s -H "Authorization: Bearer $CANVAS_API_TOKEN" \ + "$CANVAS_BASE_URL/api/v1/courses/COURSE_ID/assignments?per_page=10&order_by=due_at" +``` + +Canvas uses `Link` headers for pagination. The Python script handles pagination automatically. + +## Rules + +- This skill is **read-only** — it only fetches data, never modifies courses or assignments +- On first use, verify auth by running `$CANVAS list_courses` — if it fails with 401, guide the user through setup +- Canvas rate-limits to ~700 requests per 10 minutes; check `X-Rate-Limit-Remaining` header if hitting limits + +## Troubleshooting + +| Problem | Fix | +|---------|-----| +| 401 Unauthorized | Token invalid or expired — regenerate in Canvas Settings | +| 403 Forbidden | Token lacks permission for this course | +| Empty course list | Try `--enrollment-state active` or omit the flag to see all states | +| Wrong institution | Verify `CANVAS_BASE_URL` matches the URL in your browser | +| Timeout errors | Check network connectivity to your Canvas instance | diff --git a/optional-skills/productivity/canvas/scripts/canvas_api.py b/optional-skills/productivity/canvas/scripts/canvas_api.py new file mode 100644 index 000000000..13599c575 --- /dev/null +++ b/optional-skills/productivity/canvas/scripts/canvas_api.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +"""Canvas LMS API CLI for Hermes Agent. + +A thin CLI wrapper around the Canvas REST API. +Authenticates using a personal access token from environment variables. + +Usage: + python canvas_api.py list_courses [--per-page N] [--enrollment-state STATE] + python canvas_api.py list_assignments COURSE_ID [--per-page N] [--order-by FIELD] +""" + +import argparse +import json +import os +import sys + +import requests + +CANVAS_API_TOKEN = os.environ.get("CANVAS_API_TOKEN", "") +CANVAS_BASE_URL = os.environ.get("CANVAS_BASE_URL", "").rstrip("/") + + +def _check_config(): + """Validate required environment variables are set.""" + missing = [] + if not CANVAS_API_TOKEN: + missing.append("CANVAS_API_TOKEN") + if not CANVAS_BASE_URL: + missing.append("CANVAS_BASE_URL") + if missing: + print( + f"Missing required environment variables: {', '.join(missing)}\n" + "Set them in ~/.hermes/.env or export them in your shell.\n" + "See the canvas skill SKILL.md for setup instructions.", + file=sys.stderr, + ) + sys.exit(1) + + +def _headers(): + return {"Authorization": f"Bearer {CANVAS_API_TOKEN}"} + + +def _paginated_get(url, params=None, max_items=200): + """Fetch all pages up to max_items, following Canvas Link headers.""" + results = [] + while url and len(results) < max_items: + resp = requests.get(url, headers=_headers(), params=params, timeout=30) + resp.raise_for_status() + results.extend(resp.json()) + params = None # params are included in the Link URL for subsequent pages + url = None + link = resp.headers.get("Link", "") + for part in link.split(","): + if 'rel="next"' in part: + url = part.split(";")[0].strip().strip("<>") + return results[:max_items] + + +# ========================================================================= +# Commands +# ========================================================================= + + +def list_courses(args): + """List enrolled courses.""" + _check_config() + url = f"{CANVAS_BASE_URL}/api/v1/courses" + params = {"per_page": args.per_page} + if args.enrollment_state: + params["enrollment_state"] = args.enrollment_state + try: + courses = _paginated_get(url, params) + except requests.HTTPError as e: + print(f"API error: {e.response.status_code} {e.response.text}", file=sys.stderr) + sys.exit(1) + output = [ + { + "id": c["id"], + "name": c.get("name", ""), + "course_code": c.get("course_code", ""), + "enrollment_term_id": c.get("enrollment_term_id"), + "start_at": c.get("start_at"), + "end_at": c.get("end_at"), + "workflow_state": c.get("workflow_state", ""), + } + for c in courses + ] + print(json.dumps(output, indent=2)) + + +def list_assignments(args): + """List assignments for a course.""" + _check_config() + url = f"{CANVAS_BASE_URL}/api/v1/courses/{args.course_id}/assignments" + params = {"per_page": args.per_page} + if args.order_by: + params["order_by"] = args.order_by + try: + assignments = _paginated_get(url, params) + except requests.HTTPError as e: + print(f"API error: {e.response.status_code} {e.response.text}", file=sys.stderr) + sys.exit(1) + output = [ + { + "id": a["id"], + "name": a.get("name", ""), + "description": (a.get("description") or "")[:500], + "due_at": a.get("due_at"), + "points_possible": a.get("points_possible"), + "submission_types": a.get("submission_types", []), + "html_url": a.get("html_url", ""), + "course_id": a.get("course_id"), + } + for a in assignments + ] + print(json.dumps(output, indent=2)) + + +# ========================================================================= +# CLI parser +# ========================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Canvas LMS API CLI for Hermes Agent" + ) + sub = parser.add_subparsers(dest="command", required=True) + + # --- list_courses --- + p = sub.add_parser("list_courses", help="List enrolled courses") + p.add_argument("--per-page", type=int, default=50, help="Results per page (default 50)") + p.add_argument( + "--enrollment-state", + default="", + help="Filter by enrollment state (active, invited_or_pending, completed)", + ) + p.set_defaults(func=list_courses) + + # --- list_assignments --- + p = sub.add_parser("list_assignments", help="List assignments for a course") + p.add_argument("course_id", help="Canvas course ID") + p.add_argument("--per-page", type=int, default=50, help="Results per page (default 50)") + p.add_argument( + "--order-by", + default="", + help="Order by field (due_at, name, position)", + ) + p.set_defaults(func=list_assignments) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/optional-skills/productivity/memento-flashcards/SKILL.md b/optional-skills/productivity/memento-flashcards/SKILL.md new file mode 100644 index 000000000..40eb174d9 --- /dev/null +++ b/optional-skills/productivity/memento-flashcards/SKILL.md @@ -0,0 +1,324 @@ +--- +name: memento-flashcards +description: >- + Spaced-repetition flashcard system. Create cards from facts or text, + chat with flashcards using free-text answers graded by the agent, + generate quizzes from YouTube transcripts, review due cards with + adaptive scheduling, and export/import decks as CSV. +version: 1.0.0 +author: Memento AI +license: MIT +platforms: [macos, linux] +metadata: + hermes: + tags: [Education, Flashcards, Spaced Repetition, Learning, Quiz, YouTube] + requires_toolsets: [terminal] + category: productivity +--- + +# Memento Flashcards — Spaced-Repetition Flashcard Skill + +## Overview + +Memento gives you a local, file-based flashcard system with spaced-repetition scheduling. +Users can chat with their flashcards by answering in free text and having the agent grade the response before scheduling the next review. +Use it whenever the user wants to: + +- **Remember a fact** — turn any statement into a Q/A flashcard +- **Study with spaced repetition** — review due cards with adaptive intervals and agent-graded free-text answers +- **Quiz from a YouTube video** — fetch a transcript and generate a 5-question quiz +- **Manage decks** — organise cards into collections, export/import CSV + +All card data lives in a single JSON file. No external API keys are required — you (the agent) generate flashcard content and quiz questions directly. + +User-facing response style for Memento Flashcards: +- Use plain text only. Do not use Markdown formatting in replies to the user. +- Keep review and quiz feedback brief and neutral. Avoid extra praise, pep, or long explanations. + +## When to Use + +Use this skill when the user wants to: +- Save facts as flashcards for later review +- Review due cards with spaced repetition +- Generate a quiz from a YouTube video transcript +- Import, export, inspect, or delete flashcard data + +Do not use this skill for general Q&A, coding help, or non-memory tasks. + +## Quick Reference + +| User intent | Action | +|---|---| +| "Remember that X" / "save this as a flashcard" | Generate a Q/A card, call `memento_cards.py add` | +| Sends a fact without mentioning flashcards | Ask "Want me to save this as a Memento flashcard?" — only create if confirmed | +| "Create a flashcard" | Ask for Q, A, collection; call `memento_cards.py add` | +| "Review my cards" | Call `memento_cards.py due`, present cards one-by-one | +| "Quiz me on [YouTube URL]" | Call `youtube_quiz.py fetch VIDEO_ID`, generate 5 questions, call `memento_cards.py add-quiz` | +| "Export my cards" | Call `memento_cards.py export --output PATH` | +| "Import cards from CSV" | Call `memento_cards.py import --file PATH --collection NAME` | +| "Show my stats" | Call `memento_cards.py stats` | +| "Delete a card" | Call `memento_cards.py delete --id ID` | +| "Delete a collection" | Call `memento_cards.py delete-collection --collection NAME` | + +## Card Storage + +Cards are stored in a JSON file at: + +``` +~/.hermes/skills/productivity/memento-flashcards/data/cards.json +``` + +**Never edit this file directly.** Always use `memento_cards.py` subcommands. The script handles atomic writes (write to temp file, then rename) to prevent corruption. + +The file is created automatically on first use. + +## Procedure + +### Creating Cards from Facts + +### Activation Rules + +Not every factual statement should become a flashcard. Use this three-tier check: + +1. **Explicit intent** — the user mentions "memento", "flashcard", "remember this", "save this card", "add a card", or similar phrasing that clearly requests a flashcard → **create the card directly**, no confirmation needed. +2. **Implicit intent** — the user sends a factual statement without mentioning flashcards (e.g. "The speed of light is 299,792 km/s") → **ask first**: "Want me to save this as a Memento flashcard?" Only create the card if the user confirms. +3. **No intent** — the message is a coding task, a question, instructions, normal conversation, or anything that is clearly not a fact to memorize → **do NOT activate this skill at all**. Let other skills or default behavior handle it. + +When activation is confirmed (tier 1 directly, tier 2 after confirmation), generate a flashcard: + +**Step 1:** Turn the statement into a Q/A pair. Use this format internally: + +``` +Turn the factual statement into a front-back pair. +Return exactly two lines: +Q: +A: + +Statement: "{statement}" +``` + +Rules: +- The question should test recall of the key fact +- The answer should be concise and direct + +**Step 2:** Call the script to store the card: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add \ + --question "What year did World War 2 end?" \ + --answer "1945" \ + --collection "History" +``` + +If the user doesn't specify a collection, use `"General"` as the default. + +The script outputs JSON confirming the created card. + +### Manual Card Creation + +When the user explicitly asks to create a flashcard, ask them for: +1. The question (front of card) +2. The answer (back of card) +3. The collection name (optional — default to `"General"`) + +Then call `memento_cards.py add` as above. + +### Reviewing Due Cards + +When the user wants to review, fetch all due cards: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due +``` + +This returns a JSON array of cards where `next_review_at <= now`. If a collection filter is needed: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due --collection "History" +``` + +**Review flow (free-text grading):** + +Here is an example of the EXACT interaction pattern you must follow. The user answers, you grade them, tell them the correct answer, then rate the card. + +**Example interaction:** + +> **Agent:** What year did the Berlin Wall fall? +> +> **User:** 1991 +> +> **Agent:** Not quite. The Berlin Wall fell in 1989. Next review is tomorrow. +> *(agent calls: memento_cards.py rate --id ABC --rating hard --user-answer "1991")* +> +> Next question: Who was the first person to walk on the moon? + +**The rules:** + +1. Show only the question. Wait for the user to answer. +2. After receiving their answer, compare it to the expected answer and grade it: + - **correct** → user got the key fact right (even if worded differently) + - **partial** → right track but missing the core detail + - **incorrect** → wrong or off-topic +3. **You MUST tell the user the correct answer and how they did.** Keep it short and plain-text. Use this format: + - correct: "Correct. Answer: {answer}. Next review in 7 days." + - partial: "Close. Answer: {answer}. {what they missed}. Next review in 3 days." + - incorrect: "Not quite. Answer: {answer}. Next review tomorrow." +4. Then call the rate command: correct→easy, partial→good, incorrect→hard. +5. Then show the next question. + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py rate \ + --id CARD_ID --rating easy --user-answer "what the user said" +``` + +**Never skip step 3.** The user must always see the correct answer and feedback before you move on. + +If no cards are due, tell the user: "No cards due for review right now. Check back later!" + +**Retire override:** At any point the user can say "retire this card" to permanently remove it from reviews. Use `--rating retire` for this. + +### Spaced Repetition Algorithm + +The rating determines the next review interval: + +| Rating | Interval | ease_streak | Status change | +|---|---|---|---| +| **hard** | +1 day | reset to 0 | stays learning | +| **good** | +3 days | reset to 0 | stays learning | +| **easy** | +7 days | +1 | if ease_streak >= 3 → retired | +| **retire** | permanent | reset to 0 | → retired | + +- **learning**: card is actively in rotation +- **retired**: card won't appear in reviews (user has mastered it or manually retired it) +- Three consecutive "easy" ratings automatically retire a card + +### YouTube Quiz Generation + +When the user sends a YouTube URL and wants a quiz: + +**Step 1:** Extract the video ID from the URL (e.g. `dQw4w9WgXcQ` from `https://www.youtube.com/watch?v=dQw4w9WgXcQ`). + +**Step 2:** Fetch the transcript: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/youtube_quiz.py fetch VIDEO_ID +``` + +This returns `{"title": "...", "transcript": "..."}` or an error. + +If the script reports `missing_dependency`, tell the user to install it: +```bash +pip install youtube-transcript-api +``` + +**Step 3:** Generate 5 quiz questions from the transcript. Use these rules: + +``` +You are creating a 5-question quiz for a podcast episode. +Return ONLY a JSON array with exactly 5 objects. +Each object must contain keys 'question' and 'answer'. + +Selection criteria: +- Prioritize important, surprising, or foundational facts. +- Skip filler, obvious details, and facts that require heavy context. +- Never return true/false questions. +- Never ask only for a date. + +Question rules: +- Each question must test exactly one discrete fact. +- Use clear, unambiguous wording. +- Prefer What, Who, How many, Which. +- Avoid open-ended Describe or Explain prompts. + +Answer rules: +- Each answer must be under 240 characters. +- Lead with the answer itself, not preamble. +- Add only minimal clarifying detail if needed. +``` + +Use the first 15,000 characters of the transcript as context. Generate the questions yourself (you are the LLM). + +**Step 4:** Validate the output is valid JSON with exactly 5 items, each having non-empty `question` and `answer` strings. If validation fails, retry once. + +**Step 5:** Store quiz cards: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add-quiz \ + --video-id "VIDEO_ID" \ + --questions '[{"question":"...","answer":"..."},...]' \ + --collection "Quiz - Episode Title" +``` + +The script deduplicates by `video_id` — if cards for that video already exist, it skips creation and reports the existing cards. + +**Step 6:** Present questions one-by-one using the same free-text grading flow: +1. Show "Question 1/5: ..." and wait for the user's answer. Never include the answer or any hint about revealing it. +2. Wait for the user to answer in their own words +3. Grade their answer using the grading prompt (see "Reviewing Due Cards" section) +4. **IMPORTANT: You MUST reply to the user with feedback before doing anything else.** Show the grade, the correct answer, and when the card is next due. Do NOT silently skip to the next question. Keep it short and plain-text. Example: "Not quite. Answer: {answer}. Next review tomorrow." +5. **After showing feedback**, call the rate command and then show the next question in the same message: +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py rate \ + --id CARD_ID --rating easy --user-answer "what the user said" +``` +6. Repeat. Every answer MUST receive visible feedback before the next question. + +### Export/Import CSV + +**Export:** +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py export \ + --output ~/flashcards.csv +``` + +Produces a 3-column CSV: `question,answer,collection` (no header row). + +**Import:** +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py import \ + --file ~/flashcards.csv \ + --collection "Imported" +``` + +Reads a CSV with columns: question, answer, and optionally collection (column 3). If the collection column is missing, uses the `--collection` argument. + +### Statistics + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py stats +``` + +Returns JSON with: +- `total`: total card count +- `learning`: cards in active rotation +- `retired`: mastered cards +- `due_now`: cards due for review right now +- `collections`: breakdown by collection name + +## Pitfalls + +- **Never edit `cards.json` directly** — always use the script subcommands to avoid corruption +- **Transcript failures** — some YouTube videos have no English transcript or have transcripts disabled; inform the user and suggest another video +- **Optional dependency** — `youtube_quiz.py` needs `youtube-transcript-api`; if missing, tell the user to run `pip install youtube-transcript-api` +- **Large imports** — CSV imports with thousands of rows work fine but the JSON output may be verbose; summarize the result for the user +- **Video ID extraction** — support both `youtube.com/watch?v=ID` and `youtu.be/ID` URL formats + +## Verification + +Verify the helper scripts directly: + +```bash +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py stats +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py add --question "Capital of France?" --answer "Paris" --collection "General" +python3 ~/.hermes/skills/productivity/memento-flashcards/scripts/memento_cards.py due +``` + +If you are testing from the repo checkout, run: + +```bash +pytest tests/skills/test_memento_cards.py tests/skills/test_youtube_quiz.py -q +``` + +Agent-level verification: +- Start a review and confirm feedback is plain text, brief, and always includes the correct answer before the next card +- Run a YouTube quiz flow and confirm each answer receives visible feedback before the next question diff --git a/optional-skills/productivity/memento-flashcards/scripts/memento_cards.py b/optional-skills/productivity/memento-flashcards/scripts/memento_cards.py new file mode 100644 index 000000000..47e41dd3a --- /dev/null +++ b/optional-skills/productivity/memento-flashcards/scripts/memento_cards.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +"""Memento card storage, spaced-repetition engine, and CSV I/O. + +Stdlib-only. All output is JSON for agent parsing. +Data file: $HERMES_HOME/skills/productivity/memento-flashcards/data/cards.json +""" + +import argparse +import csv +import json +import os +import sys +import tempfile +import uuid +from datetime import datetime, timedelta, timezone +from pathlib import Path + +_HERMES_HOME = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) +DATA_DIR = _HERMES_HOME / "skills" / "productivity" / "memento-flashcards" / "data" +CARDS_FILE = DATA_DIR / "cards.json" + +RETIRED_SENTINEL = "9999-12-31T23:59:59+00:00" + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +def _iso(dt: datetime) -> str: + return dt.isoformat() + + +def _parse_iso(s: str) -> datetime: + return datetime.fromisoformat(s) + + +def _empty_store() -> dict: + return {"cards": [], "version": 1} + + +def _load() -> dict: + if not CARDS_FILE.exists(): + return _empty_store() + try: + with open(CARDS_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict) or "cards" not in data: + return _empty_store() + return data + except (json.JSONDecodeError, OSError): + return _empty_store() + + +def _save(data: dict) -> None: + DATA_DIR.mkdir(parents=True, exist_ok=True) + fd, tmp = tempfile.mkstemp(dir=DATA_DIR, suffix=".tmp") + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + f.write("\n") + os.replace(tmp, CARDS_FILE) + except BaseException: + try: + os.unlink(tmp) + except OSError: + pass + raise + + +def _out(obj: object) -> None: + json.dump(obj, sys.stdout, indent=2, ensure_ascii=False) + sys.stdout.write("\n") + + +# ── Subcommands ────────────────────────────────────────────────────────────── + +def cmd_add(args: argparse.Namespace) -> None: + data = _load() + now = _now() + card = { + "id": str(uuid.uuid4()), + "question": args.question, + "answer": args.answer, + "collection": args.collection or "General", + "status": "learning", + "ease_streak": 0, + "next_review_at": _iso(now), + "created_at": _iso(now), + "video_id": None, + "last_user_answer": None, + } + data["cards"].append(card) + _save(data) + _out({"ok": True, "card": card}) + + +def cmd_add_quiz(args: argparse.Namespace) -> None: + data = _load() + now = _now() + + try: + questions = json.loads(args.questions) + except json.JSONDecodeError as exc: + _out({"ok": False, "error": f"Invalid JSON for --questions: {exc}"}) + sys.exit(1) + + # Dedup: skip if cards with this video_id already exist + existing_ids = {c["video_id"] for c in data["cards"] if c.get("video_id")} + if args.video_id in existing_ids: + existing = [c for c in data["cards"] if c.get("video_id") == args.video_id] + _out({"ok": True, "skipped": True, "reason": "duplicate_video_id", "existing_count": len(existing), "cards": existing}) + return + + created = [] + for qa in questions: + card = { + "id": str(uuid.uuid4()), + "question": qa["question"], + "answer": qa["answer"], + "collection": args.collection or "Quiz", + "status": "learning", + "ease_streak": 0, + "next_review_at": _iso(now), + "created_at": _iso(now), + "video_id": args.video_id, + "last_user_answer": None, + } + data["cards"].append(card) + created.append(card) + + _save(data) + _out({"ok": True, "created_count": len(created), "cards": created}) + + +def cmd_due(args: argparse.Namespace) -> None: + data = _load() + now = _now() + due = [] + for card in data["cards"]: + if card["status"] == "retired": + continue + review_at = _parse_iso(card["next_review_at"]) + if review_at <= now: + if args.collection and card["collection"] != args.collection: + continue + due.append(card) + _out({"ok": True, "count": len(due), "cards": due}) + + +def cmd_rate(args: argparse.Namespace) -> None: + data = _load() + now = _now() + card = None + for c in data["cards"]: + if c["id"] == args.id: + card = c + break + if not card: + _out({"ok": False, "error": f"Card not found: {args.id}"}) + sys.exit(1) + + rating = args.rating + user_answer = getattr(args, "user_answer", None) + if user_answer is not None: + card["last_user_answer"] = user_answer + + if rating == "retire": + card["status"] = "retired" + card["next_review_at"] = RETIRED_SENTINEL + card["ease_streak"] = 0 + elif rating == "hard": + card["next_review_at"] = _iso(now + timedelta(days=1)) + card["ease_streak"] = 0 + elif rating == "good": + card["next_review_at"] = _iso(now + timedelta(days=3)) + card["ease_streak"] = 0 + elif rating == "easy": + card["next_review_at"] = _iso(now + timedelta(days=7)) + card["ease_streak"] = card.get("ease_streak", 0) + 1 + if card["ease_streak"] >= 3: + card["status"] = "retired" + + _save(data) + _out({"ok": True, "card": card}) + + +def cmd_list(args: argparse.Namespace) -> None: + data = _load() + cards = data["cards"] + if args.collection: + cards = [c for c in cards if c["collection"] == args.collection] + if args.status: + cards = [c for c in cards if c["status"] == args.status] + _out({"ok": True, "count": len(cards), "cards": cards}) + + +def cmd_stats(args: argparse.Namespace) -> None: + data = _load() + now = _now() + total = len(data["cards"]) + learning = sum(1 for c in data["cards"] if c["status"] == "learning") + retired = sum(1 for c in data["cards"] if c["status"] == "retired") + due_now = 0 + for c in data["cards"]: + if c["status"] != "retired" and _parse_iso(c["next_review_at"]) <= now: + due_now += 1 + + collections: dict[str, int] = {} + for c in data["cards"]: + name = c["collection"] + collections[name] = collections.get(name, 0) + 1 + + _out({ + "ok": True, + "total": total, + "learning": learning, + "retired": retired, + "due_now": due_now, + "collections": collections, + }) + + +def cmd_export(args: argparse.Namespace) -> None: + data = _load() + output_path = Path(args.output).expanduser() + with open(output_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f, lineterminator="\n") + for card in data["cards"]: + writer.writerow([card["question"], card["answer"], card["collection"]]) + _out({"ok": True, "exported": len(data["cards"]), "path": str(output_path)}) + + +def cmd_import(args: argparse.Namespace) -> None: + data = _load() + now = _now() + file_path = Path(args.file).expanduser() + + if not file_path.exists(): + _out({"ok": False, "error": f"File not found: {file_path}"}) + sys.exit(1) + + created = 0 + with open(file_path, "r", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + if len(row) < 2: + continue + question = row[0].strip() + answer = row[1].strip() + collection = row[2].strip() if len(row) >= 3 and row[2].strip() else (args.collection or "Imported") + if not question or not answer: + continue + card = { + "id": str(uuid.uuid4()), + "question": question, + "answer": answer, + "collection": collection, + "status": "learning", + "ease_streak": 0, + "next_review_at": _iso(now), + "created_at": _iso(now), + "video_id": None, + "last_user_answer": None, + } + data["cards"].append(card) + created += 1 + + _save(data) + _out({"ok": True, "imported": created}) + + +def cmd_delete(args: argparse.Namespace) -> None: + data = _load() + original = len(data["cards"]) + data["cards"] = [c for c in data["cards"] if c["id"] != args.id] + removed = original - len(data["cards"]) + if removed == 0: + _out({"ok": False, "error": f"Card not found: {args.id}"}) + sys.exit(1) + _save(data) + _out({"ok": True, "deleted": args.id}) + + +def cmd_delete_collection(args: argparse.Namespace) -> None: + data = _load() + original = len(data["cards"]) + data["cards"] = [c for c in data["cards"] if c["collection"] != args.collection] + removed = original - len(data["cards"]) + _save(data) + _out({"ok": True, "deleted_count": removed, "collection": args.collection}) + + +# ── CLI ────────────────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser(description="Memento flashcard manager") + sub = parser.add_subparsers(dest="command", required=True) + + p_add = sub.add_parser("add", help="Create one card") + p_add.add_argument("--question", required=True) + p_add.add_argument("--answer", required=True) + p_add.add_argument("--collection", default="General") + + p_quiz = sub.add_parser("add-quiz", help="Batch-add quiz cards") + p_quiz.add_argument("--video-id", required=True) + p_quiz.add_argument("--questions", required=True, help="JSON array of {question, answer}") + p_quiz.add_argument("--collection", default="Quiz") + + p_due = sub.add_parser("due", help="List due cards") + p_due.add_argument("--collection", default=None) + + p_rate = sub.add_parser("rate", help="Rate a card") + p_rate.add_argument("--id", required=True) + p_rate.add_argument("--rating", required=True, choices=["easy", "good", "hard", "retire"]) + p_rate.add_argument("--user-answer", default=None) + + p_list = sub.add_parser("list", help="List cards") + p_list.add_argument("--collection", default=None) + p_list.add_argument("--status", default=None, choices=["learning", "retired"]) + + sub.add_parser("stats", help="Show statistics") + + p_export = sub.add_parser("export", help="Export cards to CSV") + p_export.add_argument("--output", required=True) + + p_import = sub.add_parser("import", help="Import cards from CSV") + p_import.add_argument("--file", required=True) + p_import.add_argument("--collection", default="Imported") + + p_del = sub.add_parser("delete", help="Delete one card") + p_del.add_argument("--id", required=True) + + p_delcol = sub.add_parser("delete-collection", help="Delete all cards in a collection") + p_delcol.add_argument("--collection", required=True) + + args = parser.parse_args() + cmd_map = { + "add": cmd_add, + "add-quiz": cmd_add_quiz, + "due": cmd_due, + "rate": cmd_rate, + "list": cmd_list, + "stats": cmd_stats, + "export": cmd_export, + "import": cmd_import, + "delete": cmd_delete, + "delete-collection": cmd_delete_collection, + } + cmd_map[args.command](args) + + +if __name__ == "__main__": + main() diff --git a/optional-skills/productivity/memento-flashcards/scripts/youtube_quiz.py b/optional-skills/productivity/memento-flashcards/scripts/youtube_quiz.py new file mode 100644 index 000000000..5b6f44ca7 --- /dev/null +++ b/optional-skills/productivity/memento-flashcards/scripts/youtube_quiz.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +"""Fetch YouTube transcripts for Memento quiz generation. + +Requires: pip install youtube-transcript-api +The quiz question *generation* is done by the agent's LLM — this script only fetches transcripts. +""" + +import argparse +import json +import re +import sys + + +def _out(obj: object) -> None: + json.dump(obj, sys.stdout, indent=2, ensure_ascii=False) + sys.stdout.write("\n") + + +def _normalize_segments(segments: list) -> str: + parts = [] + for seg in segments: + text = str(seg.get("text", "")).strip() + if text: + parts.append(text) + return re.sub(r"\s+", " ", " ".join(parts)).strip() + + +def cmd_fetch(args: argparse.Namespace) -> None: + try: + import youtube_transcript_api # noqa: F811 + except ImportError: + _out({ + "ok": False, + "error": "missing_dependency", + "message": "Run: pip install youtube-transcript-api", + }) + sys.exit(1) + + video_id = args.video_id + languages = ["en", "en-US", "en-GB", "en-CA", "en-AU"] + + api = youtube_transcript_api.YouTubeTranscriptApi() + try: + raw = api.fetch(video_id, languages=languages) + except Exception as exc: + error_type = type(exc).__name__ + _out({ + "ok": False, + "error": "transcript_unavailable", + "error_type": error_type, + "message": f"Could not fetch transcript for {video_id}: {exc}", + }) + sys.exit(1) + + segments = raw + if hasattr(raw, "to_raw_data"): + segments = raw.to_raw_data() + + text = _normalize_segments(segments) + if not text: + _out({ + "ok": False, + "error": "empty_transcript", + "message": f"Transcript for {video_id} contained no usable text.", + }) + sys.exit(1) + + _out({ + "ok": True, + "video_id": video_id, + "transcript": text, + }) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Memento YouTube transcript fetcher") + sub = parser.add_subparsers(dest="command", required=True) + + p_fetch = sub.add_parser("fetch", help="Fetch transcript for a video") + p_fetch.add_argument("video_id", help="YouTube video ID") + + args = parser.parse_args() + if args.command == "fetch": + cmd_fetch(args) + + +if __name__ == "__main__": + main() diff --git a/optional-skills/productivity/siyuan/SKILL.md b/optional-skills/productivity/siyuan/SKILL.md new file mode 100644 index 000000000..49c5d6185 --- /dev/null +++ b/optional-skills/productivity/siyuan/SKILL.md @@ -0,0 +1,297 @@ +--- +name: siyuan +description: SiYuan Note API for searching, reading, creating, and managing blocks and documents in a self-hosted knowledge base via curl. +version: 1.0.0 +author: FEUAZUR +license: MIT +metadata: + hermes: + tags: [SiYuan, Notes, Knowledge Base, PKM, API] + related_skills: [obsidian, notion] + homepage: https://github.com/siyuan-note/siyuan +prerequisites: + env_vars: [SIYUAN_TOKEN] + commands: [curl, jq] +required_environment_variables: + - name: SIYUAN_TOKEN + prompt: SiYuan API token + help: "Settings > About in SiYuan desktop app" + - name: SIYUAN_URL + prompt: SiYuan instance URL (default http://127.0.0.1:6806) + required_for: remote instances +--- + +# SiYuan Note API + +Use the [SiYuan](https://github.com/siyuan-note/siyuan) kernel API via curl to search, read, create, update, and delete blocks and documents in a self-hosted knowledge base. No extra tools needed -- just curl and an API token. + +## Prerequisites + +1. Install and run SiYuan (desktop or Docker) +2. Get your API token: **Settings > About > API token** +3. Store it in `~/.hermes/.env`: + ``` + SIYUAN_TOKEN=your_token_here + SIYUAN_URL=http://127.0.0.1:6806 + ``` + `SIYUAN_URL` defaults to `http://127.0.0.1:6806` if not set. + +## API Basics + +All SiYuan API calls are **POST with JSON body**. Every request follows this pattern: + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/..." \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"param": "value"}' +``` + +Responses are JSON with this structure: +```json +{"code": 0, "msg": "", "data": { ... }} +``` +`code: 0` means success. Any other value is an error -- check `msg` for details. + +**ID format:** SiYuan IDs look like `20210808180117-6v0mkxr` (14-digit timestamp + 7 alphanumeric chars). + +## Quick Reference + +| Operation | Endpoint | +|-----------|----------| +| Full-text search | `/api/search/fullTextSearchBlock` | +| SQL query | `/api/query/sql` | +| Read block | `/api/block/getBlockKramdown` | +| Read children | `/api/block/getChildBlocks` | +| Get path | `/api/filetree/getHPathByID` | +| Get attributes | `/api/attr/getBlockAttrs` | +| List notebooks | `/api/notebook/lsNotebooks` | +| List documents | `/api/filetree/listDocsByPath` | +| Create notebook | `/api/notebook/createNotebook` | +| Create document | `/api/filetree/createDocWithMd` | +| Append block | `/api/block/appendBlock` | +| Update block | `/api/block/updateBlock` | +| Rename document | `/api/filetree/renameDocByID` | +| Set attributes | `/api/attr/setBlockAttrs` | +| Delete block | `/api/block/deleteBlock` | +| Delete document | `/api/filetree/removeDocByID` | +| Export as Markdown | `/api/export/exportMdContent` | + +## Common Operations + +### Search (Full-Text) + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/search/fullTextSearchBlock" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"query": "meeting notes", "page": 0}' | jq '.data.blocks[:5]' +``` + +### Search (SQL) + +Query the blocks database directly. Only SELECT statements are safe. + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/query/sql" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"stmt": "SELECT id, content, type, box FROM blocks WHERE content LIKE '\''%keyword%'\'' AND type='\''p'\'' LIMIT 20"}' | jq '.data' +``` + +Useful columns: `id`, `parent_id`, `root_id`, `box` (notebook ID), `path`, `content`, `type`, `subtype`, `created`, `updated`. + +### Read Block Content + +Returns block content in Kramdown (Markdown-like) format. + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/getBlockKramdown" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "20210808180117-6v0mkxr"}' | jq '.data.kramdown' +``` + +### Read Child Blocks + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/getChildBlocks" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "20210808180117-6v0mkxr"}' | jq '.data' +``` + +### Get Human-Readable Path + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/getHPathByID" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "20210808180117-6v0mkxr"}' | jq '.data' +``` + +### Get Block Attributes + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/attr/getBlockAttrs" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "20210808180117-6v0mkxr"}' | jq '.data' +``` + +### List Notebooks + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/notebook/lsNotebooks" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{}' | jq '.data.notebooks[] | {id, name, closed}' +``` + +### List Documents in a Notebook + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/listDocsByPath" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"notebook": "NOTEBOOK_ID", "path": "/"}' | jq '.data.files[] | {id, name}' +``` + +### Create a Document + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/createDocWithMd" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "notebook": "NOTEBOOK_ID", + "path": "/Meeting Notes/2026-03-22", + "markdown": "# Meeting Notes\n\n- Discussed project timeline\n- Assigned tasks" + }' | jq '.data' +``` + +### Create a Notebook + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/notebook/createNotebook" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"name": "My New Notebook"}' | jq '.data.notebook.id' +``` + +### Append Block to Document + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/appendBlock" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "parentID": "DOCUMENT_OR_BLOCK_ID", + "data": "New paragraph added at the end.", + "dataType": "markdown" + }' | jq '.data' +``` + +Also available: `/api/block/prependBlock` (same params, inserts at the beginning) and `/api/block/insertBlock` (uses `previousID` instead of `parentID` to insert after a specific block). + +### Update Block Content + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/updateBlock" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "id": "BLOCK_ID", + "data": "Updated content here.", + "dataType": "markdown" + }' | jq '.data' +``` + +### Rename a Document + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/filetree/renameDocByID" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "DOCUMENT_ID", "title": "New Title"}' +``` + +### Set Block Attributes + +Custom attributes must be prefixed with `custom-`: + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/attr/setBlockAttrs" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "id": "BLOCK_ID", + "attrs": { + "custom-status": "reviewed", + "custom-priority": "high" + } + }' +``` + +### Delete a Block + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/block/deleteBlock" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "BLOCK_ID"}' +``` + +To delete a whole document: use `/api/filetree/removeDocByID` with `{"id": "DOC_ID"}`. +To delete a notebook: use `/api/notebook/removeNotebook` with `{"notebook": "NOTEBOOK_ID"}`. + +### Export Document as Markdown + +```bash +curl -s -X POST "${SIYUAN_URL:-http://127.0.0.1:6806}/api/export/exportMdContent" \ + -H "Authorization: Token $SIYUAN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"id": "DOCUMENT_ID"}' | jq -r '.data.content' +``` + +## Block Types + +Common `type` values in SQL queries: + +| Type | Description | +|------|-------------| +| `d` | Document (root block) | +| `p` | Paragraph | +| `h` | Heading | +| `l` | List | +| `i` | List item | +| `c` | Code block | +| `m` | Math block | +| `t` | Table | +| `b` | Blockquote | +| `s` | Super block | +| `html` | HTML block | + +## Pitfalls + +- **All endpoints are POST** -- even read-only operations. Do not use GET. +- **SQL safety**: only use SELECT queries. INSERT/UPDATE/DELETE/DROP are dangerous and should never be sent. +- **ID validation**: IDs match the pattern `YYYYMMDDHHmmss-xxxxxxx`. Reject anything else. +- **Error responses**: always check `code != 0` in responses before processing `data`. +- **Large documents**: block content and export results can be very large. Use `LIMIT` in SQL and pipe through `jq` to extract only what you need. +- **Notebook IDs**: when working with a specific notebook, get its ID first via `lsNotebooks`. + +## Alternative: MCP Server + +If you prefer a native integration instead of curl, install the SiYuan MCP server: + +```yaml +# In ~/.hermes/config.yaml under mcp_servers: +mcp_servers: + siyuan: + command: npx + args: ["-y", "@porkll/siyuan-mcp"] + env: + SIYUAN_TOKEN: "your_token" + SIYUAN_URL: "http://127.0.0.1:6806" +``` diff --git a/skills/research/parallel-cli/SKILL.md b/optional-skills/research/parallel-cli/SKILL.md similarity index 100% rename from skills/research/parallel-cli/SKILL.md rename to optional-skills/research/parallel-cli/SKILL.md diff --git a/optional-skills/research/scrapling/SKILL.md b/optional-skills/research/scrapling/SKILL.md new file mode 100644 index 000000000..aaa38c90a --- /dev/null +++ b/optional-skills/research/scrapling/SKILL.md @@ -0,0 +1,335 @@ +--- +name: scrapling +description: Web scraping with Scrapling - HTTP fetching, stealth browser automation, Cloudflare bypass, and spider crawling via CLI and Python. +version: 1.0.0 +author: FEUAZUR +license: MIT +metadata: + hermes: + tags: [Web Scraping, Browser, Cloudflare, Stealth, Crawling, Spider] + related_skills: [duckduckgo-search, domain-intel] + homepage: https://github.com/D4Vinci/Scrapling +prerequisites: + commands: [scrapling, python] +--- + +# Scrapling + +[Scrapling](https://github.com/D4Vinci/Scrapling) is a web scraping framework with anti-bot bypass, stealth browser automation, and a spider framework. It provides three fetching strategies (HTTP, dynamic JS, stealth/Cloudflare) and a full CLI. + +**This skill is for educational and research purposes only.** Users must comply with local/international data scraping laws and respect website Terms of Service. + +## When to Use + +- Scraping static HTML pages (faster than browser tools) +- Scraping JS-rendered pages that need a real browser +- Bypassing Cloudflare Turnstile or bot detection +- Crawling multiple pages with a spider +- When the built-in `web_extract` tool does not return the data you need + +## Installation + +```bash +pip install "scrapling[all]" +scrapling install +``` + +Minimal install (HTTP only, no browser): +```bash +pip install scrapling +``` + +With browser automation only: +```bash +pip install "scrapling[fetchers]" +scrapling install +``` + +## Quick Reference + +| Approach | Class | Use When | +|----------|-------|----------| +| HTTP | `Fetcher` / `FetcherSession` | Static pages, APIs, fast bulk requests | +| Dynamic | `DynamicFetcher` / `DynamicSession` | JS-rendered content, SPAs | +| Stealth | `StealthyFetcher` / `StealthySession` | Cloudflare, anti-bot protected sites | +| Spider | `Spider` | Multi-page crawling with link following | + +## CLI Usage + +### Extract Static Page + +```bash +scrapling extract get 'https://example.com' output.md +``` + +With CSS selector and browser impersonation: + +```bash +scrapling extract get 'https://example.com' output.md \ + --css-selector '.content' \ + --impersonate 'chrome' +``` + +### Extract JS-Rendered Page + +```bash +scrapling extract fetch 'https://example.com' output.md \ + --css-selector '.dynamic-content' \ + --disable-resources \ + --network-idle +``` + +### Extract Cloudflare-Protected Page + +```bash +scrapling extract stealthy-fetch 'https://protected-site.com' output.html \ + --solve-cloudflare \ + --block-webrtc \ + --hide-canvas +``` + +### POST Request + +```bash +scrapling extract post 'https://example.com/api' output.json \ + --json '{"query": "search term"}' +``` + +### Output Formats + +The output format is determined by the file extension: +- `.html` -- raw HTML +- `.md` -- converted to Markdown +- `.txt` -- plain text +- `.json` / `.jsonl` -- JSON + +## Python: HTTP Scraping + +### Single Request + +```python +from scrapling.fetchers import Fetcher + +page = Fetcher.get('https://quotes.toscrape.com/') +quotes = page.css('.quote .text::text').getall() +for q in quotes: + print(q) +``` + +### Session (Persistent Cookies) + +```python +from scrapling.fetchers import FetcherSession + +with FetcherSession(impersonate='chrome') as session: + page = session.get('https://example.com/', stealthy_headers=True) + links = page.css('a::attr(href)').getall() + for link in links[:5]: + sub = session.get(link) + print(sub.css('h1::text').get()) +``` + +### POST / PUT / DELETE + +```python +page = Fetcher.post('https://api.example.com/data', json={"key": "value"}) +page = Fetcher.put('https://api.example.com/item/1', data={"name": "updated"}) +page = Fetcher.delete('https://api.example.com/item/1') +``` + +### With Proxy + +```python +page = Fetcher.get('https://example.com', proxy='http://user:pass@proxy:8080') +``` + +## Python: Dynamic Pages (JS-Rendered) + +For pages that require JavaScript execution (SPAs, lazy-loaded content): + +```python +from scrapling.fetchers import DynamicFetcher + +page = DynamicFetcher.fetch('https://example.com', headless=True) +data = page.css('.js-loaded-content::text').getall() +``` + +### Wait for Specific Element + +```python +page = DynamicFetcher.fetch( + 'https://example.com', + wait_selector=('.results', 'visible'), + network_idle=True, +) +``` + +### Disable Resources for Speed + +Blocks fonts, images, media, stylesheets (~25% faster): + +```python +from scrapling.fetchers import DynamicSession + +with DynamicSession(headless=True, disable_resources=True, network_idle=True) as session: + page = session.fetch('https://example.com') + items = page.css('.item::text').getall() +``` + +### Custom Page Automation + +```python +from playwright.sync_api import Page +from scrapling.fetchers import DynamicFetcher + +def scroll_and_click(page: Page): + page.mouse.wheel(0, 3000) + page.wait_for_timeout(1000) + page.click('button.load-more') + page.wait_for_selector('.extra-results') + +page = DynamicFetcher.fetch('https://example.com', page_action=scroll_and_click) +results = page.css('.extra-results .item::text').getall() +``` + +## Python: Stealth Mode (Anti-Bot Bypass) + +For Cloudflare-protected or heavily fingerprinted sites: + +```python +from scrapling.fetchers import StealthyFetcher + +page = StealthyFetcher.fetch( + 'https://protected-site.com', + headless=True, + solve_cloudflare=True, + block_webrtc=True, + hide_canvas=True, +) +content = page.css('.protected-content::text').getall() +``` + +### Stealth Session + +```python +from scrapling.fetchers import StealthySession + +with StealthySession(headless=True, solve_cloudflare=True) as session: + page1 = session.fetch('https://protected-site.com/page1') + page2 = session.fetch('https://protected-site.com/page2') +``` + +## Element Selection + +All fetchers return a `Selector` object with these methods: + +### CSS Selectors + +```python +page.css('h1::text').get() # First h1 text +page.css('a::attr(href)').getall() # All link hrefs +page.css('.quote .text::text').getall() # Nested selection +``` + +### XPath + +```python +page.xpath('//div[@class="content"]/text()').getall() +page.xpath('//a/@href').getall() +``` + +### Find Methods + +```python +page.find_all('div', class_='quote') # By tag + attribute +page.find_by_text('Read more', tag='a') # By text content +page.find_by_regex(r'\$\d+\.\d{2}') # By regex pattern +``` + +### Similar Elements + +Find elements with similar structure (useful for product listings, etc.): + +```python +first_product = page.css('.product')[0] +all_similar = first_product.find_similar() +``` + +### Navigation + +```python +el = page.css('.target')[0] +el.parent # Parent element +el.children # Child elements +el.next_sibling # Next sibling +el.prev_sibling # Previous sibling +``` + +## Python: Spider Framework + +For multi-page crawling with link following: + +```python +from scrapling.spiders import Spider, Request, Response + +class QuotesSpider(Spider): + name = "quotes" + start_urls = ["https://quotes.toscrape.com/"] + concurrent_requests = 10 + download_delay = 1 + + async def parse(self, response: Response): + for quote in response.css('.quote'): + yield { + "text": quote.css('.text::text').get(), + "author": quote.css('.author::text').get(), + "tags": quote.css('.tag::text').getall(), + } + + next_page = response.css('.next a::attr(href)').get() + if next_page: + yield response.follow(next_page) + +result = QuotesSpider().start() +print(f"Scraped {len(result.items)} quotes") +result.items.to_json("quotes.json") +``` + +### Multi-Session Spider + +Route requests to different fetcher types: + +```python +from scrapling.fetchers import FetcherSession, AsyncStealthySession + +class SmartSpider(Spider): + name = "smart" + start_urls = ["https://example.com/"] + + def configure_sessions(self, manager): + manager.add("fast", FetcherSession(impersonate="chrome")) + manager.add("stealth", AsyncStealthySession(headless=True), lazy=True) + + async def parse(self, response: Response): + for link in response.css('a::attr(href)').getall(): + if "protected" in link: + yield Request(link, sid="stealth") + else: + yield Request(link, sid="fast", callback=self.parse) +``` + +### Pause/Resume Crawling + +```python +spider = QuotesSpider(crawldir="./crawl_checkpoint") +spider.start() # Ctrl+C to pause, re-run to resume from checkpoint +``` + +## Pitfalls + +- **Browser install required**: run `scrapling install` after pip install -- without it, `DynamicFetcher` and `StealthyFetcher` will fail +- **Timeouts**: DynamicFetcher/StealthyFetcher timeout is in **milliseconds** (default 30000), Fetcher timeout is in **seconds** +- **Cloudflare bypass**: `solve_cloudflare=True` adds 5-15 seconds to fetch time -- only enable when needed +- **Resource usage**: StealthyFetcher runs a real browser -- limit concurrent usage +- **Legal**: always check robots.txt and website ToS before scraping. This library is for educational and research purposes +- **Python version**: requires Python 3.10+ diff --git a/pyproject.toml b/pyproject.toml index bd5fa6481..d044739a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "hermes-agent" -version = "0.4.0" +version = "0.5.0" description = "The self-improving AI agent — creates skills from experience, improves them during use, and runs anywhere" readme = "README.md" requires-python = ">=3.11" @@ -26,6 +26,7 @@ dependencies = [ # Interactive CLI (prompt_toolkit is used directly by cli.py) "prompt_toolkit>=3.0.52,<4", # Tools + "exa-py>=2.9.0,<3", "firecrawl-py>=4.16.0,<5", "parallel-web>=0.4.2,<1", "fal-client>=0.13.1,<1", @@ -37,7 +38,7 @@ dependencies = [ ] [project.optional-dependencies] -modal = ["swe-rex[modal]>=1.4.0,<2"] +modal = ["modal>=1.0.0,<2"] daytona = ["daytona>=0.148.0,<1"] dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2"] messaging = ["python-telegram-bot>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"] @@ -55,8 +56,9 @@ honcho = ["honcho-ai>=2.0.1,<3"] mcp = ["mcp>=1.2.0,<2"] homeassistant = ["aiohttp>=3.9.0,<4"] sms = ["aiohttp>=3.9.0,<4"] -acp = ["agent-client-protocol>=0.8.1,<1.0"] +acp = ["agent-client-protocol>=0.8.1,<0.9"] dingtalk = ["dingtalk-stream>=0.1.0,<1"] +feishu = ["lark-oapi>=1.5.3,<2"] rl = [ "atroposlib @ git+https://github.com/NousResearch/atropos.git", "tinker @ git+https://github.com/thinking-machines-lab/tinker.git", @@ -82,6 +84,7 @@ all = [ "hermes-agent[acp]", "hermes-agent[voice]", "hermes-agent[dingtalk]", + "hermes-agent[feishu]", ] [project.scripts] diff --git a/run_agent.py b/run_agent.py index cd3884c52..090fea978 100644 --- a/run_agent.py +++ b/run_agent.py @@ -62,7 +62,12 @@ else: # Import our tool system -from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements +from model_tools import ( + get_tool_definitions, + get_toolset_for_tool, + handle_function_call, + check_toolset_requirements, +) from tools.terminal_tool import cleanup_vm from tools.interrupt import set_interrupt as _set_interrupt from tools.browser_tool import cleanup_browser @@ -84,7 +89,7 @@ from agent.model_metadata import ( ) from agent.context_compressor import ContextCompressor from agent.prompt_caching import apply_anthropic_cache_control -from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md +from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS from agent.usage_pricing import estimate_usage_cost, normalize_usage from agent.display import ( KawaiiSpinner, build_tool_preview as _build_tool_preview, @@ -357,6 +362,85 @@ def _inject_honcho_turn_context(content, turn_context: str): return f"{text}\n\n{note}" +# Budget warning text patterns injected by _get_budget_warning(). +_BUDGET_WARNING_RE = re.compile( + r"\[BUDGET(?:\s+WARNING)?:\s+Iteration\s+\d+/\d+\..*?\]", + re.DOTALL, +) + + +# Regex to match lone surrogate code points (U+D800..U+DFFF). +# These are invalid in UTF-8 and cause UnicodeEncodeError when the OpenAI SDK +# serialises messages to JSON. Common source: clipboard paste from Google Docs +# or other rich-text editors on some platforms. +_SURROGATE_RE = re.compile(r'[\ud800-\udfff]') + + +def _sanitize_surrogates(text: str) -> str: + """Replace lone surrogate code points with U+FFFD (replacement character). + + Surrogates are invalid in UTF-8 and will crash ``json.dumps()`` inside the + OpenAI SDK. This is a fast no-op when the text contains no surrogates. + """ + if _SURROGATE_RE.search(text): + return _SURROGATE_RE.sub('\ufffd', text) + return text + + +def _sanitize_messages_surrogates(messages: list) -> bool: + """Sanitize surrogate characters from all string content in a messages list. + + Walks message dicts in-place. Returns True if any surrogates were found + and replaced, False otherwise. + """ + found = False + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if isinstance(content, str) and _SURROGATE_RE.search(content): + msg["content"] = _SURROGATE_RE.sub('\ufffd', content) + found = True + elif isinstance(content, list): + for part in content: + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str) and _SURROGATE_RE.search(text): + part["text"] = _SURROGATE_RE.sub('\ufffd', text) + found = True + return found + + +def _strip_budget_warnings_from_history(messages: list) -> None: + """Remove budget pressure warnings from tool-result messages in-place. + + Budget warnings are turn-scoped signals that must not leak into replayed + history. They live in tool-result ``content`` either as a JSON key + (``_budget_warning``) or appended plain text. + """ + for msg in messages: + if not isinstance(msg, dict) or msg.get("role") != "tool": + continue + content = msg.get("content") + if not isinstance(content, str) or "_budget_warning" not in content and "[BUDGET" not in content: + continue + + # Try JSON first (the common case: _budget_warning key in a dict) + try: + parsed = json.loads(content) + if isinstance(parsed, dict) and "_budget_warning" in parsed: + del parsed["_budget_warning"] + msg["content"] = json.dumps(parsed, ensure_ascii=False) + continue + except (json.JSONDecodeError, TypeError): + pass + + # Fallback: strip the text pattern from plain-text tool results + cleaned = _BUDGET_WARNING_RE.sub("", content).strip() + if cleaned != content: + msg["content"] = cleaned + + class AIAgent: """ AI Agent with tool calling capabilities. @@ -487,6 +571,7 @@ class AIAgent: # instead of going directly to stdout where patch_stdout's StdoutProxy # would mangle the escape sequences. None = use builtins.print. self._print_fn = None + self.background_review_callback = None # Optional sync callback for gateway delivery self.skip_context_files = skip_context_files self.pass_session_id = pass_session_id self.log_prefix_chars = log_prefix_chars @@ -534,6 +619,7 @@ class AIAgent: self.tool_progress_callback = tool_progress_callback self.thinking_callback = thinking_callback self.reasoning_callback = reasoning_callback + self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call self.clarify_callback = clarify_callback self.step_callback = step_callback self.stream_delta_callback = stream_delta_callback @@ -776,6 +862,25 @@ class AIAgent: } self._client_kwargs = client_kwargs # stored for rebuilding after interrupt + + # Enable fine-grained tool streaming for Claude on OpenRouter. + # Without this, Anthropic buffers the entire tool call and goes + # silent for minutes while thinking — OpenRouter's upstream proxy + # times out during the silence. The beta header makes Anthropic + # stream tool call arguments token-by-token, keeping the + # connection alive. + _effective_base = str(client_kwargs.get("base_url", "")).lower() + if "openrouter" in _effective_base and "claude" in (self.model or "").lower(): + headers = client_kwargs.get("default_headers") or {} + existing_beta = headers.get("x-anthropic-beta", "") + _FINE_GRAINED = "fine-grained-tool-streaming-2025-05-14" + if _FINE_GRAINED not in existing_beta: + if existing_beta: + headers["x-anthropic-beta"] = f"{existing_beta},{_FINE_GRAINED}" + else: + headers["x-anthropic-beta"] = _FINE_GRAINED + client_kwargs["default_headers"] = headers + self.api_key = client_kwargs.get("api_key", "") try: self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True) @@ -792,16 +897,30 @@ class AIAgent: except Exception as e: raise RuntimeError(f"Failed to initialize OpenAI client: {e}") - # Provider fallback — a single backup model/provider tried when the - # primary is exhausted (rate-limit, overload, connection failure). - # Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"} - self._fallback_model = fallback_model if isinstance(fallback_model, dict) else None + # Provider fallback chain — ordered list of backup providers tried + # when the primary is exhausted (rate-limit, overload, connection + # failure). Supports both legacy single-dict ``fallback_model`` and + # new list ``fallback_providers`` format. + if isinstance(fallback_model, list): + self._fallback_chain = [ + f for f in fallback_model + if isinstance(f, dict) and f.get("provider") and f.get("model") + ] + elif isinstance(fallback_model, dict) and fallback_model.get("provider") and fallback_model.get("model"): + self._fallback_chain = [fallback_model] + else: + self._fallback_chain = [] + self._fallback_index = 0 self._fallback_activated = False - if self._fallback_model: - fb_p = self._fallback_model.get("provider", "") - fb_m = self._fallback_model.get("model", "") - if fb_p and fb_m and not self.quiet_mode: - print(f"🔄 Fallback model: {fb_m} ({fb_p})") + # Legacy attribute kept for backward compat (tests, external callers) + self._fallback_model = self._fallback_chain[0] if self._fallback_chain else None + if self._fallback_chain and not self.quiet_mode: + if len(self._fallback_chain) == 1: + fb = self._fallback_chain[0] + print(f"🔄 Fallback model: {fb['model']} ({fb['provider']})") + else: + print(f"🔄 Fallback chain ({len(self._fallback_chain)} providers): " + + " → ".join(f"{f['model']} ({f['provider']})" for f in self._fallback_chain)) # Get available tools with filtering self.tools = get_tool_definitions( @@ -980,8 +1099,8 @@ class AIAgent: else: if not hcfg.enabled: logger.debug("Honcho disabled in global config") - elif not hcfg.api_key: - logger.debug("Honcho enabled but no API key configured") + elif not (hcfg.api_key or hcfg.base_url): + logger.debug("Honcho enabled but no API key or base URL configured") else: logger.debug("Honcho enabled but missing API key or disabled in config") except Exception as e: @@ -1018,6 +1137,13 @@ class AIAgent: except Exception: pass + # Tool-use enforcement config: "auto" (default — matches hardcoded + # model list), true (always), false (never), or list of substrings. + _agent_section = _agent_cfg.get("agent", {}) + if not isinstance(_agent_section, dict): + _agent_section = {} + self._tool_use_enforcement = _agent_section.get("tool_use_enforcement", "auto") + # Initialize context compressor for automatic context management # Compresses conversation when approaching model's context limit # Configuration via config.yaml (compression section) @@ -1160,7 +1286,7 @@ class AIAgent: try: fn = self._print_fn or print fn(*args, **kwargs) - except OSError: + except (OSError, ValueError): pass def _vprint(self, *args, force: bool = False, **kwargs): @@ -1526,6 +1652,12 @@ class AIAgent: if actions: summary = " · ".join(dict.fromkeys(actions)) self._safe_print(f" 💾 {summary}") + _bg_cb = self.background_review_callback + if _bg_cb: + try: + _bg_cb(f"💾 {summary}") + except Exception: + pass except Exception as e: logger.debug("Background memory/skill review failed: %s", e) @@ -2049,6 +2181,23 @@ class AIAgent: msg["content"] = self._clean_session_content(msg["content"]) cleaned.append(msg) + # Guard: never overwrite a larger session log with fewer messages. + # This protects against data loss when --resume loads a session whose + # messages weren't fully written to SQLite — the resumed agent starts + # with partial history and would otherwise clobber the full JSON log. + if self.session_log_file.exists(): + try: + existing = json.loads(self.session_log_file.read_text(encoding="utf-8")) + existing_count = existing.get("message_count", len(existing.get("messages", []))) + if existing_count > len(cleaned): + logging.debug( + "Skipping session log overwrite: existing has %d messages, current has %d", + existing_count, len(cleaned), + ) + return + except Exception: + pass # corrupted existing file — allow the overwrite + entry = { "session_id": self.session_id, "model": self.model, @@ -2158,8 +2307,14 @@ class AIAgent: # ── Honcho integration helpers ── def _honcho_should_activate(self, hcfg) -> bool: - """Return True when remote Honcho should be active.""" - if not hcfg or not hcfg.enabled or not hcfg.api_key: + """Return True when Honcho should be active. + + Self-hosted Honcho may be configured with a base_url and no API key, + so activation should accept either credential style. + """ + if not hcfg or not hcfg.enabled: + return False + if not (hcfg.api_key or hcfg.base_url): return False return True @@ -2428,6 +2583,29 @@ class AIAgent: nous_subscription_prompt = build_nous_subscription_prompt(self.valid_tool_names) if nous_subscription_prompt: prompt_parts.append(nous_subscription_prompt) + # Tool-use enforcement: tells the model to actually call tools instead + # of describing intended actions. Controlled by config.yaml + # agent.tool_use_enforcement: + # "auto" (default) — matches TOOL_USE_ENFORCEMENT_MODELS + # true — always inject (all models) + # false — never inject + # list — custom model-name substrings to match + if self.valid_tool_names: + _enforce = self._tool_use_enforcement + _inject = False + if _enforce is True or (isinstance(_enforce, str) and _enforce.lower() in ("true", "always", "yes", "on")): + _inject = True + elif _enforce is False or (isinstance(_enforce, str) and _enforce.lower() in ("false", "never", "no", "off")): + _inject = False + elif isinstance(_enforce, list): + model_lower = (self.model or "").lower() + _inject = any(p.lower() in model_lower for p in _enforce if isinstance(p, str)) + else: + # "auto" or any unrecognised value — use hardcoded defaults + model_lower = (self.model or "").lower() + _inject = any(p in model_lower for p in TOOL_USE_ENFORCEMENT_MODELS) + if _inject: + prompt_parts.append(TOOL_USE_ENFORCEMENT_GUIDANCE) # Honcho CLI awareness: tell Hermes about its own management commands # so it can refer the user to them rather than reinventing answers. @@ -2501,7 +2679,13 @@ class AIAgent: has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage']) if has_skills_tools: - avail_toolsets = {ts for ts, avail in check_toolset_requirements().items() if avail} + avail_toolsets = { + toolset + for toolset in ( + get_toolset_for_tool(tool_name) for tool_name in self.valid_tool_names + ) + if toolset + } skills_prompt = build_skills_system_prompt( available_tools=self.valid_tool_names, available_toolsets=avail_toolsets, @@ -3385,6 +3569,7 @@ class AIAgent: max_stream_retries = 1 has_tool_calls = False first_delta_fired = False + self._reasoning_deltas_fired = False for attempt in range(max_stream_retries + 1): try: with active_client.responses.stream(**api_kwargs) as stream: @@ -3661,6 +3846,7 @@ class AIAgent: def _fire_reasoning_delta(self, text: str) -> None: """Fire reasoning callback if registered.""" + self._reasoning_deltas_fired = True cb = self.reasoning_callback if cb is not None: try: @@ -3739,7 +3925,7 @@ class AIAgent: def _call_chat_completions(): """Stream a chat completions response.""" import httpx as _httpx - _base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 900.0)) + _base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0)) _stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 60.0)) stream_kwargs = { **api_kwargs, @@ -3755,16 +3941,28 @@ class AIAgent: request_client_holder["client"] = self._create_request_openai_client( reason="chat_completion_stream_request" ) + # Reset stale-stream timer so the detector measures from this + # attempt's start, not a previous attempt's last chunk. + last_chunk_time["t"] = time.time() stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) content_parts: list = [] tool_calls_acc: dict = {} tool_gen_notified: set = set() + # Ollama-compatible endpoints reuse index 0 for every tool call + # in a parallel batch, distinguishing them only by id. Track + # the last seen id per raw index so we can detect a new tool + # call starting at the same index and redirect it to a fresh slot. + _last_id_at_idx: dict = {} # raw_index -> last seen non-empty id + _active_slot_by_idx: dict = {} # raw_index -> current slot in tool_calls_acc finish_reason = None model_name = None role = "assistant" reasoning_parts: list = [] usage_obj = None + # Reset per-call reasoning tracking so _build_assistant_message + # knows whether reasoning was already displayed during streaming. + self._reasoning_deltas_fired = False for chunk in stream: last_chunk_time["t"] = time.time() @@ -3798,11 +3996,45 @@ class AIAgent: _fire_first_delta() self._fire_stream_delta(delta.content) deltas_were_sent["yes"] = True + else: + # Tool calls suppress regular content streaming (avoids + # displaying chatty "I'll use the tool..." text alongside + # tool calls). But reasoning tags embedded in suppressed + # content should still reach the display — otherwise the + # reasoning box only appears as a post-response fallback, + # rendering it confusingly after the already-streamed + # response. Route suppressed content through the stream + # delta callback so its tag extraction can fire the + # reasoning display. Non-reasoning text is harmlessly + # suppressed by the CLI's _stream_delta when the stream + # box is already closed (tool boundary flush). + if self.stream_delta_callback: + try: + self.stream_delta_callback(delta.content) + except Exception: + pass # Accumulate tool call deltas — notify display on first name if delta and delta.tool_calls: for tc_delta in delta.tool_calls: - idx = tc_delta.index if tc_delta.index is not None else 0 + raw_idx = tc_delta.index if tc_delta.index is not None else 0 + delta_id = tc_delta.id or "" + + # Ollama fix: detect a new tool call reusing the same + # raw index (different id) and redirect to a fresh slot. + if raw_idx not in _active_slot_by_idx: + _active_slot_by_idx[raw_idx] = raw_idx + if ( + delta_id + and raw_idx in _last_id_at_idx + and delta_id != _last_id_at_idx[raw_idx] + ): + new_slot = max(tool_calls_acc, default=-1) + 1 + _active_slot_by_idx[raw_idx] = new_slot + if delta_id: + _last_id_at_idx[raw_idx] = delta_id + idx = _active_slot_by_idx[raw_idx] + if idx not in tool_calls_acc: tool_calls_acc[idx] = { "id": tc_delta.id or "", @@ -3884,7 +4116,10 @@ class AIAgent: works unchanged. """ has_tool_use = False + self._reasoning_deltas_fired = False + # Reset stale-stream timer for this attempt + last_chunk_time["t"] = time.time() # Use the Anthropic SDK's streaming context manager with self._anthropic_client.messages.stream(**api_kwargs) as stream: for event in stream: @@ -3952,7 +4187,37 @@ class AIAgent: e, (_httpx.ConnectError, _httpx.RemoteProtocolError, ConnectionError) ) - if _is_timeout or _is_conn_err: + # SSE error events from proxies (e.g. OpenRouter sends + # {"error":{"message":"Network connection lost."}}) are + # raised as APIError by the OpenAI SDK. These are + # semantically identical to httpx connection drops — + # the upstream stream died — and should be retried with + # a fresh connection. Distinguish from HTTP errors: + # APIError from SSE has no status_code, while + # APIStatusError (4xx/5xx) always has one. + _is_sse_conn_err = False + if not _is_timeout and not _is_conn_err: + from openai import APIError as _APIError + if isinstance(e, _APIError) and not getattr(e, "status_code", None): + _err_lower_sse = str(e).lower() + _SSE_CONN_PHRASES = ( + "connection lost", + "connection reset", + "connection closed", + "connection terminated", + "network error", + "network connection", + "terminated", + "peer closed", + "broken pipe", + "upstream connect error", + ) + _is_sse_conn_err = any( + phrase in _err_lower_sse + for phrase in _SSE_CONN_PHRASES + ) + + if _is_timeout or _is_conn_err or _is_sse_conn_err: # Transient network / timeout error. Retry the # streaming request with a fresh connection first. if _stream_attempt < _max_stream_retries: @@ -3997,6 +4262,10 @@ class AIAgent: ) try: + # Reset stale timer — the non-streaming fallback + # uses its own client; prevent the stale detector + # from firing on stale timestamps from failed streams. + last_chunk_time["t"] = time.time() result["response"] = self._interruptible_api_call(api_kwargs) except Exception as fallback_err: result["error"] = fallback_err @@ -4006,7 +4275,19 @@ class AIAgent: if request_client is not None: self._close_request_openai_client(request_client, reason="stream_request_complete") - _stream_stale_timeout = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 90.0)) + _stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0)) + # Scale the stale timeout for large contexts: slow models (like Opus) + # can legitimately think for minutes before producing the first token + # when the context is large. Without this, the stale detector kills + # healthy connections during the model's thinking phase, producing + # spurious RemoteProtocolError ("peer closed connection"). + _est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 + if _est_tokens > 100_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 300.0) + elif _est_tokens > 50_000: + _stream_stale_timeout = max(_stream_stale_timeout_base, 240.0) + else: + _stream_stale_timeout = _stream_stale_timeout_base t = threading.Thread(target=_call, daemon=True) t.start() @@ -4055,25 +4336,26 @@ class AIAgent: # ── Provider fallback ────────────────────────────────────────────────── def _try_activate_fallback(self) -> bool: - """Switch to the configured fallback model/provider. + """Switch to the next fallback model/provider in the chain. - Called when the primary model is failing after retries. Swaps the + Called when the current model is failing after retries. Swaps the OpenAI client, model slug, and provider in-place so the retry loop - can continue with the new backend. One-shot: returns False if - already activated or not configured. + can continue with the new backend. Advances through the chain on + each call; returns False when exhausted. Uses the centralized provider router (resolve_provider_client) for auth resolution and client construction — no duplicated provider→key mappings. """ - if self._fallback_activated or not self._fallback_model: + if self._fallback_index >= len(self._fallback_chain): return False - fb = self._fallback_model + fb = self._fallback_chain[self._fallback_index] + self._fallback_index += 1 fb_provider = (fb.get("provider") or "").strip().lower() fb_model = (fb.get("model") or "").strip() if not fb_provider or not fb_model: - return False + return self._try_activate_fallback() # skip invalid, try next # Use centralized router for client construction. # raw_codex=True because the main agent needs direct responses.stream() @@ -4086,7 +4368,7 @@ class AIAgent: logging.warning( "Fallback to %s failed: provider not configured", fb_provider) - return False + return self._try_activate_fallback() # try next in chain # Determine api_mode from provider / base URL fb_api_mode = "chat_completions" @@ -4132,6 +4414,25 @@ class AIAgent: or is_native_anthropic ) + # Update context compressor limits for the fallback model. + # Without this, compression decisions use the primary model's + # context window (e.g. 200K) instead of the fallback's (e.g. 32K), + # causing oversized sessions to overflow the fallback. + if hasattr(self, 'context_compressor') and self.context_compressor: + from agent.model_metadata import get_model_context_length + fb_context_length = get_model_context_length( + self.model, base_url=self.base_url, + api_key=self.api_key, provider=self.provider, + ) + self.context_compressor.model = self.model + self.context_compressor.base_url = self.base_url + self.context_compressor.api_key = self.api_key + self.context_compressor.provider = self.provider + self.context_compressor.context_length = fb_context_length + self.context_compressor.threshold_tokens = int( + fb_context_length * self.context_compressor.threshold_percent + ) + self._emit_status( f"🔄 Primary model failed — switching to fallback: " f"{fb_model} via {fb_provider}" @@ -4142,8 +4443,8 @@ class AIAgent: ) return True except Exception as e: - logging.error("Failed to activate fallback model: %s", e) - return False + logging.error("Failed to activate fallback %s: %s", fb_model, e) + return self._try_activate_fallback() # try next in chain # ── End provider fallback ────────────────────────────────────────────── @@ -4301,6 +4602,10 @@ class AIAgent: if self.api_mode == "anthropic_messages": from agent.anthropic_adapter import build_anthropic_kwargs anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages) + # Pass context_length so the adapter can clamp max_tokens if the + # user configured a smaller context window than the model's output limit. + ctx_len = getattr(self, "context_compressor", None) + ctx_len = ctx_len.context_length if ctx_len else None return build_anthropic_kwargs( model=self.model, messages=anthropic_messages, @@ -4309,6 +4614,7 @@ class AIAgent: reasoning_config=self.reasoning_config, is_oauth=self._is_anthropic_oauth, preserve_dots=self._anthropic_preserve_dots(), + context_length=ctx_len, ) if self.api_mode == "codex_responses": @@ -4419,12 +4725,27 @@ class AIAgent: api_kwargs = { "model": self.model, "messages": sanitized_messages, - "tools": self.tools if self.tools else None, - "timeout": float(os.getenv("HERMES_API_TIMEOUT", 900.0)), + "timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)), } + if self.tools: + api_kwargs["tools"] = self.tools if self.max_tokens is not None: api_kwargs.update(self._max_tokens_param(self.max_tokens)) + elif self._is_openrouter_url() and "claude" in (self.model or "").lower(): + # OpenRouter translates requests to Anthropic's Messages API, + # which requires max_tokens as a mandatory field. When we omit + # it, OpenRouter picks a default that can be too low — the model + # spends its output budget on thinking and has almost nothing + # left for the actual response (especially large tool calls like + # write_file). Sending the model's real output limit ensures + # full capacity. Other providers handle the default fine. + try: + from agent.anthropic_adapter import _get_anthropic_max_output + _model_output_limit = _get_anthropic_max_output(self.model) + api_kwargs["max_tokens"] = _model_output_limit + except Exception: + pass # fail open — let OpenRouter pick its default extra_body = {} @@ -4560,11 +4881,15 @@ class AIAgent: logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {reasoning_text}") if reasoning_text and self.reasoning_callback: - # Skip callback for -extracted reasoning when streaming is active. - # _stream_delta() already displayed blocks during streaming; - # firing the callback again would cause duplicate display. - # Structured reasoning (from reasoning_content field) always fires. - if _from_structured or not self.stream_delta_callback: + # Skip callback when streaming is active — reasoning was already + # displayed during the stream via one of two paths: + # (a) _fire_reasoning_delta (structured reasoning_content deltas) + # (b) _stream_delta tag extraction (/) + # When streaming is NOT active, always fire so non-streaming modes + # (gateway, batch, quiet) still get reasoning. + # Any reasoning that wasn't shown during streaming is caught by the + # CLI post-response display fallback (cli.py _reasoning_shown_this_turn). + if not self.stream_delta_callback: try: self.reasoning_callback(reasoning_text) except Exception: @@ -4862,6 +5187,8 @@ class AIAgent: self._session_db.end_session(self.session_id, "compression") old_session_id = self.session_id self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" + # Update session_log_file to point to the new session's JSON file + self.session_log_file = self.logs_dir / f"session_{self.session_id}.json" self._session_db.create_session( session_id=self.session_id, source=self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"), @@ -5085,7 +5412,7 @@ class AIAgent: spinner = None if self.quiet_mode and not self.tool_progress_callback: face = random.choice(KawaiiSpinner.KAWAII_WAITING) - spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots') + spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn) spinner.start() try: @@ -5126,7 +5453,7 @@ class AIAgent: # Print cute message per tool if self.quiet_mode: cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result) - print(f" {cute_msg}") + self._safe_print(f" {cute_msg}") elif not self.quiet_mode: if self.verbose_logging: print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s") @@ -5311,7 +5638,7 @@ class AIAgent: spinner = None if self.quiet_mode and not self.tool_progress_callback: face = random.choice(KawaiiSpinner.KAWAII_WAITING) - spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots') + spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn) spinner.start() self._delegate_spinner = spinner _delegate_result = None @@ -5339,9 +5666,7 @@ class AIAgent: face = random.choice(KawaiiSpinner.KAWAII_WAITING) emoji = _get_tool_emoji(function_name) preview = _build_tool_preview(function_name, function_args) or function_name - if len(preview) > 30: - preview = preview[:27] + "..." - spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots') + spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots', print_fn=self._print_fn) spinner.start() _spinner_result = None try: @@ -5702,6 +6027,14 @@ class AIAgent: # Installed once, transparent when streams are healthy, prevents crash on write. _install_safe_stdio() + # Sanitize surrogate characters from user input. Clipboard paste from + # rich-text editors (Google Docs, Word, etc.) can inject lone surrogates + # that are invalid UTF-8 and crash JSON serialization in the OpenAI SDK. + if isinstance(user_message, str): + user_message = _sanitize_surrogates(user_message) + if isinstance(persist_user_message, str): + persist_user_message = _sanitize_surrogates(persist_user_message) + # Store stream callback for _interruptible_api_call to pick up self._stream_callback = stream_callback self._persist_user_message_idx = None @@ -5718,6 +6051,7 @@ class AIAgent: self._codex_incomplete_retries = 0 self._last_content_with_tools = None self._mute_post_response = False + self._surrogate_sanitized = False # NOTE: _turns_since_memory and _iters_since_skill are NOT reset here. # They are initialized in __init__ and must persist across run_conversation # calls so that nudge logic accumulates correctly in CLI mode. @@ -5725,6 +6059,14 @@ class AIAgent: # Initialize conversation (copy to avoid mutating the caller's list) messages = list(conversation_history) if conversation_history else [] + + # Strip budget pressure warnings from previous turns. These are + # turn-scoped signals injected by _get_budget_warning() into tool + # result content. If left in the replayed history, models (especially + # GPT-family) interpret them as still-active instructions and avoid + # making tool calls in ALL subsequent turns. + if messages: + _strip_budget_warnings_from_history(messages) # Hydrate todo store from conversation history (gateway creates a fresh # AIAgent per message, so the in-memory store is empty -- we need to @@ -5820,6 +6162,22 @@ class AIAgent: self._cached_system_prompt = ( self._cached_system_prompt + "\n\n" + self._honcho_context ).strip() + + # Plugin hook: on_session_start + # Fired once when a brand-new session is created (not on + # continuation). Plugins can use this to initialise + # session-scoped state (e.g. warm a memory cache). + try: + from hermes_cli.plugins import invoke_hook as _invoke_hook + _invoke_hook( + "on_session_start", + session_id=self.session_id, + model=self.model, + platform=getattr(self, "platform", None) or "", + ) + except Exception as exc: + logger.warning("on_session_start hook failed: %s", exc) + # Store the system prompt snapshot in SQLite if self._session_db: try: @@ -5881,6 +6239,34 @@ class AIAgent: if _preflight_tokens < self.context_compressor.threshold_tokens: break # Under threshold + # Plugin hook: pre_llm_call + # Fired once per turn before the tool-calling loop. Plugins can + # return a dict with a ``context`` key whose value is a string + # that will be appended to the ephemeral system prompt for every + # API call in this turn (not persisted to session DB or cache). + _plugin_turn_context = "" + try: + from hermes_cli.plugins import invoke_hook as _invoke_hook + _pre_results = _invoke_hook( + "pre_llm_call", + session_id=self.session_id, + user_message=original_user_message, + conversation_history=list(messages), + is_first_turn=(not bool(conversation_history)), + model=self.model, + platform=getattr(self, "platform", None) or "", + ) + _ctx_parts = [] + for r in _pre_results: + if isinstance(r, dict) and r.get("context"): + _ctx_parts.append(str(r["context"])) + elif isinstance(r, str) and r.strip(): + _ctx_parts.append(r) + if _ctx_parts: + _plugin_turn_context = "\n\n".join(_ctx_parts) + except Exception as exc: + logger.warning("pre_llm_call hook failed: %s", exc) + # Main conversation loop api_call_count = 0 final_response = None @@ -5978,6 +6364,9 @@ class AIAgent: effective_system = active_system_prompt or "" if self.ephemeral_system_prompt: effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip() + # Plugin context from pre_llm_call hooks — ephemeral, not cached. + if _plugin_turn_context: + effective_system = (effective_system + "\n\n" + _plugin_turn_context).strip() if effective_system: api_messages = [{"role": "system", "content": effective_system}] + api_messages @@ -6024,7 +6413,7 @@ class AIAgent: # Raw KawaiiSpinner only when no streaming consumers # (would conflict with streamed token output) spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star']) - thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type) + thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type, print_fn=self._print_fn) thinking_spinner.start() # Log request details if verbose @@ -6159,9 +6548,9 @@ class AIAgent: # Eager fallback: empty/malformed responses are a common # rate-limit symptom. Switch to fallback immediately # rather than retrying with extended backoff. - if not self._fallback_activated: + if self._fallback_index < len(self._fallback_chain): self._emit_status("⚠️ Empty/malformed response — switching to fallback...") - if not self._fallback_activated and self._try_activate_fallback(): + if self._try_activate_fallback(): retry_count = 0 continue @@ -6254,6 +6643,62 @@ class AIAgent: if finish_reason == "length": self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True) + # ── Detect thinking-budget exhaustion ────────────── + # When the model spends ALL output tokens on reasoning + # and has none left for the response, continuation + # retries are pointless. Detect this early and give a + # targeted error instead of wasting 3 API calls. + _trunc_content = None + if self.api_mode == "chat_completions": + _trunc_msg = response.choices[0].message if (hasattr(response, "choices") and response.choices) else None + _trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None + elif self.api_mode == "anthropic_messages": + # Anthropic response.content is a list of blocks + _text_parts = [] + for _blk in getattr(response, "content", []): + if getattr(_blk, "type", None) == "text": + _text_parts.append(getattr(_blk, "text", "")) + _trunc_content = "\n".join(_text_parts) if _text_parts else None + + _thinking_exhausted = ( + _trunc_content is not None + and not self._has_content_after_think_block(_trunc_content) + ) or _trunc_content is None + + if _thinking_exhausted: + _exhaust_error = ( + "Model used all output tokens on reasoning with none left " + "for the response. Try lowering reasoning effort or " + "increasing max_tokens." + ) + self._vprint( + f"{self.log_prefix}💭 Reasoning exhausted the output token budget — " + f"no visible response was produced.", + force=True, + ) + # Return a user-friendly message as the response so + # CLI (response box) and gateway (chat message) both + # display it naturally instead of a suppressed error. + _exhaust_response = ( + "⚠️ **Thinking Budget Exhausted**\n\n" + "The model used all its output tokens on reasoning " + "and had none left for the actual response.\n\n" + "To fix this:\n" + "→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n" + "→ Increase the output token limit: " + "set `model.max_tokens` in config.yaml" + ) + self._cleanup_task_resources(effective_task_id) + self._persist_session(messages, conversation_history) + return { + "final_response": _exhaust_response, + "messages": messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": _exhaust_error, + } + if self.api_mode == "chat_completions": assistant_message = response.choices[0].message if not assistant_message.tool_calls: @@ -6442,6 +6887,24 @@ class AIAgent: if self.thinking_callback: self.thinking_callback("") + # ----------------------------------------------------------- + # Surrogate character recovery. UnicodeEncodeError happens + # when the messages contain lone surrogates (U+D800..U+DFFF) + # that are invalid UTF-8. Common source: clipboard paste + # from Google Docs or similar rich-text editors. We sanitize + # the entire messages list in-place and retry once. + # ----------------------------------------------------------- + if isinstance(api_error, UnicodeEncodeError) and not getattr(self, '_surrogate_sanitized', False): + self._surrogate_sanitized = True + if _sanitize_messages_surrogates(messages): + self._vprint( + f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...", + force=True, + ) + continue + # Surrogates weren't in messages — might be in system + # prompt or prefill. Fall through to normal error path. + status_code = getattr(api_error, "status_code", None) if ( self.api_mode == "codex_responses" @@ -6481,8 +6944,10 @@ class AIAgent: print(f"{self.log_prefix} Auth method: {auth_method}") print(f"{self.log_prefix} Token prefix: {key[:12]}..." if key and len(key) > 12 else f"{self.log_prefix} Token: (empty or short)") print(f"{self.log_prefix} Troubleshooting:") - print(f"{self.log_prefix} • Check ANTHROPIC_TOKEN in ~/.hermes/.env for Hermes-managed OAuth/setup tokens") - print(f"{self.log_prefix} • Check ANTHROPIC_API_KEY in ~/.hermes/.env for API keys or legacy token values") + from hermes_constants import display_hermes_home as _dhh_fn + _dhh = _dhh_fn() + print(f"{self.log_prefix} • Check ANTHROPIC_TOKEN in {_dhh}/.env for Hermes-managed OAuth/setup tokens") + print(f"{self.log_prefix} • Check ANTHROPIC_API_KEY in {_dhh}/.env for API keys or legacy token values") print(f"{self.log_prefix} • For API keys: verify at https://console.anthropic.com/settings/keys") print(f"{self.log_prefix} • For Claude Code: run 'claude /login' to refresh, then retry") print(f"{self.log_prefix} • Clear stale keys: hermes config set ANTHROPIC_TOKEN \"\"") @@ -6548,7 +7013,7 @@ class AIAgent: or "usage limit" in error_msg or "quota" in error_msg ) - if is_rate_limited and not self._fallback_activated: + if is_rate_limited and self._fallback_index < len(self._fallback_chain): self._emit_status("⚠️ Rate limited — switching to fallback provider...") if self._try_activate_fallback(): retry_count = 0 @@ -6710,8 +7175,13 @@ class AIAgent: # 529 (Anthropic overloaded) is also transient. # Also catch local validation errors (ValueError, TypeError) — these # are programming bugs, not transient failures. + # Exclude UnicodeEncodeError — it's a ValueError subclass but is + # handled separately by the surrogate sanitization path above. _RETRYABLE_STATUS_CODES = {413, 429, 529} - is_local_validation_error = isinstance(api_error, (ValueError, TypeError)) + is_local_validation_error = ( + isinstance(api_error, (ValueError, TypeError)) + and not isinstance(api_error, UnicodeEncodeError) + ) # Detect generic 400s from Anthropic OAuth (transient server-side failures). # Real invalid_request_error responses include a descriptive message; # transient ones contain only "Error" or are empty. (ref: issue #1608) @@ -6779,8 +7249,41 @@ class AIAgent: retry_count = 0 continue _final_summary = self._summarize_api_error(api_error) - self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True) + if is_rate_limited: + self._vprint(f"{self.log_prefix}❌ Rate limit persisted after {max_retries} retries. Please try again later.", force=True) + else: + self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True) self._vprint(f"{self.log_prefix} 💀 Final error: {_final_summary}", force=True) + + # Detect SSE stream-drop pattern (e.g. "Network + # connection lost") and surface actionable guidance. + # This typically happens when the model generates a + # very large tool call (write_file with huge content) + # and the proxy/CDN drops the stream mid-response. + _is_stream_drop = ( + not getattr(api_error, "status_code", None) + and any(p in error_msg for p in ( + "connection lost", "connection reset", + "connection closed", "network connection", + "network error", "terminated", + )) + ) + if _is_stream_drop: + self._vprint( + f"{self.log_prefix} 💡 The provider's stream " + f"connection keeps dropping. This often happens " + f"when the model tries to write a very large " + f"file in a single tool call.", + force=True, + ) + self._vprint( + f"{self.log_prefix} Try asking the model " + f"to use execute_code with Python's open() for " + f"large files, or to write the file in smaller " + f"sections.", + force=True, + ) + logging.error( "%sAPI call failed after %s retries. %s | provider=%s model=%s msgs=%s tokens=~%s", self.log_prefix, max_retries, _final_summary, @@ -6790,8 +7293,18 @@ class AIAgent: api_kwargs, reason="max_retries_exhausted", error=api_error, ) self._persist_session(messages, conversation_history) + _final_response = f"API call failed after {max_retries} retries: {_final_summary}" + if _is_stream_drop: + _final_response += ( + "\n\nThe provider's stream connection keeps " + "dropping — this often happens when generating " + "very large tool call responses (e.g. write_file " + "with long content). Try asking me to use " + "execute_code with Python's open() for large " + "files, or to write in smaller sections." + ) return { - "final_response": f"API call failed after {max_retries} retries: {_final_summary}", + "final_response": _final_response, "messages": messages, "api_calls": api_call_count, "completed": False, @@ -6799,8 +7312,22 @@ class AIAgent: "error": _final_summary, } - wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s - self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...") + # For rate limits, respect the Retry-After header if present + _retry_after = None + if is_rate_limited: + _resp_headers = getattr(getattr(api_error, "response", None), "headers", None) + if _resp_headers and hasattr(_resp_headers, "get"): + _ra_raw = _resp_headers.get("retry-after") or _resp_headers.get("Retry-After") + if _ra_raw: + try: + _retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes + except (TypeError, ValueError): + pass + wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60) + if is_rate_limited: + self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...") + else: + self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...") logger.warning( "Retrying API call in %ss (attempt %s/%s) %s error=%s", wait_time, @@ -7169,7 +7696,6 @@ class AIAgent: except Exception: pass - _msg_count_before_tools = len(messages) self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count) # Signal that a paragraph break is needed before the next @@ -7187,18 +7713,18 @@ class AIAgent: if _tc_names == {"execute_code"}: self.iteration_budget.refund() - # Estimate next prompt size using real token counts from the - # last API response + rough estimate of newly appended tool - # results. This catches cases where tool results push the - # context past the limit that last_prompt_tokens alone misses - # (e.g. large file reads, web extractions). + # Use real token counts from the API response to decide + # compression. prompt_tokens + completion_tokens is the + # actual context size the provider reported plus the + # assistant turn — a tight lower bound for the next prompt. + # Tool results appended above aren't counted yet, but the + # threshold (default 50%) leaves ample headroom; if tool + # results push past it, the next API call will report the + # real total and trigger compression then. _compressor = self.context_compressor - _new_tool_msgs = messages[_msg_count_before_tools:] - _new_chars = sum(len(str(m.get("content", "") or "")) for m in _new_tool_msgs) - _estimated_next_prompt = ( + _real_tokens = ( _compressor.last_prompt_tokens + _compressor.last_completion_tokens - + _new_chars // 3 # conservative: JSON-heavy tool results ≈ 3 chars/token ) # ── Context pressure warnings (user-facing only) ────────── @@ -7208,12 +7734,12 @@ class AIAgent: # Does not inject into messages — just prints to CLI output # and fires status_callback for gateway platforms. if _compressor.threshold_tokens > 0: - _compaction_progress = _estimated_next_prompt / _compressor.threshold_tokens + _compaction_progress = _real_tokens / _compressor.threshold_tokens if _compaction_progress >= 0.85 and not self._context_pressure_warned: self._context_pressure_warned = True self._emit_context_pressure(_compaction_progress, _compressor) - if self.compression_enabled and _compressor.should_compress(_estimated_next_prompt): + if self.compression_enabled and _compressor.should_compress(_real_tokens): messages, active_system_prompt = self._compress_context( messages, system_message, approx_tokens=self.context_compressor.last_prompt_tokens, @@ -7387,7 +7913,7 @@ class AIAgent: error_msg = f"Error during OpenAI-compatible API call #{api_call_count}: {str(e)}" try: print(f"❌ {error_msg}") - except OSError: + except (OSError, ValueError): logger.error(error_msg) if self.verbose_logging: @@ -7460,6 +7986,25 @@ class AIAgent: self._honcho_sync(original_user_message, final_response) self._queue_honcho_prefetch(original_user_message) + # Plugin hook: post_llm_call + # Fired once per turn after the tool-calling loop completes. + # Plugins can use this to persist conversation data (e.g. sync + # to an external memory system). + if final_response and not interrupted: + try: + from hermes_cli.plugins import invoke_hook as _invoke_hook + _invoke_hook( + "post_llm_call", + session_id=self.session_id, + user_message=original_user_message, + assistant_response=final_response, + conversation_history=list(messages), + model=self.model, + platform=getattr(self, "platform", None) or "", + ) + except Exception as exc: + logger.warning("post_llm_call hook failed: %s", exc) + # Extract reasoning from the last assistant message (if any) last_reasoning = None for msg in reversed(messages): @@ -7525,6 +8070,22 @@ class AIAgent: except Exception: pass # Background review is best-effort + # Plugin hook: on_session_end + # Fired at the very end of every run_conversation call. + # Plugins can use this for cleanup, flushing buffers, etc. + try: + from hermes_cli.plugins import invoke_hook as _invoke_hook + _invoke_hook( + "on_session_end", + session_id=self.session_id, + completed=completed, + interrupted=interrupted, + model=self.model, + platform=getattr(self, "platform", None) or "", + ) + except Exception as exc: + logger.warning("on_session_end hook failed: %s", exc) + return result def chat(self, message: str, stream_callback: Optional[callable] = None) -> str: diff --git a/scripts/kill_modal.sh b/scripts/kill_modal.sh index aae3f63e2..1e9a33128 100755 --- a/scripts/kill_modal.sh +++ b/scripts/kill_modal.sh @@ -2,7 +2,7 @@ # Kill all running Modal apps (sandboxes, deployments, etc.) # # Usage: -# bash scripts/kill_modal.sh # Stop swe-rex (the sandbox app) +# bash scripts/kill_modal.sh # Stop hermes-agent sandboxes # bash scripts/kill_modal.sh --all # Stop ALL Modal apps set -uo pipefail @@ -17,10 +17,10 @@ if [[ "${1:-}" == "--all" ]]; then modal app stop "$app_id" 2>/dev/null || true done else - echo "Stopping swe-rex sandboxes..." - APPS=$(echo "$APP_LIST" | grep 'swe-rex' | grep -oE 'ap-[A-Za-z0-9]+' || true) + echo "Stopping hermes-agent sandboxes..." + APPS=$(echo "$APP_LIST" | grep 'hermes-agent' | grep -oE 'ap-[A-Za-z0-9]+' || true) if [[ -z "$APPS" ]]; then - echo " No swe-rex apps found." + echo " No hermes-agent apps found." else echo "$APPS" | while read app_id; do echo " Stopping $app_id" @@ -30,5 +30,5 @@ else fi echo "" -echo "Current swe-rex status:" -modal app list 2>/dev/null | grep -E 'State|swe-rex' || echo " (none)" +echo "Current hermes-agent status:" +modal app list 2>/dev/null | grep -E 'State|hermes-agent' || echo " (none)" diff --git a/scripts/whatsapp-bridge/allowlist.js b/scripts/whatsapp-bridge/allowlist.js new file mode 100644 index 000000000..760e413f2 --- /dev/null +++ b/scripts/whatsapp-bridge/allowlist.js @@ -0,0 +1,79 @@ +import path from 'path'; +import { existsSync, readFileSync } from 'fs'; + +export function normalizeWhatsAppIdentifier(value) { + return String(value || '') + .trim() + .replace(/:.*@/, '@') + .replace(/@.*/, '') + .replace(/^\+/, ''); +} + +export function parseAllowedUsers(rawValue) { + return new Set( + String(rawValue || '') + .split(',') + .map((value) => normalizeWhatsAppIdentifier(value)) + .filter(Boolean) + ); +} + +function readMappingFile(sessionDir, identifier, suffix = '') { + const filePath = path.join(sessionDir, `lid-mapping-${identifier}${suffix}.json`); + if (!existsSync(filePath)) { + return null; + } + + try { + const parsed = JSON.parse(readFileSync(filePath, 'utf8')); + const normalized = normalizeWhatsAppIdentifier(parsed); + return normalized || null; + } catch { + return null; + } +} + +export function expandWhatsAppIdentifiers(identifier, sessionDir) { + const normalized = normalizeWhatsAppIdentifier(identifier); + if (!normalized) { + return new Set(); + } + + // Walk both phone->LID and LID->phone mapping files so allowlists can use + // either form transparently in bot mode. + const resolved = new Set(); + const queue = [normalized]; + + while (queue.length > 0) { + const current = queue.shift(); + if (!current || resolved.has(current)) { + continue; + } + + resolved.add(current); + + for (const suffix of ['', '_reverse']) { + const mapped = readMappingFile(sessionDir, current, suffix); + if (mapped && !resolved.has(mapped)) { + queue.push(mapped); + } + } + } + + return resolved; +} + +export function matchesAllowedUser(senderId, allowedUsers, sessionDir) { + if (!allowedUsers || allowedUsers.size === 0) { + return true; + } + + const aliases = expandWhatsAppIdentifiers(senderId, sessionDir); + for (const alias of aliases) { + if (allowedUsers.has(alias)) { + return true; + } + } + + return false; +} diff --git a/scripts/whatsapp-bridge/allowlist.test.mjs b/scripts/whatsapp-bridge/allowlist.test.mjs new file mode 100644 index 000000000..7eea7399c --- /dev/null +++ b/scripts/whatsapp-bridge/allowlist.test.mjs @@ -0,0 +1,47 @@ +import test from 'node:test'; +import assert from 'node:assert/strict'; +import os from 'node:os'; +import path from 'node:path'; +import { mkdtempSync, rmSync, writeFileSync } from 'node:fs'; + +import { + expandWhatsAppIdentifiers, + matchesAllowedUser, + normalizeWhatsAppIdentifier, + parseAllowedUsers, +} from './allowlist.js'; + +test('normalizeWhatsAppIdentifier strips jid syntax and plus prefix', () => { + assert.equal(normalizeWhatsAppIdentifier('+19175395595@s.whatsapp.net'), '19175395595'); + assert.equal(normalizeWhatsAppIdentifier('267383306489914@lid'), '267383306489914'); + assert.equal(normalizeWhatsAppIdentifier('19175395595:12@s.whatsapp.net'), '19175395595'); +}); + +test('expandWhatsAppIdentifiers resolves phone and lid aliases from session files', () => { + const sessionDir = mkdtempSync(path.join(os.tmpdir(), 'hermes-wa-allowlist-')); + + try { + writeFileSync(path.join(sessionDir, 'lid-mapping-19175395595.json'), JSON.stringify('267383306489914')); + writeFileSync(path.join(sessionDir, 'lid-mapping-267383306489914_reverse.json'), JSON.stringify('19175395595')); + + const aliases = expandWhatsAppIdentifiers('267383306489914@lid', sessionDir); + assert.deepEqual([...aliases].sort(), ['19175395595', '267383306489914']); + } finally { + rmSync(sessionDir, { recursive: true, force: true }); + } +}); + +test('matchesAllowedUser accepts mapped lid sender when allowlist only contains phone number', () => { + const sessionDir = mkdtempSync(path.join(os.tmpdir(), 'hermes-wa-allowlist-')); + + try { + writeFileSync(path.join(sessionDir, 'lid-mapping-19175395595.json'), JSON.stringify('267383306489914')); + writeFileSync(path.join(sessionDir, 'lid-mapping-267383306489914_reverse.json'), JSON.stringify('19175395595')); + + const allowedUsers = parseAllowedUsers('+19175395595'); + assert.equal(matchesAllowedUser('267383306489914@lid', allowedUsers, sessionDir), true); + assert.equal(matchesAllowedUser('188012763865257@lid', allowedUsers, sessionDir), false); + } finally { + rmSync(sessionDir, { recursive: true, force: true }); + } +}); diff --git a/scripts/whatsapp-bridge/bridge.js b/scripts/whatsapp-bridge/bridge.js index 0dff8c2e2..46cc5c339 100644 --- a/scripts/whatsapp-bridge/bridge.js +++ b/scripts/whatsapp-bridge/bridge.js @@ -26,6 +26,7 @@ import path from 'path'; import { mkdirSync, readFileSync, writeFileSync, existsSync, readdirSync } from 'fs'; import { randomBytes } from 'crypto'; import qrcode from 'qrcode-terminal'; +import { matchesAllowedUser, parseAllowedUsers } from './allowlist.js'; // Parse CLI args const args = process.argv.slice(2); @@ -47,7 +48,7 @@ const DOCUMENT_CACHE_DIR = path.join(process.env.HOME || '~', '.hermes', 'docume const AUDIO_CACHE_DIR = path.join(process.env.HOME || '~', '.hermes', 'audio_cache'); const PAIR_ONLY = args.includes('--pair-only'); const WHATSAPP_MODE = getArg('mode', process.env.WHATSAPP_MODE || 'self-chat'); // "bot" or "self-chat" -const ALLOWED_USERS = (process.env.WHATSAPP_ALLOWED_USERS || '').split(',').map(s => s.trim()).filter(Boolean); +const ALLOWED_USERS = parseAllowedUsers(process.env.WHATSAPP_ALLOWED_USERS || ''); const DEFAULT_REPLY_PREFIX = '⚕ *Hermes Agent*\n────────────\n'; const REPLY_PREFIX = process.env.WHATSAPP_REPLY_PREFIX === undefined ? DEFAULT_REPLY_PREFIX @@ -190,10 +191,9 @@ async function startSocket() { if (!isSelfChat) continue; } - // Check allowlist for messages from others (resolve LID → phone if needed) - if (!msg.key.fromMe && ALLOWED_USERS.length > 0) { - const resolvedNumber = lidToPhone[senderNumber] || senderNumber; - if (!ALLOWED_USERS.includes(resolvedNumber)) continue; + // Check allowlist for messages from others (resolve LID ↔ phone aliases) + if (!msg.key.fromMe && !matchesAllowedUser(senderId, ALLOWED_USERS, SESSION_DIR)) { + continue; } // Extract message body @@ -515,8 +515,8 @@ if (PAIR_ONLY) { app.listen(PORT, '127.0.0.1', () => { console.log(`🌉 WhatsApp bridge listening on port ${PORT} (mode: ${WHATSAPP_MODE})`); console.log(`📁 Session stored in: ${SESSION_DIR}`); - if (ALLOWED_USERS.length > 0) { - console.log(`🔒 Allowed users: ${ALLOWED_USERS.join(', ')}`); + if (ALLOWED_USERS.size > 0) { + console.log(`🔒 Allowed users: ${Array.from(ALLOWED_USERS).join(', ')}`); } else { console.log(`⚠️ No WHATSAPP_ALLOWED_USERS set — all messages will be processed`); } diff --git a/skills/creative/songwriting-and-ai-music/SKILL.md b/skills/creative/songwriting-and-ai-music/SKILL.md new file mode 100644 index 000000000..2f1fc7282 --- /dev/null +++ b/skills/creative/songwriting-and-ai-music/SKILL.md @@ -0,0 +1,289 @@ +--- +name: songwriting-and-ai-music +description: > + Songwriting craft, AI music generation prompts (Suno focus), parody/adaptation + techniques, phonetic tricks, and lessons learned. These are tools and ideas, + not rules. Break any of them when the art calls for it. +tags: [songwriting, music, suno, parody, lyrics, creative] +triggers: + - writing a song + - song lyrics + - music prompt + - suno prompt + - parody song + - adapting a song + - AI music generation +--- + +# Songwriting & AI Music Generation + +Everything here is a GUIDELINE, not a rule. Art breaks rules on purpose. +Use what serves the song. Ignore what doesn't. + +--- + +## 1. Song Structure (Pick One or Invent Your Own) + +Common skeletons — mix, modify, or throw out as needed: + +``` +ABABCB Verse/Chorus/Verse/Chorus/Bridge/Chorus (most pop/rock) +AABA Verse/Verse/Bridge/Verse (refrain-based) (jazz standards, ballads) +ABAB Verse/Chorus alternating (simple, direct) +AAA Verse/Verse/Verse (strophic, no chorus) (folk, storytelling) +``` + +The six building blocks: +- Intro — set the mood, pull the listener in +- Verse — the story, the details, the world-building +- Pre-Chorus — optional tension ramp before the payoff +- Chorus — the emotional core, the part people remember +- Bridge — a detour, a shift in perspective or key +- Outro — the farewell, can echo or subvert the rest + +You don't need all of these. Some great songs are just one section +that evolves. Structure serves the emotion, not the other way around. + +--- + +## 2. Rhyme, Meter, and Sound + +RHYME TYPES (from tight to loose): +- Perfect: lean/mean +- Family: crate/braid +- Assonance: had/glass (same vowels, different endings) +- Consonance: scene/when (different vowels, similar endings) +- Near/slant: enough to suggest connection without locking it down + +Mix them. All perfect rhymes can sound like a nursery rhyme. +All slant rhymes can sound lazy. The blend is where it lives. + +INTERNAL RHYME: Rhyming within a line, not just at the ends. + "We pruned the lies from bleeding trees / Distilled the storm + from entropy" — "lies/flies," "trees/entropy" create internal echoes. + +METER: The rhythm of stressed vs unstressed syllables. +- Matching syllable counts between parallel lines helps singability +- The STRESSED syllables matter more than total count +- Say it out loud. If you stumble, the meter needs work. +- Intentionally breaking meter can create emphasis or surprise + +--- + +## 3. Emotional Arc and Dynamics + +Think of a song as a journey, not a flat road. + +ENERGY MAPPING (rough idea, not prescription): + Intro: 2-3 | Verse: 5-6 | Pre-Chorus: 7 + Chorus: 8-9 | Bridge: varies | Final Chorus: 9-10 + +The most powerful dynamic trick: CONTRAST. +- Whisper before a scream hits harder than just screaming +- Sparse before dense. Slow before fast. Low before high. +- The drop only works because of the buildup +- Silence is an instrument + +"Whisper to roar to whisper" — start intimate, build to full power, +strip back to vulnerability. Works for ballads, epics, anthems. + +--- + +## 4. Writing Lyrics That Work + +SHOW, DON'T TELL (usually): +- "I was sad" = flat +- "Your hoodie's still on the hook by the door" = alive +- But sometimes "I give my life" said plainly IS the power + +THE HOOK: +- The line people remember, hum, repeat +- Usually the title or core phrase +- Works best when melody + lyric + emotion all align +- Place it where it lands hardest (often first/last line of chorus) + +PROSODY — lyrics and music supporting each other: +- Stable feelings (resolution, peace) pair with settled melodies, + perfect rhymes, resolved chords +- Unstable feelings (longing, doubt) pair with wandering melodies, + near-rhymes, unresolved chords +- Verse melody typically sits lower, chorus goes higher +- But flip this if it serves the song + +AVOID (unless you're doing it on purpose): +- Cliches on autopilot ("heart of gold" without earning it) +- Forcing word order to hit a rhyme ("Yoda-speak") +- Same energy in every section (flat dynamics) +- Treating your first draft as sacred — revision is creation + +--- + +## 5. Parody and Adaptation + +When rewriting an existing song with new lyrics: + +THE SKELETON: Map the original's structure first. +- Count syllables per line +- Mark the rhyme scheme (ABAB, AABB, etc.) +- Identify which syllables are STRESSED +- Note where held/sustained notes fall + +FITTING NEW WORDS: +- Match stressed syllables to the same beats as the original +- Total syllable count can flex by 1-2 unstressed syllables +- On long held notes, try to match the VOWEL SOUND of the original + (if original holds "LOOOVE" with an "oo" vowel, "FOOOD" fits + better than "LIFE") +- Monosyllabic swaps in key spots keep rhythm intact + (Crime -> Code, Snake -> Noose) +- Sing your new words over the original — if you stumble, revise + +CONCEPT: +- Pick a concept strong enough to sustain the whole song +- Start from the title/hook and build outward +- Generate lots of raw material (puns, phrases, images) FIRST, + then fit the best ones into the structure +- If you need a specific line somewhere, reverse-engineer the + rhyme scheme backward to set it up + +KEEP SOME ORIGINALS: Leaving a few original lines or structures +intact adds recognizability and lets the audience feel the connection. + +--- + +## 6. Suno AI Prompt Engineering + +### Style/Genre Description Field + +FORMULA (adapt as needed): + Genre + Mood + Era + Instruments + Vocal Style + Production + Dynamics + +``` +BAD: "sad rock song" +GOOD: "Cinematic orchestral spy thriller, 1960s Cold War era, smoky + sultry female vocalist, big band jazz, brass section with + trumpets and french horns, sweeping strings, minor key, + vintage analog warmth" +``` + +DESCRIBE THE JOURNEY, not just the genre: +``` +"Begins as a haunting whisper over sparse piano. Gradually layers + in muted brass. Builds through the chorus with full orchestra. + Second verse erupts with raw belting intensity. Outro strips back + to a lone piano and a fragile whisper fading to silence." +``` + +TIPS: +- V4.5+ supports up to 1,000 chars in Style field — use them +- NO artist names or trademarks. Describe the sound instead. + "1960s Cold War spy thriller brass" not "James Bond style" + "90s grunge" not "Nirvana-style" +- Specify BPM and key when you have a preference +- Use Exclude Styles field for what you DON'T want +- Unexpected genre combos can be gold: "bossa nova trap", + "Appalachian gothic", "chiptune jazz" +- Build a vocal PERSONA, not just a gender: + "A weathered torch singer with a smoky alto, slight rasp, + who starts vulnerable and builds to devastating power" + +### Metatags (place in [brackets] inside lyrics field) + +STRUCTURE: + [Intro] [Verse] [Verse 1] [Pre-Chorus] [Chorus] + [Post-Chorus] [Hook] [Bridge] [Interlude] + [Instrumental] [Instrumental Break] [Guitar Solo] + [Breakdown] [Build-up] [Outro] [Silence] [End] + +VOCAL PERFORMANCE: + [Whispered] [Spoken Word] [Belted] [Falsetto] [Powerful] + [Soulful] [Raspy] [Breathy] [Smooth] [Gritty] + [Staccato] [Legato] [Vibrato] [Melismatic] + [Harmonies] [Choir] [Harmonized Chorus] + +DYNAMICS: + [High Energy] [Low Energy] [Building Energy] [Explosive] + [Emotional Climax] [Gradual swell] [Orchestral swell] + [Quiet arrangement] [Falling tension] [Slow Down] + +GENDER: + [Female Vocals] [Male Vocals] + +ATMOSPHERE: + [Melancholic] [Euphoric] [Nostalgic] [Aggressive] + [Dreamy] [Intimate] [Dark Atmosphere] + +SFX: + [Vinyl Crackle] [Rain] [Applause] [Static] [Thunder] + +Put tags in BOTH style field AND lyrics for reinforcement. +Keep to 5-8 tags per section max — too many confuses the AI. +Don't contradict yourself ([Calm] + [Aggressive] in same section). + +### Custom Mode +- Always use Custom Mode for serious work (separate Style + Lyrics) +- Lyrics field limit: ~3,000 chars (~40-60 lines) +- Always add structural tags — without them Suno defaults to + flat verse/chorus/verse with no emotional arc + +--- + +## 7. Phonetic Tricks for AI Singers + +AI vocalists don't read — they pronounce. Help them: + +PHONETIC RESPELLING: +- Spell words as they SOUND: "through" -> "thru" +- Proper nouns are highest failure rate — test early +- "Nous" -> "Noose" (forces correct pronunciation) +- Hyphenate to guide syllables: "Re-search", "bio-engineering" + +DELIVERY CONTROL: +- ALL CAPS = louder, more intense +- Vowel extension: "lo-o-o-ove" = sustained/melisma +- Ellipses: "I... need... you" = dramatic pauses +- Hyphenated stretch: "ne-e-ed" = emotional stretch + +ALWAYS: +- Spell out numbers: "24/7" -> "twenty four seven" +- Space acronyms: "AI" -> "A I" or "A-I" +- Test proper nouns/unusual words in a short 30-second clip first +- Once generated, pronunciation is baked in — fix in lyrics BEFORE + +--- + +## 8. Workflow + +1. Write the concept/hook first — what's the emotional core? +2. If adapting, map the original structure (syllables, rhyme, stress) +3. Generate raw material — brainstorm freely before structuring +4. Draft lyrics into the structure +5. Read/sing aloud — catch stumbles, fix meter +6. Build the Suno style description — paint the dynamic journey +7. Add metatags to lyrics for performance direction +8. Generate 3-5 variations minimum — treat them like recording takes +9. Pick the best, use Extend/Continue to build on promising sections +10. If something great happens by accident, keep it + +EXPECT: ~3-5 generations per 1 good result. Revision is normal. +Style can drift in extensions — restate genre/mood when extending. + +--- + +## 9. Lessons Learned + +- Describing the dynamic ARC in the style field matters way more + than just listing genres. "Whisper to roar to whisper" gives + Suno a performance map. +- Keeping some original lines intact in a parody adds recognizability + and emotional weight — the audience feels the ghost of the original. +- The bridge slot in a song is where you can transform imagery. + Swap the original's specific references for your theme's metaphors + while keeping the emotional function (reflection, shift, revelation). +- Monosyllabic word swaps in hooks/tags are the cleanest way to + maintain rhythm while changing meaning. +- A strong vocal persona description in the style field makes a + bigger difference than any single metatag. +- Don't be precious about rules. If a line breaks meter but hits + harder, keep it. The feeling is what matters. Craft serves art, + not the other way around. diff --git a/skills/devops/webhook-subscriptions/SKILL.md b/skills/devops/webhook-subscriptions/SKILL.md new file mode 100644 index 000000000..e5ab6d588 --- /dev/null +++ b/skills/devops/webhook-subscriptions/SKILL.md @@ -0,0 +1,180 @@ +--- +name: webhook-subscriptions +description: Create and manage webhook subscriptions for event-driven agent activation. Use when the user wants external services to trigger agent runs automatically. +version: 1.0.0 +metadata: + hermes: + tags: [webhook, events, automation, integrations] +--- + +# Webhook Subscriptions + +Create dynamic webhook subscriptions so external services (GitHub, GitLab, Stripe, CI/CD, IoT sensors, monitoring tools) can trigger Hermes agent runs by POSTing events to a URL. + +## Setup (Required First) + +The webhook platform must be enabled before subscriptions can be created. Check with: +```bash +hermes webhook list +``` + +If it says "Webhook platform is not enabled", set it up: + +### Option 1: Setup wizard +```bash +hermes gateway setup +``` +Follow the prompts to enable webhooks, set the port, and set a global HMAC secret. + +### Option 2: Manual config +Add to `~/.hermes/config.yaml`: +```yaml +platforms: + webhook: + enabled: true + extra: + host: "0.0.0.0" + port: 8644 + secret: "generate-a-strong-secret-here" +``` + +### Option 3: Environment variables +Add to `~/.hermes/.env`: +```bash +WEBHOOK_ENABLED=true +WEBHOOK_PORT=8644 +WEBHOOK_SECRET=generate-a-strong-secret-here +``` + +After configuration, start (or restart) the gateway: +```bash +hermes gateway run +# Or if using systemd: +systemctl --user restart hermes-gateway +``` + +Verify it's running: +```bash +curl http://localhost:8644/health +``` + +## Commands + +All management is via the `hermes webhook` CLI command: + +### Create a subscription +```bash +hermes webhook subscribe \ + --prompt "Prompt template with {payload.fields}" \ + --events "event1,event2" \ + --description "What this does" \ + --skills "skill1,skill2" \ + --deliver telegram \ + --deliver-chat-id "12345" \ + --secret "optional-custom-secret" +``` + +Returns the webhook URL and HMAC secret. The user configures their service to POST to that URL. + +### List subscriptions +```bash +hermes webhook list +``` + +### Remove a subscription +```bash +hermes webhook remove +``` + +### Test a subscription +```bash +hermes webhook test +hermes webhook test --payload '{"key": "value"}' +``` + +## Prompt Templates + +Prompts support `{dot.notation}` for accessing nested payload fields: + +- `{issue.title}` — GitHub issue title +- `{pull_request.user.login}` — PR author +- `{data.object.amount}` — Stripe payment amount +- `{sensor.temperature}` — IoT sensor reading + +If no prompt is specified, the full JSON payload is dumped into the agent prompt. + +## Common Patterns + +### GitHub: new issues +```bash +hermes webhook subscribe github-issues \ + --events "issues" \ + --prompt "New GitHub issue #{issue.number}: {issue.title}\n\nAction: {action}\nAuthor: {issue.user.login}\nBody:\n{issue.body}\n\nPlease triage this issue." \ + --deliver telegram \ + --deliver-chat-id "-100123456789" +``` + +Then in GitHub repo Settings → Webhooks → Add webhook: +- Payload URL: the returned webhook_url +- Content type: application/json +- Secret: the returned secret +- Events: "Issues" + +### GitHub: PR reviews +```bash +hermes webhook subscribe github-prs \ + --events "pull_request" \ + --prompt "PR #{pull_request.number} {action}: {pull_request.title}\nBy: {pull_request.user.login}\nBranch: {pull_request.head.ref}\n\n{pull_request.body}" \ + --skills "github-code-review" \ + --deliver github_comment +``` + +### Stripe: payment events +```bash +hermes webhook subscribe stripe-payments \ + --events "payment_intent.succeeded,payment_intent.payment_failed" \ + --prompt "Payment {data.object.status}: {data.object.amount} cents from {data.object.receipt_email}" \ + --deliver telegram \ + --deliver-chat-id "-100123456789" +``` + +### CI/CD: build notifications +```bash +hermes webhook subscribe ci-builds \ + --events "pipeline" \ + --prompt "Build {object_attributes.status} on {project.name} branch {object_attributes.ref}\nCommit: {commit.message}" \ + --deliver discord \ + --deliver-chat-id "1234567890" +``` + +### Generic monitoring alert +```bash +hermes webhook subscribe alerts \ + --prompt "Alert: {alert.name}\nSeverity: {alert.severity}\nMessage: {alert.message}\n\nPlease investigate and suggest remediation." \ + --deliver origin +``` + +## Security + +- Each subscription gets an auto-generated HMAC-SHA256 secret (or provide your own with `--secret`) +- The webhook adapter validates signatures on every incoming POST +- Static routes from config.yaml cannot be overwritten by dynamic subscriptions +- Subscriptions persist to `~/.hermes/webhook_subscriptions.json` + +## How It Works + +1. `hermes webhook subscribe` writes to `~/.hermes/webhook_subscriptions.json` +2. The webhook adapter hot-reloads this file on each incoming request (mtime-gated, negligible overhead) +3. When a POST arrives matching a route, the adapter formats the prompt and triggers an agent run +4. The agent's response is delivered to the configured target (Telegram, Discord, GitHub comment, etc.) + +## Troubleshooting + +If webhooks aren't working: + +1. **Is the gateway running?** Check with `systemctl --user status hermes-gateway` or `ps aux | grep gateway` +2. **Is the webhook server listening?** `curl http://localhost:8644/health` should return `{"status": "ok"}` +3. **Check gateway logs:** `grep webhook ~/.hermes/logs/gateway.log | tail -20` +4. **Signature mismatch?** Verify the secret in your service matches the one from `hermes webhook list`. GitHub sends `X-Hub-Signature-256`, GitLab sends `X-Gitlab-Token`. +5. **Firewall/NAT?** The webhook URL must be reachable from the service. For local development, use a tunnel (ngrok, cloudflared). +6. **Wrong event type?** Check `--events` filter matches what the service sends. Use `hermes webhook test ` to verify the route works. diff --git a/skills/github/github-auth/SKILL.md b/skills/github/github-auth/SKILL.md index 10c2560d0..ea8f369c4 100644 --- a/skills/github/github-auth/SKILL.md +++ b/skills/github/github-auth/SKILL.md @@ -219,6 +219,9 @@ if command -v gh &>/dev/null && gh auth status &>/dev/null; then echo "AUTH_METHOD=gh" elif [ -n "$GITHUB_TOKEN" ]; then echo "AUTH_METHOD=curl" +elif [ -f ~/.hermes/.env ] && grep -q "^GITHUB_TOKEN=" ~/.hermes/.env; then + export GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" ~/.hermes/.env | head -1 | cut -d= -f2 | tr -d '\n\r') + echo "AUTH_METHOD=curl" elif grep -q "github.com" ~/.git-credentials 2>/dev/null; then export GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') echo "AUTH_METHOD=curl" diff --git a/skills/github/github-auth/scripts/gh-env.sh b/skills/github/github-auth/scripts/gh-env.sh index c66e78ad3..043c6b555 100755 --- a/skills/github/github-auth/scripts/gh-env.sh +++ b/skills/github/github-auth/scripts/gh-env.sh @@ -23,6 +23,11 @@ if command -v gh &>/dev/null && gh auth status &>/dev/null 2>&1; then GH_USER=$(gh api user --jq '.login' 2>/dev/null) elif [ -n "$GITHUB_TOKEN" ]; then GH_AUTH_METHOD="curl" +elif [ -f "$HOME/.hermes/.env" ] && grep -q "^GITHUB_TOKEN=" "$HOME/.hermes/.env" 2>/dev/null; then + GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" "$HOME/.hermes/.env" | head -1 | cut -d= -f2 | tr -d '\n\r') + if [ -n "$GITHUB_TOKEN" ]; then + GH_AUTH_METHOD="curl" + fi elif [ -f "$HOME/.git-credentials" ] && grep -q "github.com" "$HOME/.git-credentials" 2>/dev/null; then GITHUB_TOKEN=$(grep "github.com" "$HOME/.git-credentials" | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') if [ -n "$GITHUB_TOKEN" ]; then diff --git a/skills/github/github-code-review/SKILL.md b/skills/github/github-code-review/SKILL.md index 64b02328e..52d8e4a07 100644 --- a/skills/github/github-code-review/SKILL.md +++ b/skills/github/github-code-review/SKILL.md @@ -27,7 +27,11 @@ if command -v gh &>/dev/null && gh auth status &>/dev/null; then else AUTH="git" if [ -z "$GITHUB_TOKEN" ]; then - GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + if [ -f ~/.hermes/.env ] && grep -q "^GITHUB_TOKEN=" ~/.hermes/.env; then + GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" ~/.hermes/.env | head -1 | cut -d= -f2 | tr -d '\n\r') + elif grep -q "github.com" ~/.git-credentials 2>/dev/null; then + GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + fi fi fi diff --git a/skills/github/github-issues/SKILL.md b/skills/github/github-issues/SKILL.md index 019c08a0f..a3bceb8e3 100644 --- a/skills/github/github-issues/SKILL.md +++ b/skills/github/github-issues/SKILL.md @@ -27,7 +27,11 @@ if command -v gh &>/dev/null && gh auth status &>/dev/null; then else AUTH="git" if [ -z "$GITHUB_TOKEN" ]; then - GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + if [ -f ~/.hermes/.env ] && grep -q "^GITHUB_TOKEN=" ~/.hermes/.env; then + GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" ~/.hermes/.env | head -1 | cut -d= -f2 | tr -d '\n\r') + elif grep -q "github.com" ~/.git-credentials 2>/dev/null; then + GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + fi fi fi diff --git a/skills/github/github-pr-workflow/SKILL.md b/skills/github/github-pr-workflow/SKILL.md index d09911e52..48f15ed7a 100644 --- a/skills/github/github-pr-workflow/SKILL.md +++ b/skills/github/github-pr-workflow/SKILL.md @@ -29,7 +29,11 @@ else AUTH="git" # Ensure we have a token for API calls if [ -z "$GITHUB_TOKEN" ]; then - GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + if [ -f ~/.hermes/.env ] && grep -q "^GITHUB_TOKEN=" ~/.hermes/.env; then + GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" ~/.hermes/.env | head -1 | cut -d= -f2 | tr -d '\n\r') + elif grep -q "github.com" ~/.git-credentials 2>/dev/null; then + GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + fi fi fi echo "Using: $AUTH" diff --git a/skills/github/github-repo-management/SKILL.md b/skills/github/github-repo-management/SKILL.md index 7ef95eb2d..b3732f29a 100644 --- a/skills/github/github-repo-management/SKILL.md +++ b/skills/github/github-repo-management/SKILL.md @@ -26,7 +26,11 @@ if command -v gh &>/dev/null && gh auth status &>/dev/null; then else AUTH="git" if [ -z "$GITHUB_TOKEN" ]; then - GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + if [ -f ~/.hermes/.env ] && grep -q "^GITHUB_TOKEN=" ~/.hermes/.env; then + GITHUB_TOKEN=$(grep "^GITHUB_TOKEN=" ~/.hermes/.env | head -1 | cut -d= -f2 | tr -d '\n\r') + elif grep -q "github.com" ~/.git-credentials 2>/dev/null; then + GITHUB_TOKEN=$(grep "github.com" ~/.git-credentials 2>/dev/null | head -1 | sed 's|https://[^:]*:\([^@]*\)@.*|\1|') + fi fi fi diff --git a/skills/music-creation/DESCRIPTION.md b/skills/music-creation/DESCRIPTION.md deleted file mode 100644 index 04ad703c9..000000000 --- a/skills/music-creation/DESCRIPTION.md +++ /dev/null @@ -1,3 +0,0 @@ ---- -description: Skills for generating, editing, and processing music and audio using AI models and audio tools. ---- diff --git a/skills/productivity/google-workspace/SKILL.md b/skills/productivity/google-workspace/SKILL.md index 00d91de90..5d1c71bfb 100644 --- a/skills/productivity/google-workspace/SKILL.md +++ b/skills/productivity/google-workspace/SKILL.md @@ -4,6 +4,11 @@ description: Gmail, Calendar, Drive, Contacts, Sheets, and Docs integration via version: 1.0.0 author: Nous Research license: MIT +required_credential_files: + - path: google_token.json + description: Google OAuth2 token (created by setup script) + - path: google_client_secret.json + description: Google OAuth2 client credentials (downloaded from Google Cloud Console) metadata: hermes: tags: [Google, Gmail, Calendar, Drive, Sheets, Docs, Contacts, Email, OAuth] diff --git a/skills/research/duckduckgo-search/SKILL.md b/skills/research/duckduckgo-search/SKILL.md index 0bfc64739..ea14e6b30 100644 --- a/skills/research/duckduckgo-search/SKILL.md +++ b/skills/research/duckduckgo-search/SKILL.md @@ -1,7 +1,7 @@ --- name: duckduckgo-search -description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Use the Python DDGS library or CLI to search, then web_extract for full content. -version: 1.2.0 +description: Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Prefer the `ddgs` CLI when installed; use the Python DDGS library only after verifying that `ddgs` is available in the current runtime. +version: 1.3.0 author: gamedevCloudy license: MIT metadata: @@ -9,26 +9,96 @@ metadata: tags: [search, duckduckgo, web-search, free, fallback] related_skills: [arxiv] fallback_for_toolsets: [web] -prerequisites: - commands: [ddgs] --- # DuckDuckGo Search Free web search using DuckDuckGo. **No API key required.** -Preferred when `web_search` tool is unavailable or unsuitable (no `FIRECRAWL_API_KEY` set). Can also be used as a standalone search tool. +Preferred when `web_search` is unavailable or unsuitable (for example when `FIRECRAWL_API_KEY` is not set). Can also be used as a standalone search path when DuckDuckGo results are specifically desired. -## Setup +## Detection Flow + +Check what is actually available before choosing an approach: ```bash -# Install the ddgs package (one-time) -pip install ddgs +# Check CLI availability +command -v ddgs >/dev/null && echo "DDGS_CLI=installed" || echo "DDGS_CLI=missing" ``` -## Python API (Primary) +Decision tree: +1. If `ddgs` CLI is installed, prefer `terminal` + `ddgs` +2. If `ddgs` CLI is missing, do not assume `execute_code` can import `ddgs` +3. If the user wants DuckDuckGo specifically, install `ddgs` first in the relevant environment +4. Otherwise fall back to built-in web/browser tools -Use the `DDGS` class in `execute_code` for structured results with typed fields. +Important runtime note: +- Terminal and `execute_code` are separate runtimes +- A successful shell install does not guarantee `execute_code` can import `ddgs` +- Never assume third-party Python packages are preinstalled inside `execute_code` + +## Installation + +Install `ddgs` only when DuckDuckGo search is specifically needed and the runtime does not already provide it. + +```bash +# Python package + CLI entrypoint +pip install ddgs + +# Verify CLI +ddgs --help +``` + +If a workflow depends on Python imports, verify that same runtime can import `ddgs` before using `from ddgs import DDGS`. + +## Method 1: CLI Search (Preferred) + +Use the `ddgs` command via `terminal` when it exists. This is the preferred path because it avoids assuming the `execute_code` sandbox has the `ddgs` Python package installed. + +```bash +# Text search +ddgs text -k "python async programming" -m 5 + +# News search +ddgs news -k "artificial intelligence" -m 5 + +# Image search +ddgs images -k "landscape photography" -m 10 + +# Video search +ddgs videos -k "python tutorial" -m 5 + +# With region filter +ddgs text -k "best restaurants" -m 5 -r us-en + +# Recent results only (d=day, w=week, m=month, y=year) +ddgs text -k "latest AI news" -m 5 -t w + +# JSON output for parsing +ddgs text -k "fastapi tutorial" -m 5 -o json +``` + +### CLI Flags + +| Flag | Description | Example | +|------|-------------|---------| +| `-k` | Keywords (query) — **required** | `-k "search terms"` | +| `-m` | Max results | `-m 5` | +| `-r` | Region | `-r us-en` | +| `-t` | Time limit | `-t w` (week) | +| `-s` | Safe search | `-s off` | +| `-o` | Output format | `-o json` | + +## Method 2: Python API (Only After Verification) + +Use the `DDGS` class in `execute_code` or another Python runtime only after verifying that `ddgs` is installed there. Do not assume `execute_code` includes third-party packages by default. + +Safe wording: +- "Use `execute_code` with `ddgs` after installing or verifying the package if needed" + +Avoid saying: +- "`execute_code` includes `ddgs`" +- "DuckDuckGo search works by default in `execute_code`" **Important:** `max_results` must always be passed as a **keyword argument** — positional usage raises an error on all methods. @@ -76,7 +146,7 @@ from ddgs import DDGS with DDGS() as ddgs: for r in ddgs.images("semiconductor chip", max_results=5): print(r["title"]) - print(r["image"]) # direct image URL + print(r["image"]) print(r.get("thumbnail", "")) print(r.get("source", "")) print() @@ -94,9 +164,9 @@ from ddgs import DDGS with DDGS() as ddgs: for r in ddgs.videos("FastAPI tutorial", max_results=5): print(r["title"]) - print(r.get("content", "")) # video URL - print(r.get("duration", "")) # e.g. "26:03" - print(r.get("provider", "")) # YouTube, etc. + print(r.get("content", "")) + print(r.get("duration", "")) + print(r.get("provider", "")) print(r.get("published", "")) print() ``` @@ -112,50 +182,17 @@ Returns: `title`, `content`, `description`, `duration`, `provider`, `published`, | `images()` | Visuals, diagrams | title, image, thumbnail, url | | `videos()` | Tutorials, demos | title, content, duration, provider | -## CLI (Alternative) - -Use the `ddgs` command via terminal when you don't need structured field access. - -```bash -# Text search -ddgs text -k "python async programming" -m 5 - -# News search -ddgs news -k "artificial intelligence" -m 5 - -# Image search -ddgs images -k "landscape photography" -m 10 - -# Video search -ddgs videos -k "python tutorial" -m 5 - -# With region filter -ddgs text -k "best restaurants" -m 5 -r us-en - -# Recent results only (d=day, w=week, m=month, y=year) -ddgs text -k "latest AI news" -m 5 -t w - -# JSON output for parsing -ddgs text -k "fastapi tutorial" -m 5 -o json -``` - -### CLI Flags - -| Flag | Description | Example | -|------|-------------|---------| -| `-k` | Keywords (query) — **required** | `-k "search terms"` | -| `-m` | Max results | `-m 5` | -| `-r` | Region | `-r us-en` | -| `-t` | Time limit | `-t w` (week) | -| `-s` | Safe search | `-s off` | -| `-o` | Output format | `-o json` | - ## Workflow: Search then Extract -DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full content, follow up with `web_extract`: +DuckDuckGo returns titles, URLs, and snippets — not full page content. To get full page content, search first and then extract the most relevant URL with `web_extract`, browser tools, or curl. -1. **Search** with ddgs to find relevant URLs -2. **Extract** content using the `web_extract` tool (if available) or curl +CLI example: + +```bash +ddgs text -k "fastapi deployment guide" -m 3 -o json +``` + +Python example, only after verifying `ddgs` is installed in that runtime: ```python from ddgs import DDGS @@ -164,25 +201,37 @@ with DDGS() as ddgs: results = list(ddgs.text("fastapi deployment guide", max_results=3)) for r in results: print(r["title"], "->", r["href"]) - -# Then use web_extract tool on the best URL ``` +Then extract the best URL with `web_extract` or another content-retrieval tool. + ## Limitations - **Rate limiting**: DuckDuckGo may throttle after many rapid requests. Add a short delay between searches if needed. -- **No content extraction**: ddgs returns snippets, not full page content. Use `web_extract` or curl for that. +- **No content extraction**: `ddgs` returns snippets, not full page content. Use `web_extract`, browser tools, or curl for the full article/page. - **Results quality**: Generally good but less configurable than Firecrawl's search. - **Availability**: DuckDuckGo may block requests from some cloud IPs. If searches return empty, try different keywords or wait a few seconds. -- **Field variability**: Return fields may vary between results or ddgs versions. Use `.get()` for optional fields to avoid KeyError. +- **Field variability**: Return fields may vary between results or `ddgs` versions. Use `.get()` for optional fields to avoid `KeyError`. +- **Separate runtimes**: A successful `ddgs` install in terminal does not automatically mean `execute_code` can import it. + +## Troubleshooting + +| Problem | Likely Cause | What To Do | +|---------|--------------|------------| +| `ddgs: command not found` | CLI not installed in the shell environment | Install `ddgs`, or use built-in web/browser tools instead | +| `ModuleNotFoundError: No module named 'ddgs'` | Python runtime does not have the package installed | Do not use Python DDGS there until that runtime is prepared | +| Search returns nothing | Temporary rate limiting or poor query | Wait a few seconds, retry, or adjust the query | +| CLI works but `execute_code` import fails | Terminal and `execute_code` are different runtimes | Keep using CLI, or separately prepare the Python runtime | ## Pitfalls - **`max_results` is keyword-only**: `ddgs.text("query", 5)` raises an error. Use `ddgs.text("query", max_results=5)`. +- **Do not assume the CLI exists**: Check `command -v ddgs` before using it. +- **Do not assume `execute_code` can import `ddgs`**: `from ddgs import DDGS` may fail with `ModuleNotFoundError` unless that runtime was prepared separately. +- **Package name**: The package is `ddgs` (previously `duckduckgo-search`). Install with `pip install ddgs`. - **Don't confuse `-k` and `-m`** (CLI): `-k` is for keywords, `-m` is for max results count. -- **Package name**: The package is `ddgs` (was previously `duckduckgo-search`). Install with `pip install ddgs`. -- **Empty results**: If ddgs returns nothing, it may be rate-limited. Wait a few seconds and retry. +- **Empty results**: If `ddgs` returns nothing, it may be rate-limited. Wait a few seconds and retry. ## Validated With -Smoke-tested with `ddgs==9.11.2` on Python 3.13. All four methods (text, news, images, videos) confirmed working with keyword `max_results`. +Validated examples against `ddgs==9.11.2` semantics. Skill guidance now treats CLI availability and Python import availability as separate concerns so the documented workflow matches actual runtime behavior. diff --git a/tests/acp/test_entry.py b/tests/acp/test_entry.py new file mode 100644 index 000000000..760522c31 --- /dev/null +++ b/tests/acp/test_entry.py @@ -0,0 +1,20 @@ +"""Tests for acp_adapter.entry startup wiring.""" + +import acp + +from acp_adapter import entry + + +def test_main_enables_unstable_protocol(monkeypatch): + calls = {} + + async def fake_run_agent(agent, **kwargs): + calls["kwargs"] = kwargs + + monkeypatch.setattr(entry, "_setup_logging", lambda: None) + monkeypatch.setattr(entry, "_load_env", lambda: None) + monkeypatch.setattr(acp, "run_agent", fake_run_agent) + + entry.main() + + assert calls["kwargs"]["use_unstable_protocol"] is True diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index 5b9d3de62..fc6d53dd8 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, AsyncMock, patch import pytest import acp +from acp.agent.router import build_agent_router from acp.schema import ( AgentCapabilities, AuthenticateResponse, @@ -18,6 +19,8 @@ from acp.schema import ( NewSessionResponse, PromptResponse, ResumeSessionResponse, + SetSessionConfigOptionResponse, + SetSessionModeResponse, SessionInfo, TextContentBlock, Usage, @@ -168,6 +171,74 @@ class TestListAndFork: assert fork_resp.session_id != new_resp.session_id +# --------------------------------------------------------------------------- +# session configuration / model routing +# --------------------------------------------------------------------------- + + +class TestSessionConfiguration: + @pytest.mark.asyncio + async def test_set_session_mode_returns_response(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id) + state = agent.session_manager.get_session(new_resp.session_id) + + assert isinstance(resp, SetSessionModeResponse) + assert getattr(state, "mode", None) == "chat" + + @pytest.mark.asyncio + async def test_set_config_option_returns_response(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + resp = await agent.set_config_option( + config_id="approval_mode", + session_id=new_resp.session_id, + value="auto", + ) + state = agent.session_manager.get_session(new_resp.session_id) + + assert isinstance(resp, SetSessionConfigOptionResponse) + assert getattr(state, "config_options", {}) == {"approval_mode": "auto"} + assert resp.config_options == [] + + @pytest.mark.asyncio + async def test_router_accepts_stable_session_config_methods(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + router = build_agent_router(agent) + + mode_result = await router( + "session/set_mode", + {"modeId": "chat", "sessionId": new_resp.session_id}, + False, + ) + config_result = await router( + "session/set_config_option", + { + "configId": "approval_mode", + "sessionId": new_resp.session_id, + "value": "auto", + }, + False, + ) + + assert mode_result == {} + assert config_result == {"configOptions": []} + + @pytest.mark.asyncio + async def test_router_accepts_unstable_model_switch_when_enabled(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + router = build_agent_router(agent, use_unstable_protocol=True) + + result = await router( + "session/set_model", + {"modelId": "gpt-5.4", "sessionId": new_resp.session_id}, + False, + ) + state = agent.session_manager.get_session(new_resp.session_id) + + assert result == {} + assert state.model == "gpt-5.4" + + # --------------------------------------------------------------------------- # prompt # --------------------------------------------------------------------------- diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index 10e14518a..35dcee7ad 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -11,6 +11,7 @@ from agent.auxiliary_client import ( get_text_auxiliary_client, get_vision_auxiliary_client, get_available_vision_backends, + resolve_vision_provider_client, resolve_provider_client, auxiliary_max_tokens_param, _read_codex_access_token, @@ -490,15 +491,17 @@ class TestGetTextAuxiliaryClient: assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1" assert mock_openai.call_args.kwargs["api_key"] == "task-key" - def test_task_direct_endpoint_without_openai_key_does_not_fall_back(self, monkeypatch): + def test_task_direct_endpoint_without_openai_key_uses_placeholder(self, monkeypatch): + """Local endpoints without an API key should use 'no-key-required' placeholder.""" monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_BASE_URL", "http://localhost:2345/v1") monkeypatch.setenv("AUXILIARY_WEB_EXTRACT_MODEL", "task-model") with patch("agent.auxiliary_client.OpenAI") as mock_openai: client, model = get_text_auxiliary_client("web_extract") - assert client is None - assert model is None - mock_openai.assert_not_called() + assert client is not None + assert model == "task-model" + assert mock_openai.call_args.kwargs["api_key"] == "no-key-required" + assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:2345/v1" def test_custom_endpoint_uses_config_saved_base_url(self, monkeypatch): config = { @@ -638,6 +641,30 @@ class TestVisionClientFallback: assert client.__class__.__name__ == "AnthropicAuxiliaryClient" assert model == "claude-haiku-4-5-20251001" + def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch): + def fake_load_config(): + return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}} + + codex_client = MagicMock() + with ( + patch("hermes_cli.config.load_config", fake_load_config), + patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex, + patch("agent.auxiliary_client._try_openrouter") as mock_openrouter, + patch("agent.auxiliary_client._try_nous") as mock_nous, + patch("agent.auxiliary_client._try_anthropic") as mock_anthropic, + patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom, + ): + provider, client, model = resolve_vision_provider_client() + + assert provider == "openai-codex" + assert client is codex_client + assert model == "gpt-5.2-codex" + mock_codex.assert_called_once() + mock_openrouter.assert_not_called() + mock_nous.assert_not_called() + mock_anthropic.assert_not_called() + mock_custom.assert_not_called() + def test_vision_auto_includes_codex(self, codex_auth_dir): """Codex supports vision (gpt-5.3-codex), so auto mode should use it.""" with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \ @@ -671,15 +698,16 @@ class TestVisionClientFallback: assert mock_openai.call_args.kwargs["base_url"] == "http://localhost:4567/v1" assert mock_openai.call_args.kwargs["api_key"] == "vision-key" - def test_vision_direct_endpoint_requires_openai_api_key(self, monkeypatch): + def test_vision_direct_endpoint_without_key_uses_placeholder(self, monkeypatch): + """Vision endpoint without API key should use 'no-key-required' placeholder.""" monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") monkeypatch.setenv("AUXILIARY_VISION_BASE_URL", "http://localhost:4567/v1") monkeypatch.setenv("AUXILIARY_VISION_MODEL", "vision-model") with patch("agent.auxiliary_client.OpenAI") as mock_openai: client, model = get_vision_auxiliary_client() - assert client is None - assert model is None - mock_openai.assert_not_called() + assert client is not None + assert model == "vision-model" + assert mock_openai.call_args.kwargs["api_key"] == "no-key-required" def test_vision_uses_openrouter_when_available(self, monkeypatch): monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") diff --git a/tests/agent/test_external_skills.py b/tests/agent/test_external_skills.py new file mode 100644 index 000000000..1a9cd63d5 --- /dev/null +++ b/tests/agent/test_external_skills.py @@ -0,0 +1,157 @@ +"""Tests for external skill directories (skills.external_dirs config).""" + +import json +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def external_skills_dir(tmp_path): + """Create a temp dir with a sample external skill.""" + ext_dir = tmp_path / "external-skills" + skill_dir = ext_dir / "my-external-skill" + skill_dir.mkdir(parents=True) + (skill_dir / "SKILL.md").write_text( + "---\nname: my-external-skill\ndescription: A skill from an external directory\n---\n\n# My External Skill\n\nDo external things.\n" + ) + return ext_dir + + +@pytest.fixture +def hermes_home(tmp_path): + """Create a minimal HERMES_HOME with config.""" + home = tmp_path / ".hermes" + home.mkdir() + (home / "skills").mkdir() + return home + + +class TestGetExternalSkillsDirs: + def test_empty_config(self, hermes_home): + (hermes_home / "config.yaml").write_text("skills:\n external_dirs: []\n") + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert result == [] + + def test_nonexistent_dir_skipped(self, hermes_home): + (hermes_home / "config.yaml").write_text( + "skills:\n external_dirs:\n - /nonexistent/path\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert result == [] + + def test_valid_dir_returned(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert len(result) == 1 + assert result[0] == external_skills_dir.resolve() + + def test_duplicate_dirs_deduplicated(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n - {external_skills_dir}\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert len(result) == 1 + + def test_local_skills_dir_excluded(self, hermes_home): + local_skills = hermes_home / "skills" + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {local_skills}\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert result == [] + + def test_no_config_file(self, hermes_home): + # No config.yaml at all + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert result == [] + + def test_string_value_converted_to_list(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs: {external_skills_dir}\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_external_skills_dirs + result = get_external_skills_dirs() + assert len(result) == 1 + + +class TestGetAllSkillsDirs: + def test_local_always_first(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n" + ) + with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}): + from agent.skill_utils import get_all_skills_dirs + result = get_all_skills_dirs() + assert result[0] == hermes_home / "skills" + assert result[1] == external_skills_dir.resolve() + + +class TestExternalSkillsInFindAll: + def test_external_skills_found(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n" + ) + local_skills = hermes_home / "skills" + with ( + patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}), + patch("tools.skills_tool.SKILLS_DIR", local_skills), + ): + from tools.skills_tool import _find_all_skills + skills = _find_all_skills() + names = [s["name"] for s in skills] + assert "my-external-skill" in names + + def test_local_takes_precedence(self, hermes_home, external_skills_dir): + """If the same skill name exists locally and externally, local wins.""" + local_skills = hermes_home / "skills" + local_skill = local_skills / "my-external-skill" + local_skill.mkdir(parents=True) + (local_skill / "SKILL.md").write_text( + "---\nname: my-external-skill\ndescription: Local version\n---\n\nLocal.\n" + ) + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n" + ) + with ( + patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}), + patch("tools.skills_tool.SKILLS_DIR", local_skills), + ): + from tools.skills_tool import _find_all_skills + skills = _find_all_skills() + matching = [s for s in skills if s["name"] == "my-external-skill"] + assert len(matching) == 1 + assert matching[0]["description"] == "Local version" + + +class TestExternalSkillView: + def test_skill_view_finds_external(self, hermes_home, external_skills_dir): + (hermes_home / "config.yaml").write_text( + f"skills:\n external_dirs:\n - {external_skills_dir}\n" + ) + local_skills = hermes_home / "skills" + with ( + patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}), + patch("tools.skills_tool.SKILLS_DIR", local_skills), + ): + from tools.skills_tool import skill_view + result = json.loads(skill_view("my-external-skill")) + assert result["success"] is True + assert "external things" in result["content"] diff --git a/tests/agent/test_prompt_builder.py b/tests/agent/test_prompt_builder.py index deeac8990..791f7ea0e 100644 --- a/tests/agent/test_prompt_builder.py +++ b/tests/agent/test_prompt_builder.py @@ -21,6 +21,8 @@ from agent.prompt_builder import ( build_context_files_prompt, CONTEXT_FILE_MAX_CHARS, DEFAULT_AGENT_IDENTITY, + TOOL_USE_ENFORCEMENT_GUIDANCE, + TOOL_USE_ENFORCEMENT_MODELS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, PLATFORM_HINTS, @@ -196,7 +198,7 @@ class TestParseSkillFile: ) from unittest.mock import patch - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" is_compat, _, _ = _parse_skill_file(skill_file) assert is_compat is False @@ -237,6 +239,14 @@ class TestPromptBuilderImports: class TestBuildSkillsSystemPrompt: + @pytest.fixture(autouse=True) + def _clear_skills_cache(self): + """Ensure the in-process skills prompt cache doesn't leak between tests.""" + from agent.prompt_builder import clear_skills_system_prompt_cache + clear_skills_system_prompt_cache(clear_snapshot=True) + yield + clear_skills_system_prompt_cache(clear_snapshot=True) + def test_empty_when_no_skills_dir(self, monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) result = build_skills_system_prompt() @@ -287,7 +297,7 @@ class TestBuildSkillsSystemPrompt: from unittest.mock import patch - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" result = build_skills_system_prompt() @@ -306,7 +316,7 @@ class TestBuildSkillsSystemPrompt: from unittest.mock import patch - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" result = build_skills_system_prompt() @@ -334,7 +344,7 @@ class TestBuildSkillsSystemPrompt: from unittest.mock import patch with patch( - "tools.skills_tool._get_disabled_skill_names", + "agent.prompt_builder.get_disabled_skill_names", return_value={"old-tool"}, ): result = build_skills_system_prompt() @@ -621,6 +631,10 @@ class TestBuildContextFilesPrompt: result = build_context_files_prompt(cwd=str(tmp_path)) assert "Lowercase claude rules" in result + @pytest.mark.skipif( + sys.platform == "darwin", + reason="APFS default volume is case-insensitive; CLAUDE.md and claude.md alias the same path", + ) def test_claude_md_uppercase_takes_priority(self, tmp_path): uppercase = tmp_path / "CLAUDE.md" lowercase = tmp_path / "claude.md" @@ -868,6 +882,13 @@ class TestSkillShouldShow: class TestBuildSkillsSystemPromptConditional: + @pytest.fixture(autouse=True) + def _clear_skills_cache(self): + from agent.prompt_builder import clear_skills_system_prompt_cache + clear_skills_system_prompt_cache(clear_snapshot=True) + yield + clear_skills_system_prompt_cache(clear_snapshot=True) + def test_fallback_skill_hidden_when_primary_available(self, monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) skill_dir = tmp_path / "skills" / "search" / "duckduckgo" @@ -972,3 +993,98 @@ class TestBuildSkillsSystemPromptConditional: available_toolsets=set(), ) assert "nested-null" in result + + +# ========================================================================= +# Tool-use enforcement guidance +# ========================================================================= + + +class TestToolUseEnforcementGuidance: + def test_guidance_mentions_tool_calls(self): + assert "tool call" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower() + + def test_guidance_forbids_description_only(self): + assert "describe" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower() + assert "promise" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower() + + def test_guidance_requires_action(self): + assert "MUST" in TOOL_USE_ENFORCEMENT_GUIDANCE + + def test_enforcement_models_includes_gpt(self): + assert "gpt" in TOOL_USE_ENFORCEMENT_MODELS + + def test_enforcement_models_includes_codex(self): + assert "codex" in TOOL_USE_ENFORCEMENT_MODELS + + def test_enforcement_models_is_tuple(self): + assert isinstance(TOOL_USE_ENFORCEMENT_MODELS, tuple) + + +# ========================================================================= +# Budget warning history stripping +# ========================================================================= + + +class TestStripBudgetWarningsFromHistory: + def test_strips_json_budget_warning_key(self): + import json + from run_agent import _strip_budget_warnings_from_history + + messages = [ + {"role": "tool", "tool_call_id": "c1", "content": json.dumps({ + "output": "hello", + "exit_code": 0, + "_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]", + })}, + ] + _strip_budget_warnings_from_history(messages) + parsed = json.loads(messages[0]["content"]) + assert "_budget_warning" not in parsed + assert parsed["output"] == "hello" + assert parsed["exit_code"] == 0 + + def test_strips_text_budget_warning(self): + from run_agent import _strip_budget_warnings_from_history + + messages = [ + {"role": "tool", "tool_call_id": "c1", + "content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"}, + ] + _strip_budget_warnings_from_history(messages) + assert messages[0]["content"] == "some result" + + def test_leaves_non_tool_messages_unchanged(self): + from run_agent import _strip_budget_warnings_from_history + + messages = [ + {"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"}, + {"role": "user", "content": "hello"}, + ] + original_contents = [m["content"] for m in messages] + _strip_budget_warnings_from_history(messages) + assert [m["content"] for m in messages] == original_contents + + def test_handles_empty_and_missing_content(self): + from run_agent import _strip_budget_warnings_from_history + + messages = [ + {"role": "tool", "tool_call_id": "c1", "content": ""}, + {"role": "tool", "tool_call_id": "c2"}, + ] + _strip_budget_warnings_from_history(messages) + assert messages[0]["content"] == "" + + def test_strips_caution_variant(self): + import json + from run_agent import _strip_budget_warnings_from_history + + messages = [ + {"role": "tool", "tool_call_id": "c1", "content": json.dumps({ + "output": "ok", + "_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]", + })}, + ] + _strip_budget_warnings_from_history(messages) + parsed = json.loads(messages[0]["content"]) + assert "_budget_warning" not in parsed diff --git a/tests/agent/test_skill_commands.py b/tests/agent/test_skill_commands.py index f6a114db6..6b3e551e1 100644 --- a/tests/agent/test_skill_commands.py +++ b/tests/agent/test_skill_commands.py @@ -54,7 +54,7 @@ class TestScanSkillCommands: """macOS-only skills should not register slash commands on Linux.""" with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "linux" _make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n") @@ -67,7 +67,7 @@ class TestScanSkillCommands: """macOS-only skills should register slash commands on macOS.""" with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "darwin" _make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n") @@ -78,7 +78,7 @@ class TestScanSkillCommands: """Skills without platforms field should register on any platform.""" with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "win32" _make_skill(tmp_path, "generic-tool") @@ -246,20 +246,10 @@ Generate some audio. def test_preserves_remaining_remote_setup_warning(self, tmp_path, monkeypatch): monkeypatch.setenv("TERMINAL_ENV", "ssh") monkeypatch.delenv("TENOR_API_KEY", raising=False) - - def fake_secret_callback(var_name, prompt, metadata=None): - os.environ[var_name] = "stored-in-test" - return { - "success": True, - "stored_as": var_name, - "validated": False, - "skipped": False, - } - monkeypatch.setattr( skills_tool_module, "_secret_capture_callback", - fake_secret_callback, + None, raising=False, ) diff --git a/tests/cron/test_jobs.py b/tests/cron/test_jobs.py index 71883d158..cca460100 100644 --- a/tests/cron/test_jobs.py +++ b/tests/cron/test_jobs.py @@ -20,6 +20,7 @@ from cron.jobs import ( resume_job, remove_job, mark_job_run, + advance_next_run, get_due_jobs, save_job_output, ) @@ -339,6 +340,90 @@ class TestMarkJobRun: assert updated["last_error"] == "timeout" +class TestAdvanceNextRun: + """Tests for advance_next_run() — crash-safety for recurring jobs.""" + + def test_advances_interval_job(self, tmp_cron_dir): + """Interval jobs should have next_run_at bumped to the next future occurrence.""" + job = create_job(prompt="Recurring check", schedule="every 1h") + # Force next_run_at to 5 minutes ago (i.e. the job is due) + jobs = load_jobs() + old_next = (datetime.now() - timedelta(minutes=5)).isoformat() + jobs[0]["next_run_at"] = old_next + save_jobs(jobs) + + result = advance_next_run(job["id"]) + assert result is True + + updated = get_job(job["id"]) + from cron.jobs import _ensure_aware, _hermes_now + new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"])) + assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance" + + def test_advances_cron_job(self, tmp_cron_dir): + """Cron-expression jobs should have next_run_at bumped to the next occurrence.""" + pytest.importorskip("croniter") + job = create_job(prompt="Daily wakeup", schedule="15 6 * * *") + # Force next_run_at to 30 minutes ago + jobs = load_jobs() + old_next = (datetime.now() - timedelta(minutes=30)).isoformat() + jobs[0]["next_run_at"] = old_next + save_jobs(jobs) + + result = advance_next_run(job["id"]) + assert result is True + + updated = get_job(job["id"]) + from cron.jobs import _ensure_aware, _hermes_now + new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"])) + assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance" + + def test_skips_oneshot_job(self, tmp_cron_dir): + """One-shot jobs should NOT be advanced — they need to retry on restart.""" + job = create_job(prompt="Run once", schedule="30m") + original_next = get_job(job["id"])["next_run_at"] + + result = advance_next_run(job["id"]) + assert result is False + + updated = get_job(job["id"]) + assert updated["next_run_at"] == original_next, "one-shot next_run_at should be unchanged" + + def test_nonexistent_job_returns_false(self, tmp_cron_dir): + result = advance_next_run("nonexistent-id") + assert result is False + + def test_already_future_stays_future(self, tmp_cron_dir): + """If next_run_at is already in the future, advance keeps it in the future (no harm).""" + job = create_job(prompt="Future job", schedule="every 1h") + # next_run_at is already set to ~1h from now by create_job + advance_next_run(job["id"]) + # Regardless of return value, the job should still be in the future + updated = get_job(job["id"]) + from cron.jobs import _ensure_aware, _hermes_now + new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"])) + assert new_next_dt > _hermes_now(), "next_run_at should remain in the future" + + def test_crash_safety_scenario(self, tmp_cron_dir): + """Simulate the crash-loop scenario: after advance, the job should NOT be due.""" + job = create_job(prompt="Crash test", schedule="every 1h") + # Force next_run_at to 5 minutes ago (job is due) + jobs = load_jobs() + jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=5)).isoformat() + save_jobs(jobs) + + # Job should be due before advance + due_before = get_due_jobs() + assert len(due_before) == 1 + + # Advance (simulating what tick() does before run_job) + advance_next_run(job["id"]) + + # Now the job should NOT be due (simulates restart after crash) + due_after = get_due_jobs() + assert len(due_after) == 0, "Job should not be due after advance_next_run" + + class TestGetDueJobs: def test_past_due_within_window_returned(self, tmp_cron_dir): """Jobs within the dynamic grace window are still considered due (not stale). diff --git a/tests/cron/test_scheduler.py b/tests/cron/test_scheduler.py index c5f244a1f..afec21ce7 100644 --- a/tests/cron/test_scheduler.py +++ b/tests/cron/test_scheduler.py @@ -84,6 +84,48 @@ class TestResolveDeliveryTarget: "thread_id": None, } + def test_human_friendly_label_resolved_via_channel_directory(self): + """deliver: 'whatsapp:Alice (dm)' resolves to the real JID.""" + job = {"deliver": "whatsapp:Alice (dm)"} + with patch( + "gateway.channel_directory.resolve_channel_name", + return_value="12345678901234@lid", + ): + result = _resolve_delivery_target(job) + assert result == { + "platform": "whatsapp", + "chat_id": "12345678901234@lid", + "thread_id": None, + } + + def test_human_friendly_label_without_suffix_resolved(self): + """deliver: 'telegram:My Group' resolves without display suffix.""" + job = {"deliver": "telegram:My Group"} + with patch( + "gateway.channel_directory.resolve_channel_name", + return_value="-1009999", + ): + result = _resolve_delivery_target(job) + assert result == { + "platform": "telegram", + "chat_id": "-1009999", + "thread_id": None, + } + + def test_raw_id_not_mangled_when_directory_returns_none(self): + """deliver: 'whatsapp:12345@lid' passes through when directory has no match.""" + job = {"deliver": "whatsapp:12345@lid"} + with patch( + "gateway.channel_directory.resolve_channel_name", + return_value=None, + ): + result = _resolve_delivery_target(job) + assert result == { + "platform": "whatsapp", + "chat_id": "12345@lid", + "thread_id": None, + } + def test_bare_platform_uses_matching_origin_chat(self): job = { "deliver": "telegram", @@ -167,6 +209,32 @@ class TestDeliverResultWrapping: sent_content = send_mock.call_args.kwargs.get("content") or send_mock.call_args[0][-1] assert "Cronjob Response: abc-123" in sent_content + def test_delivery_skips_wrapping_when_config_disabled(self): + """When cron.wrap_response is false, deliver raw content without header/footer.""" + from gateway.config import Platform + + pconfig = MagicMock() + pconfig.enabled = True + mock_cfg = MagicMock() + mock_cfg.platforms = {Platform.TELEGRAM: pconfig} + + with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \ + patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \ + patch("cron.scheduler.load_config", return_value={"cron": {"wrap_response": False}}): + job = { + "id": "test-job", + "name": "daily-report", + "deliver": "origin", + "origin": {"platform": "telegram", "chat_id": "123"}, + } + _deliver_result(job, "Clean output only.") + + send_mock.assert_called_once() + sent_content = send_mock.call_args.kwargs.get("content") or send_mock.call_args[0][-1] + assert sent_content == "Clean output only." + assert "Cronjob Response" not in sent_content + assert "The agent cannot see" not in sent_content + def test_no_mirror_to_session_call(self): """Cron deliveries should NOT mirror into the gateway session.""" from gateway.config import Platform @@ -687,3 +755,41 @@ class TestBuildJobPromptMissingSkill: result = _build_job_prompt({"skills": ["ghost-skill", "real-skill"], "prompt": "go"}) assert "Real skill content." in result assert "go" in result + + +class TestTickAdvanceBeforeRun: + """Verify that tick() calls advance_next_run before run_job for crash safety.""" + + def test_advance_called_before_run_job(self, tmp_path): + """advance_next_run must be called before run_job to prevent crash-loop re-fires.""" + call_order = [] + + def fake_advance(job_id): + call_order.append(("advance", job_id)) + return True + + def fake_run_job(job): + call_order.append(("run", job["id"])) + return True, "output", "response", None + + fake_job = { + "id": "test-advance", + "name": "test", + "prompt": "hello", + "enabled": True, + "schedule": {"kind": "cron", "expr": "15 6 * * *"}, + } + + with patch("cron.scheduler.get_due_jobs", return_value=[fake_job]), \ + patch("cron.scheduler.advance_next_run", side_effect=fake_advance) as adv_mock, \ + patch("cron.scheduler.run_job", side_effect=fake_run_job), \ + patch("cron.scheduler.save_job_output", return_value=tmp_path / "out.md"), \ + patch("cron.scheduler.mark_job_run"), \ + patch("cron.scheduler._deliver_result"): + from cron.scheduler import tick + executed = tick(verbose=False) + + assert executed == 1 + adv_mock.assert_called_once_with("test-advance") + # advance must happen before run + assert call_order == [("advance", "test-advance"), ("run", "test-advance")] diff --git a/tests/gateway/test_allowlist_startup_check.py b/tests/gateway/test_allowlist_startup_check.py new file mode 100644 index 000000000..96441c052 --- /dev/null +++ b/tests/gateway/test_allowlist_startup_check.py @@ -0,0 +1,46 @@ +"""Tests for the startup allowlist warning check in gateway/run.py.""" + +import os +from unittest.mock import patch + + +def _would_warn(): + """Replicate the startup allowlist warning logic. Returns True if warning fires.""" + _any_allowlist = any( + os.getenv(v) + for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", + "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", + "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", + "EMAIL_ALLOWED_USERS", + "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", + "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS", + "GATEWAY_ALLOWED_USERS") + ) + _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( + os.getenv(v, "").lower() in ("true", "1", "yes") + for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS", + "WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS", + "SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS", + "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", + "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", "WECOM_ALLOW_ALL_USERS") + ) + return not _any_allowlist and not _allow_all + + +class TestAllowlistStartupCheck: + + def test_no_config_emits_warning(self): + with patch.dict(os.environ, {}, clear=True): + assert _would_warn() is True + + def test_signal_group_allowed_users_suppresses_warning(self): + with patch.dict(os.environ, {"SIGNAL_GROUP_ALLOWED_USERS": "user1"}, clear=True): + assert _would_warn() is False + + def test_telegram_allow_all_users_suppresses_warning(self): + with patch.dict(os.environ, {"TELEGRAM_ALLOW_ALL_USERS": "true"}, clear=True): + assert _would_warn() is False + + def test_gateway_allow_all_users_suppresses_warning(self): + with patch.dict(os.environ, {"GATEWAY_ALLOW_ALL_USERS": "yes"}, clear=True): + assert _would_warn() is False diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 2ee928484..772dd8b1c 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -28,6 +28,7 @@ from gateway.platforms.api_server import ( _CORS_HEADERS, check_api_server_requirements, cors_middleware, + security_headers_middleware, ) @@ -214,9 +215,11 @@ def _make_adapter(api_key: str = "", cors_origins=None) -> APIServerAdapter: def _create_app(adapter: APIServerAdapter) -> web.Application: """Create the aiohttp app from the adapter (without starting the full server).""" - app = web.Application(middlewares=[cors_middleware]) + mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None] + app = web.Application(middlewares=mws) app["api_server_adapter"] = adapter app.router.add_get("/health", adapter._handle_health) + app.router.add_get("/v1/health", adapter._handle_health) app.router.add_get("/v1/models", adapter._handle_models) app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions) app.router.add_post("/v1/responses", adapter._handle_responses) @@ -241,6 +244,16 @@ def auth_adapter(): class TestHealthEndpoint: + @pytest.mark.asyncio + async def test_security_headers_present(self, adapter): + """Responses should include basic security headers.""" + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/health") + assert resp.status == 200 + assert resp.headers.get("X-Content-Type-Options") == "nosniff" + assert resp.headers.get("Referrer-Policy") == "no-referrer" + @pytest.mark.asyncio async def test_health_returns_ok(self, adapter): app = _create_app(adapter) @@ -251,6 +264,17 @@ class TestHealthEndpoint: assert data["status"] == "ok" assert data["platform"] == "hermes-agent" + @pytest.mark.asyncio + async def test_v1_health_alias_returns_ok(self, adapter): + """GET /v1/health should return the same response as /health.""" + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/health") + assert resp.status == 200 + data = await resp.json() + assert data["status"] == "ok" + assert data["platform"] == "hermes-agent" + # --------------------------------------------------------------------------- # /v1/models endpoint @@ -1300,6 +1324,31 @@ class TestCORS: assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "") assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "") + @pytest.mark.asyncio + async def test_cors_allows_idempotency_key_header(self): + adapter = _make_adapter(cors_origins=["http://localhost:3000"]) + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.options( + "/v1/chat/completions", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Idempotency-Key", + }, + ) + assert resp.status == 200 + assert "Idempotency-Key" in resp.headers.get("Access-Control-Allow-Headers", "") + + @pytest.mark.asyncio + async def test_cors_sets_vary_origin_header(self): + adapter = _make_adapter(cors_origins=["http://localhost:3000"]) + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"}) + assert resp.status == 200 + assert resp.headers.get("Vary") == "Origin" + @pytest.mark.asyncio async def test_cors_options_preflight_allowed_for_configured_origin(self): """Configured origins can complete browser preflight.""" @@ -1319,6 +1368,21 @@ class TestCORS: assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "") + @pytest.mark.asyncio + async def test_cors_preflight_sets_max_age(self): + adapter = _make_adapter(cors_origins=["http://localhost:3000"]) + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.options( + "/v1/chat/completions", + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization, Content-Type", + }, + ) + assert resp.status == 200 + assert resp.headers.get("Access-Control-Max-Age") == "600" # --------------------------------------------------------------------------- # Conversation parameter # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_api_server_toolset.py b/tests/gateway/test_api_server_toolset.py new file mode 100644 index 000000000..3b4ff254d --- /dev/null +++ b/tests/gateway/test_api_server_toolset.py @@ -0,0 +1,129 @@ +"""Tests for hermes-api-server toolset and API server tool availability.""" +import os +import json +from unittest.mock import patch, MagicMock + +import pytest + +from toolsets import resolve_toolset, get_toolset, validate_toolset + + +class TestHermesApiServerToolset: + """Tests for the hermes-api-server toolset definition.""" + + def test_toolset_exists(self): + ts = get_toolset("hermes-api-server") + assert ts is not None + + def test_toolset_validates(self): + assert validate_toolset("hermes-api-server") + + def test_toolset_includes_web_tools(self): + tools = resolve_toolset("hermes-api-server") + assert "web_search" in tools + assert "web_extract" in tools + + def test_toolset_includes_core_tools(self): + tools = resolve_toolset("hermes-api-server") + expected = [ + "terminal", "process", + "read_file", "write_file", "patch", "search_files", + "vision_analyze", "image_generate", + "execute_code", "delegate_task", + "todo", "memory", "session_search", "cronjob", + ] + for tool in expected: + assert tool in tools, f"Missing expected tool: {tool}" + + def test_toolset_includes_browser_tools(self): + tools = resolve_toolset("hermes-api-server") + for tool in ["browser_navigate", "browser_snapshot", "browser_click", + "browser_type", "browser_scroll", "browser_back", + "browser_press", "browser_close"]: + assert tool in tools, f"Missing browser tool: {tool}" + + def test_toolset_includes_homeassistant_tools(self): + tools = resolve_toolset("hermes-api-server") + for tool in ["ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service"]: + assert tool in tools, f"Missing HA tool: {tool}" + + def test_toolset_excludes_clarify(self): + tools = resolve_toolset("hermes-api-server") + assert "clarify" not in tools + + def test_toolset_excludes_send_message(self): + tools = resolve_toolset("hermes-api-server") + assert "send_message" not in tools + + def test_toolset_excludes_text_to_speech(self): + tools = resolve_toolset("hermes-api-server") + assert "text_to_speech" not in tools + + +class TestApiServerPlatformConfig: + def test_platforms_dict_includes_api_server(self): + from hermes_cli.tools_config import PLATFORMS + assert "api_server" in PLATFORMS + assert PLATFORMS["api_server"]["default_toolset"] == "hermes-api-server" + + +class TestApiServerAdapterToolset: + @patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True) + def test_create_agent_reads_config_toolsets(self): + """API server resolves toolsets from config like all other platforms.""" + from gateway.platforms.api_server import APIServerAdapter + from gateway.config import PlatformConfig + + adapter = APIServerAdapter(PlatformConfig()) + + with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \ + patch("gateway.run._resolve_gateway_model") as mock_model, \ + patch("gateway.run._load_gateway_config") as mock_config, \ + patch("run_agent.AIAgent") as mock_agent_cls: + + mock_kwargs.return_value = {"api_key": "test-key", "base_url": None, + "provider": None, "api_mode": None, + "command": None, "args": []} + mock_model.return_value = "test/model" + # No platform_toolsets override — should fall back to hermes-api-server default + mock_config.return_value = {} + mock_agent_cls.return_value = MagicMock() + + adapter._create_agent() + + mock_agent_cls.assert_called_once() + call_kwargs = mock_agent_cls.call_args + toolsets = call_kwargs.kwargs.get("enabled_toolsets") + assert isinstance(toolsets, list) + assert len(toolsets) > 0 + assert call_kwargs.kwargs.get("platform") == "api_server" + + @patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True) + def test_create_agent_respects_config_override(self): + """User can override API server toolsets via platform_toolsets in config.yaml.""" + from gateway.platforms.api_server import APIServerAdapter + from gateway.config import PlatformConfig + + adapter = APIServerAdapter(PlatformConfig()) + + with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \ + patch("gateway.run._resolve_gateway_model") as mock_model, \ + patch("gateway.run._load_gateway_config") as mock_config, \ + patch("run_agent.AIAgent") as mock_agent_cls: + + mock_kwargs.return_value = {"api_key": "test-key", "base_url": None, + "provider": None, "api_mode": None, + "command": None, "args": []} + mock_model.return_value = "test/model" + # User overrides with just web and terminal + mock_config.return_value = { + "platform_toolsets": {"api_server": ["web", "terminal"]} + } + mock_agent_cls.return_value = MagicMock() + + adapter._create_agent() + + mock_agent_cls.assert_called_once() + call_kwargs = mock_agent_cls.call_args + toolsets = call_kwargs.kwargs.get("enabled_toolsets") + assert sorted(toolsets) == ["terminal", "web"] diff --git a/tests/gateway/test_config.py b/tests/gateway/test_config.py index 8dbb725d8..8f24faa99 100644 --- a/tests/gateway/test_config.py +++ b/tests/gateway/test_config.py @@ -1,11 +1,15 @@ """Tests for gateway configuration management.""" +import os +from unittest.mock import patch + from gateway.config import ( GatewayConfig, HomeChannel, Platform, PlatformConfig, SessionResetPolicy, + _apply_env_overrides, load_gateway_config, ) @@ -192,3 +196,75 @@ class TestLoadGatewayConfig: assert config.unauthorized_dm_behavior == "ignore" assert config.platforms[Platform.WHATSAPP].extra["unauthorized_dm_behavior"] == "pair" + + +class TestHomeChannelEnvOverrides: + """Home channel env vars should apply even when the platform was already + configured via config.yaml (not just when credential env vars create it).""" + + def test_existing_platform_configs_accept_home_channel_env_overrides(self): + cases = [ + ( + Platform.SLACK, + PlatformConfig(enabled=True, token="xoxb-from-config"), + {"SLACK_HOME_CHANNEL": "C123", "SLACK_HOME_CHANNEL_NAME": "Ops"}, + ("C123", "Ops"), + ), + ( + Platform.SIGNAL, + PlatformConfig( + enabled=True, + extra={"http_url": "http://localhost:9090", "account": "+15551234567"}, + ), + {"SIGNAL_HOME_CHANNEL": "+1555000", "SIGNAL_HOME_CHANNEL_NAME": "Phone"}, + ("+1555000", "Phone"), + ), + ( + Platform.MATTERMOST, + PlatformConfig( + enabled=True, + token="mm-token", + extra={"url": "https://mm.example.com"}, + ), + {"MATTERMOST_HOME_CHANNEL": "ch_abc123", "MATTERMOST_HOME_CHANNEL_NAME": "General"}, + ("ch_abc123", "General"), + ), + ( + Platform.MATRIX, + PlatformConfig( + enabled=True, + token="syt_abc123", + extra={"homeserver": "https://matrix.example.org"}, + ), + {"MATRIX_HOME_ROOM": "!room123:example.org", "MATRIX_HOME_ROOM_NAME": "Bot Room"}, + ("!room123:example.org", "Bot Room"), + ), + ( + Platform.EMAIL, + PlatformConfig( + enabled=True, + extra={ + "address": "hermes@test.com", + "imap_host": "imap.test.com", + "smtp_host": "smtp.test.com", + }, + ), + {"EMAIL_HOME_ADDRESS": "user@test.com", "EMAIL_HOME_ADDRESS_NAME": "Inbox"}, + ("user@test.com", "Inbox"), + ), + ( + Platform.SMS, + PlatformConfig(enabled=True, api_key="token_abc"), + {"SMS_HOME_CHANNEL": "+15559876543", "SMS_HOME_CHANNEL_NAME": "My Phone"}, + ("+15559876543", "My Phone"), + ), + ] + + for platform, platform_config, env, expected in cases: + config = GatewayConfig(platforms={platform: platform_config}) + with patch.dict(os.environ, env, clear=True): + _apply_env_overrides(config) + + home = config.platforms[platform].home_channel + assert home is not None, f"{platform.value}: home_channel should not be None" + assert (home.chat_id, home.name) == expected, platform.value diff --git a/tests/gateway/test_dm_topics.py b/tests/gateway/test_dm_topics.py index 98c6d4c06..e71d3f82c 100644 --- a/tests/gateway/test_dm_topics.py +++ b/tests/gateway/test_dm_topics.py @@ -10,6 +10,7 @@ Covers: """ import asyncio +import os import sys from pathlib import Path from types import SimpleNamespace @@ -32,7 +33,7 @@ def _ensure_telegram_mock(): telegram_mod.constants.ChatType.CHANNEL = "channel" telegram_mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, telegram_mod) @@ -227,7 +228,8 @@ def test_persist_dm_topic_thread_id_writes_config(tmp_path): adapter = _make_adapter() - with patch.object(Path, "home", return_value=tmp_path): + with patch.object(Path, "home", return_value=tmp_path), \ + patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}): adapter._persist_dm_topic_thread_id(111, "General", 999) with open(config_file) as f: @@ -366,7 +368,8 @@ def test_get_dm_topic_info_hot_reloads_from_config(tmp_path): with open(config_file, "w") as f: yaml.dump(config_data, f) - with patch.object(Path, "home", return_value=tmp_path): + with patch.object(Path, "home", return_value=tmp_path), \ + patch.dict(os.environ, {"HERMES_HOME": str(tmp_path / ".hermes")}): result = adapter._get_dm_topic_info("111", "555") assert result is not None diff --git a/tests/gateway/test_email.py b/tests/gateway/test_email.py index 16a418da8..b6da07921 100644 --- a/tests/gateway/test_email.py +++ b/tests/gateway/test_email.py @@ -1057,5 +1057,122 @@ class TestSendEmailStandalone(unittest.TestCase): self.assertIn("not configured", result["error"]) +class TestSmtpConnectionCleanup(unittest.TestCase): + """Verify SMTP connections are closed even when send_message raises.""" + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_SMTP_HOST": "smtp.test.com", + "EMAIL_SMTP_PORT": "587", + }, clear=False) + def _make_adapter(self): + from gateway.config import PlatformConfig + from gateway.platforms.email import EmailAdapter + return EmailAdapter(PlatformConfig(enabled=True)) + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_SMTP_HOST": "smtp.test.com", + "EMAIL_SMTP_PORT": "587", + }, clear=False) + def test_smtp_quit_called_on_send_message_failure(self): + """SMTP quit() must be called even when send_message() raises.""" + adapter = self._make_adapter() + mock_smtp = MagicMock() + mock_smtp.send_message.side_effect = Exception("send failed") + + with patch("smtplib.SMTP", return_value=mock_smtp): + with self.assertRaises(Exception): + adapter._send_email("user@test.com", "Hello") + + mock_smtp.quit.assert_called_once() + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_SMTP_HOST": "smtp.test.com", + "EMAIL_SMTP_PORT": "587", + }, clear=False) + def test_smtp_close_called_when_quit_also_fails(self): + """If both send_message() and quit() fail, close() is the fallback.""" + adapter = self._make_adapter() + mock_smtp = MagicMock() + mock_smtp.send_message.side_effect = Exception("send failed") + mock_smtp.quit.side_effect = Exception("quit failed") + + with patch("smtplib.SMTP", return_value=mock_smtp): + with self.assertRaises(Exception): + adapter._send_email("user@test.com", "Hello") + + mock_smtp.close.assert_called_once() + + +class TestImapConnectionCleanup(unittest.TestCase): + """Verify IMAP connections are closed even when fetch raises.""" + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_IMAP_PORT": "993", + "EMAIL_SMTP_HOST": "smtp.test.com", + }, clear=False) + def _make_adapter(self): + from gateway.config import PlatformConfig + from gateway.platforms.email import EmailAdapter + return EmailAdapter(PlatformConfig(enabled=True)) + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_IMAP_PORT": "993", + "EMAIL_SMTP_HOST": "smtp.test.com", + }, clear=False) + def test_imap_logout_called_on_uid_fetch_failure(self): + """IMAP logout() must be called even when uid fetch raises.""" + adapter = self._make_adapter() + mock_imap = MagicMock() + + def uid_handler(command, *args): + if command == "search": + return ("OK", [b"1"]) + if command == "fetch": + raise Exception("fetch failed") + return ("NO", []) + + mock_imap.uid.side_effect = uid_handler + + with patch("imaplib.IMAP4_SSL", return_value=mock_imap): + results = adapter._fetch_new_messages() + + self.assertEqual(results, []) + mock_imap.logout.assert_called_once() + + @patch.dict(os.environ, { + "EMAIL_ADDRESS": "hermes@test.com", + "EMAIL_PASSWORD": "secret", + "EMAIL_IMAP_HOST": "imap.test.com", + "EMAIL_IMAP_PORT": "993", + "EMAIL_SMTP_HOST": "smtp.test.com", + }, clear=False) + def test_imap_logout_called_on_early_return(self): + """IMAP logout() must be called even when returning early (no unseen).""" + adapter = self._make_adapter() + mock_imap = MagicMock() + mock_imap.uid.return_value = ("OK", [b""]) + + with patch("imaplib.IMAP4_SSL", return_value=mock_imap): + results = adapter._fetch_new_messages() + + self.assertEqual(results, []) + mock_imap.logout.assert_called_once() + + if __name__ == "__main__": unittest.main() diff --git a/tests/gateway/test_feishu.py b/tests/gateway/test_feishu.py new file mode 100644 index 000000000..5344cda52 --- /dev/null +++ b/tests/gateway/test_feishu.py @@ -0,0 +1,2580 @@ +"""Tests for the Feishu gateway integration.""" + +import asyncio +import json +import os +import tempfile +import time +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +try: + import lark_oapi + _HAS_LARK_OAPI = True +except ImportError: + _HAS_LARK_OAPI = False + + +class TestPlatformEnum(unittest.TestCase): + def test_feishu_in_platform_enum(self): + from gateway.config import Platform + + self.assertEqual(Platform.FEISHU.value, "feishu") + + +class TestConfigEnvOverrides(unittest.TestCase): + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_xxx", + "FEISHU_APP_SECRET": "secret_xxx", + "FEISHU_CONNECTION_MODE": "websocket", + "FEISHU_DOMAIN": "feishu", + }, clear=False) + def test_feishu_config_loaded_from_env(self): + from gateway.config import GatewayConfig, Platform, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + + self.assertIn(Platform.FEISHU, config.platforms) + self.assertTrue(config.platforms[Platform.FEISHU].enabled) + self.assertEqual(config.platforms[Platform.FEISHU].extra["app_id"], "cli_xxx") + self.assertEqual(config.platforms[Platform.FEISHU].extra["connection_mode"], "websocket") + + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_xxx", + "FEISHU_APP_SECRET": "secret_xxx", + "FEISHU_HOME_CHANNEL": "oc_xxx", + }, clear=False) + def test_feishu_home_channel_loaded(self): + from gateway.config import GatewayConfig, Platform, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + + home = config.platforms[Platform.FEISHU].home_channel + self.assertIsNotNone(home) + self.assertEqual(home.chat_id, "oc_xxx") + + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_xxx", + "FEISHU_APP_SECRET": "secret_xxx", + }, clear=False) + def test_feishu_in_connected_platforms(self): + from gateway.config import GatewayConfig, Platform, _apply_env_overrides + + config = GatewayConfig() + _apply_env_overrides(config) + + self.assertIn(Platform.FEISHU, config.get_connected_platforms()) + + +class TestGatewayIntegration(unittest.TestCase): + def test_feishu_in_adapter_factory(self): + source = Path("gateway/run.py").read_text(encoding="utf-8") + self.assertIn("Platform.FEISHU", source) + self.assertIn("FeishuAdapter", source) + + def test_feishu_in_authorization_maps(self): + source = Path("gateway/run.py").read_text(encoding="utf-8") + self.assertIn("FEISHU_ALLOWED_USERS", source) + self.assertIn("FEISHU_ALLOW_ALL_USERS", source) + + def test_feishu_toolset_exists(self): + from toolsets import TOOLSETS + + self.assertIn("hermes-feishu", TOOLSETS) + self.assertIn("hermes-feishu", TOOLSETS["hermes-gateway"]["includes"]) + + +class TestFeishuPostParsing(unittest.TestCase): + def test_parse_post_content_extracts_text_mentions_and_media_refs(self): + from gateway.platforms.feishu import parse_feishu_post_content + + result = parse_feishu_post_content( + json.dumps( + { + "en_us": { + "title": "Rich message", + "content": [ + [{"tag": "img", "image_key": "img_1", "alt": "diagram"}], + [{"tag": "at", "user_name": "Alice", "open_id": "ou_alice"}], + [{"tag": "media", "file_key": "file_1", "file_name": "spec.pdf"}], + ], + } + } + ) + ) + + self.assertEqual(result.text_content, "Rich message\n[Image: diagram]\n@Alice\n[Attachment: spec.pdf]") + self.assertEqual(result.image_keys, ["img_1"]) + self.assertEqual(result.mentioned_ids, ["ou_alice"]) + self.assertEqual(len(result.media_refs), 1) + self.assertEqual(result.media_refs[0].file_key, "file_1") + self.assertEqual(result.media_refs[0].file_name, "spec.pdf") + self.assertEqual(result.media_refs[0].resource_type, "file") + + def test_parse_post_content_uses_fallback_when_invalid(self): + from gateway.platforms.feishu import FALLBACK_POST_TEXT, parse_feishu_post_content + + result = parse_feishu_post_content("not-json") + + self.assertEqual(result.text_content, FALLBACK_POST_TEXT) + self.assertEqual(result.image_keys, []) + self.assertEqual(result.media_refs, []) + self.assertEqual(result.mentioned_ids, []) + + def test_parse_post_content_preserves_rich_text_semantics(self): + from gateway.platforms.feishu import parse_feishu_post_content + + result = parse_feishu_post_content( + json.dumps( + { + "en_us": { + "title": "Plan *v2*", + "content": [ + [ + {"tag": "text", "text": "Bold", "style": {"bold": True}}, + {"tag": "text", "text": " "}, + {"tag": "text", "text": "Italic", "style": {"italic": True}}, + {"tag": "text", "text": " "}, + {"tag": "text", "text": "Code", "style": {"code": True}}, + ], + [{"tag": "text", "text": "line1"}, {"tag": "br"}, {"tag": "text", "text": "line2"}], + [{"tag": "hr"}], + [{"tag": "code_block", "language": "python", "text": "print('hi')"}], + ], + } + } + ) + ) + + self.assertEqual( + result.text_content, + "Plan *v2*\n**Bold** *Italic* `Code`\nline1\nline2\n---\n```python\nprint('hi')\n```", + ) + + +class TestFeishuMessageNormalization(unittest.TestCase): + def test_normalize_merge_forward_preserves_summary_lines(self): + from gateway.platforms.feishu import normalize_feishu_message + + normalized = normalize_feishu_message( + message_type="merge_forward", + raw_content=json.dumps( + { + "title": "Sprint recap", + "messages": [ + {"sender_name": "Alice", "text": "Please review PR-128"}, + { + "sender_name": "Bob", + "message_type": "post", + "content": { + "en_us": { + "content": [[{"tag": "text", "text": "Ship it"}]], + } + }, + }, + ], + } + ), + ) + + self.assertEqual(normalized.relation_kind, "merge_forward") + self.assertEqual( + normalized.text_content, + "Sprint recap\n- Alice: Please review PR-128\n- Bob: Ship it", + ) + + def test_normalize_share_chat_exposes_summary_and_metadata(self): + from gateway.platforms.feishu import normalize_feishu_message + + normalized = normalize_feishu_message( + message_type="share_chat", + raw_content=json.dumps( + { + "chat_id": "oc_chat_shared", + "chat_name": "Backend Guild", + } + ), + ) + + self.assertEqual(normalized.relation_kind, "share_chat") + self.assertEqual(normalized.text_content, "Shared chat: Backend Guild\nChat ID: oc_chat_shared") + self.assertEqual(normalized.metadata["chat_id"], "oc_chat_shared") + self.assertEqual(normalized.metadata["chat_name"], "Backend Guild") + + def test_normalize_interactive_card_preserves_title_body_and_actions(self): + from gateway.platforms.feishu import normalize_feishu_message + + normalized = normalize_feishu_message( + message_type="interactive", + raw_content=json.dumps( + { + "card": { + "header": {"title": {"tag": "plain_text", "content": "Build Failed"}}, + "elements": [ + {"tag": "div", "text": {"tag": "lark_md", "content": "Service: payments-api"}}, + {"tag": "div", "text": {"tag": "plain_text", "content": "Branch: main"}}, + { + "tag": "action", + "actions": [ + {"tag": "button", "text": {"tag": "plain_text", "content": "View Logs"}}, + {"tag": "button", "text": {"tag": "plain_text", "content": "Retry"}}, + ], + }, + ], + } + } + ), + ) + + self.assertEqual(normalized.relation_kind, "interactive") + self.assertEqual( + normalized.text_content, + "Build Failed\nService: payments-api\nBranch: main\nView Logs\nRetry\nActions: View Logs, Retry", + ) + + +class TestFeishuAdapterMessaging(unittest.TestCase): + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_app", + "FEISHU_APP_SECRET": "secret_app", + "FEISHU_CONNECTION_MODE": "webhook", + "FEISHU_WEBHOOK_HOST": "127.0.0.1", + "FEISHU_WEBHOOK_PORT": "9001", + "FEISHU_WEBHOOK_PATH": "/hook", + }, clear=True) + def test_connect_webhook_mode_starts_local_server(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + runner = AsyncMock() + site = AsyncMock() + web_module = SimpleNamespace( + Application=lambda: SimpleNamespace(router=SimpleNamespace(add_post=lambda *_args, **_kwargs: None)), + AppRunner=lambda _app: runner, + TCPSite=lambda _runner, host, port: SimpleNamespace(start=site.start, host=host, port=port), + ) + + with ( + patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True), + patch("gateway.platforms.feishu.FEISHU_WEBHOOK_AVAILABLE", True), + patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)), + patch("gateway.platforms.feishu.release_scoped_lock"), + patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()), + patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()), + patch("gateway.platforms.feishu.web", web_module), + ): + connected = asyncio.run(adapter.connect()) + + self.assertTrue(connected) + runner.setup.assert_awaited_once() + site.start.assert_awaited_once() + + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_app", + "FEISHU_APP_SECRET": "secret_app", + }, clear=True) + def test_connect_acquires_scoped_lock_and_disconnect_releases_it(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + ws_client = object() + + with ( + patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True), + patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True), + patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))), + patch("gateway.platforms.feishu.EventDispatcherHandler", object()), + patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client), + patch("gateway.platforms.feishu._run_official_feishu_ws_client"), + patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)) as acquire_lock, + patch("gateway.platforms.feishu.release_scoped_lock") as release_lock, + patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()), + patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()), + ): + loop = asyncio.new_event_loop() + future = loop.create_future() + future.set_result(None) + + class _Loop: + def run_in_executor(self, *_args, **_kwargs): + return future + + try: + with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=_Loop()): + connected = asyncio.run(adapter.connect()) + asyncio.run(adapter.disconnect()) + finally: + loop.close() + + self.assertTrue(connected) + acquire_lock.assert_called_once_with( + "feishu-app-id", + "cli_app", + metadata={"platform": "feishu"}, + ) + release_lock.assert_called_once_with("feishu-app-id", "cli_app") + + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_app", + "FEISHU_APP_SECRET": "secret_app", + }, clear=True) + def test_connect_rejects_existing_app_lock(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + + with ( + patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True), + patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True), + patch( + "gateway.platforms.feishu.acquire_scoped_lock", + return_value=(False, {"pid": 4321}), + ), + ): + connected = asyncio.run(adapter.connect()) + + self.assertFalse(connected) + self.assertEqual(adapter.fatal_error_code, "feishu_app_lock") + self.assertFalse(adapter.fatal_error_retryable) + self.assertIn("PID 4321", adapter.fatal_error_message) + + @patch.dict(os.environ, { + "FEISHU_APP_ID": "cli_app", + "FEISHU_APP_SECRET": "secret_app", + }, clear=True) + def test_connect_retries_transient_startup_failure(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + ws_client = object() + sleeps = [] + + with ( + patch("gateway.platforms.feishu.FEISHU_AVAILABLE", True), + patch("gateway.platforms.feishu.FEISHU_WEBSOCKET_AVAILABLE", True), + patch("gateway.platforms.feishu.lark", SimpleNamespace(LogLevel=SimpleNamespace(INFO="INFO", WARNING="WARNING"))), + patch("gateway.platforms.feishu.EventDispatcherHandler", object()), + patch("gateway.platforms.feishu.FeishuWSClient", return_value=ws_client), + patch("gateway.platforms.feishu.acquire_scoped_lock", return_value=(True, None)), + patch("gateway.platforms.feishu.release_scoped_lock"), + patch.object(adapter, "_hydrate_bot_identity", new=AsyncMock()), + patch("gateway.platforms.feishu.asyncio.sleep", side_effect=lambda delay: sleeps.append(delay)), + patch.object(adapter, "_build_lark_client", return_value=SimpleNamespace()), + ): + loop = asyncio.new_event_loop() + future = loop.create_future() + future.set_result(None) + + class _Loop: + def __init__(self): + self.calls = 0 + + def run_in_executor(self, *_args, **_kwargs): + self.calls += 1 + if self.calls == 1: + raise OSError("temporary websocket failure") + return future + + fake_loop = _Loop() + try: + with patch("gateway.platforms.feishu.asyncio.get_running_loop", return_value=fake_loop): + connected = asyncio.run(adapter.connect()) + finally: + loop.close() + + self.assertTrue(connected) + self.assertEqual(sleeps, [1]) + self.assertEqual(fake_loop.calls, 2) + + @patch.dict(os.environ, {}, clear=True) + def test_edit_message_updates_existing_feishu_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _MessageAPI: + def update(self, request): + captured["request"] = request + return SimpleNamespace(success=lambda: True) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.edit_message( + chat_id="oc_chat", + message_id="om_progress", + content="📖 read_file: \"/tmp/image.png\"", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(result.message_id, "om_progress") + self.assertEqual(captured["request"].message_id, "om_progress") + self.assertEqual(captured["request"].request_body.msg_type, "text") + self.assertEqual( + captured["request"].request_body.content, + json.dumps({"text": "📖 read_file: \"/tmp/image.png\""}, ensure_ascii=False), + ) + + @patch.dict(os.environ, {}, clear=True) + def test_edit_message_falls_back_to_text_when_post_update_is_rejected(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {"calls": []} + + class _MessageAPI: + def update(self, request): + captured["calls"].append(request) + if len(captured["calls"]) == 1: + return SimpleNamespace(success=lambda: False, code=230001, msg="content format of the post type is incorrect") + return SimpleNamespace(success=lambda: True) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.edit_message( + chat_id="oc_chat", + message_id="om_progress", + content="可以用 **粗体** 和 *斜体*。", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(captured["calls"][0].request_body.msg_type, "post") + self.assertEqual(captured["calls"][1].request_body.msg_type, "text") + self.assertEqual( + captured["calls"][1].request_body.content, + json.dumps({"text": "可以用 粗体 和 斜体。"}, ensure_ascii=False), + ) + + @patch.dict(os.environ, {}, clear=True) + def test_get_chat_info_uses_real_feishu_chat_api(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + + class _ChatAPI: + def get(self, request): + self.request = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(name="Hermes Group", chat_type="group"), + ) + + chat_api = _ChatAPI() + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + chat=chat_api, + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + info = asyncio.run(adapter.get_chat_info("oc_chat")) + + self.assertEqual(chat_api.request.chat_id, "oc_chat") + self.assertEqual(info["chat_id"], "oc_chat") + self.assertEqual(info["name"], "Hermes Group") + self.assertEqual(info["type"], "group") + +class TestAdapterModule(unittest.TestCase): + def test_adapter_requirement_helper_exists(self): + source = Path("gateway/platforms/feishu.py").read_text(encoding="utf-8") + self.assertIn("def check_feishu_requirements()", source) + self.assertIn("FEISHU_AVAILABLE", source) + + def test_adapter_declares_websocket_scope(self): + source = Path("gateway/platforms/feishu.py").read_text(encoding="utf-8") + self.assertIn("Supported modes: websocket, webhook", source) + self.assertIn("FEISHU_CONNECTION_MODE", source) + + def test_adapter_registers_message_read_noop_handler(self): + source = Path("gateway/platforms/feishu.py").read_text(encoding="utf-8") + self.assertIn("register_p2_im_message_message_read_v1", source) + self.assertIn("def _on_message_read_event", source) + + def test_adapter_registers_reaction_and_card_handlers_for_websocket(self): + source = Path("gateway/platforms/feishu.py").read_text(encoding="utf-8") + self.assertIn("register_p2_im_message_reaction_created_v1", source) + self.assertIn("register_p2_im_message_reaction_deleted_v1", source) + self.assertIn("register_p2_card_action_trigger", source) + + +class TestAdapterBehavior(unittest.TestCase): + @patch.dict(os.environ, {}, clear=True) + def test_build_event_handler_registers_reaction_and_card_processors(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + calls = [] + + class _Builder: + def register_p2_im_message_message_read_v1(self, _handler): + calls.append("message_read") + return self + + def register_p2_im_message_receive_v1(self, _handler): + calls.append("message_receive") + return self + + def register_p2_im_message_reaction_created_v1(self, _handler): + calls.append("reaction_created") + return self + + def register_p2_im_message_reaction_deleted_v1(self, _handler): + calls.append("reaction_deleted") + return self + + def register_p2_card_action_trigger(self, _handler): + calls.append("card_action") + return self + + def build(self): + calls.append("build") + return "handler" + + class _Dispatcher: + @staticmethod + def builder(_encrypt_key, _verification_token): + calls.append("builder") + return _Builder() + + with patch("gateway.platforms.feishu.EventDispatcherHandler", _Dispatcher): + handler = adapter._build_event_handler() + + self.assertEqual(handler, "handler") + self.assertEqual( + calls, + [ + "builder", + "message_read", + "message_receive", + "reaction_created", + "reaction_deleted", + "card_action", + "build", + ], + ) + + @patch.dict(os.environ, {}, clear=True) + @unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed") + def test_add_ack_reaction_uses_ok_emoji(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _ReactionAPI: + def create(self, request): + captured["request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(reaction_id="r_typing"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI())) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg")) + + self.assertEqual(reaction_id, "r_typing") + self.assertEqual(captured["request"].request_body.reaction_type["emoji_type"], "OK") + + @patch.dict(os.environ, {}, clear=True) + def test_add_ack_reaction_logs_warning_on_failure(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + + class _ReactionAPI: + def create(self, request): + raise RuntimeError("boom") + + adapter._client = SimpleNamespace( + im=SimpleNamespace(v1=SimpleNamespace(message_reaction=_ReactionAPI())) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with ( + patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct), + self.assertLogs("gateway.platforms.feishu", level="WARNING") as logs, + ): + reaction_id = asyncio.run(adapter._add_ack_reaction("om_msg")) + + self.assertIsNone(reaction_id) + self.assertTrue( + any("Failed to add ack reaction to om_msg" in entry for entry in logs.output), + logs.output, + ) + + @patch.dict(os.environ, {}, clear=True) + def test_ack_reaction_events_are_ignored_to_avoid_feedback_loops(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._loop = object() + event = SimpleNamespace( + message_id="om_msg", + operator_type="user", + reaction_type=SimpleNamespace(emoji_type="OK"), + ) + data = SimpleNamespace(event=event) + + with patch("gateway.platforms.feishu.asyncio.run_coroutine_threadsafe") as run_threadsafe: + adapter._on_reaction_event("im.message.reaction.created_v1", data) + + run_threadsafe.assert_not_called() + + @patch.dict(os.environ, {}, clear=True) + def test_normalize_inbound_text_strips_feishu_mentions(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + cleaned = adapter._normalize_inbound_text("hi @_user_1 there @_user_2") + self.assertEqual(cleaned, "hi there") + + @patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True) + def test_group_message_requires_mentions_even_when_policy_open(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace(mentions=[]) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + self.assertFalse(adapter._should_accept_group_message(message, sender_id)) + + message_with_mention = SimpleNamespace(mentions=[SimpleNamespace(key="@_user_1")]) + self.assertFalse(adapter._should_accept_group_message(message_with_mention, sender_id)) + + @patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True) + def test_group_message_with_other_user_mention_is_rejected_when_bot_identity_unknown(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + other_mention = SimpleNamespace( + name="Other User", + id=SimpleNamespace(open_id="ou_other", user_id="u_other"), + ) + + self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id)) + + @patch.dict( + os.environ, + { + "FEISHU_GROUP_POLICY": "allowlist", + "FEISHU_ALLOWED_USERS": "ou_allowed", + "FEISHU_BOT_NAME": "Hermes Bot", + }, + clear=True, + ) + def test_group_message_allowlist_and_mention_both_required(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + mentioned = SimpleNamespace( + mentions=[ + SimpleNamespace( + name="Hermes Bot", + id=SimpleNamespace(open_id="ou_other", user_id="u_other"), + ) + ] + ) + + self.assertTrue( + adapter._should_accept_group_message( + mentioned, + SimpleNamespace(open_id="ou_allowed", user_id=None), + ) + ) + self.assertFalse( + adapter._should_accept_group_message( + mentioned, + SimpleNamespace(open_id="ou_blocked", user_id=None), + ) + ) + + @patch.dict( + os.environ, + { + "FEISHU_GROUP_POLICY": "open", + "FEISHU_BOT_OPEN_ID": "ou_bot", + }, + clear=True, + ) + def test_group_message_matches_bot_open_id_when_configured(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + + bot_mention = SimpleNamespace( + name="Hermes", + id=SimpleNamespace(open_id="ou_bot", user_id="u_bot"), + ) + other_mention = SimpleNamespace( + name="Other", + id=SimpleNamespace(open_id="ou_other", user_id="u_other"), + ) + + self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[bot_mention]), sender_id)) + self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[other_mention]), sender_id)) + + @patch.dict( + os.environ, + { + "FEISHU_GROUP_POLICY": "open", + "FEISHU_BOT_NAME": "Hermes Bot", + }, + clear=True, + ) + def test_group_message_matches_bot_name_when_only_name_available(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + + named_mention = SimpleNamespace( + name="Hermes Bot", + id=SimpleNamespace(open_id="ou_other", user_id="u_other"), + ) + different_mention = SimpleNamespace( + name="Another Bot", + id=SimpleNamespace(open_id="ou_other", user_id="u_other"), + ) + + self.assertTrue(adapter._should_accept_group_message(SimpleNamespace(mentions=[named_mention]), sender_id)) + self.assertFalse(adapter._should_accept_group_message(SimpleNamespace(mentions=[different_mention]), sender_id)) + + @patch.dict( + os.environ, + { + "FEISHU_GROUP_POLICY": "open", + "FEISHU_BOT_OPEN_ID": "ou_bot", + }, + clear=True, + ) + def test_group_post_message_uses_parsed_mentions_when_sdk_mentions_missing(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + message = SimpleNamespace( + message_type="post", + mentions=[], + content='{"en_us":{"content":[[{"tag":"at","user_name":"Hermes","open_id":"ou_bot"}]]}}', + ) + + self.assertTrue(adapter._should_accept_group_message(message, sender_id)) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_post_message_as_text(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="post", + content='{"zh_cn":{"title":"Title","content":[[{"tag":"text","text":"hello "}],[{"tag":"a","text":"doc","href":"https://example.com"}]]}}', + message_id="om_post", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Title\nhello\n[doc](https://example.com)") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_post_message_uses_first_available_language_block(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="post", + content='{"fr_fr":{"title":"Subject","content":[[{"tag":"text","text":"bonjour"}]]}}', + message_id="om_post_fr", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Subject\nbonjour") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_post_message_with_rich_elements_does_not_drop_content(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="post", + content=( + '{"en_us":{"title":"Rich message","content":[' + '[{"tag":"img","alt":"diagram"}],' + '[{"tag":"at","user_name":"Alice"},{"tag":"text","text":" please check the attachment"}],' + '[{"tag":"media","file_name":"spec.pdf"}],' + '[{"tag":"emotion","emoji_type":"smile"}]' + ']}}' + ), + message_id="om_post_rich", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Rich message\n[Image: diagram]\n@Alice please check the attachment\n[Attachment: spec.pdf]\n:smile:") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_post_message_downloads_embedded_resources(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_image = AsyncMock(return_value=("/tmp/feishu-image.png", "image/png")) + adapter._download_feishu_message_resource = AsyncMock(return_value=("/tmp/spec.pdf", "application/pdf")) + message = SimpleNamespace( + message_type="post", + content=( + '{"en_us":{"title":"Rich message","content":[' + '[{"tag":"img","image_key":"img_123","alt":"diagram"}],' + '[{"tag":"media","file_key":"file_123","file_name":"spec.pdf"}]' + ']}}' + ), + message_id="om_post_media", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Rich message\n[Image: diagram]\n[Attachment: spec.pdf]") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, ["/tmp/feishu-image.png", "/tmp/spec.pdf"]) + self.assertEqual(media_types, ["image/png", "application/pdf"]) + adapter._download_feishu_image.assert_awaited_once_with( + message_id="om_post_media", + image_key="img_123", + ) + adapter._download_feishu_message_resource.assert_awaited_once_with( + message_id="om_post_media", + file_key="file_123", + resource_type="file", + fallback_filename="spec.pdf", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_merge_forward_message_as_text_summary(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="merge_forward", + content=json.dumps( + { + "title": "Forwarded updates", + "messages": [ + {"sender_name": "Alice", "text": "Investigating the incident"}, + {"sender_name": "Bob", "text": "ETA 10 minutes"}, + ], + } + ), + message_id="om_merge_forward", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual( + text, + "Forwarded updates\n- Alice: Investigating the incident\n- Bob: ETA 10 minutes", + ) + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_share_chat_message_as_text_summary(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="share_chat", + content='{"chat_id":"oc_shared","chat_name":"Platform Ops"}', + message_id="om_share_chat", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Shared chat: Platform Ops\nChat ID: oc_shared") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_interactive_message_as_text_summary(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + message_type="interactive", + content=json.dumps( + { + "card": { + "header": {"title": {"tag": "plain_text", "content": "Approval Request"}}, + "elements": [ + {"tag": "div", "text": {"tag": "plain_text", "content": "Requester: Alice"}}, + { + "tag": "action", + "actions": [ + {"tag": "button", "text": {"tag": "plain_text", "content": "Approve"}}, + ], + }, + ], + } + } + ), + message_id="om_interactive", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "Approval Request\nRequester: Alice\nApprove\nActions: Approve") + self.assertEqual(msg_type.value, "text") + self.assertEqual(media_urls, []) + self.assertEqual(media_types, []) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_image_message_downloads_and_caches(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_image = AsyncMock(return_value=("/tmp/feishu-image.png", "image/png")) + message = SimpleNamespace( + message_type="image", + content='{"image_key":"img_123"}', + message_id="om_image", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "") + self.assertEqual(msg_type.value, "photo") + self.assertEqual(media_urls, ["/tmp/feishu-image.png"]) + self.assertEqual(media_types, ["image/png"]) + adapter._download_feishu_image.assert_awaited_once_with( + message_id="om_image", + image_key="img_123", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_audio_message_downloads_and_caches(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_message_resource = AsyncMock( + return_value=("/tmp/feishu-audio.ogg", "audio/ogg") + ) + message = SimpleNamespace( + message_type="audio", + content='{"file_key":"file_audio","file_name":"voice.ogg"}', + message_id="om_audio", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "") + self.assertEqual(msg_type.value, "audio") + self.assertEqual(media_urls, ["/tmp/feishu-audio.ogg"]) + self.assertEqual(media_types, ["audio/ogg"]) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_file_message_downloads_and_caches(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_message_resource = AsyncMock( + return_value=("/tmp/doc_123_report.pdf", "application/pdf") + ) + message = SimpleNamespace( + message_type="file", + content='{"file_key":"file_doc","file_name":"report.pdf"}', + message_id="om_file", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "") + self.assertEqual(msg_type.value, "document") + self.assertEqual(media_urls, ["/tmp/doc_123_report.pdf"]) + self.assertEqual(media_types, ["application/pdf"]) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_media_message_with_image_mime_becomes_photo(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_message_resource = AsyncMock( + return_value=("/tmp/feishu-media.jpg", "image/jpeg") + ) + message = SimpleNamespace( + message_type="media", + content='{"file_key":"file_media","file_name":"photo.jpg"}', + message_id="om_media", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "") + self.assertEqual(msg_type.value, "photo") + self.assertEqual(media_urls, ["/tmp/feishu-media.jpg"]) + self.assertEqual(media_types, ["image/jpeg"]) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_media_message_with_video_mime_becomes_video(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._download_feishu_message_resource = AsyncMock( + return_value=("/tmp/feishu-video.mp4", "video/mp4") + ) + message = SimpleNamespace( + message_type="media", + content='{"file_key":"file_video","file_name":"clip.mp4"}', + message_id="om_video", + ) + + text, msg_type, media_urls, media_types = asyncio.run(adapter._extract_message_content(message)) + + self.assertEqual(text, "") + self.assertEqual(msg_type.value, "video") + self.assertEqual(media_urls, ["/tmp/feishu-video.mp4"]) + self.assertEqual(media_types, ["video/mp4"]) + + @patch.dict(os.environ, {}, clear=True) + def test_extract_text_from_raw_content_uses_relation_message_fallbacks(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + + shared = adapter._extract_text_from_raw_content( + msg_type="share_chat", + raw_content='{"chat_id":"oc_shared","chat_name":"Platform Ops"}', + ) + attachment = adapter._extract_text_from_raw_content( + msg_type="file", + raw_content='{"file_key":"file_1","file_name":"report.pdf"}', + ) + + self.assertEqual(shared, "Shared chat: Platform Ops\nChat ID: oc_shared") + self.assertEqual(attachment, "[Attachment: report.pdf]") + + @patch.dict(os.environ, {}, clear=True) + def test_extract_text_message_starting_with_slash_becomes_command(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._dispatch_inbound_event = AsyncMock() + adapter.get_chat_info = AsyncMock( + return_value={"chat_id": "oc_chat", "name": "Feishu DM", "type": "dm"} + ) + adapter._resolve_sender_profile = AsyncMock( + return_value={"user_id": "ou_user", "user_name": "张三", "user_id_alt": None} + ) + message = SimpleNamespace( + chat_id="oc_chat", + thread_id=None, + parent_id=None, + upper_message_id=None, + message_type="text", + content='{"text":"/help test"}', + message_id="om_command", + ) + + asyncio.run( + adapter._process_inbound_message( + data=SimpleNamespace(event=SimpleNamespace(message=message)), + message=message, + sender_id=SimpleNamespace(open_id="ou_user", user_id=None, union_id=None), + chat_type="p2p", + message_id="om_command", + ) + ) + + event = adapter._dispatch_inbound_event.await_args.args[0] + self.assertEqual(event.message_type.value, "command") + self.assertEqual(event.text, "/help test") + + @patch.dict(os.environ, {}, clear=True) + def test_extract_text_file_injects_content(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as tmp: + tmp.write("hello from feishu") + path = tmp.name + + try: + text = asyncio.run(adapter._maybe_extract_text_document(path, "text/plain")) + finally: + os.unlink(path) + + self.assertIn("hello from feishu", text) + self.assertIn("[Content of", text) + + @patch.dict(os.environ, {}, clear=True) + def test_message_event_submits_to_adapter_loop(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._loop = object() + + message = SimpleNamespace( + message_id="om_text", + chat_type="p2p", + chat_id="oc_chat", + message_type="text", + content='{"text":"hello"}', + ) + sender_id = SimpleNamespace(open_id="ou_user", user_id=None, union_id=None) + sender = SimpleNamespace(sender_id=sender_id, sender_type="user") + data = SimpleNamespace(event=SimpleNamespace(message=message, sender=sender)) + + future = SimpleNamespace(add_done_callback=lambda *_args, **_kwargs: None) + def _submit(coro, _loop): + coro.close() + return future + + with patch("gateway.platforms.feishu.asyncio.run_coroutine_threadsafe", side_effect=_submit) as submit: + adapter._on_message_event(data) + + self.assertTrue(submit.called) + + @patch.dict(os.environ, {}, clear=True) + def test_process_inbound_message_uses_event_sender_identity_only(self): + from gateway.config import PlatformConfig + from gateway.platforms.base import MessageType + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._dispatch_inbound_event = AsyncMock() + # Sender name now comes from the contact API; mock it to return a known value. + adapter._resolve_sender_name_from_api = AsyncMock(return_value="张三") + adapter.get_chat_info = AsyncMock( + return_value={"chat_id": "oc_chat", "name": "Feishu DM", "type": "dm"} + ) + message = SimpleNamespace( + chat_id="oc_chat", + thread_id=None, + message_type="text", + content='{"text":"hello"}', + message_id="om_text", + ) + sender_id = SimpleNamespace( + open_id="ou_user", + user_id="u_user", + union_id="on_union", + ) + data = SimpleNamespace(event=SimpleNamespace(message=message, sender=SimpleNamespace(sender_id=sender_id))) + + asyncio.run( + adapter._process_inbound_message( + data=data, + message=message, + sender_id=sender_id, + chat_type="p2p", + message_id="om_text", + ) + ) + + adapter._dispatch_inbound_event.assert_awaited_once() + event = adapter._dispatch_inbound_event.await_args.args[0] + self.assertEqual(event.message_type, MessageType.TEXT) + self.assertEqual(event.source.user_id, "ou_user") + self.assertEqual(event.source.user_name, "张三") + self.assertEqual(event.source.user_id_alt, "on_union") + self.assertEqual(event.source.chat_name, "Feishu DM") + + @patch.dict(os.environ, {}, clear=True) + def test_text_batch_merges_rapid_messages_into_single_event(self): + from gateway.config import PlatformConfig + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.feishu import FeishuAdapter + from gateway.session import SessionSource + + adapter = FeishuAdapter(PlatformConfig()) + adapter.handle_message = AsyncMock() + source = SessionSource( + platform=adapter.platform, + chat_id="oc_chat", + chat_name="Feishu DM", + chat_type="dm", + user_id="ou_user", + user_name="张三", + ) + + async def _sleep(_delay): + return None + + async def _run() -> None: + with patch("gateway.platforms.feishu.asyncio.sleep", side_effect=_sleep): + await adapter._dispatch_inbound_event( + MessageEvent(text="A", message_type=MessageType.TEXT, source=source, message_id="om_1") + ) + await adapter._dispatch_inbound_event( + MessageEvent(text="B", message_type=MessageType.TEXT, source=source, message_id="om_2") + ) + pending = list(adapter._pending_text_batch_tasks.values()) + self.assertEqual(len(pending), 1) + await asyncio.gather(*pending, return_exceptions=True) + + asyncio.run(_run()) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + self.assertEqual(event.text, "A\nB") + self.assertEqual(event.message_type, MessageType.TEXT) + + @patch.dict( + os.environ, + { + "HERMES_FEISHU_TEXT_BATCH_MAX_MESSAGES": "2", + }, + clear=True, + ) + def test_text_batch_flushes_when_message_count_limit_is_hit(self): + from gateway.config import PlatformConfig + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.feishu import FeishuAdapter + from gateway.session import SessionSource + + adapter = FeishuAdapter(PlatformConfig()) + adapter.handle_message = AsyncMock() + source = SessionSource( + platform=adapter.platform, + chat_id="oc_chat", + chat_name="Feishu DM", + chat_type="dm", + user_id="ou_user", + user_name="张三", + ) + + async def _sleep(_delay): + return None + + async def _run() -> None: + with patch("gateway.platforms.feishu.asyncio.sleep", side_effect=_sleep): + await adapter._dispatch_inbound_event( + MessageEvent(text="A", message_type=MessageType.TEXT, source=source, message_id="om_1") + ) + await adapter._dispatch_inbound_event( + MessageEvent(text="B", message_type=MessageType.TEXT, source=source, message_id="om_2") + ) + await adapter._dispatch_inbound_event( + MessageEvent(text="C", message_type=MessageType.TEXT, source=source, message_id="om_3") + ) + pending = list(adapter._pending_text_batch_tasks.values()) + self.assertEqual(len(pending), 1) + await asyncio.gather(*pending, return_exceptions=True) + + asyncio.run(_run()) + + self.assertEqual(adapter.handle_message.await_count, 2) + first = adapter.handle_message.await_args_list[0].args[0] + second = adapter.handle_message.await_args_list[1].args[0] + self.assertEqual(first.text, "A\nB") + self.assertEqual(second.text, "C") + + @patch.dict(os.environ, {}, clear=True) + def test_media_batch_merges_rapid_photo_messages(self): + from gateway.config import PlatformConfig + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.feishu import FeishuAdapter + from gateway.session import SessionSource + + adapter = FeishuAdapter(PlatformConfig()) + adapter.handle_message = AsyncMock() + source = SessionSource( + platform=adapter.platform, + chat_id="oc_chat", + chat_name="Feishu DM", + chat_type="dm", + user_id="ou_user", + user_name="张三", + ) + + async def _sleep(_delay): + return None + + async def _run() -> None: + with patch("gateway.platforms.feishu.asyncio.sleep", side_effect=_sleep): + await adapter._dispatch_inbound_event( + MessageEvent( + text="第一张", + message_type=MessageType.PHOTO, + source=source, + message_id="om_p1", + media_urls=["/tmp/a.png"], + media_types=["image/png"], + ) + ) + await adapter._dispatch_inbound_event( + MessageEvent( + text="第二张", + message_type=MessageType.PHOTO, + source=source, + message_id="om_p2", + media_urls=["/tmp/b.png"], + media_types=["image/png"], + ) + ) + pending = list(adapter._pending_media_batch_tasks.values()) + self.assertEqual(len(pending), 1) + await asyncio.gather(*pending, return_exceptions=True) + + asyncio.run(_run()) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + self.assertEqual(event.media_urls, ["/tmp/a.png", "/tmp/b.png"]) + self.assertIn("第一张", event.text) + self.assertIn("第二张", event.text) + + @patch.dict(os.environ, {}, clear=True) + def test_send_image_downloads_then_uses_native_image_send(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter.send_image_file = AsyncMock(return_value=SimpleNamespace(success=True, message_id="om_img")) + + async def _run(): + with patch("gateway.platforms.feishu.cache_image_from_url", new=AsyncMock(return_value="/tmp/cached.png")): + return await adapter.send_image("oc_chat", "https://example.com/cat.png", caption="cat") + + result = asyncio.run(_run()) + + self.assertTrue(result.success) + adapter.send_image_file.assert_awaited_once() + self.assertEqual(adapter.send_image_file.await_args.kwargs["image_path"], "/tmp/cached.png") + + @patch.dict(os.environ, {}, clear=True) + def test_send_animation_degrades_to_document_send(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter.send_document = AsyncMock(return_value=SimpleNamespace(success=True, message_id="om_gif")) + + async def _run(): + with patch.object( + adapter, + "_download_remote_document", + new=AsyncMock(return_value=("/tmp/anim.gif", "anim.gif")), + ): + return await adapter.send_animation("oc_chat", "https://example.com/anim.gif", caption="look") + + result = asyncio.run(_run()) + + self.assertTrue(result.success) + adapter.send_document.assert_awaited_once() + caption = adapter.send_document.await_args.kwargs["caption"] + self.assertIn("GIF downgraded to file", caption) + self.assertIn("look", caption) + + def test_dedup_state_persists_across_adapter_restart(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + with tempfile.TemporaryDirectory() as temp_home: + with patch.dict(os.environ, {"HERMES_HOME": temp_home}, clear=False): + first = FeishuAdapter(PlatformConfig()) + self.assertFalse(first._is_duplicate("om_same")) + second = FeishuAdapter(PlatformConfig()) + self.assertTrue(second._is_duplicate("om_same")) + + @patch.dict(os.environ, {}, clear=True) + def test_process_inbound_group_message_keeps_group_type_when_chat_lookup_falls_back(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._dispatch_inbound_event = AsyncMock() + adapter.get_chat_info = AsyncMock( + return_value={"chat_id": "oc_group", "name": "oc_group", "type": "dm"} + ) + adapter._resolve_sender_profile = AsyncMock( + return_value={"user_id": "ou_user", "user_name": "张三", "user_id_alt": None} + ) + message = SimpleNamespace( + chat_id="oc_group", + thread_id=None, + message_type="text", + content='{"text":"hello group"}', + message_id="om_group_text", + ) + sender_id = SimpleNamespace(open_id="ou_user", user_id=None, union_id=None) + data = SimpleNamespace(event=SimpleNamespace(message=message)) + + asyncio.run( + adapter._process_inbound_message( + data=data, + message=message, + sender_id=sender_id, + chat_type="group", + message_id="om_group_text", + ) + ) + + event = adapter._dispatch_inbound_event.await_args.args[0] + self.assertEqual(event.source.chat_type, "group") + + @patch.dict(os.environ, {}, clear=True) + def test_process_inbound_message_fetches_reply_to_text(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._dispatch_inbound_event = AsyncMock() + adapter.get_chat_info = AsyncMock( + return_value={"chat_id": "oc_chat", "name": "Feishu DM", "type": "dm"} + ) + adapter._resolve_sender_profile = AsyncMock( + return_value={"user_id": "ou_user", "user_name": "张三", "user_id_alt": None} + ) + adapter._fetch_message_text = AsyncMock(return_value="父消息内容") + message = SimpleNamespace( + chat_id="oc_chat", + thread_id=None, + parent_id="om_parent", + upper_message_id=None, + message_type="text", + content='{"text":"reply"}', + message_id="om_reply", + ) + + asyncio.run( + adapter._process_inbound_message( + data=SimpleNamespace(event=SimpleNamespace(message=message)), + message=message, + sender_id=SimpleNamespace(open_id="ou_user", user_id=None, union_id=None), + chat_type="p2p", + message_id="om_reply", + ) + ) + + event = adapter._dispatch_inbound_event.await_args.args[0] + self.assertEqual(event.reply_to_message_id, "om_parent") + self.assertEqual(event.reply_to_text, "父消息内容") + + @patch.dict(os.environ, {}, clear=True) + def test_send_replies_in_thread_when_thread_metadata_present(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _ReplyAPI: + def reply(self, request): + captured["request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_reply"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_ReplyAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send( + chat_id="oc_chat", + content="hello", + reply_to="om_parent", + metadata={"thread_id": "omt-thread"}, + ) + ) + + self.assertTrue(result.success) + self.assertEqual(result.message_id, "om_reply") + self.assertTrue(captured["request"].request_body.reply_in_thread) + + @patch.dict(os.environ, {}, clear=True) + def test_send_retries_transient_failure(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {"attempts": 0} + sleeps = [] + + class _MessageAPI: + def create(self, request): + captured["attempts"] += 1 + captured["request"] = request + if captured["attempts"] == 1: + raise OSError("temporary send failure") + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_retry"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + async def _sleep(delay): + sleeps.append(delay) + + with ( + patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct), + patch("gateway.platforms.feishu.asyncio.sleep", side_effect=_sleep), + ): + result = asyncio.run(adapter.send(chat_id="oc_chat", content="hello retry")) + + self.assertTrue(result.success) + self.assertEqual(result.message_id, "om_retry") + self.assertEqual(captured["attempts"], 2) + self.assertEqual(sleeps, [1]) + + @patch.dict(os.environ, {}, clear=True) + def test_send_does_not_retry_deterministic_api_failure(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {"attempts": 0} + sleeps = [] + + class _MessageAPI: + def create(self, request): + captured["attempts"] += 1 + return SimpleNamespace( + success=lambda: False, + code=400, + msg="bad request", + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + async def _sleep(delay): + sleeps.append(delay) + + with ( + patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct), + patch("gateway.platforms.feishu.asyncio.sleep", side_effect=_sleep), + ): + result = asyncio.run(adapter.send(chat_id="oc_chat", content="bad payload")) + + self.assertFalse(result.success) + self.assertEqual(result.error, "[400] bad request") + self.assertEqual(captured["attempts"], 1) + self.assertEqual(sleeps, []) + + @patch.dict(os.environ, {}, clear=True) + def test_send_document_reply_uses_thread_flag(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _FileAPI: + def create(self, request): + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(file_key="file_123"), + ) + + class _MessageAPI: + def reply(self, request): + captured["request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_file_reply"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + file=_FileAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False) as tmp: + tmp.write(b"%PDF-1.4 test") + file_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send_document( + chat_id="oc_chat", + file_path=file_path, + reply_to="om_parent", + metadata={"thread_id": "omt-thread"}, + ) + ) + finally: + os.unlink(file_path) + + self.assertTrue(result.success) + self.assertTrue(captured["request"].request_body.reply_in_thread) + + @patch.dict(os.environ, {}, clear=True) + def test_send_document_uploads_file_and_sends_file_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _FileAPI: + def create(self, request): + captured["upload_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(file_key="file_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_file_msg"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + file=_FileAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False) as tmp: + tmp.write(b"%PDF-1.4 test") + file_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter.send_document(chat_id="oc_chat", file_path=file_path)) + finally: + os.unlink(file_path) + + self.assertTrue(result.success) + self.assertEqual(result.message_id, "om_file_msg") + self.assertEqual(captured["upload_request"].request_body.file_type, "pdf") + self.assertEqual( + captured["message_request"].request_body.content, + '{"file_key": "file_123"}', + ) + + @patch.dict(os.environ, {}, clear=True) + def test_send_document_with_caption_uses_single_post_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _FileAPI: + def create(self, request): + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(file_key="file_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_post_msg"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + file=_FileAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".pdf", delete=False) as tmp: + tmp.write(b"%PDF-1.4 test") + file_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send_document(chat_id="oc_chat", file_path=file_path, caption="报告请看") + ) + finally: + os.unlink(file_path) + + self.assertTrue(result.success) + self.assertEqual(captured["message_request"].request_body.msg_type, "post") + self.assertIn('"tag": "media"', captured["message_request"].request_body.content) + self.assertIn('"file_key": "file_123"', captured["message_request"].request_body.content) + self.assertIn("报告请看", captured["message_request"].request_body.content) + + @patch.dict(os.environ, {}, clear=True) + def test_send_image_file_uploads_image_and_sends_image_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _ImageAPI: + def create(self, request): + captured["upload_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(image_key="img_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_image_msg"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + image=_ImageAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False) as tmp: + tmp.write(b"\x89PNG\r\n\x1a\n") + image_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter.send_image_file(chat_id="oc_chat", image_path=image_path)) + finally: + os.unlink(image_path) + + self.assertTrue(result.success) + self.assertEqual(result.message_id, "om_image_msg") + self.assertEqual(captured["upload_request"].request_body.image_type, "message") + self.assertEqual( + captured["message_request"].request_body.content, + '{"image_key": "img_123"}', + ) + + @patch.dict(os.environ, {}, clear=True) + def test_send_image_file_with_caption_uses_single_post_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _ImageAPI: + def create(self, request): + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(image_key="img_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_post_img"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + image=_ImageAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False) as tmp: + tmp.write(b"\x89PNG\r\n\x1a\n") + image_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send_image_file(chat_id="oc_chat", image_path=image_path, caption="截图说明") + ) + finally: + os.unlink(image_path) + + self.assertTrue(result.success) + self.assertEqual(captured["message_request"].request_body.msg_type, "post") + self.assertIn('"tag": "img"', captured["message_request"].request_body.content) + self.assertIn('"image_key": "img_123"', captured["message_request"].request_body.content) + self.assertIn("截图说明", captured["message_request"].request_body.content) + + @patch.dict(os.environ, {}, clear=True) + def test_send_video_uploads_file_and_sends_media_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _FileAPI: + def create(self, request): + captured["upload_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(file_key="file_video_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_video_msg"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + file=_FileAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".mp4", delete=False) as tmp: + tmp.write(b"\x00\x00\x00\x18ftypmp42") + video_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter.send_video(chat_id="oc_chat", video_path=video_path)) + finally: + os.unlink(video_path) + + self.assertTrue(result.success) + self.assertEqual(captured["upload_request"].request_body.file_type, "mp4") + self.assertEqual(captured["message_request"].request_body.msg_type, "media") + self.assertEqual(captured["message_request"].request_body.content, '{"file_key": "file_video_123"}') + + @patch.dict(os.environ, {}, clear=True) + def test_send_voice_uploads_opus_and_sends_audio_message(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _FileAPI: + def create(self, request): + captured["upload_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(file_key="file_audio_123"), + ) + + class _MessageAPI: + def create(self, request): + captured["message_request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_audio_msg"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + file=_FileAPI(), + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with tempfile.NamedTemporaryFile("wb", suffix=".opus", delete=False) as tmp: + tmp.write(b"opus") + audio_path = tmp.name + + try: + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter.send_voice(chat_id="oc_chat", audio_path=audio_path)) + finally: + os.unlink(audio_path) + + self.assertTrue(result.success) + self.assertEqual(captured["upload_request"].request_body.file_type, "opus") + self.assertEqual(captured["message_request"].request_body.msg_type, "audio") + self.assertEqual(captured["message_request"].request_body.content, '{"file_key": "file_audio_123"}') + + @patch.dict(os.environ, {}, clear=True) + def test_build_post_payload_extracts_title_and_links(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + payload = json.loads(adapter._build_post_payload("# 标题\n访问 [文档](https://example.com)")) + + elements = payload["zh_cn"]["content"][0] + self.assertEqual(elements, [{"tag": "md", "text": "# 标题\n访问 [文档](https://example.com)"}]) + + @patch.dict(os.environ, {}, clear=True) + def test_build_post_payload_wraps_markdown_in_md_tag(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + payload = json.loads( + adapter._build_post_payload("支持 **粗体**、*斜体* 和 `代码`") + ) + + elements = payload["zh_cn"]["content"][0] + self.assertEqual( + elements, + [ + {"tag": "md", "text": "支持 **粗体**、*斜体* 和 `代码`"}, + ], + ) + + @patch.dict(os.environ, {}, clear=True) + def test_build_post_payload_keeps_full_markdown_text(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + payload = json.loads( + adapter._build_post_payload( + "---\n1. 第一项\n 2. 子项\n- 外层\n - 内层\n下划线 和 ~~删除线~~" + ) + ) + + rows = payload["zh_cn"]["content"] + self.assertEqual( + rows, + [[{"tag": "md", "text": "---\n1. 第一项\n 2. 子项\n- 外层\n - 内层\n下划线 和 ~~删除线~~"}]], + ) + + @patch.dict(os.environ, {}, clear=True) + def test_send_uses_post_for_inline_markdown(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _MessageAPI: + def create(self, request): + captured["request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_markdown"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send( + chat_id="oc_chat", + content="可以用 **粗体** 和 *斜体*。", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(captured["request"].request_body.msg_type, "post") + payload = json.loads(captured["request"].request_body.content) + elements = payload["zh_cn"]["content"][0] + self.assertEqual(elements, [{"tag": "md", "text": "可以用 **粗体** 和 *斜体*。"}]) + + @patch.dict(os.environ, {}, clear=True) + def test_send_falls_back_to_text_when_post_payload_is_rejected(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {"calls": []} + + class _MessageAPI: + def create(self, request): + captured["calls"].append(request) + if len(captured["calls"]) == 1: + raise RuntimeError("content format of the post type is incorrect") + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_plain"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send( + chat_id="oc_chat", + content="可以用 **粗体** 和 *斜体*。", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(captured["calls"][0].request_body.msg_type, "post") + self.assertEqual(captured["calls"][1].request_body.msg_type, "text") + self.assertEqual( + captured["calls"][1].request_body.content, + json.dumps({"text": "可以用 粗体 和 斜体。"}, ensure_ascii=False), + ) + + @patch.dict(os.environ, {}, clear=True) + def test_send_falls_back_to_text_when_post_response_is_unsuccessful(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {"calls": []} + + class _MessageAPI: + def create(self, request): + captured["calls"].append(request) + if len(captured["calls"]) == 1: + return SimpleNamespace(success=lambda: False, code=230001, msg="content format of the post type is incorrect") + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_plain_response"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send( + chat_id="oc_chat", + content="可以用 **粗体** 和 *斜体*。", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(captured["calls"][0].request_body.msg_type, "post") + self.assertEqual(captured["calls"][1].request_body.msg_type, "text") + self.assertEqual( + captured["calls"][1].request_body.content, + json.dumps({"text": "可以用 粗体 和 斜体。"}, ensure_ascii=False), + ) + + @patch.dict(os.environ, {}, clear=True) + def test_send_uses_post_for_advanced_markdown_lines(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + captured = {} + + class _MessageAPI: + def create(self, request): + captured["request"] = request + return SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(message_id="om_markdown_advanced"), + ) + + adapter._client = SimpleNamespace( + im=SimpleNamespace( + v1=SimpleNamespace( + message=_MessageAPI(), + ) + ) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run( + adapter.send( + chat_id="oc_chat", + content="---\n1. 第一项\n下划线\n~~删除线~~", + ) + ) + + self.assertTrue(result.success) + self.assertEqual(captured["request"].request_body.msg_type, "post") + payload = json.loads(captured["request"].request_body.content) + rows = payload["zh_cn"]["content"] + self.assertEqual( + rows, + [[{"tag": "md", "text": "---\n1. 第一项\n下划线\n~~删除线~~"}]], + ) + + +@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed") +class TestWebhookSecurity(unittest.TestCase): + """Tests for webhook signature verification, rate limiting, and body size limits.""" + + def _make_adapter(self, encrypt_key: str = "") -> "FeishuAdapter": + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + with patch.dict(os.environ, {"FEISHU_APP_ID": "cli", "FEISHU_APP_SECRET": "sec", "FEISHU_ENCRYPT_KEY": encrypt_key}, clear=True): + return FeishuAdapter(PlatformConfig()) + + def test_signature_valid_passes(self): + import hashlib + from gateway.platforms.feishu import FeishuAdapter + from gateway.config import PlatformConfig + + encrypt_key = "test_secret" + adapter = self._make_adapter(encrypt_key) + body = b'{"type":"event"}' + timestamp = "1700000000" + nonce = "abc123" + content = f"{timestamp}{nonce}{encrypt_key}" + body.decode("utf-8") + sig = hashlib.sha256(content.encode("utf-8")).hexdigest() + headers = {"x-lark-request-timestamp": timestamp, "x-lark-request-nonce": nonce, "x-lark-signature": sig} + self.assertTrue(adapter._is_webhook_signature_valid(headers, body)) + + def test_signature_invalid_rejected(self): + adapter = self._make_adapter("test_secret") + headers = { + "x-lark-request-timestamp": "1700000000", + "x-lark-request-nonce": "abc", + "x-lark-signature": "deadbeef" * 8, + } + self.assertFalse(adapter._is_webhook_signature_valid(headers, b'{"type":"event"}')) + + def test_signature_missing_headers_rejected(self): + adapter = self._make_adapter("test_secret") + self.assertFalse(adapter._is_webhook_signature_valid({}, b'{}')) + + def test_rate_limit_allows_requests_within_window(self): + adapter = self._make_adapter() + for _ in range(5): + self.assertTrue(adapter._check_webhook_rate_limit("10.0.0.1")) + + def test_rate_limit_blocks_after_exceeding_max(self): + from gateway.platforms.feishu import _FEISHU_WEBHOOK_RATE_LIMIT_MAX + adapter = self._make_adapter() + for _ in range(_FEISHU_WEBHOOK_RATE_LIMIT_MAX): + adapter._check_webhook_rate_limit("10.0.0.2") + self.assertFalse(adapter._check_webhook_rate_limit("10.0.0.2")) + + def test_rate_limit_resets_after_window_expires(self): + from gateway.platforms.feishu import _FEISHU_WEBHOOK_RATE_LIMIT_MAX, _FEISHU_WEBHOOK_RATE_WINDOW_SECONDS + adapter = self._make_adapter() + ip = "10.0.0.3" + for _ in range(_FEISHU_WEBHOOK_RATE_LIMIT_MAX): + adapter._check_webhook_rate_limit(ip) + self.assertFalse(adapter._check_webhook_rate_limit(ip)) + # Simulate window expiry by backdating the stored entry. + count, window_start = adapter._webhook_rate_counts[ip] + adapter._webhook_rate_counts[ip] = (count, window_start - _FEISHU_WEBHOOK_RATE_WINDOW_SECONDS - 1) + self.assertTrue(adapter._check_webhook_rate_limit(ip)) + + @patch.dict(os.environ, {}, clear=True) + def test_webhook_request_rejects_oversized_body(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter, _FEISHU_WEBHOOK_MAX_BODY_BYTES + + adapter = FeishuAdapter(PlatformConfig()) + # Simulate a request whose Content-Length already signals oversize. + request = SimpleNamespace( + remote="127.0.0.1", + content_length=_FEISHU_WEBHOOK_MAX_BODY_BYTES + 1, + ) + response = asyncio.run(adapter._handle_webhook_request(request)) + self.assertEqual(response.status, 413) + + @patch.dict(os.environ, {}, clear=True) + def test_webhook_request_rejects_invalid_json(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + request = SimpleNamespace( + remote="127.0.0.1", + content_length=None, + read=AsyncMock(return_value=b"not-json"), + ) + response = asyncio.run(adapter._handle_webhook_request(request)) + self.assertEqual(response.status, 400) + + @patch.dict(os.environ, {"FEISHU_ENCRYPT_KEY": "secret"}, clear=True) + def test_webhook_request_rejects_bad_signature(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + body = json.dumps({"header": {"event_type": "im.message.receive_v1"}}).encode() + request = SimpleNamespace( + remote="127.0.0.1", + content_length=None, + headers={"x-lark-request-timestamp": "123", "x-lark-request-nonce": "abc", "x-lark-signature": "bad"}, + read=AsyncMock(return_value=body), + ) + response = asyncio.run(adapter._handle_webhook_request(request)) + self.assertEqual(response.status, 401) + + @patch.dict(os.environ, {}, clear=True) + def test_webhook_url_verification_challenge_passes_without_signature(self): + """Challenge requests must succeed even when no encrypt_key is set.""" + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + body = json.dumps({"type": "url_verification", "challenge": "test_challenge_token"}).encode() + request = SimpleNamespace( + remote="127.0.0.1", + content_length=None, + read=AsyncMock(return_value=body), + ) + response = asyncio.run(adapter._handle_webhook_request(request)) + self.assertEqual(response.status, 200) + self.assertIn(b"test_challenge_token", response.body) + + +class TestDedupTTL(unittest.TestCase): + """Tests for TTL-aware deduplication.""" + + @patch.dict(os.environ, {}, clear=True) + def test_duplicate_within_ttl_is_rejected(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + with patch.object(adapter, "_persist_seen_message_ids"): + adapter._seen_message_ids = {"om_dup": time.time()} + adapter._seen_message_order = ["om_dup"] + self.assertTrue(adapter._is_duplicate("om_dup")) + + @patch.dict(os.environ, {}, clear=True) + def test_expired_entry_is_not_considered_duplicate(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter, _FEISHU_DEDUP_TTL_SECONDS + + adapter = FeishuAdapter(PlatformConfig()) + # Plant an entry that expired well past the TTL. + stale_ts = time.time() - _FEISHU_DEDUP_TTL_SECONDS - 60 + adapter._seen_message_ids = {"om_old": stale_ts} + adapter._seen_message_order = ["om_old"] + with patch.object(adapter, "_persist_seen_message_ids"): + self.assertFalse(adapter._is_duplicate("om_old")) + + @patch.dict(os.environ, {}, clear=True) + def test_persist_saves_timestamps_as_dict(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + ts = time.time() + adapter._seen_message_ids = {"om_ts1": ts} + adapter._seen_message_order = ["om_ts1"] + with tempfile.TemporaryDirectory() as tmpdir: + adapter._dedup_state_path = Path(tmpdir) / "dedup.json" + adapter._persist_seen_message_ids() + saved = json.loads(adapter._dedup_state_path.read_text()) + self.assertIsInstance(saved["message_ids"], dict) + self.assertAlmostEqual(saved["message_ids"]["om_ts1"], ts, places=1) + + @patch.dict(os.environ, {}, clear=True) + def test_load_backward_compat_list_format(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "dedup.json" + path.write_text(json.dumps({"message_ids": ["om_a", "om_b"]}), encoding="utf-8") + adapter._dedup_state_path = path + adapter._load_seen_message_ids() + self.assertIn("om_a", adapter._seen_message_ids) + self.assertIn("om_b", adapter._seen_message_ids) + + +class TestGroupMentionAtAll(unittest.TestCase): + """Tests for @_all (Feishu @everyone) group mention routing.""" + + @patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "open"}, clear=True) + def test_at_all_in_content_accepts_without_explicit_bot_mention(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace( + content='{"text":"@_all 请注意"}', + mentions=[], + ) + sender_id = SimpleNamespace(open_id="ou_any", user_id=None) + self.assertTrue(adapter._should_accept_group_message(message, sender_id)) + + @patch.dict(os.environ, {"FEISHU_GROUP_POLICY": "allowlist", "FEISHU_ALLOWED_USERS": "ou_allowed"}, clear=True) + def test_at_all_still_requires_policy_gate(self): + """@_all bypasses mention gating but NOT the allowlist policy.""" + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + message = SimpleNamespace(content='{"text":"@_all attention"}', mentions=[]) + # Non-allowlisted user — should be blocked even with @_all. + blocked_sender = SimpleNamespace(open_id="ou_blocked", user_id=None) + self.assertFalse(adapter._should_accept_group_message(message, blocked_sender)) + # Allowlisted user — should pass. + allowed_sender = SimpleNamespace(open_id="ou_allowed", user_id=None) + self.assertTrue(adapter._should_accept_group_message(message, allowed_sender)) + + +@unittest.skipUnless(_HAS_LARK_OAPI, "lark-oapi not installed") +class TestSenderNameResolution(unittest.TestCase): + """Tests for _resolve_sender_name_from_api.""" + + @patch.dict(os.environ, {}, clear=True) + def test_returns_none_when_client_is_none(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._client = None + result = asyncio.run(adapter._resolve_sender_name_from_api("ou_abc")) + self.assertIsNone(result) + + @patch.dict(os.environ, {}, clear=True) + def test_returns_cached_name_within_ttl(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + adapter._client = SimpleNamespace() + future_expire = time.time() + 600 + adapter._sender_name_cache["ou_cached"] = ("Alice", future_expire) + result = asyncio.run(adapter._resolve_sender_name_from_api("ou_cached")) + self.assertEqual(result, "Alice") + + @patch.dict(os.environ, {}, clear=True) + def test_fetches_and_caches_name_from_api(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + user_obj = SimpleNamespace(name="Bob", display_name=None, nickname=None, en_name=None) + mock_response = SimpleNamespace( + success=lambda: True, + data=SimpleNamespace(user=user_obj), + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + class _ContactAPI: + def get(self, request): + return mock_response + + adapter._client = SimpleNamespace( + contact=SimpleNamespace(v3=SimpleNamespace(user=_ContactAPI())) + ) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter._resolve_sender_name_from_api("ou_bob")) + + self.assertEqual(result, "Bob") + self.assertIn("ou_bob", adapter._sender_name_cache) + + @patch.dict(os.environ, {}, clear=True) + def test_expired_cache_triggers_new_api_call(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + # Expired cache entry. + adapter._sender_name_cache["ou_expired"] = ("OldName", time.time() - 1) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + user_obj = SimpleNamespace(name="NewName", display_name=None, nickname=None, en_name=None) + + class _ContactAPI: + def get(self, request): + return SimpleNamespace(success=lambda: True, data=SimpleNamespace(user=user_obj)) + + adapter._client = SimpleNamespace( + contact=SimpleNamespace(v3=SimpleNamespace(user=_ContactAPI())) + ) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter._resolve_sender_name_from_api("ou_expired")) + + self.assertEqual(result, "NewName") + + @patch.dict(os.environ, {}, clear=True) + def test_api_failure_returns_none_without_raising(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + + class _BrokenContactAPI: + def get(self, _request): + raise RuntimeError("API down") + + adapter._client = SimpleNamespace( + contact=SimpleNamespace(v3=SimpleNamespace(user=_BrokenContactAPI())) + ) + + async def _direct(func, *args, **kwargs): + return func(*args, **kwargs) + + with patch("gateway.platforms.feishu.asyncio.to_thread", side_effect=_direct): + result = asyncio.run(adapter._resolve_sender_name_from_api("ou_broken")) + + self.assertIsNone(result) diff --git a/tests/gateway/test_flush_memory_stale_guard.py b/tests/gateway/test_flush_memory_stale_guard.py index ee1405243..495ba90ba 100644 --- a/tests/gateway/test_flush_memory_stale_guard.py +++ b/tests/gateway/test_flush_memory_stale_guard.py @@ -7,11 +7,21 @@ Verifies that: 3. The flush still works normally when memory files don't exist """ +import sys +import types import pytest from pathlib import Path from unittest.mock import MagicMock, patch, call +@pytest.fixture(autouse=True) +def _mock_dotenv(monkeypatch): + """gateway.run imports dotenv at module level; stub it so tests run without the package.""" + fake = types.ModuleType("dotenv") + fake.load_dotenv = lambda *a, **kw: None + monkeypatch.setitem(sys.modules, "dotenv", fake) + + def _make_runner(): from gateway.run import GatewayRunner @@ -57,105 +67,151 @@ class TestCronSessionBypass: runner.session_store.load_transcript.assert_called_once_with("session_abc123") +def _make_flush_context(monkeypatch, memory_dir=None): + """Return (runner, tmp_agent, fake_run_agent) with run_agent mocked in sys.modules.""" + tmp_agent = MagicMock() + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = MagicMock(return_value=tmp_agent) + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + runner = _make_runner() + runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS + return runner, tmp_agent, memory_dir + + class TestMemoryInjection: """The flush prompt should include current memory state from disk.""" - def test_memory_content_injected_into_flush_prompt(self, tmp_path): + def test_memory_content_injected_into_flush_prompt(self, tmp_path, monkeypatch): """When memory files exist, their content appears in the flush prompt.""" - runner = _make_runner() - runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS - - tmp_agent = MagicMock() memory_dir = tmp_path / "memories" memory_dir.mkdir() (memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode") (memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST") + runner, tmp_agent, _ = _make_flush_context(monkeypatch, memory_dir) + with ( patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), patch("gateway.run._resolve_gateway_model", return_value="test-model"), - patch("run_agent.AIAgent", return_value=tmp_agent), - # Intercept `from tools.memory_tool import MEMORY_DIR` inside the function patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}), ): runner._flush_memories_for_session("session_123") tmp_agent.run_conversation.assert_called_once() - call_kwargs = tmp_agent.run_conversation.call_args.kwargs - flush_prompt = call_kwargs.get("user_message", "") - - # Verify both memory sections appear in the prompt + flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "") + assert "Agent knows Python" in flush_prompt assert "User prefers dark mode" in flush_prompt assert "Name: Alice" in flush_prompt assert "Timezone: PST" in flush_prompt - # Verify the stale-overwrite warning is present assert "Do NOT overwrite or remove entries" in flush_prompt assert "current live state of memory" in flush_prompt - def test_flush_works_without_memory_files(self, tmp_path): + def test_flush_works_without_memory_files(self, tmp_path, monkeypatch): """When no memory files exist, flush still runs without the guard.""" - runner = _make_runner() - runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS - - tmp_agent = MagicMock() empty_dir = tmp_path / "no_memories" empty_dir.mkdir() + runner, tmp_agent, _ = _make_flush_context(monkeypatch) + with ( patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), patch("gateway.run._resolve_gateway_model", return_value="test-model"), - patch("run_agent.AIAgent", return_value=tmp_agent), patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}), ): runner._flush_memories_for_session("session_456") - # Should still run, just without the memory guard section tmp_agent.run_conversation.assert_called_once() flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "") assert "Do NOT overwrite or remove entries" not in flush_prompt assert "Review the conversation above" in flush_prompt - def test_empty_memory_files_no_injection(self, tmp_path): + def test_empty_memory_files_no_injection(self, tmp_path, monkeypatch): """Empty memory files should not trigger the guard section.""" - runner = _make_runner() - runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS - - tmp_agent = MagicMock() memory_dir = tmp_path / "memories" memory_dir.mkdir() (memory_dir / "MEMORY.md").write_text("") (memory_dir / "USER.md").write_text(" \n ") # whitespace only + runner, tmp_agent, _ = _make_flush_context(monkeypatch) + with ( patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), patch("gateway.run._resolve_gateway_model", return_value="test-model"), - patch("run_agent.AIAgent", return_value=tmp_agent), patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}), ): runner._flush_memories_for_session("session_789") tmp_agent.run_conversation.assert_called_once() flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "") - # No memory content → no guard section assert "current live state of memory" not in flush_prompt +class TestFlushAgentSilenced: + """The flush agent must not produce any terminal output.""" + + def test_print_fn_set_to_noop(self, tmp_path, monkeypatch): + """_print_fn on the flush agent must be a no-op so tool output never leaks.""" + runner = _make_runner() + runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS + + captured_agent = {} + + def _fake_ai_agent(*args, **kwargs): + agent = MagicMock() + captured_agent["instance"] = agent + return agent + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = _fake_ai_agent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + with ( + patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), + patch("gateway.run._resolve_gateway_model", return_value="test-model"), + patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=tmp_path)}), + ): + runner._flush_memories_for_session("session_silent") + + agent = captured_agent["instance"] + assert agent._print_fn is not None, "_print_fn should be overridden to suppress output" + # Confirm it is callable and produces no output (no exception) + agent._print_fn("should be silenced") + + def test_kawaii_spinner_respects_print_fn(self): + """KawaiiSpinner must route all output through print_fn when supplied.""" + from agent.display import KawaiiSpinner + + written = [] + spinner = KawaiiSpinner("test", print_fn=lambda *a, **kw: written.append(a)) + spinner._write("hello") + assert written == [("hello",)], "spinner should route through print_fn" + + # A no-op print_fn must produce no output to stdout + import io, sys + buf = io.StringIO() + old_stdout = sys.stdout + sys.stdout = buf + try: + silent_spinner = KawaiiSpinner("silent", print_fn=lambda *a, **kw: None) + silent_spinner._write("should not appear") + silent_spinner.stop("done") + finally: + sys.stdout = old_stdout + assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout" + + class TestFlushPromptStructure: """Verify the flush prompt retains its core instructions.""" - def test_core_instructions_present(self): + def test_core_instructions_present(self, monkeypatch): """The flush prompt should still contain the original guidance.""" - runner = _make_runner() - runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS - - tmp_agent = MagicMock() + runner, tmp_agent, _ = _make_flush_context(monkeypatch) with ( patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), patch("gateway.run._resolve_gateway_model", return_value="test-model"), - patch("run_agent.AIAgent", return_value=tmp_agent), - # Make the import fail gracefully so we test without memory files patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}), ): runner._flush_memories_for_session("session_struct") diff --git a/tests/gateway/test_hooks.py b/tests/gateway/test_hooks.py index 039ce6b2e..1301aebae 100644 --- a/tests/gateway/test_hooks.py +++ b/tests/gateway/test_hooks.py @@ -29,13 +29,18 @@ class TestHookRegistryInit: assert reg._handlers == {} +def _patch_no_builtins(reg): + """Suppress built-in hook registration so tests only exercise user-hook discovery.""" + return patch.object(reg, "_register_builtin_hooks") + + class TestDiscoverAndLoad: def test_loads_valid_hook(self, tmp_path): _create_hook(tmp_path, "my-hook", '["agent:start"]', "def handle(event_type, context):\n pass\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 1 @@ -48,7 +53,7 @@ class TestDiscoverAndLoad: (hook_dir / "handler.py").write_text("def handle(e, c): pass\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 0 @@ -59,7 +64,7 @@ class TestDiscoverAndLoad: (hook_dir / "HOOK.yaml").write_text("name: bad\nevents: ['agent:start']\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 0 @@ -71,7 +76,7 @@ class TestDiscoverAndLoad: (hook_dir / "handler.py").write_text("def handle(e, c): pass\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 0 @@ -83,14 +88,14 @@ class TestDiscoverAndLoad: (hook_dir / "handler.py").write_text("def something_else(): pass\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 0 def test_nonexistent_hooks_dir(self, tmp_path): reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path / "nonexistent"): + with patch("gateway.hooks.HOOKS_DIR", tmp_path / "nonexistent"), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 0 @@ -102,7 +107,7 @@ class TestDiscoverAndLoad: "def handle(e, c): pass\n") reg = HookRegistry() - with patch("gateway.hooks.HOOKS_DIR", tmp_path): + with patch("gateway.hooks.HOOKS_DIR", tmp_path), _patch_no_builtins(reg): reg.discover_and_load() assert len(reg.loaded_hooks) == 2 diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 31e59caeb..5a9879f60 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1,4 +1,5 @@ """Tests for Matrix platform adapter.""" +import asyncio import json import re import pytest @@ -446,3 +447,199 @@ class TestMatrixRequirements: monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) from gateway.platforms.matrix import check_matrix_requirements assert check_matrix_requirements() is False + + +# --------------------------------------------------------------------------- +# Access-token auth / E2EE bootstrap +# --------------------------------------------------------------------------- + +class TestMatrixAccessTokenAuth: + @pytest.mark.asyncio + async def test_connect_fetches_device_id_from_whoami_for_access_token(self): + from gateway.platforms.matrix import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_test_access_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + "encryption": True, + }, + ) + adapter = MatrixAdapter(config) + + class FakeWhoamiResponse: + def __init__(self, user_id, device_id): + self.user_id = user_id + self.device_id = device_id + + class FakeSyncResponse: + def __init__(self): + self.rooms = MagicMock(join={}) + + fake_client = MagicMock() + fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) + fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) + fake_client.keys_upload = AsyncMock() + fake_client.keys_query = AsyncMock() + fake_client.keys_claim = AsyncMock() + fake_client.send_to_device_messages = AsyncMock(return_value=[]) + fake_client.get_users_for_key_claiming = MagicMock(return_value={}) + fake_client.close = AsyncMock() + fake_client.add_event_callback = MagicMock() + fake_client.rooms = {} + fake_client.account_data = {} + fake_client.olm = object() + fake_client.should_upload_keys = False + fake_client.should_query_keys = False + fake_client.should_claim_keys = False + + def _restore_login(user_id, device_id, access_token): + fake_client.user_id = user_id + fake_client.device_id = device_id + fake_client.access_token = access_token + fake_client.olm = object() + + fake_client.restore_login = MagicMock(side_effect=_restore_login) + + fake_nio = MagicMock() + fake_nio.AsyncClient = MagicMock(return_value=fake_client) + fake_nio.WhoamiResponse = FakeWhoamiResponse + fake_nio.SyncResponse = FakeSyncResponse + fake_nio.LoginResponse = type("LoginResponse", (), {}) + fake_nio.RoomMessageText = type("RoomMessageText", (), {}) + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {}) + fake_nio.MegolmEvent = type("MegolmEvent", (), {}) + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): + with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): + assert await adapter.connect() is True + + fake_client.restore_login.assert_called_once_with( + "@bot:example.org", "DEV123", "syt_test_access_token" + ) + assert fake_client.access_token == "syt_test_access_token" + assert fake_client.user_id == "@bot:example.org" + assert fake_client.device_id == "DEV123" + fake_client.whoami.assert_awaited_once() + + await adapter.disconnect() + + +class TestMatrixE2EEMaintenance: + @pytest.mark.asyncio + async def test_sync_loop_runs_e2ee_maintenance_requests(self): + adapter = _make_adapter() + adapter._encryption = True + adapter._closing = False + + class FakeSyncError: + pass + + async def _sync_once(timeout=30000): + adapter._closing = True + return MagicMock() + + fake_client = MagicMock() + fake_client.sync = AsyncMock(side_effect=_sync_once) + fake_client.send_to_device_messages = AsyncMock(return_value=[]) + fake_client.keys_upload = AsyncMock() + fake_client.keys_query = AsyncMock() + fake_client.get_users_for_key_claiming = MagicMock( + return_value={"@alice:example.org": ["DEVICE1"]} + ) + fake_client.keys_claim = AsyncMock() + fake_client.olm = object() + fake_client.should_upload_keys = True + fake_client.should_query_keys = True + fake_client.should_claim_keys = True + + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.SyncError = FakeSyncError + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._sync_loop() + + fake_client.sync.assert_awaited_once_with(timeout=30000) + fake_client.send_to_device_messages.assert_awaited_once() + fake_client.keys_upload.assert_awaited_once() + fake_client.keys_query.assert_awaited_once() + fake_client.keys_claim.assert_awaited_once_with( + {"@alice:example.org": ["DEVICE1"]} + ) + + +class TestMatrixEncryptedSendFallback: + @pytest.mark.asyncio + async def test_send_retries_with_ignored_unverified_devices(self): + adapter = _make_adapter() + adapter._encryption = True + + class FakeRoomSendResponse: + def __init__(self, event_id): + self.event_id = event_id + + class FakeOlmUnverifiedDeviceError(Exception): + pass + + fake_client = MagicMock() + fake_client.room_send = AsyncMock(side_effect=[ + FakeOlmUnverifiedDeviceError("unverified"), + FakeRoomSendResponse("$event123"), + ]) + adapter._client = fake_client + adapter._run_e2ee_maintenance = AsyncMock() + + fake_nio = MagicMock() + fake_nio.RoomSendResponse = FakeRoomSendResponse + fake_nio.OlmUnverifiedDeviceError = FakeOlmUnverifiedDeviceError + + with patch.dict("sys.modules", {"nio": fake_nio}): + result = await adapter.send("!room:example.org", "hello") + + assert result.success is True + assert result.message_id == "$event123" + adapter._run_e2ee_maintenance.assert_awaited_once() + assert fake_client.room_send.await_count == 2 + first_call = fake_client.room_send.await_args_list[0] + second_call = fake_client.room_send.await_args_list[1] + assert first_call.kwargs.get("ignore_unverified_devices") is False + assert second_call.kwargs.get("ignore_unverified_devices") is True + + @pytest.mark.asyncio + async def test_send_retries_after_timeout_in_encrypted_room(self): + adapter = _make_adapter() + adapter._encryption = True + + class FakeRoomSendResponse: + def __init__(self, event_id): + self.event_id = event_id + + fake_client = MagicMock() + fake_client.room_send = AsyncMock(side_effect=[ + asyncio.TimeoutError(), + FakeRoomSendResponse("$event456"), + ]) + adapter._client = fake_client + adapter._run_e2ee_maintenance = AsyncMock() + + fake_nio = MagicMock() + fake_nio.RoomSendResponse = FakeRoomSendResponse + + with patch.dict("sys.modules", {"nio": fake_nio}): + result = await adapter.send("!room:example.org", "hello") + + assert result.success is True + assert result.message_id == "$event456" + adapter._run_e2ee_maintenance.assert_awaited_once() + assert fake_client.room_send.await_count == 2 + second_call = fake_client.room_send.await_args_list[1] + assert second_call.kwargs.get("ignore_unverified_devices") is True diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py index 238506b05..a7a586ff5 100644 --- a/tests/gateway/test_mattermost.py +++ b/tests/gateway/test_mattermost.py @@ -1,5 +1,6 @@ """Tests for Mattermost platform adapter.""" import json +import os import time import pytest from unittest.mock import MagicMock, patch, AsyncMock @@ -269,6 +270,7 @@ class TestMattermostWebSocketParsing: def setup_method(self): self.adapter = _make_adapter() self.adapter._bot_user_id = "bot_user_id" + self.adapter._bot_username = "hermes-bot" # Mock handle_message to capture the MessageEvent without processing self.adapter.handle_message = AsyncMock() @@ -293,7 +295,8 @@ class TestMattermostWebSocketParsing: await self.adapter._handle_ws_event(event) assert self.adapter.handle_message.called msg_event = self.adapter.handle_message.call_args[0][0] - assert msg_event.text == "@bot_user_id Hello from Matrix!" + # @mention is stripped from the message text + assert msg_event.text == "Hello from Matrix!" assert msg_event.message_id == "post_abc" @pytest.mark.asyncio @@ -410,6 +413,87 @@ class TestMattermostWebSocketParsing: assert not self.adapter.handle_message.called +# --------------------------------------------------------------------------- +# Mention behavior (require_mention + free_response_channels) +# --------------------------------------------------------------------------- + +class TestMattermostMentionBehavior: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._bot_user_id = "bot_user_id" + self.adapter._bot_username = "hermes-bot" + self.adapter.handle_message = AsyncMock() + + def _make_event(self, message, channel_type="O", channel_id="chan_456"): + post_data = { + "id": "post_mention", + "user_id": "user_123", + "channel_id": channel_id, + "message": message, + } + return { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": channel_type, + "sender_name": "@alice", + }, + } + + @pytest.mark.asyncio + async def test_require_mention_true_skips_without_mention(self): + """Default: messages without @mention in channels are skipped.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("MATTERMOST_REQUIRE_MENTION", None) + os.environ.pop("MATTERMOST_FREE_RESPONSE_CHANNELS", None) + await self.adapter._handle_ws_event(self._make_event("hello")) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_require_mention_false_responds_to_all(self): + """MATTERMOST_REQUIRE_MENTION=false: respond to all channel messages.""" + with patch.dict(os.environ, {"MATTERMOST_REQUIRE_MENTION": "false"}): + await self.adapter._handle_ws_event(self._make_event("hello")) + assert self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_free_response_channel_responds_without_mention(self): + """Messages in free-response channels don't need @mention.""" + with patch.dict(os.environ, {"MATTERMOST_FREE_RESPONSE_CHANNELS": "chan_456,chan_789"}): + os.environ.pop("MATTERMOST_REQUIRE_MENTION", None) + await self.adapter._handle_ws_event(self._make_event("hello", channel_id="chan_456")) + assert self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_non_free_channel_still_requires_mention(self): + """Channels NOT in free-response list still require @mention.""" + with patch.dict(os.environ, {"MATTERMOST_FREE_RESPONSE_CHANNELS": "chan_789"}): + os.environ.pop("MATTERMOST_REQUIRE_MENTION", None) + await self.adapter._handle_ws_event(self._make_event("hello", channel_id="chan_456")) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_dm_always_responds(self): + """DMs (channel_type=D) always respond regardless of mention settings.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("MATTERMOST_REQUIRE_MENTION", None) + await self.adapter._handle_ws_event(self._make_event("hello", channel_type="D")) + assert self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_mention_stripped_from_text(self): + """@mention is stripped from message text.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("MATTERMOST_REQUIRE_MENTION", None) + await self.adapter._handle_ws_event( + self._make_event("@hermes-bot what is 2+2") + ) + assert self.adapter.handle_message.called + msg = self.adapter.handle_message.call_args[0][0] + assert "@hermes-bot" not in msg.text + assert "2+2" in msg.text + + # --------------------------------------------------------------------------- # File upload (send_image) # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_media_download_retry.py b/tests/gateway/test_media_download_retry.py new file mode 100644 index 000000000..ad00da246 --- /dev/null +++ b/tests/gateway/test_media_download_retry.py @@ -0,0 +1,722 @@ +""" +Tests for media download retry logic added in PR #2982. + +Covers: +- gateway/platforms/base.py: cache_image_from_url +- gateway/platforms/slack.py: SlackAdapter._download_slack_file + SlackAdapter._download_slack_file_bytes +- gateway/platforms/mattermost.py: MattermostAdapter._send_url_as_file + +All async tests use asyncio.run() directly — pytest-asyncio is not installed +in this environment. +""" + +import asyncio +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import httpx + +# --------------------------------------------------------------------------- +# Helpers for building httpx exceptions +# --------------------------------------------------------------------------- + +def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError: + request = httpx.Request("GET", "http://example.com/img.jpg") + response = httpx.Response(status_code=status_code, request=request) + return httpx.HTTPStatusError( + f"HTTP {status_code}", request=request, response=response + ) + + +def _make_timeout_error() -> httpx.TimeoutException: + return httpx.TimeoutException("timed out") + + +# --------------------------------------------------------------------------- +# cache_image_from_url (base.py) +# --------------------------------------------------------------------------- + +class TestCacheImageFromUrl: + """Tests for gateway.platforms.base.cache_image_from_url""" + + def test_success_on_first_attempt(self, tmp_path, monkeypatch): + """A clean 200 response caches the image and returns a path.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + fake_response = MagicMock() + fake_response.content = b"\xff\xd8\xff fake jpeg" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=fake_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client): + from gateway.platforms.base import cache_image_from_url + return await cache_image_from_url( + "http://example.com/img.jpg", ext=".jpg" + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + mock_client.get.assert_called_once() + + def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch): + """A timeout on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + fake_response = MagicMock() + fake_response.content = b"image data" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_timeout_error(), fake_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + mock_sleep = AsyncMock() + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_image_from_url + return await cache_image_from_url( + "http://example.com/img.jpg", ext=".jpg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + assert mock_client.get.call_count == 2 + mock_sleep.assert_called_once() + + def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch): + """A 429 response on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + ok_response = MagicMock() + ok_response.content = b"image data" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(429), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_image_from_url + return await cache_image_from_url( + "http://example.com/img.jpg", ext=".jpg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + assert mock_client.get.call_count == 2 + + def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch): + """Timeout on every attempt raises after all retries are consumed.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_timeout_error()) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_image_from_url + await cache_image_from_url( + "http://example.com/img.jpg", ext=".jpg", retries=2 + ) + + with pytest.raises(httpx.TimeoutException): + asyncio.run(run()) + + # 3 total calls: initial + 2 retries + assert mock_client.get.call_count == 3 + + def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch): + """A 404 (non-retryable) is raised immediately without any retry.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + mock_sleep = AsyncMock() + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_http_status_error(404)) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_image_from_url + await cache_image_from_url( + "http://example.com/img.jpg", ext=".jpg", retries=2 + ) + + with pytest.raises(httpx.HTTPStatusError): + asyncio.run(run()) + + # Only 1 attempt, no sleep + assert mock_client.get.call_count == 1 + mock_sleep.assert_not_called() + + +# --------------------------------------------------------------------------- +# cache_audio_from_url (base.py) +# --------------------------------------------------------------------------- + +class TestCacheAudioFromUrl: + """Tests for gateway.platforms.base.cache_audio_from_url""" + + def test_success_on_first_attempt(self, tmp_path, monkeypatch): + """A clean 200 response caches the audio and returns a path.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + fake_response = MagicMock() + fake_response.content = b"\x00\x01 fake audio" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=fake_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg" + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + mock_client.get.assert_called_once() + + def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch): + """A timeout on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + fake_response = MagicMock() + fake_response.content = b"audio data" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_timeout_error(), fake_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + mock_sleep = AsyncMock() + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + mock_sleep.assert_called_once() + + def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch): + """A 429 response on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + ok_response = MagicMock() + ok_response.content = b"audio data" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(429), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + + def test_retries_on_500_then_succeeds(self, tmp_path, monkeypatch): + """A 500 response on the first attempt is retried; second attempt succeeds.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + ok_response = MagicMock() + ok_response.content = b"audio data" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(500), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + return await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + path = asyncio.run(run()) + assert path.endswith(".ogg") + assert mock_client.get.call_count == 2 + + def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch): + """Timeout on every attempt raises after all retries are consumed.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_timeout_error()) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + from gateway.platforms.base import cache_audio_from_url + await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + with pytest.raises(httpx.TimeoutException): + asyncio.run(run()) + + # 3 total calls: initial + 2 retries + assert mock_client.get.call_count == 3 + + def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch): + """A 404 (non-retryable) is raised immediately without any retry.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + mock_sleep = AsyncMock() + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_http_status_error(404)) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + from gateway.platforms.base import cache_audio_from_url + await cache_audio_from_url( + "http://example.com/voice.ogg", ext=".ogg", retries=2 + ) + + with pytest.raises(httpx.HTTPStatusError): + asyncio.run(run()) + + # Only 1 attempt, no sleep + assert mock_client.get.call_count == 1 + mock_sleep.assert_not_called() + + +# --------------------------------------------------------------------------- +# Slack mock setup (mirrors existing test_slack.py approach) +# --------------------------------------------------------------------------- + +def _ensure_slack_mock(): + if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"): + return + slack_bolt = MagicMock() + slack_bolt.async_app.AsyncApp = MagicMock + slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock + slack_sdk = MagicMock() + slack_sdk.web.async_client.AsyncWebClient = MagicMock + for name, mod in [ + ("slack_bolt", slack_bolt), + ("slack_bolt.async_app", slack_bolt.async_app), + ("slack_bolt.adapter", slack_bolt.adapter), + ("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode), + ("slack_bolt.adapter.socket_mode.async_handler", + slack_bolt.adapter.socket_mode.async_handler), + ("slack_sdk", slack_sdk), + ("slack_sdk.web", slack_sdk.web), + ("slack_sdk.web.async_client", slack_sdk.web.async_client), + ]: + sys.modules.setdefault(name, mod) + + +_ensure_slack_mock() + +import gateway.platforms.slack as _slack_mod # noqa: E402 +_slack_mod.SLACK_AVAILABLE = True + +from gateway.platforms.slack import SlackAdapter # noqa: E402 +from gateway.config import Platform, PlatformConfig # noqa: E402 + + +def _make_slack_adapter(): + config = PlatformConfig(enabled=True, token="xoxb-fake-token") + adapter = SlackAdapter(config) + adapter._app = MagicMock() + adapter._app.client = AsyncMock() + adapter._bot_user_id = "U_BOT" + adapter._running = True + return adapter + + +# --------------------------------------------------------------------------- +# SlackAdapter._download_slack_file +# --------------------------------------------------------------------------- + +class TestSlackDownloadSlackFile: + """Tests for SlackAdapter._download_slack_file""" + + def test_success_on_first_attempt(self, tmp_path, monkeypatch): + """Successful download on first try returns a cached file path.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + adapter = _make_slack_adapter() + + fake_response = MagicMock() + fake_response.content = b"fake image bytes" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=fake_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client): + return await adapter._download_slack_file( + "https://files.slack.com/img.jpg", ext=".jpg" + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + mock_client.get.assert_called_once() + + def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch): + """Timeout on first attempt triggers retry; success on second.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + adapter = _make_slack_adapter() + + fake_response = MagicMock() + fake_response.content = b"image bytes" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_timeout_error(), fake_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + mock_sleep = AsyncMock() + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + return await adapter._download_slack_file( + "https://files.slack.com/img.jpg", ext=".jpg" + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + assert mock_client.get.call_count == 2 + mock_sleep.assert_called_once() + + def test_raises_after_max_retries(self, tmp_path, monkeypatch): + """Timeout on every attempt eventually raises after 3 total tries.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + adapter = _make_slack_adapter() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_timeout_error()) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await adapter._download_slack_file( + "https://files.slack.com/img.jpg", ext=".jpg" + ) + + with pytest.raises(httpx.TimeoutException): + asyncio.run(run()) + + assert mock_client.get.call_count == 3 + + def test_non_retryable_403_raises_immediately(self, tmp_path, monkeypatch): + """A 403 is not retried; it raises immediately.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + adapter = _make_slack_adapter() + + mock_sleep = AsyncMock() + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_http_status_error(403)) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", mock_sleep): + await adapter._download_slack_file( + "https://files.slack.com/img.jpg", ext=".jpg" + ) + + with pytest.raises(httpx.HTTPStatusError): + asyncio.run(run()) + + assert mock_client.get.call_count == 1 + mock_sleep.assert_not_called() + + +# --------------------------------------------------------------------------- +# SlackAdapter._download_slack_file_bytes +# --------------------------------------------------------------------------- + +class TestSlackDownloadSlackFileBytes: + """Tests for SlackAdapter._download_slack_file_bytes""" + + def test_success_returns_bytes(self): + """Successful download returns raw bytes.""" + adapter = _make_slack_adapter() + + fake_response = MagicMock() + fake_response.content = b"raw bytes here" + fake_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=fake_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client): + return await adapter._download_slack_file_bytes( + "https://files.slack.com/file.bin" + ) + + result = asyncio.run(run()) + assert result == b"raw bytes here" + + def test_retries_on_429_then_succeeds(self): + """429 on first attempt is retried; raw bytes returned on second.""" + adapter = _make_slack_adapter() + + ok_response = MagicMock() + ok_response.content = b"final bytes" + ok_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get = AsyncMock( + side_effect=[_make_http_status_error(429), ok_response] + ) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + return await adapter._download_slack_file_bytes( + "https://files.slack.com/file.bin" + ) + + result = asyncio.run(run()) + assert result == b"final bytes" + assert mock_client.get.call_count == 2 + + def test_raises_after_max_retries(self): + """Persistent timeouts raise after all 3 attempts are exhausted.""" + adapter = _make_slack_adapter() + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=_make_timeout_error()) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + async def run(): + with patch("httpx.AsyncClient", return_value=mock_client), \ + patch("asyncio.sleep", new_callable=AsyncMock): + await adapter._download_slack_file_bytes( + "https://files.slack.com/file.bin" + ) + + with pytest.raises(httpx.TimeoutException): + asyncio.run(run()) + + assert mock_client.get.call_count == 3 + + +# --------------------------------------------------------------------------- +# MattermostAdapter._send_url_as_file +# --------------------------------------------------------------------------- + +def _make_mm_adapter(): + """Build a minimal MattermostAdapter with mocked internals.""" + from gateway.platforms.mattermost import MattermostAdapter + config = PlatformConfig( + enabled=True, token="mm-token-fake", + extra={"url": "https://mm.example.com"}, + ) + adapter = MattermostAdapter(config) + adapter._session = MagicMock() + adapter._upload_file = AsyncMock(return_value="file-id-123") + adapter._api_post = AsyncMock(return_value={"id": "post-id-abc"}) + adapter.send = AsyncMock(return_value=MagicMock(success=True)) + return adapter + + +def _make_aiohttp_resp(status: int, content: bytes = b"file bytes", + content_type: str = "image/jpeg"): + """Build a context-manager mock for an aiohttp response.""" + resp = MagicMock() + resp.status = status + resp.content_type = content_type + resp.read = AsyncMock(return_value=content) + resp.__aenter__ = AsyncMock(return_value=resp) + resp.__aexit__ = AsyncMock(return_value=False) + return resp + + +class TestMattermostSendUrlAsFile: + """Tests for MattermostAdapter._send_url_as_file""" + + def test_success_on_first_attempt(self): + """200 on first attempt → file uploaded and post created.""" + adapter = _make_mm_adapter() + resp = _make_aiohttp_resp(200) + adapter._session.get = MagicMock(return_value=resp) + + async def run(): + with patch("asyncio.sleep", new_callable=AsyncMock): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", "caption", None + ) + + result = asyncio.run(run()) + assert result.success + adapter._upload_file.assert_called_once() + adapter._api_post.assert_called_once() + + def test_retries_on_429_then_succeeds(self): + """429 on first attempt is retried; 200 on second attempt succeeds.""" + adapter = _make_mm_adapter() + + resp_429 = _make_aiohttp_resp(429) + resp_200 = _make_aiohttp_resp(200) + adapter._session.get = MagicMock(side_effect=[resp_429, resp_200]) + + mock_sleep = AsyncMock() + + async def run(): + with patch("asyncio.sleep", mock_sleep): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", None, None + ) + + result = asyncio.run(run()) + assert result.success + assert adapter._session.get.call_count == 2 + mock_sleep.assert_called_once() + + def test_retries_on_500_then_succeeds(self): + """5xx on first attempt is retried; 200 on second attempt succeeds.""" + adapter = _make_mm_adapter() + + resp_500 = _make_aiohttp_resp(500) + resp_200 = _make_aiohttp_resp(200) + adapter._session.get = MagicMock(side_effect=[resp_500, resp_200]) + + async def run(): + with patch("asyncio.sleep", new_callable=AsyncMock): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", None, None + ) + + result = asyncio.run(run()) + assert result.success + assert adapter._session.get.call_count == 2 + + def test_falls_back_to_text_after_max_retries_on_5xx(self): + """Three consecutive 500s exhaust retries; falls back to send() with URL text.""" + adapter = _make_mm_adapter() + + resp_500 = _make_aiohttp_resp(500) + adapter._session.get = MagicMock(return_value=resp_500) + + async def run(): + with patch("asyncio.sleep", new_callable=AsyncMock): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", "my caption", None + ) + + asyncio.run(run()) + + adapter.send.assert_called_once() + text_arg = adapter.send.call_args[0][1] + assert "http://cdn.example.com/img.png" in text_arg + + def test_falls_back_on_client_error(self): + """aiohttp.ClientError on every attempt falls back to send() with URL.""" + import aiohttp + + adapter = _make_mm_adapter() + + error_resp = MagicMock() + error_resp.__aenter__ = AsyncMock( + side_effect=aiohttp.ClientConnectionError("connection refused") + ) + error_resp.__aexit__ = AsyncMock(return_value=False) + adapter._session.get = MagicMock(return_value=error_resp) + + async def run(): + with patch("asyncio.sleep", new_callable=AsyncMock): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", None, None + ) + + asyncio.run(run()) + + adapter.send.assert_called_once() + text_arg = adapter.send.call_args[0][1] + assert "http://cdn.example.com/img.png" in text_arg + + def test_non_retryable_404_falls_back_immediately(self): + """404 is non-retryable (< 500, != 429); send() is called right away.""" + adapter = _make_mm_adapter() + + resp_404 = _make_aiohttp_resp(404) + adapter._session.get = MagicMock(return_value=resp_404) + + mock_sleep = AsyncMock() + + async def run(): + with patch("asyncio.sleep", mock_sleep): + return await adapter._send_url_as_file( + "C123", "http://cdn.example.com/img.png", None, None + ) + + asyncio.run(run()) + + adapter.send.assert_called_once() + # No sleep — fell back on first attempt + mock_sleep.assert_not_called() + assert adapter._session.get.call_count == 1 diff --git a/tests/gateway/test_platform_base.py b/tests/gateway/test_platform_base.py index 1aa0e1144..13b52f24f 100644 --- a/tests/gateway/test_platform_base.py +++ b/tests/gateway/test_platform_base.py @@ -62,6 +62,18 @@ class TestMessageEventGetCommand: event = MessageEvent(text="/") assert event.get_command() == "" + def test_command_with_at_botname(self): + event = MessageEvent(text="/new@TigerNanoBot") + assert event.get_command() == "new" + + def test_command_with_at_botname_and_args(self): + event = MessageEvent(text="/compress@TigerNanoBot") + assert event.get_command() == "compress" + + def test_command_mixed_case_with_at_botname(self): + event = MessageEvent(text="/RESET@TigerNanoBot") + assert event.get_command() == "reset" + class TestMessageEventGetCommandArgs: def test_command_with_args(self): diff --git a/tests/gateway/test_platform_reconnect.py b/tests/gateway/test_platform_reconnect.py index 3073f2f5d..68dfd2044 100644 --- a/tests/gateway/test_platform_reconnect.py +++ b/tests/gateway/test_platform_reconnect.py @@ -344,6 +344,7 @@ class TestRuntimeDisconnectQueuing: async def test_retryable_runtime_error_queued_for_reconnect(self): """Retryable runtime errors should add the platform to _failed_platforms.""" runner = _make_runner() + runner.stop = AsyncMock() adapter = StubAdapter(succeed=True) adapter._set_fatal_error("network_error", "DNS failure", retryable=True) @@ -371,8 +372,12 @@ class TestRuntimeDisconnectQueuing: assert Platform.TELEGRAM not in runner._failed_platforms @pytest.mark.asyncio - async def test_retryable_error_prevents_shutdown_when_queued(self): - """Gateway should not shut down if failed platforms are queued for reconnection.""" + async def test_retryable_error_exits_for_service_restart_when_all_down(self): + """Gateway should exit with failure when all platforms fail with retryable errors. + + This lets systemd Restart=on-failure restart the process, which is more + reliable than in-process background reconnection after exhausted retries. + """ runner = _make_runner() runner.stop = AsyncMock() @@ -382,7 +387,28 @@ class TestRuntimeDisconnectQueuing: await runner._handle_adapter_fatal_error(adapter) - # stop() should NOT have been called since we have platforms queued + # stop() SHOULD be called — gateway exits for systemd restart + runner.stop.assert_called_once() + assert runner._exit_with_failure is True + assert Platform.TELEGRAM in runner._failed_platforms + + @pytest.mark.asyncio + async def test_retryable_error_no_exit_when_other_adapters_still_connected(self): + """Gateway should NOT exit if some adapters are still connected.""" + runner = _make_runner() + runner.stop = AsyncMock() + + failing_adapter = StubAdapter(succeed=True) + failing_adapter._set_fatal_error("network_error", "DNS failure", retryable=True) + runner.adapters[Platform.TELEGRAM] = failing_adapter + + # Another adapter is still connected + healthy_adapter = StubAdapter(succeed=True) + runner.adapters[Platform.DISCORD] = healthy_adapter + + await runner._handle_adapter_fatal_error(failing_adapter) + + # stop() should NOT have been called — Discord is still up runner.stop.assert_not_called() assert Platform.TELEGRAM in runner._failed_platforms diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index c4839133c..95ad2fba7 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -14,8 +14,8 @@ from gateway.session import SessionSource class ProgressCaptureAdapter(BasePlatformAdapter): - def __init__(self): - super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) + def __init__(self, platform=Platform.TELEGRAM): + super().__init__(PlatformConfig(enabled=True, token="***"), platform) self.sent = [] self.edits = [] self.typing = [] @@ -76,7 +76,7 @@ def _make_runner(adapter): GatewayRunner = gateway_run.GatewayRunner runner = object.__new__(GatewayRunner) - runner.adapters = {Platform.TELEGRAM: adapter} + runner.adapters = {adapter.platform: adapter} runner._voice_mode = {} runner._prefill_messages = [] runner._ephemeral_system_prompt = "" @@ -133,3 +133,87 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa ] assert adapter.edits assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing) + + +@pytest.mark.asyncio +async def test_run_agent_progress_does_not_use_event_message_id_for_telegram_dm(monkeypatch, tmp_path): + """Telegram DM progress must not reuse event message id as thread metadata.""" + monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all") + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = FakeAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + adapter = ProgressCaptureAdapter(platform=Platform.TELEGRAM) + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="12345", + chat_type="dm", + thread_id=None, + ) + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-2", + session_key="agent:main:telegram:dm:12345", + event_message_id="777", + ) + + assert result["final_response"] == "done" + assert adapter.sent + assert adapter.sent[0]["metadata"] is None + assert all(call["metadata"] is None for call in adapter.typing) + + +@pytest.mark.asyncio +async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch, tmp_path): + """Slack DM progress should keep event ts fallback threading.""" + monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all") + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = FakeAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + adapter = ProgressCaptureAdapter(platform=Platform.SLACK) + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + + source = SessionSource( + platform=Platform.SLACK, + chat_id="D123", + chat_type="dm", + thread_id=None, + ) + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-3", + session_key="agent:main:slack:dm:D123", + event_message_id="1234567890.000001", + ) + + assert result["final_response"] == "done" + assert adapter.sent + assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"} + assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing) diff --git a/tests/gateway/test_runner_fatal_adapter.py b/tests/gateway/test_runner_fatal_adapter.py index 6eb285059..13b9a7d99 100644 --- a/tests/gateway/test_runner_fatal_adapter.py +++ b/tests/gateway/test_runner_fatal_adapter.py @@ -89,7 +89,8 @@ async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatc await runner._handle_adapter_fatal_error(adapter) - # Should NOT shut down — platform is queued for reconnection - runner.stop.assert_not_awaited() + # Should shut down with failure — systemd Restart=on-failure will restart + runner.stop.assert_awaited_once() + assert runner._exit_with_failure is True assert Platform.WHATSAPP in runner._failed_platforms assert runner._failed_platforms[Platform.WHATSAPP]["attempts"] == 0 diff --git a/tests/gateway/test_send_image_file.py b/tests/gateway/test_send_image_file.py index 25a841717..cb0e43673 100644 --- a/tests/gateway/test_send_image_file.py +++ b/tests/gateway/test_send_image_file.py @@ -76,7 +76,7 @@ def _ensure_telegram_mock(): telegram_mod.constants.ChatType.CHANNEL = "channel" telegram_mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, telegram_mod) diff --git a/tests/gateway/test_send_retry.py b/tests/gateway/test_send_retry.py new file mode 100644 index 000000000..4005f4071 --- /dev/null +++ b/tests/gateway/test_send_retry.py @@ -0,0 +1,231 @@ +""" +Tests for BasePlatformAdapter._send_with_retry and _is_retryable_error. + +Verifies that: +- Transient network errors trigger retry with backoff +- Permanent errors fall back to plain-text immediately (no retry) +- User receives a delivery-failure notice when all retries are exhausted +- Successful sends on retry return success +- SendResult.retryable flag is respected +""" +import pytest +from unittest.mock import AsyncMock, patch + +from gateway.platforms.base import BasePlatformAdapter, SendResult, _RETRYABLE_ERROR_PATTERNS +from gateway.platforms.base import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Minimal concrete adapter for testing (no real network) +# --------------------------------------------------------------------------- + +class _StubAdapter(BasePlatformAdapter): + def __init__(self): + cfg = PlatformConfig() + super().__init__(cfg, Platform.TELEGRAM) + self._send_results = [] # queue of SendResult to return per call + self._send_calls = [] # record of (chat_id, content) sent + + def _next_result(self) -> SendResult: + if self._send_results: + return self._send_results.pop(0) + return SendResult(success=True, message_id="ok") + + async def send(self, chat_id, content, reply_to=None, metadata=None, **kwargs) -> SendResult: + self._send_calls.append((chat_id, content)) + return self._next_result() + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + pass + + async def send_typing(self, chat_id, metadata=None) -> None: + pass + + async def get_chat_info(self, chat_id): + return {"name": "test", "type": "direct", "chat_id": chat_id} + + +# --------------------------------------------------------------------------- +# _is_retryable_error +# --------------------------------------------------------------------------- + +class TestIsRetryableError: + def test_none_is_not_retryable(self): + assert not _StubAdapter._is_retryable_error(None) + + def test_empty_string_is_not_retryable(self): + assert not _StubAdapter._is_retryable_error("") + + @pytest.mark.parametrize("pattern", _RETRYABLE_ERROR_PATTERNS) + def test_known_pattern_is_retryable(self, pattern): + assert _StubAdapter._is_retryable_error(f"httpx.{pattern.title()}: connection dropped") + + def test_permission_error_not_retryable(self): + assert not _StubAdapter._is_retryable_error("Forbidden: bot was blocked by the user") + + def test_bad_request_not_retryable(self): + assert not _StubAdapter._is_retryable_error("Bad Request: can't parse entities") + + def test_case_insensitive(self): + assert _StubAdapter._is_retryable_error("CONNECTERROR: host unreachable") + + +# --------------------------------------------------------------------------- +# _send_with_retry — success on first attempt +# --------------------------------------------------------------------------- + +class TestSendWithRetrySuccess: + @pytest.mark.asyncio + async def test_success_first_attempt(self): + adapter = _StubAdapter() + adapter._send_results = [SendResult(success=True, message_id="123")] + result = await adapter._send_with_retry("chat1", "hello") + assert result.success + assert len(adapter._send_calls) == 1 + + @pytest.mark.asyncio + async def test_returns_message_id(self): + adapter = _StubAdapter() + adapter._send_results = [SendResult(success=True, message_id="abc")] + result = await adapter._send_with_retry("chat1", "hi") + assert result.message_id == "abc" + + +# --------------------------------------------------------------------------- +# _send_with_retry — network error with successful retry +# --------------------------------------------------------------------------- + +class TestSendWithRetryNetworkRetry: + @pytest.mark.asyncio + async def test_retries_on_connect_error_and_succeeds(self): + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="httpx.ConnectError: connection refused"), + SendResult(success=True, message_id="ok"), + ] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0) + assert result.success + assert len(adapter._send_calls) == 2 # initial + 1 retry + + @pytest.mark.asyncio + async def test_retries_on_timeout_and_succeeds(self): + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="ReadTimeout: request timed out"), + SendResult(success=False, error="ReadTimeout: request timed out"), + SendResult(success=True, message_id="ok"), + ] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=3, base_delay=0) + assert result.success + assert len(adapter._send_calls) == 3 + + @pytest.mark.asyncio + async def test_retryable_flag_respected(self): + """SendResult.retryable=True should trigger retry even if error string doesn't match.""" + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="internal platform error", retryable=True), + SendResult(success=True, message_id="ok"), + ] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0) + assert result.success + assert len(adapter._send_calls) == 2 + + @pytest.mark.asyncio + async def test_network_to_nonnetwork_transition_falls_back_to_plaintext(self): + """If error switches from network to formatting mid-retry, fall through to plain-text fallback.""" + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="httpx.ConnectError: host unreachable"), + SendResult(success=False, error="Bad Request: can't parse entities"), + SendResult(success=True, message_id="fallback_ok"), # plain-text fallback + ] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0) + assert result.success + # 3 calls: initial (network) + 1 retry (non-network, breaks loop) + plain-text fallback + assert len(adapter._send_calls) == 3 + assert "plain text" in adapter._send_calls[-1][1].lower() + + +# --------------------------------------------------------------------------- +# _send_with_retry — all retries exhausted → user notification +# --------------------------------------------------------------------------- + +class TestSendWithRetryExhausted: + @pytest.mark.asyncio + async def test_sends_user_notice_after_exhaustion(self): + adapter = _StubAdapter() + network_err = SendResult(success=False, error="httpx.ConnectError: host unreachable") + # initial + 2 retries + notice attempt + adapter._send_results = [network_err, network_err, network_err, SendResult(success=True)] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0) + # Result is the last failed one (before notice) + assert not result.success + # 4 total calls: 1 initial + 2 retries + 1 notice + assert len(adapter._send_calls) == 4 + # The notice content should mention delivery failure + notice_content = adapter._send_calls[-1][1] + assert "delivery failed" in notice_content.lower() or "Message delivery failed" in notice_content + + @pytest.mark.asyncio + async def test_notice_send_exception_doesnt_propagate(self): + """If the notice itself throws, _send_with_retry should not raise.""" + adapter = _StubAdapter() + network_err = SendResult(success=False, error="ConnectError") + adapter._send_results = [network_err, network_err, network_err] + + original_send = adapter.send + call_count = [0] + + async def send_with_notice_failure(chat_id, content, **kwargs): + call_count[0] += 1 + if call_count[0] > 3: + raise RuntimeError("notice send also failed") + return network_err + + adapter.send = send_with_notice_failure + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0) + assert not result.success # still failed, but no exception raised + + +# --------------------------------------------------------------------------- +# _send_with_retry — non-network failure → plain-text fallback (no retry) +# --------------------------------------------------------------------------- + +class TestSendWithRetryFallback: + @pytest.mark.asyncio + async def test_non_network_error_falls_back_immediately(self): + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="Bad Request: can't parse entities"), + SendResult(success=True, message_id="fallback_ok"), + ] + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0) + # No sleep — no retry loop for non-network errors + mock_sleep.assert_not_called() + assert result.success + assert len(adapter._send_calls) == 2 + # Fallback content should be plain-text notice + assert "plain text" in adapter._send_calls[1][1].lower() + + @pytest.mark.asyncio + async def test_fallback_failure_logged_but_not_raised(self): + adapter = _StubAdapter() + adapter._send_results = [ + SendResult(success=False, error="Forbidden: bot blocked"), + SendResult(success=False, error="Forbidden: bot blocked"), + ] + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await adapter._send_with_retry("chat1", "hello", max_retries=2) + assert not result.success + assert len(adapter._send_calls) == 2 # original + fallback only diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index f31a80c3a..82281acc2 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -846,7 +846,7 @@ class TestLastPromptTokens: store.update_session("k1", model="openai/gpt-5.4") - store._db.update_token_counts.assert_called_once_with( + store._db.set_token_counts.assert_called_once_with( "s1", input_tokens=0, output_tokens=0, @@ -858,4 +858,48 @@ class TestLastPromptTokens: billing_provider=None, billing_base_url=None, model="openai/gpt-5.4", + absolute=True, ) + + +class TestRewriteTranscriptPreservesReasoning: + """rewrite_transcript must not drop reasoning fields from SQLite.""" + + def test_reasoning_survives_rewrite(self, tmp_path): + from hermes_state import SessionDB + + db = SessionDB(db_path=tmp_path / "test.db") + session_id = "reasoning-test" + db.create_session(session_id=session_id, source="cli") + + # Insert a message WITH all three reasoning fields + db.append_message( + session_id=session_id, + role="assistant", + content="The answer is 42.", + reasoning="I need to think step by step.", + reasoning_details=[{"type": "summary", "text": "step by step"}], + codex_reasoning_items=[{"id": "r1", "type": "reasoning"}], + ) + + # Verify all three were stored + before = db.get_messages_as_conversation(session_id) + assert before[0].get("reasoning") == "I need to think step by step." + assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}] + assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}] + + # Now simulate /retry: build the SessionStore and call rewrite_transcript + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._db = db + store._loaded = True + + # rewrite_transcript receives the messages that load_transcript returned + store.rewrite_transcript(session_id, before) + + # Load again — all three reasoning fields must survive + after = db.get_messages_as_conversation(session_id) + assert after[0].get("reasoning") == "I need to think step by step." + assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}] + assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}] diff --git a/tests/gateway/test_session_hygiene.py b/tests/gateway/test_session_hygiene.py index 80d249347..b8ff8f8a8 100644 --- a/tests/gateway/test_session_hygiene.py +++ b/tests/gateway/test_session_hygiene.py @@ -304,8 +304,12 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t class FakeCompressAgent: def __init__(self, **kwargs): self.model = kwargs.get("model") + self.session_id = kwargs.get("session_id", "fake-session") + self._print_fn = None def _compress_context(self, messages, *_args, **_kwargs): + # Simulate real _compress_context: create a new session_id + self.session_id = f"{self.session_id}_compressed" return ([{"role": "assistant", "content": "compressed"}], None) fake_run_agent = types.ModuleType("run_agent") diff --git a/tests/gateway/test_session_info.py b/tests/gateway/test_session_info.py new file mode 100644 index 000000000..5f04b1a48 --- /dev/null +++ b/tests/gateway/test_session_info.py @@ -0,0 +1,110 @@ +"""Tests for GatewayRunner._format_session_info — session config surfacing.""" + +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path + +from gateway.run import GatewayRunner + + +@pytest.fixture() +def runner(): + """Create a bare GatewayRunner without __init__.""" + return GatewayRunner.__new__(GatewayRunner) + + +def _patch_info(tmp_path, config_yaml, model, runtime): + """Return a context-manager stack that patches _format_session_info deps.""" + cfg_path = tmp_path / "config.yaml" + if config_yaml is not None: + cfg_path.write_text(config_yaml) + return ( + patch("gateway.run._hermes_home", tmp_path), + patch("gateway.run._resolve_gateway_model", return_value=model), + patch("gateway.run._resolve_runtime_agent_kwargs", return_value=runtime), + ) + + +class TestFormatSessionInfo: + + def test_includes_model_name(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: anthropic/claude-opus-4.6\n provider: openrouter\n", + "anthropic/claude-opus-4.6", + {"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"}) + with p1, p2, p3: + info = runner._format_session_info() + assert "claude-opus-4.6" in info + + def test_includes_provider(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n", + "test-model", + {"provider": "openrouter", "base_url": "", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "openrouter" in info + + def test_config_context_length(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 32768\n", + "test-model", + {"provider": "custom", "base_url": "", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "32K" in info + assert "config" in info + + def test_default_fallback_hint(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: unknown-model-xyz\n", + "unknown-model-xyz", + {"provider": "", "base_url": "", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "128K" in info + assert "model.context_length" in info + + def test_local_endpoint_shown(self, runner, tmp_path): + p1, p2, p3 = _patch_info( + tmp_path, + "model:\n default: qwen3:8b\n provider: custom\n base_url: http://localhost:11434/v1\n context_length: 8192\n", + "qwen3:8b", + {"provider": "custom", "base_url": "http://localhost:11434/v1", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "localhost:11434" in info + assert "8K" in info + + def test_cloud_endpoint_hidden(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n", + "test-model", + {"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"}) + with p1, p2, p3: + info = runner._format_session_info() + assert "Endpoint" not in info + + def test_million_context_format(self, runner, tmp_path): + p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 1000000\n", + "test-model", + {"provider": "", "base_url": "", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "1.0M" in info + + def test_missing_config(self, runner, tmp_path): + """No config.yaml should not crash.""" + p1, p2, p3 = _patch_info(tmp_path, None, # don't create config + "anthropic/claude-sonnet-4.6", + {"provider": "openrouter", "base_url": "", "api_key": ""}) + with p1, p2, p3: + info = runner._format_session_info() + assert "Model" in info + assert "Context" in info + + def test_runtime_resolution_failure_doesnt_crash(self, runner, tmp_path): + """If runtime resolution raises, should still produce output.""" + cfg_path = tmp_path / "config.yaml" + cfg_path.write_text("model:\n default: test-model\n context_length: 4096\n") + with patch("gateway.run._hermes_home", tmp_path), \ + patch("gateway.run._resolve_gateway_model", return_value="test-model"), \ + patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("no creds")): + info = runner._format_session_info() + assert "4K" in info + assert "config" in info diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index 8bf5537f4..acd6513e5 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -1,11 +1,42 @@ """Tests for Signal messenger platform adapter.""" +import base64 import json import pytest from unittest.mock import MagicMock, patch, AsyncMock +from urllib.parse import quote from gateway.config import Platform, PlatformConfig +# --------------------------------------------------------------------------- +# Shared Helpers +# --------------------------------------------------------------------------- + +def _make_signal_adapter(monkeypatch, account="+15551234567", **extra): + """Create a SignalAdapter with sensible test defaults.""" + monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", extra.pop("group_allowed", "")) + from gateway.platforms.signal import SignalAdapter + config = PlatformConfig() + config.enabled = True + config.extra = { + "http_url": "http://localhost:8080", + "account": account, + **extra, + } + return SignalAdapter(config) + + +def _stub_rpc(return_value): + """Return an async mock for SignalAdapter._rpc that captures call params.""" + captured = [] + + async def mock_rpc(method, params, rpc_id=None): + captured.append({"method": method, "params": dict(params)}) + return return_value + + return mock_rpc, captured + + # --------------------------------------------------------------------------- # Platform & Config # --------------------------------------------------------------------------- @@ -61,48 +92,22 @@ class TestSignalConfigLoading: # --------------------------------------------------------------------------- class TestSignalAdapterInit: - def _make_config(self, **extra): - config = PlatformConfig() - config.enabled = True - config.extra = { - "http_url": "http://localhost:8080", - "account": "+15551234567", - **extra, - } - return config - def test_init_parses_config(self, monkeypatch): - monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "group123,group456") - - from gateway.platforms.signal import SignalAdapter - adapter = SignalAdapter(self._make_config()) - + adapter = _make_signal_adapter(monkeypatch, group_allowed="group123,group456") assert adapter.http_url == "http://localhost:8080" assert adapter.account == "+15551234567" assert "group123" in adapter.group_allow_from def test_init_empty_allowlist(self, monkeypatch): - monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "") - - from gateway.platforms.signal import SignalAdapter - adapter = SignalAdapter(self._make_config()) - + adapter = _make_signal_adapter(monkeypatch) assert len(adapter.group_allow_from) == 0 def test_init_strips_trailing_slash(self, monkeypatch): - monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "") - - from gateway.platforms.signal import SignalAdapter - adapter = SignalAdapter(self._make_config(http_url="http://localhost:8080/")) - + adapter = _make_signal_adapter(monkeypatch, http_url="http://localhost:8080/") assert adapter.http_url == "http://localhost:8080" def test_self_message_filtering(self, monkeypatch): - monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "") - - from gateway.platforms.signal import SignalAdapter - adapter = SignalAdapter(self._make_config()) - + adapter = _make_signal_adapter(monkeypatch) assert adapter._account_normalized == "+15551234567" @@ -189,6 +194,73 @@ class TestSignalHelpers: assert check_signal_requirements() is False +# --------------------------------------------------------------------------- +# SSE URL Encoding (Bug Fix: phone numbers with + must be URL-encoded) +# --------------------------------------------------------------------------- + +class TestSignalSSEUrlEncoding: + """Verify that phone numbers with + are URL-encoded in the SSE endpoint.""" + + def test_sse_url_encodes_plus_in_account(self): + """The + in E.164 phone numbers must be percent-encoded in the SSE query string.""" + encoded = quote("+31612345678", safe="") + assert encoded == "%2B31612345678" + + def test_sse_url_encoding_preserves_digits(self): + """Digits and country codes should pass through URL encoding unchanged.""" + assert quote("+15551234567", safe="") == "%2B15551234567" + + +# --------------------------------------------------------------------------- +# Attachment Fetch (Bug Fix: parameter must be "id" not "attachmentId") +# --------------------------------------------------------------------------- + +class TestSignalAttachmentFetch: + """Verify that _fetch_attachment uses the correct RPC parameter name.""" + + @pytest.mark.asyncio + async def test_fetch_attachment_uses_id_parameter(self, monkeypatch): + """RPC getAttachment must use 'id', not 'attachmentId' (signal-cli requirement).""" + adapter = _make_signal_adapter(monkeypatch) + + png_data = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + b64_data = base64.b64encode(png_data).decode() + + adapter._rpc, captured = _stub_rpc({"data": b64_data}) + + with patch("gateway.platforms.signal.cache_image_from_bytes", return_value="/tmp/test.png"): + await adapter._fetch_attachment("attachment-123") + + call = captured[0] + assert call["method"] == "getAttachment" + assert call["params"]["id"] == "attachment-123" + assert "attachmentId" not in call["params"], "Must NOT use 'attachmentId' — causes NullPointerException in signal-cli" + assert call["params"]["account"] == "+15551234567" + + @pytest.mark.asyncio + async def test_fetch_attachment_returns_none_on_empty(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + adapter._rpc, _ = _stub_rpc(None) + path, ext = await adapter._fetch_attachment("missing-id") + assert path is None + assert ext == "" + + @pytest.mark.asyncio + async def test_fetch_attachment_handles_dict_response(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + + pdf_data = b"%PDF-1.4" + b"\x00" * 100 + b64_data = base64.b64encode(pdf_data).decode() + + adapter._rpc, _ = _stub_rpc({"data": b64_data}) + + with patch("gateway.platforms.signal.cache_document_from_bytes", return_value="/tmp/test.pdf"): + path, ext = await adapter._fetch_attachment("doc-456") + + assert path == "/tmp/test.pdf" + assert ext == ".pdf" + + # --------------------------------------------------------------------------- # Session Source # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_sse_agent_cancel.py b/tests/gateway/test_sse_agent_cancel.py new file mode 100644 index 000000000..6b5306fbe --- /dev/null +++ b/tests/gateway/test_sse_agent_cancel.py @@ -0,0 +1,280 @@ +"""Tests for SSE client disconnect → agent task cancellation. + +When a streaming /v1/chat/completions client disconnects mid-stream +(network drop, browser tab close), the agent is interrupted via +agent.interrupt() so it stops making LLM API calls, and the asyncio +task wrapper is cancelled. +""" + +import asyncio +import json +import queue +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Build a minimal APIServerAdapter with mocked internals.""" + from gateway.platforms.api_server import APIServerAdapter + from gateway.config import PlatformConfig + + config = PlatformConfig(enabled=True, token="test-key") + adapter = APIServerAdapter(config) + return adapter + + +def _make_request(): + """Build a mock aiohttp request.""" + req = MagicMock() + req.headers = {} + return req + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestSSEAgentCancelOnDisconnect: + """gateway/platforms/api_server.py — _write_sse_chat_completion()""" + + def test_agent_task_cancelled_on_client_disconnect(self): + """When response.write raises ConnectionResetError (client dropped), + the agent task must be cancelled.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello ") # Some data already queued + + # Agent task that runs forever (simulates a long LLM call) + agent_done = asyncio.Event() + + async def fake_agent(): + await agent_done.wait() + return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + # Mock response that raises ConnectionResetError on second write + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("client disconnected") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch.object(type(adapter), '_write_sse_chat_completion', + adapter._write_sse_chat_completion): + # Patch StreamResponse creation + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-123", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # The critical assertion: agent_task must be cancelled + assert agent_task.cancelled() or agent_task.done() + # Clean up + agent_done.set() + + asyncio.run(run()) + + def test_agent_task_not_cancelled_on_normal_completion(self): + """On normal stream completion, agent task should NOT be cancelled.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello") + stream_q.put(None) # End-of-stream sentinel + + async def fake_agent(): + return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + await asyncio.sleep(0) # Let agent complete + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock() + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-456", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # Agent should have completed normally, not been cancelled + assert agent_task.done() + assert not agent_task.cancelled() + + asyncio.run(run()) + + def test_broken_pipe_also_cancels_agent(self): + """BrokenPipeError (another disconnect variant) also cancels the task.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + + async def fake_agent(): + await asyncio.sleep(999) # Never completes + return {}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken")) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-789", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + assert agent_task.cancelled() or agent_task.done() + + asyncio.run(run()) + + def test_already_done_task_not_cancelled_on_disconnect(self): + """If agent already finished before disconnect, don't try to cancel.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("data") + + async def fake_agent(): + return {"final_response": "done"}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + await asyncio.sleep(0) # Let agent complete + + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("late disconnect") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-done", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + # Task was already done — should not be cancelled + assert agent_task.done() + assert not agent_task.cancelled() + + asyncio.run(run()) + + def test_agent_interrupt_called_on_disconnect(self): + """When the client disconnects, agent.interrupt() must be called + so the agent thread stops making LLM API calls.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + stream_q.put("hello ") + + agent_done = asyncio.Event() + + async def fake_agent(): + await agent_done.wait() + return {"final_response": "done"}, {} + + # Mock agent with an interrupt method + mock_agent = MagicMock() + mock_agent.interrupt = MagicMock() + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + agent_ref = [mock_agent] + + mock_response = AsyncMock(spec=web.StreamResponse) + call_count = 0 + + async def write_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise ConnectionResetError("client disconnected") + + mock_response.write = AsyncMock(side_effect=write_side_effect) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-int", "gpt-4", 1234567890, + stream_q, agent_task, agent_ref, + ) + + # agent.interrupt() must have been called + mock_agent.interrupt.assert_called_once_with("SSE client disconnected") + # Clean up + agent_done.set() + + asyncio.run(run()) + + def test_agent_ref_none_still_cancels_task(self): + """When agent_ref is not provided (None), the task is still cancelled + on disconnect — just without the interrupt() call.""" + adapter = _make_adapter() + + stream_q = queue.Queue() + + async def fake_agent(): + await asyncio.sleep(999) + return {}, {} + + async def run(): + from aiohttp import web + + agent_task = asyncio.ensure_future(fake_agent()) + + mock_response = AsyncMock(spec=web.StreamResponse) + mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone")) + mock_response.prepare = AsyncMock() + + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=mock_response): + # No agent_ref passed — should still handle disconnect cleanly + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-noref", "gpt-4", 1234567890, + stream_q, agent_task, + ) + + assert agent_task.cancelled() or agent_task.done() + + asyncio.run(run()) diff --git a/tests/gateway/test_telegram_conflict.py b/tests/gateway/test_telegram_conflict.py index c96768de2..9f1074648 100644 --- a/tests/gateway/test_telegram_conflict.py +++ b/tests/gateway/test_telegram_conflict.py @@ -20,7 +20,7 @@ def _ensure_telegram_mock(): telegram_mod.constants.ChatType.CHANNEL = "channel" telegram_mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, telegram_mod) @@ -29,6 +29,14 @@ _ensure_telegram_mock() from gateway.platforms.telegram import TelegramAdapter # noqa: E402 +@pytest.fixture(autouse=True) +def _no_auto_discovery(monkeypatch): + """Disable DoH auto-discovery so connect() uses the plain builder chain.""" + async def _noop(): + return [] + monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) + + @pytest.mark.asyncio async def test_connect_rejects_same_host_token_lock(monkeypatch): adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token")) diff --git a/tests/gateway/test_telegram_documents.py b/tests/gateway/test_telegram_documents.py index 0472bdbac..11a8df5f8 100644 --- a/tests/gateway/test_telegram_documents.py +++ b/tests/gateway/test_telegram_documents.py @@ -45,7 +45,7 @@ def _ensure_telegram_mock(): telegram_mod.constants.ChatType.CHANNEL = "channel" telegram_mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, telegram_mod) diff --git a/tests/gateway/test_telegram_format.py b/tests/gateway/test_telegram_format.py index 446a3e1b9..7a50aded4 100644 --- a/tests/gateway/test_telegram_format.py +++ b/tests/gateway/test_telegram_format.py @@ -28,7 +28,7 @@ def _ensure_telegram_mock(): mod.constants.ChatType.SUPERGROUP = "supergroup" mod.constants.ChatType.CHANNEL = "channel" mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, mod) diff --git a/tests/gateway/test_telegram_network.py b/tests/gateway/test_telegram_network.py new file mode 100644 index 000000000..2770211f3 --- /dev/null +++ b/tests/gateway/test_telegram_network.py @@ -0,0 +1,644 @@ +"""Tests for gateway.platforms.telegram_network – fallback transport layer. + +Background +---------- +api.telegram.org resolves to an IP (e.g. 149.154.166.110) that is unreachable +from some networks. The workaround: route TCP through a different IP in the +same Telegram-owned 149.154.160.0/20 block (e.g. 149.154.167.220) while +keeping TLS SNI and the Host header as api.telegram.org so Telegram's edge +servers still accept the request. This is the programmatic equivalent of: + + curl --resolve api.telegram.org:443:149.154.167.220 https://api.telegram.org/bot/getMe + +The TelegramFallbackTransport implements this: try the primary (DNS-resolved) +path first, and on ConnectTimeout / ConnectError fall through to configured +fallback IPs in order, then "stick" to whichever IP works. +""" + +import httpx +import pytest + +from gateway.platforms import telegram_network as tnet + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeTransport(httpx.AsyncBaseTransport): + """Records calls and raises / returns based on a host→action mapping.""" + + def __init__(self, calls, behavior): + self.calls = calls + self.behavior = behavior + self.closed = False + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + self.calls.append( + { + "url_host": request.url.host, + "host_header": request.headers.get("host"), + "sni_hostname": request.extensions.get("sni_hostname"), + "path": request.url.path, + } + ) + action = self.behavior.get(request.url.host, "ok") + if action == "timeout": + raise httpx.ConnectTimeout("timed out") + if action == "connect_error": + raise httpx.ConnectError("connect error") + if isinstance(action, Exception): + raise action + return httpx.Response(200, request=request, text="ok") + + async def aclose(self) -> None: + self.closed = True + + +def _fake_transport_factory(calls, behavior): + """Returns a factory that creates FakeTransport instances.""" + instances = [] + + def factory(**kwargs): + t = FakeTransport(calls, behavior) + instances.append(t) + return t + + factory.instances = instances + return factory + + +def _telegram_request(path="/botTOKEN/getMe"): + return httpx.Request("GET", f"https://api.telegram.org{path}") + + +# ═══════════════════════════════════════════════════════════════════════════ +# IP parsing & validation +# ═══════════════════════════════════════════════════════════════════════════ + +class TestParseFallbackIpEnv: + def test_filters_invalid_and_ipv6(self, caplog): + ips = tnet.parse_fallback_ip_env("149.154.167.220, bad, 2001:67c:4e8:f004::9,149.154.167.220") + assert ips == ["149.154.167.220", "149.154.167.220"] + assert "Ignoring invalid Telegram fallback IP" in caplog.text + assert "Ignoring non-IPv4 Telegram fallback IP" in caplog.text + + def test_none_returns_empty(self): + assert tnet.parse_fallback_ip_env(None) == [] + + def test_empty_string_returns_empty(self): + assert tnet.parse_fallback_ip_env("") == [] + + def test_whitespace_only_returns_empty(self): + assert tnet.parse_fallback_ip_env(" , , ") == [] + + def test_single_valid_ip(self): + assert tnet.parse_fallback_ip_env("149.154.167.220") == ["149.154.167.220"] + + def test_multiple_valid_ips(self): + ips = tnet.parse_fallback_ip_env("149.154.167.220, 149.154.167.221") + assert ips == ["149.154.167.220", "149.154.167.221"] + + def test_rejects_leading_zeros(self, caplog): + """Leading zeros are ambiguous (octal?) so ipaddress rejects them.""" + ips = tnet.parse_fallback_ip_env("149.154.167.010") + assert ips == [] + assert "Ignoring invalid" in caplog.text + + +class TestNormalizeFallbackIps: + def test_deduplication_happens_at_transport_level(self): + """_normalize does not dedup; TelegramFallbackTransport.__init__ does.""" + raw = ["149.154.167.220", "149.154.167.220"] + assert tnet._normalize_fallback_ips(raw) == ["149.154.167.220", "149.154.167.220"] + + def test_empty_strings_skipped(self): + assert tnet._normalize_fallback_ips(["", " ", "149.154.167.220"]) == ["149.154.167.220"] + + +# ═══════════════════════════════════════════════════════════════════════════ +# Request rewriting +# ═══════════════════════════════════════════════════════════════════════════ + +class TestRewriteRequestForIp: + def test_preserves_host_and_sni(self): + request = _telegram_request() + rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220") + + assert rewritten.url.host == "149.154.167.220" + assert rewritten.headers["host"] == "api.telegram.org" + assert rewritten.extensions["sni_hostname"] == "api.telegram.org" + assert rewritten.url.path == "/botTOKEN/getMe" + + def test_preserves_method_and_path(self): + request = httpx.Request("POST", "https://api.telegram.org/botTOKEN/sendMessage") + rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220") + + assert rewritten.method == "POST" + assert rewritten.url.path == "/botTOKEN/sendMessage" + + +# ═══════════════════════════════════════════════════════════════════════════ +# Fallback transport – core behavior +# ═══════════════════════════════════════════════════════════════════════════ + +class TestFallbackTransport: + """Primary path fails → try fallback IPs → stick to whichever works.""" + + @pytest.mark.asyncio + async def test_falls_back_on_connect_timeout_and_becomes_sticky(self, monkeypatch): + calls = [] + behavior = {"api.telegram.org": "timeout", "149.154.167.220": "ok"} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + resp = await transport.handle_async_request(_telegram_request()) + + assert resp.status_code == 200 + assert transport._sticky_ip == "149.154.167.220" + # First attempt was primary (api.telegram.org), second was fallback + assert calls[0]["url_host"] == "api.telegram.org" + assert calls[1]["url_host"] == "149.154.167.220" + assert calls[1]["host_header"] == "api.telegram.org" + assert calls[1]["sni_hostname"] == "api.telegram.org" + + # Second request goes straight to sticky IP + calls.clear() + resp2 = await transport.handle_async_request(_telegram_request()) + assert resp2.status_code == 200 + assert calls[0]["url_host"] == "149.154.167.220" + + @pytest.mark.asyncio + async def test_falls_back_on_connect_error(self, monkeypatch): + calls = [] + behavior = {"api.telegram.org": "connect_error", "149.154.167.220": "ok"} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + resp = await transport.handle_async_request(_telegram_request()) + + assert resp.status_code == 200 + assert transport._sticky_ip == "149.154.167.220" + + @pytest.mark.asyncio + async def test_does_not_fallback_on_non_connect_error(self, monkeypatch): + """Errors like ReadTimeout are not connection issues — don't retry.""" + calls = [] + behavior = {"api.telegram.org": httpx.ReadTimeout("read timeout"), "149.154.167.220": "ok"} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + + with pytest.raises(httpx.ReadTimeout): + await transport.handle_async_request(_telegram_request()) + + assert [c["url_host"] for c in calls] == ["api.telegram.org"] + + @pytest.mark.asyncio + async def test_all_ips_fail_raises_last_error(self, monkeypatch): + calls = [] + behavior = {"api.telegram.org": "timeout", "149.154.167.220": "timeout"} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + + with pytest.raises(httpx.ConnectTimeout): + await transport.handle_async_request(_telegram_request()) + + assert [c["url_host"] for c in calls] == ["api.telegram.org", "149.154.167.220"] + assert transport._sticky_ip is None + + @pytest.mark.asyncio + async def test_multiple_fallback_ips_tried_in_order(self, monkeypatch): + calls = [] + behavior = { + "api.telegram.org": "timeout", + "149.154.167.220": "timeout", + "149.154.167.221": "ok", + } + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"]) + resp = await transport.handle_async_request(_telegram_request()) + + assert resp.status_code == 200 + assert transport._sticky_ip == "149.154.167.221" + assert [c["url_host"] for c in calls] == [ + "api.telegram.org", + "149.154.167.220", + "149.154.167.221", + ] + + @pytest.mark.asyncio + async def test_sticky_ip_tried_first_but_falls_through_if_stale(self, monkeypatch): + """If the sticky IP stops working, the transport retries others.""" + calls = [] + behavior = { + "api.telegram.org": "timeout", + "149.154.167.220": "ok", + "149.154.167.221": "ok", + } + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"]) + + # First request: primary fails → .220 works → becomes sticky + await transport.handle_async_request(_telegram_request()) + assert transport._sticky_ip == "149.154.167.220" + + # Now .220 goes bad too + calls.clear() + behavior["149.154.167.220"] = "timeout" + + resp = await transport.handle_async_request(_telegram_request()) + assert resp.status_code == 200 + # Tried sticky (.220) first, then fell through to .221 + assert [c["url_host"] for c in calls] == ["149.154.167.220", "149.154.167.221"] + assert transport._sticky_ip == "149.154.167.221" + + +class TestFallbackTransportPassthrough: + """Requests that don't need fallback behavior.""" + + @pytest.mark.asyncio + async def test_non_telegram_host_bypasses_fallback(self, monkeypatch): + calls = [] + behavior = {} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + request = httpx.Request("GET", "https://example.com/path") + resp = await transport.handle_async_request(request) + + assert resp.status_code == 200 + assert calls[0]["url_host"] == "example.com" + assert transport._sticky_ip is None + + @pytest.mark.asyncio + async def test_empty_fallback_list_uses_primary_only(self, monkeypatch): + calls = [] + behavior = {} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport([]) + resp = await transport.handle_async_request(_telegram_request()) + + assert resp.status_code == 200 + assert calls[0]["url_host"] == "api.telegram.org" + + @pytest.mark.asyncio + async def test_primary_succeeds_no_fallback_needed(self, monkeypatch): + calls = [] + behavior = {"api.telegram.org": "ok"} + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior)) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + resp = await transport.handle_async_request(_telegram_request()) + + assert resp.status_code == 200 + assert transport._sticky_ip is None + assert len(calls) == 1 + + +class TestFallbackTransportInit: + def test_deduplicates_fallback_ips(self, monkeypatch): + monkeypatch.setattr( + tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {}) + ) + transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.220"]) + assert transport._fallback_ips == ["149.154.167.220"] + + def test_filters_invalid_ips_at_init(self, monkeypatch): + monkeypatch.setattr( + tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {}) + ) + transport = tnet.TelegramFallbackTransport(["149.154.167.220", "not-an-ip"]) + assert transport._fallback_ips == ["149.154.167.220"] + + def test_uses_proxy_env_for_primary_and_fallback_transports(self, monkeypatch): + seen_kwargs = [] + + def factory(**kwargs): + seen_kwargs.append(kwargs.copy()) + return FakeTransport([], {}) + + for key in ("HTTPS_PROXY", "HTTP_PROXY", "ALL_PROXY", "https_proxy", "http_proxy", "all_proxy"): + monkeypatch.delenv(key, raising=False) + monkeypatch.setenv("HTTPS_PROXY", "http://proxy.example:8080") + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", factory) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220"]) + + assert transport._fallback_ips == ["149.154.167.220"] + assert len(seen_kwargs) == 2 + assert all(kwargs["proxy"] == "http://proxy.example:8080" for kwargs in seen_kwargs) + + +class TestFallbackTransportClose: + @pytest.mark.asyncio + async def test_aclose_closes_all_transports(self, monkeypatch): + factory = _fake_transport_factory([], {}) + monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", factory) + + transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"]) + await transport.aclose() + + # 1 primary + 2 fallback transports + assert len(factory.instances) == 3 + assert all(t.closed for t in factory.instances) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Config layer – TELEGRAM_FALLBACK_IPS env → config.extra +# ═══════════════════════════════════════════════════════════════════════════ + +class TestConfigFallbackIps: + def test_env_var_populates_config_extra(self, monkeypatch): + from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides + + monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220,149.154.167.221") + config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")}) + _apply_env_overrides(config) + + assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [ + "149.154.167.220", "149.154.167.221", + ] + + def test_env_var_creates_platform_if_missing(self, monkeypatch): + from gateway.config import GatewayConfig, Platform, _apply_env_overrides + + monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220") + config = GatewayConfig(platforms={}) + _apply_env_overrides(config) + + assert Platform.TELEGRAM in config.platforms + assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == ["149.154.167.220"] + + def test_env_var_strips_whitespace(self, monkeypatch): + from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides + + monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", " 149.154.167.220 , 149.154.167.221 ") + config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")}) + _apply_env_overrides(config) + + assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [ + "149.154.167.220", "149.154.167.221", + ] + + def test_empty_env_var_does_not_populate(self, monkeypatch): + from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides + + monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "") + config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")}) + _apply_env_overrides(config) + + assert "fallback_ips" not in config.platforms[Platform.TELEGRAM].extra + + +# ═══════════════════════════════════════════════════════════════════════════ +# Adapter layer – _fallback_ips() reads config correctly +# ═══════════════════════════════════════════════════════════════════════════ + +class TestAdapterFallbackIps: + def _make_adapter(self, extra=None): + import sys + from unittest.mock import MagicMock + + # Ensure telegram mock is in place + if "telegram" not in sys.modules or not hasattr(sys.modules["telegram"], "__file__"): + mod = MagicMock() + mod.ext.ContextTypes.DEFAULT_TYPE = type(None) + mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2" + mod.constants.ChatType.GROUP = "group" + mod.constants.ChatType.SUPERGROUP = "supergroup" + mod.constants.ChatType.CHANNEL = "channel" + mod.constants.ChatType.PRIVATE = "private" + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): + sys.modules.setdefault(name, mod) + + from gateway.config import PlatformConfig + from gateway.platforms.telegram import TelegramAdapter + + config = PlatformConfig(enabled=True, token="test-token") + if extra: + config.extra.update(extra) + return TelegramAdapter(config) + + def test_list_in_extra(self): + adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220"]}) + assert adapter._fallback_ips() == ["149.154.167.220"] + + def test_csv_string_in_extra(self): + adapter = self._make_adapter(extra={"fallback_ips": "149.154.167.220,149.154.167.221"}) + assert adapter._fallback_ips() == ["149.154.167.220", "149.154.167.221"] + + def test_empty_extra(self): + adapter = self._make_adapter() + assert adapter._fallback_ips() == [] + + def test_no_extra_attr(self): + adapter = self._make_adapter() + adapter.config.extra = None + assert adapter._fallback_ips() == [] + + def test_invalid_ips_filtered(self): + adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220", "not-valid"]}) + assert adapter._fallback_ips() == ["149.154.167.220"] + + +# ═══════════════════════════════════════════════════════════════════════════ +# DoH auto-discovery +# ═══════════════════════════════════════════════════════════════════════════ + +def _doh_answer(*ips: str) -> dict: + """Build a minimal DoH JSON response with A records.""" + return {"Answer": [{"type": 1, "data": ip} for ip in ips]} + + +class FakeDoHClient: + """Mock httpx.AsyncClient for DoH queries.""" + + def __init__(self, responses: dict): + # responses: URL prefix → (status, json_body) | Exception + self._responses = responses + self.requests_made: list[dict] = [] + + @staticmethod + def _make_response(status, body, url): + """Build an httpx.Response with a request attached (needed for raise_for_status).""" + request = httpx.Request("GET", url) + return httpx.Response(status, json=body, request=request) + + async def get(self, url, *, params=None, headers=None, **kwargs): + self.requests_made.append({"url": url, "params": params, "headers": headers}) + for prefix, action in self._responses.items(): + if url.startswith(prefix): + if isinstance(action, Exception): + raise action + status, body = action + return self._make_response(status, body, url) + return self._make_response(200, {}, url) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class TestDiscoverFallbackIps: + """Tests for discover_fallback_ips() — DoH-based auto-discovery.""" + + def _patch_doh(self, monkeypatch, responses, system_dns_ips=None): + """Wire up fake DoH client and system DNS.""" + client = FakeDoHClient(responses) + monkeypatch.setattr(tnet.httpx, "AsyncClient", lambda **kw: client) + + if system_dns_ips is not None: + addrs = [(None, None, None, None, (ip, 443)) for ip in system_dns_ips] + monkeypatch.setattr(tnet.socket, "getaddrinfo", lambda *a, **kw: addrs) + else: + def _fail(*a, **kw): + raise OSError("dns failed") + monkeypatch.setattr(tnet.socket, "getaddrinfo", _fail) + return client + + @pytest.mark.asyncio + async def test_google_and_cloudflare_ips_collected(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.167.220")), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert "149.154.167.220" in ips + assert "149.154.167.221" in ips + + @pytest.mark.asyncio + async def test_system_dns_ip_excluded(self, monkeypatch): + """The IP from system DNS is the one that doesn't work — exclude it.""" + self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == ["149.154.167.220"] + + @pytest.mark.asyncio + async def test_doh_results_deduplicated(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.167.220")), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == ["149.154.167.220"] + + @pytest.mark.asyncio + async def test_doh_timeout_falls_back_to_seed(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": httpx.TimeoutException("timeout"), + "https://cloudflare-dns.com": httpx.TimeoutException("timeout"), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == tnet._SEED_FALLBACK_IPS + + @pytest.mark.asyncio + async def test_doh_connect_error_falls_back_to_seed(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": httpx.ConnectError("refused"), + "https://cloudflare-dns.com": httpx.ConnectError("refused"), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == tnet._SEED_FALLBACK_IPS + + @pytest.mark.asyncio + async def test_doh_malformed_json_falls_back_to_seed(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": (200, {"Status": 0}), # no Answer key + "https://cloudflare-dns.com": (200, {"garbage": True}), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == tnet._SEED_FALLBACK_IPS + + @pytest.mark.asyncio + async def test_one_provider_fails_other_succeeds(self, monkeypatch): + self._patch_doh(monkeypatch, { + "https://dns.google": httpx.TimeoutException("timeout"), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == ["149.154.167.220"] + + @pytest.mark.asyncio + async def test_system_dns_failure_keeps_all_doh_ips(self, monkeypatch): + """If system DNS fails, nothing gets excluded — all DoH IPs kept.""" + self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")), + "https://cloudflare-dns.com": (200, _doh_answer()), + }, system_dns_ips=None) # triggers OSError + + ips = await tnet.discover_fallback_ips() + assert "149.154.166.110" in ips + assert "149.154.167.220" in ips + + @pytest.mark.asyncio + async def test_all_doh_ips_same_as_system_dns_uses_seed(self, monkeypatch): + """DoH returns only the same blocked IP — seed list is the fallback.""" + self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.166.110")), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == tnet._SEED_FALLBACK_IPS + + @pytest.mark.asyncio + async def test_cloudflare_gets_accept_header(self, monkeypatch): + client = self._patch_doh(monkeypatch, { + "https://dns.google": (200, _doh_answer("149.154.167.220")), + "https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")), + }, system_dns_ips=["149.154.166.110"]) + + await tnet.discover_fallback_ips() + + cf_reqs = [r for r in client.requests_made if "cloudflare" in r["url"]] + assert cf_reqs + assert cf_reqs[0]["headers"]["Accept"] == "application/dns-json" + + @pytest.mark.asyncio + async def test_non_a_records_ignored(self, monkeypatch): + """AAAA records (type 28) and CNAME (type 5) should be skipped.""" + answer = { + "Answer": [ + {"type": 5, "data": "telegram.org"}, # CNAME + {"type": 28, "data": "2001:67c:4e8:f004::9"}, # AAAA + {"type": 1, "data": "149.154.167.220"}, # A ✓ + ] + } + self._patch_doh(monkeypatch, { + "https://dns.google": (200, answer), + "https://cloudflare-dns.com": (200, _doh_answer()), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == ["149.154.167.220"] + + @pytest.mark.asyncio + async def test_invalid_ip_in_doh_response_skipped(self, monkeypatch): + answer = {"Answer": [ + {"type": 1, "data": "not-an-ip"}, + {"type": 1, "data": "149.154.167.220"}, + ]} + self._patch_doh(monkeypatch, { + "https://dns.google": (200, answer), + "https://cloudflare-dns.com": (200, _doh_answer()), + }, system_dns_ips=["149.154.166.110"]) + + ips = await tnet.discover_fallback_ips() + assert ips == ["149.154.167.220"] diff --git a/tests/gateway/test_telegram_network_reconnect.py b/tests/gateway/test_telegram_network_reconnect.py index 822382357..f78a7f208 100644 --- a/tests/gateway/test_telegram_network_reconnect.py +++ b/tests/gateway/test_telegram_network_reconnect.py @@ -27,7 +27,7 @@ def _ensure_telegram_mock(): telegram_mod.constants.ChatType.CHANNEL = "channel" telegram_mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, telegram_mod) @@ -36,6 +36,14 @@ _ensure_telegram_mock() from gateway.platforms.telegram import TelegramAdapter # noqa: E402 +@pytest.fixture(autouse=True) +def _no_auto_discovery(monkeypatch): + """Disable DoH auto-discovery so connect() uses the plain builder chain.""" + async def _noop(): + return [] + monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop) + + def _make_adapter() -> TelegramAdapter: return TelegramAdapter(PlatformConfig(enabled=True, token="test-token")) diff --git a/tests/gateway/test_telegram_reply_mode.py b/tests/gateway/test_telegram_reply_mode.py index 1ec16b512..1218afa0c 100644 --- a/tests/gateway/test_telegram_reply_mode.py +++ b/tests/gateway/test_telegram_reply_mode.py @@ -25,7 +25,7 @@ def _ensure_telegram_mock(): mod.constants.ChatType.SUPERGROUP = "supergroup" mod.constants.ChatType.CHANNEL = "channel" mod.constants.ChatType.PRIVATE = "private" - for name in ("telegram", "telegram.ext", "telegram.constants"): + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): sys.modules.setdefault(name, mod) diff --git a/tests/gateway/test_telegram_thread_fallback.py b/tests/gateway/test_telegram_thread_fallback.py new file mode 100644 index 000000000..e2817d834 --- /dev/null +++ b/tests/gateway/test_telegram_thread_fallback.py @@ -0,0 +1,199 @@ +"""Tests for Telegram send() thread_id fallback. + +When message_thread_id points to a non-existent thread, Telegram returns +BadRequest('Message thread not found'). Since BadRequest is a subclass of +NetworkError in python-telegram-bot, the old retry loop treated this as a +transient error and retried 3 times before silently failing — killing all +tool progress messages, streaming responses, and typing indicators. + +The fix detects "thread not found" BadRequest errors and retries the send +WITHOUT message_thread_id so the message still reaches the chat. +""" + +import sys +import types +from types import SimpleNamespace + +import pytest + +from gateway.config import PlatformConfig, Platform +from gateway.platforms.base import SendResult + + +# ── Fake telegram.error hierarchy ────────────────────────────────────── +# Mirrors the real python-telegram-bot hierarchy: +# BadRequest → NetworkError → TelegramError → Exception + + +class FakeNetworkError(Exception): + pass + + +class FakeBadRequest(FakeNetworkError): + pass + + +# Build a fake telegram module tree so the adapter's internal imports work +_fake_telegram = types.ModuleType("telegram") +_fake_telegram_error = types.ModuleType("telegram.error") +_fake_telegram_error.NetworkError = FakeNetworkError +_fake_telegram_error.BadRequest = FakeBadRequest +_fake_telegram.error = _fake_telegram_error +_fake_telegram_constants = types.ModuleType("telegram.constants") +_fake_telegram_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2") +_fake_telegram.constants = _fake_telegram_constants + + +@pytest.fixture(autouse=True) +def _inject_fake_telegram(monkeypatch): + """Inject fake telegram modules so the adapter can import from them.""" + monkeypatch.setitem(sys.modules, "telegram", _fake_telegram) + monkeypatch.setitem(sys.modules, "telegram.error", _fake_telegram_error) + monkeypatch.setitem(sys.modules, "telegram.constants", _fake_telegram_constants) + + +def _make_adapter(): + from gateway.platforms.telegram import TelegramAdapter + + config = PlatformConfig(enabled=True, token="fake-token") + adapter = object.__new__(TelegramAdapter) + adapter._config = config + adapter._platform = Platform.TELEGRAM + adapter._connected = True + adapter._dm_topics = {} + adapter._dm_topics_config = [] + adapter._reply_to_mode = "first" + adapter._fallback_ips = [] + adapter._polling_conflict_count = 0 + adapter._polling_network_error_count = 0 + adapter._polling_error_callback_ref = None + adapter.platform = Platform.TELEGRAM + return adapter + + +@pytest.mark.asyncio +async def test_send_retries_without_thread_on_thread_not_found(): + """When message_thread_id causes 'thread not found', retry without it.""" + adapter = _make_adapter() + + call_log = [] + + async def mock_send_message(**kwargs): + call_log.append(dict(kwargs)) + tid = kwargs.get("message_thread_id") + if tid is not None: + raise FakeBadRequest("Message thread not found") + return SimpleNamespace(message_id=42) + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + result = await adapter.send( + chat_id="123", + content="test message", + metadata={"thread_id": "99999"}, + ) + + assert result.success is True + assert result.message_id == "42" + # First call has thread_id, second call retries without + assert len(call_log) == 2 + assert call_log[0]["message_thread_id"] == 99999 + assert call_log[1]["message_thread_id"] is None + + +@pytest.mark.asyncio +async def test_send_raises_on_other_bad_request(): + """Non-thread BadRequest errors should NOT be retried — they fail immediately.""" + adapter = _make_adapter() + + async def mock_send_message(**kwargs): + raise FakeBadRequest("Chat not found") + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + result = await adapter.send( + chat_id="123", + content="test message", + metadata={"thread_id": "99999"}, + ) + + assert result.success is False + assert "Chat not found" in result.error + + +@pytest.mark.asyncio +async def test_send_without_thread_id_unaffected(): + """Normal sends without thread_id should work as before.""" + adapter = _make_adapter() + + call_log = [] + + async def mock_send_message(**kwargs): + call_log.append(dict(kwargs)) + return SimpleNamespace(message_id=100) + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + result = await adapter.send( + chat_id="123", + content="test message", + ) + + assert result.success is True + assert len(call_log) == 1 + assert call_log[0]["message_thread_id"] is None + + +@pytest.mark.asyncio +async def test_send_retries_network_errors_normally(): + """Real transient network errors (not BadRequest) should still be retried.""" + adapter = _make_adapter() + + attempt = [0] + + async def mock_send_message(**kwargs): + attempt[0] += 1 + if attempt[0] < 3: + raise FakeNetworkError("Connection reset") + return SimpleNamespace(message_id=200) + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + result = await adapter.send( + chat_id="123", + content="test message", + ) + + assert result.success is True + assert attempt[0] == 3 # Two retries then success + + +@pytest.mark.asyncio +async def test_thread_fallback_only_fires_once(): + """After clearing thread_id, subsequent chunks should also use None.""" + adapter = _make_adapter() + + call_log = [] + + async def mock_send_message(**kwargs): + call_log.append(dict(kwargs)) + tid = kwargs.get("message_thread_id") + if tid is not None: + raise FakeBadRequest("Message thread not found") + return SimpleNamespace(message_id=42) + + adapter._bot = SimpleNamespace(send_message=mock_send_message) + + # Send a long message that gets split into chunks + long_msg = "A" * 5000 # Exceeds Telegram's 4096 limit + result = await adapter.send( + chat_id="123", + content=long_msg, + metadata={"thread_id": "99999"}, + ) + + assert result.success is True + # First chunk: attempt with thread → fail → retry without → succeed + # Second chunk: should use thread_id=None directly (effective_thread_id + # was cleared per-chunk but the metadata doesn't change between chunks) + # The key point: the message was delivered despite the invalid thread diff --git a/tests/gateway/test_unauthorized_dm_behavior.py b/tests/gateway/test_unauthorized_dm_behavior.py index 0dbe457a8..02aae301c 100644 --- a/tests/gateway/test_unauthorized_dm_behavior.py +++ b/tests/gateway/test_unauthorized_dm_behavior.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +import gateway.run as gateway_run from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import MessageEvent from gateway.session import SessionSource @@ -19,7 +20,7 @@ def _clear_auth_env(monkeypatch) -> None: "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", "MATRIX_ALLOWED_USERS", - "DINGTALK_ALLOWED_USERS", + "DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS", "TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS", @@ -30,7 +31,7 @@ def _clear_auth_env(monkeypatch) -> None: "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", "MATRIX_ALLOW_ALL_USERS", - "DINGTALK_ALLOW_ALL_USERS", + "DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", "WECOM_ALLOW_ALL_USERS", "GATEWAY_ALLOW_ALL_USERS", ): monkeypatch.delenv(key, raising=False) @@ -62,6 +63,32 @@ def _make_runner(platform: Platform, config: GatewayConfig): return runner, adapter +def test_whatsapp_lid_user_matches_phone_allowlist_via_session_mapping(monkeypatch, tmp_path): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("WHATSAPP_ALLOWED_USERS", "15550000001") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + + session_dir = tmp_path / "whatsapp" / "session" + session_dir.mkdir(parents=True) + (session_dir / "lid-mapping-15550000001.json").write_text('"900000000000001"', encoding="utf-8") + (session_dir / "lid-mapping-900000000000001_reverse.json").write_text('"15550000001"', encoding="utf-8") + + runner, _adapter = _make_runner( + Platform.WHATSAPP, + GatewayConfig(platforms={Platform.WHATSAPP: PlatformConfig(enabled=True)}), + ) + + source = SessionSource( + platform=Platform.WHATSAPP, + user_id="900000000000001@lid", + chat_id="900000000000001@lid", + user_name="tester", + chat_type="dm", + ) + + assert runner._is_user_authorized(source) is True + + @pytest.mark.asyncio async def test_unauthorized_dm_pairs_by_default(monkeypatch): _clear_auth_env(monkeypatch) diff --git a/tests/gateway/test_webhook_dynamic_routes.py b/tests/gateway/test_webhook_dynamic_routes.py new file mode 100644 index 000000000..2029dd139 --- /dev/null +++ b/tests/gateway/test_webhook_dynamic_routes.py @@ -0,0 +1,87 @@ +"""Tests for webhook adapter dynamic route loading.""" + +import json +import os +import pytest +from pathlib import Path + +from gateway.config import PlatformConfig +from gateway.platforms.webhook import WebhookAdapter, _DYNAMIC_ROUTES_FILENAME + + +def _make_adapter(routes=None, extra=None): + _extra = extra or {} + if routes: + _extra["routes"] = routes + _extra.setdefault("secret", "test-global-secret") + config = PlatformConfig(enabled=True, extra=_extra) + return WebhookAdapter(config) + + +@pytest.fixture(autouse=True) +def _isolate(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + +class TestDynamicRouteLoading: + def test_no_dynamic_file(self): + adapter = _make_adapter(routes={"static": {"secret": "s"}}) + adapter._reload_dynamic_routes() + assert "static" in adapter._routes + assert len(adapter._dynamic_routes) == 0 + + def test_loads_dynamic_routes(self, tmp_path): + subs = {"my-hook": {"secret": "dynamic-secret", "prompt": "test", "events": []}} + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(json.dumps(subs)) + + adapter = _make_adapter(routes={"static": {"secret": "s"}}) + adapter._reload_dynamic_routes() + assert "my-hook" in adapter._routes + assert "static" in adapter._routes + + def test_static_takes_precedence(self, tmp_path): + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"conflict": {"secret": "dynamic", "prompt": "dyn"}}) + ) + adapter = _make_adapter(routes={"conflict": {"secret": "static", "prompt": "stat"}}) + adapter._reload_dynamic_routes() + assert adapter._routes["conflict"]["secret"] == "static" + + def test_mtime_gated(self, tmp_path): + import time + path = tmp_path / _DYNAMIC_ROUTES_FILENAME + path.write_text(json.dumps({"v1": {"secret": "s"}})) + + adapter = _make_adapter() + adapter._reload_dynamic_routes() + assert "v1" in adapter._dynamic_routes + + # Same mtime — no reload + adapter._dynamic_routes["injected"] = True + adapter._reload_dynamic_routes() + assert "injected" in adapter._dynamic_routes + + # New write — reloads + time.sleep(0.05) + path.write_text(json.dumps({"v2": {"secret": "s"}})) + adapter._reload_dynamic_routes() + assert "v2" in adapter._dynamic_routes + assert "v1" not in adapter._dynamic_routes + + def test_file_removal_clears(self, tmp_path): + path = tmp_path / _DYNAMIC_ROUTES_FILENAME + path.write_text(json.dumps({"temp": {"secret": "s"}})) + adapter = _make_adapter() + adapter._reload_dynamic_routes() + assert "temp" in adapter._dynamic_routes + + path.unlink() + adapter._reload_dynamic_routes() + assert len(adapter._dynamic_routes) == 0 + + def test_corrupted_file(self, tmp_path): + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text("not json") + adapter = _make_adapter(routes={"static": {"secret": "s"}}) + adapter._reload_dynamic_routes() + assert "static" in adapter._routes + assert len(adapter._dynamic_routes) == 0 diff --git a/tests/gateway/test_wecom.py b/tests/gateway/test_wecom.py new file mode 100644 index 000000000..a7101c697 --- /dev/null +++ b/tests/gateway/test_wecom.py @@ -0,0 +1,596 @@ +"""Tests for the WeCom platform adapter.""" + +import base64 +import os +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import SendResult + + +class TestWeComRequirements: + def test_returns_false_without_aiohttp(self, monkeypatch): + monkeypatch.setattr("gateway.platforms.wecom.AIOHTTP_AVAILABLE", False) + monkeypatch.setattr("gateway.platforms.wecom.HTTPX_AVAILABLE", True) + from gateway.platforms.wecom import check_wecom_requirements + + assert check_wecom_requirements() is False + + def test_returns_false_without_httpx(self, monkeypatch): + monkeypatch.setattr("gateway.platforms.wecom.AIOHTTP_AVAILABLE", True) + monkeypatch.setattr("gateway.platforms.wecom.HTTPX_AVAILABLE", False) + from gateway.platforms.wecom import check_wecom_requirements + + assert check_wecom_requirements() is False + + def test_returns_true_when_available(self, monkeypatch): + monkeypatch.setattr("gateway.platforms.wecom.AIOHTTP_AVAILABLE", True) + monkeypatch.setattr("gateway.platforms.wecom.HTTPX_AVAILABLE", True) + from gateway.platforms.wecom import check_wecom_requirements + + assert check_wecom_requirements() is True + + +class TestWeComAdapterInit: + def test_reads_config_from_extra(self): + from gateway.platforms.wecom import WeComAdapter + + config = PlatformConfig( + enabled=True, + extra={ + "bot_id": "cfg-bot", + "secret": "cfg-secret", + "websocket_url": "wss://custom.wecom.example/ws", + "group_policy": "allowlist", + "group_allow_from": ["group-1"], + }, + ) + adapter = WeComAdapter(config) + + assert adapter._bot_id == "cfg-bot" + assert adapter._secret == "cfg-secret" + assert adapter._ws_url == "wss://custom.wecom.example/ws" + assert adapter._group_policy == "allowlist" + assert adapter._group_allow_from == ["group-1"] + + def test_falls_back_to_env_vars(self, monkeypatch): + monkeypatch.setenv("WECOM_BOT_ID", "env-bot") + monkeypatch.setenv("WECOM_SECRET", "env-secret") + monkeypatch.setenv("WECOM_WEBSOCKET_URL", "wss://env.example/ws") + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + assert adapter._bot_id == "env-bot" + assert adapter._secret == "env-secret" + assert adapter._ws_url == "wss://env.example/ws" + + +class TestWeComConnect: + @pytest.mark.asyncio + async def test_connect_records_missing_credentials(self, monkeypatch): + import gateway.platforms.wecom as wecom_module + from gateway.platforms.wecom import WeComAdapter + + monkeypatch.setattr(wecom_module, "AIOHTTP_AVAILABLE", True) + monkeypatch.setattr(wecom_module, "HTTPX_AVAILABLE", True) + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + + success = await adapter.connect() + + assert success is False + assert adapter.has_fatal_error is True + assert adapter.fatal_error_code == "wecom_missing_credentials" + assert "WECOM_BOT_ID" in (adapter.fatal_error_message or "") + + @pytest.mark.asyncio + async def test_connect_records_handshake_failure_details(self, monkeypatch): + import gateway.platforms.wecom as wecom_module + from gateway.platforms.wecom import WeComAdapter + + class DummyClient: + async def aclose(self): + return None + + monkeypatch.setattr(wecom_module, "AIOHTTP_AVAILABLE", True) + monkeypatch.setattr(wecom_module, "HTTPX_AVAILABLE", True) + monkeypatch.setattr( + wecom_module, + "httpx", + SimpleNamespace(AsyncClient=lambda **kwargs: DummyClient()), + ) + + adapter = WeComAdapter( + PlatformConfig(enabled=True, extra={"bot_id": "bot-1", "secret": "secret-1"}) + ) + adapter._open_connection = AsyncMock(side_effect=RuntimeError("invalid secret (errcode=40013)")) + + success = await adapter.connect() + + assert success is False + assert adapter.has_fatal_error is True + assert adapter.fatal_error_code == "wecom_connect_error" + assert "invalid secret" in (adapter.fatal_error_message or "") + + +class TestWeComReplyMode: + @pytest.mark.asyncio + async def test_send_uses_passive_reply_stream_when_reply_context_exists(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._reply_req_ids["msg-1"] = "req-1" + adapter._send_reply_request = AsyncMock( + return_value={"headers": {"req_id": "req-1"}, "errcode": 0} + ) + + result = await adapter.send("chat-123", "hello from reply", reply_to="msg-1") + + assert result.success is True + adapter._send_reply_request.assert_awaited_once() + args = adapter._send_reply_request.await_args.args + assert args[0] == "req-1" + assert args[1]["msgtype"] == "stream" + assert args[1]["stream"]["finish"] is True + assert args[1]["stream"]["content"] == "hello from reply" + + @pytest.mark.asyncio + async def test_send_image_file_uses_passive_reply_media_when_reply_context_exists(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._reply_req_ids["msg-1"] = "req-1" + adapter._prepare_outbound_media = AsyncMock( + return_value={ + "data": b"image-bytes", + "content_type": "image/png", + "file_name": "demo.png", + "detected_type": "image", + "final_type": "image", + "rejected": False, + "reject_reason": None, + "downgraded": False, + "downgrade_note": None, + } + ) + adapter._upload_media_bytes = AsyncMock(return_value={"media_id": "media-1", "type": "image"}) + adapter._send_reply_request = AsyncMock( + return_value={"headers": {"req_id": "req-1"}, "errcode": 0} + ) + + result = await adapter.send_image_file("chat-123", "/tmp/demo.png", reply_to="msg-1") + + assert result.success is True + adapter._send_reply_request.assert_awaited_once() + args = adapter._send_reply_request.await_args.args + assert args[0] == "req-1" + assert args[1] == {"msgtype": "image", "image": {"media_id": "media-1"}} + + +class TestExtractText: + def test_extracts_plain_text(self): + from gateway.platforms.wecom import WeComAdapter + + body = { + "msgtype": "text", + "text": {"content": " hello world "}, + } + text, reply_text = WeComAdapter._extract_text(body) + assert text == "hello world" + assert reply_text is None + + def test_extracts_mixed_text(self): + from gateway.platforms.wecom import WeComAdapter + + body = { + "msgtype": "mixed", + "mixed": { + "msg_item": [ + {"msgtype": "text", "text": {"content": "part1"}}, + {"msgtype": "image", "image": {"url": "https://example.com/x.png"}}, + {"msgtype": "text", "text": {"content": "part2"}}, + ] + }, + } + text, _reply_text = WeComAdapter._extract_text(body) + assert text == "part1\npart2" + + def test_extracts_voice_and_quote(self): + from gateway.platforms.wecom import WeComAdapter + + body = { + "msgtype": "voice", + "voice": {"content": "spoken text"}, + "quote": {"msgtype": "text", "text": {"content": "quoted"}}, + } + text, reply_text = WeComAdapter._extract_text(body) + assert text == "spoken text" + assert reply_text == "quoted" + + +class TestCallbackDispatch: + @pytest.mark.asyncio + @pytest.mark.parametrize("cmd", ["aibot_msg_callback", "aibot_callback"]) + async def test_dispatch_accepts_new_and_legacy_callback_cmds(self, cmd): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._on_message = AsyncMock() + + await adapter._dispatch_payload({"cmd": cmd, "headers": {"req_id": "req-1"}, "body": {}}) + + adapter._on_message.assert_awaited_once() + + +class TestPolicyHelpers: + def test_dm_allowlist(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter( + PlatformConfig(enabled=True, extra={"dm_policy": "allowlist", "allow_from": ["user-1"]}) + ) + assert adapter._is_dm_allowed("user-1") is True + assert adapter._is_dm_allowed("user-2") is False + + def test_group_allowlist_and_per_group_sender_allowlist(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter( + PlatformConfig( + enabled=True, + extra={ + "group_policy": "allowlist", + "group_allow_from": ["group-1"], + "groups": {"group-1": {"allow_from": ["user-1"]}}, + }, + ) + ) + + assert adapter._is_group_allowed("group-1", "user-1") is True + assert adapter._is_group_allowed("group-1", "user-2") is False + assert adapter._is_group_allowed("group-2", "user-1") is False + + +class TestMediaHelpers: + def test_detect_wecom_media_type(self): + from gateway.platforms.wecom import WeComAdapter + + assert WeComAdapter._detect_wecom_media_type("image/png") == "image" + assert WeComAdapter._detect_wecom_media_type("video/mp4") == "video" + assert WeComAdapter._detect_wecom_media_type("audio/amr") == "voice" + assert WeComAdapter._detect_wecom_media_type("application/pdf") == "file" + + def test_voice_non_amr_downgrades_to_file(self): + from gateway.platforms.wecom import WeComAdapter + + result = WeComAdapter._apply_file_size_limits(128, "voice", "audio/mpeg") + + assert result["final_type"] == "file" + assert result["downgraded"] is True + assert "AMR" in (result["downgrade_note"] or "") + + def test_oversized_file_is_rejected(self): + from gateway.platforms.wecom import ABSOLUTE_MAX_BYTES, WeComAdapter + + result = WeComAdapter._apply_file_size_limits(ABSOLUTE_MAX_BYTES + 1, "file", "application/pdf") + + assert result["rejected"] is True + assert "20MB" in (result["reject_reason"] or "") + + def test_decrypt_file_bytes_round_trip(self): + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + from gateway.platforms.wecom import WeComAdapter + + plaintext = b"wecom-secret" + key = os.urandom(32) + pad_len = 32 - (len(plaintext) % 32) + padded = plaintext + bytes([pad_len]) * pad_len + encryptor = Cipher(algorithms.AES(key), modes.CBC(key[:16])).encryptor() + encrypted = encryptor.update(padded) + encryptor.finalize() + + decrypted = WeComAdapter._decrypt_file_bytes(encrypted, base64.b64encode(key).decode("ascii")) + + assert decrypted == plaintext + + @pytest.mark.asyncio + async def test_load_outbound_media_rejects_placeholder_path(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + + with pytest.raises(ValueError, match="placeholder was not replaced"): + await adapter._load_outbound_media("") + + +class TestMediaUpload: + @pytest.mark.asyncio + async def test_upload_media_bytes_uses_sdk_sequence(self, monkeypatch): + import gateway.platforms.wecom as wecom_module + from gateway.platforms.wecom import ( + APP_CMD_UPLOAD_MEDIA_CHUNK, + APP_CMD_UPLOAD_MEDIA_FINISH, + APP_CMD_UPLOAD_MEDIA_INIT, + WeComAdapter, + ) + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + calls = [] + + async def fake_send_request(cmd, body, timeout=0): + calls.append((cmd, body)) + if cmd == APP_CMD_UPLOAD_MEDIA_INIT: + return {"errcode": 0, "body": {"upload_id": "upload-1"}} + if cmd == APP_CMD_UPLOAD_MEDIA_CHUNK: + return {"errcode": 0} + if cmd == APP_CMD_UPLOAD_MEDIA_FINISH: + return { + "errcode": 0, + "body": { + "media_id": "media-1", + "type": "file", + "created_at": "2026-03-18T00:00:00Z", + }, + } + raise AssertionError(f"unexpected cmd {cmd}") + + monkeypatch.setattr(wecom_module, "UPLOAD_CHUNK_SIZE", 4) + adapter._send_request = fake_send_request + + result = await adapter._upload_media_bytes(b"abcdefghij", "file", "demo.bin") + + assert result["media_id"] == "media-1" + assert [cmd for cmd, _body in calls] == [ + APP_CMD_UPLOAD_MEDIA_INIT, + APP_CMD_UPLOAD_MEDIA_CHUNK, + APP_CMD_UPLOAD_MEDIA_CHUNK, + APP_CMD_UPLOAD_MEDIA_CHUNK, + APP_CMD_UPLOAD_MEDIA_FINISH, + ] + assert calls[1][1]["chunk_index"] == 0 + assert calls[2][1]["chunk_index"] == 1 + assert calls[3][1]["chunk_index"] == 2 + + @pytest.mark.asyncio + async def test_download_remote_bytes_rejects_large_content_length(self): + from gateway.platforms.wecom import WeComAdapter + + class FakeResponse: + headers = {"content-length": "10"} + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def raise_for_status(self): + return None + + async def aiter_bytes(self): + yield b"abc" + + class FakeClient: + def stream(self, method, url, headers=None): + return FakeResponse() + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._http_client = FakeClient() + + with pytest.raises(ValueError, match="exceeds WeCom limit"): + await adapter._download_remote_bytes("https://example.com/file.bin", max_bytes=4) + + @pytest.mark.asyncio + async def test_cache_media_decrypts_url_payload_before_writing(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + plaintext = b"secret document bytes" + key = os.urandom(32) + pad_len = 32 - (len(plaintext) % 32) + padded = plaintext + bytes([pad_len]) * pad_len + + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + encryptor = Cipher(algorithms.AES(key), modes.CBC(key[:16])).encryptor() + encrypted = encryptor.update(padded) + encryptor.finalize() + adapter._download_remote_bytes = AsyncMock( + return_value=( + encrypted, + { + "content-type": "application/octet-stream", + "content-disposition": 'attachment; filename="secret.bin"', + }, + ) + ) + + cached = await adapter._cache_media( + "file", + { + "url": "https://example.com/secret.bin", + "aeskey": base64.b64encode(key).decode("ascii"), + }, + ) + + assert cached is not None + cached_path, content_type = cached + assert Path(cached_path).read_bytes() == plaintext + assert content_type == "application/octet-stream" + + +class TestSend: + @pytest.mark.asyncio + async def test_send_uses_proactive_payload(self): + from gateway.platforms.wecom import APP_CMD_SEND, WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._send_request = AsyncMock(return_value={"headers": {"req_id": "req-1"}, "errcode": 0}) + + result = await adapter.send("chat-123", "Hello WeCom") + + assert result.success is True + adapter._send_request.assert_awaited_once_with( + APP_CMD_SEND, + { + "chatid": "chat-123", + "msgtype": "markdown", + "markdown": {"content": "Hello WeCom"}, + }, + ) + + @pytest.mark.asyncio + async def test_send_reports_wecom_errors(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._send_request = AsyncMock(return_value={"errcode": 40001, "errmsg": "bad request"}) + + result = await adapter.send("chat-123", "Hello WeCom") + + assert result.success is False + assert "40001" in (result.error or "") + + @pytest.mark.asyncio + async def test_send_image_falls_back_to_text_for_remote_url(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._send_media_source = AsyncMock(return_value=SendResult(success=False, error="upload failed")) + adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="msg-1")) + + result = await adapter.send_image("chat-123", "https://example.com/demo.png", caption="demo") + + assert result.success is True + adapter.send.assert_awaited_once_with(chat_id="chat-123", content="demo\nhttps://example.com/demo.png", reply_to=None) + + @pytest.mark.asyncio + async def test_send_voice_sends_caption_and_downgrade_note(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._prepare_outbound_media = AsyncMock( + return_value={ + "data": b"voice-bytes", + "content_type": "audio/mpeg", + "file_name": "voice.mp3", + "detected_type": "voice", + "final_type": "file", + "rejected": False, + "reject_reason": None, + "downgraded": True, + "downgrade_note": "语音格式 audio/mpeg 不支持,企微仅支持 AMR 格式,已转为文件格式发送", + } + ) + adapter._upload_media_bytes = AsyncMock(return_value={"media_id": "media-1", "type": "file"}) + adapter._send_media_message = AsyncMock(return_value={"headers": {"req_id": "req-media"}, "errcode": 0}) + adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="msg-1")) + + result = await adapter.send_voice("chat-123", "/tmp/voice.mp3", caption="listen") + + assert result.success is True + adapter._send_media_message.assert_awaited_once_with("chat-123", "file", "media-1") + assert adapter.send.await_count == 2 + adapter.send.assert_any_await(chat_id="chat-123", content="listen", reply_to=None) + adapter.send.assert_any_await( + chat_id="chat-123", + content="ℹ️ 语音格式 audio/mpeg 不支持,企微仅支持 AMR 格式,已转为文件格式发送", + reply_to=None, + ) + + +class TestInboundMessages: + @pytest.mark.asyncio + async def test_on_message_builds_event(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter.handle_message = AsyncMock() + adapter._extract_media = AsyncMock(return_value=(["/tmp/test.png"], ["image/png"])) + + payload = { + "cmd": "aibot_msg_callback", + "headers": {"req_id": "req-1"}, + "body": { + "msgid": "msg-1", + "chatid": "group-1", + "chattype": "group", + "from": {"userid": "user-1"}, + "msgtype": "text", + "text": {"content": "hello"}, + }, + } + + await adapter._on_message(payload) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "hello" + assert event.source.chat_id == "group-1" + assert event.source.user_id == "user-1" + assert event.media_urls == ["/tmp/test.png"] + assert event.media_types == ["image/png"] + + @pytest.mark.asyncio + async def test_on_message_preserves_quote_context(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter.handle_message = AsyncMock() + adapter._extract_media = AsyncMock(return_value=([], [])) + + payload = { + "cmd": "aibot_msg_callback", + "headers": {"req_id": "req-1"}, + "body": { + "msgid": "msg-1", + "chatid": "group-1", + "chattype": "group", + "from": {"userid": "user-1"}, + "msgtype": "text", + "text": {"content": "follow up"}, + "quote": {"msgtype": "text", "text": {"content": "quoted message"}}, + }, + } + + await adapter._on_message(payload) + + event = adapter.handle_message.await_args.args[0] + assert event.reply_to_text == "quoted message" + assert event.reply_to_message_id == "quote:msg-1" + + @pytest.mark.asyncio + async def test_on_message_respects_group_policy(self): + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter( + PlatformConfig( + enabled=True, + extra={"group_policy": "allowlist", "group_allow_from": ["group-allowed"]}, + ) + ) + adapter.handle_message = AsyncMock() + adapter._extract_media = AsyncMock(return_value=([], [])) + + payload = { + "cmd": "aibot_callback", + "headers": {"req_id": "req-1"}, + "body": { + "msgid": "msg-1", + "chatid": "group-blocked", + "chattype": "group", + "from": {"userid": "user-1"}, + "msgtype": "text", + "text": {"content": "hello"}, + }, + } + + await adapter._on_message(payload) + adapter.handle_message.assert_not_awaited() + + +class TestPlatformEnum: + def test_wecom_in_platform_enum(self): + assert Platform.WECOM.value == "wecom" diff --git a/tests/gateway/test_whatsapp_connect.py b/tests/gateway/test_whatsapp_connect.py index 7a2126bb8..61ff8f361 100644 --- a/tests/gateway/test_whatsapp_connect.py +++ b/tests/gateway/test_whatsapp_connect.py @@ -63,6 +63,7 @@ def _make_adapter(): adapter._background_tasks = set() adapter._auto_tts_disabled_chats = set() adapter._message_queue = asyncio.Queue() + adapter._http_session = None return adapter @@ -219,6 +220,7 @@ class TestBridgeRuntimeFailure: fatal_handler = AsyncMock() adapter.set_fatal_error_handler(fatal_handler) adapter._running = True + adapter._http_session = MagicMock() # Persistent session active mock_fh = MagicMock() adapter._bridge_log_fh = mock_fh @@ -242,6 +244,7 @@ class TestBridgeRuntimeFailure: fatal_handler = AsyncMock() adapter.set_fatal_error_handler(fatal_handler) adapter._running = True + adapter._http_session = MagicMock() # Persistent session active mock_fh = MagicMock() adapter._bridge_log_fh = mock_fh @@ -417,3 +420,83 @@ class TestKillPortProcess: with patch("gateway.platforms.whatsapp._IS_WINDOWS", True), \ patch("gateway.platforms.whatsapp.subprocess.run", side_effect=OSError("no netstat")): _kill_port_process(3000) # must not raise + + +# --------------------------------------------------------------------------- +# Persistent HTTP session lifecycle +# --------------------------------------------------------------------------- + +class TestHttpSessionLifecycle: + """Verify persistent aiohttp.ClientSession is created and cleaned up.""" + + @pytest.mark.asyncio + async def test_session_closed_on_disconnect(self): + """disconnect() should close self._http_session.""" + adapter = _make_adapter() + mock_session = AsyncMock() + mock_session.closed = False + adapter._http_session = mock_session + adapter._poll_task = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_session.close.assert_called_once() + assert adapter._http_session is None + + @pytest.mark.asyncio + async def test_session_not_closed_when_already_closed(self): + """disconnect() should skip close() when session is already closed.""" + adapter = _make_adapter() + mock_session = AsyncMock() + mock_session.closed = True + adapter._http_session = mock_session + adapter._poll_task = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_session.close.assert_not_called() + assert adapter._http_session is None + + @pytest.mark.asyncio + async def test_poll_task_cancelled_on_disconnect(self): + """disconnect() should cancel the poll task.""" + adapter = _make_adapter() + mock_task = MagicMock() + mock_task.done.return_value = False + mock_task.cancel = MagicMock() + mock_future = asyncio.Future() + mock_future.set_exception(asyncio.CancelledError()) + mock_task.__await__ = mock_future.__await__ + adapter._poll_task = mock_task + adapter._http_session = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_task.cancel.assert_called_once() + assert adapter._poll_task is None + + @pytest.mark.asyncio + async def test_disconnect_skips_done_poll_task(self): + """disconnect() should not cancel an already-done poll task.""" + adapter = _make_adapter() + mock_task = MagicMock() + mock_task.done.return_value = True + adapter._poll_task = mock_task + adapter._http_session = None + adapter._bridge_process = None + adapter._running = True + adapter._session_lock_identity = None + + await adapter.disconnect() + + mock_task.cancel.assert_not_called() + assert adapter._poll_task is None diff --git a/tests/hermes_cli/test_cmd_update.py b/tests/hermes_cli/test_cmd_update.py index 0ccb7af81..9ffa809a5 100644 --- a/tests/hermes_cli/test_cmd_update.py +++ b/tests/hermes_cli/test_cmd_update.py @@ -105,3 +105,24 @@ class TestCmdUpdateBranchFallback: commands = [" ".join(str(a) for a in c.args[0]) for c in mock_run.call_args_list] pull_cmds = [c for c in commands if "pull" in c] assert len(pull_cmds) == 0 + + def test_update_non_interactive_skips_migration_prompt(self, mock_args, capsys): + """When stdin/stdout aren't TTYs, config migration prompt is skipped.""" + with patch("shutil.which", return_value=None), patch( + "subprocess.run" + ) as mock_run, patch("builtins.input") as mock_input, patch( + "hermes_cli.config.get_missing_env_vars", return_value=["MISSING_KEY"] + ), patch("hermes_cli.config.get_missing_config_fields", return_value=[]), patch( + "hermes_cli.config.check_config_version", return_value=(1, 2) + ), patch("hermes_cli.main.sys") as mock_sys: + mock_sys.stdin.isatty.return_value = False + mock_sys.stdout.isatty.return_value = False + mock_run.side_effect = _make_run_side_effect( + branch="main", verify_ok=True, commit_count="1" + ) + + cmd_update(mock_args) + + mock_input.assert_not_called() + captured = capsys.readouterr() + assert "Non-interactive session" in captured.out diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index 0189f036b..87daa845b 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -1,6 +1,7 @@ """Tests for gateway service management helpers.""" import os +from pathlib import Path from types import SimpleNamespace import hermes_cli.gateway as gateway_cli @@ -152,12 +153,13 @@ class TestLaunchdServiceRecovery: def test_launchd_start_reloads_unloaded_job_and_retries(self, tmp_path, monkeypatch): plist_path = tmp_path / "ai.hermes.gateway.plist" plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8") + label = gateway_cli.get_launchd_label() calls = [] def fake_run(cmd, check=False, **kwargs): calls.append(cmd) - if cmd == ["launchctl", "start", "ai.hermes.gateway"] and calls.count(cmd) == 1: + if cmd == ["launchctl", "start", label] and calls.count(cmd) == 1: raise gateway_cli.subprocess.CalledProcessError(3, cmd, stderr="Could not find service") return SimpleNamespace(returncode=0, stdout="", stderr="") @@ -167,9 +169,9 @@ class TestLaunchdServiceRecovery: gateway_cli.launchd_start() assert calls == [ - ["launchctl", "start", "ai.hermes.gateway"], + ["launchctl", "start", label], ["launchctl", "load", str(plist_path)], - ["launchctl", "start", "ai.hermes.gateway"], + ["launchctl", "start", label], ] def test_launchd_status_reports_local_stale_plist_when_unloaded(self, tmp_path, monkeypatch, capsys): @@ -354,6 +356,20 @@ class TestGeneratedUnitUsesDetectedVenv: assert "/venv/" not in unit or "/.venv/" in unit +class TestGeneratedUnitIncludesLocalBin: + """~/.local/bin must be in PATH so uvx/pipx tools are discoverable.""" + + def test_user_unit_includes_local_bin_in_path(self): + unit = gateway_cli.generate_systemd_unit(system=False) + home = str(Path.home()) + assert f"{home}/.local/bin" in unit + + def test_system_unit_includes_local_bin_in_path(self): + unit = gateway_cli.generate_systemd_unit(system=True) + # System unit uses the resolved home dir from _system_service_identity + assert "/.local/bin" in unit + + class TestEnsureUserSystemdEnv: """Tests for _ensure_user_systemd_env() D-Bus session bus auto-detection.""" diff --git a/tests/hermes_cli/test_nous_subscription.py b/tests/hermes_cli/test_nous_subscription.py new file mode 100644 index 000000000..2c41c0b22 --- /dev/null +++ b/tests/hermes_cli/test_nous_subscription.py @@ -0,0 +1,44 @@ +"""Tests for Nous subscription feature detection.""" + +from hermes_cli import nous_subscription as ns + + +def test_get_nous_subscription_features_recognizes_direct_exa_backend(monkeypatch): + env = {"EXA_API_KEY": "exa-test"} + + monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, "")) + monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {}) + monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: False) + monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "web") + monkeypatch.setattr(ns, "_has_agent_browser", lambda: False) + monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "") + monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False) + + features = ns.get_nous_subscription_features({"web": {"backend": "exa"}}) + + assert features.web.available is True + assert features.web.active is True + assert features.web.managed_by_nous is False + assert features.web.direct_override is True + assert features.web.current_provider == "exa" + + +def test_get_nous_subscription_features_prefers_managed_modal_in_auto_mode(monkeypatch): + monkeypatch.setenv("HERMES_ENABLE_NOUS_MANAGED_TOOLS", "1") + monkeypatch.setattr(ns, "get_env_value", lambda name: "") + monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True}) + monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True) + monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "terminal") + monkeypatch.setattr(ns, "_has_agent_browser", lambda: False) + monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "") + monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: True) + monkeypatch.setattr(ns, "is_managed_tool_gateway_ready", lambda vendor: vendor == "modal") + + features = ns.get_nous_subscription_features( + {"terminal": {"backend": "modal", "modal_mode": "auto"}} + ) + + assert features.modal.available is True + assert features.modal.active is True + assert features.modal.managed_by_nous is True + assert features.modal.direct_override is False diff --git a/tests/hermes_cli/test_profiles.py b/tests/hermes_cli/test_profiles.py new file mode 100644 index 000000000..80152a4a0 --- /dev/null +++ b/tests/hermes_cli/test_profiles.py @@ -0,0 +1,622 @@ +"""Comprehensive tests for hermes_cli.profiles module. + +Tests cover: validation, directory resolution, CRUD operations, active profile +management, export/import, renaming, alias collision checks, profile isolation, +and shell completion generation. +""" + +import json +import os +import tarfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from hermes_cli.profiles import ( + validate_profile_name, + get_profile_dir, + create_profile, + delete_profile, + list_profiles, + set_active_profile, + get_active_profile, + get_active_profile_name, + resolve_profile_env, + check_alias_collision, + rename_profile, + export_profile, + import_profile, + generate_bash_completion, + generate_zsh_completion, + _get_profiles_root, + _get_default_hermes_home, +) + + +# --------------------------------------------------------------------------- +# Shared fixture: redirect Path.home() and HERMES_HOME for profile tests +# --------------------------------------------------------------------------- + +@pytest.fixture() +def profile_env(tmp_path, monkeypatch): + """Set up an isolated environment for profile tests. + + * Path.home() -> tmp_path (so _get_profiles_root() = tmp_path/.hermes/profiles) + * HERMES_HOME -> tmp_path/.hermes (so get_hermes_home() agrees) + * Creates the bare-minimum ~/.hermes directory. + """ + monkeypatch.setattr(Path, "home", lambda: tmp_path) + default_home = tmp_path / ".hermes" + default_home.mkdir(exist_ok=True) + monkeypatch.setenv("HERMES_HOME", str(default_home)) + return tmp_path + + +# =================================================================== +# TestValidateProfileName +# =================================================================== + +class TestValidateProfileName: + """Tests for validate_profile_name().""" + + @pytest.mark.parametrize("name", ["coder", "work-bot", "a1", "my_agent"]) + def test_valid_names_accepted(self, name): + # Should not raise + validate_profile_name(name) + + @pytest.mark.parametrize("name", ["UPPER", "has space", ".hidden", "-leading"]) + def test_invalid_names_rejected(self, name): + with pytest.raises(ValueError): + validate_profile_name(name) + + def test_too_long_rejected(self): + long_name = "a" * 65 + with pytest.raises(ValueError): + validate_profile_name(long_name) + + def test_max_length_accepted(self): + # 64 chars total: 1 leading + 63 remaining = 64, within [0,63] range + name = "a" * 64 + validate_profile_name(name) + + def test_default_accepted(self): + # 'default' is a special-case pass-through + validate_profile_name("default") + + def test_empty_string_rejected(self): + with pytest.raises(ValueError): + validate_profile_name("") + + +# =================================================================== +# TestGetProfileDir +# =================================================================== + +class TestGetProfileDir: + """Tests for get_profile_dir().""" + + def test_default_returns_hermes_home(self, profile_env): + tmp_path = profile_env + result = get_profile_dir("default") + assert result == tmp_path / ".hermes" + + def test_named_profile_returns_profiles_subdir(self, profile_env): + tmp_path = profile_env + result = get_profile_dir("coder") + assert result == tmp_path / ".hermes" / "profiles" / "coder" + + +# =================================================================== +# TestCreateProfile +# =================================================================== + +class TestCreateProfile: + """Tests for create_profile().""" + + def test_creates_directory_with_subdirs(self, profile_env): + profile_dir = create_profile("coder", no_alias=True) + assert profile_dir.is_dir() + for subdir in ["memories", "sessions", "skills", "skins", "logs", + "plans", "workspace", "cron"]: + assert (profile_dir / subdir).is_dir(), f"Missing subdir: {subdir}" + + def test_duplicate_raises_file_exists(self, profile_env): + create_profile("coder", no_alias=True) + with pytest.raises(FileExistsError): + create_profile("coder", no_alias=True) + + def test_default_raises_value_error(self, profile_env): + with pytest.raises(ValueError, match="default"): + create_profile("default", no_alias=True) + + def test_invalid_name_raises_value_error(self, profile_env): + with pytest.raises(ValueError): + create_profile("INVALID!", no_alias=True) + + def test_clone_config_copies_files(self, profile_env): + tmp_path = profile_env + default_home = tmp_path / ".hermes" + # Create source config files in default profile + (default_home / "config.yaml").write_text("model: test") + (default_home / ".env").write_text("KEY=val") + (default_home / "SOUL.md").write_text("Be helpful.") + + profile_dir = create_profile("coder", clone_config=True, no_alias=True) + + assert (profile_dir / "config.yaml").read_text() == "model: test" + assert (profile_dir / ".env").read_text() == "KEY=val" + assert (profile_dir / "SOUL.md").read_text() == "Be helpful." + + def test_clone_all_copies_entire_tree(self, profile_env): + tmp_path = profile_env + default_home = tmp_path / ".hermes" + # Populate default with some content + (default_home / "memories").mkdir(exist_ok=True) + (default_home / "memories" / "note.md").write_text("remember this") + (default_home / "config.yaml").write_text("model: gpt-4") + # Runtime files that should be stripped + (default_home / "gateway.pid").write_text("12345") + (default_home / "gateway_state.json").write_text("{}") + (default_home / "processes.json").write_text("[]") + + profile_dir = create_profile("coder", clone_all=True, no_alias=True) + + # Content should be copied + assert (profile_dir / "memories" / "note.md").read_text() == "remember this" + assert (profile_dir / "config.yaml").read_text() == "model: gpt-4" + # Runtime files should be stripped + assert not (profile_dir / "gateway.pid").exists() + assert not (profile_dir / "gateway_state.json").exists() + assert not (profile_dir / "processes.json").exists() + + def test_clone_config_missing_files_skipped(self, profile_env): + """Clone config gracefully skips files that don't exist in source.""" + profile_dir = create_profile("coder", clone_config=True, no_alias=True) + # No error; optional files just not copied + assert not (profile_dir / "config.yaml").exists() + assert not (profile_dir / ".env").exists() + assert not (profile_dir / "SOUL.md").exists() + + +# =================================================================== +# TestDeleteProfile +# =================================================================== + +class TestDeleteProfile: + """Tests for delete_profile().""" + + def test_removes_directory(self, profile_env): + profile_dir = create_profile("coder", no_alias=True) + assert profile_dir.is_dir() + # Mock gateway import to avoid real systemd/launchd interaction + with patch("hermes_cli.profiles._cleanup_gateway_service"): + delete_profile("coder", yes=True) + assert not profile_dir.is_dir() + + def test_default_raises_value_error(self, profile_env): + with pytest.raises(ValueError, match="default"): + delete_profile("default", yes=True) + + def test_nonexistent_raises_file_not_found(self, profile_env): + with pytest.raises(FileNotFoundError): + delete_profile("nonexistent", yes=True) + + +# =================================================================== +# TestListProfiles +# =================================================================== + +class TestListProfiles: + """Tests for list_profiles().""" + + def test_returns_default_when_no_named_profiles(self, profile_env): + profiles = list_profiles() + names = [p.name for p in profiles] + assert "default" in names + + def test_includes_named_profiles(self, profile_env): + create_profile("alpha", no_alias=True) + create_profile("beta", no_alias=True) + profiles = list_profiles() + names = [p.name for p in profiles] + assert "alpha" in names + assert "beta" in names + + def test_sorted_alphabetically(self, profile_env): + create_profile("zebra", no_alias=True) + create_profile("alpha", no_alias=True) + create_profile("middle", no_alias=True) + profiles = list_profiles() + named = [p.name for p in profiles if not p.is_default] + assert named == sorted(named) + + def test_default_is_first(self, profile_env): + create_profile("alpha", no_alias=True) + profiles = list_profiles() + assert profiles[0].name == "default" + assert profiles[0].is_default is True + + +# =================================================================== +# TestActiveProfile +# =================================================================== + +class TestActiveProfile: + """Tests for set_active_profile() / get_active_profile().""" + + def test_set_and_get_roundtrip(self, profile_env): + create_profile("coder", no_alias=True) + set_active_profile("coder") + assert get_active_profile() == "coder" + + def test_no_file_returns_default(self, profile_env): + assert get_active_profile() == "default" + + def test_empty_file_returns_default(self, profile_env): + tmp_path = profile_env + active_path = tmp_path / ".hermes" / "active_profile" + active_path.write_text("") + assert get_active_profile() == "default" + + def test_set_to_default_removes_file(self, profile_env): + tmp_path = profile_env + create_profile("coder", no_alias=True) + set_active_profile("coder") + active_path = tmp_path / ".hermes" / "active_profile" + assert active_path.exists() + + set_active_profile("default") + assert not active_path.exists() + + def test_set_nonexistent_raises(self, profile_env): + with pytest.raises(FileNotFoundError): + set_active_profile("nonexistent") + + +# =================================================================== +# TestGetActiveProfileName +# =================================================================== + +class TestGetActiveProfileName: + """Tests for get_active_profile_name().""" + + def test_default_hermes_home_returns_default(self, profile_env): + # HERMES_HOME points to tmp_path/.hermes which is the default + assert get_active_profile_name() == "default" + + def test_profile_path_returns_profile_name(self, profile_env, monkeypatch): + tmp_path = profile_env + create_profile("coder", no_alias=True) + profile_dir = tmp_path / ".hermes" / "profiles" / "coder" + monkeypatch.setenv("HERMES_HOME", str(profile_dir)) + assert get_active_profile_name() == "coder" + + def test_custom_path_returns_custom(self, profile_env, monkeypatch): + tmp_path = profile_env + custom = tmp_path / "some" / "other" / "path" + custom.mkdir(parents=True) + monkeypatch.setenv("HERMES_HOME", str(custom)) + assert get_active_profile_name() == "custom" + + +# =================================================================== +# TestResolveProfileEnv +# =================================================================== + +class TestResolveProfileEnv: + """Tests for resolve_profile_env().""" + + def test_existing_profile_returns_path(self, profile_env): + tmp_path = profile_env + create_profile("coder", no_alias=True) + result = resolve_profile_env("coder") + assert result == str(tmp_path / ".hermes" / "profiles" / "coder") + + def test_default_returns_default_home(self, profile_env): + tmp_path = profile_env + result = resolve_profile_env("default") + assert result == str(tmp_path / ".hermes") + + def test_nonexistent_raises_file_not_found(self, profile_env): + with pytest.raises(FileNotFoundError): + resolve_profile_env("nonexistent") + + def test_invalid_name_raises_value_error(self, profile_env): + with pytest.raises(ValueError): + resolve_profile_env("INVALID!") + + +# =================================================================== +# TestAliasCollision +# =================================================================== + +class TestAliasCollision: + """Tests for check_alias_collision().""" + + def test_normal_name_returns_none(self, profile_env): + # Mock 'which' to return not-found + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=1, stdout="") + result = check_alias_collision("mybot") + assert result is None + + def test_reserved_name_returns_message(self, profile_env): + result = check_alias_collision("hermes") + assert result is not None + assert "reserved" in result.lower() + + def test_subcommand_returns_message(self, profile_env): + result = check_alias_collision("chat") + assert result is not None + assert "subcommand" in result.lower() + + def test_default_is_reserved(self, profile_env): + result = check_alias_collision("default") + assert result is not None + assert "reserved" in result.lower() + + +# =================================================================== +# TestRenameProfile +# =================================================================== + +class TestRenameProfile: + """Tests for rename_profile().""" + + def test_renames_directory(self, profile_env): + tmp_path = profile_env + create_profile("oldname", no_alias=True) + old_dir = tmp_path / ".hermes" / "profiles" / "oldname" + assert old_dir.is_dir() + + # Mock alias collision to avoid subprocess calls + with patch("hermes_cli.profiles.check_alias_collision", return_value="skip"): + new_dir = rename_profile("oldname", "newname") + + assert not old_dir.is_dir() + assert new_dir.is_dir() + assert new_dir == tmp_path / ".hermes" / "profiles" / "newname" + + def test_default_raises_value_error(self, profile_env): + with pytest.raises(ValueError, match="default"): + rename_profile("default", "newname") + + def test_rename_to_default_raises_value_error(self, profile_env): + create_profile("coder", no_alias=True) + with pytest.raises(ValueError, match="default"): + rename_profile("coder", "default") + + def test_nonexistent_raises_file_not_found(self, profile_env): + with pytest.raises(FileNotFoundError): + rename_profile("nonexistent", "newname") + + def test_target_exists_raises_file_exists(self, profile_env): + create_profile("alpha", no_alias=True) + create_profile("beta", no_alias=True) + with pytest.raises(FileExistsError): + rename_profile("alpha", "beta") + + +# =================================================================== +# TestExportImport +# =================================================================== + +class TestExportImport: + """Tests for export_profile() / import_profile().""" + + def test_export_creates_tar_gz(self, profile_env, tmp_path): + create_profile("coder", no_alias=True) + # Put a marker file so we can verify content + profile_dir = get_profile_dir("coder") + (profile_dir / "marker.txt").write_text("hello") + + output = tmp_path / "export" / "coder.tar.gz" + output.parent.mkdir(parents=True, exist_ok=True) + result = export_profile("coder", str(output)) + + assert Path(result).exists() + assert tarfile.is_tarfile(str(result)) + + def test_import_restores_from_archive(self, profile_env, tmp_path): + # Create and export a profile + create_profile("coder", no_alias=True) + profile_dir = get_profile_dir("coder") + (profile_dir / "marker.txt").write_text("hello") + + archive_path = tmp_path / "export" / "coder.tar.gz" + archive_path.parent.mkdir(parents=True, exist_ok=True) + export_profile("coder", str(archive_path)) + + # Delete the profile, then import it back under a new name + import shutil + shutil.rmtree(profile_dir) + assert not profile_dir.is_dir() + + imported = import_profile(str(archive_path), name="coder") + assert imported.is_dir() + assert (imported / "marker.txt").read_text() == "hello" + + def test_import_to_existing_name_raises(self, profile_env, tmp_path): + create_profile("coder", no_alias=True) + profile_dir = get_profile_dir("coder") + + archive_path = tmp_path / "export" / "coder.tar.gz" + archive_path.parent.mkdir(parents=True, exist_ok=True) + export_profile("coder", str(archive_path)) + + # Importing to same existing name should fail + with pytest.raises(FileExistsError): + import_profile(str(archive_path), name="coder") + + def test_export_nonexistent_raises(self, profile_env, tmp_path): + with pytest.raises(FileNotFoundError): + export_profile("nonexistent", str(tmp_path / "out.tar.gz")) + + +# =================================================================== +# TestProfileIsolation +# =================================================================== + +class TestProfileIsolation: + """Verify that two profiles have completely separate paths.""" + + def test_separate_config_paths(self, profile_env): + create_profile("alpha", no_alias=True) + create_profile("beta", no_alias=True) + alpha_dir = get_profile_dir("alpha") + beta_dir = get_profile_dir("beta") + assert alpha_dir / "config.yaml" != beta_dir / "config.yaml" + assert str(alpha_dir) not in str(beta_dir) + + def test_separate_state_db_paths(self, profile_env): + alpha_dir = get_profile_dir("alpha") + beta_dir = get_profile_dir("beta") + assert alpha_dir / "state.db" != beta_dir / "state.db" + + def test_separate_skills_paths(self, profile_env): + create_profile("alpha", no_alias=True) + create_profile("beta", no_alias=True) + alpha_dir = get_profile_dir("alpha") + beta_dir = get_profile_dir("beta") + assert alpha_dir / "skills" != beta_dir / "skills" + # Verify both exist and are independent dirs + assert (alpha_dir / "skills").is_dir() + assert (beta_dir / "skills").is_dir() + + +# =================================================================== +# TestCompletion +# =================================================================== + +class TestCompletion: + """Tests for bash/zsh completion generators.""" + + def test_bash_completion_contains_complete(self): + script = generate_bash_completion() + assert len(script) > 0 + assert "complete" in script + + def test_zsh_completion_contains_compdef(self): + script = generate_zsh_completion() + assert len(script) > 0 + assert "compdef" in script + + def test_bash_completion_has_hermes_profiles_function(self): + script = generate_bash_completion() + assert "_hermes_profiles" in script + + def test_zsh_completion_has_hermes_function(self): + script = generate_zsh_completion() + assert "_hermes" in script + + +# =================================================================== +# TestGetProfilesRoot / TestGetDefaultHermesHome (internal helpers) +# =================================================================== + +class TestInternalHelpers: + """Tests for _get_profiles_root() and _get_default_hermes_home().""" + + def test_profiles_root_under_home(self, profile_env): + tmp_path = profile_env + root = _get_profiles_root() + assert root == tmp_path / ".hermes" / "profiles" + + def test_default_hermes_home(self, profile_env): + tmp_path = profile_env + home = _get_default_hermes_home() + assert home == tmp_path / ".hermes" + + +# =================================================================== +# Edge cases and additional coverage +# =================================================================== + +class TestEdgeCases: + """Additional edge-case tests.""" + + def test_create_profile_returns_correct_path(self, profile_env): + tmp_path = profile_env + result = create_profile("mybot", no_alias=True) + expected = tmp_path / ".hermes" / "profiles" / "mybot" + assert result == expected + + def test_list_profiles_default_info_fields(self, profile_env): + profiles = list_profiles() + default = [p for p in profiles if p.name == "default"][0] + assert default.is_default is True + assert default.gateway_running is False + assert default.skill_count == 0 + + def test_gateway_running_check_with_pid_file(self, profile_env): + """Verify _check_gateway_running reads pid file and probes os.kill.""" + from hermes_cli.profiles import _check_gateway_running + tmp_path = profile_env + default_home = tmp_path / ".hermes" + + # No pid file -> not running + assert _check_gateway_running(default_home) is False + + # Write a PID file with a JSON payload + pid_file = default_home / "gateway.pid" + pid_file.write_text(json.dumps({"pid": 99999})) + + # os.kill(99999, 0) should raise ProcessLookupError -> not running + assert _check_gateway_running(default_home) is False + + # Mock os.kill to simulate a running process + with patch("os.kill", return_value=None): + assert _check_gateway_running(default_home) is True + + def test_gateway_running_check_plain_pid(self, profile_env): + """Pid file containing just a number (legacy format).""" + from hermes_cli.profiles import _check_gateway_running + tmp_path = profile_env + default_home = tmp_path / ".hermes" + pid_file = default_home / "gateway.pid" + pid_file.write_text("99999") + + with patch("os.kill", return_value=None): + assert _check_gateway_running(default_home) is True + + def test_profile_name_boundary_single_char(self): + """Single alphanumeric character is valid.""" + validate_profile_name("a") + validate_profile_name("1") + + def test_profile_name_boundary_all_hyphens(self): + """Name starting with hyphen is invalid.""" + with pytest.raises(ValueError): + validate_profile_name("-abc") + + def test_profile_name_underscore_start(self): + """Name starting with underscore is invalid (must start with [a-z0-9]).""" + with pytest.raises(ValueError): + validate_profile_name("_abc") + + def test_clone_from_named_profile(self, profile_env): + """Clone config from a named (non-default) profile.""" + tmp_path = profile_env + # Create source profile with config + source_dir = create_profile("source", no_alias=True) + (source_dir / "config.yaml").write_text("model: cloned") + (source_dir / ".env").write_text("SECRET=yes") + + target_dir = create_profile( + "target", clone_from="source", clone_config=True, no_alias=True, + ) + assert (target_dir / "config.yaml").read_text() == "model: cloned" + assert (target_dir / ".env").read_text() == "SECRET=yes" + + def test_delete_clears_active_profile(self, profile_env): + """Deleting the active profile resets active to default.""" + tmp_path = profile_env + create_profile("coder", no_alias=True) + set_active_profile("coder") + assert get_active_profile() == "coder" + + with patch("hermes_cli.profiles._cleanup_gateway_service"): + delete_profile("coder", yes=True) + + assert get_active_profile() == "default" diff --git a/tests/hermes_cli/test_setup_openclaw_migration.py b/tests/hermes_cli/test_setup_openclaw_migration.py index be5d61bab..0991b6d1b 100644 --- a/tests/hermes_cli/test_setup_openclaw_migration.py +++ b/tests/hermes_cli/test_setup_openclaw_migration.py @@ -94,7 +94,7 @@ class TestOfferOpenclawMigration: fake_mod.Migrator.assert_called_once() call_kwargs = fake_mod.Migrator.call_args[1] assert call_kwargs["execute"] is True - assert call_kwargs["overwrite"] is False + assert call_kwargs["overwrite"] is True assert call_kwargs["migrate_secrets"] is True assert call_kwargs["preset_name"] == "full" fake_migrator.migrate.assert_called_once() @@ -285,3 +285,182 @@ class TestSetupWizardOpenclawIntegration: setup_mod.run_setup_wizard(args) mock_migration.assert_not_called() + + +# --------------------------------------------------------------------------- +# _get_section_config_summary / _skip_configured_section — unit tests +# --------------------------------------------------------------------------- + + +class TestGetSectionConfigSummary: + """Test the _get_section_config_summary helper.""" + + def test_model_returns_none_without_api_key(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._get_section_config_summary({}, "model") + assert result is None + + def test_model_returns_summary_with_api_key(self): + def env_side(key): + return "sk-xxx" if key == "OPENROUTER_API_KEY" else "" + + with patch.object(setup_mod, "get_env_value", side_effect=env_side): + result = setup_mod._get_section_config_summary( + {"model": "openai/gpt-4"}, "model" + ) + assert result == "openai/gpt-4" + + def test_model_returns_dict_default_key(self): + def env_side(key): + return "sk-xxx" if key == "OPENAI_API_KEY" else "" + + with patch.object(setup_mod, "get_env_value", side_effect=env_side): + result = setup_mod._get_section_config_summary( + {"model": {"default": "claude-opus-4", "provider": "anthropic"}}, + "model", + ) + assert result == "claude-opus-4" + + def test_terminal_always_returns(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._get_section_config_summary( + {"terminal": {"backend": "docker"}}, "terminal" + ) + assert result == "backend: docker" + + def test_agent_always_returns(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._get_section_config_summary( + {"agent": {"max_turns": 120}}, "agent" + ) + assert result == "max turns: 120" + + def test_gateway_returns_none_without_tokens(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._get_section_config_summary({}, "gateway") + assert result is None + + def test_gateway_lists_platforms(self): + def env_side(key): + if key == "TELEGRAM_BOT_TOKEN": + return "tok123" + if key == "DISCORD_BOT_TOKEN": + return "disc456" + return "" + + with patch.object(setup_mod, "get_env_value", side_effect=env_side): + result = setup_mod._get_section_config_summary({}, "gateway") + assert "Telegram" in result + assert "Discord" in result + + def test_tools_returns_none_without_keys(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._get_section_config_summary({}, "tools") + assert result is None + + def test_tools_lists_configured(self): + def env_side(key): + return "key" if key == "BROWSERBASE_API_KEY" else "" + + with patch.object(setup_mod, "get_env_value", side_effect=env_side): + result = setup_mod._get_section_config_summary({}, "tools") + assert "Browser" in result + + +class TestSkipConfiguredSection: + """Test the _skip_configured_section helper.""" + + def test_returns_false_when_not_configured(self): + with patch.object(setup_mod, "get_env_value", return_value=""): + result = setup_mod._skip_configured_section({}, "model", "Model") + assert result is False + + def test_returns_true_when_user_skips(self): + def env_side(key): + return "sk-xxx" if key == "OPENROUTER_API_KEY" else "" + + with ( + patch.object(setup_mod, "get_env_value", side_effect=env_side), + patch.object(setup_mod, "prompt_yes_no", return_value=False), + ): + result = setup_mod._skip_configured_section( + {"model": "openai/gpt-4"}, "model", "Model" + ) + assert result is True + + def test_returns_false_when_user_wants_reconfig(self): + def env_side(key): + return "sk-xxx" if key == "OPENROUTER_API_KEY" else "" + + with ( + patch.object(setup_mod, "get_env_value", side_effect=env_side), + patch.object(setup_mod, "prompt_yes_no", return_value=True), + ): + result = setup_mod._skip_configured_section( + {"model": "openai/gpt-4"}, "model", "Model" + ) + assert result is False + + +class TestSetupWizardSkipsConfiguredSections: + """After migration, already-configured sections should offer skip.""" + + def test_sections_skipped_when_migration_imported_settings(self, tmp_path): + """When migration ran and API key exists, model section should be skippable. + + Simulates the real flow: get_env_value returns "" during the is_existing + check (before migration), then returns a key after migration imported it. + """ + args = _first_time_args() + + # Track whether migration has "run" — after it does, API key is available + migration_done = {"value": False} + + def env_side(key): + if migration_done["value"] and key == "OPENROUTER_API_KEY": + return "sk-xxx" + return "" + + def fake_migration(hermes_home): + migration_done["value"] = True + return True + + reloaded_config = {"model": "openai/gpt-4"} + + with ( + patch.object(setup_mod, "ensure_hermes_home"), + patch.object( + setup_mod, "load_config", + side_effect=[{}, reloaded_config], + ), + patch.object(setup_mod, "get_hermes_home", return_value=tmp_path), + patch.object(setup_mod, "get_env_value", side_effect=env_side), + patch.object(setup_mod, "is_interactive_stdin", return_value=True), + patch("hermes_cli.auth.get_active_provider", return_value=None), + patch("builtins.input", return_value=""), + # Migration succeeds and flips the env_side flag + patch.object( + setup_mod, "_offer_openclaw_migration", + side_effect=fake_migration, + ), + # User says No to all reconfig prompts + patch.object(setup_mod, "prompt_yes_no", return_value=False), + patch.object(setup_mod, "setup_model_provider") as mock_model, + patch.object(setup_mod, "setup_terminal_backend") as mock_terminal, + patch.object(setup_mod, "setup_agent_settings") as mock_agent, + patch.object(setup_mod, "setup_gateway") as mock_gateway, + patch.object(setup_mod, "setup_tools") as mock_tools, + patch.object(setup_mod, "save_config"), + patch.object(setup_mod, "_print_setup_summary"), + ): + setup_mod.run_setup_wizard(args) + + # Model has API key → skip offered, user said No → section NOT called + mock_model.assert_not_called() + # Terminal/agent always have a summary → skip offered, user said No + mock_terminal.assert_not_called() + mock_agent.assert_not_called() + # Gateway has no tokens (env_side returns "" for gateway keys) → section runs + mock_gateway.assert_called_once() + # Tools have no keys → section runs + mock_tools.assert_called_once() diff --git a/tests/hermes_cli/test_skills_skip_confirm.py b/tests/hermes_cli/test_skills_skip_confirm.py index 7293a6b3c..fd430185f 100644 --- a/tests/hermes_cli/test_skills_skip_confirm.py +++ b/tests/hermes_cli/test_skills_skip_confirm.py @@ -1,10 +1,13 @@ """ -Tests for skip_confirm behavior in /skills install and /skills uninstall. +Tests for skip_confirm and invalidate_cache behavior in /skills install +and /skills uninstall slash commands. -Verifies that --yes / -y bypasses the interactive confirmation prompt -that hangs inside prompt_toolkit's TUI. +Slash commands always skip confirmation (input() hangs in TUI). +Cache invalidation is deferred by default; --now opts into immediate +invalidation (at the cost of breaking prompt cache mid-session). Based on PR #1595 by 333Alden333 (salvaged). +Updated for PR #3586 (cache-aware install/uninstall). """ from unittest.mock import patch, MagicMock @@ -32,23 +35,43 @@ class TestHandleSkillsSlashInstallFlags: _, kwargs = mock_install.call_args assert kwargs.get("skip_confirm") is True - def test_force_flag_sets_force_not_skip(self): + def test_force_flag_sets_force(self): from hermes_cli.skills_hub import handle_skills_slash with patch("hermes_cli.skills_hub.do_install") as mock_install: handle_skills_slash("/skills install test/skill --force") mock_install.assert_called_once() _, kwargs = mock_install.call_args assert kwargs.get("force") is True - assert kwargs.get("skip_confirm") is False + # Slash commands always skip confirmation (input() hangs in TUI) + assert kwargs.get("skip_confirm") is True - def test_no_flags(self): + def test_no_flags_still_skips_confirm(self): + """Slash commands always skip confirmation — input() hangs in TUI.""" from hermes_cli.skills_hub import handle_skills_slash with patch("hermes_cli.skills_hub.do_install") as mock_install: handle_skills_slash("/skills install test/skill") mock_install.assert_called_once() _, kwargs = mock_install.call_args assert kwargs.get("force") is False - assert kwargs.get("skip_confirm") is False + assert kwargs.get("skip_confirm") is True + + def test_default_defers_cache_invalidation(self): + """Without --now, cache invalidation is deferred to next session.""" + from hermes_cli.skills_hub import handle_skills_slash + with patch("hermes_cli.skills_hub.do_install") as mock_install: + handle_skills_slash("/skills install test/skill") + mock_install.assert_called_once() + _, kwargs = mock_install.call_args + assert kwargs.get("invalidate_cache") is False + + def test_now_flag_invalidates_cache(self): + """--now opts into immediate cache invalidation.""" + from hermes_cli.skills_hub import handle_skills_slash + with patch("hermes_cli.skills_hub.do_install") as mock_install: + handle_skills_slash("/skills install test/skill --now") + mock_install.assert_called_once() + _, kwargs = mock_install.call_args + assert kwargs.get("invalidate_cache") is True class TestHandleSkillsSlashUninstallFlags: @@ -70,13 +93,32 @@ class TestHandleSkillsSlashUninstallFlags: _, kwargs = mock_uninstall.call_args assert kwargs.get("skip_confirm") is True - def test_no_flags(self): + def test_no_flags_still_skips_confirm(self): + """Slash commands always skip confirmation — input() hangs in TUI.""" from hermes_cli.skills_hub import handle_skills_slash with patch("hermes_cli.skills_hub.do_uninstall") as mock_uninstall: handle_skills_slash("/skills uninstall test-skill") mock_uninstall.assert_called_once() _, kwargs = mock_uninstall.call_args - assert kwargs.get("skip_confirm", False) is False + assert kwargs.get("skip_confirm") is True + + def test_default_defers_cache_invalidation(self): + """Without --now, cache invalidation is deferred to next session.""" + from hermes_cli.skills_hub import handle_skills_slash + with patch("hermes_cli.skills_hub.do_uninstall") as mock_uninstall: + handle_skills_slash("/skills uninstall test-skill") + mock_uninstall.assert_called_once() + _, kwargs = mock_uninstall.call_args + assert kwargs.get("invalidate_cache") is False + + def test_now_flag_invalidates_cache(self): + """--now opts into immediate cache invalidation.""" + from hermes_cli.skills_hub import handle_skills_slash + with patch("hermes_cli.skills_hub.do_uninstall") as mock_uninstall: + handle_skills_slash("/skills uninstall test-skill --now") + mock_uninstall.assert_called_once() + _, kwargs = mock_uninstall.call_args + assert kwargs.get("invalidate_cache") is True class TestDoInstallSkipConfirm: diff --git a/tests/hermes_cli/test_tool_token_estimation.py b/tests/hermes_cli/test_tool_token_estimation.py new file mode 100644 index 000000000..3e48980bf --- /dev/null +++ b/tests/hermes_cli/test_tool_token_estimation.py @@ -0,0 +1,283 @@ +"""Tests for tool token estimation and curses_ui status_fn support.""" + +from unittest.mock import patch + +import pytest + +# tiktoken is not in core/[all] deps — skip estimation tests when unavailable +_has_tiktoken = True +try: + import tiktoken # noqa: F401 +except ImportError: + _has_tiktoken = False + +_needs_tiktoken = pytest.mark.skipif(not _has_tiktoken, reason="tiktoken not installed") + + +# ─── Token Estimation Tests ────────────────────────────────────────────────── + + +@_needs_tiktoken +def test_estimate_tool_tokens_returns_positive_counts(): + """_estimate_tool_tokens should return a non-empty dict with positive values.""" + from hermes_cli.tools_config import _estimate_tool_tokens, _tool_token_cache + + # Clear cache to force fresh computation + import hermes_cli.tools_config as tc + tc._tool_token_cache = None + + tokens = _estimate_tool_tokens() + + assert isinstance(tokens, dict) + assert len(tokens) > 0 + for name, count in tokens.items(): + assert isinstance(name, str) + assert isinstance(count, int) + assert count > 0, f"Tool {name} has non-positive token count: {count}" + + +@_needs_tiktoken +def test_estimate_tool_tokens_is_cached(): + """Second call should return the same cached dict object.""" + import hermes_cli.tools_config as tc + tc._tool_token_cache = None + + first = tc._estimate_tool_tokens() + second = tc._estimate_tool_tokens() + + assert first is second + + +def test_estimate_tool_tokens_returns_empty_when_tiktoken_unavailable(monkeypatch): + """Graceful degradation when tiktoken cannot be imported.""" + import hermes_cli.tools_config as tc + tc._tool_token_cache = None + + import builtins + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "tiktoken": + raise ImportError("mocked") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + + result = tc._estimate_tool_tokens() + + assert result == {} + + # Reset cache for other tests + tc._tool_token_cache = None + + +@_needs_tiktoken +def test_estimate_tool_tokens_covers_known_tools(): + """Should include schemas for well-known tools like terminal, web_search.""" + import hermes_cli.tools_config as tc + tc._tool_token_cache = None + + tokens = tc._estimate_tool_tokens() + + # These tools should always be discoverable + for expected in ("terminal", "web_search", "read_file"): + assert expected in tokens, f"Expected {expected!r} in token estimates" + + +# ─── Status Function Tests ─────────────────────────────────────────────────── + + +def test_prompt_toolset_checklist_passes_status_fn(monkeypatch): + """_prompt_toolset_checklist should pass a status_fn to curses_checklist.""" + import hermes_cli.tools_config as tc + + captured_kwargs = {} + + def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None): + captured_kwargs["status_fn"] = status_fn + captured_kwargs["title"] = title + return selected # Return pre-selected unchanged + + monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist) + + tc._prompt_toolset_checklist("CLI", {"web", "terminal"}) + + assert "status_fn" in captured_kwargs + # If tiktoken is available, status_fn should be set + tokens = tc._estimate_tool_tokens() + if tokens: + assert captured_kwargs["status_fn"] is not None + + +def test_status_fn_returns_formatted_token_count(monkeypatch): + """The status_fn should return a human-readable token count string.""" + import hermes_cli.tools_config as tc + from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS + + captured = {} + + def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None): + captured["status_fn"] = status_fn + return selected + + monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist) + + tc._prompt_toolset_checklist("CLI", {"web", "terminal"}) + + status_fn = captured.get("status_fn") + if status_fn is None: + pytest.skip("tiktoken unavailable; status_fn not created") + + # Find the indices for web and terminal + idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)} + + # Call status_fn with web + terminal selected + result = status_fn({idx_map["web"], idx_map["terminal"]}) + assert "tokens" in result + assert "Est. tool context" in result + + +def test_status_fn_deduplicates_overlapping_tools(monkeypatch): + """When toolsets overlap (browser includes web_search), tokens should not double-count.""" + import hermes_cli.tools_config as tc + from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS + + captured = {} + + def fake_checklist(title, items, selected, *, cancel_returns=None, status_fn=None): + captured["status_fn"] = status_fn + return selected + + monkeypatch.setattr("hermes_cli.curses_ui.curses_checklist", fake_checklist) + + tc._prompt_toolset_checklist("CLI", {"web"}) + + status_fn = captured.get("status_fn") + if status_fn is None: + pytest.skip("tiktoken unavailable; status_fn not created") + + idx_map = {ts_key: i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)} + + # web alone + web_only = status_fn({idx_map["web"]}) + # browser includes web_search, so browser + web should not double-count web_search + browser_only = status_fn({idx_map["browser"]}) + both = status_fn({idx_map["web"], idx_map["browser"]}) + + # Extract numeric token counts from strings like "~8.3k tokens" or "~350 tokens" + import re + + def parse_tokens(s): + m = re.search(r"~([\d.]+)k?\s+tokens", s) + if not m: + return 0 + val = float(m.group(1)) + if "k" in s[m.start():m.end()]: + val *= 1000 + return val + + web_tok = parse_tokens(web_only) + browser_tok = parse_tokens(browser_only) + both_tok = parse_tokens(both) + + # Both together should be LESS than naive sum (due to web_search dedup) + naive_sum = web_tok + browser_tok + assert both_tok < naive_sum, ( + f"Expected deduplication: web({web_tok}) + browser({browser_tok}) = {naive_sum} " + f"but combined = {both_tok}" + ) + + +def test_status_fn_empty_selection(): + """Status function with no tools selected should return ~0 tokens.""" + import hermes_cli.tools_config as tc + + tc._tool_token_cache = None + tokens = tc._estimate_tool_tokens() + if not tokens: + pytest.skip("tiktoken unavailable") + + from hermes_cli.tools_config import CONFIGURABLE_TOOLSETS + from toolsets import resolve_toolset + + ts_keys = [ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS] + + def status_fn(chosen: set) -> str: + all_tools: set = set() + for idx in chosen: + all_tools.update(resolve_toolset(ts_keys[idx])) + total = sum(tokens.get(name, 0) for name in all_tools) + if total >= 1000: + return f"Est. tool context: ~{total / 1000:.1f}k tokens" + return f"Est. tool context: ~{total} tokens" + + result = status_fn(set()) + assert "~0 tokens" in result + + +# ─── Curses UI Status Bar Tests ────────────────────────────────────────────── + + +def test_curses_checklist_numbered_fallback_shows_status(monkeypatch, capsys): + """The numbered fallback should print the status_fn output.""" + from hermes_cli.curses_ui import _numbered_fallback + + def my_status(chosen): + return f"Selected {len(chosen)} items" + + # Simulate user pressing Enter immediately (empty input → confirm) + monkeypatch.setattr("builtins.input", lambda _prompt="": "") + + result = _numbered_fallback( + "Test title", + ["Item A", "Item B", "Item C"], + {0, 2}, + {0, 2}, + status_fn=my_status, + ) + + captured = capsys.readouterr() + assert "Selected 2 items" in captured.out + assert result == {0, 2} + + +def test_curses_checklist_numbered_fallback_without_status(monkeypatch, capsys): + """The numbered fallback should work fine without status_fn.""" + from hermes_cli.curses_ui import _numbered_fallback + + monkeypatch.setattr("builtins.input", lambda _prompt="": "") + + result = _numbered_fallback( + "Test title", + ["Item A", "Item B"], + {0}, + {0}, + ) + + captured = capsys.readouterr() + assert "Est. tool context" not in captured.out + assert result == {0} + + +# ─── Registry get_schema Tests ─────────────────────────────────────────────── + + +def test_registry_get_schema_returns_schema(): + """registry.get_schema() should return a tool's schema dict.""" + from tools.registry import registry + + # Import to trigger discovery + import model_tools # noqa: F401 + + schema = registry.get_schema("terminal") + assert schema is not None + assert "name" in schema + assert schema["name"] == "terminal" + assert "parameters" in schema + + +def test_registry_get_schema_returns_none_for_unknown(): + """registry.get_schema() should return None for unknown tools.""" + from tools.registry import registry + + assert registry.get_schema("nonexistent_tool_xyz") is None diff --git a/tests/hermes_cli/test_tools_config.py b/tests/hermes_cli/test_tools_config.py index dccbce9d3..946ba77fd 100644 --- a/tests/hermes_cli/test_tools_config.py +++ b/tests/hermes_cli/test_tools_config.py @@ -332,3 +332,52 @@ def test_first_install_nous_auto_configures_managed_defaults(monkeypatch): assert config["tts"]["provider"] == "openai" assert config["browser"]["cloud_provider"] == "browserbase" assert configured == [] + +# ── Platform / toolset consistency ──────────────────────────────────────────── + + +class TestPlatformToolsetConsistency: + """Every platform in tools_config.PLATFORMS must have a matching toolset.""" + + def test_all_platforms_have_toolset_definitions(self): + """Each platform's default_toolset must exist in TOOLSETS.""" + from hermes_cli.tools_config import PLATFORMS + from toolsets import TOOLSETS + + for platform, meta in PLATFORMS.items(): + ts_name = meta["default_toolset"] + assert ts_name in TOOLSETS, ( + f"Platform {platform!r} references toolset {ts_name!r} " + f"which is not defined in toolsets.py" + ) + + def test_gateway_toolset_includes_all_messaging_platforms(self): + """hermes-gateway includes list should cover all messaging platforms.""" + from hermes_cli.tools_config import PLATFORMS + from toolsets import TOOLSETS + + gateway_includes = set(TOOLSETS["hermes-gateway"]["includes"]) + # Exclude non-messaging platforms from the check + non_messaging = {"cli", "api_server"} + for platform, meta in PLATFORMS.items(): + if platform in non_messaging: + continue + ts_name = meta["default_toolset"] + assert ts_name in gateway_includes, ( + f"Platform {platform!r} toolset {ts_name!r} missing from " + f"hermes-gateway includes" + ) + + def test_skills_config_covers_tools_config_platforms(self): + """skills_config.PLATFORMS should have entries for all gateway platforms.""" + from hermes_cli.tools_config import PLATFORMS as TOOLS_PLATFORMS + from hermes_cli.skills_config import PLATFORMS as SKILLS_PLATFORMS + + non_messaging = {"api_server"} + for platform in TOOLS_PLATFORMS: + if platform in non_messaging: + continue + assert platform in SKILLS_PLATFORMS, ( + f"Platform {platform!r} in tools_config but missing from " + f"skills_config PLATFORMS" + ) diff --git a/tests/hermes_cli/test_update_autostash.py b/tests/hermes_cli/test_update_autostash.py index 9b8b6d79a..042b4fd47 100644 --- a/tests/hermes_cli/test_update_autostash.py +++ b/tests/hermes_cli/test_update_autostash.py @@ -267,7 +267,8 @@ def test_restore_stashed_changes_user_declines_reset(monkeypatch, tmp_path, caps def test_restore_stashed_changes_auto_resets_non_interactive(monkeypatch, tmp_path, capsys): - """Non-interactive mode auto-resets without prompting.""" + """Non-interactive mode auto-resets without prompting and returns False + instead of sys.exit(1) so the update can continue (gateway /update path).""" calls = [] def fake_run(cmd, **kwargs): @@ -282,9 +283,9 @@ def test_restore_stashed_changes_auto_resets_non_interactive(monkeypatch, tmp_pa monkeypatch.setattr(hermes_main.subprocess, "run", fake_run) - with pytest.raises(SystemExit, match="1"): - hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False) + result = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False) + assert result is False out = capsys.readouterr().out assert "Working tree reset to clean state" in out reset_calls = [c for c, _ in calls if c[1:3] == ["reset", "--hard"]] @@ -384,3 +385,236 @@ def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path): install_cmds = [c for c in recorded if "pip" in c and "install" in c] assert len(install_cmds) == 1 assert ".[all]" in install_cmds[0] + + +# --------------------------------------------------------------------------- +# ff-only fallback to reset --hard on diverged history +# --------------------------------------------------------------------------- + +def _make_update_side_effect( + current_branch="main", + commit_count="3", + ff_only_fails=False, + reset_fails=False, + fetch_fails=False, + fetch_stderr="", +): + """Build a subprocess.run side_effect for cmd_update tests.""" + recorded = [] + + def side_effect(cmd, **kwargs): + recorded.append(cmd) + joined = " ".join(str(c) for c in cmd) + if "fetch" in joined and "origin" in joined: + if fetch_fails: + return SimpleNamespace(stdout="", stderr=fetch_stderr, returncode=128) + return SimpleNamespace(stdout="", stderr="", returncode=0) + if "rev-parse" in joined and "--abbrev-ref" in joined: + return SimpleNamespace(stdout=f"{current_branch}\n", stderr="", returncode=0) + if "checkout" in joined and "main" in joined: + return SimpleNamespace(stdout="", stderr="", returncode=0) + if "rev-list" in joined: + return SimpleNamespace(stdout=f"{commit_count}\n", stderr="", returncode=0) + if "--ff-only" in joined: + if ff_only_fails: + return SimpleNamespace( + stdout="", + stderr="fatal: Not possible to fast-forward, aborting.\n", + returncode=128, + ) + return SimpleNamespace(stdout="Updating abc..def\n", stderr="", returncode=0) + if "reset" in joined and "--hard" in joined: + if reset_fails: + return SimpleNamespace(stdout="", stderr="error: unable to write\n", returncode=1) + return SimpleNamespace(stdout="HEAD is now at abc123\n", stderr="", returncode=0) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + return side_effect, recorded + + +def test_cmd_update_falls_back_to_reset_when_ff_only_fails(monkeypatch, tmp_path, capsys): + """When --ff-only fails (diverged history), update resets to origin/{branch}.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + side_effect, recorded = _make_update_side_effect(ff_only_fails=True) + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + reset_calls = [c for c in recorded if "reset" in c and "--hard" in c] + assert len(reset_calls) == 1 + assert reset_calls[0] == ["git", "reset", "--hard", "origin/main"] + + out = capsys.readouterr().out + assert "Fast-forward not possible" in out + + +def test_cmd_update_no_reset_when_ff_only_succeeds(monkeypatch, tmp_path): + """When --ff-only succeeds, no reset is attempted.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + side_effect, recorded = _make_update_side_effect() + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + reset_calls = [c for c in recorded if "reset" in c and "--hard" in c] + assert len(reset_calls) == 0 + + +# --------------------------------------------------------------------------- +# Non-main branch → auto-checkout main +# --------------------------------------------------------------------------- + +def test_cmd_update_switches_to_main_from_feature_branch(monkeypatch, tmp_path, capsys): + """When on a feature branch, update checks out main before pulling.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + side_effect, recorded = _make_update_side_effect(current_branch="fix/something") + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + checkout_calls = [c for c in recorded if "checkout" in c and "main" in c] + assert len(checkout_calls) == 1 + + out = capsys.readouterr().out + assert "fix/something" in out + assert "switching to main" in out + + +def test_cmd_update_switches_to_main_from_detached_head(monkeypatch, tmp_path, capsys): + """When in detached HEAD state, update checks out main before pulling.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + side_effect, recorded = _make_update_side_effect(current_branch="HEAD") + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + checkout_calls = [c for c in recorded if "checkout" in c and "main" in c] + assert len(checkout_calls) == 1 + + out = capsys.readouterr().out + assert "detached HEAD" in out + + +def test_cmd_update_restores_stash_and_branch_when_already_up_to_date(monkeypatch, tmp_path, capsys): + """When on a feature branch with no updates, stash is restored and branch switched back.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + # Enable stash so it returns a ref + monkeypatch.setattr( + hermes_main, "_stash_local_changes_if_needed", + lambda *a, **kw: "abc123deadbeef", + ) + restore_calls = [] + monkeypatch.setattr( + hermes_main, "_restore_stashed_changes", + lambda *a, **kw: restore_calls.append(1) or True, + ) + + side_effect, recorded = _make_update_side_effect( + current_branch="fix/something", commit_count="0", + ) + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + # Stash should have been restored + assert len(restore_calls) == 1 + + # Should have checked out back to the original branch + checkout_back = [c for c in recorded if "checkout" in c and "fix/something" in c] + assert len(checkout_back) == 1 + + out = capsys.readouterr().out + assert "Already up to date" in out + + +def test_cmd_update_no_checkout_when_already_on_main(monkeypatch, tmp_path): + """When already on main, no checkout is needed.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + side_effect, recorded = _make_update_side_effect() + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + hermes_main.cmd_update(SimpleNamespace()) + + checkout_calls = [c for c in recorded if "checkout" in c] + assert len(checkout_calls) == 0 + + +# --------------------------------------------------------------------------- +# Fetch failure — friendly error messages +# --------------------------------------------------------------------------- + +def test_cmd_update_network_error_shows_friendly_message(monkeypatch, tmp_path, capsys): + """Network failures during fetch show a user-friendly message.""" + _setup_update_mocks(monkeypatch, tmp_path) + + side_effect, _ = _make_update_side_effect( + fetch_fails=True, + fetch_stderr="fatal: unable to access 'https://...': Could not resolve host: github.com", + ) + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + with pytest.raises(SystemExit, match="1"): + hermes_main.cmd_update(SimpleNamespace()) + + out = capsys.readouterr().out + assert "Network error" in out + + +def test_cmd_update_auth_error_shows_friendly_message(monkeypatch, tmp_path, capsys): + """Auth failures during fetch show a user-friendly message.""" + _setup_update_mocks(monkeypatch, tmp_path) + + side_effect, _ = _make_update_side_effect( + fetch_fails=True, + fetch_stderr="fatal: Authentication failed for 'https://...'", + ) + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + with pytest.raises(SystemExit, match="1"): + hermes_main.cmd_update(SimpleNamespace()) + + out = capsys.readouterr().out + assert "Authentication failed" in out + + +# --------------------------------------------------------------------------- +# reset --hard failure — don't attempt stash restore +# --------------------------------------------------------------------------- + +def test_cmd_update_skips_stash_restore_when_reset_fails(monkeypatch, tmp_path, capsys): + """When reset --hard fails, stash restore is skipped with a helpful message.""" + _setup_update_mocks(monkeypatch, tmp_path) + # Re-enable stash so it actually returns a ref + monkeypatch.setattr( + hermes_main, "_stash_local_changes_if_needed", + lambda *a, **kw: "abc123deadbeef", + ) + restore_calls = [] + monkeypatch.setattr( + hermes_main, "_restore_stashed_changes", + lambda *a, **kw: restore_calls.append(1) or True, + ) + + side_effect, _ = _make_update_side_effect(ff_only_fails=True, reset_fails=True) + monkeypatch.setattr(hermes_main.subprocess, "run", side_effect) + + with pytest.raises(SystemExit, match="1"): + hermes_main.cmd_update(SimpleNamespace()) + + # Stash restore should NOT have been called + assert len(restore_calls) == 0 + + out = capsys.readouterr().out + assert "preserved in stash" in out diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index b9cdecaa0..89ac84219 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -101,6 +101,69 @@ class TestLaunchdPlistReplace: assert replace_idx == run_idx + 1 +class TestLaunchdPlistPath: + def test_plist_contains_environment_variables(self): + plist = gateway_cli.generate_launchd_plist() + assert "EnvironmentVariables" in plist + assert "PATH" in plist + assert "VIRTUAL_ENV" in plist + assert "HERMES_HOME" in plist + + def test_plist_path_includes_venv_bin(self): + plist = gateway_cli.generate_launchd_plist() + detected = gateway_cli._detect_venv_dir() + venv_bin = str(detected / "bin") if detected else str(gateway_cli.PROJECT_ROOT / "venv" / "bin") + assert venv_bin in plist + + def test_plist_path_starts_with_venv_bin(self): + plist = gateway_cli.generate_launchd_plist() + lines = plist.splitlines() + for i, line in enumerate(lines): + if "PATH" in line.strip(): + path_value = lines[i + 1].strip() + path_value = path_value.replace("", "").replace("", "") + detected = gateway_cli._detect_venv_dir() + venv_bin = str(detected / "bin") if detected else str(gateway_cli.PROJECT_ROOT / "venv" / "bin") + assert path_value.startswith(venv_bin + ":") + break + else: + raise AssertionError("PATH key not found in plist") + + def test_plist_path_includes_node_modules_bin(self): + plist = gateway_cli.generate_launchd_plist() + node_bin = str(gateway_cli.PROJECT_ROOT / "node_modules" / ".bin") + lines = plist.splitlines() + for i, line in enumerate(lines): + if "PATH" in line.strip(): + path_value = lines[i + 1].strip() + path_value = path_value.replace("", "").replace("", "") + assert node_bin in path_value.split(":") + break + else: + raise AssertionError("PATH key not found in plist") + + def test_plist_path_includes_current_env_path(self, monkeypatch): + monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin") + plist = gateway_cli.generate_launchd_plist() + assert "/custom/bin" in plist + + def test_plist_path_deduplicates_venv_bin_when_already_in_path(self, monkeypatch): + detected = gateway_cli._detect_venv_dir() + venv_bin = str(detected / "bin") if detected else str(gateway_cli.PROJECT_ROOT / "venv" / "bin") + monkeypatch.setenv("PATH", f"{venv_bin}:/usr/bin:/bin") + plist = gateway_cli.generate_launchd_plist() + lines = plist.splitlines() + for i, line in enumerate(lines): + if "PATH" in line.strip(): + path_value = lines[i + 1].strip() + path_value = path_value.replace("", "").replace("", "") + parts = path_value.split(":") + assert parts.count(venv_bin) == 1 + break + else: + raise AssertionError("PATH key not found in plist") + + # --------------------------------------------------------------------------- # cmd_update — macOS launchd detection # --------------------------------------------------------------------------- @@ -177,6 +240,33 @@ class TestLaunchdPlistRefresh: assert any("unload" in s for s in cmd_strs) assert any("start" in s for s in cmd_strs) + def test_launchd_start_recreates_missing_plist_and_loads_service(self, tmp_path, monkeypatch): + """launchd_start self-heals when the plist file is missing entirely.""" + plist_path = tmp_path / "ai.hermes.gateway.plist" + assert not plist_path.exists() + + monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path) + + calls = [] + def fake_run(cmd, check=False, **kwargs): + calls.append(cmd) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run) + + gateway_cli.launchd_start() + + # Should have created the plist + assert plist_path.exists() + assert "--replace" in plist_path.read_text() + + cmd_strs = [" ".join(c) for c in calls] + # Should load the new plist, then start + assert any("load" in s for s in cmd_strs) + assert any("start" in s for s in cmd_strs) + # Should NOT call unload (nothing to unload) + assert not any("unload" in s for s in cmd_strs) + class TestCmdUpdateLaunchdRestart: """cmd_update correctly detects and handles launchd on macOS.""" diff --git a/tests/hermes_cli/test_webhook_cli.py b/tests/hermes_cli/test_webhook_cli.py new file mode 100644 index 000000000..0094e917c --- /dev/null +++ b/tests/hermes_cli/test_webhook_cli.py @@ -0,0 +1,189 @@ +"""Tests for hermes_cli/webhook.py — webhook subscription CLI.""" + +import json +import os +import pytest +from argparse import Namespace +from pathlib import Path + +from hermes_cli.webhook import ( + webhook_command, + _load_subscriptions, + _save_subscriptions, + _subscriptions_path, + _is_webhook_enabled, +) + + +@pytest.fixture(autouse=True) +def _isolate(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + # Default: webhooks enabled (most tests need this) + monkeypatch.setattr( + "hermes_cli.webhook._is_webhook_enabled", lambda: True + ) + + +def _make_args(**kwargs): + defaults = { + "webhook_action": None, + "name": "", + "prompt": "", + "events": "", + "description": "", + "skills": "", + "deliver": "log", + "deliver_chat_id": "", + "secret": "", + "payload": "", + } + defaults.update(kwargs) + return Namespace(**defaults) + + +class TestSubscribe: + def test_basic_create(self, capsys): + webhook_command(_make_args(webhook_action="subscribe", name="test-hook")) + out = capsys.readouterr().out + assert "Created" in out + assert "/webhooks/test-hook" in out + subs = _load_subscriptions() + assert "test-hook" in subs + + def test_with_options(self, capsys): + webhook_command(_make_args( + webhook_action="subscribe", + name="gh-issues", + events="issues,pull_request", + prompt="Issue: {issue.title}", + deliver="telegram", + deliver_chat_id="12345", + description="Watch GitHub", + )) + subs = _load_subscriptions() + route = subs["gh-issues"] + assert route["events"] == ["issues", "pull_request"] + assert route["prompt"] == "Issue: {issue.title}" + assert route["deliver"] == "telegram" + assert route["deliver_extra"] == {"chat_id": "12345"} + + def test_custom_secret(self): + webhook_command(_make_args( + webhook_action="subscribe", name="s", secret="my-secret" + )) + assert _load_subscriptions()["s"]["secret"] == "my-secret" + + def test_auto_secret(self): + webhook_command(_make_args(webhook_action="subscribe", name="s")) + secret = _load_subscriptions()["s"]["secret"] + assert len(secret) > 20 + + def test_update(self, capsys): + webhook_command(_make_args(webhook_action="subscribe", name="x", prompt="v1")) + webhook_command(_make_args(webhook_action="subscribe", name="x", prompt="v2")) + out = capsys.readouterr().out + assert "Updated" in out + assert _load_subscriptions()["x"]["prompt"] == "v2" + + def test_invalid_name(self, capsys): + webhook_command(_make_args(webhook_action="subscribe", name="bad name!")) + out = capsys.readouterr().out + assert "Error" in out or "Invalid" in out + assert _load_subscriptions() == {} + + +class TestList: + def test_empty(self, capsys): + webhook_command(_make_args(webhook_action="list")) + out = capsys.readouterr().out + assert "No dynamic" in out + + def test_with_entries(self, capsys): + webhook_command(_make_args(webhook_action="subscribe", name="a")) + webhook_command(_make_args(webhook_action="subscribe", name="b")) + capsys.readouterr() # clear + webhook_command(_make_args(webhook_action="list")) + out = capsys.readouterr().out + assert "2 webhook" in out + assert "a" in out + assert "b" in out + + +class TestRemove: + def test_remove_existing(self, capsys): + webhook_command(_make_args(webhook_action="subscribe", name="temp")) + webhook_command(_make_args(webhook_action="remove", name="temp")) + out = capsys.readouterr().out + assert "Removed" in out + assert _load_subscriptions() == {} + + def test_remove_nonexistent(self, capsys): + webhook_command(_make_args(webhook_action="remove", name="nope")) + out = capsys.readouterr().out + assert "No subscription" in out + + def test_selective_remove(self): + webhook_command(_make_args(webhook_action="subscribe", name="keep")) + webhook_command(_make_args(webhook_action="subscribe", name="drop")) + webhook_command(_make_args(webhook_action="remove", name="drop")) + subs = _load_subscriptions() + assert "keep" in subs + assert "drop" not in subs + + +class TestPersistence: + def test_file_written(self): + webhook_command(_make_args(webhook_action="subscribe", name="persist")) + path = _subscriptions_path() + assert path.exists() + data = json.loads(path.read_text()) + assert "persist" in data + + def test_corrupted_file(self): + path = _subscriptions_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("broken{{{") + assert _load_subscriptions() == {} + + +class TestWebhookEnabledGate: + def test_blocks_when_disabled(self, capsys, monkeypatch): + monkeypatch.setattr("hermes_cli.webhook._is_webhook_enabled", lambda: False) + webhook_command(_make_args(webhook_action="subscribe", name="blocked")) + out = capsys.readouterr().out + assert "not enabled" in out.lower() + assert "hermes gateway setup" in out + assert _load_subscriptions() == {} + + def test_blocks_list_when_disabled(self, capsys, monkeypatch): + monkeypatch.setattr("hermes_cli.webhook._is_webhook_enabled", lambda: False) + webhook_command(_make_args(webhook_action="list")) + out = capsys.readouterr().out + assert "not enabled" in out.lower() + + def test_allows_when_enabled(self, capsys): + # _is_webhook_enabled already patched to True by autouse fixture + webhook_command(_make_args(webhook_action="subscribe", name="allowed")) + out = capsys.readouterr().out + assert "Created" in out + assert "allowed" in _load_subscriptions() + + def test_real_check_disabled(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.webhook._get_webhook_config", + lambda: {}, + ) + monkeypatch.setattr( + "hermes_cli.webhook._is_webhook_enabled", + lambda: bool({}.get("enabled")), + ) + import hermes_cli.webhook as wh_mod + assert wh_mod._is_webhook_enabled() is False + + def test_real_check_enabled(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.webhook._is_webhook_enabled", + lambda: True, + ) + import hermes_cli.webhook as wh_mod + assert wh_mod._is_webhook_enabled() is True diff --git a/tests/skills/test_memento_cards.py b/tests/skills/test_memento_cards.py new file mode 100644 index 000000000..c1e29039c --- /dev/null +++ b/tests/skills/test_memento_cards.py @@ -0,0 +1,427 @@ +"""Tests for optional-skills/productivity/memento-flashcards/scripts/memento_cards.py""" + +import csv +import json +import os +import sys +import uuid +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest import mock + +import pytest + +# Add the scripts dir so we can import the module directly +SCRIPTS_DIR = Path(__file__).resolve().parents[2] / "optional-skills" / "productivity" / "memento-flashcards" / "scripts" +sys.path.insert(0, str(SCRIPTS_DIR)) + +import memento_cards + + +@pytest.fixture(autouse=True) +def isolated_data(tmp_path, monkeypatch): + """Redirect card storage to a temp directory for every test.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + monkeypatch.setattr(memento_cards, "DATA_DIR", data_dir) + monkeypatch.setattr(memento_cards, "CARDS_FILE", data_dir / "cards.json") + return data_dir + + +def _run(capsys, argv: list[str]) -> dict: + """Run main() with given argv and return parsed JSON output.""" + with mock.patch("sys.argv", ["memento_cards"] + argv): + memento_cards.main() + captured = capsys.readouterr() + return json.loads(captured.out) + + +# ── Add / List / Delete ────────────────────────────────────────────────────── + +class TestCardCRUD: + def test_add_creates_card(self, capsys): + result = _run(capsys, ["add", "--question", "What is 2+2?", "--answer", "4", "--collection", "Math"]) + assert result["ok"] is True + card = result["card"] + assert card["question"] == "What is 2+2?" + assert card["answer"] == "4" + assert card["collection"] == "Math" + assert card["status"] == "learning" + assert card["ease_streak"] == 0 + uuid.UUID(card["id"]) # validates it's a real UUID + + def test_add_default_collection(self, capsys): + result = _run(capsys, ["add", "--question", "Q?", "--answer", "A"]) + assert result["card"]["collection"] == "General" + + def test_list_all(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"]) + result = _run(capsys, ["list"]) + assert result["count"] == 2 + + def test_list_by_collection(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"]) + result = _run(capsys, ["list", "--collection", "C1"]) + assert result["count"] == 1 + assert result["cards"][0]["collection"] == "C1" + + def test_list_by_status(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1"]) + result = _run(capsys, ["list", "--status", "learning"]) + assert result["count"] == 1 + result = _run(capsys, ["list", "--status", "retired"]) + assert result["count"] == 0 + + def test_delete_card(self, capsys): + result = _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = result["card"]["id"] + del_result = _run(capsys, ["delete", "--id", card_id]) + assert del_result["ok"] is True + assert del_result["deleted"] == card_id + # Verify gone + list_result = _run(capsys, ["list"]) + assert list_result["count"] == 0 + + def test_delete_nonexistent(self, capsys): + with pytest.raises(SystemExit): + _run(capsys, ["delete", "--id", "nonexistent"]) + + def test_delete_collection(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "ToDelete"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "ToDelete"]) + _run(capsys, ["add", "--question", "Q3", "--answer", "A3", "--collection", "Keep"]) + result = _run(capsys, ["delete-collection", "--collection", "ToDelete"]) + assert result["ok"] is True + assert result["deleted_count"] == 2 + list_result = _run(capsys, ["list"]) + assert list_result["count"] == 1 + assert list_result["cards"][0]["collection"] == "Keep" + + +# ── Due Filtering ──────────────────────────────────────────────────────────── + +class TestDueFiltering: + def test_new_card_is_due(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + result = _run(capsys, ["due"]) + assert result["count"] == 1 + + def test_future_card_not_due(self, capsys, monkeypatch): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + # Rate it good (pushes next_review_at to +3 days) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + _run(capsys, ["rate", "--id", card_id, "--rating", "good"]) + result = _run(capsys, ["due"]) + assert result["count"] == 0 + + def test_retired_card_not_due(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + _run(capsys, ["rate", "--id", card_id, "--rating", "retire"]) + result = _run(capsys, ["due"]) + assert result["count"] == 0 + + def test_due_with_collection_filter(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"]) + result = _run(capsys, ["due", "--collection", "C1"]) + assert result["count"] == 1 + assert result["cards"][0]["collection"] == "C1" + + +# ── Rating and Rescheduling ────────────────────────────────────────────────── + +class TestRating: + def test_hard_adds_1_day(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + before = datetime.now(timezone.utc) + result = _run(capsys, ["rate", "--id", card_id, "--rating", "hard"]) + after = datetime.now(timezone.utc) + next_review = datetime.fromisoformat(result["card"]["next_review_at"]) + assert before + timedelta(days=1) <= next_review <= after + timedelta(days=1) + assert result["card"]["ease_streak"] == 0 + + def test_good_adds_3_days(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + before = datetime.now(timezone.utc) + result = _run(capsys, ["rate", "--id", card_id, "--rating", "good"]) + next_review = datetime.fromisoformat(result["card"]["next_review_at"]) + assert next_review >= before + timedelta(days=3) + assert result["card"]["ease_streak"] == 0 + + def test_easy_adds_7_days_and_increments_streak(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"]) + assert result["card"]["ease_streak"] == 1 + assert result["card"]["status"] == "learning" + + def test_retire_sets_retired(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + result = _run(capsys, ["rate", "--id", card_id, "--rating", "retire"]) + assert result["card"]["status"] == "retired" + assert result["card"]["ease_streak"] == 0 + + def test_auto_retire_after_3_easys(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + + # Force card to be due by manipulating next_review_at through rate + for i in range(3): + # Load and directly set next_review_at to now so it's ratable + data = memento_cards._load() + for c in data["cards"]: + if c["id"] == card_id: + c["next_review_at"] = memento_cards._iso(memento_cards._now()) + memento_cards._save(data) + + result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"]) + + assert result["card"]["ease_streak"] == 3 + assert result["card"]["status"] == "retired" + + def test_hard_resets_ease_streak(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + + # Easy twice + for _ in range(2): + data = memento_cards._load() + for c in data["cards"]: + if c["id"] == card_id: + c["next_review_at"] = memento_cards._iso(memento_cards._now()) + memento_cards._save(data) + _run(capsys, ["rate", "--id", card_id, "--rating", "easy"]) + + # Verify streak is 2 + check = _run(capsys, ["list"]) + assert check["cards"][0]["ease_streak"] == 2 + + # Hard resets + data = memento_cards._load() + for c in data["cards"]: + if c["id"] == card_id: + c["next_review_at"] = memento_cards._iso(memento_cards._now()) + memento_cards._save(data) + result = _run(capsys, ["rate", "--id", card_id, "--rating", "hard"]) + assert result["card"]["ease_streak"] == 0 + assert result["card"]["status"] == "learning" + + def test_rate_nonexistent_card(self, capsys): + with pytest.raises(SystemExit): + _run(capsys, ["rate", "--id", "nonexistent", "--rating", "easy"]) + + +# ── CSV Export/Import ──────────────────────────────────────────────────────── + +class TestCSV: + def test_export_import_roundtrip(self, capsys, tmp_path): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C2"]) + + csv_path = str(tmp_path / "export.csv") + result = _run(capsys, ["export", "--output", csv_path]) + assert result["ok"] is True + assert result["exported"] == 2 + + # Verify CSV content + with open(csv_path, "r") as f: + reader = csv.reader(f) + rows = list(reader) + assert len(rows) == 2 + assert rows[0] == ["Q1", "A1", "C1"] + assert rows[1] == ["Q2", "A2", "C2"] + + # Delete all and reimport + data = memento_cards._load() + data["cards"] = [] + memento_cards._save(data) + + result = _run(capsys, ["import", "--file", csv_path, "--collection", "Fallback"]) + assert result["ok"] is True + assert result["imported"] == 2 + + # Verify imported cards use CSV collection column + list_result = _run(capsys, ["list"]) + collections = {c["collection"] for c in list_result["cards"]} + assert collections == {"C1", "C2"} + + def test_import_without_collection_column(self, capsys, tmp_path): + csv_path = str(tmp_path / "no_col.csv") + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Q1", "A1"]) + writer.writerow(["Q2", "A2"]) + + result = _run(capsys, ["import", "--file", csv_path, "--collection", "MyDeck"]) + assert result["imported"] == 2 + + list_result = _run(capsys, ["list"]) + assert all(c["collection"] == "MyDeck" for c in list_result["cards"]) + + def test_import_skips_empty_rows(self, capsys, tmp_path): + csv_path = str(tmp_path / "sparse.csv") + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["Q1", "A1"]) + writer.writerow(["", ""]) # empty + writer.writerow(["Q2"]) # only one column + writer.writerow(["Q3", "A3"]) + + result = _run(capsys, ["import", "--file", csv_path, "--collection", "Test"]) + assert result["imported"] == 2 + + def test_import_nonexistent_file(self, capsys, tmp_path): + with pytest.raises(SystemExit): + _run(capsys, ["import", "--file", str(tmp_path / "nope.csv"), "--collection", "X"]) + + +# ── Quiz Batch Add ─────────────────────────────────────────────────────────── + +class TestQuizBatchAdd: + def test_add_quiz_creates_cards(self, capsys): + questions = json.dumps([ + {"question": "Q1?", "answer": "A1"}, + {"question": "Q2?", "answer": "A2"}, + ]) + result = _run(capsys, ["add-quiz", "--video-id", "abc123", "--questions", questions, "--collection", "Quiz - Test"]) + assert result["ok"] is True + assert result["created_count"] == 2 + for card in result["cards"]: + assert card["video_id"] == "abc123" + assert card["collection"] == "Quiz - Test" + + def test_add_quiz_deduplicates_by_video_id(self, capsys): + questions = json.dumps([{"question": "Q?", "answer": "A"}]) + _run(capsys, ["add-quiz", "--video-id", "dup1", "--questions", questions]) + result = _run(capsys, ["add-quiz", "--video-id", "dup1", "--questions", questions]) + assert result["ok"] is True + assert result["skipped"] is True + assert result["reason"] == "duplicate_video_id" + # Only 1 card total (not 2) + list_result = _run(capsys, ["list"]) + assert list_result["count"] == 1 + + def test_add_quiz_invalid_json(self, capsys): + with pytest.raises(SystemExit): + _run(capsys, ["add-quiz", "--video-id", "x", "--questions", "not json"]) + + +# ── Statistics ─────────────────────────────────────────────────────────────── + +class TestStats: + def test_stats_empty(self, capsys): + result = _run(capsys, ["stats"]) + assert result["total"] == 0 + assert result["learning"] == 0 + assert result["retired"] == 0 + assert result["due_now"] == 0 + + def test_stats_counts(self, capsys): + _run(capsys, ["add", "--question", "Q1", "--answer", "A1", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q2", "--answer", "A2", "--collection", "C1"]) + _run(capsys, ["add", "--question", "Q3", "--answer", "A3", "--collection", "C2"]) + + # Retire one + card_id = _run(capsys, ["list"])["cards"][0]["id"] + _run(capsys, ["rate", "--id", card_id, "--rating", "retire"]) + + result = _run(capsys, ["stats"]) + assert result["total"] == 3 + assert result["learning"] == 2 + assert result["retired"] == 1 + assert result["due_now"] == 2 # 2 learning cards still due + assert result["collections"] == {"C1": 2, "C2": 1} + + +# ── Edge Cases ─────────────────────────────────────────────────────────────── + +class TestEdgeCases: + def test_empty_deck_operations(self, capsys): + """Operations on empty deck shouldn't crash.""" + result = _run(capsys, ["due"]) + assert result["count"] == 0 + result = _run(capsys, ["list"]) + assert result["count"] == 0 + result = _run(capsys, ["stats"]) + assert result["total"] == 0 + + def test_corrupt_json_recovery(self, capsys): + """Corrupt JSON file should be treated as empty.""" + memento_cards.DATA_DIR.mkdir(parents=True, exist_ok=True) + with open(memento_cards.CARDS_FILE, "w") as f: + f.write("{corrupted json...") + result = _run(capsys, ["list"]) + assert result["count"] == 0 + # Can still add + result = _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + assert result["ok"] is True + + def test_missing_cards_key_recovery(self, capsys): + """JSON without 'cards' key should be treated as empty.""" + memento_cards.DATA_DIR.mkdir(parents=True, exist_ok=True) + with open(memento_cards.CARDS_FILE, "w") as f: + json.dump({"version": 1}, f) + result = _run(capsys, ["list"]) + assert result["count"] == 0 + + def test_atomic_write_creates_dir(self, capsys): + """Data dir is created automatically if missing.""" + import shutil + if memento_cards.DATA_DIR.exists(): + shutil.rmtree(memento_cards.DATA_DIR) + result = _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + assert result["ok"] is True + assert memento_cards.CARDS_FILE.exists() + + def test_delete_collection_empty(self, capsys): + """Deleting a nonexistent collection succeeds with 0 deleted.""" + result = _run(capsys, ["delete-collection", "--collection", "Nope"]) + assert result["ok"] is True + assert result["deleted_count"] == 0 + + +# ── User Answer Tracking ──────────────────────────────────────────────────── + +class TestUserAnswer: + def test_rate_stores_user_answer(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy", + "--user-answer", "my answer"]) + assert result["card"]["last_user_answer"] == "my answer" + + def test_rate_without_user_answer_keeps_null(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + result = _run(capsys, ["rate", "--id", card_id, "--rating", "easy"]) + assert result["card"]["last_user_answer"] is None + + def test_new_card_has_last_user_answer_null(self, capsys): + result = _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + assert result["card"]["last_user_answer"] is None + + def test_user_answer_persists_in_list(self, capsys): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + _run(capsys, ["rate", "--id", card_id, "--rating", "easy", + "--user-answer", "my answer"]) + result = _run(capsys, ["list"]) + assert result["cards"][0]["last_user_answer"] == "my answer" + + def test_export_excludes_user_answer(self, capsys, tmp_path): + _run(capsys, ["add", "--question", "Q", "--answer", "A"]) + card_id = _run(capsys, ["list"])["cards"][0]["id"] + _run(capsys, ["rate", "--id", card_id, "--rating", "easy", + "--user-answer", "my answer"]) + csv_path = str(tmp_path / "export.csv") + _run(capsys, ["export", "--output", csv_path]) + with open(csv_path) as f: + rows = list(csv.reader(f)) + # CSV stays 3-column (question, answer, collection) — user_answer is internal only + assert len(rows[0]) == 3 diff --git a/tests/skills/test_youtube_quiz.py b/tests/skills/test_youtube_quiz.py new file mode 100644 index 000000000..182889ff6 --- /dev/null +++ b/tests/skills/test_youtube_quiz.py @@ -0,0 +1,128 @@ +"""Tests for optional-skills/productivity/memento-flashcards/scripts/youtube_quiz.py""" + +import json +import sys +from pathlib import Path +from types import SimpleNamespace +from unittest import mock + +import pytest + +SCRIPTS_DIR = Path(__file__).resolve().parents[2] / "optional-skills" / "productivity" / "memento-flashcards" / "scripts" +sys.path.insert(0, str(SCRIPTS_DIR)) + +import youtube_quiz + + +def _run(capsys, argv: list[str]) -> dict: + """Run main() with given argv and return parsed JSON output.""" + with mock.patch("sys.argv", ["youtube_quiz"] + argv): + youtube_quiz.main() + captured = capsys.readouterr() + return json.loads(captured.out) + + +class TestNormalizeSegments: + def test_basic(self): + segments = [{"text": "hello "}, {"text": " world"}] + assert youtube_quiz._normalize_segments(segments) == "hello world" + + def test_empty_segments(self): + assert youtube_quiz._normalize_segments([]) == "" + + def test_whitespace_only(self): + assert youtube_quiz._normalize_segments([{"text": " "}, {"text": " "}]) == "" + + def test_collapses_multiple_spaces(self): + segments = [{"text": "a b"}, {"text": "c d"}] + assert youtube_quiz._normalize_segments(segments) == "a b c d" + + +class TestFetchMissingDependency: + def test_missing_youtube_transcript_api(self, capsys, monkeypatch): + """When youtube-transcript-api is not installed, report the error.""" + import builtins + real_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "youtube_transcript_api": + raise ImportError("No module named 'youtube_transcript_api'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mock_import) + + with pytest.raises(SystemExit) as exc_info: + _run(capsys, ["fetch", "test123"]) + + captured = capsys.readouterr() + result = json.loads(captured.out) + assert result["ok"] is False + assert result["error"] == "missing_dependency" + assert "pip install" in result["message"] + + +class TestFetchWithMockedAPI: + def _make_mock_module(self, segments=None, raise_exc=None): + """Create a mock youtube_transcript_api module.""" + mock_module = mock.MagicMock() + + mock_api_instance = mock.MagicMock() + mock_module.YouTubeTranscriptApi.return_value = mock_api_instance + + if raise_exc: + mock_api_instance.fetch.side_effect = raise_exc + else: + raw_data = segments or [{"text": "Hello world"}] + result = mock.MagicMock() + result.to_raw_data.return_value = raw_data + mock_api_instance.fetch.return_value = result + + return mock_module + + def test_successful_fetch(self, capsys): + mock_mod = self._make_mock_module( + segments=[{"text": "This is a test"}, {"text": "transcript segment"}] + ) + with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}): + result = _run(capsys, ["fetch", "abc123"]) + + assert result["ok"] is True + assert result["video_id"] == "abc123" + assert "This is a test" in result["transcript"] + assert "transcript segment" in result["transcript"] + + def test_fetch_error(self, capsys): + mock_mod = self._make_mock_module(raise_exc=Exception("Video unavailable")) + with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}): + with pytest.raises(SystemExit): + _run(capsys, ["fetch", "bad_id"]) + + captured = capsys.readouterr() + result = json.loads(captured.out) + assert result["ok"] is False + assert result["error"] == "transcript_unavailable" + + def test_empty_transcript(self, capsys): + mock_mod = self._make_mock_module(segments=[{"text": ""}, {"text": " "}]) + with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}): + with pytest.raises(SystemExit): + _run(capsys, ["fetch", "empty_vid"]) + + captured = capsys.readouterr() + result = json.loads(captured.out) + assert result["ok"] is False + assert result["error"] == "empty_transcript" + + def test_segments_without_to_raw_data(self, capsys): + """Handle plain list segments (no to_raw_data method).""" + mock_mod = mock.MagicMock() + mock_api = mock.MagicMock() + mock_mod.YouTubeTranscriptApi.return_value = mock_api + # Return a plain list (no to_raw_data attribute) + mock_api.fetch.return_value = [{"text": "plain list"}] + + with mock.patch.dict("sys.modules", {"youtube_transcript_api": mock_mod}): + result = _run(capsys, ["fetch", "plain123"]) + + assert result["ok"] is True + assert result["transcript"] == "plain list" diff --git a/tests/test_anthropic_adapter.py b/tests/test_anthropic_adapter.py index 71638f0d3..7e2e1c767 100644 --- a/tests/test_anthropic_adapter.py +++ b/tests/test_anthropic_adapter.py @@ -801,6 +801,48 @@ class TestConvertMessages: assert all(not (b.get("type") == "text" and b.get("text") == "") for b in assistant_blocks) assert any(b.get("type") == "tool_use" for b in assistant_blocks) + def test_empty_user_message_string_gets_placeholder(self): + """Empty user message strings should get '(empty message)' placeholder. + + Anthropic rejects requests with empty user message content. + Regression test for #3143 — Discord @mention-only messages. + """ + messages = [ + {"role": "user", "content": ""}, + ] + _, result = convert_messages_to_anthropic(messages) + assert result[0]["role"] == "user" + assert result[0]["content"] == "(empty message)" + + def test_whitespace_only_user_message_gets_placeholder(self): + """Whitespace-only user messages should also get placeholder.""" + messages = [ + {"role": "user", "content": " \n\t "}, + ] + _, result = convert_messages_to_anthropic(messages) + assert result[0]["content"] == "(empty message)" + + def test_empty_user_message_list_gets_placeholder(self): + """Empty content list for user messages should get placeholder block.""" + messages = [ + {"role": "user", "content": []}, + ] + _, result = convert_messages_to_anthropic(messages) + assert result[0]["role"] == "user" + assert isinstance(result[0]["content"], list) + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0] == {"type": "text", "text": "(empty message)"} + + def test_user_message_with_empty_text_blocks_gets_placeholder(self): + """User message with only empty text blocks should get placeholder.""" + messages = [ + {"role": "user", "content": [{"type": "text", "text": ""}, {"type": "text", "text": " "}]}, + ] + _, result = convert_messages_to_anthropic(messages) + assert result[0]["role"] == "user" + assert isinstance(result[0]["content"], list) + assert result[0]["content"] == [{"type": "text", "text": "(empty message)"}] + # --------------------------------------------------------------------------- # Build kwargs @@ -884,7 +926,8 @@ class TestBuildAnthropicKwargs: ) assert "thinking" not in kwargs - def test_default_max_tokens(self): + def test_default_max_tokens_uses_model_output_limit(self): + """When max_tokens is None, use the model's native output limit.""" kwargs = build_anthropic_kwargs( model="claude-sonnet-4-20250514", messages=[{"role": "user", "content": "Hi"}], @@ -892,7 +935,135 @@ class TestBuildAnthropicKwargs: max_tokens=None, reasoning_config=None, ) - assert kwargs["max_tokens"] == 16384 + assert kwargs["max_tokens"] == 64_000 # Sonnet 4 output limit + + def test_default_max_tokens_opus_4_6(self): + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 128_000 + + def test_default_max_tokens_sonnet_4_6(self): + kwargs = build_anthropic_kwargs( + model="claude-sonnet-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 64_000 + + def test_default_max_tokens_date_stamped_model(self): + """Date-stamped model IDs should resolve via substring match.""" + kwargs = build_anthropic_kwargs( + model="claude-sonnet-4-5-20250929", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 64_000 + + def test_default_max_tokens_older_model(self): + kwargs = build_anthropic_kwargs( + model="claude-3-5-sonnet-20241022", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 8_192 + + def test_default_max_tokens_unknown_model_uses_highest(self): + """Unknown future models should get the highest known limit.""" + kwargs = build_anthropic_kwargs( + model="claude-ultra-5-20260101", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 128_000 + + def test_explicit_max_tokens_overrides_default(self): + """User-specified max_tokens should be respected.""" + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=4096, + reasoning_config=None, + ) + assert kwargs["max_tokens"] == 4096 + + def test_context_length_clamp(self): + """max_tokens should be clamped to context_length if it's smaller.""" + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", # 128K output + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + context_length=50000, + ) + assert kwargs["max_tokens"] == 49999 # context_length - 1 + + def test_context_length_no_clamp_when_larger(self): + """No clamping when context_length exceeds output limit.""" + kwargs = build_anthropic_kwargs( + model="claude-sonnet-4-6", # 64K output + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=None, + reasoning_config=None, + context_length=200000, + ) + assert kwargs["max_tokens"] == 64_000 + + +# --------------------------------------------------------------------------- +# Model output limit lookup +# --------------------------------------------------------------------------- + + +class TestGetAnthropicMaxOutput: + def test_opus_4_6(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-opus-4-6") == 128_000 + + def test_opus_4_6_variant(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-opus-4-6:1m:fast") == 128_000 + + def test_sonnet_4_6(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-sonnet-4-6") == 64_000 + + def test_sonnet_4_date_stamped(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-sonnet-4-20250514") == 64_000 + + def test_claude_3_5_sonnet(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-3-5-sonnet-20241022") == 8_192 + + def test_claude_3_opus(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-3-opus-20240229") == 4_096 + + def test_unknown_future_model(self): + from agent.anthropic_adapter import _get_anthropic_max_output + assert _get_anthropic_max_output("claude-ultra-5-20260101") == 128_000 + + def test_longest_prefix_wins(self): + """'claude-3-5-sonnet' should match before 'claude-3-5'.""" + from agent.anthropic_adapter import _get_anthropic_max_output + # claude-3-5-sonnet (8192) should win over a hypothetical shorter match + assert _get_anthropic_max_output("claude-3-5-sonnet-20241022") == 8_192 # --------------------------------------------------------------------------- diff --git a/tests/test_anthropic_error_handling.py b/tests/test_anthropic_error_handling.py index d6b8717bf..3d7660aa8 100644 --- a/tests/test_anthropic_error_handling.py +++ b/tests/test_anthropic_error_handling.py @@ -217,10 +217,17 @@ def test_529_overloaded_is_retried_and_recovers(monkeypatch): def test_429_exhausts_all_retries_before_raising(monkeypatch): - """429 must retry max_retries times, not abort on first attempt.""" + """429 must retry max_retries times, then return a failed result. + + The agent no longer re-raises after exhausting retries — it returns a + result dict with the error in final_response. This changed when the + fallback-provider feature was added (the agent tries a fallback before + giving up, and returns a result dict either way). + """ agent_cls = _make_agent_cls(_RateLimitError) # always fails - with pytest.raises(_RateLimitError): - _run_with_agent(monkeypatch, agent_cls) + result = _run_with_agent(monkeypatch, agent_cls) + resp = str(result.get("final_response", "")) + assert "429" in resp or "retries" in resp.lower() def test_400_bad_request_is_non_retryable(monkeypatch): diff --git a/tests/test_api_key_providers.py b/tests/test_api_key_providers.py index 95d18bdd8..0c6337d3e 100644 --- a/tests/test_api_key_providers.py +++ b/tests/test_api_key_providers.py @@ -38,6 +38,7 @@ class TestProviderRegistry: @pytest.mark.parametrize("provider_id,name,auth_type", [ ("copilot-acp", "GitHub Copilot ACP", "external_process"), ("copilot", "GitHub Copilot", "api_key"), + ("huggingface", "Hugging Face", "api_key"), ("zai", "Z.AI / GLM", "api_key"), ("kimi-coding", "Kimi / Moonshot", "api_key"), ("minimax", "MiniMax", "api_key"), @@ -87,6 +88,11 @@ class TestProviderRegistry: assert pconfig.api_key_env_vars == ("KILOCODE_API_KEY",) assert pconfig.base_url_env_var == "KILOCODE_BASE_URL" + def test_huggingface_env_vars(self): + pconfig = PROVIDER_REGISTRY["huggingface"] + assert pconfig.api_key_env_vars == ("HF_TOKEN",) + assert pconfig.base_url_env_var == "HF_BASE_URL" + def test_base_urls(self): assert PROVIDER_REGISTRY["copilot"].inference_base_url == "https://api.githubcopilot.com" assert PROVIDER_REGISTRY["copilot-acp"].inference_base_url == "acp://copilot" @@ -96,6 +102,7 @@ class TestProviderRegistry: assert PROVIDER_REGISTRY["minimax-cn"].inference_base_url == "https://api.minimaxi.com/anthropic" assert PROVIDER_REGISTRY["ai-gateway"].inference_base_url == "https://ai-gateway.vercel.sh/v1" assert PROVIDER_REGISTRY["kilocode"].inference_base_url == "https://api.kilo.ai/api/gateway" + assert PROVIDER_REGISTRY["huggingface"].inference_base_url == "https://router.huggingface.co/v1" def test_oauth_providers_unchanged(self): """Ensure we didn't break the existing OAuth providers.""" @@ -199,6 +206,18 @@ class TestResolveProvider: assert resolve_provider("github-copilot-acp") == "copilot-acp" assert resolve_provider("copilot-acp-agent") == "copilot-acp" + def test_explicit_huggingface(self): + assert resolve_provider("huggingface") == "huggingface" + + def test_alias_hf(self): + assert resolve_provider("hf") == "huggingface" + + def test_alias_hugging_face(self): + assert resolve_provider("hugging-face") == "huggingface" + + def test_alias_huggingface_hub(self): + assert resolve_provider("huggingface-hub") == "huggingface" + def test_unknown_provider_raises(self): with pytest.raises(AuthError): resolve_provider("nonexistent-provider-xyz") @@ -235,6 +254,10 @@ class TestResolveProvider: monkeypatch.setenv("KILOCODE_API_KEY", "test-kilo-key") assert resolve_provider("auto") == "kilocode" + def test_auto_detects_hf_token(self, monkeypatch): + monkeypatch.setenv("HF_TOKEN", "hf_test_token") + assert resolve_provider("auto") == "huggingface" + def test_openrouter_takes_priority_over_glm(self, monkeypatch): """OpenRouter API key should win over GLM in auto-detection.""" monkeypatch.setenv("OPENROUTER_API_KEY", "or-key") @@ -243,7 +266,8 @@ class TestResolveProvider: def test_auto_does_not_select_copilot_from_github_token(self, monkeypatch): monkeypatch.setenv("GITHUB_TOKEN", "gh-test-token") - assert resolve_provider("auto") == "openrouter" + with pytest.raises(AuthError, match="No inference provider configured"): + resolve_provider("auto") # ============================================================================= @@ -708,3 +732,55 @@ class TestKimiMoonshotModelListIsolation: coding_models = _PROVIDER_MODELS["kimi-coding"] assert "kimi-for-coding" in coding_models assert "kimi-k2-thinking-turbo" in coding_models + + +# ============================================================================= +# Hugging Face provider model list tests +# ============================================================================= + +class TestHuggingFaceModels: + """Verify Hugging Face model lists are consistent across all locations.""" + + def test_main_provider_models_has_huggingface(self): + from hermes_cli.main import _PROVIDER_MODELS + assert "huggingface" in _PROVIDER_MODELS + models = _PROVIDER_MODELS["huggingface"] + assert len(models) >= 6, "Expected at least 6 curated HF models" + + def test_models_py_has_huggingface(self): + from hermes_cli.models import _PROVIDER_MODELS + assert "huggingface" in _PROVIDER_MODELS + models = _PROVIDER_MODELS["huggingface"] + assert len(models) >= 6 + + def test_model_lists_match(self): + """Model lists in main.py and models.py should be identical.""" + from hermes_cli.main import _PROVIDER_MODELS as main_models + from hermes_cli.models import _PROVIDER_MODELS as models_models + assert main_models["huggingface"] == models_models["huggingface"] + + def test_model_metadata_has_context_lengths(self): + """Every HF model should have a context length entry.""" + from hermes_cli.models import _PROVIDER_MODELS + from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS + hf_models = _PROVIDER_MODELS["huggingface"] + for model in hf_models: + assert model in DEFAULT_CONTEXT_LENGTHS, ( + f"HF model {model!r} missing from DEFAULT_CONTEXT_LENGTHS" + ) + + def test_models_use_org_name_format(self): + """HF models should use org/name format (e.g. Qwen/Qwen3-235B).""" + from hermes_cli.models import _PROVIDER_MODELS + for model in _PROVIDER_MODELS["huggingface"]: + assert "/" in model, f"HF model {model!r} missing org/ prefix" + + def test_provider_aliases_in_models_py(self): + from hermes_cli.models import _PROVIDER_ALIASES + assert _PROVIDER_ALIASES.get("hf") == "huggingface" + assert _PROVIDER_ALIASES.get("hugging-face") == "huggingface" + + def test_provider_label(self): + from hermes_cli.models import _PROVIDER_LABELS + assert "huggingface" in _PROVIDER_LABELS + assert _PROVIDER_LABELS["huggingface"] == "Hugging Face" diff --git a/tests/test_async_httpx_del_neuter.py b/tests/test_async_httpx_del_neuter.py new file mode 100644 index 000000000..ce8e20e70 --- /dev/null +++ b/tests/test_async_httpx_del_neuter.py @@ -0,0 +1,162 @@ +"""Tests for the AsyncHttpxClientWrapper.__del__ neuter fix. + +The OpenAI SDK's ``AsyncHttpxClientWrapper.__del__`` schedules +``aclose()`` via ``asyncio.get_running_loop().create_task()``. When GC +fires during CLI idle time, prompt_toolkit's event loop picks up the task +and crashes with "Event loop is closed" because the underlying TCP +transport is bound to a dead worker loop. + +The three-layer defence: +1. ``neuter_async_httpx_del()`` replaces ``__del__`` with a no-op. +2. A custom asyncio exception handler silences residual errors. +3. ``cleanup_stale_async_clients()`` evicts stale cache entries. +""" + +import asyncio +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Layer 1: neuter_async_httpx_del +# --------------------------------------------------------------------------- + +class TestNeuterAsyncHttpxDel: + """Verify neuter_async_httpx_del replaces __del__ on the SDK class.""" + + def test_del_becomes_noop(self): + """After neuter, __del__ should do nothing (no RuntimeError).""" + from agent.auxiliary_client import neuter_async_httpx_del + + try: + from openai._base_client import AsyncHttpxClientWrapper + except ImportError: + pytest.skip("openai SDK not installed") + + # Save original so we can restore + original_del = AsyncHttpxClientWrapper.__del__ + try: + neuter_async_httpx_del() + # The patched __del__ should be a no-op lambda + assert AsyncHttpxClientWrapper.__del__ is not original_del + # Calling it should not raise, even without a running loop + wrapper = MagicMock(spec=AsyncHttpxClientWrapper) + AsyncHttpxClientWrapper.__del__(wrapper) # Should be silent + finally: + # Restore original to avoid leaking into other tests + AsyncHttpxClientWrapper.__del__ = original_del + + def test_neuter_idempotent(self): + """Calling neuter twice doesn't break anything.""" + from agent.auxiliary_client import neuter_async_httpx_del + + try: + from openai._base_client import AsyncHttpxClientWrapper + except ImportError: + pytest.skip("openai SDK not installed") + + original_del = AsyncHttpxClientWrapper.__del__ + try: + neuter_async_httpx_del() + first_del = AsyncHttpxClientWrapper.__del__ + neuter_async_httpx_del() + second_del = AsyncHttpxClientWrapper.__del__ + # Both calls should succeed; the class should have a no-op + assert first_del is not original_del + assert second_del is not original_del + finally: + AsyncHttpxClientWrapper.__del__ = original_del + + def test_neuter_graceful_without_sdk(self): + """neuter_async_httpx_del doesn't raise if the openai SDK isn't installed.""" + from agent.auxiliary_client import neuter_async_httpx_del + + with patch.dict("sys.modules", {"openai._base_client": None}): + # Should not raise + neuter_async_httpx_del() + + +# --------------------------------------------------------------------------- +# Layer 3: cleanup_stale_async_clients +# --------------------------------------------------------------------------- + +class TestCleanupStaleAsyncClients: + """Verify stale cache entries are evicted and force-closed.""" + + def test_removes_stale_entries(self): + """Entries with a closed loop should be evicted.""" + from agent.auxiliary_client import ( + _client_cache, + _client_cache_lock, + cleanup_stale_async_clients, + ) + + # Create a loop, close it, make a cache entry + loop = asyncio.new_event_loop() + loop.close() + + mock_client = MagicMock() + # Give it _client attribute for _force_close_async_httpx + mock_client._client = MagicMock() + mock_client._client.is_closed = False + + key = ("test_stale", True, "", "", id(loop)) + with _client_cache_lock: + _client_cache[key] = (mock_client, "test-model", loop) + + try: + cleanup_stale_async_clients() + with _client_cache_lock: + assert key not in _client_cache, "Stale entry should be removed" + finally: + # Clean up in case test fails + with _client_cache_lock: + _client_cache.pop(key, None) + + def test_keeps_live_entries(self): + """Entries with an open loop should be preserved.""" + from agent.auxiliary_client import ( + _client_cache, + _client_cache_lock, + cleanup_stale_async_clients, + ) + + loop = asyncio.new_event_loop() # NOT closed + + mock_client = MagicMock() + key = ("test_live", True, "", "", id(loop)) + with _client_cache_lock: + _client_cache[key] = (mock_client, "test-model", loop) + + try: + cleanup_stale_async_clients() + with _client_cache_lock: + assert key in _client_cache, "Live entry should be preserved" + finally: + loop.close() + with _client_cache_lock: + _client_cache.pop(key, None) + + def test_keeps_entries_without_loop(self): + """Sync entries (cached_loop=None) should be preserved.""" + from agent.auxiliary_client import ( + _client_cache, + _client_cache_lock, + cleanup_stale_async_clients, + ) + + mock_client = MagicMock() + key = ("test_sync", False, "", "", 0) + with _client_cache_lock: + _client_cache[key] = (mock_client, "test-model", None) + + try: + cleanup_stale_async_clients() + with _client_cache_lock: + assert key in _client_cache, "Sync entry should be preserved" + finally: + with _client_cache_lock: + _client_cache.pop(key, None) diff --git a/tests/test_cli_init.py b/tests/test_cli_init.py index f41f81bb8..b5598aed1 100644 --- a/tests/test_cli_init.py +++ b/tests/test_cli_init.py @@ -96,6 +96,59 @@ class TestVerboseAndToolProgress: assert cli.tool_progress_mode in ("off", "new", "all", "verbose") +class TestBusyInputMode: + def test_default_busy_input_mode_is_interrupt(self): + cli = _make_cli() + assert cli.busy_input_mode == "interrupt" + + def test_busy_input_mode_queue_is_honored(self): + cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}}) + assert cli.busy_input_mode == "queue" + + def test_unknown_busy_input_mode_falls_back_to_interrupt(self): + cli = _make_cli(config_overrides={"display": {"busy_input_mode": "bogus"}}) + assert cli.busy_input_mode == "interrupt" + + def test_queue_command_works_while_busy(self): + """When agent is running, /queue should still put the prompt in _pending_input.""" + cli = _make_cli() + cli._agent_running = True + cli.process_command("/queue follow up") + assert cli._pending_input.get_nowait() == "follow up" + + def test_queue_command_works_while_idle(self): + """When agent is idle, /queue should still queue (not reject).""" + cli = _make_cli() + cli._agent_running = False + cli.process_command("/queue follow up") + assert cli._pending_input.get_nowait() == "follow up" + + def test_queue_mode_routes_busy_enter_to_pending(self): + """In queue mode, Enter while busy should go to _pending_input, not _interrupt_queue.""" + cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}}) + cli._agent_running = True + # Simulate what handle_enter does for non-command input while busy + text = "follow up" + if cli.busy_input_mode == "queue": + cli._pending_input.put(text) + else: + cli._interrupt_queue.put(text) + assert cli._pending_input.get_nowait() == "follow up" + assert cli._interrupt_queue.empty() + + def test_interrupt_mode_routes_busy_enter_to_interrupt(self): + """In interrupt mode (default), Enter while busy goes to _interrupt_queue.""" + cli = _make_cli() + cli._agent_running = True + text = "redirect" + if cli.busy_input_mode == "queue": + cli._pending_input.put(text) + else: + cli._interrupt_queue.put(text) + assert cli._interrupt_queue.get_nowait() == "redirect" + assert cli._pending_input.empty() + + class TestSingleQueryState: def test_voice_and_interrupt_state_initialized_before_run(self): """Single-query mode calls chat() without going through run().""" diff --git a/tests/test_cli_status_bar.py b/tests/test_cli_status_bar.py index c1dd4b35b..936ec2190 100644 --- a/tests/test_cli_status_bar.py +++ b/tests/test_cli_status_bar.py @@ -182,3 +182,94 @@ class TestCLIUsageReport: assert "Total cost:" in output assert "n/a" in output assert "Pricing unknown for glm-5" in output + + +class TestStatusBarWidthSource: + """Ensure status bar fragments don't overflow the terminal width.""" + + def _make_wide_cli(self): + from datetime import datetime, timedelta + cli_obj = _attach_agent( + _make_cli(), + prompt_tokens=100_000, + completion_tokens=5_000, + total_tokens=105_000, + api_calls=20, + context_tokens=100_000, + context_length=200_000, + ) + cli_obj._status_bar_visible = True + return cli_obj + + def test_fragments_fit_within_announced_width(self): + """Total fragment text length must not exceed the width used to build them.""" + from unittest.mock import MagicMock, patch + cli_obj = self._make_wide_cli() + + for width in (40, 52, 76, 80, 120, 200): + mock_app = MagicMock() + mock_app.output.get_size.return_value = MagicMock(columns=width) + + with patch("prompt_toolkit.application.get_app", return_value=mock_app): + frags = cli_obj._get_status_bar_fragments() + + total_text = "".join(text for _, text in frags) + assert len(total_text) <= width + 4, ( # +4 for minor padding chars + f"At width={width}, fragment total {len(total_text)} chars overflows " + f"({total_text!r})" + ) + + def test_fragments_use_pt_width_over_shutil(self): + """When prompt_toolkit reports a width, shutil.get_terminal_size must not be used.""" + from unittest.mock import MagicMock, patch + cli_obj = self._make_wide_cli() + + mock_app = MagicMock() + mock_app.output.get_size.return_value = MagicMock(columns=120) + + with patch("prompt_toolkit.application.get_app", return_value=mock_app) as mock_get_app, \ + patch("shutil.get_terminal_size") as mock_shutil: + cli_obj._get_status_bar_fragments() + + mock_shutil.assert_not_called() + + def test_fragments_fall_back_to_shutil_when_no_app(self): + """Outside a TUI context (no running app), shutil must be used as fallback.""" + from unittest.mock import MagicMock, patch + cli_obj = self._make_wide_cli() + + with patch("prompt_toolkit.application.get_app", side_effect=Exception("no app")), \ + patch("shutil.get_terminal_size", return_value=MagicMock(columns=100)) as mock_shutil: + frags = cli_obj._get_status_bar_fragments() + + mock_shutil.assert_called() + assert len(frags) > 0 + + def test_build_status_bar_text_uses_pt_width(self): + """_build_status_bar_text() must also prefer prompt_toolkit width.""" + from unittest.mock import MagicMock, patch + cli_obj = self._make_wide_cli() + + mock_app = MagicMock() + mock_app.output.get_size.return_value = MagicMock(columns=80) + + with patch("prompt_toolkit.application.get_app", return_value=mock_app), \ + patch("shutil.get_terminal_size") as mock_shutil: + text = cli_obj._build_status_bar_text() # no explicit width + + mock_shutil.assert_not_called() + assert isinstance(text, str) + assert len(text) > 0 + + def test_explicit_width_skips_pt_lookup(self): + """An explicit width= argument must bypass both PT and shutil lookups.""" + from unittest.mock import patch + cli_obj = self._make_wide_cli() + + with patch("prompt_toolkit.application.get_app") as mock_get_app, \ + patch("shutil.get_terminal_size") as mock_shutil: + text = cli_obj._build_status_bar_text(width=100) + + mock_get_app.assert_not_called() + mock_shutil.assert_not_called() + assert len(text) > 0 diff --git a/tests/test_codex_models.py b/tests/test_codex_models.py index 32fe63153..da178d9be 100644 --- a/tests/test_codex_models.py +++ b/tests/test_codex_models.py @@ -33,6 +33,7 @@ def test_get_codex_model_ids_prioritizes_default_and_cache(tmp_path, monkeypatch assert "gpt-5.3-codex" in models # Non-codex-suffixed models are included when the cache says they're available assert "gpt-5.4" in models + assert "gpt-5.4-mini" in models assert "gpt-5-hidden-codex" not in models @@ -64,7 +65,7 @@ def test_get_codex_model_ids_adds_forward_compat_models_from_templates(monkeypat models = get_codex_model_ids(access_token="codex-access-token") - assert models == ["gpt-5.2-codex", "gpt-5.3-codex", "gpt-5.4", "gpt-5.3-codex-spark"] + assert models == ["gpt-5.2-codex", "gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark"] def test_model_command_uses_runtime_access_token_for_codex_list(monkeypatch): diff --git a/tests/test_compressor_fallback_update.py b/tests/test_compressor_fallback_update.py new file mode 100644 index 000000000..064fd9b67 --- /dev/null +++ b/tests/test_compressor_fallback_update.py @@ -0,0 +1,91 @@ +"""Tests that _try_activate_fallback updates the context compressor.""" + +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent +from agent.context_compressor import ContextCompressor + + +def _make_agent_with_compressor() -> AIAgent: + """Build a minimal AIAgent with a context_compressor, skipping __init__.""" + agent = AIAgent.__new__(AIAgent) + + # Primary model settings + agent.model = "primary-model" + agent.provider = "openrouter" + agent.base_url = "https://openrouter.ai/api/v1" + agent.api_key = "sk-primary" + agent.api_mode = "chat_completions" + agent.client = MagicMock() + agent.quiet_mode = True + + # Fallback config + agent._fallback_activated = False + agent._fallback_model = { + "provider": "openai", + "model": "gpt-4o", + } + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 + + # Context compressor with primary model values + compressor = ContextCompressor( + model="primary-model", + threshold_percent=0.50, + base_url="https://openrouter.ai/api/v1", + api_key="sk-primary", + provider="openrouter", + quiet_mode=True, + ) + agent.context_compressor = compressor + + return agent + + +@patch("agent.auxiliary_client.resolve_provider_client") +@patch("agent.model_metadata.get_model_context_length", return_value=128_000) +def test_compressor_updated_on_fallback(mock_ctx_len, mock_resolve): + """After fallback activation, the compressor must reflect the fallback model.""" + agent = _make_agent_with_compressor() + + assert agent.context_compressor.model == "primary-model" + + fb_client = MagicMock() + fb_client.base_url = "https://api.openai.com/v1" + fb_client.api_key = "sk-fallback" + mock_resolve.return_value = (fb_client, None) + + agent._is_direct_openai_url = lambda url: "api.openai.com" in url + agent._emit_status = lambda msg: None + + result = agent._try_activate_fallback() + + assert result is True + assert agent._fallback_activated is True + + c = agent.context_compressor + assert c.model == "gpt-4o" + assert c.base_url == "https://api.openai.com/v1" + assert c.api_key == "sk-fallback" + assert c.provider == "openai" + assert c.context_length == 128_000 + assert c.threshold_tokens == int(128_000 * c.threshold_percent) + + +@patch("agent.auxiliary_client.resolve_provider_client") +@patch("agent.model_metadata.get_model_context_length", return_value=128_000) +def test_compressor_not_present_does_not_crash(mock_ctx_len, mock_resolve): + """If the agent has no compressor, fallback should still succeed.""" + agent = _make_agent_with_compressor() + agent.context_compressor = None + + fb_client = MagicMock() + fb_client.base_url = "https://api.openai.com/v1" + fb_client.api_key = "sk-fallback" + mock_resolve.return_value = (fb_client, None) + + agent._is_direct_openai_url = lambda url: "api.openai.com" in url + agent._emit_status = lambda msg: None + + result = agent._try_activate_fallback() + assert result is True diff --git a/tests/test_context_pressure.py b/tests/test_context_pressure.py index f89daef52..522603fdb 100644 --- a/tests/test_context_pressure.py +++ b/tests/test_context_pressure.py @@ -69,10 +69,12 @@ class TestFormatContextPressure: assert isinstance(result, str) def test_over_100_percent_capped(self): - """Progress > 1.0 should not break the bar.""" + """Progress > 1.0 should cap both bar and percentage text at 100%.""" line = format_context_pressure(1.05, 100_000, 0.50) assert "▰" in line assert line.count("▰") == 20 + assert "100%" in line + assert "105%" not in line class TestFormatContextPressureGateway: @@ -100,6 +102,13 @@ class TestFormatContextPressureGateway: msg = format_context_pressure_gateway(0.80, 0.50) assert "▰" in msg + def test_over_100_percent_capped(self): + """Progress > 1.0 should cap percentage text at 100%.""" + msg = format_context_pressure_gateway(1.09, 0.50) + assert "100% to compaction" in msg + assert "109%" not in msg + assert msg.count("▰") == 20 + # --------------------------------------------------------------------------- # AIAgent context pressure flag tests diff --git a/tests/test_mcp_serve.py b/tests/test_mcp_serve.py new file mode 100644 index 000000000..9dc013cac --- /dev/null +++ b/tests/test_mcp_serve.py @@ -0,0 +1,1111 @@ +""" +Tests for mcp_serve — Hermes MCP server. + +Three layers of tests: +1. Unit tests — helpers, content extraction, attachment parsing +2. EventBridge tests — queue mechanics, cursors, waiters, concurrency +3. End-to-end tests — call actual MCP tools through FastMCP's tool manager + with real session data in SQLite and sessions.json +""" + +import asyncio +import json +import os +import sqlite3 +import time +import threading +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _isolate_hermes_home(tmp_path, monkeypatch): + """Redirect HERMES_HOME to a temp directory.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + try: + import hermes_constants + monkeypatch.setattr(hermes_constants, "get_hermes_home", lambda: tmp_path) + except (ImportError, AttributeError): + pass + return tmp_path + + +@pytest.fixture +def sessions_dir(tmp_path): + sdir = tmp_path / "sessions" + sdir.mkdir(parents=True, exist_ok=True) + return sdir + + +@pytest.fixture +def sample_sessions(): + return { + "agent:main:telegram:dm:123456": { + "session_key": "agent:main:telegram:dm:123456", + "session_id": "20260329_120000_abc123", + "platform": "telegram", + "chat_type": "dm", + "display_name": "Alice", + "created_at": "2026-03-29T12:00:00", + "updated_at": "2026-03-29T14:30:00", + "input_tokens": 50000, + "output_tokens": 2000, + "total_tokens": 52000, + "origin": { + "platform": "telegram", + "chat_id": "123456", + "chat_name": "Alice", + "chat_type": "dm", + "user_id": "123456", + "user_name": "Alice", + "thread_id": None, + "chat_topic": None, + }, + }, + "agent:main:discord:group:789:456": { + "session_key": "agent:main:discord:group:789:456", + "session_id": "20260329_100000_def456", + "platform": "discord", + "chat_type": "group", + "display_name": "Bob", + "created_at": "2026-03-29T10:00:00", + "updated_at": "2026-03-29T13:00:00", + "input_tokens": 30000, + "output_tokens": 1000, + "total_tokens": 31000, + "origin": { + "platform": "discord", + "chat_id": "789", + "chat_name": "#general", + "chat_type": "group", + "user_id": "456", + "user_name": "Bob", + "thread_id": None, + "chat_topic": None, + }, + }, + "agent:main:slack:group:C1234:U5678": { + "session_key": "agent:main:slack:group:C1234:U5678", + "session_id": "20260328_090000_ghi789", + "platform": "slack", + "chat_type": "group", + "display_name": "Carol", + "created_at": "2026-03-28T09:00:00", + "updated_at": "2026-03-28T11:00:00", + "input_tokens": 10000, + "output_tokens": 500, + "total_tokens": 10500, + "origin": { + "platform": "slack", + "chat_id": "C1234", + "chat_name": "#engineering", + "chat_type": "group", + "user_id": "U5678", + "user_name": "Carol", + "thread_id": None, + "chat_topic": None, + }, + }, + } + + +@pytest.fixture +def populated_sessions_dir(sessions_dir, sample_sessions): + (sessions_dir / "sessions.json").write_text(json.dumps(sample_sessions)) + return sessions_dir + + +def _create_test_db(db_path, session_id, messages): + """Create a minimal SQLite DB mimicking hermes_state schema.""" + conn = sqlite3.connect(str(db_path)) + conn.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + source TEXT DEFAULT 'cli', + started_at TEXT, + message_count INTEGER DEFAULT 0 + ) + """) + conn.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT, + tool_call_id TEXT, + tool_calls TEXT, + tool_name TEXT, + timestamp TEXT, + token_count INTEGER DEFAULT 0, + finish_reason TEXT, + reasoning TEXT, + reasoning_details TEXT, + codex_reasoning_items TEXT + ) + """) + conn.execute( + "INSERT OR IGNORE INTO sessions (id, source, started_at, message_count) VALUES (?, 'gateway', ?, ?)", + (session_id, "2026-03-29T12:00:00", len(messages)), + ) + for msg in messages: + content = msg.get("content", "") + if isinstance(content, (list, dict)): + content = json.dumps(content) + conn.execute( + "INSERT INTO messages (session_id, role, content, timestamp, tool_calls) VALUES (?, ?, ?, ?, ?)", + (session_id, msg["role"], content, + msg.get("timestamp", "2026-03-29T12:00:00"), + json.dumps(msg["tool_calls"]) if msg.get("tool_calls") else None), + ) + conn.commit() + conn.close() + + +@pytest.fixture +def mock_session_db(tmp_path, populated_sessions_dir): + """Create a real SQLite DB with test messages and wire it up.""" + db_path = tmp_path / "state.db" + messages = [ + {"role": "user", "content": "Hello Alice!", "timestamp": "2026-03-29T12:00:01"}, + {"role": "assistant", "content": "Hi! How can I help?", "timestamp": "2026-03-29T12:00:05"}, + {"role": "user", "content": "Check the image MEDIA: /tmp/screenshot.png please", + "timestamp": "2026-03-29T12:01:00"}, + {"role": "assistant", "content": "I see the screenshot. It shows a terminal.", + "timestamp": "2026-03-29T12:01:10"}, + {"role": "tool", "content": '{"result": "ok"}', "timestamp": "2026-03-29T12:01:15"}, + {"role": "user", "content": "Thanks!", "timestamp": "2026-03-29T12:02:00"}, + ] + _create_test_db(db_path, "20260329_120000_abc123", messages) + + # Create a mock SessionDB that reads from our test DB + class TestSessionDB: + def __init__(self): + self._db_path = db_path + + def get_messages(self, session_id): + conn = sqlite3.connect(str(self._db_path)) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY id", + (session_id,), + ).fetchall() + conn.close() + result = [] + for r in rows: + d = dict(r) + if d.get("tool_calls"): + d["tool_calls"] = json.loads(d["tool_calls"]) + result.append(d) + return result + + return TestSessionDB() + + +# --------------------------------------------------------------------------- +# 1. UNIT TESTS — helpers, extraction, attachments +# --------------------------------------------------------------------------- + +class TestImports: + def test_import_module(self): + import mcp_serve + assert hasattr(mcp_serve, "create_mcp_server") + assert hasattr(mcp_serve, "run_mcp_server") + assert hasattr(mcp_serve, "EventBridge") + + def test_mcp_available_flag(self): + import mcp_serve + assert isinstance(mcp_serve._MCP_SERVER_AVAILABLE, bool) + + +class TestHelpers: + def test_get_sessions_dir(self, tmp_path): + from mcp_serve import _get_sessions_dir + result = _get_sessions_dir() + assert result == tmp_path / "sessions" + + def test_load_sessions_index_empty(self, sessions_dir, monkeypatch): + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + assert mcp_serve._load_sessions_index() == {} + + def test_load_sessions_index_with_data(self, populated_sessions_dir, monkeypatch): + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir) + result = mcp_serve._load_sessions_index() + assert len(result) == 3 + + def test_load_sessions_index_corrupt(self, sessions_dir, monkeypatch): + (sessions_dir / "sessions.json").write_text("not json!") + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + assert mcp_serve._load_sessions_index() == {} + + +class TestContentExtraction: + def test_text(self): + from mcp_serve import _extract_message_content + assert _extract_message_content({"content": "Hello"}) == "Hello" + + def test_multipart(self): + from mcp_serve import _extract_message_content + msg = {"content": [ + {"type": "text", "text": "A"}, + {"type": "image", "url": "http://x.com/i.png"}, + {"type": "text", "text": "B"}, + ]} + assert _extract_message_content(msg) == "A\nB" + + def test_empty(self): + from mcp_serve import _extract_message_content + assert _extract_message_content({"content": ""}) == "" + assert _extract_message_content({}) == "" + assert _extract_message_content({"content": None}) == "" + + +class TestAttachmentExtraction: + def test_image_url_block(self): + from mcp_serve import _extract_attachments + msg = {"content": [ + {"type": "image_url", "image_url": {"url": "http://x.com/pic.jpg"}}, + ]} + att = _extract_attachments(msg) + assert len(att) == 1 + assert att[0] == {"type": "image", "url": "http://x.com/pic.jpg"} + + def test_media_tag_in_text(self): + from mcp_serve import _extract_attachments + msg = {"content": "Here MEDIA: /tmp/out.png done"} + att = _extract_attachments(msg) + assert len(att) == 1 + assert att[0] == {"type": "media", "path": "/tmp/out.png"} + + def test_multiple_media_tags(self): + from mcp_serve import _extract_attachments + msg = {"content": "MEDIA: /a.png and MEDIA: /b.mp3"} + assert len(_extract_attachments(msg)) == 2 + + def test_no_attachments(self): + from mcp_serve import _extract_attachments + assert _extract_attachments({"content": "plain text"}) == [] + + def test_image_content_block(self): + from mcp_serve import _extract_attachments + msg = {"content": [{"type": "image", "url": "http://x.com/p.png"}]} + att = _extract_attachments(msg) + assert att[0]["type"] == "image" + + +# --------------------------------------------------------------------------- +# 2. EVENT BRIDGE TESTS — queue, cursors, waiters, concurrency +# --------------------------------------------------------------------------- + +class TestEventBridge: + def test_create(self): + from mcp_serve import EventBridge + b = EventBridge() + assert b._cursor == 0 + assert b._queue == [] + + def test_enqueue_and_poll(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + b._enqueue(QueueEvent(cursor=0, type="message", session_key="k1", + data={"content": "hi"})) + r = b.poll_events(after_cursor=0) + assert len(r["events"]) == 1 + assert r["events"][0]["type"] == "message" + assert r["next_cursor"] == 1 + + def test_cursor_filter(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + for i in range(5): + b._enqueue(QueueEvent(cursor=0, type="message", session_key=f"s{i}")) + r = b.poll_events(after_cursor=3) + assert len(r["events"]) == 2 + assert r["events"][0]["session_key"] == "s3" + + def test_session_filter(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + b._enqueue(QueueEvent(cursor=0, type="message", session_key="a")) + b._enqueue(QueueEvent(cursor=0, type="message", session_key="b")) + b._enqueue(QueueEvent(cursor=0, type="message", session_key="a")) + r = b.poll_events(after_cursor=0, session_key="a") + assert len(r["events"]) == 2 + + def test_poll_empty(self): + from mcp_serve import EventBridge + r = EventBridge().poll_events(after_cursor=0) + assert r["events"] == [] + assert r["next_cursor"] == 0 + + def test_poll_limit(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + for i in range(10): + b._enqueue(QueueEvent(cursor=0, type="message", session_key=f"s{i}")) + r = b.poll_events(after_cursor=0, limit=3) + assert len(r["events"]) == 3 + + def test_wait_immediate(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + b._enqueue(QueueEvent(cursor=0, type="message", session_key="t", + data={"content": "hi"})) + event = b.wait_for_event(after_cursor=0, timeout_ms=100) + assert event is not None + assert event["type"] == "message" + + def test_wait_timeout(self): + from mcp_serve import EventBridge + start = time.monotonic() + event = EventBridge().wait_for_event(after_cursor=0, timeout_ms=150) + assert event is None + assert time.monotonic() - start >= 0.1 + + def test_wait_wakes_on_enqueue(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + result = [None] + + def waiter(): + result[0] = b.wait_for_event(after_cursor=0, timeout_ms=5000) + + t = threading.Thread(target=waiter) + t.start() + time.sleep(0.05) + b._enqueue(QueueEvent(cursor=0, type="message", session_key="wake")) + t.join(timeout=2) + assert result[0] is not None + assert result[0]["session_key"] == "wake" + + def test_queue_limit(self): + from mcp_serve import EventBridge, QueueEvent, QUEUE_LIMIT + b = EventBridge() + for i in range(QUEUE_LIMIT + 50): + b._enqueue(QueueEvent(cursor=0, type="message", session_key=f"s{i}")) + assert len(b._queue) == QUEUE_LIMIT + + def test_concurrent_enqueue(self): + from mcp_serve import EventBridge, QueueEvent + b = EventBridge() + errors = [] + + def batch(start): + try: + for i in range(100): + b._enqueue(QueueEvent(cursor=0, type="message", + session_key=f"s{start}_{i}")) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=batch, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert not errors + assert len(b._queue) == 500 + assert b._cursor == 500 + + def test_approvals_lifecycle(self): + from mcp_serve import EventBridge + b = EventBridge() + b._pending_approvals["a1"] = { + "id": "a1", "kind": "exec", + "description": "rm -rf /tmp", + "session_key": "test", "created_at": "2026-03-29T12:00:00", + } + assert len(b.list_pending_approvals()) == 1 + result = b.respond_to_approval("a1", "deny") + assert result["resolved"] is True + assert len(b.list_pending_approvals()) == 0 + + def test_respond_nonexistent(self): + from mcp_serve import EventBridge + r = EventBridge().respond_to_approval("nope", "deny") + assert "error" in r + + +# --------------------------------------------------------------------------- +# 3. END-TO-END TESTS — call MCP tools through FastMCP server +# --------------------------------------------------------------------------- + +@pytest.fixture +def mcp_server_e2e(populated_sessions_dir, mock_session_db, monkeypatch): + """Create a fully wired MCP server for E2E testing.""" + mcp = pytest.importorskip("mcp", reason="MCP SDK not installed") + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir) + monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: mock_session_db) + monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: {}) + + bridge = mcp_serve.EventBridge() + server = mcp_serve.create_mcp_server(event_bridge=bridge) + return server, bridge + + +def _run_tool(server, name, args=None): + """Call an MCP tool through FastMCP's tool manager and return parsed JSON.""" + result = asyncio.get_event_loop().run_until_complete( + server._tool_manager.call_tool(name, args or {}) + ) + return json.loads(result) if isinstance(result, str) else result + + +@pytest.fixture +def _event_loop(): + """Ensure an event loop exists for sync tests calling async tools.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() + + +class TestE2EConversationsList: + def test_list_all(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list") + assert result["count"] == 3 + platforms = {c["platform"] for c in result["conversations"]} + assert platforms == {"telegram", "discord", "slack"} + + def test_list_sorted_by_updated(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list") + keys = [c["session_key"] for c in result["conversations"]] + # Telegram (14:30) > Discord (13:00) > Slack (11:00) + assert keys[0] == "agent:main:telegram:dm:123456" + assert keys[1] == "agent:main:discord:group:789:456" + assert keys[2] == "agent:main:slack:group:C1234:U5678" + + def test_filter_by_platform(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list", {"platform": "discord"}) + assert result["count"] == 1 + assert result["conversations"][0]["platform"] == "discord" + + def test_filter_by_platform_case_insensitive(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list", {"platform": "TELEGRAM"}) + assert result["count"] == 1 + + def test_search_by_name(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list", {"search": "Alice"}) + assert result["count"] == 1 + assert result["conversations"][0]["display_name"] == "Alice" + + def test_search_no_match(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list", {"search": "nobody"}) + assert result["count"] == 0 + + def test_limit(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversations_list", {"limit": 2}) + assert result["count"] == 2 + + +class TestE2EConversationGet: + def test_get_existing(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversation_get", + {"session_key": "agent:main:telegram:dm:123456"}) + assert result["platform"] == "telegram" + assert result["display_name"] == "Alice" + assert result["chat_id"] == "123456" + assert result["input_tokens"] == 50000 + + def test_get_nonexistent(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "conversation_get", + {"session_key": "nonexistent:key"}) + assert "error" in result + + +class TestE2EMessagesRead: + def test_read_messages(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_read", + {"session_key": "agent:main:telegram:dm:123456"}) + assert result["count"] > 0 + # Should filter out tool messages — only user/assistant + roles = {m["role"] for m in result["messages"]} + assert "tool" not in roles + assert "user" in roles + assert "assistant" in roles + + def test_read_messages_content(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_read", + {"session_key": "agent:main:telegram:dm:123456"}) + contents = [m["content"] for m in result["messages"]] + assert "Hello Alice!" in contents + assert "Hi! How can I help?" in contents + + def test_read_messages_have_ids(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_read", + {"session_key": "agent:main:telegram:dm:123456"}) + for msg in result["messages"]: + assert "id" in msg + assert msg["id"] # non-empty + + def test_read_with_limit(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_read", + {"session_key": "agent:main:telegram:dm:123456", + "limit": 2}) + assert result["count"] == 2 + + def test_read_nonexistent_session(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_read", + {"session_key": "nonexistent:key"}) + assert "error" in result + + +class TestE2EAttachmentsFetch: + def test_fetch_media_from_message(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + # First get message IDs + msgs = _run_tool(server, "messages_read", + {"session_key": "agent:main:telegram:dm:123456"}) + # Find the message with MEDIA: tag + media_msg = None + for m in msgs["messages"]: + if "MEDIA:" in m["content"]: + media_msg = m + break + assert media_msg is not None, "Should have a message with MEDIA: tag" + + result = _run_tool(server, "attachments_fetch", { + "session_key": "agent:main:telegram:dm:123456", + "message_id": media_msg["id"], + }) + assert result["count"] >= 1 + assert result["attachments"][0]["type"] == "media" + assert result["attachments"][0]["path"] == "/tmp/screenshot.png" + + def test_fetch_from_nonexistent_message(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "attachments_fetch", { + "session_key": "agent:main:telegram:dm:123456", + "message_id": "99999", + }) + assert "error" in result + + def test_fetch_from_nonexistent_session(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "attachments_fetch", { + "session_key": "nonexistent:key", + "message_id": "1", + }) + assert "error" in result + + +class TestE2EEventsPoll: + def test_poll_empty(self, mcp_server_e2e, _event_loop): + server, bridge = mcp_server_e2e + result = _run_tool(server, "events_poll") + assert result["events"] == [] + assert result["next_cursor"] == 0 + + def test_poll_with_events(self, mcp_server_e2e, _event_loop): + from mcp_serve import QueueEvent + server, bridge = mcp_server_e2e + bridge._enqueue(QueueEvent(cursor=0, type="message", + session_key="agent:main:telegram:dm:123456", + data={"role": "user", "content": "Hello"})) + bridge._enqueue(QueueEvent(cursor=0, type="message", + session_key="agent:main:telegram:dm:123456", + data={"role": "assistant", "content": "Hi"})) + + result = _run_tool(server, "events_poll") + assert len(result["events"]) == 2 + assert result["events"][0]["content"] == "Hello" + assert result["events"][1]["content"] == "Hi" + assert result["next_cursor"] == 2 + + def test_poll_cursor_pagination(self, mcp_server_e2e, _event_loop): + from mcp_serve import QueueEvent + server, bridge = mcp_server_e2e + for i in range(5): + bridge._enqueue(QueueEvent(cursor=0, type="message", + session_key=f"s{i}")) + + page1 = _run_tool(server, "events_poll", {"limit": 2}) + assert len(page1["events"]) == 2 + assert page1["next_cursor"] == 2 + + page2 = _run_tool(server, "events_poll", + {"after_cursor": page1["next_cursor"], "limit": 2}) + assert len(page2["events"]) == 2 + assert page2["next_cursor"] == 4 + + def test_poll_session_filter(self, mcp_server_e2e, _event_loop): + from mcp_serve import QueueEvent + server, bridge = mcp_server_e2e + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a")) + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="b")) + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="a")) + + result = _run_tool(server, "events_poll", + {"session_key": "b"}) + assert len(result["events"]) == 1 + + +class TestE2EEventsWait: + def test_wait_timeout(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "events_wait", {"timeout_ms": 100}) + assert result["event"] is None + assert result["reason"] == "timeout" + + def test_wait_with_existing_event(self, mcp_server_e2e, _event_loop): + from mcp_serve import QueueEvent + server, bridge = mcp_server_e2e + bridge._enqueue(QueueEvent(cursor=0, type="message", + session_key="test", + data={"content": "waiting for this"})) + result = _run_tool(server, "events_wait", {"timeout_ms": 100}) + assert result["event"] is not None + assert result["event"]["content"] == "waiting for this" + + def test_wait_caps_timeout(self, mcp_server_e2e, _event_loop): + """Timeout should be capped at 300000ms (5 min).""" + from mcp_serve import QueueEvent + server, bridge = mcp_server_e2e + bridge._enqueue(QueueEvent(cursor=0, type="message", session_key="t")) + # Even with huge timeout, should return immediately since event exists + result = _run_tool(server, "events_wait", {"timeout_ms": 999999}) + assert result["event"] is not None + + +class TestE2EMessagesSend: + def test_send_missing_args(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "messages_send", {"target": "", "message": "hi"}) + assert "error" in result + + def test_send_delegates_to_tool(self, mcp_server_e2e, _event_loop, monkeypatch): + server, _ = mcp_server_e2e + mock = MagicMock(return_value=json.dumps({"success": True, "platform": "telegram"})) + monkeypatch.setattr("tools.send_message_tool.send_message_tool", mock) + + result = _run_tool(server, "messages_send", + {"target": "telegram:123456", "message": "Hello!"}) + assert result["success"] is True + mock.assert_called_once() + call_args = mock.call_args[0][0] + assert call_args["action"] == "send" + assert call_args["target"] == "telegram:123456" + + +class TestE2EChannelsList: + def test_channels_from_sessions(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "channels_list") + assert result["count"] == 3 + targets = {c["target"] for c in result["channels"]} + assert "telegram:123456" in targets + assert "discord:789" in targets + assert "slack:C1234" in targets + + def test_channels_platform_filter(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "channels_list", {"platform": "slack"}) + assert result["count"] == 1 + assert result["channels"][0]["target"] == "slack:C1234" + + def test_channels_with_directory(self, mcp_server_e2e, _event_loop, monkeypatch): + import mcp_serve + monkeypatch.setattr(mcp_serve, "_load_channel_directory", lambda: { + "telegram": [ + {"id": "123456", "name": "Alice", "type": "dm"}, + {"id": "-100999", "name": "Dev Group", "type": "group"}, + ], + }) + # Need to recreate server to pick up the new mock + server, bridge = mcp_server_e2e + # The tool closure already captured the old mock, so test the function directly + directory = mcp_serve._load_channel_directory() + assert len(directory["telegram"]) == 2 + + +class TestE2EPermissions: + def test_list_empty(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "permissions_list_open") + assert result["count"] == 0 + assert result["approvals"] == [] + + def test_list_with_approvals(self, mcp_server_e2e, _event_loop): + server, bridge = mcp_server_e2e + bridge._pending_approvals["a1"] = { + "id": "a1", "kind": "exec", + "description": "sudo rm -rf /", + "session_key": "test", + "created_at": "2026-03-29T12:00:00", + } + result = _run_tool(server, "permissions_list_open") + assert result["count"] == 1 + assert result["approvals"][0]["id"] == "a1" + + def test_respond_allow(self, mcp_server_e2e, _event_loop): + server, bridge = mcp_server_e2e + bridge._pending_approvals["a1"] = {"id": "a1", "kind": "exec"} + result = _run_tool(server, "permissions_respond", + {"id": "a1", "decision": "allow-once"}) + assert result["resolved"] is True + assert result["decision"] == "allow-once" + # Should be gone now + check = _run_tool(server, "permissions_list_open") + assert check["count"] == 0 + + def test_respond_deny(self, mcp_server_e2e, _event_loop): + server, bridge = mcp_server_e2e + bridge._pending_approvals["a2"] = {"id": "a2", "kind": "plugin"} + result = _run_tool(server, "permissions_respond", + {"id": "a2", "decision": "deny"}) + assert result["resolved"] is True + + def test_respond_invalid_decision(self, mcp_server_e2e, _event_loop): + server, bridge = mcp_server_e2e + bridge._pending_approvals["a3"] = {"id": "a3", "kind": "exec"} + result = _run_tool(server, "permissions_respond", + {"id": "a3", "decision": "maybe"}) + assert "error" in result + + def test_respond_nonexistent(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + result = _run_tool(server, "permissions_respond", + {"id": "nope", "decision": "deny"}) + assert "error" in result + + +# --------------------------------------------------------------------------- +# 4. TOOL LISTING — verify all 10 tools are registered +# --------------------------------------------------------------------------- + +class TestToolRegistration: + def test_all_tools_registered(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + tools = server._tool_manager.list_tools() + tool_names = {t.name for t in tools} + + expected = { + "conversations_list", "conversation_get", "messages_read", + "attachments_fetch", "events_poll", "events_wait", + "messages_send", "channels_list", + "permissions_list_open", "permissions_respond", + } + assert expected == tool_names, f"Missing: {expected - tool_names}, Extra: {tool_names - expected}" + + def test_tools_have_descriptions(self, mcp_server_e2e, _event_loop): + server, _ = mcp_server_e2e + for tool in server._tool_manager.list_tools(): + assert tool.description, f"Tool {tool.name} has no description" + + +# --------------------------------------------------------------------------- +# 5. SERVER LIFECYCLE / CLI INTEGRATION +# --------------------------------------------------------------------------- + +class TestServerCreation: + def test_create_server(self, populated_sessions_dir, monkeypatch): + pytest.importorskip("mcp", reason="MCP SDK not installed") + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir) + assert mcp_serve.create_mcp_server() is not None + + def test_create_with_bridge(self, populated_sessions_dir, monkeypatch): + pytest.importorskip("mcp", reason="MCP SDK not installed") + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: populated_sessions_dir) + bridge = mcp_serve.EventBridge() + assert mcp_serve.create_mcp_server(event_bridge=bridge) is not None + + def test_create_without_mcp_sdk(self, monkeypatch): + import mcp_serve + monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", False) + with pytest.raises(ImportError, match="MCP server requires"): + mcp_serve.create_mcp_server() + + +class TestRunMcpServer: + def test_run_without_mcp_exits(self, monkeypatch): + import mcp_serve + monkeypatch.setattr(mcp_serve, "_MCP_SERVER_AVAILABLE", False) + with pytest.raises(SystemExit) as exc_info: + mcp_serve.run_mcp_server() + assert exc_info.value.code == 1 + + +class TestCliIntegration: + def test_parse_serve(self): + import argparse + parser = argparse.ArgumentParser() + subs = parser.add_subparsers(dest="command") + mcp_p = subs.add_parser("mcp") + mcp_sub = mcp_p.add_subparsers(dest="mcp_action") + serve_p = mcp_sub.add_parser("serve") + serve_p.add_argument("-v", "--verbose", action="store_true") + + args = parser.parse_args(["mcp", "serve"]) + assert args.mcp_action == "serve" + assert args.verbose is False + + def test_parse_serve_verbose(self): + import argparse + parser = argparse.ArgumentParser() + subs = parser.add_subparsers(dest="command") + mcp_p = subs.add_parser("mcp") + mcp_sub = mcp_p.add_subparsers(dest="mcp_action") + serve_p = mcp_sub.add_parser("serve") + serve_p.add_argument("-v", "--verbose", action="store_true") + + args = parser.parse_args(["mcp", "serve", "--verbose"]) + assert args.verbose is True + + def test_dispatcher_routes_serve(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + mock_run = MagicMock() + monkeypatch.setattr("mcp_serve.run_mcp_server", mock_run) + + import argparse + args = argparse.Namespace(mcp_action="serve", verbose=True) + from hermes_cli.mcp_config import mcp_command + mcp_command(args) + mock_run.assert_called_once_with(verbose=True) + + +# --------------------------------------------------------------------------- +# 6. EDGE CASES +# --------------------------------------------------------------------------- + +class TestEdgeCases: + def test_empty_sessions_json(self, sessions_dir, monkeypatch): + (sessions_dir / "sessions.json").write_text("{}") + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + assert mcp_serve._load_sessions_index() == {} + + def test_sessions_without_origin(self, sessions_dir, monkeypatch): + data = {"agent:main:telegram:dm:111": { + "session_key": "agent:main:telegram:dm:111", + "session_id": "20260329_120000_xyz", + "platform": "telegram", + "updated_at": "2026-03-29T12:00:00", + }} + (sessions_dir / "sessions.json").write_text(json.dumps(data)) + import mcp_serve + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + entries = mcp_serve._load_sessions_index() + assert entries["agent:main:telegram:dm:111"]["platform"] == "telegram" + + def test_bridge_start_stop(self): + from mcp_serve import EventBridge + b = EventBridge() + assert not b._running + b._running = True + b.stop() + assert not b._running + + def test_truncation(self): + assert len(("x" * 5000)[:2000]) == 2000 + + +# --------------------------------------------------------------------------- +# 7. EVENT BRIDGE POLL LOOP E2E — real SQLite DB, mtime optimization +# --------------------------------------------------------------------------- + +class TestEventBridgePollE2E: + """End-to-end tests for the EventBridge polling loop with real files.""" + + def test_poll_detects_new_messages(self, tmp_path, monkeypatch): + """Write to SQLite + sessions.json, verify EventBridge picks it up.""" + import mcp_serve + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + + session_id = "20260329_150000_poll_test" + db_path = tmp_path / "state.db" + + # Write sessions.json + sessions_data = { + "agent:main:telegram:dm:poll_test": { + "session_key": "agent:main:telegram:dm:poll_test", + "session_id": session_id, + "platform": "telegram", + "chat_type": "dm", + "display_name": "PollTest", + "updated_at": "2026-03-29T15:00:05", + "origin": {"platform": "telegram", "chat_id": "poll_test"}, + } + } + (sessions_dir / "sessions.json").write_text(json.dumps(sessions_data)) + + # Write messages to SQLite + messages = [ + {"role": "user", "content": "First message", + "timestamp": "2026-03-29T15:00:01"}, + {"role": "assistant", "content": "Reply", + "timestamp": "2026-03-29T15:00:03"}, + ] + _create_test_db(db_path, session_id, messages) + + # Create a mock SessionDB that reads our test DB + class TestDB: + def get_messages(self, sid): + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY id", + (sid,), + ).fetchall() + conn.close() + return [dict(r) for r in rows] + + monkeypatch.setattr(mcp_serve, "_get_session_db", lambda: TestDB()) + + bridge = mcp_serve.EventBridge() + # Run one poll cycle manually + bridge._poll_once(TestDB()) + + # Should have found the messages + result = bridge.poll_events(after_cursor=0) + assert len(result["events"]) == 2 + assert result["events"][0]["role"] == "user" + assert result["events"][0]["content"] == "First message" + assert result["events"][1]["role"] == "assistant" + + def test_poll_skips_when_unchanged(self, tmp_path, monkeypatch): + """Second poll with no file changes should be a no-op.""" + import mcp_serve + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + + session_id = "20260329_150000_skip_test" + db_path = tmp_path / "state.db" + + sessions_data = { + "agent:main:telegram:dm:skip": { + "session_key": "agent:main:telegram:dm:skip", + "session_id": session_id, + "platform": "telegram", + "updated_at": "2026-03-29T15:00:05", + "origin": {"platform": "telegram", "chat_id": "skip"}, + } + } + (sessions_dir / "sessions.json").write_text(json.dumps(sessions_data)) + _create_test_db(db_path, session_id, [ + {"role": "user", "content": "Hello", "timestamp": "2026-03-29T15:00:01"}, + ]) + + class TestDB: + def __init__(self): + self.call_count = 0 + + def get_messages(self, sid): + self.call_count += 1 + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY id", + (sid,), + ).fetchall() + conn.close() + return [dict(r) for r in rows] + + db = TestDB() + bridge = mcp_serve.EventBridge() + + # First poll — should process + bridge._poll_once(db) + first_calls = db.call_count + assert first_calls >= 1 + + # Second poll — files unchanged, should skip entirely + bridge._poll_once(db) + assert db.call_count == first_calls, \ + "Second poll should skip DB queries when files unchanged" + + def test_poll_detects_new_message_after_db_write(self, tmp_path, monkeypatch): + """Write a new message to the DB after first poll, verify it's detected.""" + import mcp_serve + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir() + monkeypatch.setattr(mcp_serve, "_get_sessions_dir", lambda: sessions_dir) + + session_id = "20260329_150000_new_msg" + db_path = tmp_path / "state.db" + + sessions_data = { + "agent:main:telegram:dm:new": { + "session_key": "agent:main:telegram:dm:new", + "session_id": session_id, + "platform": "telegram", + "updated_at": "2026-03-29T15:00:05", + "origin": {"platform": "telegram", "chat_id": "new"}, + } + } + (sessions_dir / "sessions.json").write_text(json.dumps(sessions_data)) + _create_test_db(db_path, session_id, [ + {"role": "user", "content": "First", "timestamp": "2026-03-29T15:00:01"}, + ]) + + class TestDB: + def get_messages(self, sid): + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY id", + (sid,), + ).fetchall() + conn.close() + return [dict(r) for r in rows] + + db = TestDB() + bridge = mcp_serve.EventBridge() + + # First poll + bridge._poll_once(db) + r1 = bridge.poll_events(after_cursor=0) + assert len(r1["events"]) == 1 + + # Add a new message to the DB + conn = sqlite3.connect(str(db_path)) + conn.execute( + "INSERT INTO messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)", + (session_id, "assistant", "New reply!", "2026-03-29T15:00:10"), + ) + conn.commit() + conn.close() + # Touch the DB file to update mtime (WAL mode may not update mtime on small writes) + os.utime(db_path, None) + + # Update sessions.json updated_at to trigger re-check + sessions_data["agent:main:telegram:dm:new"]["updated_at"] = "2026-03-29T15:00:10" + (sessions_dir / "sessions.json").write_text(json.dumps(sessions_data)) + + # Second poll — should detect the new message + bridge._poll_once(db) + r2 = bridge.poll_events(after_cursor=r1["next_cursor"]) + assert len(r2["events"]) == 1 + assert r2["events"][0]["content"] == "New reply!" + + def test_poll_interval_is_200ms(self): + """Verify the poll interval constant.""" + from mcp_serve import POLL_INTERVAL + assert POLL_INTERVAL == 0.2 diff --git a/tests/test_percentage_clamp.py b/tests/test_percentage_clamp.py new file mode 100644 index 000000000..67d119149 --- /dev/null +++ b/tests/test_percentage_clamp.py @@ -0,0 +1,154 @@ +"""Tests for percentage clamping at 100% across display paths. + +PR #3480 capped context pressure percentage at 100% in agent/display.py +but missed the same unclamped pattern in 4 other files. When token counts +overshoot the context length (possible during streaming or before +compression fires), users see >100% in /stats, gateway status, and +memory tool output. +""" + +import pytest + + +class TestContextCompressorUsagePercent: + """agent/context_compressor.py — get_status() usage_percent""" + + def test_usage_percent_capped_at_100(self): + """Tokens exceeding context_length should still show max 100%.""" + from agent.context_compressor import ContextCompressor + + comp = ContextCompressor.__new__(ContextCompressor) + comp.last_prompt_tokens = 210_000 # exceeds context_length + comp.context_length = 200_000 + comp.threshold_tokens = 160_000 + comp.compression_count = 0 + + status = comp.get_status() + assert status["usage_percent"] <= 100 + + def test_usage_percent_normal(self): + """Normal usage should show correct percentage.""" + from agent.context_compressor import ContextCompressor + + comp = ContextCompressor.__new__(ContextCompressor) + comp.last_prompt_tokens = 100_000 + comp.context_length = 200_000 + comp.threshold_tokens = 160_000 + comp.compression_count = 0 + + status = comp.get_status() + assert status["usage_percent"] == 50.0 + + def test_usage_percent_zero_context_length(self): + """Zero context_length should return 0, not crash.""" + from agent.context_compressor import ContextCompressor + + comp = ContextCompressor.__new__(ContextCompressor) + comp.last_prompt_tokens = 1000 + comp.context_length = 0 + comp.threshold_tokens = 0 + comp.compression_count = 0 + + status = comp.get_status() + assert status["usage_percent"] == 0 + + +class TestMemoryToolPercentClamp: + """tools/memory_tool.py — _success_response and _render_block pct""" + + def test_over_limit_clamped_at_100(self): + """Percentage should be capped at 100 even if current > limit.""" + # Simulate the calculation directly + current = 5500 + limit = 5000 + pct = min(100, int((current / limit) * 100)) if limit > 0 else 0 + assert pct == 100 + + def test_normal_percentage(self): + current = 2500 + limit = 5000 + pct = min(100, int((current / limit) * 100)) if limit > 0 else 0 + assert pct == 50 + + def test_zero_limit_returns_zero(self): + current = 100 + limit = 0 + pct = min(100, int((current / limit) * 100)) if limit > 0 else 0 + assert pct == 0 + + +class TestCLIStatsPercentClamp: + """cli.py — /stats command percentage""" + + def test_over_context_clamped_at_100(self): + """Tokens exceeding context_length should show max 100%.""" + last_prompt = 210_000 + ctx_len = 200_000 + pct = min(100, (last_prompt / ctx_len * 100)) if ctx_len else 0 + assert pct == 100 + + def test_normal_context(self): + last_prompt = 100_000 + ctx_len = 200_000 + pct = min(100, (last_prompt / ctx_len * 100)) if ctx_len else 0 + assert pct == 50.0 + + def test_zero_context_length(self): + last_prompt = 1000 + ctx_len = 0 + pct = min(100, (last_prompt / ctx_len * 100)) if ctx_len else 0 + assert pct == 0 + + +class TestGatewayStatsPercentClamp: + """gateway/run.py — _format_usage_stats percentage""" + + def test_over_context_clamped_at_100(self): + last_prompt_tokens = 210_000 + context_length = 200_000 + pct = min(100, last_prompt_tokens / context_length * 100) if context_length else 0 + assert pct == 100 + + def test_normal_context(self): + last_prompt_tokens = 150_000 + context_length = 200_000 + pct = min(100, last_prompt_tokens / context_length * 100) if context_length else 0 + assert pct == 75.0 + + +class TestSourceLinesAreClamped: + """Verify the actual source files have min(100, ...) applied.""" + + @staticmethod + def _read_file(rel_path: str) -> str: + import os + base = os.path.dirname(os.path.dirname(__file__)) + with open(os.path.join(base, rel_path)) as f: + return f.read() + + def test_context_compressor_clamped(self): + src = self._read_file("agent/context_compressor.py") + assert "min(100," in src, ( + "context_compressor.py usage_percent is not clamped with min(100, ...)" + ) + + def test_gateway_run_clamped(self): + src = self._read_file("gateway/run.py") + # Check that the stats handler has min(100, ...) + assert "min(100, ctx.last_prompt_tokens" in src, ( + "gateway/run.py stats pct is not clamped with min(100, ...)" + ) + + def test_cli_clamped(self): + src = self._read_file("cli.py") + assert "min(100, (last_prompt" in src, ( + "cli.py /stats pct is not clamped with min(100, ...)" + ) + + def test_memory_tool_clamped(self): + src = self._read_file("tools/memory_tool.py") + # Both _success_response and _render_block should have min(100, ...) + count = src.count("min(100, int((current / limit)") + assert count >= 2, ( + f"memory_tool.py has only {count} clamped pct lines, expected >= 2" + ) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f90853a81..0da5b640d 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -226,6 +226,42 @@ class TestPluginHooks: # Should not raise despite 1/0 mgr.invoke_hook("post_tool_call", tool_name="x", args={}, result="r", task_id="") + def test_hook_return_values_collected(self, tmp_path, monkeypatch): + """invoke_hook() collects non-None return values from callbacks.""" + plugins_dir = tmp_path / "hermes_test" / "plugins" + _make_plugin_dir( + plugins_dir, "ctx_plugin", + register_body=( + 'ctx.register_hook("pre_llm_call", ' + 'lambda **kw: {"context": "memory from plugin"})' + ), + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test")) + + mgr = PluginManager() + mgr.discover_and_load() + + results = mgr.invoke_hook("pre_llm_call", session_id="s1", user_message="hi", + conversation_history=[], is_first_turn=True, model="test") + assert len(results) == 1 + assert results[0] == {"context": "memory from plugin"} + + def test_hook_none_returns_excluded(self, tmp_path, monkeypatch): + """invoke_hook() excludes None returns from the result list.""" + plugins_dir = tmp_path / "hermes_test" / "plugins" + _make_plugin_dir( + plugins_dir, "none_hook", + register_body='ctx.register_hook("post_llm_call", lambda **kw: None)', + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test")) + + mgr = PluginManager() + mgr.discover_and_load() + + results = mgr.invoke_hook("post_llm_call", session_id="s1", + user_message="hi", assistant_response="bye", model="test") + assert results == [] + def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog): """Registering an unknown hook name logs a warning.""" plugins_dir = tmp_path / "hermes_test" / "plugins" diff --git a/tests/test_plugins_cmd.py b/tests/test_plugins_cmd.py index e93e2dc50..ac95571be 100644 --- a/tests/test_plugins_cmd.py +++ b/tests/test_plugins_cmd.py @@ -150,11 +150,11 @@ class TestPluginsCommandDispatch: plugins_command(args) mock_list.assert_called_once() - @patch("hermes_cli.plugins_cmd.cmd_list") - def test_none_falls_through_to_list(self, mock_list): + @patch("hermes_cli.plugins_cmd.cmd_toggle") + def test_none_falls_through_to_toggle(self, mock_toggle): args = self._make_args(None) plugins_command(args) - mock_list.assert_called_once() + mock_toggle.assert_called_once() @patch("hermes_cli.plugins_cmd.cmd_install") def test_install_dispatches(self, mock_install): diff --git a/tests/test_provider_fallback.py b/tests/test_provider_fallback.py new file mode 100644 index 000000000..2bb210955 --- /dev/null +++ b/tests/test_provider_fallback.py @@ -0,0 +1,156 @@ +"""Tests for ordered provider fallback chain (salvage of PR #1761). + +Extends the single-fallback tests in test_fallback_model.py to cover +the new list-based ``fallback_providers`` config format and chain +advancement through multiple providers. +""" + +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent + + +def _make_agent(fallback_model=None): + """Create a minimal AIAgent with optional fallback config.""" + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-key", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + fallback_model=fallback_model, + ) + agent.client = MagicMock() + return agent + + +def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"): + mock = MagicMock() + mock.base_url = base_url + mock.api_key = api_key + return mock + + +# ── Chain initialisation ────────────────────────────────────────────────── + + +class TestFallbackChainInit: + def test_no_fallback(self): + agent = _make_agent(fallback_model=None) + assert agent._fallback_chain == [] + assert agent._fallback_index == 0 + assert agent._fallback_model is None + + def test_single_dict_backwards_compat(self): + fb = {"provider": "openai", "model": "gpt-4o"} + agent = _make_agent(fallback_model=fb) + assert agent._fallback_chain == [fb] + assert agent._fallback_model == fb + + def test_list_of_providers(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + assert len(agent._fallback_chain) == 2 + assert agent._fallback_model == fbs[0] + + def test_invalid_entries_filtered(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "", "model": "glm-4.7"}, + {"provider": "zai"}, + "not-a-dict", + ] + agent = _make_agent(fallback_model=fbs) + assert len(agent._fallback_chain) == 1 + assert agent._fallback_chain[0]["provider"] == "openai" + + def test_empty_list(self): + agent = _make_agent(fallback_model=[]) + assert agent._fallback_chain == [] + assert agent._fallback_model is None + + def test_invalid_dict_no_provider(self): + agent = _make_agent(fallback_model={"model": "gpt-4o"}) + assert agent._fallback_chain == [] + + +# ── Chain advancement ───────────────────────────────────────────────────── + + +class TestFallbackChainAdvancement: + def test_exhausted_returns_false(self): + agent = _make_agent(fallback_model=None) + assert agent._try_activate_fallback() is False + + def test_advances_index(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "gpt-4o")): + assert agent._try_activate_fallback() is True + assert agent._fallback_index == 1 + assert agent.model == "gpt-4o" + assert agent._fallback_activated is True + + def test_second_fallback_works(self): + fbs = [ + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "zai", "model": "glm-4.7"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "resolved")): + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" + assert agent._try_activate_fallback() is True + assert agent.model == "glm-4.7" + assert agent._fallback_index == 2 + + def test_all_exhausted_returns_false(self): + fbs = [{"provider": "openai", "model": "gpt-4o"}] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client", + return_value=(_mock_client(), "gpt-4o")): + assert agent._try_activate_fallback() is True + assert agent._try_activate_fallback() is False + + def test_skips_unconfigured_provider_to_next(self): + """If resolve_provider_client returns None, skip to next in chain.""" + fbs = [ + {"provider": "broken", "model": "nope"}, + {"provider": "openai", "model": "gpt-4o"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: + mock_rpc.side_effect = [ + (None, None), # broken provider + (_mock_client(), "gpt-4o"), # fallback succeeds + ] + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" + assert agent._fallback_index == 2 + + def test_skips_provider_that_raises_to_next(self): + """If resolve_provider_client raises, skip to next in chain.""" + fbs = [ + {"provider": "broken", "model": "nope"}, + {"provider": "openai", "model": "gpt-4o"}, + ] + agent = _make_agent(fallback_model=fbs) + with patch("agent.auxiliary_client.resolve_provider_client") as mock_rpc: + mock_rpc.side_effect = [ + RuntimeError("auth failed"), + (_mock_client(), "gpt-4o"), + ] + assert agent._try_activate_fallback() is True + assert agent.model == "gpt-4o" diff --git a/tests/test_reasoning_command.py b/tests/test_reasoning_command.py index 81d452a27..4270d630d 100644 --- a/tests/test_reasoning_command.py +++ b/tests/test_reasoning_command.py @@ -472,6 +472,7 @@ class TestInlineThinkBlockExtraction(unittest.TestCase): agent._extract_reasoning = AIAgent._extract_reasoning.__get__(agent) agent.verbose_logging = False agent.reasoning_callback = None + agent.stream_delta_callback = None # non-streaming by default return agent def test_single_think_block_extracted(self): @@ -605,5 +606,159 @@ class TestEndToEndPipeline(unittest.TestCase): self.assertIsNone(result["last_reasoning"]) +# --------------------------------------------------------------------------- +# Duplicate reasoning box prevention (Bug fix: 3 boxes for 1 reasoning) +# --------------------------------------------------------------------------- + +class TestReasoningDeltasFiredFlag(unittest.TestCase): + """_build_assistant_message should not re-fire reasoning_callback when + reasoning was already streamed via _fire_reasoning_delta.""" + + def _make_agent(self): + from run_agent import AIAgent + agent = AIAgent.__new__(AIAgent) + agent.reasoning_callback = None + agent.stream_delta_callback = None + agent._reasoning_deltas_fired = False + agent.verbose_logging = False + return agent + + def test_fire_reasoning_delta_sets_flag(self): + agent = self._make_agent() + captured = [] + agent.reasoning_callback = lambda t: captured.append(t) + self.assertFalse(agent._reasoning_deltas_fired) + agent._fire_reasoning_delta("thinking...") + self.assertTrue(agent._reasoning_deltas_fired) + self.assertEqual(captured, ["thinking..."]) + + def test_build_assistant_message_skips_callback_when_already_streamed(self): + """When streaming already fired reasoning deltas, the post-stream + _build_assistant_message should NOT re-fire the callback.""" + agent = self._make_agent() + captured = [] + agent.reasoning_callback = lambda t: captured.append(t) + agent.stream_delta_callback = lambda t: None # streaming is active + + # Simulate streaming having fired reasoning + agent._reasoning_deltas_fired = True + + msg = SimpleNamespace( + content="I'll merge that.", + tool_calls=None, + reasoning_content="Let me merge the PR.", + reasoning=None, + reasoning_details=None, + ) + agent._build_assistant_message(msg, "stop") + + # Callback should NOT have been fired again + self.assertEqual(captured, []) + + def test_build_assistant_message_skips_callback_when_streaming_active(self): + """When streaming is active, callback should NEVER fire from + _build_assistant_message — reasoning was already displayed during the + stream (either via reasoning_content deltas or content tag extraction). + Any missed reasoning is caught by the CLI post-response fallback.""" + agent = self._make_agent() + captured = [] + agent.reasoning_callback = lambda t: captured.append(t) + agent.stream_delta_callback = lambda t: None # streaming active + + # Even though _reasoning_deltas_fired is False (reasoning came through + # content tags, not reasoning_content deltas), callback should not fire + agent._reasoning_deltas_fired = False + + msg = SimpleNamespace( + content="I'll merge that.", + tool_calls=None, + reasoning_content="Let me merge the PR.", + reasoning=None, + reasoning_details=None, + ) + agent._build_assistant_message(msg, "stop") + + # Callback should NOT fire — streaming is active + self.assertEqual(captured, []) + + def test_build_assistant_message_fires_callback_without_streaming(self): + """When no streaming is active, callback always fires for structured + reasoning.""" + agent = self._make_agent() + captured = [] + agent.reasoning_callback = lambda t: captured.append(t) + # No streaming + agent.stream_delta_callback = None + agent._reasoning_deltas_fired = False + + msg = SimpleNamespace( + content="I'll merge that.", + tool_calls=None, + reasoning_content="Let me merge the PR.", + reasoning=None, + reasoning_details=None, + ) + agent._build_assistant_message(msg, "stop") + + self.assertEqual(captured, ["Let me merge the PR."]) + + +class TestReasoningShownThisTurnFlag(unittest.TestCase): + """Post-response reasoning display should be suppressed when reasoning + was already shown during streaming in a tool-calling loop.""" + + def _make_cli(self): + from cli import HermesCLI + cli = HermesCLI.__new__(HermesCLI) + cli.show_reasoning = True + cli.streaming_enabled = True + cli._stream_box_opened = False + cli._reasoning_box_opened = False + cli._reasoning_stream_started = False + cli._reasoning_shown_this_turn = False + cli._reasoning_buf = "" + cli._stream_buf = "" + cli._stream_started = False + cli._stream_text_ansi = "" + cli._stream_prefilt = "" + cli._in_reasoning_block = False + cli._reasoning_preview_buf = "" + return cli + + @patch("cli._cprint") + def test_streaming_reasoning_sets_turn_flag(self, mock_cprint): + cli = self._make_cli() + self.assertFalse(cli._reasoning_shown_this_turn) + cli._stream_reasoning_delta("Thinking about it...") + self.assertTrue(cli._reasoning_shown_this_turn) + + @patch("cli._cprint") + def test_turn_flag_survives_reset_stream_state(self, mock_cprint): + """_reasoning_shown_this_turn must NOT be cleared by + _reset_stream_state (called at intermediate turn boundaries).""" + cli = self._make_cli() + cli._stream_reasoning_delta("Thinking...") + self.assertTrue(cli._reasoning_shown_this_turn) + + # Simulate intermediate turn boundary (tool call) + cli._reset_stream_state() + + # Flag must persist + self.assertTrue(cli._reasoning_shown_this_turn) + + @patch("cli._cprint") + def test_turn_flag_cleared_before_new_turn(self, mock_cprint): + """The turn flag should be reset at the start of a new user turn. + This happens outside _reset_stream_state, at the call site.""" + cli = self._make_cli() + cli._reasoning_shown_this_turn = True + + # Simulate new user turn setup + cli._reset_stream_state() + cli._reasoning_shown_this_turn = False # done by process_input + + self.assertFalse(cli._reasoning_shown_this_turn) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index cfed4afbc..a17d8b60d 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -589,6 +589,164 @@ class TestBuildSystemPrompt: prompt = agent._build_system_prompt() assert "NOUS SUBSCRIPTION BLOCK" in prompt + def test_skills_prompt_derives_available_toolsets_from_loaded_tools(self): + tools = _make_tool_defs("web_search", "skills_list", "skill_view", "skill_manage") + toolset_map = { + "web_search": "web", + "skills_list": "skills", + "skill_view": "skills", + "skill_manage": "skills", + } + + with ( + patch("run_agent.get_tool_definitions", return_value=tools), + patch( + "run_agent.check_toolset_requirements", + side_effect=AssertionError("should not re-check toolset requirements"), + ), + patch("run_agent.get_toolset_for_tool", create=True, side_effect=toolset_map.get), + patch("run_agent.build_skills_system_prompt", return_value="SKILLS_PROMPT") as mock_skills, + patch("run_agent.OpenAI"), + ): + agent = AIAgent( + api_key="test-k...7890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + + prompt = agent._build_system_prompt() + + assert "SKILLS_PROMPT" in prompt + assert mock_skills.call_args.kwargs["available_tools"] == set(toolset_map) + assert mock_skills.call_args.kwargs["available_toolsets"] == {"web", "skills"} + + +class TestToolUseEnforcementConfig: + """Tests for the agent.tool_use_enforcement config option.""" + + def _make_agent(self, model="openai/gpt-4.1", tool_use_enforcement="auto"): + """Create an agent with tools and a specific enforcement config.""" + with ( + patch( + "run_agent.get_tool_definitions", + return_value=_make_tool_defs("terminal", "web_search"), + ), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch( + "hermes_cli.config.load_config", + return_value={"agent": {"tool_use_enforcement": tool_use_enforcement}}, + ), + ): + a = AIAgent( + model=model, + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + return a + + def test_auto_injects_for_gpt(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="openai/gpt-4.1", tool_use_enforcement="auto") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_auto_injects_for_codex(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="openai/codex-mini", tool_use_enforcement="auto") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_auto_skips_for_claude(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="anthropic/claude-sonnet-4", tool_use_enforcement="auto") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE not in prompt + + def test_true_forces_for_all_models(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="anthropic/claude-sonnet-4", tool_use_enforcement=True) + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_string_true_forces_for_all_models(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="anthropic/claude-sonnet-4", tool_use_enforcement="true") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_always_forces_for_all_models(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="deepseek/deepseek-r1", tool_use_enforcement="always") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_false_disables_for_gpt(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="openai/gpt-4.1", tool_use_enforcement=False) + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE not in prompt + + def test_string_false_disables(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent(model="openai/gpt-4.1", tool_use_enforcement="off") + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE not in prompt + + def test_custom_list_matches(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent( + model="deepseek/deepseek-r1", + tool_use_enforcement=["deepseek", "gemini"], + ) + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_custom_list_no_match(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent( + model="anthropic/claude-sonnet-4", + tool_use_enforcement=["deepseek", "gemini"], + ) + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE not in prompt + + def test_custom_list_case_insensitive(self): + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + agent = self._make_agent( + model="openai/GPT-4.1", + tool_use_enforcement=["GPT", "Codex"], + ) + prompt = agent._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE in prompt + + def test_no_tools_never_injects(self): + """Even with enforcement=true, no injection when agent has no tools.""" + from agent.prompt_builder import TOOL_USE_ENFORCEMENT_GUIDANCE + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch( + "hermes_cli.config.load_config", + return_value={"agent": {"tool_use_enforcement": True}}, + ), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + enabled_toolsets=[], + ) + a.client = MagicMock() + prompt = a._build_system_prompt() + assert TOOL_USE_ENFORCEMENT_GUIDANCE not in prompt + class TestInvalidateSystemPrompt: def test_clears_cache(self, agent): @@ -610,7 +768,7 @@ class TestBuildApiKwargs: kwargs = agent._build_api_kwargs(messages) assert kwargs["model"] == agent.model assert kwargs["messages"] is messages - assert kwargs["timeout"] == 900.0 + assert kwargs["timeout"] == 1800.0 def test_provider_preferences_injected(self, agent): agent.providers_allowed = ["Anthropic"] @@ -1345,19 +1503,11 @@ class TestRunConversation: assert result["final_response"] == "Recovered after compression" assert result["completed"] is True - @pytest.mark.parametrize( - ("first_content", "second_content", "expected_final"), - [ - ("Part 1 ", "Part 2", "Part 1 Part 2"), - ("internal reasoning", "Recovered final answer", "Recovered final answer"), - ], - ) - def test_length_finish_reason_requests_continuation( - self, agent, first_content, second_content, expected_final - ): + def test_length_finish_reason_requests_continuation(self, agent): + """Normal truncation (partial real content) triggers continuation.""" self._setup_agent(agent) - first = _mock_response(content=first_content, finish_reason="length") - second = _mock_response(content=second_content, finish_reason="stop") + first = _mock_response(content="Part 1 ", finish_reason="length") + second = _mock_response(content="Part 2", finish_reason="stop") agent.client.chat.completions.create.side_effect = [first, second] with ( @@ -1369,12 +1519,58 @@ class TestRunConversation: assert result["completed"] is True assert result["api_calls"] == 2 - assert result["final_response"] == expected_final + assert result["final_response"] == "Part 1 Part 2" second_call_messages = agent.client.chat.completions.create.call_args_list[1].kwargs["messages"] assert second_call_messages[-1]["role"] == "user" assert "truncated by the output length limit" in second_call_messages[-1]["content"] + def test_length_thinking_exhausted_skips_continuation(self, agent): + """When finish_reason='length' but content is only thinking, skip retries.""" + self._setup_agent(agent) + resp = _mock_response( + content="internal reasoning", + finish_reason="length", + ) + agent.client.chat.completions.create.return_value = resp + + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("hello") + + # Should return immediately — no continuation, only 1 API call + assert result["completed"] is False + assert result["api_calls"] == 1 + assert "reasoning" in result["error"].lower() + assert "output tokens" in result["error"].lower() + # Should have a user-friendly response (not None) + assert result["final_response"] is not None + assert "Thinking Budget Exhausted" in result["final_response"] + assert "/thinkon" in result["final_response"] + + def test_length_empty_content_detected_as_thinking_exhausted(self, agent): + """When finish_reason='length' and content is None/empty, detect exhaustion.""" + self._setup_agent(agent) + resp = _mock_response(content=None, finish_reason="length") + agent.client.chat.completions.create.return_value = resp + + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("hello") + + assert result["completed"] is False + assert result["api_calls"] == 1 + assert "reasoning" in result["error"].lower() + # User-friendly message is returned + assert result["final_response"] is not None + assert "Thinking Budget Exhausted" in result["final_response"] + class TestRetryExhaustion: """Regression: retry_count > max_retries was dead code (off-by-one). @@ -2316,6 +2512,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_anthropic_sets_api_mode(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -2337,6 +2535,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_anthropic_enables_prompt_caching(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-20250514"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -2354,6 +2554,8 @@ class TestFallbackAnthropicProvider: def test_fallback_to_openrouter_uses_openai_client(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://openrouter.ai/api/v1" @@ -2602,6 +2804,50 @@ class TestStreamingApiCall: assert tc[0].function.name == "search" assert tc[1].function.name == "read" + def test_ollama_reused_index_separate_tool_calls(self, agent): + """Ollama sends every tool call at index 0 with different ids. + + Without the fix, names and arguments get concatenated into one slot. + """ + chunks = [ + _make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{"q":"hello"}')]), + # Second tool call at the SAME index 0, but different id + _make_chunk(tool_calls=[_make_tc_delta(0, "call_b", "read_file", '{"path":"x.py"}')]), + _make_chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._interruptible_streaming_api_call({"messages": []}) + + tc = resp.choices[0].message.tool_calls + assert len(tc) == 2, f"Expected 2 tool calls, got {len(tc)}: {[t.function.name for t in tc]}" + assert tc[0].function.name == "search" + assert tc[0].function.arguments == '{"q":"hello"}' + assert tc[0].id == "call_a" + assert tc[1].function.name == "read_file" + assert tc[1].function.arguments == '{"path":"x.py"}' + assert tc[1].id == "call_b" + + def test_ollama_reused_index_streamed_args(self, agent): + """Ollama with streamed arguments across multiple chunks at same index.""" + chunks = [ + _make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{"q":')]), + _make_chunk(tool_calls=[_make_tc_delta(0, None, None, '"hello"}')]), + # New tool call, same index 0 + _make_chunk(tool_calls=[_make_tc_delta(0, "call_b", "read", '{}')]), + _make_chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._interruptible_streaming_api_call({"messages": []}) + + tc = resp.choices[0].message.tool_calls + assert len(tc) == 2 + assert tc[0].function.name == "search" + assert tc[0].function.arguments == '{"q":"hello"}' + assert tc[1].function.name == "read" + assert tc[1].function.arguments == '{}' + def test_content_and_tool_calls_together(self, agent): chunks = [ _make_chunk(content="I'll search"), @@ -3003,6 +3249,8 @@ class TestFallbackSetsOAuthFlag: def test_fallback_to_anthropic_oauth_sets_flag(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" @@ -3024,6 +3272,8 @@ class TestFallbackSetsOAuthFlag: def test_fallback_to_anthropic_api_key_clears_flag(self, agent): agent._fallback_activated = False agent._fallback_model = {"provider": "anthropic", "model": "claude-sonnet-4-6"} + agent._fallback_chain = [agent._fallback_model] + agent._fallback_index = 0 mock_client = MagicMock() mock_client.base_url = "https://api.anthropic.com/v1" diff --git a/tests/test_runtime_provider_resolution.py b/tests/test_runtime_provider_resolution.py index 3597986bf..84b018333 100644 --- a/tests/test_runtime_provider_resolution.py +++ b/tests/test_runtime_provider_resolution.py @@ -493,22 +493,22 @@ def test_minimax_default_url_uses_anthropic_messages(monkeypatch): assert resolved["base_url"] == "https://api.minimax.io/anthropic" -def test_minimax_stale_v1_url_auto_corrected(monkeypatch): - """MiniMax with stale /v1 base URL should be auto-corrected to /anthropic.""" +def test_minimax_v1_url_uses_chat_completions(monkeypatch): + """MiniMax with /v1 base URL should use chat_completions (user override for regions where /anthropic 404s).""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax") monkeypatch.setattr(rp, "_get_model_config", lambda: {}) monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key") - monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.io/v1") + monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.chat/v1") resolved = rp.resolve_runtime_provider(requested="minimax") assert resolved["provider"] == "minimax" - assert resolved["api_mode"] == "anthropic_messages" - assert resolved["base_url"] == "https://api.minimax.io/anthropic" + assert resolved["api_mode"] == "chat_completions" + assert resolved["base_url"] == "https://api.minimax.chat/v1" -def test_minimax_cn_stale_v1_url_auto_corrected(monkeypatch): - """MiniMax-CN with stale /v1 base URL should be auto-corrected to /anthropic.""" +def test_minimax_cn_v1_url_uses_chat_completions(monkeypatch): + """MiniMax-CN with /v1 base URL should use chat_completions (user override).""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-cn") monkeypatch.setattr(rp, "_get_model_config", lambda: {}) monkeypatch.setenv("MINIMAX_CN_API_KEY", "test-minimax-cn-key") @@ -517,8 +517,8 @@ def test_minimax_cn_stale_v1_url_auto_corrected(monkeypatch): resolved = rp.resolve_runtime_provider(requested="minimax-cn") assert resolved["provider"] == "minimax-cn" - assert resolved["api_mode"] == "anthropic_messages" - assert resolved["base_url"] == "https://api.minimaxi.com/anthropic" + assert resolved["api_mode"] == "chat_completions" + assert resolved["base_url"] == "https://api.minimaxi.com/v1" def test_minimax_explicit_api_mode_respected(monkeypatch): @@ -534,8 +534,8 @@ def test_minimax_explicit_api_mode_respected(monkeypatch): assert resolved["api_mode"] == "chat_completions" -def test_alibaba_default_anthropic_endpoint_uses_anthropic_messages(monkeypatch): - """Alibaba with default /apps/anthropic URL should use anthropic_messages mode.""" +def test_alibaba_default_coding_intl_endpoint_uses_chat_completions(monkeypatch): + """Alibaba default coding-intl /v1 URL should use chat_completions mode.""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba") monkeypatch.setattr(rp, "_get_model_config", lambda: {}) monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key") @@ -544,22 +544,22 @@ def test_alibaba_default_anthropic_endpoint_uses_anthropic_messages(monkeypatch) resolved = rp.resolve_runtime_provider(requested="alibaba") assert resolved["provider"] == "alibaba" - assert resolved["api_mode"] == "anthropic_messages" - assert resolved["base_url"] == "https://dashscope-intl.aliyuncs.com/apps/anthropic" + assert resolved["api_mode"] == "chat_completions" + assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1" -def test_alibaba_openai_compatible_v1_endpoint_stays_chat_completions(monkeypatch): - """Alibaba with /v1 coding endpoint should use chat_completions mode.""" +def test_alibaba_anthropic_endpoint_override_uses_anthropic_messages(monkeypatch): + """Alibaba with /apps/anthropic URL override should auto-detect anthropic_messages mode.""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba") monkeypatch.setattr(rp, "_get_model_config", lambda: {}) monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key") - monkeypatch.setenv("DASHSCOPE_BASE_URL", "https://coding-intl.dashscope.aliyuncs.com/v1") + monkeypatch.setenv("DASHSCOPE_BASE_URL", "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic") resolved = rp.resolve_runtime_provider(requested="alibaba") assert resolved["provider"] == "alibaba" - assert resolved["api_mode"] == "chat_completions" - assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1" + assert resolved["api_mode"] == "anthropic_messages" + assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic" def test_named_custom_provider_anthropic_api_mode(monkeypatch): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 88e3aa9e8..107a8a4d4 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -362,9 +362,11 @@ class TestStreamingCallbacks: # Text before tool call IS fired (we don't know yet it will have tools) assert "thinking..." in deltas - # Text after tool call is NOT fired - assert " more text" not in deltas - # But content is still accumulated in the response + # Text after tool call IS still routed to stream_delta_callback so that + # reasoning tag extraction can fire (PR #3566). Display-level suppression + # of non-reasoning text happens in the CLI's _stream_delta, not here. + assert " more text" in deltas + # Content is still accumulated in the response assert response.choices[0].message.content == "thinking... more text" @@ -532,6 +534,121 @@ class TestStreamingFallback: mock_non_stream.assert_called_once() assert mock_close.call_count >= 1 + @patch("run_agent.AIAgent._interruptible_api_call") + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create, mock_non_stream): + """SSE 'Network connection lost' (APIError w/ no status_code) retries like httpx errors. + + OpenRouter sends {"error":{"message":"Network connection lost."}} as an SSE + event when the upstream stream drops. The OpenAI SDK raises APIError from + this. It should be retried at the streaming level, same as httpx connection + errors, before falling back to non-streaming. + """ + from run_agent import AIAgent + import httpx + + # Create an APIError that mimics what the OpenAI SDK raises from SSE error events. + # Key: no status_code attribute (unlike APIStatusError which has one). + from openai import APIError as OAIAPIError + sse_error = OAIAPIError( + message="Network connection lost.", + request=httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions"), + body={"message": "Network connection lost."}, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = sse_error + mock_create.return_value = mock_client + + fallback_response = SimpleNamespace( + id="fallback", + model="test", + choices=[SimpleNamespace( + index=0, + message=SimpleNamespace( + role="assistant", + content="fallback after SSE retries", + tool_calls=None, + reasoning_content=None, + ), + finish_reason="stop", + )], + usage=None, + ) + mock_non_stream.return_value = fallback_response + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "fallback after SSE retries" + # Should retry 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts) + # before falling back to non-streaming + assert mock_client.chat.completions.create.call_count == 3 + mock_non_stream.assert_called_once() + # Connection cleanup should happen for each failed retry + assert mock_close.call_count >= 2 + + @patch("run_agent.AIAgent._interruptible_api_call") + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_sse_non_connection_error_falls_back_immediately(self, mock_close, mock_create, mock_non_stream): + """SSE errors that aren't connection-related still fall back immediately (no stream retry).""" + from run_agent import AIAgent + import httpx + + from openai import APIError as OAIAPIError + sse_error = OAIAPIError( + message="Invalid model configuration.", + request=httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions"), + body={"message": "Invalid model configuration."}, + ) + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = sse_error + mock_create.return_value = mock_client + + fallback_response = SimpleNamespace( + id="fallback", + model="test", + choices=[SimpleNamespace( + index=0, + message=SimpleNamespace( + role="assistant", + content="fallback no retry", + tool_calls=None, + reasoning_content=None, + ), + finish_reason="stop", + )], + usage=None, + ) + mock_non_stream.return_value = fallback_response + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "fallback no retry" + # Should NOT retry — goes straight to non-streaming fallback + assert mock_client.chat.completions.create.call_count == 1 + mock_non_stream.assert_called_once() + # ── Test: Reasoning Streaming ──────────────────────────────────────────── diff --git a/tests/test_surrogate_sanitization.py b/tests/test_surrogate_sanitization.py new file mode 100644 index 000000000..defad587e --- /dev/null +++ b/tests/test_surrogate_sanitization.py @@ -0,0 +1,154 @@ +"""Tests for surrogate character sanitization in user input. + +Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps() +inside the OpenAI SDK. They can appear via clipboard paste from rich-text +editors like Google Docs. +""" +import json +import pytest +from unittest.mock import MagicMock, patch + +from run_agent import ( + _sanitize_surrogates, + _sanitize_messages_surrogates, + _SURROGATE_RE, +) + + +class TestSanitizeSurrogates: + """Test the _sanitize_surrogates() helper.""" + + def test_normal_text_unchanged(self): + text = "Hello, this is normal text with unicode: café ñ 日本語 🎉" + assert _sanitize_surrogates(text) == text + + def test_empty_string(self): + assert _sanitize_surrogates("") == "" + + def test_single_surrogate_replaced(self): + result = _sanitize_surrogates("Hello \udce2 world") + assert result == "Hello \ufffd world" + + def test_multiple_surrogates_replaced(self): + result = _sanitize_surrogates("a\ud800b\udc00c\udfff") + assert result == "a\ufffdb\ufffdc\ufffd" + + def test_all_surrogate_range(self): + """Verify the regex catches the full surrogate range.""" + for cp in [0xD800, 0xD900, 0xDA00, 0xDB00, 0xDC00, 0xDD00, 0xDE00, 0xDF00, 0xDFFF]: + text = f"test{chr(cp)}end" + result = _sanitize_surrogates(text) + assert '\ufffd' in result, f"Surrogate U+{cp:04X} not caught" + + def test_result_is_json_serializable(self): + """Sanitized text must survive json.dumps + utf-8 encoding.""" + dirty = "data \udce2\udcb0 from clipboard" + clean = _sanitize_surrogates(dirty) + serialized = json.dumps({"content": clean}, ensure_ascii=False) + # Must not raise UnicodeEncodeError + serialized.encode("utf-8") + + def test_original_surrogates_fail_encoding(self): + """Confirm the original bug: surrogates crash utf-8 encoding.""" + dirty = "data \udce2 from clipboard" + serialized = json.dumps({"content": dirty}, ensure_ascii=False) + with pytest.raises(UnicodeEncodeError): + serialized.encode("utf-8") + + +class TestSanitizeMessagesSurrogates: + """Test the _sanitize_messages_surrogates() helper for message lists.""" + + def test_clean_messages_returns_false(self): + msgs = [ + {"role": "user", "content": "all clean"}, + {"role": "assistant", "content": "me too"}, + ] + assert _sanitize_messages_surrogates(msgs) is False + + def test_dirty_string_content_sanitized(self): + msgs = [ + {"role": "user", "content": "text with \udce2 surrogate"}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\ufffd" in msgs[0]["content"] + assert "\udce2" not in msgs[0]["content"] + + def test_dirty_multimodal_content_sanitized(self): + msgs = [ + {"role": "user", "content": [ + {"type": "text", "text": "multimodal \udce2 content"}, + {"type": "image_url", "image_url": {"url": "http://example.com"}}, + ]}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\ufffd" in msgs[0]["content"][0]["text"] + assert "\udce2" not in msgs[0]["content"][0]["text"] + + def test_mixed_clean_and_dirty(self): + msgs = [ + {"role": "user", "content": "clean text"}, + {"role": "user", "content": "dirty \udce2 text"}, + {"role": "assistant", "content": "clean response"}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert msgs[0]["content"] == "clean text" + assert "\ufffd" in msgs[1]["content"] + assert msgs[2]["content"] == "clean response" + + def test_non_dict_items_skipped(self): + msgs = ["not a dict", {"role": "user", "content": "ok"}] + assert _sanitize_messages_surrogates(msgs) is False + + def test_tool_messages_sanitized(self): + """Tool results could also contain surrogates from file reads etc.""" + msgs = [ + {"role": "tool", "content": "result with \udce2 data", "tool_call_id": "x"}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\ufffd" in msgs[0]["content"] + + +class TestRunConversationSurrogateSanitization: + """Integration: verify run_conversation sanitizes user_message.""" + + @patch("run_agent.AIAgent._build_system_prompt") + @patch("run_agent.AIAgent._interruptible_streaming_api_call") + @patch("run_agent.AIAgent._interruptible_api_call") + def test_user_message_surrogates_sanitized(self, mock_api, mock_stream, mock_sys): + """Surrogates in user_message are stripped before API call.""" + from run_agent import AIAgent + + mock_sys.return_value = "system prompt" + + # Mock streaming to return a simple response + mock_choice = MagicMock() + mock_choice.message.content = "response" + mock_choice.message.tool_calls = None + mock_choice.message.refusal = None + mock_choice.finish_reason = "stop" + mock_choice.message.reasoning_content = None + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + mock_response.model = "test-model" + mock_response.id = "test-id" + + mock_stream.return_value = mock_response + mock_api.return_value = mock_response + + agent = AIAgent(model="test/model", quiet_mode=True, skip_memory=True, skip_context_files=True) + agent.client = MagicMock() + + # Pass a message with surrogates + result = agent.run_conversation( + user_message="test \udce2 message", + conversation_history=[], + ) + + # The message stored in history should have surrogates replaced + for msg in result.get("messages", []): + if msg.get("role") == "user": + assert "\udce2" not in msg["content"], "Surrogate leaked into stored message" + assert "\ufffd" in msg["content"], "Replacement char not in stored message" diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index b973cb0f0..abdda05fa 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -339,6 +339,16 @@ class TestTeePattern: assert dangerous is True assert key is not None + def test_tee_custom_hermes_home_env(self): + dangerous, key, desc = detect_dangerous_command("echo x | tee $HERMES_HOME/.env") + assert dangerous is True + assert key is not None + + def test_tee_quoted_custom_hermes_home_env(self): + dangerous, key, desc = detect_dangerous_command('echo x | tee "$HERMES_HOME/.env"') + assert dangerous is True + assert key is not None + def test_tee_tmp_safe(self): dangerous, key, desc = detect_dangerous_command("echo hello | tee /tmp/output.txt") assert dangerous is False @@ -374,6 +384,30 @@ class TestFindExecFullPathRm: assert key is None +class TestSensitiveRedirectPattern: + """Detect shell redirection writes to sensitive user-managed paths.""" + + def test_redirect_to_custom_hermes_home_env(self): + dangerous, key, desc = detect_dangerous_command("echo x > $HERMES_HOME/.env") + assert dangerous is True + assert key is not None + + def test_append_to_home_ssh_authorized_keys(self): + dangerous, key, desc = detect_dangerous_command("cat key >> $HOME/.ssh/authorized_keys") + assert dangerous is True + assert key is not None + + def test_append_to_tilde_ssh_authorized_keys(self): + dangerous, key, desc = detect_dangerous_command("cat key >> ~/.ssh/authorized_keys") + assert dangerous is True + assert key is not None + + def test_redirect_to_safe_tmp_file(self): + dangerous, key, desc = detect_dangerous_command("echo hello > /tmp/output.txt") + assert dangerous is False + assert key is None + + class TestPatternKeyUniqueness: """Bug: pattern_key is derived by splitting on \\b and taking [1], so patterns starting with the same word (e.g. find -exec rm and find -delete) @@ -512,6 +546,30 @@ class TestGatewayProtection: dangerous, key, desc = detect_dangerous_command(cmd) assert dangerous is False + def test_pkill_hermes_detected(self): + """pkill targeting hermes/gateway processes must be caught.""" + cmd = 'pkill -f "cli.py --gateway"' + dangerous, key, desc = detect_dangerous_command(cmd) + assert dangerous is True + assert "self-termination" in desc + + def test_killall_hermes_detected(self): + cmd = "killall hermes" + dangerous, key, desc = detect_dangerous_command(cmd) + assert dangerous is True + assert "self-termination" in desc + + def test_pkill_gateway_detected(self): + cmd = "pkill -f gateway" + dangerous, key, desc = detect_dangerous_command(cmd) + assert dangerous is True + + def test_pkill_unrelated_not_flagged(self): + """pkill targeting unrelated processes should not be flagged.""" + cmd = "pkill -f nginx" + dangerous, key, desc = detect_dangerous_command(cmd) + assert dangerous is False + class TestNormalizationBypass: """Obfuscation techniques must not bypass dangerous command detection.""" @@ -582,3 +640,4 @@ class TestNormalizationBypass: dangerous, key, desc = detect_dangerous_command(cmd) assert dangerous is False + diff --git a/tests/tools/test_browser_content_none_guard.py b/tests/tools/test_browser_content_none_guard.py new file mode 100644 index 000000000..6952bb938 --- /dev/null +++ b/tests/tools/test_browser_content_none_guard.py @@ -0,0 +1,109 @@ +"""Tests for None guard on browser_tool LLM response content. + +browser_tool.py has two call sites that access response.choices[0].message.content +without checking for None — _extract_relevant_content (line 996) and +browser_vision (line 1626). When reasoning-only models (DeepSeek-R1, QwQ) +return content=None, these produce null snapshots or null analysis. + +These tests verify both sites are guarded. +""" + +import types +from unittest.mock import MagicMock, patch + +import pytest + + +# ── helpers ──────────────────────────────────────────────────────────────── + +def _make_response(content): + """Build a minimal OpenAI-compatible ChatCompletion response stub.""" + message = types.SimpleNamespace(content=content) + choice = types.SimpleNamespace(message=message) + return types.SimpleNamespace(choices=[choice]) + + +# ── _extract_relevant_content (line 996) ────────────────────────────────── + +class TestExtractRelevantContentNoneGuard: + """tools/browser_tool.py — _extract_relevant_content()""" + + def test_none_content_falls_back_to_truncated(self): + """When LLM returns None content, should fall back to truncated snapshot.""" + with patch("tools.browser_tool.call_llm", return_value=_make_response(None)), \ + patch("tools.browser_tool._get_extraction_model", return_value="test-model"): + from tools.browser_tool import _extract_relevant_content + result = _extract_relevant_content("This is a long snapshot text", "find the button") + + assert result is not None + assert isinstance(result, str) + assert len(result) > 0 + + def test_normal_content_returned(self): + """Normal string content should pass through.""" + with patch("tools.browser_tool.call_llm", return_value=_make_response("Extracted content here")), \ + patch("tools.browser_tool._get_extraction_model", return_value="test-model"): + from tools.browser_tool import _extract_relevant_content + result = _extract_relevant_content("snapshot text", "task") + + assert result == "Extracted content here" + + def test_empty_string_content_falls_back(self): + """Empty string content should also fall back to truncated.""" + with patch("tools.browser_tool.call_llm", return_value=_make_response(" ")), \ + patch("tools.browser_tool._get_extraction_model", return_value="test-model"): + from tools.browser_tool import _extract_relevant_content + result = _extract_relevant_content("This is a long snapshot text", "task") + + assert result is not None + assert len(result) > 0 + + +# ── browser_vision (line 1626) ──────────────────────────────────────────── + +class TestBrowserVisionNoneGuard: + """tools/browser_tool.py — browser_vision() analysis extraction""" + + def test_none_content_produces_fallback_message(self): + """When LLM returns None content, analysis should have a fallback message.""" + response = _make_response(None) + analysis = (response.choices[0].message.content or "").strip() + fallback = analysis or "Vision analysis returned no content." + + assert fallback == "Vision analysis returned no content." + + def test_normal_content_passes_through(self): + """Normal analysis content should pass through unchanged.""" + response = _make_response(" The page shows a login form. ") + analysis = (response.choices[0].message.content or "").strip() + fallback = analysis or "Vision analysis returned no content." + + assert fallback == "The page shows a login form." + + +# ── source line verification ────────────────────────────────────────────── + +class TestBrowserSourceLinesAreGuarded: + """Verify the actual source file has the fix applied.""" + + @staticmethod + def _read_file() -> str: + import os + base = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + with open(os.path.join(base, "tools", "browser_tool.py")) as f: + return f.read() + + def test_extract_relevant_content_guarded(self): + src = self._read_file() + # The old unguarded pattern should NOT exist + assert "return response.choices[0].message.content\n" not in src, ( + "browser_tool.py _extract_relevant_content still has unguarded " + ".content return — apply None guard" + ) + + def test_browser_vision_guarded(self): + src = self._read_file() + assert "analysis = response.choices[0].message.content\n" not in src, ( + "browser_tool.py browser_vision still has unguarded " + ".content assignment — apply None guard" + ) diff --git a/tests/tools/test_command_guards.py b/tests/tools/test_command_guards.py index c890a2c6f..a4b43147f 100644 --- a/tests/tools/test_command_guards.py +++ b/tests/tools/test_command_guards.py @@ -95,23 +95,49 @@ class TestTirithAllowSafeCommand: # --------------------------------------------------------------------------- class TestTirithBlock: + """Tirith 'block' is now treated as an approvable warning (not a hard block). + + Users are prompted with the tirith findings and can approve if they + understand the risk. The prompt defaults to deny, so if no input is + provided the command is still blocked — but through the approval flow, + not a hard block bypass. + """ + @patch(_TIRITH_PATCH, return_value=_tirith_result("block", summary="homograph detected")) - def test_tirith_block_safe_command(self, mock_tirith): + def test_tirith_block_prompts_user(self, mock_tirith): + """tirith block goes through approval flow (user gets prompted).""" os.environ["HERMES_INTERACTIVE"] = "1" result = check_all_command_guards("curl http://gооgle.com", "local") + # Default is deny (no input → timeout → deny), so still blocked assert result["approved"] is False - assert "BLOCKED" in result["message"] - assert "homograph" in result["message"] + # But through the approval flow, not a hard block — message says + # "User denied" rather than "Command blocked by security scan" + assert "denied" in result["message"].lower() or "BLOCKED" in result["message"] @patch(_TIRITH_PATCH, return_value=_tirith_result("block", summary="terminal injection")) - def test_tirith_block_plus_dangerous(self, mock_tirith): - """tirith block takes precedence even if command is also dangerous.""" + def test_tirith_block_plus_dangerous_prompts_combined(self, mock_tirith): + """tirith block + dangerous pattern → combined approval prompt.""" os.environ["HERMES_INTERACTIVE"] = "1" result = check_all_command_guards("rm -rf / | curl http://evil", "local") assert result["approved"] is False - assert "BLOCKED" in result["message"] + + @patch(_TIRITH_PATCH, + return_value=_tirith_result("block", + findings=[{"rule_id": "curl_pipe_shell", + "severity": "HIGH", + "title": "Pipe to interpreter", + "description": "Downloaded content executed without inspection"}], + summary="pipe to shell")) + def test_tirith_block_gateway_returns_approval_required(self, mock_tirith): + """In gateway mode, tirith block should return approval_required.""" + os.environ["HERMES_GATEWAY_SESSION"] = "1" + result = check_all_command_guards("curl -fsSL https://x.dev/install.sh | sh", "local") + assert result["approved"] is False + assert result.get("status") == "approval_required" + # Findings should be included in the description + assert "Pipe to interpreter" in result.get("description", "") or "pipe" in result.get("message", "").lower() # --------------------------------------------------------------------------- diff --git a/tests/tools/test_config_null_guard.py b/tests/tools/test_config_null_guard.py new file mode 100644 index 000000000..a6ab64009 --- /dev/null +++ b/tests/tools/test_config_null_guard.py @@ -0,0 +1,111 @@ +"""Tests for config.get() null-coalescing in tool configuration. + +YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)`` +return ``None`` instead of the default — calling ``.lower()`` on that raises +``AttributeError``. These tests verify the ``or`` coalescing guards. +""" + +from unittest.mock import patch +import pytest + + +# ── TTS tool ────────────────────────────────────────────────────────────── + +class TestTTSProviderNullGuard: + """tools/tts_tool.py — _get_provider()""" + + def test_explicit_null_provider_returns_default(self): + """YAML ``tts: {provider: null}`` should fall back to default.""" + from tools.tts_tool import _get_provider, DEFAULT_PROVIDER + + result = _get_provider({"provider": None}) + assert result == DEFAULT_PROVIDER.lower().strip() + + def test_missing_provider_returns_default(self): + """No ``provider`` key at all should also return default.""" + from tools.tts_tool import _get_provider, DEFAULT_PROVIDER + + result = _get_provider({}) + assert result == DEFAULT_PROVIDER.lower().strip() + + def test_valid_provider_passed_through(self): + from tools.tts_tool import _get_provider + + result = _get_provider({"provider": "OPENAI"}) + assert result == "openai" + + +# ── Web tools ───────────────────────────────────────────────────────────── + +class TestWebBackendNullGuard: + """tools/web_tools.py — _get_backend()""" + + @patch("tools.web_tools._load_web_config", return_value={"backend": None}) + def test_explicit_null_backend_does_not_crash(self, _cfg): + """YAML ``web: {backend: null}`` should not raise AttributeError.""" + from tools.web_tools import _get_backend + + # Should not raise — the exact return depends on env key fallback + result = _get_backend() + assert isinstance(result, str) + + @patch("tools.web_tools._load_web_config", return_value={}) + def test_missing_backend_does_not_crash(self, _cfg): + from tools.web_tools import _get_backend + + result = _get_backend() + assert isinstance(result, str) + + +# ── MCP tool ────────────────────────────────────────────────────────────── + +class TestMCPAuthNullGuard: + """tools/mcp_tool.py — MCPServerTask.__init__() auth config line""" + + def test_explicit_null_auth_does_not_crash(self): + """YAML ``auth: null`` in MCP server config should not raise.""" + # Test the expression directly — MCPServerTask.__init__ has many deps + config = {"auth": None, "timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "" + + def test_missing_auth_defaults_to_empty(self): + config = {"timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "" + + def test_valid_auth_passed_through(self): + config = {"auth": "OAUTH", "timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "oauth" + + +# ── Trajectory compressor ───────────────────────────────────────────────── + +class TestTrajectoryCompressorNullGuard: + """trajectory_compressor.py — _detect_provider() and config loading""" + + def test_null_base_url_does_not_crash(self): + """base_url=None should not crash _detect_provider().""" + from trajectory_compressor import CompressionConfig, TrajectoryCompressor + + config = CompressionConfig() + config.base_url = None + + compressor = TrajectoryCompressor.__new__(TrajectoryCompressor) + compressor.config = config + + # Should not raise AttributeError; returns empty string (no match) + result = compressor._detect_provider() + assert result == "" + + def test_config_loading_null_base_url_keeps_default(self): + """YAML ``summarization: {base_url: null}`` should keep default.""" + from trajectory_compressor import CompressionConfig + from hermes_constants import OPENROUTER_BASE_URL + + config = CompressionConfig() + data = {"summarization": {"base_url": None}} + + config.base_url = data["summarization"].get("base_url") or config.base_url + assert config.base_url == OPENROUTER_BASE_URL diff --git a/tests/tools/test_credential_files.py b/tests/tools/test_credential_files.py new file mode 100644 index 000000000..293e2c6da --- /dev/null +++ b/tests/tools/test_credential_files.py @@ -0,0 +1,158 @@ +"""Tests for credential file passthrough registry (tools/credential_files.py).""" + +import os +from pathlib import Path + +import pytest + +from tools.credential_files import ( + clear_credential_files, + get_credential_file_mounts, + register_credential_file, + register_credential_files, + reset_config_cache, +) + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Reset registry between tests.""" + clear_credential_files() + reset_config_cache() + yield + clear_credential_files() + reset_config_cache() + + +class TestRegisterCredentialFile: + def test_registers_existing_file(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "token.json").write_text('{"token": "abc"}') + + result = register_credential_file("token.json") + + assert result is True + mounts = get_credential_file_mounts() + assert len(mounts) == 1 + assert mounts[0]["host_path"] == str(tmp_path / "token.json") + assert mounts[0]["container_path"] == "/root/.hermes/token.json" + + def test_skips_missing_file(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + result = register_credential_file("nonexistent.json") + + assert result is False + assert get_credential_file_mounts() == [] + + def test_custom_container_base(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "cred.json").write_text("{}") + + register_credential_file("cred.json", container_base="/home/user/.hermes") + + mounts = get_credential_file_mounts() + assert mounts[0]["container_path"] == "/home/user/.hermes/cred.json" + + def test_deduplicates_by_container_path(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "token.json").write_text("{}") + + register_credential_file("token.json") + register_credential_file("token.json") + + mounts = get_credential_file_mounts() + assert len(mounts) == 1 + + +class TestRegisterCredentialFiles: + def test_string_entries(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "a.json").write_text("{}") + (tmp_path / "b.json").write_text("{}") + + missing = register_credential_files(["a.json", "b.json"]) + + assert missing == [] + assert len(get_credential_file_mounts()) == 2 + + def test_dict_entries(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "token.json").write_text("{}") + + missing = register_credential_files([ + {"path": "token.json", "description": "OAuth token"}, + ]) + + assert missing == [] + assert len(get_credential_file_mounts()) == 1 + + def test_returns_missing_files(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "exists.json").write_text("{}") + + missing = register_credential_files([ + "exists.json", + "missing.json", + {"path": "also_missing.json"}, + ]) + + assert missing == ["missing.json", "also_missing.json"] + assert len(get_credential_file_mounts()) == 1 + + def test_empty_list(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + assert register_credential_files([]) == [] + + +class TestConfigCredentialFiles: + def test_loads_from_config(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "oauth.json").write_text("{}") + (tmp_path / "config.yaml").write_text( + "terminal:\n credential_files:\n - oauth.json\n" + ) + + mounts = get_credential_file_mounts() + + assert len(mounts) == 1 + assert mounts[0]["host_path"] == str(tmp_path / "oauth.json") + + def test_config_skips_missing_files(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "config.yaml").write_text( + "terminal:\n credential_files:\n - nonexistent.json\n" + ) + + mounts = get_credential_file_mounts() + assert mounts == [] + + def test_combines_skill_and_config(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + (tmp_path / "skill_token.json").write_text("{}") + (tmp_path / "config_token.json").write_text("{}") + (tmp_path / "config.yaml").write_text( + "terminal:\n credential_files:\n - config_token.json\n" + ) + + register_credential_file("skill_token.json") + mounts = get_credential_file_mounts() + + assert len(mounts) == 2 + paths = {m["container_path"] for m in mounts} + assert "/root/.hermes/skill_token.json" in paths + assert "/root/.hermes/config_token.json" in paths + + +class TestGetMountsRechecksExistence: + def test_removed_file_excluded_from_mounts(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + token = tmp_path / "token.json" + token.write_text("{}") + + register_credential_file("token.json") + assert len(get_credential_file_mounts()) == 1 + + # Delete the file after registration + token.unlink() + assert get_credential_file_mounts() == [] diff --git a/tests/tools/test_honcho_tools.py b/tests/tools/test_honcho_tools.py index 16e144541..0651eb52c 100644 --- a/tests/tools/test_honcho_tools.py +++ b/tests/tools/test_honcho_tools.py @@ -1,11 +1,86 @@ """Regression tests for per-call Honcho tool session routing.""" import json -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +from dataclasses import dataclass from tools import honcho_tools +class TestCheckHonchoAvailable: + """Tests for _check_honcho_available (banner + runtime gating).""" + + def setup_method(self): + self.orig_manager = honcho_tools._session_manager + self.orig_key = honcho_tools._session_key + + def teardown_method(self): + honcho_tools._session_manager = self.orig_manager + honcho_tools._session_key = self.orig_key + + def test_returns_true_when_session_active(self): + """Fast path: session context already injected (mid-conversation).""" + honcho_tools._session_manager = MagicMock() + honcho_tools._session_key = "test-key" + assert honcho_tools._check_honcho_available() is True + + def test_returns_true_when_configured_but_no_session(self): + """Slow path: honcho configured but agent not started yet (banner time).""" + honcho_tools._session_manager = None + honcho_tools._session_key = None + + @dataclass + class FakeConfig: + enabled: bool = True + api_key: str = "test-key" + base_url: str = None + + with patch("tools.honcho_tools.HonchoClientConfig", create=True): + with patch( + "honcho_integration.client.HonchoClientConfig" + ) as mock_cls: + mock_cls.from_global_config.return_value = FakeConfig() + assert honcho_tools._check_honcho_available() is True + + def test_returns_false_when_not_configured(self): + """No session, no config: tool genuinely unavailable.""" + honcho_tools._session_manager = None + honcho_tools._session_key = None + + @dataclass + class FakeConfig: + enabled: bool = False + api_key: str = None + base_url: str = None + + with patch( + "honcho_integration.client.HonchoClientConfig" + ) as mock_cls: + mock_cls.from_global_config.return_value = FakeConfig() + assert honcho_tools._check_honcho_available() is False + + def test_returns_false_when_import_fails(self): + """Graceful fallback when honcho_integration not installed.""" + import sys + + honcho_tools._session_manager = None + honcho_tools._session_key = None + + # Hide honcho_integration from the import system to simulate + # an environment where the package is not installed. + hidden = { + k: sys.modules.pop(k) + for k in list(sys.modules) + if k.startswith("honcho_integration") + } + try: + with patch.dict(sys.modules, {"honcho_integration": None, + "honcho_integration.client": None}): + assert honcho_tools._check_honcho_available() is False + finally: + sys.modules.update(hidden) + + class TestHonchoToolSessionContext: def setup_method(self): self.orig_manager = honcho_tools._session_manager diff --git a/tests/tools/test_llm_content_none_guard.py b/tests/tools/test_llm_content_none_guard.py new file mode 100644 index 000000000..b0adea8c7 --- /dev/null +++ b/tests/tools/test_llm_content_none_guard.py @@ -0,0 +1,294 @@ +"""Tests for None guard on response.choices[0].message.content.strip(). + +OpenAI-compatible APIs return ``message.content = None`` when the model +responds with tool calls only or reasoning-only output (e.g. DeepSeek-R1, +Qwen-QwQ via OpenRouter with ``reasoning.enabled = True``). Calling +``.strip()`` on ``None`` raises ``AttributeError``. + +These tests verify that every call site handles ``content is None`` safely, +and that ``extract_content_or_reasoning()`` falls back to structured +reasoning fields when content is empty. +""" + +import asyncio +import types +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agent.auxiliary_client import extract_content_or_reasoning + + +# ── helpers ──────────────────────────────────────────────────────────────── + +def _make_response(content, **msg_attrs): + """Build a minimal OpenAI-compatible ChatCompletion response stub. + + Extra keyword args are set as attributes on the message object + (e.g. reasoning="...", reasoning_content="...", reasoning_details=[...]). + """ + message = types.SimpleNamespace(content=content, tool_calls=None, **msg_attrs) + choice = types.SimpleNamespace(message=message) + return types.SimpleNamespace(choices=[choice]) + + +def _run(coro): + """Run an async coroutine synchronously.""" + return asyncio.get_event_loop().run_until_complete(coro) + + +# ── mixture_of_agents_tool — reference model (line 146) ─────────────────── + +class TestMoAReferenceModelContentNone: + """tools/mixture_of_agents_tool.py — _query_model()""" + + def test_none_content_raises_before_fix(self): + """Demonstrate that None content from a reasoning model crashes.""" + response = _make_response(None) + + # Simulate the exact line: response.choices[0].message.content.strip() + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + """The ``or ""`` guard should convert None to empty string.""" + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + def test_normal_content_unaffected(self): + """Regular string content should pass through unchanged.""" + response = _make_response(" Hello world ") + + content = (response.choices[0].message.content or "").strip() + assert content == "Hello world" + + +# ── mixture_of_agents_tool — aggregator (line 214) ──────────────────────── + +class TestMoAAggregatorContentNone: + """tools/mixture_of_agents_tool.py — _run_aggregator()""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── web_tools — LLM content processor (line 419) ───────────────────────── + +class TestWebToolsProcessorContentNone: + """tools/web_tools.py — _process_with_llm() return line""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── web_tools — synthesis/summarization (line 538) ──────────────────────── + +class TestWebToolsSynthesisContentNone: + """tools/web_tools.py — synthesize_content() final_summary line""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── vision_tools (line 350) ─────────────────────────────────────────────── + +class TestVisionToolsContentNone: + """tools/vision_tools.py — analyze_image() analysis extraction""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── skills_guard (line 963) ─────────────────────────────────────────────── + +class TestSkillsGuardContentNone: + """tools/skills_guard.py — _llm_audit_skill() llm_text extraction""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── session_search_tool (line 164) ──────────────────────────────────────── + +class TestSessionSearchContentNone: + """tools/session_search_tool.py — _summarize_session() return line""" + + def test_none_content_raises_before_fix(self): + response = _make_response(None) + + with pytest.raises(AttributeError): + response.choices[0].message.content.strip() + + def test_none_content_safe_with_or_guard(self): + response = _make_response(None) + + content = (response.choices[0].message.content or "").strip() + assert content == "" + + +# ── integration: verify the actual source lines are guarded ─────────────── + +class TestSourceLinesAreGuarded: + """Read the actual source files and verify the fix is applied. + + These tests will FAIL before the fix (bare .content.strip()) and + PASS after ((.content or "").strip()). + """ + + @staticmethod + def _read_file(rel_path: str) -> str: + import os + base = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + with open(os.path.join(base, rel_path)) as f: + return f.read() + + def test_mixture_of_agents_reference_model_guarded(self): + src = self._read_file("tools/mixture_of_agents_tool.py") + # The unguarded pattern should NOT exist + assert ".message.content.strip()" not in src, ( + "tools/mixture_of_agents_tool.py still has unguarded " + ".content.strip() — apply `(... or \"\").strip()` guard" + ) + + def test_web_tools_guarded(self): + src = self._read_file("tools/web_tools.py") + assert ".message.content.strip()" not in src, ( + "tools/web_tools.py still has unguarded " + ".content.strip() — apply `(... or \"\").strip()` guard" + ) + + def test_vision_tools_guarded(self): + src = self._read_file("tools/vision_tools.py") + assert ".message.content.strip()" not in src, ( + "tools/vision_tools.py still has unguarded " + ".content.strip() — apply `(... or \"\").strip()` guard" + ) + + def test_skills_guard_guarded(self): + src = self._read_file("tools/skills_guard.py") + assert ".message.content.strip()" not in src, ( + "tools/skills_guard.py still has unguarded " + ".content.strip() — apply `(... or \"\").strip()` guard" + ) + + def test_session_search_tool_guarded(self): + src = self._read_file("tools/session_search_tool.py") + assert ".message.content.strip()" not in src, ( + "tools/session_search_tool.py still has unguarded " + ".content.strip() — apply `(... or \"\").strip()` guard" + ) + + +# ── extract_content_or_reasoning() ──────────────────────────────────────── + +class TestExtractContentOrReasoning: + """agent/auxiliary_client.py — extract_content_or_reasoning()""" + + def test_normal_content_returned(self): + response = _make_response(" Hello world ") + assert extract_content_or_reasoning(response) == "Hello world" + + def test_none_content_returns_empty(self): + response = _make_response(None) + assert extract_content_or_reasoning(response) == "" + + def test_empty_string_returns_empty(self): + response = _make_response("") + assert extract_content_or_reasoning(response) == "" + + def test_think_blocks_stripped_with_remaining_content(self): + response = _make_response("internal reasoningThe answer is 42.") + assert extract_content_or_reasoning(response) == "The answer is 42." + + def test_think_only_content_falls_back_to_reasoning_field(self): + """When content is only think blocks, fall back to structured reasoning.""" + response = _make_response( + "some reasoning", + reasoning="The actual reasoning output", + ) + assert extract_content_or_reasoning(response) == "The actual reasoning output" + + def test_none_content_with_reasoning_field(self): + """DeepSeek-R1 pattern: content=None, reasoning='...'""" + response = _make_response(None, reasoning="Step 1: analyze the problem...") + assert extract_content_or_reasoning(response) == "Step 1: analyze the problem..." + + def test_none_content_with_reasoning_content_field(self): + """Moonshot/Novita pattern: content=None, reasoning_content='...'""" + response = _make_response(None, reasoning_content="Let me think about this...") + assert extract_content_or_reasoning(response) == "Let me think about this..." + + def test_none_content_with_reasoning_details(self): + """OpenRouter unified format: reasoning_details=[{summary: ...}]""" + response = _make_response(None, reasoning_details=[ + {"type": "reasoning.summary", "summary": "The key insight is..."}, + ]) + assert extract_content_or_reasoning(response) == "The key insight is..." + + def test_reasoning_fields_not_duplicated(self): + """When reasoning and reasoning_content have the same value, don't duplicate.""" + response = _make_response(None, reasoning="same text", reasoning_content="same text") + assert extract_content_or_reasoning(response) == "same text" + + def test_multiple_reasoning_sources_combined(self): + """Different reasoning sources are joined with double newline.""" + response = _make_response( + None, + reasoning="First part", + reasoning_content="Second part", + ) + result = extract_content_or_reasoning(response) + assert "First part" in result + assert "Second part" in result + + def test_content_preferred_over_reasoning(self): + """When both content and reasoning exist, content wins.""" + response = _make_response("Actual answer", reasoning="Internal reasoning") + assert extract_content_or_reasoning(response) == "Actual answer" diff --git a/tests/tools/test_local_persistent.py b/tests/tools/test_local_persistent.py index b20cca5be..5b9ce2e23 100644 --- a/tests/tools/test_local_persistent.py +++ b/tests/tools/test_local_persistent.py @@ -63,6 +63,18 @@ class TestLocalOneShotRegression: assert r["output"].strip() == "" env.cleanup() + def test_oneshot_heredoc_does_not_leak_fence_wrapper(self): + """Heredoc closing line must not be merged with the fence wrapper tail.""" + env = LocalEnvironment(persistent=False) + cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF" + r = env.execute(cmd) + env.cleanup() + assert r["returncode"] == 0 + assert "heredoc body line" in r["output"] + assert "__hermes_rc" not in r["output"] + assert "printf '" not in r["output"] + assert "exit $" not in r["output"] + class TestLocalPersistent: @pytest.fixture diff --git a/tests/tools/test_managed_browserbase_and_modal.py b/tests/tools/test_managed_browserbase_and_modal.py index 085f19cfd..3c8bb1214 100644 --- a/tests/tools/test_managed_browserbase_and_modal.py +++ b/tests/tools/test_managed_browserbase_and_modal.py @@ -357,7 +357,7 @@ def test_terminal_tool_prefers_managed_modal_when_gateway_ready_and_no_direct_cr assert not direct_ctor.called -def test_terminal_tool_keeps_direct_modal_when_direct_credentials_exist(): +def test_terminal_tool_auto_mode_prefers_managed_modal_when_available(): _install_fake_tools_package() env = os.environ.copy() env.update({ @@ -385,7 +385,43 @@ def test_terminal_tool_keeps_direct_modal_when_direct_credentials_exist(): "container_persistent": True, "modal_mode": "auto", }, - task_id="task-modal-direct", + task_id="task-modal-auto", + ) + + assert result == "managed-modal-env" + assert managed_ctor.called + assert not direct_ctor.called + + +def test_terminal_tool_auto_mode_falls_back_to_direct_modal_when_managed_unavailable(): + _install_fake_tools_package() + env = os.environ.copy() + env.update({ + "MODAL_TOKEN_ID": "tok-id", + "MODAL_TOKEN_SECRET": "tok-secret", + }) + + with patch.dict(os.environ, env, clear=True): + terminal_tool = _load_tool_module("tools.terminal_tool", "terminal_tool.py") + + with ( + patch.object(terminal_tool, "is_managed_tool_gateway_ready", return_value=False), + patch.object(terminal_tool, "_ManagedModalEnvironment", return_value="managed-modal-env") as managed_ctor, + patch.object(terminal_tool, "_ModalEnvironment", return_value="direct-modal-env") as direct_ctor, + ): + result = terminal_tool._create_environment( + env_type="modal", + image="python:3.11", + cwd="/root", + timeout=60, + container_config={ + "container_cpu": 1, + "container_memory": 2048, + "container_disk": 1024, + "container_persistent": True, + "modal_mode": "auto", + }, + task_id="task-modal-direct-fallback", ) assert result == "direct-modal-env" diff --git a/tests/tools/test_mcp_dynamic_discovery.py b/tests/tools/test_mcp_dynamic_discovery.py new file mode 100644 index 000000000..c7c4ae86c --- /dev/null +++ b/tests/tools/test_mcp_dynamic_discovery.py @@ -0,0 +1,170 @@ +"""Tests for MCP dynamic tool discovery (notifications/tools/list_changed).""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.mcp_tool import MCPServerTask, _register_server_tools +from tools.registry import ToolRegistry + + +def _make_mcp_tool(name: str, desc: str = ""): + return SimpleNamespace(name=name, description=desc, inputSchema=None) + + +class TestRegisterServerTools: + """Tests for the extracted _register_server_tools helper.""" + + @pytest.fixture + def mock_registry(self): + return ToolRegistry() + + @pytest.fixture + def mock_toolsets(self): + return { + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + "custom-toolset": {"tools": [], "description": "Other", "includes": []}, + } + + def test_injects_hermes_toolsets(self, mock_registry, mock_toolsets): + """Tools are injected into hermes-* toolsets but not custom ones.""" + server = MCPServerTask("my_srv") + server._tools = [_make_mcp_tool("my_tool", "desc")] + server.session = MagicMock() + + with patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"), \ + patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True): + + registered = _register_server_tools("my_srv", server, {}) + + assert "mcp_my_srv_my_tool" in registered + assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names() + + # Injected into hermes-* toolsets + assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-cli"]["tools"] + assert "mcp_my_srv_my_tool" in mock_toolsets["hermes-telegram"]["tools"] + # NOT into non-hermes toolsets + assert "mcp_my_srv_my_tool" not in mock_toolsets["custom-toolset"]["tools"] + + +class TestRefreshTools: + """Tests for MCPServerTask._refresh_tools nuke-and-repave cycle.""" + + @pytest.fixture + def mock_registry(self): + return ToolRegistry() + + @pytest.fixture + def mock_toolsets(self): + return { + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + } + + @pytest.mark.asyncio + async def test_nuke_and_repave(self, mock_registry, mock_toolsets): + """Old tools are removed and new tools registered on refresh.""" + server = MCPServerTask("live_srv") + server._refresh_lock = asyncio.Lock() + server._config = {} + + # Seed initial state: one old tool registered + mock_registry.register( + name="mcp_live_srv_old_tool", toolset="mcp-live_srv", schema={}, + handler=lambda x: x, check_fn=lambda: True, is_async=False, + description="", emoji="", + ) + server._registered_tool_names = ["mcp_live_srv_old_tool"] + mock_toolsets["hermes-cli"]["tools"].append("mcp_live_srv_old_tool") + + # New tool list from server + new_tool = _make_mcp_tool("new_tool", "new behavior") + server.session = SimpleNamespace( + list_tools=AsyncMock( + return_value=SimpleNamespace(tools=[new_tool]) + ) + ) + + with patch("tools.registry.registry", mock_registry), \ + patch("toolsets.create_custom_toolset"), \ + patch.dict("toolsets.TOOLSETS", mock_toolsets, clear=True): + + await server._refresh_tools() + + # Old tool completely gone + assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names() + assert "mcp_live_srv_old_tool" not in mock_toolsets["hermes-cli"]["tools"] + + # New tool registered + assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names() + assert "mcp_live_srv_new_tool" in mock_toolsets["hermes-cli"]["tools"] + assert server._registered_tool_names == ["mcp_live_srv_new_tool"] + + +class TestMessageHandler: + """Tests for MCPServerTask._make_message_handler dispatch.""" + + @pytest.mark.asyncio + async def test_dispatches_tool_list_changed(self): + from tools.mcp_tool import _MCP_NOTIFICATION_TYPES + if not _MCP_NOTIFICATION_TYPES: + pytest.skip("MCP SDK ToolListChangedNotification not available") + + from mcp.types import ServerNotification, ToolListChangedNotification + + server = MCPServerTask("notif_srv") + with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + handler = server._make_message_handler() + notification = ServerNotification( + root=ToolListChangedNotification(method="notifications/tools/list_changed") + ) + await handler(notification) + mock_refresh.assert_awaited_once() + + @pytest.mark.asyncio + async def test_ignores_exceptions_and_other_messages(self): + server = MCPServerTask("notif_srv") + with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + handler = server._make_message_handler() + # Exceptions should not trigger refresh + await handler(RuntimeError("connection dead")) + # Unknown message types should not trigger refresh + await handler({"jsonrpc": "2.0", "result": "ok"}) + mock_refresh.assert_not_awaited() + + +class TestDeregister: + """Tests for ToolRegistry.deregister.""" + + def test_removes_tool(self): + reg = ToolRegistry() + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x) + assert "foo" in reg.get_all_tool_names() + reg.deregister("foo") + assert "foo" not in reg.get_all_tool_names() + + def test_cleans_up_toolset_check(self): + reg = ToolRegistry() + check = lambda: True # noqa: E731 + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check) + assert reg.is_toolset_available("ts1") + reg.deregister("foo") + # Toolset check should be gone since no tools remain + assert "ts1" not in reg._toolset_checks + + def test_preserves_toolset_check_if_other_tools_remain(self): + reg = ToolRegistry() + check = lambda: True # noqa: E731 + reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check) + reg.register(name="bar", toolset="ts1", schema={}, handler=lambda x: x) + reg.deregister("foo") + # bar still in ts1, so check should remain + assert "ts1" in reg._toolset_checks + + def test_noop_for_unknown_tool(self): + reg = ToolRegistry() + reg.deregister("nonexistent") # Should not raise diff --git a/tests/tools/test_modal_sandbox_fixes.py b/tests/tools/test_modal_sandbox_fixes.py index 23dfa2f8f..7e3feb5cf 100644 --- a/tests/tools/test_modal_sandbox_fixes.py +++ b/tests/tools/test_modal_sandbox_fixes.py @@ -4,10 +4,9 @@ Covers the bugs discovered while setting up TBLite evaluation: 1. Tool resolution — terminal + file tools load correctly 2. CWD fix — host paths get replaced with /root for container backends 3. ephemeral_disk version check -4. Tilde ~ replaced with /root for container backends -5. ensurepip fix in Modal image builder -6. install_pipx stays True for swerex-remote -7. /home/ added to host prefix check +4. ensurepip fix in Modal image builder +5. No swe-rex dependency — uses native Modal SDK +6. /home/ added to host prefix check """ import os @@ -251,7 +250,7 @@ class TestModalEnvironmentDefaults: # ========================================================================= -# Test 7: ensurepip fix in patches.py +# Test 7: ensurepip fix in ModalEnvironment # ========================================================================= class TestEnsurepipFix: @@ -275,17 +274,24 @@ class TestEnsurepipFix: "to fix pip before Modal's bootstrap" ) - def test_modal_environment_uses_install_pipx(self): - """ModalEnvironment should pass install_pipx to ModalDeployment.""" + def test_modal_environment_uses_native_sdk(self): + """ModalEnvironment should use Modal SDK directly, not swe-rex.""" try: from tools.environments.modal import ModalEnvironment except ImportError: pytest.skip("tools.environments.modal not importable") import inspect - source = inspect.getsource(ModalEnvironment.__init__) - assert "install_pipx" in source, ( - "ModalEnvironment should pass install_pipx to ModalDeployment" + source = inspect.getsource(ModalEnvironment) + assert "swerex" not in source.lower(), ( + "ModalEnvironment should not depend on swe-rex; " + "use Modal SDK directly via Sandbox.create() + exec()" + ) + assert "Sandbox.create.aio" in source, ( + "ModalEnvironment should use async Modal Sandbox.create.aio()" + ) + assert "exec.aio" in source, ( + "ModalEnvironment should use Sandbox.exec.aio() for command execution" ) diff --git a/tests/tools/test_modal_snapshot_isolation.py b/tests/tools/test_modal_snapshot_isolation.py index 0b4f7fc56..e2daf6a31 100644 --- a/tests/tools/test_modal_snapshot_isolation.py +++ b/tests/tools/test_modal_snapshot_isolation.py @@ -4,6 +4,8 @@ import types from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path +import pytest + REPO_ROOT = Path(__file__).resolve().parents[2] TOOLS_DIR = REPO_ROOT / "tools" @@ -24,13 +26,32 @@ def _reset_modules(prefixes: tuple[str, ...]): sys.modules.pop(name, None) +@pytest.fixture(autouse=True) +def _restore_tool_modules(): + original_modules = { + name: module + for name, module in sys.modules.items() + if name == "tools" + or name.startswith("tools.") + or name == "hermes_cli" + or name.startswith("hermes_cli.") + or name == "modal" + or name.startswith("modal.") + } + try: + yield + finally: + _reset_modules(("tools", "hermes_cli", "modal")) + sys.modules.update(original_modules) + + def _install_modal_test_modules( tmp_path: Path, *, fail_on_snapshot_ids: set[str] | None = None, snapshot_id: str = "im-fresh", ): - _reset_modules(("tools", "hermes_cli", "swerex", "modal")) + _reset_modules(("tools", "hermes_cli", "modal")) hermes_cli = types.ModuleType("hermes_cli") hermes_cli.__path__ = [] # type: ignore[attr-defined] @@ -62,7 +83,7 @@ def _install_modal_test_modules( from_id_calls: list[str] = [] registry_calls: list[tuple[str, list[str] | None]] = [] - deployment_calls: list[dict] = [] + create_calls: list[dict] = [] class _FakeImage: @staticmethod @@ -75,53 +96,55 @@ def _install_modal_test_modules( registry_calls.append((image, setup_dockerfile_commands)) return {"kind": "registry", "image": image} - class _FakeRuntime: - async def execute(self, _command): - return types.SimpleNamespace(stdout="ok", exit_code=0) + async def _lookup_aio(_name: str, create_if_missing: bool = False): + return types.SimpleNamespace(name="hermes-agent", create_if_missing=create_if_missing) - class _FakeModalDeployment: - def __init__(self, **kwargs): - deployment_calls.append(dict(kwargs)) - self.image = kwargs["image"] - self.runtime = _FakeRuntime() + class _FakeSandboxInstance: + def __init__(self, image): + self.image = image async def _snapshot_aio(): return types.SimpleNamespace(object_id=snapshot_id) - self._sandbox = types.SimpleNamespace( - snapshot_filesystem=types.SimpleNamespace(aio=_snapshot_aio), - ) + async def _terminate_aio(): + return None - async def start(self): - image = self.image if isinstance(self.image, dict) else {} - image_id = image.get("image_id") - if fail_on_snapshot_ids and image_id in fail_on_snapshot_ids: - raise RuntimeError(f"cannot restore {image_id}") + self.snapshot_filesystem = types.SimpleNamespace(aio=_snapshot_aio) + self.terminate = types.SimpleNamespace(aio=_terminate_aio) - async def stop(self): - return None + async def _create_aio(*_args, image=None, app=None, timeout=None, **kwargs): + create_calls.append({ + "image": image, + "app": app, + "timeout": timeout, + **kwargs, + }) + image_id = image.get("image_id") if isinstance(image, dict) else None + if fail_on_snapshot_ids and image_id in fail_on_snapshot_ids: + raise RuntimeError(f"cannot restore {image_id}") + return _FakeSandboxInstance(image) - class _FakeRexCommand: - def __init__(self, **kwargs): - self.kwargs = kwargs + class _FakeMount: + @staticmethod + def from_local_file(host_path: str, remote_path: str): + return {"host_path": host_path, "remote_path": remote_path} - sys.modules["modal"] = types.SimpleNamespace(Image=_FakeImage) + class _FakeApp: + lookup = types.SimpleNamespace(aio=_lookup_aio) - swerex = types.ModuleType("swerex") - swerex.__path__ = [] # type: ignore[attr-defined] - sys.modules["swerex"] = swerex - swerex_deployment = types.ModuleType("swerex.deployment") - swerex_deployment.__path__ = [] # type: ignore[attr-defined] - sys.modules["swerex.deployment"] = swerex_deployment - sys.modules["swerex.deployment.modal"] = types.SimpleNamespace(ModalDeployment=_FakeModalDeployment) - swerex_runtime = types.ModuleType("swerex.runtime") - swerex_runtime.__path__ = [] # type: ignore[attr-defined] - sys.modules["swerex.runtime"] = swerex_runtime - sys.modules["swerex.runtime.abstract"] = types.SimpleNamespace(Command=_FakeRexCommand) + class _FakeSandbox: + create = types.SimpleNamespace(aio=_create_aio) + + sys.modules["modal"] = types.SimpleNamespace( + Image=_FakeImage, + App=_FakeApp, + Sandbox=_FakeSandbox, + Mount=_FakeMount, + ) return { "snapshot_store": hermes_home / "modal_snapshots.json", - "deployment_calls": deployment_calls, + "create_calls": create_calls, "from_id_calls": from_id_calls, "registry_calls": registry_calls, } @@ -138,7 +161,7 @@ def test_modal_environment_migrates_legacy_snapshot_key_and_uses_snapshot_id(tmp try: assert state["from_id_calls"] == ["im-legacy123"] - assert state["deployment_calls"][0]["image"] == {"kind": "snapshot", "image_id": "im-legacy123"} + assert state["create_calls"][0]["image"] == {"kind": "snapshot", "image_id": "im-legacy123"} assert json.loads(snapshot_store.read_text()) == {"direct:task-legacy": "im-legacy123"} finally: env.cleanup() @@ -154,7 +177,7 @@ def test_modal_environment_prunes_stale_direct_snapshot_and_retries_base_image(t env = modal_module.ModalEnvironment(image="python:3.11", task_id="task-stale") try: - assert [call["image"] for call in state["deployment_calls"]] == [ + assert [call["image"] for call in state["create_calls"]] == [ {"kind": "snapshot", "image_id": "im-stale123"}, {"kind": "registry", "image": "python:3.11"}, ] diff --git a/tests/tools/test_patch_parser.py b/tests/tools/test_patch_parser.py index 77baab8dd..42e5129f5 100644 --- a/tests/tools/test_patch_parser.py +++ b/tests/tools/test_patch_parser.py @@ -185,3 +185,71 @@ class TestApplyUpdate: ' result = 1\n' ' return result + 1' ) + + +class TestAdditionOnlyHunks: + """Regression tests for #3081 — addition-only hunks were silently dropped.""" + + def test_addition_only_hunk_with_context_hint(self): + """A hunk with only + lines should insert at the context hint location.""" + patch = """\ +*** Begin Patch +*** Update File: src/app.py +@@ def main @@ ++def helper(): ++ return 42 +*** End Patch""" + ops, err = parse_v4a_patch(patch) + assert err is None + assert len(ops) == 1 + assert len(ops[0].hunks) == 1 + + hunk = ops[0].hunks[0] + # All lines should be additions + assert all(l.prefix == '+' for l in hunk.lines) + + # Apply to a file that contains the context hint + class FakeFileOps: + written = None + def read_file(self, path, **kw): + return SimpleNamespace( + content="def main():\n pass\n", + error=None, + ) + def write_file(self, path, content): + self.written = content + return SimpleNamespace(error=None) + + file_ops = FakeFileOps() + result = apply_v4a_operations(ops, file_ops) + assert result.success is True + assert "def helper():" in file_ops.written + assert "return 42" in file_ops.written + + def test_addition_only_hunk_without_context_hint(self): + """A hunk with only + lines and no context hint appends at end of file.""" + patch = """\ +*** Begin Patch +*** Update File: src/app.py ++def new_func(): ++ return True +*** End Patch""" + ops, err = parse_v4a_patch(patch) + assert err is None + + class FakeFileOps: + written = None + def read_file(self, path, **kw): + return SimpleNamespace( + content="existing = True\n", + error=None, + ) + def write_file(self, path, content): + self.written = content + return SimpleNamespace(error=None) + + file_ops = FakeFileOps() + result = apply_v4a_operations(ops, file_ops) + assert result.success is True + assert file_ops.written.endswith("def new_func():\n return True\n") + assert "existing = True" in file_ops.written diff --git a/tests/tools/test_registry.py b/tests/tools/test_registry.py index eac4ab04f..455e9f48a 100644 --- a/tests/tools/test_registry.py +++ b/tests/tools/test_registry.py @@ -81,6 +81,33 @@ class TestGetDefinitions: assert len(defs) == 1 assert defs[0]["function"]["name"] == "available" + def test_reuses_shared_check_fn_once_per_call(self): + reg = ToolRegistry() + calls = {"count": 0} + + def shared_check(): + calls["count"] += 1 + return True + + reg.register( + name="first", + toolset="shared", + schema=_make_schema("first"), + handler=_dummy_handler, + check_fn=shared_check, + ) + reg.register( + name="second", + toolset="shared", + schema=_make_schema("second"), + handler=_dummy_handler, + check_fn=shared_check, + ) + + defs = reg.get_definitions({"first", "second"}) + assert len(defs) == 2 + assert calls["count"] == 1 + class TestUnknownToolDispatch: def test_returns_error_json(self): diff --git a/tests/tools/test_send_message_missing_platforms.py b/tests/tools/test_send_message_missing_platforms.py new file mode 100644 index 000000000..8943109e0 --- /dev/null +++ b/tests/tools/test_send_message_missing_platforms.py @@ -0,0 +1,334 @@ +"""Tests for _send_mattermost, _send_matrix, _send_homeassistant, _send_dingtalk.""" + +import asyncio +import os +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from tools.send_message_tool import ( + _send_dingtalk, + _send_homeassistant, + _send_mattermost, + _send_matrix, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_aiohttp_resp(status, json_data=None, text_data=None): + """Build a minimal async-context-manager mock for an aiohttp response.""" + resp = AsyncMock() + resp.status = status + resp.json = AsyncMock(return_value=json_data or {}) + resp.text = AsyncMock(return_value=text_data or "") + return resp + + +def _make_aiohttp_session(resp): + """Wrap a response mock in a session mock that supports async-with for post/put.""" + request_ctx = MagicMock() + request_ctx.__aenter__ = AsyncMock(return_value=resp) + request_ctx.__aexit__ = AsyncMock(return_value=False) + + session = MagicMock() + session.post = MagicMock(return_value=request_ctx) + session.put = MagicMock(return_value=request_ctx) + + session_ctx = MagicMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + return session_ctx, session + + +# --------------------------------------------------------------------------- +# _send_mattermost +# --------------------------------------------------------------------------- + + +class TestSendMattermost: + def test_success(self): + resp = _make_aiohttp_resp(201, json_data={"id": "post123"}) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, {"MATTERMOST_URL": "", "MATTERMOST_TOKEN": ""}, clear=False): + extra = {"url": "https://mm.example.com"} + result = asyncio.run(_send_mattermost("tok-abc", extra, "channel1", "hello")) + + assert result == {"success": True, "platform": "mattermost", "chat_id": "channel1", "message_id": "post123"} + session.post.assert_called_once() + call_kwargs = session.post.call_args + assert call_kwargs[0][0] == "https://mm.example.com/api/v4/posts" + assert call_kwargs[1]["headers"]["Authorization"] == "Bearer tok-abc" + assert call_kwargs[1]["json"] == {"channel_id": "channel1", "message": "hello"} + + def test_http_error(self): + resp = _make_aiohttp_resp(400, text_data="Bad Request") + session_ctx, _ = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx): + result = asyncio.run(_send_mattermost( + "tok", {"url": "https://mm.example.com"}, "ch", "hi" + )) + + assert "error" in result + assert "400" in result["error"] + assert "Bad Request" in result["error"] + + def test_missing_config(self): + with patch.dict(os.environ, {"MATTERMOST_URL": "", "MATTERMOST_TOKEN": ""}, clear=False): + result = asyncio.run(_send_mattermost("", {}, "ch", "hi")) + + assert "error" in result + assert "MATTERMOST_URL" in result["error"] or "not configured" in result["error"] + + def test_env_var_fallback(self): + resp = _make_aiohttp_resp(200, json_data={"id": "p99"}) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, {"MATTERMOST_URL": "https://mm.env.com", "MATTERMOST_TOKEN": "env-tok"}, clear=False): + result = asyncio.run(_send_mattermost("", {}, "ch", "hi")) + + assert result["success"] is True + call_kwargs = session.post.call_args + assert "https://mm.env.com" in call_kwargs[0][0] + assert call_kwargs[1]["headers"]["Authorization"] == "Bearer env-tok" + + +# --------------------------------------------------------------------------- +# _send_matrix +# --------------------------------------------------------------------------- + + +class TestSendMatrix: + def test_success(self): + resp = _make_aiohttp_resp(200, json_data={"event_id": "$abc123"}) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, {"MATRIX_HOMESERVER": "", "MATRIX_ACCESS_TOKEN": ""}, clear=False): + extra = {"homeserver": "https://matrix.example.com"} + result = asyncio.run(_send_matrix("syt_tok", extra, "!room:example.com", "hello matrix")) + + assert result == { + "success": True, + "platform": "matrix", + "chat_id": "!room:example.com", + "message_id": "$abc123", + } + session.put.assert_called_once() + call_kwargs = session.put.call_args + url = call_kwargs[0][0] + assert url.startswith("https://matrix.example.com/_matrix/client/v3/rooms/!room:example.com/send/m.room.message/") + assert call_kwargs[1]["headers"]["Authorization"] == "Bearer syt_tok" + assert call_kwargs[1]["json"] == {"msgtype": "m.text", "body": "hello matrix"} + + def test_http_error(self): + resp = _make_aiohttp_resp(403, text_data="Forbidden") + session_ctx, _ = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx): + result = asyncio.run(_send_matrix( + "tok", {"homeserver": "https://matrix.example.com"}, + "!room:example.com", "hi" + )) + + assert "error" in result + assert "403" in result["error"] + assert "Forbidden" in result["error"] + + def test_missing_config(self): + with patch.dict(os.environ, {"MATRIX_HOMESERVER": "", "MATRIX_ACCESS_TOKEN": ""}, clear=False): + result = asyncio.run(_send_matrix("", {}, "!room:example.com", "hi")) + + assert "error" in result + assert "MATRIX_HOMESERVER" in result["error"] or "not configured" in result["error"] + + def test_env_var_fallback(self): + resp = _make_aiohttp_resp(200, json_data={"event_id": "$ev1"}) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, { + "MATRIX_HOMESERVER": "https://matrix.env.com", + "MATRIX_ACCESS_TOKEN": "env-tok", + }, clear=False): + result = asyncio.run(_send_matrix("", {}, "!r:env.com", "hi")) + + assert result["success"] is True + url = session.put.call_args[0][0] + assert "matrix.env.com" in url + + def test_txn_id_is_unique_across_calls(self): + """Each call should generate a distinct transaction ID in the URL.""" + txn_ids = [] + + def capture(*args, **kwargs): + url = args[0] + txn_ids.append(url.rsplit("/", 1)[-1]) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=_make_aiohttp_resp(200, json_data={"event_id": "$x"})) + ctx.__aexit__ = AsyncMock(return_value=False) + return ctx + + session = MagicMock() + session.put = capture + session_ctx = MagicMock() + session_ctx.__aenter__ = AsyncMock(return_value=session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + + extra = {"homeserver": "https://matrix.example.com"} + + import time + with patch("aiohttp.ClientSession", return_value=session_ctx): + asyncio.run(_send_matrix("tok", extra, "!r:example.com", "first")) + time.sleep(0.002) + with patch("aiohttp.ClientSession", return_value=session_ctx): + asyncio.run(_send_matrix("tok", extra, "!r:example.com", "second")) + + assert len(txn_ids) == 2 + assert txn_ids[0] != txn_ids[1] + + +# --------------------------------------------------------------------------- +# _send_homeassistant +# --------------------------------------------------------------------------- + + +class TestSendHomeAssistant: + def test_success(self): + resp = _make_aiohttp_resp(200) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, {"HASS_URL": "", "HASS_TOKEN": ""}, clear=False): + extra = {"url": "https://hass.example.com"} + result = asyncio.run(_send_homeassistant("hass-tok", extra, "mobile_app_phone", "alert!")) + + assert result == {"success": True, "platform": "homeassistant", "chat_id": "mobile_app_phone"} + session.post.assert_called_once() + call_kwargs = session.post.call_args + assert call_kwargs[0][0] == "https://hass.example.com/api/services/notify/notify" + assert call_kwargs[1]["headers"]["Authorization"] == "Bearer hass-tok" + assert call_kwargs[1]["json"] == {"message": "alert!", "target": "mobile_app_phone"} + + def test_http_error(self): + resp = _make_aiohttp_resp(401, text_data="Unauthorized") + session_ctx, _ = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx): + result = asyncio.run(_send_homeassistant( + "bad-tok", {"url": "https://hass.example.com"}, + "target", "msg" + )) + + assert "error" in result + assert "401" in result["error"] + assert "Unauthorized" in result["error"] + + def test_missing_config(self): + with patch.dict(os.environ, {"HASS_URL": "", "HASS_TOKEN": ""}, clear=False): + result = asyncio.run(_send_homeassistant("", {}, "target", "msg")) + + assert "error" in result + assert "HASS_URL" in result["error"] or "not configured" in result["error"] + + def test_env_var_fallback(self): + resp = _make_aiohttp_resp(200) + session_ctx, session = _make_aiohttp_session(resp) + + with patch("aiohttp.ClientSession", return_value=session_ctx), \ + patch.dict(os.environ, {"HASS_URL": "https://hass.env.com", "HASS_TOKEN": "env-tok"}, clear=False): + result = asyncio.run(_send_homeassistant("", {}, "notify_target", "hi")) + + assert result["success"] is True + url = session.post.call_args[0][0] + assert "hass.env.com" in url + + +# --------------------------------------------------------------------------- +# _send_dingtalk +# --------------------------------------------------------------------------- + + +class TestSendDingtalk: + def _make_httpx_resp(self, status_code=200, json_data=None): + resp = MagicMock() + resp.status_code = status_code + resp.json = MagicMock(return_value=json_data or {"errcode": 0, "errmsg": "ok"}) + resp.raise_for_status = MagicMock() + return resp + + def _make_httpx_client(self, resp): + client = AsyncMock() + client.post = AsyncMock(return_value=resp) + client_ctx = MagicMock() + client_ctx.__aenter__ = AsyncMock(return_value=client) + client_ctx.__aexit__ = AsyncMock(return_value=False) + return client_ctx, client + + def test_success(self): + resp = self._make_httpx_resp(json_data={"errcode": 0, "errmsg": "ok"}) + client_ctx, client = self._make_httpx_client(resp) + + with patch("httpx.AsyncClient", return_value=client_ctx): + extra = {"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=abc"} + result = asyncio.run(_send_dingtalk(extra, "ignored", "hello dingtalk")) + + assert result == {"success": True, "platform": "dingtalk", "chat_id": "ignored"} + client.post.assert_awaited_once() + call_kwargs = client.post.await_args + assert call_kwargs[0][0] == "https://oapi.dingtalk.com/robot/send?access_token=abc" + assert call_kwargs[1]["json"] == {"msgtype": "text", "text": {"content": "hello dingtalk"}} + + def test_api_error_in_response_body(self): + """DingTalk always returns HTTP 200 but signals errors via errcode.""" + resp = self._make_httpx_resp(json_data={"errcode": 310000, "errmsg": "sign not match"}) + client_ctx, _ = self._make_httpx_client(resp) + + with patch("httpx.AsyncClient", return_value=client_ctx): + result = asyncio.run(_send_dingtalk( + {"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=bad"}, + "ch", "hi" + )) + + assert "error" in result + assert "sign not match" in result["error"] + + def test_http_error(self): + """If raise_for_status throws, the error is caught and returned.""" + resp = self._make_httpx_resp(status_code=429) + resp.raise_for_status = MagicMock(side_effect=Exception("429 Too Many Requests")) + client_ctx, _ = self._make_httpx_client(resp) + + with patch("httpx.AsyncClient", return_value=client_ctx): + result = asyncio.run(_send_dingtalk( + {"webhook_url": "https://oapi.dingtalk.com/robot/send?access_token=tok"}, + "ch", "hi" + )) + + assert "error" in result + assert "DingTalk send failed" in result["error"] + + def test_missing_config(self): + with patch.dict(os.environ, {"DINGTALK_WEBHOOK_URL": ""}, clear=False): + result = asyncio.run(_send_dingtalk({}, "ch", "hi")) + + assert "error" in result + assert "DINGTALK_WEBHOOK_URL" in result["error"] or "not configured" in result["error"] + + def test_env_var_fallback(self): + resp = self._make_httpx_resp(json_data={"errcode": 0, "errmsg": "ok"}) + client_ctx, client = self._make_httpx_client(resp) + + with patch("httpx.AsyncClient", return_value=client_ctx), \ + patch.dict(os.environ, {"DINGTALK_WEBHOOK_URL": "https://oapi.dingtalk.com/robot/send?access_token=env"}, clear=False): + result = asyncio.run(_send_dingtalk({}, "ch", "hi")) + + assert result["success"] is True + call_kwargs = client.post.await_args + assert "access_token=env" in call_kwargs[0][0] diff --git a/tests/tools/test_skill_env_passthrough.py b/tests/tools/test_skill_env_passthrough.py index 19662f984..19737d2ee 100644 --- a/tests/tools/test_skill_env_passthrough.py +++ b/tests/tools/test_skill_env_passthrough.py @@ -63,6 +63,35 @@ class TestSkillViewRegistersPassthrough: assert result["success"] is True assert is_env_passthrough("TENOR_API_KEY") + def test_remote_backend_persisted_env_vars_registered(self, tmp_path, monkeypatch): + """Remote-backed skills still register locally available env vars.""" + monkeypatch.setenv("TERMINAL_ENV", "docker") + _create_skill( + tmp_path, + "test-skill", + frontmatter_extra=( + "required_environment_variables:\n" + " - name: TENOR_API_KEY\n" + " prompt: Enter your Tenor API key\n" + ), + ) + monkeypatch.setattr("tools.skills_tool.SKILLS_DIR", tmp_path) + + from hermes_cli.config import save_env_value + + save_env_value("TENOR_API_KEY", "persisted-value-123") + monkeypatch.delenv("TENOR_API_KEY", raising=False) + + with patch("tools.skills_tool._secret_capture_callback", None): + from tools.skills_tool import skill_view + + result = json.loads(skill_view(name="test-skill")) + + assert result["success"] is True + assert result["setup_needed"] is False + assert result["missing_required_environment_variables"] == [] + assert is_env_passthrough("TENOR_API_KEY") + def test_missing_env_vars_not_registered(self, tmp_path, monkeypatch): """When a skill declares required_environment_variables but the var is NOT set, it should NOT be registered in the passthrough.""" diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py index bd992ec3d..06a2f88ae 100644 --- a/tests/tools/test_skill_manager_tool.py +++ b/tests/tools/test_skill_manager_tool.py @@ -6,6 +6,7 @@ from unittest.mock import patch from tools.skill_manager_tool import ( _validate_name, + _validate_category, _validate_frontmatter, _validate_file_path, _find_skill, @@ -82,6 +83,22 @@ class TestValidateName: assert "Invalid skill name 'skill@name'" in err +class TestValidateCategory: + def test_valid_categories(self): + assert _validate_category(None) is None + assert _validate_category("") is None + assert _validate_category("devops") is None + assert _validate_category("mlops-v2") is None + + def test_path_traversal_rejected(self): + err = _validate_category("../escape") + assert "Invalid category '../escape'" in err + + def test_absolute_path_rejected(self): + err = _validate_category("/tmp/escape") + assert "Invalid category '/tmp/escape'" in err + + # --------------------------------------------------------------------------- # _validate_frontmatter # --------------------------------------------------------------------------- @@ -191,6 +208,29 @@ class TestCreateSkill: result = _create_skill("my-skill", "no frontmatter here") assert result["success"] is False + def test_create_rejects_category_traversal(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + + with patch("tools.skill_manager_tool.SKILLS_DIR", skills_dir): + result = _create_skill("my-skill", VALID_SKILL_CONTENT, category="../escape") + + assert result["success"] is False + assert "Invalid category '../escape'" in result["error"] + assert not (tmp_path / "escape").exists() + + def test_create_rejects_absolute_category(self, tmp_path): + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + outside = tmp_path / "outside" + + with patch("tools.skill_manager_tool.SKILLS_DIR", skills_dir): + result = _create_skill("my-skill", VALID_SKILL_CONTENT, category=str(outside)) + + assert result["success"] is False + assert f"Invalid category '{outside}'" in result["error"] + assert not (outside / "my-skill" / "SKILL.md").exists() + class TestEditSkill: def test_edit_existing_skill(self, tmp_path): diff --git a/tests/tools/test_skills_tool.py b/tests/tools/test_skills_tool.py index 6af2c83cb..82d8b0dd1 100644 --- a/tests/tools/test_skills_tool.py +++ b/tests/tools/test_skills_tool.py @@ -589,38 +589,38 @@ class TestSkillMatchesPlatform: assert skill_matches_platform({"platforms": None}) is True def test_macos_on_darwin(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" assert skill_matches_platform({"platforms": ["macos"]}) is True def test_macos_on_linux(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" assert skill_matches_platform({"platforms": ["macos"]}) is False def test_linux_on_linux(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" assert skill_matches_platform({"platforms": ["linux"]}) is True def test_linux_on_darwin(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" assert skill_matches_platform({"platforms": ["linux"]}) is False def test_windows_on_win32(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "win32" assert skill_matches_platform({"platforms": ["windows"]}) is True def test_windows_on_linux(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" assert skill_matches_platform({"platforms": ["windows"]}) is False def test_multi_platform_match(self): """Skills listing multiple platforms should match any of them.""" - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" assert skill_matches_platform({"platforms": ["macos", "linux"]}) is True mock_sys.platform = "linux" @@ -630,20 +630,20 @@ class TestSkillMatchesPlatform: def test_string_instead_of_list(self): """A single string value should be treated as a one-element list.""" - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" assert skill_matches_platform({"platforms": "macos"}) is True mock_sys.platform = "linux" assert skill_matches_platform({"platforms": "macos"}) is False def test_case_insensitive(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "darwin" assert skill_matches_platform({"platforms": ["MacOS"]}) is True assert skill_matches_platform({"platforms": ["MACOS"]}) is True def test_unknown_platform_no_match(self): - with patch("tools.skills_tool.sys") as mock_sys: + with patch("agent.skill_utils.sys") as mock_sys: mock_sys.platform = "linux" assert skill_matches_platform({"platforms": ["freebsd"]}) is False @@ -659,7 +659,7 @@ class TestFindAllSkillsPlatformFiltering: def test_excludes_incompatible_platform(self, tmp_path): with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "linux" _make_skill(tmp_path, "universal-skill") @@ -672,7 +672,7 @@ class TestFindAllSkillsPlatformFiltering: def test_includes_matching_platform(self, tmp_path): with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "darwin" _make_skill(tmp_path, "mac-only", frontmatter_extra="platforms: [macos]\n") @@ -684,7 +684,7 @@ class TestFindAllSkillsPlatformFiltering: """Skills without platforms field should appear on any platform.""" with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): mock_sys.platform = "win32" _make_skill(tmp_path, "generic-skill") @@ -695,7 +695,7 @@ class TestFindAllSkillsPlatformFiltering: def test_multi_platform_skill(self, tmp_path): with ( patch("tools.skills_tool.SKILLS_DIR", tmp_path), - patch("tools.skills_tool.sys") as mock_sys, + patch("agent.skill_utils.sys") as mock_sys, ): _make_skill( tmp_path, "cross-plat", frontmatter_extra="platforms: [macos, linux]\n" @@ -813,6 +813,29 @@ class TestSkillViewPrerequisites: assert result["setup_needed"] is False assert result["missing_required_environment_variables"] == [] + def test_remote_backend_treats_persisted_env_as_available( + self, tmp_path, monkeypatch + ): + monkeypatch.setenv("TERMINAL_ENV", "docker") + + with patch("tools.skills_tool.SKILLS_DIR", tmp_path): + _make_skill( + tmp_path, + "remote-ready", + frontmatter_extra="prerequisites:\n env_vars: [PERSISTED_REMOTE_KEY]\n", + ) + from hermes_cli.config import save_env_value + + save_env_value("PERSISTED_REMOTE_KEY", "persisted-value") + monkeypatch.delenv("PERSISTED_REMOTE_KEY", raising=False) + raw = skill_view("remote-ready") + + result = json.loads(raw) + assert result["success"] is True + assert result["setup_needed"] is False + assert result["missing_required_environment_variables"] == [] + assert result["readiness_status"] == "available" + def test_no_setup_metadata_when_no_required_envs(self, tmp_path): with patch("tools.skills_tool.SKILLS_DIR", tmp_path): _make_skill(tmp_path, "plain-skill") @@ -878,17 +901,11 @@ class TestSkillViewPrerequisites: assert result["setup_needed"] is True @pytest.mark.parametrize( - "backend,expected_note", - [ - ("ssh", "remote environment"), - ("daytona", "remote environment"), - ("docker", "docker-backed skills"), - ("singularity", "singularity-backed skills"), - ("modal", "modal-backed skills"), - ], + "backend", + ["ssh", "daytona", "docker", "singularity", "modal"], ) - def test_remote_backend_keeps_setup_needed_after_local_secret_capture( - self, tmp_path, monkeypatch, backend, expected_note + def test_remote_backend_becomes_available_after_local_secret_capture( + self, tmp_path, monkeypatch, backend ): monkeypatch.setenv("TERMINAL_ENV", backend) monkeypatch.delenv("TENOR_API_KEY", raising=False) @@ -926,10 +943,10 @@ class TestSkillViewPrerequisites: result = json.loads(raw) assert result["success"] is True assert len(calls) == 1 - assert result["setup_needed"] is True - assert result["readiness_status"] == "setup_needed" - assert result["missing_required_environment_variables"] == ["TENOR_API_KEY"] - assert expected_note in result["setup_note"].lower() + assert result["setup_needed"] is False + assert result["readiness_status"] == "available" + assert result["missing_required_environment_variables"] == [] + assert "setup_note" not in result def test_skill_view_surfaces_skill_read_errors(self, tmp_path, monkeypatch): with patch("tools.skills_tool.SKILLS_DIR", tmp_path): diff --git a/tests/tools/test_terminal_requirements.py b/tests/tools/test_terminal_requirements.py index c55fc8310..2cbe3f711 100644 --- a/tests/tools/test_terminal_requirements.py +++ b/tests/tools/test_terminal_requirements.py @@ -101,6 +101,24 @@ def test_modal_backend_with_managed_gateway_does_not_require_direct_creds_or_min assert terminal_tool_module.check_terminal_requirements() is True +def test_modal_backend_auto_mode_prefers_managed_gateway_over_direct_creds(monkeypatch, tmp_path): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("HERMES_ENABLE_NOUS_MANAGED_TOOLS", "1") + monkeypatch.setenv("TERMINAL_ENV", "modal") + monkeypatch.setenv("MODAL_TOKEN_ID", "tok-id") + monkeypatch.setenv("MODAL_TOKEN_SECRET", "tok-secret") + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) + monkeypatch.setattr(terminal_tool_module, "is_managed_tool_gateway_ready", lambda _vendor: True) + monkeypatch.setattr( + terminal_tool_module.importlib.util, + "find_spec", + lambda _name: (_ for _ in ()).throw(AssertionError("should not be called")), + ) + + assert terminal_tool_module.check_terminal_requirements() is True + + def test_modal_backend_direct_mode_does_not_fall_back_to_managed(monkeypatch, caplog, tmp_path): _clear_terminal_env(monkeypatch) monkeypatch.setenv("TERMINAL_ENV", "modal") @@ -119,6 +137,26 @@ def test_modal_backend_direct_mode_does_not_fall_back_to_managed(monkeypatch, ca ) +def test_modal_backend_managed_mode_does_not_fall_back_to_direct(monkeypatch, caplog, tmp_path): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "modal") + monkeypatch.setenv("TERMINAL_MODAL_MODE", "managed") + monkeypatch.setenv("MODAL_TOKEN_ID", "tok-id") + monkeypatch.setenv("MODAL_TOKEN_SECRET", "tok-secret") + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("USERPROFILE", str(tmp_path)) + monkeypatch.setattr(terminal_tool_module, "is_managed_tool_gateway_ready", lambda _vendor: False) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "HERMES_ENABLE_NOUS_MANAGED_TOOLS is not enabled" in record.getMessage() + for record in caplog.records + ) + + def test_modal_backend_managed_mode_without_feature_flag_logs_clear_error(monkeypatch, caplog, tmp_path): _clear_terminal_env(monkeypatch) monkeypatch.setenv("TERMINAL_ENV", "modal") diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index d43f89cf1..0cd4c8e3c 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -96,6 +96,7 @@ class TestGetProviderFallbackPriority: monkeypatch.setenv("GROQ_API_KEY", "gsk-test") monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider assert _get_provider({}) == "groq" @@ -130,9 +131,10 @@ class TestExplicitProviderRespected: def test_explicit_local_no_fallback_to_openai(self, monkeypatch): """GH-1774: provider=local must not silently fall back to openai even when an OpenAI API key is set.""" - monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key-here") + monkeypatch.setenv("OPENAI_API_KEY", "***") monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({"provider": "local"}) @@ -141,6 +143,7 @@ class TestExplicitProviderRespected: def test_explicit_local_no_fallback_to_groq(self, monkeypatch): monkeypatch.setenv("GROQ_API_KEY", "gsk-test") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({"provider": "local"}) @@ -181,6 +184,7 @@ class TestExplicitProviderRespected: monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key") monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider # Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect @@ -191,6 +195,7 @@ class TestExplicitProviderRespected: monkeypatch.setenv("GROQ_API_KEY", "gsk-test") monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({}) diff --git a/tests/tools/test_vision_tools.py b/tests/tools/test_vision_tools.py index 4f152cebd..97ee57a11 100644 --- a/tests/tools/test_vision_tools.py +++ b/tests/tools/test_vision_tools.py @@ -354,6 +354,78 @@ class TestErrorLoggingExcInfo: assert warning_records[0].exc_info is not None +class TestVisionSafetyGuards: + @pytest.mark.asyncio + async def test_local_non_image_file_rejected_before_llm_call(self, tmp_path): + secret = tmp_path / "secret.txt" + secret.write_text("TOP-SECRET=1\n", encoding="utf-8") + + with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm: + result = json.loads(await vision_analyze_tool(str(secret), "extract text")) + + assert result["success"] is False + assert "Only real image files are supported" in result["error"] + mock_llm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_blocked_remote_url_short_circuits_before_download(self): + blocked = { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + + with ( + patch("tools.vision_tools.check_website_access", return_value=blocked), + patch("tools.vision_tools._validate_image_url", return_value=True), + patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download, + ): + result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe")) + + assert result["success"] is False + assert "Blocked by website policy" in result["error"] + mock_download.assert_not_awaited() + + @pytest.mark.asyncio + async def test_download_blocks_redirected_final_url(self, tmp_path): + from tools.vision_tools import _download_image + + def fake_check(url): + if url == "https://allowed.test/cat.png": + return None + if url == "https://blocked.test/final.png": + return { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + raise AssertionError(f"unexpected URL checked: {url}") + + class FakeResponse: + url = "https://blocked.test/final.png" + content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16 + + def raise_for_status(self): + return None + + with ( + patch("tools.vision_tools.check_website_access", side_effect=fake_check), + patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls, + pytest.raises(PermissionError, match="Blocked by website policy"), + ): + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=FakeResponse()) + mock_client_cls.return_value = mock_client + + await _download_image("https://allowed.test/cat.png", tmp_path / "cat.png", max_retries=1) + + assert not (tmp_path / "cat.png").exists() + + # --------------------------------------------------------------------------- # check_vision_requirements & get_debug_session_info # --------------------------------------------------------------------------- diff --git a/tests/tools/test_web_tools_config.py b/tests/tools/test_web_tools_config.py index 93ab6846f..9e33d7445 100644 --- a/tests/tools/test_web_tools_config.py +++ b/tests/tools/test_web_tools_config.py @@ -220,13 +220,13 @@ class TestFirecrawlClientConfig: response = MagicMock() response.choices = [MagicMock(message=MagicMock(content="summary text"))] - fake_client = MagicMock(base_url="https://api.openrouter.ai/v1") - fake_client.chat.completions.create = AsyncMock(return_value=response) - with patch( - "tools.web_tools.get_async_text_auxiliary_client", - side_effect=[(None, None), (fake_client, "test-model")], - ): + "tools.web_tools._resolve_web_extract_auxiliary", + side_effect=[(None, None, {}), (MagicMock(base_url="https://api.openrouter.ai/v1"), "test-model", {})], + ), patch( + "tools.web_tools.async_call_llm", + new=AsyncMock(return_value=response), + ) as mock_async_call: assert tools.web_tools.check_auxiliary_model() is False result = await tools.web_tools._call_summarizer_llm( "Some content worth summarizing", @@ -235,7 +235,7 @@ class TestFirecrawlClientConfig: ) assert result == "summary text" - fake_client.chat.completions.create.assert_awaited_once() + mock_async_call.assert_awaited_once() # ── Singleton caching ──────────────────────────────────────────── @@ -299,6 +299,7 @@ class TestBackendSelection: _ENV_KEYS = ( "HERMES_ENABLE_NOUS_MANAGED_TOOLS", + "EXA_API_KEY", "PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", @@ -327,6 +328,13 @@ class TestBackendSelection: with patch("tools.web_tools._load_web_config", return_value={"backend": "parallel"}): assert _get_backend() == "parallel" + def test_config_exa(self): + """web.backend=exa in config → 'exa' regardless of other keys.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "exa"}), \ + patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + assert _get_backend() == "exa" + def test_config_firecrawl(self): """web.backend=firecrawl in config → 'firecrawl' even if Parallel key set.""" from tools.web_tools import _get_backend @@ -368,6 +376,20 @@ class TestBackendSelection: patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): assert _get_backend() == "parallel" + def test_fallback_exa_only_key(self): + """Only EXA_API_KEY set → 'exa'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"EXA_API_KEY": "exa-test"}): + assert _get_backend() == "exa" + + def test_fallback_parallel_takes_priority_over_exa(self): + """Exa should only win the fallback path when it is the only configured backend.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"EXA_API_KEY": "exa-test", "PARALLEL_API_KEY": "par-test"}): + assert _get_backend() == "parallel" + def test_fallback_tavily_only_key(self): """Only TAVILY_API_KEY set → 'tavily'.""" from tools.web_tools import _get_backend @@ -502,6 +524,7 @@ class TestCheckWebApiKey: _ENV_KEYS = ( "HERMES_ENABLE_NOUS_MANAGED_TOOLS", + "EXA_API_KEY", "PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", @@ -527,6 +550,11 @@ class TestCheckWebApiKey: from tools.web_tools import check_web_api_key assert check_web_api_key() is True + def test_exa_key_only(self): + with patch.dict(os.environ, {"EXA_API_KEY": "exa-test"}): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + def test_firecrawl_key_only(self): with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): from tools.web_tools import check_web_api_key @@ -581,3 +609,9 @@ class TestCheckWebApiKey: with patch.dict(os.environ, {"FIRECRAWL_GATEWAY_URL": "http://127.0.0.1:3002"}, clear=False): from tools.web_tools import check_web_api_key assert check_web_api_key() is True + + +def test_web_requires_env_includes_exa_key(): + from tools.web_tools import _web_requires_env + + assert "EXA_API_KEY" in _web_requires_env() diff --git a/tools/__init__.py b/tools/__init__.py index 9b2542296..3214b979e 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1,262 +1,25 @@ #!/usr/bin/env python3 -""" -Tools Package +"""Tools package namespace. -This package contains all the specific tool implementations for the Hermes Agent. -Each module provides specialized functionality for different capabilities: +Keep package import side effects minimal. Importing ``tools`` should not +eagerly import the full tool stack, because several subsystems load tools while +``hermes_cli.config`` is still initializing. -- web_tools: Web search, content extraction, and crawling -- terminal_tool: Command execution (local/docker/modal/daytona/ssh/singularity backends) -- vision_tools: Image analysis and understanding -- mixture_of_agents_tool: Multi-model collaborative reasoning -- image_generation_tool: Text-to-image generation with upscaling +Callers should import concrete submodules directly, for example: -The tools are imported into model_tools.py which provides a unified interface -for the AI agent to access all capabilities. + import tools.web_tools + from tools import browser_tool + +Python will resolve those submodules via the package path without needing them +to be re-exported here. """ -# Export all tools for easy importing -from .web_tools import ( - web_search_tool, - web_extract_tool, - web_crawl_tool, - check_firecrawl_api_key -) -# Primary terminal tool (local/docker/singularity/modal/daytona/ssh) -from .terminal_tool import ( - terminal_tool, - check_terminal_requirements, - cleanup_vm, - cleanup_all_environments, - get_active_environments_info, - register_task_env_overrides, - clear_task_env_overrides, - TERMINAL_TOOL_DESCRIPTION -) - -from .vision_tools import ( - vision_analyze_tool, - check_vision_requirements -) - -from .mixture_of_agents_tool import ( - mixture_of_agents_tool, - check_moa_requirements -) - -from .image_generation_tool import ( - image_generate_tool, - check_image_generation_requirements -) - -from .skills_tool import ( - skills_list, - skill_view, - check_skills_requirements, - SKILLS_TOOL_DESCRIPTION -) - -from .skill_manager_tool import ( - skill_manage, - check_skill_manage_requirements, - SKILL_MANAGE_SCHEMA -) - -# Browser automation tools (agent-browser + Browserbase) -from .browser_tool import ( - browser_navigate, - browser_snapshot, - browser_click, - browser_type, - browser_scroll, - browser_back, - browser_press, - browser_close, - browser_get_images, - browser_vision, - cleanup_browser, - cleanup_all_browsers, - get_active_browser_sessions, - check_browser_requirements, - BROWSER_TOOL_SCHEMAS -) - -# Cronjob management tools (CLI-only, hermes-cli toolset) -from .cronjob_tools import ( - cronjob, - schedule_cronjob, - list_cronjobs, - remove_cronjob, - check_cronjob_requirements, - get_cronjob_tool_definitions, - CRONJOB_SCHEMA, -) - -# RL Training tools (Tinker-Atropos) -from .rl_training_tool import ( - rl_list_environments, - rl_select_environment, - rl_get_current_config, - rl_edit_config, - rl_start_training, - rl_check_status, - rl_stop_training, - rl_get_results, - rl_list_runs, - rl_test_inference, - check_rl_api_keys, - get_missing_keys, -) - -# File manipulation tools (read, write, patch, search) -from .file_tools import ( - read_file_tool, - write_file_tool, - patch_tool, - search_tool, - get_file_tools, - clear_file_ops_cache, -) - -# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI) -from .tts_tool import ( - text_to_speech_tool, - check_tts_requirements, -) - -# Planning & task management tool -from .todo_tool import ( - todo_tool, - check_todo_requirements, - TODO_SCHEMA, - TodoStore, -) - -# Clarifying questions tool (interactive Q&A with the user) -from .clarify_tool import ( - clarify_tool, - check_clarify_requirements, - CLARIFY_SCHEMA, -) - -# Code execution sandbox (programmatic tool calling) -from .code_execution_tool import ( - execute_code, - check_sandbox_requirements, - EXECUTE_CODE_SCHEMA, -) - -# Subagent delegation (spawn child agents with isolated context) -from .delegate_tool import ( - delegate_task, - check_delegate_requirements, - DELEGATE_TASK_SCHEMA, -) - -# File tools have no external requirements - they use the terminal backend def check_file_requirements(): - """File tools only require terminal backend to be available.""" + """File tools only require terminal backend availability.""" from .terminal_tool import check_terminal_requirements + return check_terminal_requirements() -__all__ = [ - # Web tools - 'web_search_tool', - 'web_extract_tool', - 'web_crawl_tool', - 'check_firecrawl_api_key', - # Terminal tools - 'terminal_tool', - 'check_terminal_requirements', - 'cleanup_vm', - 'cleanup_all_environments', - 'get_active_environments_info', - 'register_task_env_overrides', - 'clear_task_env_overrides', - 'TERMINAL_TOOL_DESCRIPTION', - # Vision tools - 'vision_analyze_tool', - 'check_vision_requirements', - # MoA tools - 'mixture_of_agents_tool', - 'check_moa_requirements', - # Image generation tools - 'image_generate_tool', - 'check_image_generation_requirements', - # Skills tools - 'skills_list', - 'skill_view', - 'check_skills_requirements', - 'SKILLS_TOOL_DESCRIPTION', - # Skill management - 'skill_manage', - 'check_skill_manage_requirements', - 'SKILL_MANAGE_SCHEMA', - # Browser automation tools - 'browser_navigate', - 'browser_snapshot', - 'browser_click', - 'browser_type', - 'browser_scroll', - 'browser_back', - 'browser_press', - 'browser_close', - 'browser_get_images', - 'browser_vision', - 'cleanup_browser', - 'cleanup_all_browsers', - 'get_active_browser_sessions', - 'check_browser_requirements', - 'BROWSER_TOOL_SCHEMAS', - # Cronjob management tools (CLI-only) - 'cronjob', - 'schedule_cronjob', - 'list_cronjobs', - 'remove_cronjob', - 'check_cronjob_requirements', - 'get_cronjob_tool_definitions', - 'CRONJOB_SCHEMA', - # RL Training tools - 'rl_list_environments', - 'rl_select_environment', - 'rl_get_current_config', - 'rl_edit_config', - 'rl_start_training', - 'rl_check_status', - 'rl_stop_training', - 'rl_get_results', - 'rl_list_runs', - 'rl_test_inference', - 'check_rl_api_keys', - 'get_missing_keys', - # File manipulation tools - 'read_file_tool', - 'write_file_tool', - 'patch_tool', - 'search_tool', - 'get_file_tools', - 'clear_file_ops_cache', - 'check_file_requirements', - # Text-to-speech tools - 'text_to_speech_tool', - 'check_tts_requirements', - # Planning & task management tool - 'todo_tool', - 'check_todo_requirements', - 'TODO_SCHEMA', - 'TodoStore', - # Clarifying questions tool - 'clarify_tool', - 'check_clarify_requirements', - 'CLARIFY_SCHEMA', - # Code execution sandbox - 'execute_code', - 'check_sandbox_requirements', - 'EXECUTE_CODE_SCHEMA', - # Subagent delegation - 'delegate_task', - 'check_delegate_requirements', - 'DELEGATE_TASK_SCHEMA', -] +__all__ = ["check_file_requirements"] diff --git a/tools/approval.py b/tools/approval.py index f3ae4e1fe..8ae52407f 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -18,6 +18,21 @@ from typing import Optional logger = logging.getLogger(__name__) +# Sensitive write targets that should trigger approval even when referenced +# via shell expansions like $HOME or $HERMES_HOME. +_SSH_SENSITIVE_PATH = r'(?:~|\$home|\$\{home\})/\.ssh(?:/|$)' +_HERMES_ENV_PATH = ( + r'(?:~\/\.hermes/|' + r'(?:\$home|\$\{home\})/\.hermes/|' + r'(?:\$hermes_home|\$\{hermes_home\})/)' + r'\.env\b' +) +_SENSITIVE_WRITE_TARGET = ( + r'(?:/etc/|/dev/sd|' + rf'{_SSH_SENSITIVE_PATH}|' + rf'{_HERMES_ENV_PATH})' +) + # ========================================================================= # Dangerous command patterns # ========================================================================= @@ -46,13 +61,16 @@ DANGEROUS_PATTERNS = [ (r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"), (r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"), (r'\b(bash|sh|zsh|ksh)\s+<\s*>?\s*["\']?{_SENSITIVE_WRITE_TARGET}', "overwrite system file via redirection"), (r'\bxargs\s+.*\brm\b', "xargs with rm"), (r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"), (r'\bfind\b.*-delete\b', "find -delete"), # Gateway protection: never start gateway outside systemd management (r'gateway\s+run\b.*(&\s*$|&\s*;|\bdisown\b|\bsetsid\b)', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"), (r'\bnohup\b.*gateway\s+run\b', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"), + # Self-termination protection: prevent agent from killing its own process + (r'\b(pkill|killall)\b.*\b(hermes|gateway|cli\.py)\b', "kill hermes/gateway process (self-termination)"), ] @@ -456,6 +474,33 @@ def check_dangerous_command(command: str, env_type: str, # Combined pre-exec guard (tirith + dangerous command detection) # ========================================================================= +def _format_tirith_description(tirith_result: dict) -> str: + """Build a human-readable description from tirith findings. + + Includes severity, title, and description for each finding so users + can make an informed approval decision. + """ + findings = tirith_result.get("findings") or [] + if not findings: + summary = tirith_result.get("summary") or "security issue detected" + return f"Security scan: {summary}" + + parts = [] + for f in findings: + severity = f.get("severity", "") + title = f.get("title", "") + desc = f.get("description", "") + if title and desc: + parts.append(f"[{severity}] {title}: {desc}" if severity else f"{title}: {desc}") + elif title: + parts.append(f"[{severity}] {title}" if severity else title) + if not parts: + summary = tirith_result.get("summary") or "security issue detected" + return f"Security scan: {summary}" + + return "Security scan — " + "; ".join(parts) + + def check_all_command_guards(command: str, env_type: str, approval_callback=None) -> dict: """Run all pre-exec security checks and return a single approval decision. @@ -499,24 +544,20 @@ def check_all_command_guards(command: str, env_type: str, # --- Phase 2: Decide --- - # If tirith blocks, block immediately (no approval possible) - if tirith_result["action"] == "block": - summary = tirith_result.get("summary") or "security issue detected" - return { - "approved": False, - "message": f"BLOCKED: Command blocked by security scan ({summary}). Do NOT retry.", - } - # Collect warnings that need approval warnings = [] # list of (pattern_key, description, is_tirith) session_key = os.getenv("HERMES_SESSION_KEY", "default") - if tirith_result["action"] == "warn": + # Tirith block/warn → approvable warning with rich findings. + # Previously, tirith "block" was a hard block with no approval prompt. + # Now both block and warn go through the approval flow so users can + # inspect the explanation and approve if they understand the risk. + if tirith_result["action"] in ("block", "warn"): findings = tirith_result.get("findings") or [] rule_id = findings[0].get("rule_id", "unknown") if findings else "unknown" tirith_key = f"tirith:{rule_id}" - tirith_desc = f"Security scan: {tirith_result.get('summary') or 'security warning detected'}" + tirith_desc = _format_tirith_description(tirith_result) if not is_approved(session_key, tirith_key): warnings.append((tirith_key, tirith_desc, True)) diff --git a/tools/browser_tool.py b/tools/browser_tool.py index 3018d5231..d229419d8 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -1027,7 +1027,7 @@ def _extract_relevant_content( if model: call_kwargs["model"] = model response = call_llm(**call_kwargs) - return response.choices[0].message.content + return (response.choices[0].message.content or "").strip() or _truncate_snapshot(snapshot_text) except Exception: return _truncate_snapshot(snapshot_text) @@ -1557,8 +1557,8 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] effective_task_id = task_id or "default" # Save screenshot to persistent location so it can be shared with users - hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) - screenshots_dir = hermes_home / "browser_screenshots" + from hermes_constants import get_hermes_dir + screenshots_dir = get_hermes_dir("cache/screenshots", "browser_screenshots") screenshot_path = screenshots_dir / f"browser_screenshot_{uuid_mod.uuid4().hex}.png" try: @@ -1657,10 +1657,10 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] call_kwargs["model"] = vision_model response = call_llm(**call_kwargs) - analysis = response.choices[0].message.content + analysis = (response.choices[0].message.content or "").strip() response_data = { "success": True, - "analysis": analysis, + "analysis": analysis or "Vision analysis returned no content.", "screenshot_path": str(screenshot_path), } # Include annotation data if annotated screenshot was taken diff --git a/tools/credential_files.py b/tools/credential_files.py new file mode 100644 index 000000000..56c32d572 --- /dev/null +++ b/tools/credential_files.py @@ -0,0 +1,163 @@ +"""Credential file passthrough registry for remote terminal backends. + +Skills that declare ``required_credential_files`` in their frontmatter need +those files available inside sandboxed execution environments (Modal, Docker). +By default remote backends create bare containers with no host files. + +This module provides a session-scoped registry so skill-declared credential +files (and user-configured overrides) are mounted into remote sandboxes. + +Two sources feed the registry: + +1. **Skill declarations** — when a skill is loaded via ``skill_view``, its + ``required_credential_files`` entries are registered here if the files + exist on the host. +2. **User config** — ``terminal.credential_files`` in config.yaml lets users + explicitly list additional files to mount. + +Remote backends (``tools/environments/modal.py``, ``docker.py``) call +:func:`get_credential_file_mounts` at sandbox creation time. + +Each registered entry is a dict:: + + { + "host_path": "/home/user/.hermes/google_token.json", + "container_path": "/root/.hermes/google_token.json", + } +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Dict, List + +logger = logging.getLogger(__name__) + +# Session-scoped list of credential files to mount. +# Key: container_path (deduplicated), Value: host_path +_registered_files: Dict[str, str] = {} + +# Cache for config-based file list (loaded once per process). +_config_files: List[Dict[str, str]] | None = None + + +def _resolve_hermes_home() -> Path: + return Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) + + +def register_credential_file( + relative_path: str, + container_base: str = "/root/.hermes", +) -> bool: + """Register a credential file for mounting into remote sandboxes. + + *relative_path* is relative to ``HERMES_HOME`` (e.g. ``google_token.json``). + Returns True if the file exists on the host and was registered. + """ + hermes_home = _resolve_hermes_home() + host_path = hermes_home / relative_path + if not host_path.is_file(): + logger.debug("credential_files: skipping %s (not found)", host_path) + return False + + container_path = f"{container_base.rstrip('/')}/{relative_path}" + _registered_files[container_path] = str(host_path) + logger.debug("credential_files: registered %s -> %s", host_path, container_path) + return True + + +def register_credential_files( + entries: list, + container_base: str = "/root/.hermes", +) -> List[str]: + """Register multiple credential files from skill frontmatter entries. + + Each entry is either a string (relative path) or a dict with a ``path`` + key. Returns the list of relative paths that were NOT found on the host + (i.e. missing files). + """ + missing = [] + for entry in entries: + if isinstance(entry, str): + rel_path = entry.strip() + elif isinstance(entry, dict): + rel_path = (entry.get("path") or "").strip() + else: + continue + if not rel_path: + continue + if not register_credential_file(rel_path, container_base): + missing.append(rel_path) + return missing + + +def _load_config_files() -> List[Dict[str, str]]: + """Load ``terminal.credential_files`` from config.yaml (cached).""" + global _config_files + if _config_files is not None: + return _config_files + + result: List[Dict[str, str]] = [] + try: + hermes_home = _resolve_hermes_home() + config_path = hermes_home / "config.yaml" + if config_path.exists(): + import yaml + + with open(config_path) as f: + cfg = yaml.safe_load(f) or {} + cred_files = cfg.get("terminal", {}).get("credential_files") + if isinstance(cred_files, list): + for item in cred_files: + if isinstance(item, str) and item.strip(): + host_path = hermes_home / item.strip() + if host_path.is_file(): + container_path = f"/root/.hermes/{item.strip()}" + result.append({ + "host_path": str(host_path), + "container_path": container_path, + }) + except Exception as e: + logger.debug("Could not read terminal.credential_files from config: %s", e) + + _config_files = result + return _config_files + + +def get_credential_file_mounts() -> List[Dict[str, str]]: + """Return all credential files that should be mounted into remote sandboxes. + + Each item has ``host_path`` and ``container_path`` keys. + Combines skill-registered files and user config. + """ + mounts: Dict[str, str] = {} + + # Skill-registered files + for container_path, host_path in _registered_files.items(): + # Re-check existence (file may have been deleted since registration) + if Path(host_path).is_file(): + mounts[container_path] = host_path + + # Config-based files + for entry in _load_config_files(): + cp = entry["container_path"] + if cp not in mounts and Path(entry["host_path"]).is_file(): + mounts[cp] = entry["host_path"] + + return [ + {"host_path": hp, "container_path": cp} + for cp, hp in mounts.items() + ] + + +def clear_credential_files() -> None: + """Reset the skill-scoped registry (e.g. on session reset).""" + _registered_files.clear() + + +def reset_config_cache() -> None: + """Force re-read of config on next access (for testing).""" + global _config_files + _config_files = None diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 0a023c904..84054c6e2 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -372,7 +372,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr }, "deliver": { "type": "string", - "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" + "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" }, "model": { "type": "string", diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index f974ee8ff..b5b0a57c4 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -289,7 +289,10 @@ def _run_single_child( if interrupted: status = "interrupted" - elif completed and summary: + elif summary: + # A summary means the subagent produced usable output. + # exit_reason ("completed" vs "max_iterations") already + # tells the parent *how* the task ended. status = "completed" else: status = "failed" diff --git a/tools/environments/base.py b/tools/environments/base.py index 896937adf..2b02c3c47 100644 --- a/tools/environments/base.py +++ b/tools/environments/base.py @@ -5,7 +5,7 @@ import os import subprocess from pathlib import Path -from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_home def get_sandbox_dir() -> Path: diff --git a/tools/environments/docker.py b/tools/environments/docker.py index c5546dbe4..a24786d17 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -312,6 +312,24 @@ class DockerEnvironment(BaseEnvironment): elif workspace_explicitly_mounted: logger.debug("Skipping docker cwd mount: /workspace already mounted by user config") + # Mount credential files (OAuth tokens, etc.) declared by skills. + # Read-only so the container can authenticate but not modify host creds. + try: + from tools.credential_files import get_credential_file_mounts + + for mount_entry in get_credential_file_mounts(): + volume_args.extend([ + "-v", + f"{mount_entry['host_path']}:{mount_entry['container_path']}:ro", + ]) + logger.info( + "Docker: mounting credential %s -> %s", + mount_entry["host_path"], + mount_entry["container_path"], + ) + except Exception as e: + logger.debug("Docker: could not load credential file mounts: %s", e) + logger.info(f"Docker volume_args: {volume_args}") all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args + volume_args logger.info(f"Docker run_args: {all_run_args}") @@ -406,8 +424,17 @@ class DockerEnvironment(BaseEnvironment): if effective_stdin is not None: cmd.append("-i") cmd.extend(["-w", work_dir]) - hermes_env = _load_hermes_env_vars() if self._forward_env else {} - for key in self._forward_env: + # Combine explicit docker_forward_env with skill-declared env_passthrough + # vars so skills that declare required_environment_variables (e.g. Notion) + # have their keys forwarded into the container automatically. + forward_keys = set(self._forward_env) + try: + from tools.env_passthrough import get_all_passthrough + forward_keys |= get_all_passthrough() + except Exception: + pass + hermes_env = _load_hermes_env_vars() if forward_keys else {} + for key in sorted(forward_keys): value = os.getenv(key) if value is None: value = hermes_env.get(key) diff --git a/tools/environments/local.py b/tools/environments/local.py index 8ee794e3b..8cd416efa 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -391,12 +391,17 @@ class LocalEnvironment(PersistentShellMixin, BaseEnvironment): effective_stdin = stdin_data user_shell = _find_bash() + # Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line). + # A trailing `; __hermes_rc` glued to `< str: def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]: - """Return a snapshot id for direct Modal restore and whether the key is legacy.""" + """Return a snapshot id and whether it came from the legacy key format.""" snapshots = _load_snapshots() namespaced_key = _direct_snapshot_key(task_id) @@ -104,7 +104,7 @@ def _resolve_modal_image(image_spec: Any) -> Any: class _AsyncWorker: - """Background thread with its own event loop for async-safe swe-rex calls.""" + """Background thread with its own event loop for async-safe Modal calls.""" def __init__(self): self._loop: Optional[asyncio.AbstractEventLoop] = None @@ -136,12 +136,7 @@ class _AsyncWorker: class ModalEnvironment(BaseEnvironment): - """Modal cloud execution via SWE-ReX. - - Uses swe-rex's ModalDeployment directly for sandbox management. - Adds sudo -S support, configurable resources (CPU, memory, disk), - and optional filesystem persistence via Modal's snapshot API. - """ + """Modal cloud execution via native Modal sandboxes.""" def __init__( self, @@ -157,39 +152,73 @@ class ModalEnvironment(BaseEnvironment): self._persistent = persistent_filesystem self._task_id = task_id self._base_image = image - self._deployment = None + self._sandbox = None + self._app = None self._worker = _AsyncWorker() + self._synced_creds: Dict[str, tuple] = {} sandbox_kwargs = dict(modal_sandbox_kwargs or {}) restored_snapshot_id = None restored_from_legacy_key = False if self._persistent: - restored_snapshot_id, restored_from_legacy_key = _get_snapshot_restore_candidate(self._task_id) + restored_snapshot_id, restored_from_legacy_key = _get_snapshot_restore_candidate( + self._task_id + ) if restored_snapshot_id: logger.info("Modal: restoring from snapshot %s", restored_snapshot_id[:20]) + import modal as _modal + + cred_mounts = [] + try: + from tools.credential_files import get_credential_file_mounts + + for mount_entry in get_credential_file_mounts(): + cred_mounts.append( + _modal.Mount.from_local_file( + mount_entry["host_path"], + remote_path=mount_entry["container_path"], + ) + ) + logger.info( + "Modal: mounting credential %s -> %s", + mount_entry["host_path"], + mount_entry["container_path"], + ) + except Exception as e: + logger.debug("Modal: could not load credential file mounts: %s", e) + self._worker.start() - from swerex.deployment.modal import ModalDeployment - - async def _create_and_start(image_spec: Any): - deployment = ModalDeployment( + async def _create_sandbox(image_spec: Any): + app = await _modal.App.lookup.aio("hermes-agent", create_if_missing=True) + create_kwargs = dict(sandbox_kwargs) + if cred_mounts: + existing_mounts = list(create_kwargs.pop("mounts", [])) + existing_mounts.extend(cred_mounts) + create_kwargs["mounts"] = existing_mounts + sandbox = await _modal.Sandbox.create.aio( + "sleep", + "infinity", image=image_spec, - startup_timeout=180.0, - runtime_timeout=3600.0, - deployment_timeout=3600.0, - install_pipx=True, - modal_sandbox_kwargs=sandbox_kwargs, + app=app, + timeout=int(create_kwargs.pop("timeout", 3600)), + **create_kwargs, ) - await deployment.start() - return deployment + return app, sandbox try: target_image_spec = restored_snapshot_id or image try: + # _resolve_modal_image keeps the Modal bootstrap fix together: + # it applies ensurepip via setup_dockerfile_commands before + # Modal builds or restores the image. effective_image = _resolve_modal_image(target_image_spec) - self._deployment = self._worker.run_coroutine(_create_and_start(effective_image)) + self._app, self._sandbox = self._worker.run_coroutine( + _create_sandbox(effective_image), + timeout=300, + ) except Exception as exc: if not restored_snapshot_id: raise @@ -201,18 +230,83 @@ class ModalEnvironment(BaseEnvironment): ) _delete_direct_snapshot(self._task_id, restored_snapshot_id) base_image = _resolve_modal_image(image) - self._deployment = self._worker.run_coroutine(_create_and_start(base_image)) + self._app, self._sandbox = self._worker.run_coroutine( + _create_sandbox(base_image), + timeout=300, + ) else: if restored_snapshot_id and restored_from_legacy_key: _store_direct_snapshot(self._task_id, restored_snapshot_id) - logger.info("Modal: migrated legacy snapshot entry for task %s", self._task_id) + logger.info( + "Modal: migrated legacy snapshot entry for task %s", + self._task_id, + ) except Exception: self._worker.stop() raise - def execute(self, command: str, cwd: str = "", *, - timeout: int | None = None, - stdin_data: str | None = None) -> dict: + logger.info("Modal: sandbox created (task=%s)", self._task_id) + + def _sync_credential_files(self) -> None: + """Push credential files into the running sandbox.""" + try: + from tools.credential_files import get_credential_file_mounts + + mounts = get_credential_file_mounts() + if not mounts: + return + + for entry in mounts: + host_path = entry["host_path"] + container_path = entry["container_path"] + hp = Path(host_path) + try: + stat = hp.stat() + file_key = (stat.st_mtime, stat.st_size) + except OSError: + continue + + if self._synced_creds.get(container_path) == file_key: + continue + + try: + content = hp.read_text(encoding="utf-8") + except Exception: + continue + + import base64 + + b64 = base64.b64encode(content.encode("utf-8")).decode("ascii") + container_dir = str(Path(container_path).parent) + cmd = ( + f"mkdir -p {shlex.quote(container_dir)} && " + f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(container_path)}" + ) + + async def _write(): + proc = await self._sandbox.exec.aio("bash", "-c", cmd) + await proc.wait.aio() + + self._worker.run_coroutine(_write(), timeout=15) + self._synced_creds[container_path] = file_key + logger.debug( + "Modal: synced credential %s -> %s", + host_path, + container_path, + ) + except Exception as e: + logger.debug("Modal: credential file sync failed: %s", e) + + def execute( + self, + command: str, + cwd: str = "", + *, + timeout: int | None = None, + stdin_data: str | None = None, + ) -> dict: + self._sync_credential_files() + if stdin_data is not None: marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}" while marker in stdin_data: @@ -221,42 +315,48 @@ class ModalEnvironment(BaseEnvironment): exec_command, sudo_stdin = self._prepare_command(command) - # Modal sandboxes execute commands via the Modal SDK and cannot pipe - # subprocess stdin directly the way a local Popen can. When a sudo - # password is present, use a shell-level pipe from printf so that the - # password feeds sudo -S without appearing as an echo argument embedded - # in the shell string. + # Modal sandboxes execute commands via exec() and cannot pipe + # subprocess stdin directly. When a sudo password is present, + # use a shell-level pipe from printf. if sudo_stdin is not None: - import shlex exec_command = ( f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}" ) - from swerex.runtime.abstract import Command as RexCommand - effective_cwd = cwd or self.cwd effective_timeout = timeout or self.timeout + full_command = f"cd {shlex.quote(effective_cwd)} && {exec_command}" result_holder = {"value": None, "error": None} def _run(): try: async def _do_execute(): - return await self._deployment.runtime.execute( - RexCommand( - command=exec_command, - shell=True, - check=False, - cwd=effective_cwd, - timeout=effective_timeout, - merge_output_streams=True, - ) + process = await self._sandbox.exec.aio( + "bash", + "-c", + full_command, + timeout=effective_timeout, ) + stdout = await process.stdout.read.aio() + stderr = await process.stderr.read.aio() + exit_code = await process.wait.aio() + if isinstance(stdout, bytes): + stdout = stdout.decode("utf-8", errors="replace") + if isinstance(stderr, bytes): + stderr = stderr.decode("utf-8", errors="replace") + output = stdout + if stderr: + output = f"{stdout}\n{stderr}" if stdout else stderr + return output, exit_code - output = self._worker.run_coroutine(_do_execute()) + output, exit_code = self._worker.run_coroutine( + _do_execute(), + timeout=effective_timeout + 30, + ) result_holder["value"] = { - "output": output.stdout, - "returncode": output.exit_code, + "output": output, + "returncode": exit_code, } except Exception as e: result_holder["error"] = e @@ -268,7 +368,7 @@ class ModalEnvironment(BaseEnvironment): if is_interrupted(): try: self._worker.run_coroutine( - asyncio.wait_for(self._deployment.stop(), timeout=10), + self._sandbox.terminate.aio(), timeout=15, ) except Exception: @@ -279,44 +379,46 @@ class ModalEnvironment(BaseEnvironment): } if result_holder["error"]: - return {"output": f"Modal execution error: {result_holder['error']}", "returncode": 1} + return { + "output": f"Modal execution error: {result_holder['error']}", + "returncode": 1, + } return result_holder["value"] def cleanup(self): """Snapshot the filesystem (if persistent) then stop the sandbox.""" - if self._deployment is None: + if self._sandbox is None: return if self._persistent: try: - sandbox = getattr(self._deployment, "_sandbox", None) - if sandbox: - async def _snapshot(): - img = await sandbox.snapshot_filesystem.aio() - return img.object_id + async def _snapshot(): + img = await self._sandbox.snapshot_filesystem.aio() + return img.object_id - try: - snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60) - except Exception: - snapshot_id = None + try: + snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60) + except Exception: + snapshot_id = None - if snapshot_id: - _store_direct_snapshot(self._task_id, snapshot_id) - logger.info( - "Modal: saved filesystem snapshot %s for task %s", - snapshot_id[:20], - self._task_id, - ) + if snapshot_id: + _store_direct_snapshot(self._task_id, snapshot_id) + logger.info( + "Modal: saved filesystem snapshot %s for task %s", + snapshot_id[:20], + self._task_id, + ) except Exception as e: logger.warning("Modal: filesystem snapshot failed: %s", e) try: self._worker.run_coroutine( - asyncio.wait_for(self._deployment.stop(), timeout=10), + self._sandbox.terminate.aio(), timeout=15, ) except Exception: pass finally: self._worker.stop() - self._deployment = None + self._sandbox = None + self._app = None diff --git a/tools/environments/singularity.py b/tools/environments/singularity.py index 72afbac5f..aa9c7b009 100644 --- a/tools/environments/singularity.py +++ b/tools/environments/singularity.py @@ -16,7 +16,7 @@ import uuid from pathlib import Path from typing import Any, Dict, Optional -from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_home from tools.environments.base import BaseEnvironment from tools.interrupt import is_interrupted diff --git a/tools/file_operations.py b/tools/file_operations.py index ebd4d601d..96bdc2d53 100644 --- a/tools/file_operations.py +++ b/tools/file_operations.py @@ -32,6 +32,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Optional, List, Dict, Any from pathlib import Path +from hermes_constants import get_hermes_home # --------------------------------------------------------------------------- @@ -46,7 +47,7 @@ WRITE_DENIED_PATHS = { os.path.join(_HOME, ".ssh", "id_rsa"), os.path.join(_HOME, ".ssh", "id_ed25519"), os.path.join(_HOME, ".ssh", "config"), - os.path.join(_HOME, ".hermes", ".env"), + str(get_hermes_home() / ".env"), os.path.join(_HOME, ".bashrc"), os.path.join(_HOME, ".zshrc"), os.path.join(_HOME, ".profile"), diff --git a/tools/file_tools.py b/tools/file_tools.py index 519178c00..7387c4dcb 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -171,8 +171,9 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = # Security: block direct reads of internal Hermes cache/index files # to prevent prompt injection via catalog or hub metadata files. import pathlib as _pathlib + from hermes_constants import get_hermes_home as _get_hh _resolved = _pathlib.Path(path).expanduser().resolve() - _hermes_home = _pathlib.Path("~/.hermes").expanduser().resolve() + _hermes_home = _get_hh().resolve() _blocked_dirs = [ _hermes_home / "skills" / ".hub" / "index-cache", _hermes_home / "skills" / ".hub", diff --git a/tools/honcho_tools.py b/tools/honcho_tools.py index 4aa86d57a..c3a1ac59c 100644 --- a/tools/honcho_tools.py +++ b/tools/honcho_tools.py @@ -45,8 +45,23 @@ def clear_session_context() -> None: # ── Availability check ── def _check_honcho_available() -> bool: - """Tool is only available when Honcho is active.""" - return _session_manager is not None and _session_key is not None + """Tool is available when Honcho is active OR configured. + + At banner time the session context hasn't been injected yet, but if + a valid config exists the tools *will* activate once the agent starts. + Returning True for "configured" prevents the banner from marking + honcho tools as red/disabled when they're actually going to work. + """ + # Fast path: session already active (mid-conversation) + if _session_manager is not None and _session_key is not None: + return True + # Slow path: check if Honcho is configured (banner time) + try: + from honcho_integration.client import HonchoClientConfig + cfg = HonchoClientConfig.from_global_config() + return cfg.enabled and bool(cfg.api_key or cfg.base_url) + except Exception: + return False def _resolve_session_context(**kwargs): diff --git a/tools/managed_tool_gateway.py b/tools/managed_tool_gateway.py index 4d9da52bf..d3bec0678 100644 --- a/tools/managed_tool_gateway.py +++ b/tools/managed_tool_gateway.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from dataclasses import dataclass from typing import Callable, Optional -from hermes_cli.config import get_hermes_home +from hermes_constants import get_hermes_home from tools.tool_backend_helpers import managed_nous_tools_enabled _DEFAULT_TOOL_GATEWAY_DOMAIN = "nousresearch.com" diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index f539586eb..4c762150e 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -70,6 +70,7 @@ Thread safety: """ import asyncio +import inspect import json import logging import math @@ -89,6 +90,8 @@ logger = logging.getLogger(__name__) _MCP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False +_MCP_NOTIFICATION_TYPES = False +_MCP_MESSAGE_HANDLER_SUPPORTED = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -98,6 +101,13 @@ try: _MCP_HTTP_AVAILABLE = True except ImportError: _MCP_HTTP_AVAILABLE = False + # Prefer the non-deprecated API (mcp >= 1.24.0); fall back to the + # deprecated wrapper for older SDK versions. + try: + from mcp.client.streamable_http import streamable_http_client + _MCP_NEW_HTTP = True + except ImportError: + _MCP_NEW_HTTP = False # Sampling types -- separated so older SDK versions don't break MCP support try: from mcp.types import ( @@ -112,9 +122,39 @@ try: _MCP_SAMPLING_TYPES = True except ImportError: logger.debug("MCP sampling types not available -- sampling disabled") + # Notification types for dynamic tool discovery (tools/list_changed) + try: + from mcp.types import ( + ServerNotification, + ToolListChangedNotification, + PromptListChangedNotification, + ResourceListChangedNotification, + ) + _MCP_NOTIFICATION_TYPES = True + except ImportError: + logger.debug("MCP notification types not available -- dynamic tool discovery disabled") except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") + +def _check_message_handler_support() -> bool: + """Check if ClientSession accepts ``message_handler`` kwarg. + + Inspects the constructor signature for backward compatibility with older + MCP SDK versions that don't support notification handlers. + """ + if not _MCP_AVAILABLE: + return False + try: + return "message_handler" in inspect.signature(ClientSession).parameters + except (TypeError, ValueError): + return False + + +_MCP_MESSAGE_HANDLER_SUPPORTED = _check_message_handler_support() +if _MCP_AVAILABLE and not _MCP_MESSAGE_HANDLER_SUPPORTED: + logger.debug("MCP SDK does not support message_handler -- dynamic tool discovery disabled") + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- @@ -690,7 +730,7 @@ class MCPServerTask: __slots__ = ( "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", - "_sampling", "_registered_tool_names", "_auth_type", + "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", ) def __init__(self, name: str): @@ -706,11 +746,80 @@ class MCPServerTask: self._sampling: Optional[SamplingHandler] = None self._registered_tool_names: list[str] = [] self._auth_type: str = "" + self._refresh_lock = asyncio.Lock() def _is_http(self) -> bool: """Check if this server uses HTTP transport.""" return "url" in self._config + # ----- Dynamic tool discovery (notifications/tools/list_changed) ----- + + def _make_message_handler(self): + """Build a ``message_handler`` callback for ``ClientSession``. + + Dispatches on notification type. Only ``ToolListChangedNotification`` + triggers a refresh; prompt and resource change notifications are + logged as stubs for future work. + """ + async def _handler(message): + try: + if isinstance(message, Exception): + logger.debug("MCP message handler (%s): exception: %s", self.name, message) + return + if _MCP_NOTIFICATION_TYPES and isinstance(message, ServerNotification): + match message.root: + case ToolListChangedNotification(): + logger.info( + "MCP server '%s': received tools/list_changed notification", + self.name, + ) + await self._refresh_tools() + case PromptListChangedNotification(): + logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) + case ResourceListChangedNotification(): + logger.debug("MCP server '%s': resources/list_changed (ignored)", self.name) + case _: + pass + except Exception: + logger.exception("Error in MCP message handler for '%s'", self.name) + return _handler + + async def _refresh_tools(self): + """Re-fetch tools from the server and update the registry. + + Called when the server sends ``notifications/tools/list_changed``. + The lock prevents overlapping refreshes from rapid-fire notifications. + After the initial ``await`` (list_tools), all mutations are synchronous + — atomic from the event loop's perspective. + """ + from tools.registry import registry + from toolsets import TOOLSETS + + async with self._refresh_lock: + # 1. Fetch current tool list from server + tools_result = await self.session.list_tools() + new_mcp_tools = tools_result.tools if hasattr(tools_result, "tools") else [] + + # 2. Remove old tools from hermes-* umbrella toolsets + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): + ts["tools"] = [t for t in ts["tools"] if t not in self._registered_tool_names] + + # 3. Deregister old tools from the central registry + for prefixed_name in self._registered_tool_names: + registry.deregister(prefixed_name) + + # 4. Re-register with fresh tool list + self._tools = new_mcp_tools + self._registered_tool_names = _register_server_tools( + self.name, self, self._config + ) + + logger.info( + "MCP server '%s': dynamically refreshed %d tool(s)", + self.name, len(self._registered_tool_names), + ) + async def _run_stdio(self, config: dict): """Run the server using stdio transport.""" command = config.get("command") @@ -731,6 +840,8 @@ class MCPServerTask: ) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} + if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: + sampling_kwargs["message_handler"] = self._make_message_handler() async with stdio_client(server_params) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: await session.initialize() @@ -762,21 +873,52 @@ class MCPServerTask: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} - _http_kwargs: dict = { - "headers": headers, - "timeout": float(connect_timeout), - } - if _oauth_auth is not None: - _http_kwargs["auth"] = _oauth_auth - async with streamablehttp_client(url, **_http_kwargs) as ( - read_stream, write_stream, _get_session_id, - ): - async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: - await session.initialize() - self.session = session - await self._discover_tools() - self._ready.set() - await self._shutdown_event.wait() + if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: + sampling_kwargs["message_handler"] = self._make_message_handler() + + if _MCP_NEW_HTTP: + # New API (mcp >= 1.24.0): build an explicit httpx.AsyncClient + # matching the SDK's own create_mcp_http_client defaults. + import httpx + + client_kwargs: dict = { + "follow_redirects": True, + "timeout": httpx.Timeout(float(connect_timeout), read=300.0), + } + if headers: + client_kwargs["headers"] = headers + if _oauth_auth is not None: + client_kwargs["auth"] = _oauth_auth + + # Caller owns the client lifecycle — the SDK skips cleanup when + # http_client is provided, so we wrap in async-with. + async with httpx.AsyncClient(**client_kwargs) as http_client: + async with streamable_http_client(url, http_client=http_client) as ( + read_stream, write_stream, _get_session_id, + ): + async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: + await session.initialize() + self.session = session + await self._discover_tools() + self._ready.set() + await self._shutdown_event.wait() + else: + # Deprecated API (mcp < 1.24.0): manages httpx client internally. + _http_kwargs: dict = { + "headers": headers, + "timeout": float(connect_timeout), + } + if _oauth_auth is not None: + _http_kwargs["auth"] = _oauth_auth + async with streamablehttp_client(url, **_http_kwargs) as ( + read_stream, write_stream, _get_session_id, + ): + async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session: + await session.initialize() + self.session = session + await self._discover_tools() + self._ready.set() + await self._shutdown_event.wait() async def _discover_tools(self): """Discover tools from the connected session.""" @@ -797,7 +939,7 @@ class MCPServerTask: """ self._config = config self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) - self._auth_type = config.get("auth", "").lower().strip() + self._auth_type = (config.get("auth") or "").lower().strip() # Set up sampling handler if enabled and SDK types are available sampling_config = config.get("sampling", {}) @@ -1486,24 +1628,19 @@ def _existing_tool_names() -> List[str]: return names -async def _discover_and_register_server(name: str, config: dict) -> List[str]: - """Connect to a single MCP server, discover tools, and register them. +def _register_server_tools(name: str, server: MCPServerTask, config: dict) -> List[str]: + """Register tools from an already-connected server into the registry. - Also registers utility tools for MCP Resources and Prompts support - (list_resources, read_resource, list_prompts, get_prompt). + Handles include/exclude filtering, utility tools, toolset creation, + and hermes-* umbrella toolset injection. - Returns list of registered tool names. + Used by both initial discovery and dynamic refresh (list_changed). + + Returns: + List of registered prefixed tool names. """ from tools.registry import registry - from toolsets import create_custom_toolset - - connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) - server = await asyncio.wait_for( - _connect_server(name, config), - timeout=connect_timeout, - ) - with _lock: - _servers[name] = server + from toolsets import create_custom_toolset, TOOLSETS registered_names: List[str] = [] toolset_name = f"mcp-{name}" @@ -1589,8 +1726,6 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: ) registered_names.append(util_name) - server._registered_tool_names = list(registered_names) - # Create a custom toolset so these tools are discoverable if registered_names: create_custom_toolset( @@ -1598,6 +1733,31 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: description=f"MCP tools from {name} server", tools=registered_names, ) + # Inject into hermes-* umbrella toolsets for default behavior + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): + for tool_name in registered_names: + if tool_name not in ts["tools"]: + ts["tools"].append(tool_name) + + return registered_names + + +async def _discover_and_register_server(name: str, config: dict) -> List[str]: + """Connect to a single MCP server, discover tools, and register them. + + Returns list of registered tool names. + """ + connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + server = await asyncio.wait_for( + _connect_server(name, config), + timeout=connect_timeout, + ) + with _lock: + _servers[name] = server + + registered_names = _register_server_tools(name, server, config) + server._registered_tool_names = list(registered_names) transport_type = "HTTP" if "url" in config else "stdio" logger.info( diff --git a/tools/memory_tool.py b/tools/memory_tool.py index a344c8745..2d687e94d 100644 --- a/tools/memory_tool.py +++ b/tools/memory_tool.py @@ -339,7 +339,7 @@ class MemoryStore: entries = self._entries_for(target) current = self._char_count(target) limit = self._char_limit(target) - pct = int((current / limit) * 100) if limit > 0 else 0 + pct = min(100, int((current / limit) * 100)) if limit > 0 else 0 resp = { "success": True, @@ -360,7 +360,7 @@ class MemoryStore: limit = self._char_limit(target) content = ENTRY_DELIMITER.join(entries) current = len(content) - pct = int((current / limit) * 100) if limit > 0 else 0 + pct = min(100, int((current / limit) * 100)) if limit > 0 else 0 if target == "user": header = f"USER PROFILE (who the user is) [{pct}% — {current:,}/{limit:,} chars]" diff --git a/tools/mixture_of_agents_tool.py b/tools/mixture_of_agents_tool.py index 18d8840c1..9367a3f1e 100644 --- a/tools/mixture_of_agents_tool.py +++ b/tools/mixture_of_agents_tool.py @@ -52,6 +52,7 @@ import asyncio import datetime from typing import Dict, Any, List, Optional from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key +from agent.auxiliary_client import extract_content_or_reasoning from tools.debug_helpers import DebugSession logger = logging.getLogger(__name__) @@ -143,7 +144,13 @@ async def _run_reference_model_safe( response = await _get_openrouter_client().chat.completions.create(**api_params) - content = response.choices[0].message.content.strip() + content = extract_content_or_reasoning(response) + if not content: + # Reasoning-only response — let the retry loop handle it + logger.warning("%s returned empty content (attempt %s/%s), retrying", model, attempt + 1, max_retries) + if attempt < max_retries - 1: + await asyncio.sleep(min(2 ** (attempt + 1), 60)) + continue logger.info("%s responded (%s characters)", model, len(content)) return model, content, True @@ -211,7 +218,14 @@ async def _run_aggregator_model( response = await _get_openrouter_client().chat.completions.create(**api_params) - content = response.choices[0].message.content.strip() + content = extract_content_or_reasoning(response) + + # Retry once on empty content (reasoning-only response) + if not content: + logger.warning("Aggregator returned empty content, retrying once") + response = await _get_openrouter_client().chat.completions.create(**api_params) + content = extract_content_or_reasoning(response) + logger.info("Aggregation complete (%s characters)", len(content)) return content diff --git a/tools/patch_parser.py b/tools/patch_parser.py index bef196e50..1a11f1413 100644 --- a/tools/patch_parser.py +++ b/tools/patch_parser.py @@ -419,6 +419,23 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]: if error: return False, f"Could not apply hunk: {error}" + else: + # Addition-only hunk (no context or removed lines). + # Insert at the location indicated by the context hint, or at end of file. + insert_text = '\n'.join(replace_lines) + if hunk.context_hint: + hint_pos = new_content.find(hunk.context_hint) + if hint_pos != -1: + # Insert after the line containing the context hint + eol = new_content.find('\n', hint_pos) + if eol != -1: + new_content = new_content[:eol + 1] + insert_text + '\n' + new_content[eol + 1:] + else: + new_content = new_content + '\n' + insert_text + else: + new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n' + else: + new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n' # Write new content write_result = file_ops.write_file(op.file_path, new_content) diff --git a/tools/registry.py b/tools/registry.py index c13d98502..432e1f074 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -87,6 +87,23 @@ class ToolRegistry: if check_fn and toolset not in self._toolset_checks: self._toolset_checks[toolset] = check_fn + def deregister(self, name: str) -> None: + """Remove a tool from the registry. + + Also cleans up the toolset check if no other tools remain in the + same toolset. Used by MCP dynamic tool discovery to nuke-and-repave + when a server sends ``notifications/tools/list_changed``. + """ + entry = self._tools.pop(name, None) + if entry is None: + return + # Drop the toolset check if this was the last tool in that toolset + if entry.toolset in self._toolset_checks and not any( + e.toolset == entry.toolset for e in self._tools.values() + ): + self._toolset_checks.pop(entry.toolset, None) + logger.debug("Deregistered tool: %s", name) + # ------------------------------------------------------------------ # Schema retrieval # ------------------------------------------------------------------ @@ -98,21 +115,26 @@ class ToolRegistry: are included. """ result = [] + check_results: Dict[Callable, bool] = {} for name in sorted(tool_names): entry = self._tools.get(name) if not entry: continue if entry.check_fn: - try: - if not entry.check_fn(): + if entry.check_fn not in check_results: + try: + check_results[entry.check_fn] = bool(entry.check_fn()) + except Exception: + check_results[entry.check_fn] = False if not quiet: - logger.debug("Tool %s unavailable (check failed)", name) - continue - except Exception: + logger.debug("Tool %s check raised; skipping", name) + if not check_results[entry.check_fn]: if not quiet: - logger.debug("Tool %s check raised; skipping", name) + logger.debug("Tool %s unavailable (check failed)", name) continue - result.append({"type": "function", "function": entry.schema}) + # Ensure schema always has a "name" field — use entry.name as fallback + schema_with_name = {**entry.schema, "name": entry.name} + result.append({"type": "function", "function": schema_with_name}) return result # ------------------------------------------------------------------ @@ -146,6 +168,15 @@ class ToolRegistry: """Return sorted list of all registered tool names.""" return sorted(self._tools.keys()) + def get_schema(self, name: str) -> Optional[dict]: + """Return a tool's raw schema dict, bypassing check_fn filtering. + + Useful for token estimation and introspection where availability + doesn't matter — only the schema content does. + """ + entry = self._tools.get(name) + return entry.schema if entry else None + def get_toolset_for_tool(self, name: str) -> Optional[str]: """Return the toolset a tool belongs to, or None.""" entry = self._tools.get(name) diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index cf983445b..d12eed509 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -15,6 +15,7 @@ import time logger = logging.getLogger(__name__) _TELEGRAM_TOPIC_TARGET_RE = re.compile(r"^\s*(-?\d+)(?::(\d+))?\s*$") +_FEISHU_TARGET_RE = re.compile(r"^\s*((?:oc|ou|on|chat|open)_[-A-Za-z0-9]+)(?::([-A-Za-z0-9_]+))?\s*$") _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"} _AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"} @@ -128,6 +129,8 @@ def _handle_send(args): "mattermost": Platform.MATTERMOST, "homeassistant": Platform.HOMEASSISTANT, "dingtalk": Platform.DINGTALK, + "feishu": Platform.FEISHU, + "wecom": Platform.WECOM, "email": Platform.EMAIL, "sms": Platform.SMS, } @@ -198,6 +201,10 @@ def _parse_target_ref(platform_name: str, target_ref: str): match = _TELEGRAM_TOPIC_TARGET_RE.fullmatch(target_ref) if match: return match.group(1), match.group(2), True + if platform_name == "feishu": + match = _FEISHU_TARGET_RE.fullmatch(target_ref) + if match: + return match.group(1), match.group(2), True if target_ref.lstrip("-").isdigit(): return target_ref, None, True return None, None, False @@ -280,6 +287,13 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, from gateway.platforms.discord import DiscordAdapter from gateway.platforms.slack import SlackAdapter + # Feishu adapter import is optional (requires lark-oapi) + try: + from gateway.platforms.feishu import FeishuAdapter + _feishu_available = True + except ImportError: + _feishu_available = False + media_files = media_files or [] # Platform message length limits (from adapter class attributes) @@ -288,6 +302,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, Platform.DISCORD: DiscordAdapter.MAX_MESSAGE_LENGTH, Platform.SLACK: SlackAdapter.MAX_MESSAGE_LENGTH, } + if _feishu_available: + _MAX_LENGTHS[Platform.FEISHU] = FeishuAdapter.MAX_MESSAGE_LENGTH # Smart-chunk the message to fit within platform limits. # For short messages or platforms without a known limit this is a no-op. @@ -343,6 +359,18 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, result = await _send_email(pconfig.extra, chat_id, chunk) elif platform == Platform.SMS: result = await _send_sms(pconfig.api_key, chat_id, chunk) + elif platform == Platform.MATTERMOST: + result = await _send_mattermost(pconfig.token, pconfig.extra, chat_id, chunk) + elif platform == Platform.MATRIX: + result = await _send_matrix(pconfig.token, pconfig.extra, chat_id, chunk) + elif platform == Platform.HOMEASSISTANT: + result = await _send_homeassistant(pconfig.token, pconfig.extra, chat_id, chunk) + elif platform == Platform.DINGTALK: + result = await _send_dingtalk(pconfig.extra, chat_id, chunk) + elif platform == Platform.FEISHU: + result = await _send_feishu(pconfig, chat_id, chunk, thread_id=thread_id) + elif platform == Platform.WECOM: + result = await _send_wecom(pconfig.extra, chat_id, chunk) else: result = {"error": f"Direct sending not yet implemented for {platform.value}"} @@ -666,6 +694,193 @@ async def _send_sms(auth_token, chat_id, message): return {"error": f"SMS send failed: {e}"} +async def _send_mattermost(token, extra, chat_id, message): + """Send via Mattermost REST API.""" + try: + import aiohttp + except ImportError: + return {"error": "aiohttp not installed. Run: pip install aiohttp"} + try: + base_url = (extra.get("url") or os.getenv("MATTERMOST_URL", "")).rstrip("/") + token = token or os.getenv("MATTERMOST_TOKEN", "") + if not base_url or not token: + return {"error": "Mattermost not configured (MATTERMOST_URL, MATTERMOST_TOKEN required)"} + url = f"{base_url}/api/v4/posts" + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session: + async with session.post(url, headers=headers, json={"channel_id": chat_id, "message": message}) as resp: + if resp.status not in (200, 201): + body = await resp.text() + return {"error": f"Mattermost API error ({resp.status}): {body}"} + data = await resp.json() + return {"success": True, "platform": "mattermost", "chat_id": chat_id, "message_id": data.get("id")} + except Exception as e: + return {"error": f"Mattermost send failed: {e}"} + + +async def _send_matrix(token, extra, chat_id, message): + """Send via Matrix Client-Server API.""" + try: + import aiohttp + except ImportError: + return {"error": "aiohttp not installed. Run: pip install aiohttp"} + try: + homeserver = (extra.get("homeserver") or os.getenv("MATRIX_HOMESERVER", "")).rstrip("/") + token = token or os.getenv("MATRIX_ACCESS_TOKEN", "") + if not homeserver or not token: + return {"error": "Matrix not configured (MATRIX_HOMESERVER, MATRIX_ACCESS_TOKEN required)"} + txn_id = f"hermes_{int(time.time() * 1000)}" + url = f"{homeserver}/_matrix/client/v3/rooms/{chat_id}/send/m.room.message/{txn_id}" + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session: + async with session.put(url, headers=headers, json={"msgtype": "m.text", "body": message}) as resp: + if resp.status not in (200, 201): + body = await resp.text() + return {"error": f"Matrix API error ({resp.status}): {body}"} + data = await resp.json() + return {"success": True, "platform": "matrix", "chat_id": chat_id, "message_id": data.get("event_id")} + except Exception as e: + return {"error": f"Matrix send failed: {e}"} + + +async def _send_homeassistant(token, extra, chat_id, message): + """Send via Home Assistant notify service.""" + try: + import aiohttp + except ImportError: + return {"error": "aiohttp not installed. Run: pip install aiohttp"} + try: + hass_url = (extra.get("url") or os.getenv("HASS_URL", "")).rstrip("/") + token = token or os.getenv("HASS_TOKEN", "") + if not hass_url or not token: + return {"error": "Home Assistant not configured (HASS_URL, HASS_TOKEN required)"} + url = f"{hass_url}/api/services/notify/notify" + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session: + async with session.post(url, headers=headers, json={"message": message, "target": chat_id}) as resp: + if resp.status not in (200, 201): + body = await resp.text() + return {"error": f"Home Assistant API error ({resp.status}): {body}"} + return {"success": True, "platform": "homeassistant", "chat_id": chat_id} + except Exception as e: + return {"error": f"Home Assistant send failed: {e}"} + + +async def _send_dingtalk(extra, chat_id, message): + """Send via DingTalk robot webhook. + + Note: The gateway's DingTalk adapter uses per-session webhook URLs from + incoming messages (dingtalk-stream SDK). For cross-platform send_message + delivery we use a static robot webhook URL instead, which must be + configured via ``DINGTALK_WEBHOOK_URL`` env var or ``webhook_url`` in the + platform's extra config. + """ + try: + import httpx + except ImportError: + return {"error": "httpx not installed"} + try: + webhook_url = extra.get("webhook_url") or os.getenv("DINGTALK_WEBHOOK_URL", "") + if not webhook_url: + return {"error": "DingTalk not configured. Set DINGTALK_WEBHOOK_URL env var or webhook_url in dingtalk platform extra config."} + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + webhook_url, + json={"msgtype": "text", "text": {"content": message}}, + ) + resp.raise_for_status() + data = resp.json() + if data.get("errcode", 0) != 0: + return {"error": f"DingTalk API error: {data.get('errmsg', 'unknown')}"} + return {"success": True, "platform": "dingtalk", "chat_id": chat_id} + except Exception as e: + return {"error": f"DingTalk send failed: {e}"} + + +async def _send_wecom(extra, chat_id, message): + """Send via WeCom using the adapter's WebSocket send pipeline.""" + try: + from gateway.platforms.wecom import WeComAdapter, check_wecom_requirements + if not check_wecom_requirements(): + return {"error": "WeCom requirements not met. Need aiohttp + WECOM_BOT_ID/SECRET."} + except ImportError: + return {"error": "WeCom adapter not available."} + + try: + from gateway.config import PlatformConfig + pconfig = PlatformConfig(extra=extra) + adapter = WeComAdapter(pconfig) + connected = await adapter.connect() + if not connected: + return {"error": f"WeCom: failed to connect — {adapter.fatal_error_message or 'unknown error'}"} + try: + result = await adapter.send(chat_id, message) + if not result.success: + return {"error": f"WeCom send failed: {result.error}"} + return {"success": True, "platform": "wecom", "chat_id": chat_id, "message_id": result.message_id} + finally: + await adapter.disconnect() + except Exception as e: + return {"error": f"WeCom send failed: {e}"} + + +async def _send_feishu(pconfig, chat_id, message, media_files=None, thread_id=None): + """Send via Feishu/Lark using the adapter's send pipeline.""" + try: + from gateway.platforms.feishu import FeishuAdapter, FEISHU_AVAILABLE + if not FEISHU_AVAILABLE: + return {"error": "Feishu dependencies not installed. Run: pip install 'hermes-agent[feishu]'"} + from gateway.platforms.feishu import FEISHU_DOMAIN, LARK_DOMAIN + except ImportError: + return {"error": "Feishu dependencies not installed. Run: pip install 'hermes-agent[feishu]'"} + + media_files = media_files or [] + + try: + adapter = FeishuAdapter(pconfig) + domain_name = getattr(adapter, "_domain_name", "feishu") + domain = FEISHU_DOMAIN if domain_name != "lark" else LARK_DOMAIN + adapter._client = adapter._build_lark_client(domain) + metadata = {"thread_id": thread_id} if thread_id else None + + last_result = None + if message.strip(): + last_result = await adapter.send(chat_id, message, metadata=metadata) + if not last_result.success: + return {"error": f"Feishu send failed: {last_result.error}"} + + for media_path, is_voice in media_files: + if not os.path.exists(media_path): + return {"error": f"Media file not found: {media_path}"} + + ext = os.path.splitext(media_path)[1].lower() + if ext in _IMAGE_EXTS: + last_result = await adapter.send_image_file(chat_id, media_path, metadata=metadata) + elif ext in _VIDEO_EXTS: + last_result = await adapter.send_video(chat_id, media_path, metadata=metadata) + elif ext in _VOICE_EXTS and is_voice: + last_result = await adapter.send_voice(chat_id, media_path, metadata=metadata) + elif ext in _AUDIO_EXTS: + last_result = await adapter.send_voice(chat_id, media_path, metadata=metadata) + else: + last_result = await adapter.send_document(chat_id, media_path, metadata=metadata) + + if not last_result.success: + return {"error": f"Feishu media send failed: {last_result.error}"} + + if last_result is None: + return {"error": "No deliverable text or media remained after processing MEDIA tags"} + + return { + "success": True, + "platform": "feishu", + "chat_id": chat_id, + "message_id": last_result.message_id, + } + except Exception as e: + return {"error": f"Feishu send failed: {e}"} + + def _check_send_message(): """Gate send_message on gateway running (always available on messaging platforms).""" platform = os.getenv("HERMES_SESSION_PLATFORM", "") diff --git a/tools/session_search_tool.py b/tools/session_search_tool.py index 235585270..3ff36f940 100644 --- a/tools/session_search_tool.py +++ b/tools/session_search_tool.py @@ -21,7 +21,7 @@ import json import logging from typing import Dict, Any, List, Optional, Union -from agent.auxiliary_client import async_call_llm +from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning MAX_SESSION_CHARS = 100_000 MAX_SUMMARY_TOKENS = 10000 @@ -161,7 +161,15 @@ async def _summarize_session( temperature=0.1, max_tokens=MAX_SUMMARY_TOKENS, ) - return response.choices[0].message.content.strip() + content = extract_content_or_reasoning(response) + if content: + return content + # Reasoning-only / empty — let the retry loop handle it + logging.warning("Session search LLM returned empty content (attempt %d/%d)", attempt + 1, max_retries) + if attempt < max_retries - 1: + await asyncio.sleep(1 * (attempt + 1)) + continue + return content except RuntimeError: logging.warning("No auxiliary model available for session summarization") return None @@ -384,23 +392,30 @@ def session_search( }, ensure_ascii=False) summaries = [] - for (session_id, match_info, _, _), result in zip(tasks, results): + for (session_id, match_info, conversation_text, _), result in zip(tasks, results): if isinstance(result, Exception): logging.warning( "Failed to summarize session %s: %s", - session_id, - result, - exc_info=True, + session_id, result, exc_info=True, ) - continue + result = None + + entry = { + "session_id": session_id, + "when": _format_timestamp(match_info.get("session_started")), + "source": match_info.get("source", "unknown"), + "model": match_info.get("model"), + } + if result: - summaries.append({ - "session_id": session_id, - "when": _format_timestamp(match_info.get("session_started")), - "source": match_info.get("source", "unknown"), - "model": match_info.get("model"), - "summary": result, - }) + entry["summary"] = result + else: + # Fallback: raw preview so matched sessions aren't silently + # dropped when the summarizer is unavailable (fixes #3409). + preview = (conversation_text[:500] + "\n…[truncated]") if conversation_text else "No preview available." + entry["summary"] = f"[Raw preview — summarization unavailable]\n{preview}" + + summaries.append(entry) return json.dumps({ "success": True, diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 045e13500..8507a6d13 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -113,6 +113,31 @@ def _validate_name(name: str) -> Optional[str]: return None +def _validate_category(category: Optional[str]) -> Optional[str]: + """Validate an optional category name used as a single directory segment.""" + if category is None: + return None + if not isinstance(category, str): + return "Category must be a string." + + category = category.strip() + if not category: + return None + if "/" in category or "\\" in category: + return ( + f"Invalid category '{category}'. Use lowercase letters, numbers, " + "hyphens, dots, and underscores. Categories must be a single directory name." + ) + if len(category) > MAX_NAME_LENGTH: + return f"Category exceeds {MAX_NAME_LENGTH} characters." + if not VALID_NAME_RE.match(category): + return ( + f"Invalid category '{category}'. Use lowercase letters, numbers, " + "hyphens, dots, and underscores. Categories must be a single directory name." + ) + return None + + def _validate_frontmatter(content: str) -> Optional[str]: """ Validate that SKILL.md content has proper frontmatter with required fields. @@ -241,6 +266,10 @@ def _create_skill(name: str, content: str, category: str = None) -> Dict[str, An if err: return {"success": False, "error": err} + err = _validate_category(category) + if err: + return {"success": False, "error": err} + # Validate content err = _validate_frontmatter(content) if err: @@ -547,6 +576,13 @@ def skill_manage( else: result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"} + if result.get("success"): + try: + from agent.prompt_builder import clear_skills_system_prompt_cache + clear_skills_system_prompt_cache(clear_snapshot=True) + except Exception: + pass + return json.dumps(result, ensure_ascii=False) diff --git a/tools/skills_guard.py b/tools/skills_guard.py index 217863af5..d22b7d294 100644 --- a/tools/skills_guard.py +++ b/tools/skills_guard.py @@ -948,9 +948,9 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, # Call the LLM via the centralized provider router try: - from agent.auxiliary_client import call_llm + from agent.auxiliary_client import call_llm, extract_content_or_reasoning - response = call_llm( + call_kwargs = dict( provider="openrouter", model=model, messages=[{ @@ -960,7 +960,13 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult, temperature=0, max_tokens=1000, ) - llm_text = response.choices[0].message.content.strip() + response = call_llm(**call_kwargs) + llm_text = extract_content_or_reasoning(response) + + # Retry once on empty content (reasoning-only response) + if not llm_text: + response = call_llm(**call_kwargs) + llm_text = extract_content_or_reasoning(response) except Exception: # LLM audit is best-effort — don't block install if the call fails return static_result diff --git a/tools/skills_hub.py b/tools/skills_hub.py index 3814dddfe..86f8e47d1 100644 --- a/tools/skills_hub.py +++ b/tools/skills_hub.py @@ -251,6 +251,7 @@ class GitHubSource(SkillSource): {"repo": "openai/skills", "path": "skills/"}, {"repo": "anthropics/skills", "path": "skills/"}, {"repo": "VoltAgent/awesome-agent-skills", "path": "skills/"}, + {"repo": "garrytan/gstack", "path": ""}, ] def __init__(self, auth: GitHubAuth, extra_taps: Optional[List[Dict]] = None): @@ -395,7 +396,8 @@ class GitHubSource(SkillSource): if dir_name.startswith(".") or dir_name.startswith("_"): continue - skill_identifier = f"{repo}/{path.rstrip('/')}/{dir_name}" + prefix = path.rstrip("/") + skill_identifier = f"{repo}/{prefix}/{dir_name}" if prefix else f"{repo}/{dir_name}" meta = self.inspect(skill_identifier) if meta: skills.append(meta) diff --git a/tools/skills_tool.py b/tools/skills_tool.py index fef89f198..6c9e2441a 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -78,7 +78,6 @@ from pathlib import Path from typing import Dict, Any, List, Optional, Set, Tuple import yaml -from hermes_cli.config import load_env, _ENV_VAR_NAME_RE from tools.registry import registry logger = logging.getLogger(__name__) @@ -101,11 +100,28 @@ _PLATFORM_MAP = { "linux": "linux", "windows": "win32", } +_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") _EXCLUDED_SKILL_DIRS = frozenset((".git", ".github", ".hub")) _REMOTE_ENV_BACKENDS = frozenset({"docker", "singularity", "modal", "ssh", "daytona"}) _secret_capture_callback = None +def load_env() -> Dict[str, str]: + """Load profile-scoped environment variables from HERMES_HOME/.env.""" + env_path = get_hermes_home() / ".env" + env_vars: Dict[str, str] = {} + if not env_path.exists(): + return env_vars + + with env_path.open() as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, _, value = line.partition("=") + env_vars[key.strip()] = value.strip().strip("\"'") + return env_vars + + class SkillReadinessStatus(str, Enum): AVAILABLE = "available" SETUP_NEEDED = "setup_needed" @@ -120,28 +136,11 @@ def set_secret_capture_callback(callback) -> None: def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool: """Check if a skill is compatible with the current OS platform. - Skills declare platform requirements via a top-level ``platforms`` list - in their YAML frontmatter:: - - platforms: [macos] # macOS only - platforms: [macos, linux] # macOS and Linux - - Valid values: ``macos``, ``linux``, ``windows``. - - If the field is absent or empty the skill is compatible with **all** - platforms (backward-compatible default). + Delegates to ``agent.skill_utils.skill_matches_platform`` — kept here + as a public re-export so existing callers don't need updating. """ - platforms = frontmatter.get("platforms") - if not platforms: - return True # No restriction → loads everywhere - if not isinstance(platforms, list): - platforms = [platforms] - current = sys.platform - for p in platforms: - mapped = _PLATFORM_MAP.get(str(p).lower().strip(), str(p).lower().strip()) - if current.startswith(mapped): - return True - return False + from agent.skill_utils import skill_matches_platform as _impl + return _impl(frontmatter) def _normalize_prerequisite_values(value: Any) -> List[str]: @@ -372,13 +371,8 @@ def _remaining_required_environment_names( capture_result: Dict[str, Any], *, env_snapshot: Dict[str, str] | None = None, - backend: str | None = None, ) -> List[str]: - if backend is None: - backend = _get_terminal_backend_name() missing_names = set(capture_result["missing_names"]) - if backend in _REMOTE_ENV_BACKENDS: - return [entry["name"] for entry in required_env_vars] if env_snapshot is None: env_snapshot = load_env() @@ -419,40 +413,13 @@ def check_skills_requirements() -> bool: def _parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]: + """Parse YAML frontmatter from markdown content. + + Delegates to ``agent.skill_utils.parse_frontmatter`` — kept here + as a public re-export so existing callers don't need updating. """ - Parse YAML frontmatter from markdown content. - - Uses yaml.safe_load for full YAML support (nested metadata, lists, etc.) - with a fallback to simple key:value splitting for robustness. - - Args: - content: Full markdown file content - - Returns: - Tuple of (frontmatter dict, remaining content) - """ - frontmatter = {} - body = content - - if content.startswith("---"): - end_match = re.search(r"\n---\s*\n", content[3:]) - if end_match: - yaml_content = content[3 : end_match.start() + 3] - body = content[end_match.end() + 3 :] - - try: - parsed = yaml.safe_load(yaml_content) - if isinstance(parsed, dict): - frontmatter = parsed - # yaml.safe_load returns None for empty frontmatter - except yaml.YAMLError: - # Fallback: simple key:value parsing for malformed YAML - for line in yaml_content.strip().split("\n"): - if ":" in line: - key, value = line.split(":", 1) - frontmatter[key.strip()] = value.strip() - - return frontmatter, body + from agent.skill_utils import parse_frontmatter + return parse_frontmatter(content) def _get_category_from_path(skill_path: Path) -> Optional[str]: @@ -516,24 +483,13 @@ def _parse_tags(tags_value) -> List[str]: def _get_disabled_skill_names() -> Set[str]: - """Load disabled skill names from config (once per call). + """Load disabled skill names from config. - Resolves platform from ``HERMES_PLATFORM`` env var, falls back to - the global disabled list. + Delegates to ``agent.skill_utils.get_disabled_skill_names`` — kept here + as a public re-export so existing callers don't need updating. """ - import os - try: - from hermes_cli.config import load_config - config = load_config() - skills_cfg = config.get("skills", {}) - resolved_platform = os.getenv("HERMES_PLATFORM") - if resolved_platform: - platform_disabled = skills_cfg.get("platform_disabled", {}).get(resolved_platform) - if platform_disabled is not None: - return set(platform_disabled) - return set(skills_cfg.get("disabled", [])) - except Exception: - return set() + from agent.skill_utils import get_disabled_skill_names + return get_disabled_skill_names() def _is_skill_disabled(name: str, platform: str = None) -> bool: @@ -554,7 +510,7 @@ def _is_skill_disabled(name: str, platform: str = None) -> bool: def _find_all_skills(*, skip_disabled: bool = False) -> List[Dict[str, Any]]: - """Recursively find all skills in ~/.hermes/skills/. + """Recursively find all skills in ~/.hermes/skills/ and external dirs. Args: skip_disabled: If True, return ALL skills regardless of disabled @@ -564,59 +520,68 @@ def _find_all_skills(*, skip_disabled: bool = False) -> List[Dict[str, Any]]: Returns: List of skill metadata dicts (name, description, category). """ - skills = [] + from agent.skill_utils import get_external_skills_dirs - if not SKILLS_DIR.exists(): - return skills + skills = [] + seen_names: set = set() # Load disabled set once (not per-skill) disabled = set() if skip_disabled else _get_disabled_skill_names() + # Scan local dir first, then external dirs (local takes precedence) + dirs_to_scan = [] + if SKILLS_DIR.exists(): + dirs_to_scan.append(SKILLS_DIR) + dirs_to_scan.extend(get_external_skills_dirs()) - for skill_md in SKILLS_DIR.rglob("SKILL.md"): - if any(part in _EXCLUDED_SKILL_DIRS for part in skill_md.parts): - continue - - skill_dir = skill_md.parent - - try: - content = skill_md.read_text(encoding="utf-8")[:4000] - frontmatter, body = _parse_frontmatter(content) - - if not skill_matches_platform(frontmatter): + for scan_dir in dirs_to_scan: + for skill_md in scan_dir.rglob("SKILL.md"): + if any(part in _EXCLUDED_SKILL_DIRS for part in skill_md.parts): continue - name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH] - if name in disabled: + skill_dir = skill_md.parent + + try: + content = skill_md.read_text(encoding="utf-8")[:4000] + frontmatter, body = _parse_frontmatter(content) + + if not skill_matches_platform(frontmatter): + continue + + name = frontmatter.get("name", skill_dir.name)[:MAX_NAME_LENGTH] + if name in seen_names: + continue + if name in disabled: + continue + + description = frontmatter.get("description", "") + if not description: + for line in body.strip().split("\n"): + line = line.strip() + if line and not line.startswith("#"): + description = line + break + + if len(description) > MAX_DESCRIPTION_LENGTH: + description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..." + + category = _get_category_from_path(skill_md) + + seen_names.add(name) + skills.append({ + "name": name, + "description": description, + "category": category, + }) + + except (UnicodeDecodeError, PermissionError) as e: + logger.debug("Failed to read skill file %s: %s", skill_md, e) + continue + except Exception as e: + logger.debug( + "Skipping skill at %s: failed to parse: %s", skill_md, e, exc_info=True + ) continue - - description = frontmatter.get("description", "") - if not description: - for line in body.strip().split("\n"): - line = line.strip() - if line and not line.startswith("#"): - description = line - break - - if len(description) > MAX_DESCRIPTION_LENGTH: - description = description[:MAX_DESCRIPTION_LENGTH - 3] + "..." - - category = _get_category_from_path(skill_md) - - skills.append({ - "name": name, - "description": description, - "category": category, - }) - - except (UnicodeDecodeError, PermissionError) as e: - logger.debug("Failed to read skill file %s: %s", skill_md, e) - continue - except Exception as e: - logger.debug( - "Skipping skill at %s: failed to parse: %s", skill_md, e, exc_info=True - ) - continue return skills @@ -816,7 +781,15 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: JSON string with skill content or error message """ try: - if not SKILLS_DIR.exists(): + from agent.skill_utils import get_external_skills_dirs + + # Build list of all skill directories to search + all_dirs = [] + if SKILLS_DIR.exists(): + all_dirs.append(SKILLS_DIR) + all_dirs.extend(get_external_skills_dirs()) + + if not all_dirs: return json.dumps( { "success": False, @@ -828,27 +801,37 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: skill_dir = None skill_md = None - # Try direct path first (e.g., "mlops/axolotl") - direct_path = SKILLS_DIR / name - if direct_path.is_dir() and (direct_path / "SKILL.md").exists(): - skill_dir = direct_path - skill_md = direct_path / "SKILL.md" - elif direct_path.with_suffix(".md").exists(): - skill_md = direct_path.with_suffix(".md") + # Search all dirs: local first, then external (first match wins) + for search_dir in all_dirs: + # Try direct path first (e.g., "mlops/axolotl") + direct_path = search_dir / name + if direct_path.is_dir() and (direct_path / "SKILL.md").exists(): + skill_dir = direct_path + skill_md = direct_path / "SKILL.md" + break + elif direct_path.with_suffix(".md").exists(): + skill_md = direct_path.with_suffix(".md") + break - # Search by directory name + # Search by directory name across all dirs if not skill_md: - for found_skill_md in SKILLS_DIR.rglob("SKILL.md"): - if found_skill_md.parent.name == name: - skill_dir = found_skill_md.parent - skill_md = found_skill_md + for search_dir in all_dirs: + for found_skill_md in search_dir.rglob("SKILL.md"): + if found_skill_md.parent.name == name: + skill_dir = found_skill_md.parent + skill_md = found_skill_md + break + if skill_md: break # Legacy: flat .md files if not skill_md: - for found_md in SKILLS_DIR.rglob(f"{name}.md"): - if found_md.name != "SKILL.md": - skill_md = found_md + for search_dir in all_dirs: + for found_md in search_dir.rglob(f"{name}.md"): + if found_md.name != "SKILL.md": + skill_md = found_md + break + if skill_md: break if not skill_md or not skill_md.exists(): @@ -875,12 +858,21 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: ensure_ascii=False, ) - # Security: warn if skill is loaded from outside the trusted skills directory + # Security: warn if skill is loaded from outside trusted directories + # (local skills dir + configured external_dirs are all trusted) + _outside_skills_dir = True + _trusted_dirs = [SKILLS_DIR.resolve()] try: - skill_md.resolve().relative_to(SKILLS_DIR.resolve()) - _outside_skills_dir = False - except ValueError: - _outside_skills_dir = True + _trusted_dirs.extend(d.resolve() for d in all_dirs[1:]) + except Exception: + pass + for _td in _trusted_dirs: + try: + skill_md.resolve().relative_to(_td) + _outside_skills_dir = False + break + except ValueError: + continue # Security: detect common prompt injection patterns _INJECTION_PATTERNS = [ @@ -1118,7 +1110,11 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: if script_files: linked_files["scripts"] = script_files - rel_path = str(skill_md.relative_to(SKILLS_DIR)) + try: + rel_path = str(skill_md.relative_to(SKILLS_DIR)) + except ValueError: + # External skill — use path relative to the skill's own parent dir + rel_path = str(skill_md.relative_to(skill_md.parent.parent)) if skill_md.parent.parent else skill_md.name skill_name = frontmatter.get( "name", skill_md.stem if not skill_dir else skill_dir.name ) @@ -1131,8 +1127,7 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: missing_required_env_vars = [ e for e in required_env_vars - if backend in _REMOTE_ENV_BACKENDS - or not _is_env_var_persisted(e["name"], env_snapshot) + if not _is_env_var_persisted(e["name"], env_snapshot) ] capture_result = _capture_required_environment_variables( skill_name, @@ -1144,7 +1139,6 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: required_env_vars, capture_result, env_snapshot=env_snapshot, - backend=backend, ) setup_needed = bool(remaining_missing_required_envs) @@ -1168,6 +1162,27 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: exc_info=True, ) + # Register credential files for mounting into remote sandboxes + # (Modal, Docker). Files that exist on the host are registered; + # missing ones are added to the setup_needed indicators. + required_cred_files_raw = frontmatter.get("required_credential_files", []) + if not isinstance(required_cred_files_raw, list): + required_cred_files_raw = [] + missing_cred_files: list = [] + if required_cred_files_raw: + try: + from tools.credential_files import register_credential_files + + missing_cred_files = register_credential_files(required_cred_files_raw) + if missing_cred_files: + setup_needed = True + except Exception: + logger.debug( + "Could not register credential files for skill %s", + skill_name, + exc_info=True, + ) + result = { "success": True, "name": skill_name, @@ -1183,6 +1198,7 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: "required_environment_variables": required_env_vars, "required_commands": [], "missing_required_environment_variables": remaining_missing_required_envs, + "missing_credential_files": missing_cred_files, "missing_required_commands": [], "setup_needed": setup_needed, "setup_skipped": capture_result["setup_skipped"], @@ -1201,6 +1217,8 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: if setup_needed: missing_items = [ f"env ${env_name}" for env_name in remaining_missing_required_envs + ] + [ + f"file {path}" for path in missing_cred_files ] setup_note = _build_setup_note( SkillReadinessStatus.SETUP_NEEDED, diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index d9d2fa4f7..f4ffeec79 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -52,6 +52,7 @@ logger = logging.getLogger(__name__) # long-running subprocesses immediately instead of blocking until timeout. # --------------------------------------------------------------------------- from tools.interrupt import is_interrupted, _interrupt_event # noqa: F401 — re-exported +# display_hermes_home imported lazily at call site (stale-module safety during hermes update) def ensure_minisweagent_on_path(_repo_root: Path | None = None) -> None: @@ -69,7 +70,7 @@ from tools.tool_backend_helpers import ( coerce_modal_mode, has_direct_modal_credentials, managed_nous_tools_enabled, - normalize_modal_mode, + resolve_modal_backend_state, ) @@ -172,7 +173,8 @@ def _handle_sudo_failure(output: str, env_type: str) -> str: for failure in sudo_failures: if failure in output: - return output + "\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to ~/.hermes/.env on the agent machine." + from hermes_constants import display_hermes_home as _dhh + return output + f"\n\n💡 Tip: To enable sudo over messaging, add SUDO_PASSWORD to {_dhh()}/.env on the agent machine." return output @@ -546,30 +548,12 @@ def _get_env_config() -> Dict[str, Any]: def _get_modal_backend_state(modal_mode: object | None) -> Dict[str, Any]: """Resolve direct vs managed Modal backend selection.""" - requested_mode = coerce_modal_mode(modal_mode) - normalized_mode = normalize_modal_mode(modal_mode) - has_direct = has_direct_modal_credentials() - managed_ready = is_managed_tool_gateway_ready("modal") - managed_mode_blocked = ( - requested_mode == "managed" and not managed_nous_tools_enabled() + return resolve_modal_backend_state( + modal_mode, + has_direct=has_direct_modal_credentials(), + managed_ready=is_managed_tool_gateway_ready("modal"), ) - if normalized_mode == "managed": - selected_backend = "managed" if managed_ready else None - elif normalized_mode == "direct": - selected_backend = "direct" if has_direct else None - else: - selected_backend = "direct" if has_direct else "managed" if managed_ready else None - - return { - "requested_mode": requested_mode, - "mode": normalized_mode, - "has_direct": has_direct, - "managed_ready": managed_ready, - "managed_mode_blocked": managed_mode_blocked, - "selected_backend": selected_backend, - } - def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ssh_config: dict = None, container_config: dict = None, @@ -1347,8 +1331,8 @@ def check_terminal_requirements() -> bool: ) return False - if importlib.util.find_spec("swerex") is None: - logger.error("swe-rex is required for direct modal terminal backend: pip install 'swe-rex[modal]'") + if importlib.util.find_spec("modal") is None: + logger.error("modal is required for direct modal terminal backend: pip install modal") return False return True @@ -1406,7 +1390,8 @@ if __name__ == "__main__": print(f" TERMINAL_MODAL_IMAGE: {os.getenv('TERMINAL_MODAL_IMAGE', default_img)}") print(f" TERMINAL_DAYTONA_IMAGE: {os.getenv('TERMINAL_DAYTONA_IMAGE', default_img)}") print(f" TERMINAL_CWD: {os.getenv('TERMINAL_CWD', os.getcwd())}") - print(f" TERMINAL_SANDBOX_DIR: {os.getenv('TERMINAL_SANDBOX_DIR', '~/.hermes/sandboxes')}") + from hermes_constants import display_hermes_home as _dhh + print(f" TERMINAL_SANDBOX_DIR: {os.getenv('TERMINAL_SANDBOX_DIR', f'{_dhh()}/sandboxes')}") print(f" TERMINAL_TIMEOUT: {os.getenv('TERMINAL_TIMEOUT', '60')}") print(f" TERMINAL_LIFETIME_SECONDS: {os.getenv('TERMINAL_LIFETIME_SECONDS', '300')}") diff --git a/tools/tool_backend_helpers.py b/tools/tool_backend_helpers.py index 4b8d9d157..b65e19174 100644 --- a/tools/tool_backend_helpers.py +++ b/tools/tool_backend_helpers.py @@ -4,6 +4,7 @@ from __future__ import annotations import os from pathlib import Path +from typing import Any, Dict from utils import env_var_enabled @@ -33,10 +34,7 @@ def coerce_modal_mode(value: object | None) -> str: def normalize_modal_mode(value: object | None) -> str: """Return a normalized modal execution mode.""" - mode = coerce_modal_mode(value) - if mode == "managed" and not managed_nous_tools_enabled(): - return "direct" - return mode + return coerce_modal_mode(value) def has_direct_modal_credentials() -> bool: @@ -47,6 +45,42 @@ def has_direct_modal_credentials() -> bool: ) +def resolve_modal_backend_state( + modal_mode: object | None, + *, + has_direct: bool, + managed_ready: bool, +) -> Dict[str, Any]: + """Resolve direct vs managed Modal backend selection. + + Semantics: + - ``direct`` means direct-only + - ``managed`` means managed-only + - ``auto`` prefers managed when available, then falls back to direct + """ + requested_mode = coerce_modal_mode(modal_mode) + normalized_mode = normalize_modal_mode(modal_mode) + managed_mode_blocked = ( + requested_mode == "managed" and not managed_nous_tools_enabled() + ) + + if normalized_mode == "managed": + selected_backend = "managed" if managed_nous_tools_enabled() and managed_ready else None + elif normalized_mode == "direct": + selected_backend = "direct" if has_direct else None + else: + selected_backend = "managed" if managed_nous_tools_enabled() and managed_ready else "direct" if has_direct else None + + return { + "requested_mode": requested_mode, + "mode": normalized_mode, + "has_direct": has_direct, + "managed_ready": managed_ready, + "managed_mode_blocked": managed_mode_blocked, + "selected_backend": selected_backend, + } + + def resolve_openai_audio_api_key() -> str: """Prefer the voice-tools key, but fall back to the normal OpenAI key.""" return ( diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index 4a1f7ed51..976a59d40 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -12,7 +12,7 @@ Provides speech-to-text transcription with three providers: Used by the messaging gateway to automatically transcribe voice messages sent by users on Telegram, Discord, WhatsApp, Slack, and Signal. -Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg +Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg, aac Usage:: @@ -74,7 +74,7 @@ COMMON_LOCAL_BIN_DIRS = ("/opt/homebrew/bin", "/usr/local/bin") GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1") OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1") -SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"} +SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg", ".aac"} LOCAL_NATIVE_AUDIO_FORMATS = {".wav", ".aiff", ".aif"} MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB diff --git a/tools/tts_tool.py b/tools/tts_tool.py index 9210c3318..6487dbfa4 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -34,7 +34,6 @@ import tempfile import threading import uuid from pathlib import Path -from hermes_constants import get_hermes_home from typing import Callable, Dict, Any, Optional from urllib.parse import urljoin @@ -79,7 +78,12 @@ DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5" DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts" DEFAULT_OPENAI_VOICE = "alloy" DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" -DEFAULT_OUTPUT_DIR = str(get_hermes_home() / "audio_cache") + +def _get_default_output_dir() -> str: + from hermes_constants import get_hermes_dir + return str(get_hermes_dir("cache/audio", "audio_cache")) + +DEFAULT_OUTPUT_DIR = _get_default_output_dir() MAX_TEXT_LENGTH = 4000 @@ -107,7 +111,7 @@ def _load_tts_config() -> Dict[str, Any]: def _get_provider(tts_config: Dict[str, Any]) -> str: """Get the configured TTS provider name.""" - return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip() + return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip() # =========================================================================== diff --git a/tools/vision_tools.py b/tools/vision_tools.py index fe81032b0..47b406846 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -37,8 +37,9 @@ from pathlib import Path from typing import Any, Awaitable, Dict, Optional from urllib.parse import urlparse import httpx -from agent.auxiliary_client import async_call_llm +from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning from tools.debug_helpers import DebugSession +from tools.website_policy import check_website_access logger = logging.getLogger(__name__) @@ -76,6 +77,28 @@ def _validate_image_url(url: str) -> bool: return True +def _detect_image_mime_type(image_path: Path) -> Optional[str]: + """Return a MIME type when the file looks like a supported image.""" + with image_path.open("rb") as f: + header = f.read(64) + + if header.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if header.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if header.startswith((b"GIF87a", b"GIF89a")): + return "image/gif" + if header.startswith(b"BM"): + return "image/bmp" + if len(header) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + if image_path.suffix.lower() == ".svg": + head = image_path.read_text(encoding="utf-8", errors="ignore")[:4096].lower() + if " Path: """ Download an image from a URL to a local destination (async) with retry logic. @@ -115,6 +138,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = last_error = None for attempt in range(max_retries): try: + blocked = check_website_access(image_url) + if blocked: + raise PermissionError(blocked["message"]) + # Download the image with appropriate headers using async httpx # Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum) # SSRF: event_hooks validates each redirect target against private IP ranges @@ -131,6 +158,11 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = }, ) response.raise_for_status() + + final_url = str(response.url) + blocked = check_website_access(final_url) + if blocked: + raise PermissionError(blocked["message"]) # Save the image content destination.write_bytes(response.content) @@ -257,6 +289,7 @@ async def vision_analyze_tool( # Track whether we should clean up the file after processing. # Local files (e.g. from the image cache) should NOT be deleted. should_cleanup = True + detected_mime_type = None try: from tools.interrupt import is_interrupted @@ -275,6 +308,9 @@ async def vision_analyze_tool( should_cleanup = False # Don't delete cached/local files elif _validate_image_url(image_url): # Remote URL -- download to a temporary location + blocked = check_website_access(image_url) + if blocked: + raise PermissionError(blocked["message"]) logger.info("Downloading image from URL...") temp_dir = Path("./temp_vision_images") temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg" @@ -289,10 +325,14 @@ async def vision_analyze_tool( image_size_bytes = temp_image_path.stat().st_size image_size_kb = image_size_bytes / 1024 logger.info("Image ready (%.1f KB)", image_size_kb) + + detected_mime_type = _detect_image_mime_type(temp_image_path) + if not detected_mime_type: + raise ValueError("Only real image files are supported for vision analysis.") # Convert image to base64 data URL logger.info("Converting image to base64...") - image_data_url = _image_to_base64_data_url(temp_image_path) + image_data_url = _image_to_base64_data_url(temp_image_path, mime_type=detected_mime_type) # Calculate size in KB for better readability data_size_kb = len(image_data_url) / 1024 logger.info("Image converted to base64 (%.1f KB)", data_size_kb) @@ -346,8 +386,15 @@ async def vision_analyze_tool( call_kwargs["model"] = model response = await async_call_llm(**call_kwargs) - # Extract the analysis - analysis = response.choices[0].message.content.strip() + # Extract the analysis — fall back to reasoning if content is empty + analysis = extract_content_or_reasoning(response) + + # Retry once on empty content (reasoning-only response) + if not analysis: + logger.warning("Vision LLM returned empty content, retrying once") + response = await async_call_llm(**call_kwargs) + analysis = extract_content_or_reasoning(response) + analysis_length = len(analysis) logger.info("Image analysis completed (%s characters)", analysis_length) diff --git a/tools/web_tools.py b/tools/web_tools.py index 7e9e84483..f9b8dfea9 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -13,6 +13,7 @@ Available tools: - web_crawl_tool: Crawl websites with specific instructions Backend compatibility: +- Exa: https://exa.ai (search, extract) - Firecrawl: https://docs.firecrawl.dev/introduction (search, extract, crawl; direct or derived firecrawl-gateway. for Nous Subscribers) - Parallel: https://docs.parallel.ai (search, extract) - Tavily: https://tavily.com (search, extract, crawl) @@ -47,7 +48,11 @@ import asyncio from typing import List, Dict, Any, Optional import httpx from firecrawl import Firecrawl -from agent.auxiliary_client import get_async_text_auxiliary_client +from agent.auxiliary_client import ( + async_call_llm, + extract_content_or_reasoning, + get_async_text_auxiliary_client, +) from tools.debug_helpers import DebugSession from tools.managed_tool_gateway import ( build_vendor_gateway_url, @@ -82,8 +87,8 @@ def _get_backend() -> str: Falls back to whichever API key is present for users who configured keys manually without running setup. """ - configured = _load_web_config().get("backend", "").lower().strip() - if configured in ("parallel", "firecrawl", "tavily"): + configured = (_load_web_config().get("backend") or "").lower().strip() + if configured in ("parallel", "firecrawl", "tavily", "exa"): return configured # Fallback for manual / legacy config — use whichever key is present. @@ -94,6 +99,9 @@ def _get_backend() -> str: ) has_parallel = _has_env("PARALLEL_API_KEY") has_tavily = _has_env("TAVILY_API_KEY") + has_exa = _has_env("EXA_API_KEY") + if has_exa and not has_firecrawl and not has_parallel and not has_tavily: + return "exa" if has_tavily and not has_firecrawl and not has_parallel: return "tavily" if has_parallel and not has_firecrawl: @@ -105,6 +113,8 @@ def _get_backend() -> str: def _is_backend_available(backend: str) -> bool: """Return True when the selected backend is currently usable.""" + if backend == "exa": + return _has_env("EXA_API_KEY") if backend == "parallel": return _has_env("PARALLEL_API_KEY") if backend == "firecrawl": @@ -178,6 +188,7 @@ def _firecrawl_backend_help_suffix() -> str: def _web_requires_env() -> list[str]: """Return tool metadata env vars for the currently enabled web backends.""" requires = [ + "EXA_API_KEY", "PARALLEL_API_KEY", "TAVILY_API_KEY", "FIRECRAWL_API_KEY", @@ -621,18 +632,32 @@ Create a markdown summary that captures all key information in a well-organized, if aux_client is None or not effective_model: logger.warning("No auxiliary model available for web content processing") return None - from agent.auxiliary_client import auxiliary_max_tokens_param - response = await aux_client.chat.completions.create( - model=effective_model, - messages=[ + call_kwargs = { + "task": "web_extract", + "model": effective_model, + "messages": [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} + {"role": "user", "content": user_prompt}, ], - temperature=0.1, - **auxiliary_max_tokens_param(max_tokens), - **({} if not extra_body else {"extra_body": extra_body}), - ) - return response.choices[0].message.content.strip() + "temperature": 0.1, + "max_tokens": max_tokens, + } + if extra_body: + call_kwargs["extra_body"] = extra_body + response = await async_call_llm(**call_kwargs) + content = extract_content_or_reasoning(response) + if content: + return content + # Reasoning-only / empty response — let the retry loop handle it + logger.warning("LLM returned empty content (attempt %d/%d), retrying", attempt + 1, max_retries) + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + retry_delay = min(retry_delay * 2, 60) + continue + return content # Return whatever we got after exhausting retries + except RuntimeError: + logger.warning("No auxiliary model available for web content processing") + return None except Exception as api_error: last_error = api_error if attempt < max_retries - 1: @@ -744,19 +769,26 @@ Create a single, unified markdown summary.""" fallback = fallback[:max_output_size] + "\n\n[... truncated ...]" return fallback - from agent.auxiliary_client import auxiliary_max_tokens_param - response = await aux_client.chat.completions.create( - model=effective_model, - messages=[ + call_kwargs = { + "task": "web_extract", + "model": effective_model, + "messages": [ {"role": "system", "content": "You synthesize multiple summaries into one cohesive, comprehensive summary. Be thorough but concise."}, - {"role": "user", "content": synthesis_prompt} + {"role": "user", "content": synthesis_prompt}, ], - temperature=0.1, - **auxiliary_max_tokens_param(20000), - **({} if not extra_body else {"extra_body": extra_body}), - ) - final_summary = response.choices[0].message.content.strip() - + "temperature": 0.1, + "max_tokens": 20000, + } + if extra_body: + call_kwargs["extra_body"] = extra_body + response = await async_call_llm(**call_kwargs) + final_summary = extract_content_or_reasoning(response) + + # Retry once on empty content (reasoning-only response) + if not final_summary: + logger.warning("Synthesis LLM returned empty content, retrying once") + response = await async_call_llm(**call_kwargs) + final_summary = extract_content_or_reasoning(response) # Enforce hard cap if len(final_summary) > max_output_size: final_summary = final_summary[:max_output_size] + "\n\n[... summary truncated for context management ...]" @@ -810,6 +842,91 @@ def clean_base64_images(text: str) -> str: return cleaned_text +# ─── Exa Client ────────────────────────────────────────────────────────────── + +_exa_client = None + +def _get_exa_client(): + """Get or create the Exa client (lazy initialization). + + Requires EXA_API_KEY environment variable. + """ + from exa_py import Exa + global _exa_client + if _exa_client is None: + api_key = os.getenv("EXA_API_KEY") + if not api_key: + raise ValueError( + "EXA_API_KEY environment variable not set. " + "Get your API key at https://exa.ai" + ) + _exa_client = Exa(api_key=api_key) + _exa_client.headers["x-exa-integration"] = "hermes-agent" + return _exa_client + + +# ─── Exa Search & Extract Helpers ───────────────────────────────────────────── + +def _exa_search(query: str, limit: int = 10) -> dict: + """Search using the Exa SDK and return results as a dict.""" + from tools.interrupt import is_interrupted + if is_interrupted(): + return {"error": "Interrupted", "success": False} + + logger.info("Exa search: '%s' (limit=%d)", query, limit) + response = _get_exa_client().search( + query, + num_results=limit, + contents={ + "highlights": True, + }, + ) + + web_results = [] + for i, result in enumerate(response.results or []): + highlights = result.highlights or [] + web_results.append({ + "url": result.url or "", + "title": result.title or "", + "description": " ".join(highlights) if highlights else "", + "position": i + 1, + }) + + return {"success": True, "data": {"web": web_results}} + + +def _exa_extract(urls: List[str]) -> List[Dict[str, Any]]: + """Extract content from URLs using the Exa SDK. + + Returns a list of result dicts matching the structure expected by the + LLM post-processing pipeline (url, title, content, metadata). + """ + from tools.interrupt import is_interrupted + if is_interrupted(): + return [{"url": u, "error": "Interrupted", "title": ""} for u in urls] + + logger.info("Exa extract: %d URL(s)", len(urls)) + response = _get_exa_client().get_contents( + urls, + text=True, + ) + + results = [] + for result in response.results or []: + content = result.text or "" + url = result.url or "" + title = result.title or "" + results.append({ + "url": url, + "title": title, + "content": content, + "raw_content": content, + "metadata": {"sourceURL": url, "title": title}, + }) + + return results + + # ─── Parallel Search & Extract Helpers ──────────────────────────────────────── def _parallel_search(query: str, limit: int = 5) -> dict: @@ -947,6 +1064,15 @@ def web_search_tool(query: str, limit: int = 5) -> str: _debug.save() return result_json + if backend == "exa": + response_data = _exa_search(query, limit) + debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", [])) + result_json = json.dumps(response_data, indent=2, ensure_ascii=False) + debug_call_data["final_response_size"] = len(result_json) + _debug.log_call("web_search_tool", debug_call_data) + _debug.save() + return result_json + if backend == "tavily": logger.info("Tavily search: '%s' (limit: %d)", query, limit) raw = _tavily_request("search", { @@ -1074,6 +1200,8 @@ async def web_extract_tool( if backend == "parallel": results = await _parallel_extract(safe_urls) + elif backend == "exa": + results = _exa_extract(safe_urls) elif backend == "tavily": logger.info("Tavily extract: %d URL(s)", len(safe_urls)) raw = _tavily_request("extract", { @@ -1737,9 +1865,9 @@ def check_firecrawl_api_key() -> bool: def check_web_api_key() -> bool: """Check whether the configured web backend is available.""" configured = _load_web_config().get("backend", "").lower().strip() - if configured in ("parallel", "firecrawl", "tavily"): + if configured in ("exa", "parallel", "firecrawl", "tavily"): return _is_backend_available(configured) - return any(_is_backend_available(backend) for backend in ("parallel", "firecrawl", "tavily")) + return any(_is_backend_available(backend) for backend in ("exa", "parallel", "firecrawl", "tavily")) def check_auxiliary_model() -> bool: @@ -1771,7 +1899,9 @@ if __name__ == "__main__": if web_available: backend = _get_backend() print(f"✅ Web backend: {backend}") - if backend == "parallel": + if backend == "exa": + print(" Using Exa API (https://exa.ai)") + elif backend == "parallel": print(" Using Parallel API (https://parallel.ai)") elif backend == "tavily": print(" Using Tavily API (https://tavily.com)") @@ -1787,7 +1917,7 @@ if __name__ == "__main__": else: print("❌ No web search backend configured") print( - "Set PARALLEL_API_KEY, TAVILY_API_KEY, FIRECRAWL_API_KEY, FIRECRAWL_API_URL" + "Set EXA_API_KEY, PARALLEL_API_KEY, TAVILY_API_KEY, FIRECRAWL_API_KEY, FIRECRAWL_API_URL" f"{_firecrawl_backend_help_suffix()}" ) diff --git a/toolsets.py b/toolsets.py index 1f6a0674d..ad762555b 100644 --- a/toolsets.py +++ b/toolsets.py @@ -248,6 +248,42 @@ TOOLSETS = { ], "includes": [] }, + + "hermes-api-server": { + "description": "OpenAI-compatible API server — full agent tools accessible via HTTP (no interactive UI tools like clarify or send_message)", + "tools": [ + # Web + "web_search", "web_extract", + # Terminal + process management + "terminal", "process", + # File manipulation + "read_file", "write_file", "patch", "search_files", + # Vision + image generation + "vision_analyze", "image_generate", + # MoA + "mixture_of_agents", + # Skills + "skills_list", "skill_view", "skill_manage", + # Browser automation + "browser_navigate", "browser_snapshot", "browser_click", + "browser_type", "browser_scroll", "browser_back", + "browser_press", "browser_close", "browser_get_images", + "browser_vision", "browser_console", + # Planning & memory + "todo", "memory", + # Session history search + "session_search", + # Code execution + delegation + "execute_code", "delegate_task", + # Cronjob management + "cronjob", + # Home Assistant smart home control (gated on HASS_TOKEN via check_fn) + "ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service", + # Honcho memory tools (gated on honcho being active via check_fn) + "honcho_context", "honcho_profile", "honcho_search", "honcho_conclude", + ], + "includes": [] + }, "hermes-cli": { "description": "Full interactive CLI toolset - all default tools plus cronjob management", @@ -297,6 +333,36 @@ TOOLSETS = { "includes": [] }, + "hermes-mattermost": { + "description": "Mattermost bot toolset - self-hosted team messaging (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + + "hermes-matrix": { + "description": "Matrix bot toolset - decentralized encrypted messaging (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + + "hermes-dingtalk": { + "description": "DingTalk bot toolset - enterprise messaging platform (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + + "hermes-feishu": { + "description": "Feishu/Lark bot toolset - enterprise messaging via Feishu/Lark (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + + "hermes-wecom": { + "description": "WeCom bot toolset - enterprise WeChat messaging (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-sms": { "description": "SMS bot toolset - interact with Hermes via SMS (Twilio)", "tools": _HERMES_CORE_TOOLS, @@ -306,7 +372,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant", "hermes-email", "hermes-sms"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom"] } } diff --git a/trajectory_compressor.py b/trajectory_compressor.py index 1bfed6bfc..fd69cd18a 100644 --- a/trajectory_compressor.py +++ b/trajectory_compressor.py @@ -123,7 +123,7 @@ class CompressionConfig: # Summarization if 'summarization' in data: config.summarization_model = data['summarization'].get('model', config.summarization_model) - config.base_url = data['summarization'].get('base_url', config.base_url) + config.base_url = data['summarization'].get('base_url') or config.base_url config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) config.temperature = data['summarization'].get('temperature', config.temperature) config.max_retries = data['summarization'].get('max_retries', config.max_retries) @@ -386,7 +386,7 @@ class TrajectoryCompressor: def _detect_provider(self) -> str: """Detect the provider name from the configured base_url.""" - url = self.config.base_url.lower() + url = (self.config.base_url or "").lower() if "openrouter" in url: return "openrouter" if "nousresearch.com" in url: diff --git a/uv.lock b/uv.lock index 48720c67f..63161f8a6 100644 --- a/uv.lock +++ b/uv.lock @@ -376,15 +376,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/0a/0896b829a39b5669a2d811e1a79598de661693685cd62b31f11d0c18e65b/av-17.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dba98603fc4665b4f750de86fbaf6c0cfaece970671a9b529e0e3d1711e8367e", size = 22071058, upload-time = "2026-03-14T14:38:43.663Z" }, ] -[[package]] -name = "bashlex" -version = "0.18" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/60/aae0bb54f9af5e0128ba90eb83d8d0d506ee8f0475c4fdda3deeda20b1d2/bashlex-0.18.tar.gz", hash = "sha256:5bb03a01c6d5676338c36fd1028009c8ad07e7d61d8a1ce3f513b7fff52796ee", size = 68742, upload-time = "2023-01-18T15:21:26.402Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, -] - [[package]] name = "blinker" version = "1.9.0" @@ -394,34 +385,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, ] -[[package]] -name = "boto3" -version = "1.42.57" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "botocore" }, - { name = "jmespath" }, - { name = "s3transfer" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b6/86/46898eaae75ab2185bcf2af406fb4cd1646a0bc277d5dab8ca36c30b7e5e/boto3-1.42.57.tar.gz", hash = "sha256:b598f1705f231f118a81abbfde0c5b52879b1b1997a1aba513f04d61e7b12cbd", size = 112799, upload-time = "2026-02-25T20:31:59.362Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/21/854be1e1829a33450079c1a05f89ef03a2a44bdad590de3e10dc09d73cbd/boto3-1.42.57-py3-none-any.whl", hash = "sha256:74f47051e3b741a0c1e64d57b891076c2c68f8d7b98aee36b044fab1849b4823", size = 140554, upload-time = "2026-02-25T20:31:53.215Z" }, -] - -[[package]] -name = "botocore" -version = "1.42.57" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jmespath" }, - { name = "python-dateutil" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/cd/9c/f9e289f44985fe5b2e3ffc127a55cf7e87ef88499f5a8001db86d74ecfb1/botocore-1.42.57.tar.gz", hash = "sha256:51f94c602b687a70aa11d8bbea2b741b87b0aef7bddb43e5386247bf4311c479", size = 14940952, upload-time = "2026-02-25T20:31:42.049Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/bd/89d0fdb65488d6ee40194268b07316433b41f3aa3f242676ed804c3200f5/botocore-1.42.57-py3-none-any.whl", hash = "sha256:0d26c09955e52ac5090d9cf9e218542df81670077049a606be7c3bd235208e67", size = 14614741, upload-time = "2026-02-25T20:31:39.081Z" }, -] - [[package]] name = "cachetools" version = "5.5.2" @@ -1637,7 +1600,7 @@ wheels = [ [[package]] name = "hermes-agent" -version = "0.4.0" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "anthropic" }, @@ -1674,6 +1637,7 @@ all = [ { name = "elevenlabs" }, { name = "honcho-ai" }, { name = "mcp" }, + { name = "modal" }, { name = "numpy" }, { name = "ptyprocess", marker = "sys_platform != 'win32'" }, { name = "pytest" }, @@ -1685,7 +1649,6 @@ all = [ { name = "slack-bolt" }, { name = "slack-sdk" }, { name = "sounddevice" }, - { name = "swe-rex", extra = ["modal"] }, ] cli = [ { name = "simple-term-menu" }, @@ -1725,7 +1688,7 @@ messaging = [ { name = "slack-sdk" }, ] modal = [ - { name = "swe-rex", extra = ["modal"] }, + { name = "modal" }, ] pty = [ { name = "ptyprocess", marker = "sys_platform != 'win32'" }, @@ -1758,7 +1721,7 @@ yc-bench = [ [package.metadata] requires-dist = [ - { name = "agent-client-protocol", marker = "extra == 'acp'", specifier = ">=0.8.1,<1.0" }, + { name = "agent-client-protocol", marker = "extra == 'acp'", specifier = ">=0.8.1,<0.9" }, { name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0,<4" }, { name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.13.3,<4" }, { name = "aiohttp", marker = "extra == 'sms'", specifier = ">=3.9.0,<4" }, @@ -1797,6 +1760,7 @@ requires-dist = [ { name = "matrix-nio", extras = ["e2e"], marker = "extra == 'matrix'", specifier = ">=0.24.0,<1" }, { name = "mcp", marker = "extra == 'dev'", specifier = ">=1.2.0,<2" }, { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.2.0,<2" }, + { name = "modal", marker = "extra == 'modal'", specifier = ">=1.0.0,<2" }, { name = "numpy", marker = "extra == 'voice'", specifier = ">=1.24.0,<3" }, { name = "openai", specifier = ">=2.21.0,<3" }, { name = "parallel-web", specifier = ">=0.4.2,<1" }, @@ -1819,7 +1783,6 @@ requires-dist = [ { name = "slack-sdk", marker = "extra == 'messaging'", specifier = ">=3.27.0,<4" }, { name = "slack-sdk", marker = "extra == 'slack'", specifier = ">=3.27.0,<4" }, { name = "sounddevice", marker = "extra == 'voice'", specifier = ">=0.4.6,<1" }, - { name = "swe-rex", extras = ["modal"], marker = "extra == 'modal'", specifier = ">=1.4.0,<2" }, { name = "tenacity", specifier = ">=9.1.4,<10" }, { name = "tinker", marker = "extra == 'rl'", git = "https://github.com/thinking-machines-lab/tinker.git" }, { name = "uvicorn", extras = ["standard"], marker = "extra == 'rl'", specifier = ">=0.24.0,<1" }, @@ -2150,15 +2113,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/8a/a342b2f0251f3dac4ca17618265d93bf244a2a4d089126e81e4c1056ac50/jiter-0.13.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bb00b6d26db67a05fe3e12c76edc75f32077fb51deed13822dc648fa373bc19", size = 343768, upload-time = "2026-02-02T12:37:55.055Z" }, ] -[[package]] -name = "jmespath" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, -] - [[package]] name = "joblib" version = "1.5.3" @@ -3256,18 +3210,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/41/19c65578ef9a54b3083253c68a607f099642747168fe00f3a2bceb7c3a34/peewee-3.19.0-py3-none-any.whl", hash = "sha256:de220b94766e6008c466e00ce4ba5299b9a832117d9eb36d45d0062f3cfd7417", size = 411885, upload-time = "2026-01-07T17:24:58.33Z" }, ] -[[package]] -name = "pexpect" -version = "4.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "ptyprocess" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, -] - [[package]] name = "pillow" version = "12.1.1" @@ -4301,18 +4243,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] -[[package]] -name = "s3transfer" -version = "0.16.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "botocore" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, -] - [[package]] name = "safetensors" version = "0.7.0" @@ -4547,31 +4477,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/e6/412c1e1f200ca8c32ecf10201839183e261ad61ced3ede34a66f6d4be3cf/streamlit-1.55.0-py3-none-any.whl", hash = "sha256:1e4a16449c6131696180f4ddb40ea8c51834e89c2a43e1b0362bc9b1cfd9b415", size = 9075714, upload-time = "2026-03-03T22:25:59.126Z" }, ] -[[package]] -name = "swe-rex" -version = "1.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "bashlex" }, - { name = "fastapi" }, - { name = "pexpect" }, - { name = "pydantic" }, - { name = "python-multipart" }, - { name = "requests" }, - { name = "rich" }, - { name = "uvicorn" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/94/86/a069f93ec866151a4d476d546e60220e66b3788878b6e248b2df3ab2c5f1/swe_rex-1.4.0.tar.gz", hash = "sha256:14f8a24c49a63f9e251340b1109ac75a4aacbaece410f8599209de9bfca843c0", size = 41755, upload-time = "2025-08-14T01:19:20.22Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/0d/d06ab2aa78138055c297490762cd7b4d8ac58a544783f874c869cdb7b534/swe_rex-1.4.0-py3-none-any.whl", hash = "sha256:61261ad03eb23b717b5901cd5d229f24f6e1be2e120aad5c2e5ea3384a1d15ad", size = 47756, upload-time = "2025-08-14T01:19:18.93Z" }, -] - -[package.optional-dependencies] -modal = [ - { name = "boto3" }, - { name = "modal" }, -] - [[package]] name = "sympy" version = "1.14.0" diff --git a/website/docs/developer-guide/contributing.md b/website/docs/developer-guide/contributing.md index 1d1e24c62..603b416ac 100644 --- a/website/docs/developer-guide/contributing.md +++ b/website/docs/developer-guide/contributing.md @@ -90,6 +90,7 @@ pytest tests/ -v - **Comments**: Only when explaining non-obvious intent, trade-offs, or API quirks - **Error handling**: Catch specific exceptions. Use `logger.warning()`/`logger.error()` with `exc_info=True` for unexpected errors - **Cross-platform**: Never assume Unix (see below) +- **Profile-safe paths**: Never hardcode `~/.hermes` — use `get_hermes_home()` from `hermes_constants` for code paths and `display_hermes_home()` for user-facing messages. See [AGENTS.md](https://github.com/NousResearch/hermes-agent/blob/main/AGENTS.md#profiles-multi-instance-support) for full rules. ## Cross-Platform Compatibility diff --git a/website/docs/developer-guide/creating-skills.md b/website/docs/developer-guide/creating-skills.md index f2238d7d5..e5660b61f 100644 --- a/website/docs/developer-guide/creating-skills.md +++ b/website/docs/developer-guide/creating-skills.md @@ -168,11 +168,38 @@ required_environment_variables: The user can skip setup and keep loading the skill. Hermes never exposes the raw secret value to the model. Gateway and messaging sessions show local setup guidance instead of collecting secrets in-band. :::tip Sandbox Passthrough -When your skill is loaded, any declared `required_environment_variables` that are set are **automatically passed through** to `execute_code` and `terminal` sandboxes. Your skill's scripts can access `$TENOR_API_KEY` (or `os.environ["TENOR_API_KEY"]` in Python) without the user needing to configure anything extra. See [Environment Variable Passthrough](/docs/user-guide/security#environment-variable-passthrough) for details. +When your skill is loaded, any declared `required_environment_variables` that are set are **automatically passed through** to `execute_code` and `terminal` sandboxes — including remote backends like Docker and Modal. Your skill's scripts can access `$TENOR_API_KEY` (or `os.environ["TENOR_API_KEY"]` in Python) without the user needing to configure anything extra. See [Environment Variable Passthrough](/docs/user-guide/security#environment-variable-passthrough) for details. ::: Legacy `prerequisites.env_vars` remains supported as a backward-compatible alias. +### Credential File Requirements (OAuth tokens, etc.) + +Skills that use OAuth or file-based credentials can declare files that need to be mounted into remote sandboxes. This is for credentials stored as **files** (not env vars) — typically OAuth token files produced by a setup script. + +```yaml +required_credential_files: + - path: google_token.json + description: Google OAuth2 token (created by setup script) + - path: google_client_secret.json + description: Google OAuth2 client credentials +``` + +Each entry supports: +- `path` (required) — file path relative to `~/.hermes/` +- `description` (optional) — explains what the file is and how it's created + +When loaded, Hermes checks if these files exist. Missing files trigger `setup_needed`. Existing files are automatically: +- **Mounted into Docker** containers as read-only bind mounts +- **Synced into Modal** sandboxes (at creation + before each command, so mid-session OAuth works) +- Available on **local** backend without any special handling + +:::tip When to use which +Use `required_environment_variables` for simple API keys and tokens (strings stored in `~/.hermes/.env`). Use `required_credential_files` for OAuth token files, client secrets, service account JSON, certificates, or any credential that's a file on disk. +::: + +See the `skills/productivity/google-workspace/SKILL.md` for a complete example using both. + ## Skill Guidelines ### No External Dependencies diff --git a/website/docs/getting-started/quickstart.md b/website/docs/getting-started/quickstart.md index 24068d895..27cee7084 100644 --- a/website/docs/getting-started/quickstart.md +++ b/website/docs/getting-started/quickstart.md @@ -50,6 +50,7 @@ hermes setup # Or configure everything at once | **MiniMax** | International MiniMax endpoint | Set `MINIMAX_API_KEY` | | **MiniMax China** | China-region MiniMax endpoint | Set `MINIMAX_CN_API_KEY` | | **Alibaba Cloud** | Qwen models via DashScope | Set `DASHSCOPE_API_KEY` | +| **Hugging Face** | 20+ open models via unified router (Qwen, DeepSeek, Kimi, etc.) | Set `HF_TOKEN` | | **Kilo Code** | KiloCode-hosted models | Set `KILOCODE_API_KEY` | | **OpenCode Zen** | Pay-as-you-go access to curated models | Set `OPENCODE_ZEN_API_KEY` | | **OpenCode Go** | $10/month subscription for open models | Set `OPENCODE_GO_API_KEY` | diff --git a/website/docs/guides/build-a-hermes-plugin.md b/website/docs/guides/build-a-hermes-plugin.md index de3dbec19..abe1e3424 100644 --- a/website/docs/guides/build-a-hermes-plugin.md +++ b/website/docs/guides/build-a-hermes-plugin.md @@ -365,16 +365,18 @@ def register(ctx): Available hooks: -| Hook | When | Arguments | -|------|------|-----------| -| `pre_tool_call` | Before any tool runs | `tool_name`, `args`, `task_id` | -| `post_tool_call` | After any tool returns | `tool_name`, `args`, `result`, `task_id` | -| `pre_llm_call` | Before LLM API call | `messages`, `model` | -| `post_llm_call` | After LLM response | `messages`, `response`, `model` | -| `on_session_start` | Session begins | `session_id`, `platform` | -| `on_session_end` | Session ends | `session_id`, `platform` | +| Hook | When | Arguments | Return | +|------|------|-----------|--------| +| `pre_tool_call` | Before any tool runs | `tool_name`, `args`, `task_id` | — | +| `post_tool_call` | After any tool returns | `tool_name`, `args`, `result`, `task_id` | — | +| `pre_llm_call` | Once per turn, before the LLM loop | `session_id`, `user_message`, `conversation_history`, `is_first_turn`, `model`, `platform` | `{"context": "..."}` | +| `post_llm_call` | Once per turn, after the LLM loop | `session_id`, `user_message`, `assistant_response`, `conversation_history`, `model`, `platform` | — | +| `on_session_start` | New session created (first turn only) | `session_id`, `model`, `platform` | — | +| `on_session_end` | End of every `run_conversation` call | `session_id`, `completed`, `interrupted`, `model`, `platform` | — | -Hooks are observers — they can't modify arguments or return values. If a hook crashes, it's logged and skipped; other hooks and the tool continue normally. +Most hooks are fire-and-forget observers. The exception is `pre_llm_call`: if a callback returns a dict with a `"context"` key (or a plain string), the value is appended to the ephemeral system prompt for the current turn. This allows memory plugins to inject recalled context without touching core code. + +If a hook crashes, it's logged and skipped; other hooks and the agent continue normally. ### Distribute via pip diff --git a/website/docs/guides/team-telegram-assistant.md b/website/docs/guides/team-telegram-assistant.md index 88de9c706..04350bfab 100644 --- a/website/docs/guides/team-telegram-assistant.md +++ b/website/docs/guides/team-telegram-assistant.md @@ -168,11 +168,15 @@ journalctl -u hermes-gateway -f ```bash # macOS — manage the service -launchctl start ai.hermes.gateway -launchctl stop ai.hermes.gateway +hermes gateway start +hermes gateway stop tail -f ~/.hermes/logs/gateway.log ``` +:::tip macOS PATH +The launchd plist captures your shell PATH at install time so gateway subprocesses can find tools like Node.js and ffmpeg. If you install new tools later, re-run `hermes gateway install` to update the plist. +::: + ### Verify It's Running ```bash diff --git a/website/docs/reference/cli-commands.md b/website/docs/reference/cli-commands.md index d527b61e2..d27a34f2c 100644 --- a/website/docs/reference/cli-commands.md +++ b/website/docs/reference/cli-commands.md @@ -39,6 +39,7 @@ hermes [global-options] [subcommand/options] | `hermes login` / `logout` | Authenticate with OAuth-backed providers. | | `hermes status` | Show agent, auth, and platform status. | | `hermes cron` | Inspect and tick the cron scheduler. | +| `hermes webhook` | Manage dynamic webhook subscriptions for event-driven activation. | | `hermes doctor` | Diagnose config and dependency issues. | | `hermes config` | Show, edit, migrate, and query configuration files. | | `hermes pairing` | Approve or revoke messaging pairing codes. | @@ -66,7 +67,7 @@ Common options: | `-q`, `--query "..."` | One-shot, non-interactive prompt. | | `-m`, `--model ` | Override the model for this run. | | `-t`, `--toolsets ` | Enable a comma-separated set of toolsets. | -| `--provider ` | Force a provider: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode`. | +| `--provider ` | Force a provider: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `huggingface`, `alibaba`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode`. | | `-s`, `--skills ` | Preload one or more skills for the session (can be repeated or comma-separated). | | `-v`, `--verbose` | Verbose output. | | `-Q`, `--quiet` | Programmatic mode: suppress banner/spinner/tool previews. | @@ -214,6 +215,39 @@ hermes cron | `status` | Check whether the cron scheduler is running. | | `tick` | Run due jobs once and exit. | +## `hermes webhook` + +```bash +hermes webhook +``` + +Manage dynamic webhook subscriptions for event-driven agent activation. Requires the webhook platform to be enabled in config — if not configured, prints setup instructions. + +| Subcommand | Description | +|------------|-------------| +| `subscribe` / `add` | Create a webhook route. Returns the URL and HMAC secret to configure on your service. | +| `list` / `ls` | Show all agent-created subscriptions. | +| `remove` / `rm` | Delete a dynamic subscription. Static routes from config.yaml are not affected. | +| `test` | Send a test POST to verify a subscription is working. | + +### `hermes webhook subscribe` + +```bash +hermes webhook subscribe [options] +``` + +| Option | Description | +|--------|-------------| +| `--prompt` | Prompt template with `{dot.notation}` payload references. | +| `--events` | Comma-separated event types to accept (e.g. `issues,pull_request`). Empty = all. | +| `--description` | Human-readable description. | +| `--skills` | Comma-separated skill names to load for the agent run. | +| `--deliver` | Delivery target: `log` (default), `telegram`, `discord`, `slack`, `github_comment`. | +| `--deliver-chat-id` | Target chat/channel ID for cross-platform delivery. | +| `--secret` | Custom HMAC secret. Auto-generated if omitted. | + +Subscriptions persist to `~/.hermes/webhook_subscriptions.json` and are hot-reloaded by the webhook adapter without a gateway restart. + ## `hermes doctor` ```bash @@ -350,32 +384,38 @@ See [ACP Editor Integration](../user-guide/features/acp.md) and [ACP Internals]( hermes mcp ``` -Manage MCP (Model Context Protocol) server configurations. +Manage MCP (Model Context Protocol) server configurations and run Hermes as an MCP server. | Subcommand | Description | |------------|-------------| +| `serve [-v\|--verbose]` | Run Hermes as an MCP server — expose conversations to other agents. | | `add [--url URL] [--command CMD] [--args ...] [--auth oauth\|header]` | Add an MCP server with automatic tool discovery. | | `remove ` (alias: `rm`) | Remove an MCP server from config. | | `list` (alias: `ls`) | List configured MCP servers. | | `test ` | Test connection to an MCP server. | | `configure ` (alias: `config`) | Toggle tool selection for a server. | -See [MCP Config Reference](./mcp-config-reference.md) and [Use MCP with Hermes](../guides/use-mcp-with-hermes.md). +See [MCP Config Reference](./mcp-config-reference.md), [Use MCP with Hermes](../guides/use-mcp-with-hermes.md), and [MCP Server Mode](../user-guide/features/mcp.md#running-hermes-as-an-mcp-server). ## `hermes plugins` ```bash -hermes plugins +hermes plugins [subcommand] ``` -Manage Hermes Agent plugins. +Manage Hermes Agent plugins. Running `hermes plugins` with no subcommand launches an interactive curses checklist to enable/disable installed plugins. | Subcommand | Description | |------------|-------------| +| *(none)* | Interactive toggle UI — enable/disable plugins with arrow keys and space. | | `install [--force]` | Install a plugin from a Git URL or `owner/repo`. | | `update ` | Pull latest changes for an installed plugin. | | `remove ` (aliases: `rm`, `uninstall`) | Remove an installed plugin. | -| `list` (alias: `ls`) | List installed plugins. | +| `enable ` | Enable a disabled plugin. | +| `disable ` | Disable a plugin without removing it. | +| `list` (alias: `ls`) | List installed plugins with enabled/disabled status. | + +Disabled plugins are stored in `config.yaml` under `plugins.disabled` and skipped during loading. See [Plugins](../user-guide/features/plugins.md) and [Build a Hermes Plugin](../guides/build-a-hermes-plugin.md). @@ -423,10 +463,94 @@ hermes insights [--days N] [--source platform] ## `hermes claw` ```bash -hermes claw migrate +hermes claw migrate [options] ``` -Used to migrate settings, memories, skills, and keys from OpenClaw to Hermes. +Migrate your OpenClaw setup to Hermes. Reads from `~/.openclaw` (or a custom path) and writes to `~/.hermes`. + +| Option | Description | +|--------|-------------| +| `--dry-run` | Preview what would be migrated without writing anything. | +| `--preset ` | Migration preset: `full` (default, includes secrets) or `user-data` (excludes API keys). | +| `--overwrite` | Overwrite existing Hermes files on conflicts (default: skip). | +| `--migrate-secrets` | Include API keys in migration (enabled by default with `--preset full`). | +| `--source ` | Custom OpenClaw directory (default: `~/.openclaw`). | +| `--workspace-target ` | Target directory for workspace instructions (AGENTS.md). | +| `--skill-conflict ` | Handle skill name collisions: `skip` (default), `overwrite`, or `rename`. | +| `--yes` | Skip the confirmation prompt. | + +### What gets migrated + +The migration covers your entire OpenClaw footprint. Items are either **directly imported** into Hermes equivalents or **archived** for manual review when there's no direct mapping. + +#### Directly imported + +| Category | OpenClaw source | Hermes destination | +|----------|----------------|-------------------| +| **Persona** | `SOUL.md` | `~/.hermes/SOUL.md` | +| **Workspace instructions** | `AGENTS.md` | `AGENTS.md` in target workspace | +| **Long-term memory** | `MEMORY.md` | `~/.hermes/MEMORY.md` (merged with existing entries) | +| **User profile** | `USER.md` | `~/.hermes/USER.md` (merged with existing entries) | +| **Daily memory files** | `workspace/memory/` | Merged into `~/.hermes/MEMORY.md` | +| **Default model** | Config model setting | `config.yaml` model section | +| **Custom providers** | Provider definitions (baseUrl, apiType, headers) | `config.yaml` custom\_providers | +| **MCP servers** | MCP server definitions | `config.yaml` mcp\_servers | +| **User skills** | Workspace skills | `~/.hermes/skills/openclaw-imports/` | +| **Shared skills** | `~/.openclaw/skills/` | `~/.hermes/skills/openclaw-imports/` | +| **Command allowlist** | Exec approval patterns | `config.yaml` command\_allowlist | +| **Messaging settings** | Allowlists, working directory | `config.yaml` messaging section | +| **Session policies** | Daily/idle reset policies | `config.yaml` session\_reset | +| **Agent defaults** | Compaction, context, thinking settings | `config.yaml` agent section | +| **Browser settings** | Browser automation config | `config.yaml` browser section | +| **Tool settings** | Exec timeout, sandbox, web search | `config.yaml` tools section | +| **Approval rules** | Approval mode and rules | `config.yaml` approvals section | +| **TTS config** | TTS provider and voice | `config.yaml` tts section | +| **TTS assets** | Workspace TTS files | `~/.hermes/tts/` | +| **Gateway config** | Gateway port and auth | `config.yaml` gateway section | +| **Telegram settings** | Bot token, allowlist | `~/.hermes/.env` | +| **Discord settings** | Bot token, allowlist | `~/.hermes/.env` | +| **Slack settings** | Bot/app tokens, allowlist | `~/.hermes/.env` | +| **WhatsApp settings** | Allowlist | `~/.hermes/.env` | +| **Signal settings** | Account, HTTP URL, allowlist | `~/.hermes/.env` | +| **Channel config** | Matrix, Mattermost, IRC, group settings | `config.yaml` + archive | +| **Provider API keys** | OPENROUTER\_API\_KEY, OPENAI\_API\_KEY, ANTHROPIC\_API\_KEY, etc. | `~/.hermes/.env` (requires `--migrate-secrets`) | + +#### Archived for manual review + +These OpenClaw features don't have direct Hermes equivalents. They're saved to an archive directory for you to review and recreate manually. + +| Category | What's archived | How to recreate in Hermes | +|----------|----------------|--------------------------| +| **Cron / scheduled tasks** | Job definitions | Recreate with `hermes cron create` | +| **Plugins** | Plugin configuration, installed extensions | Check the [plugins guide](../user-guide/features/hooks.md) | +| **Hooks and webhooks** | Internal hooks, webhooks, Gmail integration | Use `hermes webhook` or gateway hooks | +| **Memory backend** | QMD, vector search, citation settings | Configure Honcho via `hermes honcho` | +| **Skills registry** | Per-skill enabled/config/env settings | Use `hermes skills config` | +| **UI and identity** | Theme, assistant identity, display prefs | Use `/skin` command or `config.yaml` | +| **Logging** | Diagnostics configuration | Set in `config.yaml` logging section | + +### Security + +API keys are **not migrated by default**. The `--preset full` preset enables secret migration, but only for an allowlist of known keys: `OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `ELEVENLABS_API_KEY`, `TELEGRAM_BOT_TOKEN`, and `VOICE_TOOLS_OPENAI_KEY`. All other secrets are skipped. + +### Examples + +```bash +# Preview what would be migrated +hermes claw migrate --dry-run + +# Full migration including API keys +hermes claw migrate --preset full + +# Migrate user data only (no secrets), overwrite conflicts +hermes claw migrate --preset user-data --overwrite + +# Migrate from a custom OpenClaw path +hermes claw migrate --source /home/user/old-openclaw + +# Migrate and place AGENTS.md in a specific project +hermes claw migrate --workspace-target /home/user/my-project +``` ## Maintenance commands diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index d228c3927..7262bf05b 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -37,10 +37,12 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config | `MINIMAX_CN_BASE_URL` | Override MiniMax China base URL (default: `https://api.minimaxi.com/v1`) | | `KILOCODE_API_KEY` | Kilo Code API key ([kilo.ai](https://kilo.ai)) | | `KILOCODE_BASE_URL` | Override Kilo Code base URL (default: `https://api.kilo.ai/api/gateway`) | +| `HF_TOKEN` | Hugging Face token for Inference Providers ([huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)) | +| `HF_BASE_URL` | Override Hugging Face base URL (default: `https://router.huggingface.co/v1`) | | `ANTHROPIC_API_KEY` | Anthropic Console API key ([console.anthropic.com](https://console.anthropic.com/)) | | `ANTHROPIC_TOKEN` | Manual or legacy Anthropic OAuth/setup-token override | | `DASHSCOPE_API_KEY` | Alibaba Cloud DashScope API key for Qwen models ([modelstudio.console.alibabacloud.com](https://modelstudio.console.alibabacloud.com/)) | -| `DASHSCOPE_BASE_URL` | Custom DashScope base URL (default: international endpoint) | +| `DASHSCOPE_BASE_URL` | Custom DashScope base URL (default: `https://coding-intl.dashscope.aliyuncs.com/v1`) | | `DEEPSEEK_API_KEY` | DeepSeek API key for direct DeepSeek access ([platform.deepseek.com](https://platform.deepseek.com/api_keys)) | | `DEEPSEEK_BASE_URL` | Custom DeepSeek API base URL | | `OPENCODE_ZEN_API_KEY` | OpenCode Zen API key — pay-as-you-go access to curated models ([opencode.ai](https://opencode.ai/auth)) | @@ -61,7 +63,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe | Variable | Description | |----------|-------------| -| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode` (default: `auto`) | +| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `copilot`, `copilot-acp`, `anthropic`, `huggingface`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode`, `alibaba` (default: `auto`) | | `HERMES_PORTAL_BASE_URL` | Override Nous Portal URL (for development/testing) | | `NOUS_INFERENCE_BASE_URL` | Override Nous inference API URL | | `HERMES_NOUS_MIN_KEY_TTL_SECONDS` | Min agent key TTL before re-mint (default: 1800 = 30min) | @@ -103,7 +105,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe |----------|-------------| | `TERMINAL_ENV` | Backend: `local`, `docker`, `ssh`, `singularity`, `modal`, `daytona` | | `TERMINAL_DOCKER_IMAGE` | Docker image (default: `python:3.11`) | -| `TERMINAL_DOCKER_FORWARD_ENV` | JSON array of env var names to explicitly forward into Docker terminal sessions | +| `TERMINAL_DOCKER_FORWARD_ENV` | JSON array of env var names to explicitly forward into Docker terminal sessions. Note: skill-declared `required_environment_variables` are forwarded automatically — you only need this for vars not declared by any skill. | | `TERMINAL_DOCKER_VOLUMES` | Additional Docker volume mounts (comma-separated `host:container` pairs) | | `TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE` | Advanced opt-in: mount the launch cwd into Docker `/workspace` (`true`/`false`, default: `false`) | | `TERMINAL_SINGULARITY_IMAGE` | Singularity image or `.sif` path | @@ -196,10 +198,25 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `DINGTALK_CLIENT_ID` | DingTalk bot AppKey from developer portal ([open.dingtalk.com](https://open.dingtalk.com)) | | `DINGTALK_CLIENT_SECRET` | DingTalk bot AppSecret from developer portal | | `DINGTALK_ALLOWED_USERS` | Comma-separated DingTalk user IDs allowed to message the bot | +| `FEISHU_APP_ID` | Feishu/Lark bot App ID from [open.feishu.cn](https://open.feishu.cn/) | +| `FEISHU_APP_SECRET` | Feishu/Lark bot App Secret | +| `FEISHU_DOMAIN` | `feishu` (China) or `lark` (international). Default: `feishu` | +| `FEISHU_CONNECTION_MODE` | `websocket` (recommended) or `webhook`. Default: `websocket` | +| `FEISHU_ENCRYPT_KEY` | Optional encryption key for webhook mode | +| `FEISHU_VERIFICATION_TOKEN` | Optional verification token for webhook mode | +| `FEISHU_ALLOWED_USERS` | Comma-separated Feishu user IDs allowed to message the bot | +| `FEISHU_HOME_CHANNEL` | Feishu chat ID for cron delivery and notifications | +| `WECOM_BOT_ID` | WeCom AI Bot ID from admin console | +| `WECOM_SECRET` | WeCom AI Bot secret | +| `WECOM_WEBSOCKET_URL` | Custom WebSocket URL (default: `wss://openws.work.weixin.qq.com`) | +| `WECOM_ALLOWED_USERS` | Comma-separated WeCom user IDs allowed to message the bot | +| `WECOM_HOME_CHANNEL` | WeCom chat ID for cron delivery and notifications | | `MATTERMOST_URL` | Mattermost server URL (e.g. `https://mm.example.com`) | | `MATTERMOST_TOKEN` | Bot token or personal access token for Mattermost | | `MATTERMOST_ALLOWED_USERS` | Comma-separated Mattermost user IDs allowed to message the bot | | `MATTERMOST_HOME_CHANNEL` | Channel ID for proactive message delivery (cron, notifications) | +| `MATTERMOST_REQUIRE_MENTION` | Require `@mention` in channels (default: `true`). Set to `false` to respond to all messages. | +| `MATTERMOST_FREE_RESPONSE_CHANNELS` | Comma-separated channel IDs where bot responds without `@mention` | | `MATTERMOST_REPLY_MODE` | Reply style: `thread` (threaded replies) or `off` (flat messages, default) | | `MATRIX_HOMESERVER` | Matrix homeserver URL (e.g. `https://matrix.org`) | | `MATRIX_ACCESS_TOKEN` | Matrix access token for bot authentication | diff --git a/website/docs/reference/faq.md b/website/docs/reference/faq.md index a632bc102..e207420f8 100644 --- a/website/docs/reference/faq.md +++ b/website/docs/reference/faq.md @@ -357,6 +357,23 @@ lsof -i :8080 hermes config show ``` +#### macOS: Node.js / ffmpeg / other tools not found by gateway + +**Cause:** launchd services inherit a minimal PATH (`/usr/bin:/bin:/usr/sbin:/sbin`) that doesn't include Homebrew, nvm, cargo, or other user-installed tool directories. This commonly breaks the WhatsApp bridge (`node not found`) or voice transcription (`ffmpeg not found`). + +**Solution:** The gateway captures your shell PATH when you run `hermes gateway install`. If you installed tools after setting up the gateway, re-run the install to capture the updated PATH: + +```bash +hermes gateway install # Re-snapshots your current PATH +hermes gateway start # Detects the updated plist and reloads +``` + +You can verify the plist has the correct PATH: +```bash +/usr/libexec/PlistBuddy -c "Print :EnvironmentVariables:PATH" \ + ~/Library/LaunchAgents/ai.hermes.gateway.plist +``` + --- ### Performance Issues @@ -472,6 +489,44 @@ If an MCP server crashes mid-request, Hermes will report a timeout. Check the se --- +## Profiles + +### How do profiles differ from just setting HERMES_HOME? + +Profiles are a managed layer on top of `HERMES_HOME`. You *could* manually set `HERMES_HOME=/some/path` before every command, but profiles handle all the plumbing for you: creating the directory structure, generating shell aliases (`hermes-work`), tracking the active profile in `~/.hermes/active_profile`, and syncing skill updates across all profiles automatically. They also integrate with tab completion so you don't have to remember paths. + +### Can two profiles share the same bot token? + +No. Each messaging platform (Telegram, Discord, etc.) requires exclusive access to a bot token. If two profiles try to use the same token simultaneously, the second gateway will fail to connect. Create a separate bot per profile — for Telegram, talk to [@BotFather](https://t.me/BotFather) to make additional bots. + +### Do profiles share memory or sessions? + +No. Each profile has its own memory store, session database, and skills directory. They are completely isolated. If you want to start a new profile with existing memories and sessions, use `hermes profile create newname --clone-all` to copy everything from the current profile. + +### What happens when I run `hermes update`? + +`hermes update` pulls the latest code and reinstalls dependencies **once** (not per-profile). It then syncs updated skills to all profiles automatically. You only need to run `hermes update` once — it covers every profile on the machine. + +### Can I move a profile to a different machine? + +Yes. Export the profile to a portable archive and import it on the other machine: + +```bash +# On the source machine +hermes profile export work ./work-backup.tar.gz + +# Copy the file to the target machine, then: +hermes profile import ./work-backup.tar.gz work +``` + +The imported profile will have all config, memories, sessions, and skills from the export. You may need to update paths or re-authenticate with providers if the new machine has a different setup. + +### How many profiles can I run? + +There is no hard limit. Each profile is just a directory under `~/.hermes/profiles/`. The practical limit depends on your disk space and how many concurrent gateways your system can handle (each gateway is a lightweight Python process). Running dozens of profiles is fine; each idle profile uses no resources. + +--- + ## Still Stuck? If your issue isn't covered here: diff --git a/website/docs/reference/profile-commands.md b/website/docs/reference/profile-commands.md new file mode 100644 index 000000000..a59e27574 --- /dev/null +++ b/website/docs/reference/profile-commands.md @@ -0,0 +1,280 @@ +--- +sidebar_position: 7 +--- + +# Profile Commands Reference + +This page covers all commands related to [Hermes profiles](../user-guide/profiles.md). For general CLI commands, see [CLI Commands Reference](./cli-commands.md). + +## `hermes profile` + +```bash +hermes profile +``` + +Top-level command for managing profiles. Running `hermes profile` without a subcommand shows help. + +| Subcommand | Description | +|------------|-------------| +| `list` | List all profiles. | +| `use` | Set the active (default) profile. | +| `create` | Create a new profile. | +| `delete` | Delete a profile. | +| `show` | Show details about a profile. | +| `alias` | Regenerate the shell alias for a profile. | +| `rename` | Rename a profile. | +| `export` | Export a profile to a tar.gz archive. | +| `import` | Import a profile from a tar.gz archive. | + +## `hermes profile list` + +```bash +hermes profile list +``` + +Lists all profiles. The currently active profile is marked with `*`. + +**Example:** + +```bash +$ hermes profile list + default +* work + dev + personal +``` + +No options. + +## `hermes profile use` + +```bash +hermes profile use +``` + +Sets `` as the active profile. All subsequent `hermes` commands (without `-p`) will use this profile. + +| Argument | Description | +|----------|-------------| +| `` | Profile name to activate. Use `default` to return to the base profile. | + +**Example:** + +```bash +hermes profile use work +hermes profile use default +``` + +## `hermes profile create` + +```bash +hermes profile create [options] +``` + +Creates a new profile. + +| Argument / Option | Description | +|-------------------|-------------| +| `` | Name for the new profile. Must be a valid directory name (alphanumeric, hyphens, underscores). | +| `--clone` | Copy `config.yaml`, `.env`, and `SOUL.md` from the current profile. | +| `--clone-all` | Copy everything (config, memories, skills, sessions, state) from the current profile. | +| `--from ` | Clone from a specific profile instead of the current one. Used with `--clone` or `--clone-all`. | + +**Examples:** + +```bash +# Blank profile — needs full setup +hermes profile create mybot + +# Clone config only from current profile +hermes profile create work --clone + +# Clone everything from current profile +hermes profile create backup --clone-all + +# Clone config from a specific profile +hermes profile create work2 --clone --from work +``` + +## `hermes profile delete` + +```bash +hermes profile delete [options] +``` + +Deletes a profile and removes its shell alias. + +| Argument / Option | Description | +|-------------------|-------------| +| `` | Profile to delete. | +| `--yes`, `-y` | Skip confirmation prompt. | + +**Example:** + +```bash +hermes profile delete mybot +hermes profile delete mybot --yes +``` + +:::warning +This permanently deletes the profile's entire directory including all config, memories, sessions, and skills. Cannot delete the currently active profile. +::: + +## `hermes profile show` + +```bash +hermes profile show [name] +``` + +Displays details about a profile including its home directory, configured model, active platforms, and disk usage. + +| Argument | Description | +|----------|-------------| +| `[name]` | Profile to inspect. Defaults to the current active profile if omitted. | + +**Example:** + +```bash +$ hermes profile show work +Profile: work +Home: ~/.hermes/profiles/work +Model: anthropic/claude-sonnet-4 +Platforms: telegram, discord +Skills: 12 installed +Disk: 48 MB +``` + +## `hermes profile alias` + +```bash +hermes profile alias +``` + +Regenerates the shell alias script at `~/.local/bin/hermes-`. Useful if the alias was accidentally deleted or if you need to update it after moving your Hermes installation. + +| Argument | Description | +|----------|-------------| +| `` | Profile to create/update the alias for. | + +**Example:** + +```bash +hermes profile alias work +# Creates/updates ~/.local/bin/work +``` + +## `hermes profile rename` + +```bash +hermes profile rename +``` + +Renames a profile. Updates the directory and shell alias. + +| Argument | Description | +|----------|-------------| +| `` | Current profile name. | +| `` | New profile name. | + +**Example:** + +```bash +hermes profile rename mybot assistant +# ~/.hermes/profiles/mybot → ~/.hermes/profiles/assistant +# ~/.local/bin/mybot → ~/.local/bin/assistant +``` + +## `hermes profile export` + +```bash +hermes profile export +``` + +Exports a profile as a compressed tar.gz archive. + +| Argument | Description | +|----------|-------------| +| `` | Profile to export. | +| `` | Path for the output archive (e.g., `./work-backup.tar.gz`). | + +**Example:** + +```bash +hermes profile export work ./work-2026-03-29.tar.gz +``` + +## `hermes profile import` + +```bash +hermes profile import [name] +``` + +Imports a profile from a tar.gz archive. + +| Argument | Description | +|----------|-------------| +| `` | Path to the tar.gz archive to import. | +| `[name]` | Name for the imported profile. Defaults to the original profile name from the archive. | + +**Example:** + +```bash +hermes profile import ./work-2026-03-29.tar.gz work-restored +``` + +## `hermes -p` / `hermes --profile` + +```bash +hermes -p [options] +hermes --profile [options] +``` + +Global flag to run any Hermes command under a specific profile without changing the sticky default. This overrides the active profile for the duration of the command. + +| Option | Description | +|--------|-------------| +| `-p `, `--profile ` | Profile to use for this command. | + +**Examples:** + +```bash +hermes -p work chat -q "Check the server status" +hermes --profile dev gateway start +hermes -p personal skills list +hermes -p work config edit +``` + +## `hermes completion` + +```bash +hermes completion +``` + +Generates shell completion scripts. Includes completions for profile names and profile subcommands. + +| Argument | Description | +|----------|-------------| +| `` | Shell to generate completions for: `bash`, `zsh`, or `fish`. | + +**Examples:** + +```bash +# Install completions +hermes completion bash >> ~/.bashrc +hermes completion zsh >> ~/.zshrc +hermes completion fish > ~/.config/fish/completions/hermes.fish + +# Reload shell +source ~/.bashrc +``` + +After installation, tab completion works for: +- `hermes profile ` — subcommands (list, use, create, etc.) +- `hermes profile use ` — profile names +- `hermes -p ` — profile names + +## See also + +- [Profiles User Guide](../user-guide/profiles.md) +- [CLI Commands Reference](./cli-commands.md) +- [FAQ — Profiles section](./faq.md#profiles) diff --git a/website/docs/reference/skills-catalog.md b/website/docs/reference/skills-catalog.md index 4f6889b09..c0d83212f 100644 --- a/website/docs/reference/skills-catalog.md +++ b/website/docs/reference/skills-catalog.md @@ -48,6 +48,14 @@ Creative content generation — ASCII art, hand-drawn style diagrams, and visual | `ascii-video` | "Production pipeline for ASCII art video — any format. Converts video/audio/images/generative input into colored ASCII character video output (MP4, GIF, image sequence). Covers: video-to-ASCII conversion, audio-reactive music visualizers, generative ASCII art animations, hybrid… | `creative/ascii-video` | | `excalidraw` | Create hand-drawn style diagrams using Excalidraw JSON format. Generate .excalidraw files for architecture diagrams, flowcharts, sequence diagrams, concept maps, and more. Files can be opened at excalidraw.com or uploaded for shareable links. | `creative/excalidraw` | +## devops + +DevOps and infrastructure automation skills. + +| Skill | Description | Path | +|-------|-------------|------| +| `webhook-subscriptions` | Create and manage webhook subscriptions for event-driven agent activation. External services (GitHub, Stripe, CI/CD, IoT) POST events to trigger agent runs. Requires webhook platform to be enabled. | `devops/webhook-subscriptions` | + ## dogfood | Skill | Description | Path | @@ -245,11 +253,18 @@ Skills for academic research, paper discovery, literature review, domain reconna | `arxiv` | Search and retrieve academic papers from arXiv using their free REST API. No API key needed. Search by keyword, author, category, or ID. Combine with web_extract or the ocr-and-documents skill to read full paper content. | `research/arxiv` | | `blogwatcher` | Monitor blogs and RSS/Atom feeds for updates using the blogwatcher CLI. Add blogs, scan for new articles, and track what you've read. | `research/blogwatcher` | | `domain-intel` | Passive domain reconnaissance using Python stdlib. Subdomain discovery, SSL certificate inspection, WHOIS lookups, DNS records, domain availability checks, and bulk multi-domain analysis. No API keys required. | `research/domain-intel` | -| `duckduckgo-search` | Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Use the Python DDGS library or CLI to search, then web_extract for full content. | `research/duckduckgo-search` | -| `parallel-cli` | Optional vendor skill for Parallel CLI — agent-native web search, extraction, deep research, enrichment, FindAll, and monitoring. | `research/parallel-cli` | +| `duckduckgo-search` | Free web search via DuckDuckGo — text, news, images, videos. No API key needed. Prefer the `ddgs` CLI when installed; use the Python DDGS library only after verifying that `ddgs` is available in the current runtime. | `research/duckduckgo-search` | | `ml-paper-writing` | Write publication-ready ML/AI papers for NeurIPS, ICML, ICLR, ACL, AAAI, COLM. Use when drafting papers from research repos, structuring arguments, verifying citations, or preparing camera-ready submissions. Includes LaTeX templates, reviewer guidelines, and citation verificatio… | `research/ml-paper-writing` | | `polymarket` | Query Polymarket prediction market data — search markets, get prices, orderbooks, and price history. Read-only via public REST APIs, no API key needed. | `research/polymarket` | +## red-teaming + +Skills for LLM red-teaming, jailbreaking, and safety filter bypass research. + +| Skill | Description | Path | +|-------|-------------|------| +| `godmode` | Jailbreak API-served LLMs using G0DM0D3 techniques — Parseltongue input obfuscation (33 techniques), GODMODE CLASSIC system prompt templates, ULTRAPLINIAN multi-model racing, encoding escalation, and Hermes-native prefill/system prompt integration. Works on any model accessible via API including closed-source models. | `red-teaming/godmode` | + ## smart-home Skills for controlling smart home devices — lights, switches, sensors, and home automation systems. @@ -277,3 +292,84 @@ Skills for interacting with social platforms — posting, reading, monitoring, a | `systematic-debugging` | Use when encountering any bug, test failure, or unexpected behavior. 4-phase root cause investigation — NO fixes without understanding the problem first. | `software-development/systematic-debugging` | | `test-driven-development` | Use when implementing any feature or bugfix, before writing implementation code. Enforces RED-GREEN-REFACTOR cycle with test-first approach. | `software-development/test-driven-development` | | `writing-plans` | Use when you have a spec or requirements for a multi-step task. Creates comprehensive implementation plans with bite-sized tasks, exact file paths, and complete code examples. | `software-development/writing-plans` | + +--- + +# Optional Skills + +Optional skills ship with the repository under `optional-skills/` but are **not active by default**. They cover heavier or niche use cases. Install them with: + +```bash +hermes skills install official// +``` + +## autonomous-ai-agents + +| Skill | Description | Path | +|-------|-------------|------| +| `blackbox` | Delegate coding tasks to Blackbox AI CLI agent. Multi-model agent with built-in judge that runs tasks through multiple LLMs and picks the best result. Requires the blackbox CLI and a Blackbox AI API key. | `autonomous-ai-agents/blackbox` | + +## blockchain + +| Skill | Description | Path | +|-------|-------------|------| +| `base` | Query Base (Ethereum L2) blockchain data with USD pricing — wallet balances, token info, transaction details, gas analysis, contract inspection, whale detection, and live network stats. Uses Base RPC + CoinGecko. No API key required. | `blockchain/base` | +| `solana` | Query Solana blockchain data with USD pricing — wallet balances, token portfolios with values, transaction details, NFTs, whale detection, and live network stats. Uses Solana RPC + CoinGecko. No API key required. | `blockchain/solana` | + +## creative + +| Skill | Description | Path | +|-------|-------------|------| +| `blender-mcp` | Control Blender directly from Hermes via socket connection to the blender-mcp addon. Create 3D objects, materials, animations, and run arbitrary Blender Python (bpy) code. | `creative/blender-mcp` | +| `meme-generation` | Generate real meme images by picking a template and overlaying text with Pillow. Produces actual .png meme files. | `creative/meme-generation` | + +## devops + +| Skill | Description | Path | +|-------|-------------|------| +| `docker-management` | Manage Docker containers, images, volumes, networks, and Compose stacks — lifecycle ops, debugging, cleanup, and Dockerfile optimization. | `devops/docker-management` | + +## email + +| Skill | Description | Path | +|-------|-------------|------| +| `agentmail` | Give the agent its own dedicated email inbox via AgentMail. Send, receive, and manage email autonomously using agent-owned email addresses (e.g. hermes-agent@agentmail.to). | `email/agentmail` | + +## health + +| Skill | Description | Path | +|-------|-------------|------| +| `neuroskill-bci` | Connect to a running NeuroSkill instance and incorporate the user's real-time cognitive and emotional state (focus, relaxation, mood, cognitive load, drowsiness, heart rate, HRV, sleep staging, and 40+ derived EXG scores) into responses. Requires a BCI wearable (Muse 2/S or OpenBCI) and the NeuroSkill desktop app. | `health/neuroskill-bci` | + +## mcp + +| Skill | Description | Path | +|-------|-------------|------| +| `fastmcp` | Build, test, inspect, install, and deploy MCP servers with FastMCP in Python. Use when creating a new MCP server, wrapping an API or database as MCP tools, exposing resources or prompts, or preparing a FastMCP server for HTTP deployment. | `mcp/fastmcp` | + +## migration + +| Skill | Description | Path | +|-------|-------------|------| +| `openclaw-migration` | Migrate a user's OpenClaw customization footprint into Hermes Agent. Imports Hermes-compatible memories, SOUL.md, command allowlists, user skills, and selected workspace assets from ~/.openclaw, then reports what could not be migrated and why. | `migration/openclaw-migration` | + +## productivity + +| Skill | Description | Path | +|-------|-------------|------| +| `telephony` | Give Hermes phone capabilities — provision and persist a Twilio number, send and receive SMS/MMS, make direct calls, and place AI-driven outbound calls through Bland.ai or Vapi. | `productivity/telephony` | + +## research + +| Skill | Description | Path | +|-------|-------------|------| +| `bioinformatics` | Gateway to 400+ bioinformatics skills from bioSkills and ClawBio. Covers genomics, transcriptomics, single-cell, variant calling, pharmacogenomics, metagenomics, structural biology, and more. | `research/bioinformatics` | +| `qmd` | Search personal knowledge bases, notes, docs, and meeting transcripts locally using qmd — a hybrid retrieval engine with BM25, vector search, and LLM reranking. Supports CLI and MCP integration. | `research/qmd` | + +## security + +| Skill | Description | Path | +|-------|-------------|------| +| `1password` | Set up and use 1Password CLI (op). Use when installing the CLI, enabling desktop app integration, signing in, and reading/injecting secrets for commands. | `security/1password` | +| `oss-forensics` | Supply chain investigation, evidence recovery, and forensic analysis for GitHub repositories. Covers deleted commit recovery, force-push detection, IOC extraction, multi-source evidence collection, and structured forensic reporting. | `security/oss-forensics` | +| `sherlock` | OSINT username search across 400+ social networks. Hunt down social media accounts by username. | `security/sherlock` | diff --git a/website/docs/reference/toolsets-reference.md b/website/docs/reference/toolsets-reference.md index bb1813379..83cf92e4c 100644 --- a/website/docs/reference/toolsets-reference.md +++ b/website/docs/reference/toolsets-reference.md @@ -19,10 +19,16 @@ Toolsets are named bundles of tools that you can enable with `hermes chat --tool | `file` | core | `patch`, `read_file`, `search_files`, `write_file` | | `hermes-acp` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_console`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `delegate_task`, `execute_code`, `memory`, `patch`, `process`, `read_file`, `search_files`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` | | `hermes-cli` | platform | `browser_back`, `browser_click`, `browser_close`, `browser_console`, `browser_get_images`, `browser_navigate`, `browser_press`, `browser_scroll`, `browser_snapshot`, `browser_type`, `browser_vision`, `clarify`, `cronjob`, `delegate_task`, `execute_code`, `ha_call_service`, `ha_get_state`, `ha_list_entities`, `ha_list_services`, `honcho_conclude`, `honcho_context`, `honcho_profile`, `honcho_search`, `image_generate`, `memory`, `mixture_of_agents`, `patch`, `process`, `read_file`, `search_files`, `send_message`, `session_search`, `skill_manage`, `skill_view`, `skills_list`, `terminal`, `text_to_speech`, `todo`, `vision_analyze`, `web_extract`, `web_search`, `write_file` | +| `hermes-api-server` | platform | _(same as hermes-cli)_ | +| `hermes-dingtalk` | platform | _(same as hermes-cli)_ | +| `hermes-feishu` | platform | _(same as hermes-cli)_ | +| `hermes-wecom` | platform | _(same as hermes-cli)_ | | `hermes-discord` | platform | _(same as hermes-cli)_ | | `hermes-email` | platform | _(same as hermes-cli)_ | | `hermes-gateway` | composite | Union of all messaging platform toolsets | | `hermes-homeassistant` | platform | _(same as hermes-cli)_ | +| `hermes-matrix` | platform | _(same as hermes-cli)_ | +| `hermes-mattermost` | platform | _(same as hermes-cli)_ | | `hermes-signal` | platform | _(same as hermes-cli)_ | | `hermes-slack` | platform | _(same as hermes-cli)_ | | `hermes-sms` | platform | _(same as hermes-cli)_ | diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index 4aa5afb0b..3d88377d5 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -92,8 +92,13 @@ You need at least one way to connect to an LLM. Use `hermes model` to switch pro | **Kilo Code** | `KILOCODE_API_KEY` in `~/.hermes/.env` (provider: `kilocode`) | | **OpenCode Zen** | `OPENCODE_ZEN_API_KEY` in `~/.hermes/.env` (provider: `opencode-zen`) | | **OpenCode Go** | `OPENCODE_GO_API_KEY` in `~/.hermes/.env` (provider: `opencode-go`) | +| **Hugging Face** | `HF_TOKEN` in `~/.hermes/.env` (provider: `huggingface`, aliases: `hf`) | | **Custom Endpoint** | `hermes model` (saved in `config.yaml`) or `OPENAI_BASE_URL` + `OPENAI_API_KEY` in `~/.hermes/.env` | +:::tip Model key alias +In the `model:` config section, you can use either `default:` or `model:` as the key name for your model ID. Both `model: { default: my-model }` and `model: { model: my-model }` work identically. +::: + :::info Codex Note The OpenAI Codex provider authenticates via device code (open a URL, enter a code). Hermes stores the resulting credentials in its own auth store under `~/.hermes/auth.json` and can import existing Codex CLI credentials from `~/.codex/auth.json` when present. No Codex CLI installation is required. ::: @@ -211,7 +216,7 @@ hermes chat --provider minimax-cn --model MiniMax-M2.7 # Requires: MINIMAX_CN_API_KEY in ~/.hermes/.env # Alibaba Cloud / DashScope (Qwen models) -hermes chat --provider alibaba --model qwen-plus +hermes chat --provider alibaba --model qwen3.5-plus # Requires: DASHSCOPE_API_KEY in ~/.hermes/.env ``` @@ -224,6 +229,32 @@ model: Base URLs can be overridden with `GLM_BASE_URL`, `KIMI_BASE_URL`, `MINIMAX_BASE_URL`, `MINIMAX_CN_BASE_URL`, or `DASHSCOPE_BASE_URL` environment variables. +### Hugging Face Inference Providers + +[Hugging Face Inference Providers](https://huggingface.co/docs/inference-providers) routes to 20+ open models through a unified OpenAI-compatible endpoint (`router.huggingface.co/v1`). Requests are automatically routed to the fastest available backend (Groq, Together, SambaNova, etc.) with automatic failover. + +```bash +# Use any available model +hermes chat --provider huggingface --model Qwen/Qwen3-235B-A22B-Thinking-2507 +# Requires: HF_TOKEN in ~/.hermes/.env + +# Short alias +hermes chat --provider hf --model deepseek-ai/DeepSeek-V3.2 +``` + +Or set it permanently in `config.yaml`: +```yaml +model: + provider: "huggingface" + default: "Qwen/Qwen3-235B-A22B-Thinking-2507" +``` + +Get your token at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) — make sure to enable the "Make calls to Inference Providers" permission. Free tier included ($0.10/month credit, no markup on provider rates). + +You can append routing suffixes to model names: `:fastest` (default), `:cheapest`, or `:provider_name` to force a specific backend. + +The base URL can be overridden with `HF_BASE_URL`. + ## Custom & Self-Hosted LLM Providers Hermes Agent works with **any OpenAI-compatible API endpoint**. If a server implements `/v1/chat/completions`, you can point Hermes at it. This means you can use local models, GPU inference servers, multi-provider routers, or any third-party API. @@ -627,7 +658,7 @@ fallback_model: When activated, the fallback swaps the model and provider mid-session without losing your conversation. It fires **at most once** per session. -Supported providers: `openrouter`, `nous`, `openai-codex`, `copilot`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `custom`. +Supported providers: `openrouter`, `nous`, `openai-codex`, `copilot`, `anthropic`, `huggingface`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `custom`. :::tip Fallback is configured exclusively through `config.yaml` — there are no environment variables for it. For full details on when it triggers, supported providers, and how it interacts with auxiliary tasks and delegation, see [Fallback Providers](/docs/user-guide/features/fallback-providers). @@ -998,6 +1029,7 @@ auxiliary: model: "" # e.g. "google/gemini-2.5-flash" base_url: "" api_key: "" + timeout: 30 # seconds # Dangerous command approval classifier approval: @@ -1005,8 +1037,17 @@ auxiliary: model: "" base_url: "" api_key: "" + timeout: 30 # seconds + + # Context compression timeout (separate from compression.* config) + compression: + timeout: 120 # seconds — compression summarizes long conversations, needs more time ``` +:::tip +Each auxiliary task has a configurable `timeout` (in seconds). Defaults: vision 30s, web_extract 30s, approval 30s, compression 120s. Increase these if you use slow local models for auxiliary tasks. +::: + :::info Context compression has its own top-level `compression:` block with `summary_provider`, `summary_model`, and `summary_base_url` — see [Context Compression](#context-compression) above. The fallback model uses a `fallback_model:` block — see [Fallback Model](#fallback-model) above. All three follow the same provider/model/base_url pattern. ::: @@ -1138,6 +1179,24 @@ You can also change the reasoning effort at runtime with the `/reasoning` comman /reasoning hide # Hide model thinking ``` +## Tool-Use Enforcement + +Some models (especially GPT-family) occasionally describe intended actions as text instead of making tool calls. Tool-use enforcement injects guidance that steers the model back to actually calling tools. + +```yaml +agent: + tool_use_enforcement: "auto" # "auto" | true | false | ["model-substring", ...] +``` + +| Value | Behavior | +|-------|----------| +| `"auto"` (default) | Enabled for GPT models (`gpt-`, `openai/gpt-`) and disabled for all others. | +| `true` | Always enabled for all models. | +| `false` | Always disabled. | +| `["gpt-", "o1-", "custom-model"]` | Enabled only for models whose name contains one of the listed substrings. | + +When enabled, the system prompt includes guidance reminding the model to make actual tool calls rather than describing what it would do. This is transparent to the user and has no effect on models that already use tools reliably. + ## TTS Configuration ```yaml diff --git a/website/docs/user-guide/docker.md b/website/docs/user-guide/docker.md new file mode 100644 index 000000000..229919774 --- /dev/null +++ b/website/docs/user-guide/docker.md @@ -0,0 +1,56 @@ +# Hermes Agent — Docker + +Want to run Hermes Agent, but without installing packages on your host? This'll sort you out. + +This will let you run the agent in a container, with the most relevant modes outlined below. + +The container stores all user data (config, API keys, sessions, skills, memories) in a single directory mounted from the host at `/opt/data`. The image itself is stateless and can be upgraded by pulling a new version without losing any configuration. + +## Quick start + +If this is your first time running Hermes Agent, create a data directory on the host and start the container interactively to run the setup wizard: + +```sh +mkdir -p ~/.hermes +docker run -it --rm \ + -v ~/.hermes:/opt/data \ + nousresearch/hermes-agent +``` + +This drops you into the setup wizard, which will prompt you for your API keys and write them to `~/.hermes/.env`. You only need to do this once. It is highly recommended to set up a chat system for the gateway to work with at this point. + +## Running in gateway mode + +Once configured, run the container in the background as a persistent gateway (Telegram, Discord, Slack, WhatsApp, etc.): + +```sh +docker run -d \ + --name hermes \ + --restart unless-stopped \ + -v ~/.hermes:/opt/data \ + nousresearch/hermes-agent gateway run +``` + +## Running interactively (CLI chat) + +To open an interactive chat session against a running data directory: + +```sh +docker run -it --rm \ + -v ~/.hermes:/opt/data \ + nousresearch/hermes-agent +``` + +## Upgrading + +Pull the latest image and recreate the container. Your data directory is untouched. + +```sh +docker pull nousresearch/hermes-agent:latest +docker rm -f hermes +docker run -d \ + --name hermes \ + --restart unless-stopped \ + -v ~/.hermes:/opt/data \ + nousresearch/hermes-agent +``` diff --git a/website/docs/user-guide/features/api-server.md b/website/docs/user-guide/features/api-server.md index 3fab67441..6739ad7ab 100644 --- a/website/docs/user-guide/features/api-server.md +++ b/website/docs/user-guide/features/api-server.md @@ -154,7 +154,7 @@ Lists `hermes-agent` as an available model. Required by most frontends for model ### GET /health -Health check. Returns `{"status": "ok"}`. +Health check. Returns `{"status": "ok"}`. Also available at **GET /v1/health** for OpenAI-compatible clients that expect the `/v1/` prefix. ## System Prompt Handling @@ -199,6 +199,12 @@ The default bind address (`127.0.0.1`) is for local-only use. Browser access is # config.yaml support coming in a future release. ``` +## Security Headers + +All responses include security headers: +- `X-Content-Type-Options: nosniff` — prevents MIME type sniffing +- `Referrer-Policy: no-referrer` — prevents referrer leakage + ## CORS The API server does **not** enable browser CORS by default. @@ -209,6 +215,11 @@ For direct browser access, set an explicit allowlist: API_SERVER_CORS_ORIGINS=http://localhost:3000,http://127.0.0.1:3000 ``` +When CORS is enabled: +- **Preflight responses** include `Access-Control-Max-Age: 600` (10 minute cache) +- **SSE streaming responses** include CORS headers so browser EventSource clients work correctly +- **`Idempotency-Key`** is an allowed request header — clients can send it for deduplication (responses are cached by key for 5 minutes) + Most documented frontends such as Open WebUI connect server-to-server and do not need CORS at all. ## Compatible Frontends diff --git a/website/docs/user-guide/features/fallback-providers.md b/website/docs/user-guide/features/fallback-providers.md index 63e9337e4..e46f69e35 100644 --- a/website/docs/user-guide/features/fallback-providers.md +++ b/website/docs/user-guide/features/fallback-providers.md @@ -44,6 +44,8 @@ Both `provider` and `model` are **required**. If either is missing, the fallback | MiniMax | `minimax` | `MINIMAX_API_KEY` | | MiniMax (China) | `minimax-cn` | `MINIMAX_CN_API_KEY` | | Kilo Code | `kilocode` | `KILOCODE_API_KEY` | +| Alibaba / DashScope | `alibaba` | `DASHSCOPE_API_KEY` | +| Hugging Face | `huggingface` | `HF_TOKEN` | | Custom endpoint | `custom` | `base_url` + `api_key_env` (see below) | ### Custom Endpoint Fallback @@ -161,7 +163,7 @@ When a task's provider is set to `"auto"` (the default), Hermes tries providers ```text OpenRouter → Nous Portal → Custom endpoint → Codex OAuth → -API-key providers (z.ai, Kimi, MiniMax, Anthropic) → give up +API-key providers (z.ai, Kimi, MiniMax, Hugging Face, Anthropic) → give up ``` **For vision tasks:** diff --git a/website/docs/user-guide/features/hooks.md b/website/docs/user-guide/features/hooks.md index 272ea9cea..87c7f9846 100644 --- a/website/docs/user-guide/features/hooks.md +++ b/website/docs/user-guide/features/hooks.md @@ -88,6 +88,26 @@ Handlers registered for `command:*` fire for any `command:` event (`command:mode ### Examples +#### Boot Checklist (BOOT.md) — Built-in + +The gateway ships with a built-in `boot-md` hook that looks for `~/.hermes/BOOT.md` on every startup. If the file exists, the agent runs its instructions in a background session. No installation needed — just create the file. + +**Create `~/.hermes/BOOT.md`:** + +```markdown +# Startup Checklist + +1. Check if any cron jobs failed overnight — run `hermes cron list` +2. Send a message to Discord #general saying "Gateway restarted, all systems go" +3. Check if /opt/app/deploy.log has any errors from the last 24 hours +``` + +The agent runs these instructions in a background thread so it doesn't block gateway startup. If nothing needs attention, the agent replies with `[SILENT]` and no message is delivered. + +:::tip +No BOOT.md? The hook silently skips — zero overhead. Create the file whenever you need startup automation, delete it when you don't. +::: + #### Telegram Alert on Long Tasks Send yourself a message when the agent takes more than 10 steps: @@ -209,10 +229,10 @@ def register(ctx): |------|-----------|-------------------| | `pre_tool_call` | Before any tool executes | `tool_name`, `args`, `task_id` | | `post_tool_call` | After any tool returns | `tool_name`, `args`, `result`, `task_id` | -| `pre_llm_call` | Before LLM API request | *(planned — not yet wired)* | -| `post_llm_call` | After LLM API response | *(planned — not yet wired)* | -| `on_session_start` | Session begins | *(planned — not yet wired)* | -| `on_session_end` | Session ends | *(planned — not yet wired)* | +| `pre_llm_call` | Before LLM API request | `session_id`, `user_message`, `conversation_history`, `is_first_turn`, `model`, `platform` | +| `post_llm_call` | After LLM API response | `session_id`, `user_message`, `assistant_response`, `conversation_history`, `model`, `platform` | +| `on_session_start` | Session begins | `session_id`, `model`, `platform` | +| `on_session_end` | Session ends | `session_id`, `completed`, `interrupted`, `model`, `platform` | Callbacks receive keyword arguments matching the columns above: diff --git a/website/docs/user-guide/features/mcp.md b/website/docs/user-guide/features/mcp.md index 15890015b..9b8326d46 100644 --- a/website/docs/user-guide/features/mcp.md +++ b/website/docs/user-guide/features/mcp.md @@ -403,6 +403,105 @@ Because Hermes now only registers those wrappers when both are true: This is intentional and keeps the tool list honest. +## Running Hermes as an MCP server + +In addition to connecting **to** MCP servers, Hermes can also **be** an MCP server. This lets other MCP-capable agents (Claude Code, Cursor, Codex, or any MCP client) use Hermes's messaging capabilities — list conversations, read message history, and send messages across all your connected platforms. + +### When to use this + +- You want Claude Code, Cursor, or another coding agent to send and read Telegram/Discord/Slack messages through Hermes +- You want a single MCP server that bridges to all of Hermes's connected messaging platforms at once +- You already have a running Hermes gateway with connected platforms + +### Quick start + +```bash +hermes mcp serve +``` + +This starts a stdio MCP server. The MCP client (not you) manages the process lifecycle. + +### MCP client configuration + +Add Hermes to your MCP client config. For example, in Claude Code's `~/.claude/claude_desktop_config.json`: + +```json +{ + "mcpServers": { + "hermes": { + "command": "hermes", + "args": ["mcp", "serve"] + } + } +} +``` + +Or if you installed Hermes in a specific location: + +```json +{ + "mcpServers": { + "hermes": { + "command": "/home/user/.hermes/hermes-agent/venv/bin/hermes", + "args": ["mcp", "serve"] + } + } +} +``` + +### Available tools + +The MCP server exposes 10 tools, matching OpenClaw's channel bridge surface plus a Hermes-specific channel browser: + +| Tool | Description | +|------|-------------| +| `conversations_list` | List active messaging conversations. Filter by platform or search by name. | +| `conversation_get` | Get detailed info about one conversation by session key. | +| `messages_read` | Read recent message history for a conversation. | +| `attachments_fetch` | Extract non-text attachments (images, media) from a specific message. | +| `events_poll` | Poll for new conversation events since a cursor position. | +| `events_wait` | Long-poll / block until the next event arrives (near-real-time). | +| `messages_send` | Send a message through a platform (e.g. `telegram:123456`, `discord:#general`). | +| `channels_list` | List available messaging targets across all platforms. | +| `permissions_list_open` | List pending approval requests observed during this bridge session. | +| `permissions_respond` | Allow or deny a pending approval request. | + +### Event system + +The MCP server includes a live event bridge that polls Hermes's session database for new messages. This gives MCP clients near-real-time awareness of incoming conversations: + +``` +# Poll for new events (non-blocking) +events_poll(after_cursor=0) + +# Wait for next event (blocks up to timeout) +events_wait(after_cursor=42, timeout_ms=30000) +``` + +Event types: `message`, `approval_requested`, `approval_resolved` + +The event queue is in-memory and starts when the bridge connects. Older messages are available through `messages_read`. + +### Options + +```bash +hermes mcp serve # Normal mode +hermes mcp serve --verbose # Debug logging on stderr +``` + +### How it works + +The MCP server reads conversation data directly from Hermes's session store (`~/.hermes/sessions/sessions.json` and the SQLite database). A background thread polls the database for new messages and maintains an in-memory event queue. For sending messages, it uses the same `send_message` infrastructure as the Hermes agent itself. + +The gateway does NOT need to be running for read operations (listing conversations, reading history, polling events). It DOES need to be running for send operations, since the platform adapters need active connections. + +### Current limits + +- Stdio transport only (no HTTP MCP transport yet) +- Event polling at ~200ms intervals via mtime-optimized DB polling (skips work when files are unchanged) +- No `claude/channel` push notification protocol yet +- Text-only sends (no media/attachment sending through `messages_send`) + ## Related docs - [Use MCP with Hermes](/docs/guides/use-mcp-with-hermes) diff --git a/website/docs/user-guide/features/plugins.md b/website/docs/user-guide/features/plugins.md index 30ab6c35a..0f2e20f17 100644 --- a/website/docs/user-guide/features/plugins.md +++ b/website/docs/user-guide/features/plugins.md @@ -52,10 +52,10 @@ Plugins can register callbacks for these lifecycle events. See the **[Event Hook |------|-----------| | `pre_tool_call` | Before any tool executes | | `post_tool_call` | After any tool returns | -| `pre_llm_call` | Before LLM API request *(planned)* | -| `post_llm_call` | After LLM API response *(planned)* | -| `on_session_start` | Session begins *(planned)* | -| `on_session_end` | Session ends *(planned)* | +| `pre_llm_call` | Once per turn, before the LLM loop — can return `{"context": "..."}` to inject into the system prompt | +| `post_llm_call` | Once per turn, after the LLM loop completes | +| `on_session_start` | New session created (first turn only) | +| `on_session_end` | End of every `run_conversation` call | ## Slash commands @@ -87,9 +87,26 @@ The handler receives the argument string (everything after `/greet`) and returns ## Managing plugins -``` -/plugins # list loaded plugins in a session -hermes config set display.show_cost true # show cost in status bar +```bash +hermes plugins # interactive toggle UI — enable/disable with checkboxes +hermes plugins list # table view with enabled/disabled status +hermes plugins install user/repo # install from Git +hermes plugins update my-plugin # pull latest +hermes plugins remove my-plugin # uninstall +hermes plugins enable my-plugin # re-enable a disabled plugin +hermes plugins disable my-plugin # disable without removing ``` +Running `hermes plugins` with no arguments launches an interactive curses checklist (same UI as `hermes tools`) where you can toggle plugins on/off with arrow keys and space. + +Disabled plugins remain installed but are skipped during loading. The disabled list is stored in `config.yaml` under `plugins.disabled`: + +```yaml +plugins: + disabled: + - my-noisy-plugin +``` + +In a running session, `/plugins` shows which plugins are currently loaded. + See the **[full guide](/docs/guides/build-a-hermes-plugin)** for handler contracts, schema format, hook behavior, error handling, and common mistakes. diff --git a/website/docs/user-guide/features/skills.md b/website/docs/user-guide/features/skills.md index d21c98885..3d166b978 100644 --- a/website/docs/user-guide/features/skills.md +++ b/website/docs/user-guide/features/skills.md @@ -8,7 +8,9 @@ description: "On-demand knowledge documents — progressive disclosure, agent-ma Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage and are compatible with the [agentskills.io](https://agentskills.io/specification) open standard. -All skills live in **`~/.hermes/skills/`** — a single directory that serves as the source of truth. On fresh install, bundled skills are copied from the repo. Hub-installed and agent-created skills also go here. The agent can modify or delete any skill. +All skills live in **`~/.hermes/skills/`** — the primary directory and source of truth. On fresh install, bundled skills are copied from the repo. Hub-installed and agent-created skills also go here. The agent can modify or delete any skill. + +You can also point Hermes at **external skill directories** — additional folders scanned alongside the local one. See [External Skill Directories](#external-skill-directories) below. See also: @@ -164,6 +166,47 @@ Once set, declared env vars are **automatically passed through** to `execute_cod └── .bundled_manifest # Tracks seeded bundled skills ``` +## External Skill Directories + +If you maintain skills outside of Hermes — for example, a shared `~/.agents/skills/` directory used by multiple AI tools — you can tell Hermes to scan those directories too. + +Add `external_dirs` under the `skills` section in `~/.hermes/config.yaml`: + +```yaml +skills: + external_dirs: + - ~/.agents/skills + - /home/shared/team-skills + - ${SKILLS_REPO}/skills +``` + +Paths support `~` expansion and `${VAR}` environment variable substitution. + +### How it works + +- **Read-only**: External dirs are only scanned for skill discovery. When the agent creates or edits a skill, it always writes to `~/.hermes/skills/`. +- **Local precedence**: If the same skill name exists in both the local dir and an external dir, the local version wins. +- **Full integration**: External skills appear in the system prompt index, `skills_list`, `skill_view`, and as `/skill-name` slash commands — no different from local skills. +- **Non-existent paths are silently skipped**: If a configured directory doesn't exist, Hermes ignores it without errors. Useful for optional shared directories that may not be present on every machine. + +### Example + +```text +~/.hermes/skills/ # Local (primary, read-write) +├── devops/deploy-k8s/ +│ └── SKILL.md +└── mlops/axolotl/ + └── SKILL.md + +~/.agents/skills/ # External (read-only, shared) +├── my-custom-workflow/ +│ └── SKILL.md +└── team-conventions/ + └── SKILL.md +``` + +All four skills appear in your skill index. If you create a new skill called `my-custom-workflow` locally, it shadows the external version. + ## Agent-Managed Skills (skill_manage tool) The agent can create, update, and delete its own skills via the `skill_manage` tool. This is the agent's **procedural memory** — when it figures out a non-trivial workflow, it saves the approach as a skill for future reuse. @@ -277,9 +320,12 @@ hermes skills install well-known:https://mintlify.com/docs/.well-known/skills/mi Hermes can install directly from GitHub repositories and GitHub-based taps. This is useful when you already know the repo/path or want to add your own custom source repo. -- OpenAI skills: [openai/skills](https://github.com/openai/skills) -- Anthropic skills: [anthropics/skills](https://github.com/anthropics/skills) -- Example community tap source: [VoltAgent/awesome-agent-skills](https://github.com/VoltAgent/awesome-agent-skills) +Default taps (browsable without any setup): +- [openai/skills](https://github.com/openai/skills) +- [anthropics/skills](https://github.com/anthropics/skills) +- [VoltAgent/awesome-agent-skills](https://github.com/VoltAgent/awesome-agent-skills) +- [garrytan/gstack](https://github.com/garrytan/gstack) + - Example: ```bash diff --git a/website/docs/user-guide/features/tools.md b/website/docs/user-guide/features/tools.md index 981d2caf2..5e1ab601e 100644 --- a/website/docs/user-guide/features/tools.md +++ b/website/docs/user-guide/features/tools.md @@ -104,7 +104,7 @@ hermes config set terminal.singularity_image ~/python.sif ### Modal (Serverless Cloud) ```bash -uv pip install "swe-rex[modal]" +uv pip install modal modal setup hermes config set terminal.backend modal ``` diff --git a/website/docs/user-guide/messaging/discord.md b/website/docs/user-guide/messaging/discord.md index 0c2148c59..df97930a6 100644 --- a/website/docs/user-guide/messaging/discord.md +++ b/website/docs/user-guide/messaging/discord.md @@ -95,13 +95,17 @@ You'll land on the **General Information** page. Note the **Application ID** — 1. In the left sidebar, click **Bot**. 2. Discord automatically creates a bot user for your application. You'll see the bot's username, which you can customize. 3. Under **Authorization Flow**: - - Set **Public Bot** to **OFF** — this prevents other people from inviting your bot to their servers. + - Set **Public Bot** to **ON** — required to use the Discord-provided invite link (recommended). This allows the Installation tab to generate a default authorization URL. - Leave **Require OAuth2 Code Grant** set to **OFF**. :::tip You can set a custom avatar and banner for your bot on this page. This is what users will see in Discord. ::: +:::info[Private Bot Alternative] +If you prefer to keep your bot private (Public Bot = OFF), you **must** use the **Manual URL** method in Step 5 instead of the Installation tab. The Discord-provided link requires Public Bot to be enabled. +::: + ## Step 3: Enable Privileged Gateway Intents This is the most critical step in the entire setup. Without the correct intents enabled, your bot will connect to Discord but **will not be able to read message content**. @@ -149,6 +153,10 @@ You need an OAuth2 URL to invite the bot to your server. There are two ways to d ### Option A: Using the Installation Tab (Recommended) +:::note[Requires Public Bot] +This method requires **Public Bot** to be set to **ON** in Step 2. If you set Public Bot to OFF, use the Manual URL method below instead. +::: + 1. In the left sidebar, click **Installation**. 2. Under **Installation Contexts**, enable **Guild Install**. 3. For **Install Link**, select **Discord Provided Link**. @@ -361,3 +369,6 @@ Always set `DISCORD_ALLOWED_USERS` to restrict who can interact with the bot. Wi ::: For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md). + + + diff --git a/website/docs/user-guide/messaging/email.md b/website/docs/user-guide/messaging/email.md index c302532b1..c1cf6f5f3 100644 --- a/website/docs/user-guide/messaging/email.md +++ b/website/docs/user-guide/messaging/email.md @@ -104,6 +104,7 @@ The adapter polls the IMAP inbox for UNSEEN messages at a configurable interval - Documents (PDF, ZIP, etc.) → available for file access - **HTML-only emails** have tags stripped for plain text extraction - **Self-messages** are filtered out to prevent reply loops +- **Automated/noreply senders** are silently ignored — `noreply@`, `mailer-daemon@`, `bounce@`, `no-reply@`, and emails with `Auto-Submitted`, `Precedence: bulk`, or `List-Unsubscribe` headers ### Sending Replies diff --git a/website/docs/user-guide/messaging/feishu.md b/website/docs/user-guide/messaging/feishu.md new file mode 100644 index 000000000..f515648b8 --- /dev/null +++ b/website/docs/user-guide/messaging/feishu.md @@ -0,0 +1,129 @@ +--- +sidebar_position: 11 +title: "Feishu / Lark" +description: "Set up Hermes Agent as a Feishu or Lark bot" +--- + +# Feishu / Lark Setup + +Hermes Agent integrates with Feishu and Lark as a full-featured bot. Once connected, you can chat with the agent in direct messages or group chats, receive cron job results in a home chat, and send text, images, audio, and file attachments through the normal gateway flow. + +The integration supports both connection modes: + +- `websocket` — recommended; Hermes opens the outbound connection and you do not need a public webhook endpoint +- `webhook` — useful when you want Feishu/Lark to push events into your gateway over HTTP + +## How Hermes Behaves + +| Context | Behavior | +|---------|----------| +| Direct messages | Hermes responds to every message. | +| Group chats | Hermes responds when the bot is addressed in the chat. | +| Shared group chats | By default, session history is isolated per user inside a shared chat. | + +This shared-chat behavior is controlled by `config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +Set it to `false` only if you explicitly want one shared conversation per chat. + +## Step 1: Create a Feishu / Lark App + +1. Open the Feishu or Lark developer console: + - Feishu: + - Lark: +2. Create a new app. +3. In **Credentials & Basic Info**, copy the **App ID** and **App Secret**. +4. Enable the **Bot** capability for the app. + +:::warning +Keep the App Secret private. Anyone with it can impersonate your app. +::: + +## Step 2: Choose a Connection Mode + +### Recommended: WebSocket mode + +Use WebSocket mode when Hermes runs on your laptop, workstation, or a private server. No public URL is required. + +```bash +FEISHU_CONNECTION_MODE=websocket +``` + +### Optional: Webhook mode + +Use webhook mode only when you already run Hermes behind a reachable HTTP endpoint. + +```bash +FEISHU_CONNECTION_MODE=webhook +``` + +In webhook mode, Hermes serves a Feishu endpoint at: + +```text +/feishu/webhook +``` + +## Step 3: Configure Hermes + +### Option A: Interactive Setup + +```bash +hermes gateway setup +``` + +Select **Feishu / Lark** and fill in the prompts. + +### Option B: Manual Configuration + +Add the following to `~/.hermes/.env`: + +```bash +FEISHU_APP_ID=cli_xxx +FEISHU_APP_SECRET=secret_xxx +FEISHU_DOMAIN=feishu +FEISHU_CONNECTION_MODE=websocket + +# Optional but strongly recommended +FEISHU_ALLOWED_USERS=ou_xxx,ou_yyy +FEISHU_HOME_CHANNEL=oc_xxx +``` + +`FEISHU_DOMAIN` accepts: + +- `feishu` for Feishu China +- `lark` for Lark international + +## Step 4: Start the Gateway + +```bash +hermes gateway +``` + +Then message the bot from Feishu/Lark to confirm that the connection is live. + +## Home Chat + +Use `/set-home` in a Feishu/Lark chat to mark it as the home channel for cron job results and cross-platform notifications. + +You can also preconfigure it: + +```bash +FEISHU_HOME_CHANNEL=oc_xxx +``` + +## Security + +For production use, set an allowlist: + +```bash +FEISHU_ALLOWED_USERS=ou_xxx,ou_yyy +``` + +If you leave the allowlist empty, anyone who can reach the bot may be able to use it. + +## Toolset + +Feishu / Lark uses the `hermes-feishu` platform preset, which includes the same core tools as Telegram and other gateway-based messaging platforms. diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 6069df4f4..9073e45ff 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -6,7 +6,7 @@ description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, # Messaging Gateway -Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. +Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, Feishu/Lark, WeCom, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. For the full voice feature set — including CLI microphone mode, spoken replies in messaging, and Discord voice-channel conversations — see [Voice Mode](/docs/user-guide/features/voice-mode) and [Use Voice Mode with Hermes](/docs/guides/use-voice-mode-with-hermes). @@ -27,6 +27,8 @@ flowchart TB mm[Mattermost] mx[Matrix] dt[DingTalk] + fs[Feishu/Lark] + wc[WeCom] api["API Server
(OpenAI-compatible)"] wh[Webhooks] end @@ -289,12 +291,27 @@ If you run multiple Hermes installations on the same machine (with different `HE ### macOS (launchd) ```bash -hermes gateway install -launchctl start ai.hermes.gateway -launchctl stop ai.hermes.gateway -tail -f ~/.hermes/logs/gateway.log +hermes gateway install # Install as launchd agent +hermes gateway start # Start the service +hermes gateway stop # Stop the service +hermes gateway status # Check status +tail -f ~/.hermes/logs/gateway.log # View logs ``` +The generated plist lives at `~/Library/LaunchAgents/ai.hermes.gateway.plist`. It includes three environment variables: + +- **PATH** — your full shell PATH at install time, with the venv `bin/` and `node_modules/.bin` prepended. This ensures user-installed tools (Node.js, ffmpeg, etc.) are available to gateway subprocesses like the WhatsApp bridge. +- **VIRTUAL_ENV** — points to the Python virtualenv so tools can resolve packages correctly. +- **HERMES_HOME** — scopes the gateway to your Hermes installation. + +:::tip PATH changes after install +launchd plists are static — if you install new tools (e.g. a new Node.js version via nvm, or ffmpeg via Homebrew) after setting up the gateway, run `hermes gateway install` again to capture the updated PATH. The gateway will detect the stale plist and reload automatically. +::: + +:::info Multiple installations +Like the Linux systemd service, each `HERMES_HOME` directory gets its own launchd label. The default `~/.hermes` uses `ai.hermes.gateway`; other installations use `ai.hermes.gateway-`. +::: + ## Platform-Specific Toolsets Each platform has its own toolset: @@ -313,6 +330,8 @@ Each platform has its own toolset: | Mattermost | `hermes-mattermost` | Full tools including terminal | | Matrix | `hermes-matrix` | Full tools including terminal | | DingTalk | `hermes-dingtalk` | Full tools including terminal | +| Feishu/Lark | `hermes-feishu` | Full tools including terminal | +| WeCom | `hermes-wecom` | Full tools including terminal | | API Server | `hermes` (default) | Full tools including terminal | | Webhooks | `hermes-webhook` | Full tools including terminal | @@ -329,5 +348,7 @@ Each platform has its own toolset: - [Mattermost Setup](mattermost.md) - [Matrix Setup](matrix.md) - [DingTalk Setup](dingtalk.md) +- [Feishu/Lark Setup](feishu.md) +- [WeCom Setup](wecom.md) - [Open WebUI + API Server](open-webui.md) - [Webhooks](webhooks.md) diff --git a/website/docs/user-guide/messaging/mattermost.md b/website/docs/user-guide/messaging/mattermost.md index f959bb872..cff50e94d 100644 --- a/website/docs/user-guide/messaging/mattermost.md +++ b/website/docs/user-guide/messaging/mattermost.md @@ -149,6 +149,12 @@ MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c # Optional: reply mode (thread or off, default: off) # MATTERMOST_REPLY_MODE=thread + +# Optional: respond without @mention (default: true = require mention) +# MATTERMOST_REQUIRE_MENTION=false + +# Optional: channels where bot responds without @mention (comma-separated channel IDs) +# MATTERMOST_FREE_RESPONSE_CHANNELS=channel_id_1,channel_id_2 ``` Optional behavior settings in `~/.hermes/config.yaml`: @@ -206,6 +212,19 @@ Set it in your `~/.hermes/.env`: MATTERMOST_REPLY_MODE=thread ``` +## Mention Behavior + +By default, the bot only responds in channels when `@mentioned`. You can change this: + +| Variable | Default | Description | +|----------|---------|-------------| +| `MATTERMOST_REQUIRE_MENTION` | `true` | Set to `false` to respond to all messages in channels (DMs always work). | +| `MATTERMOST_FREE_RESPONSE_CHANNELS` | _(none)_ | Comma-separated channel IDs where the bot responds without `@mention`, even when require_mention is true. | + +To find a channel ID in Mattermost: open the channel, click the channel name header, and look for the ID in the URL or channel details. + +When the bot is `@mentioned`, the mention is automatically stripped from the message before processing. + ## Troubleshooting ### Bot is not responding to messages diff --git a/website/docs/user-guide/messaging/signal.md b/website/docs/user-guide/messaging/signal.md index ceebc3515..d47b7ca03 100644 --- a/website/docs/user-guide/messaging/signal.md +++ b/website/docs/user-guide/messaging/signal.md @@ -36,22 +36,6 @@ brew install signal-cli # Extract and add to PATH ``` -### Alternative: Docker (signal-cli-rest-api) - -If you prefer Docker, use the [signal-cli-rest-api](https://github.com/bbernhard/signal-cli-rest-api) container: - -```bash -docker run -d --name signal-cli \ - -p 8080:8080 \ - -v $HOME/.local/share/signal-cli:/home/.local/share/signal-cli \ - -e MODE=json-rpc \ - bbernhard/signal-cli-rest-api -``` - -:::tip -Use `MODE=json-rpc` for best performance. The `normal` mode spawns a JVM per request and is much slower. -::: - --- ## Step 1: Link Your Signal Account diff --git a/website/docs/user-guide/messaging/slack.md b/website/docs/user-guide/messaging/slack.md index a40ba470f..f011dcd78 100644 --- a/website/docs/user-guide/messaging/slack.md +++ b/website/docs/user-guide/messaging/slack.md @@ -114,7 +114,22 @@ Without these events, Slack simply never delivers channel messages to the bot. --- -## Step 5: Install App to Workspace +## Step 5: Enable the Messages Tab + +This step enables direct messages to the bot. Without it, users see **"Sending messages to this app has been turned off"** when trying to DM the bot. + +1. In the sidebar, go to **Features → App Home** +2. Scroll to **Show Tabs** +3. Toggle **Messages Tab** to ON +4. Check **"Allow users to send Slash commands and messages from the messages tab"** + +:::danger Without this step, DMs are completely blocked +Even with all the correct scopes and event subscriptions, Slack will not allow users to send direct messages to the bot unless the Messages Tab is enabled. This is a Slack platform requirement, not a Hermes configuration issue. +::: + +--- + +## Step 6: Install App to Workspace 1. In the sidebar, go to **Settings → Install App** 2. Click **Install to Workspace** @@ -129,7 +144,7 @@ to take effect. The Install App page will show a banner prompting you to do so. --- -## Step 6: Find User IDs for the Allowlist +## Step 7: Find User IDs for the Allowlist Hermes uses Slack **Member IDs** (not usernames or display names) for the allowlist. @@ -144,7 +159,7 @@ Member IDs look like `U01ABC2DEF3`. You need your own Member ID at minimum. --- -## Step 7: Configure Hermes +## Step 8: Configure Hermes Add the following to your `~/.hermes/.env` file: @@ -175,7 +190,7 @@ sudo hermes gateway install --system # Linux only: boot-time system service --- -## Step 8: Invite the Bot to Channels +## Step 9: Invite the Bot to Channels After starting the gateway, you need to **invite the bot** to any channel where you want it to respond: @@ -239,6 +254,7 @@ Hermes supports voice on Slack: | Bot works in DMs but not in channels | **Most common issue.** Add `message.channels` and `message.groups` to event subscriptions, reinstall the app, and invite the bot to the channel with `/invite @Hermes Agent` | | Bot doesn't respond to @mentions in channels | 1) Check `message.channels` event is subscribed. 2) Bot must be invited to the channel. 3) Ensure `channels:history` scope is added. 4) Reinstall the app after scope/event changes | | Bot ignores messages in private channels | Add both the `message.groups` event subscription and `groups:history` scope, then reinstall the app and `/invite` the bot | +| "Sending messages to this app has been turned off" in DMs | Enable the **Messages Tab** in App Home settings (see Step 5) | | "not_authed" or "invalid_auth" errors | Regenerate your Bot Token and App Token, update `.env` | | Bot responds but can't post in a channel | Invite the bot to the channel with `/invite @Hermes Agent` | | "missing_scope" error | Add the required scope in OAuth & Permissions, then **reinstall** the app | diff --git a/website/docs/user-guide/messaging/webhooks.md b/website/docs/user-guide/messaging/webhooks.md index 817446386..b804152f2 100644 --- a/website/docs/user-guide/messaging/webhooks.md +++ b/website/docs/user-guide/messaging/webhooks.md @@ -15,7 +15,7 @@ The agent processes the event and can respond by posting comments on PRs, sendin ## Quick Start 1. Enable via `hermes gateway setup` or environment variables -2. Define webhook routes in `config.yaml` +2. Define routes in `config.yaml` **or** create them dynamically with `hermes webhook subscribe` 3. Point your service at `http://your-server:8644/webhooks/` --- @@ -205,6 +205,56 @@ For cross-platform delivery (telegram, discord, slack, signal, sms), the target --- +## Dynamic Subscriptions (CLI) {#dynamic-subscriptions} + +In addition to static routes in `config.yaml`, you can create webhook subscriptions dynamically using the `hermes webhook` CLI command. This is especially useful when the agent itself needs to set up event-driven triggers. + +### Create a subscription + +```bash +hermes webhook subscribe github-issues \ + --events "issues" \ + --prompt "New issue #{issue.number}: {issue.title}\nBy: {issue.user.login}\n\n{issue.body}" \ + --deliver telegram \ + --deliver-chat-id "-100123456789" \ + --description "Triage new GitHub issues" +``` + +This returns the webhook URL and an auto-generated HMAC secret. Configure your service to POST to that URL. + +### List subscriptions + +```bash +hermes webhook list +``` + +### Remove a subscription + +```bash +hermes webhook remove github-issues +``` + +### Test a subscription + +```bash +hermes webhook test github-issues +hermes webhook test github-issues --payload '{"issue": {"number": 42, "title": "Test"}}' +``` + +### How dynamic subscriptions work + +- Subscriptions are stored in `~/.hermes/webhook_subscriptions.json` +- The webhook adapter hot-reloads this file on each incoming request (mtime-gated, negligible overhead) +- Static routes from `config.yaml` always take precedence over dynamic ones with the same name +- Dynamic subscriptions use the same route format and capabilities as static routes (events, prompt templates, skills, delivery) +- No gateway restart required — subscribe and it's immediately live + +### Agent-driven subscriptions + +The agent can create subscriptions via the terminal tool when guided by the `webhook-subscriptions` skill. Ask the agent to "set up a webhook for GitHub issues" and it will run the appropriate `hermes webhook subscribe` command. + +--- + ## Security {#security} The webhook adapter includes multiple layers of security: diff --git a/website/docs/user-guide/messaging/wecom.md b/website/docs/user-guide/messaging/wecom.md new file mode 100644 index 000000000..e5a551b8f --- /dev/null +++ b/website/docs/user-guide/messaging/wecom.md @@ -0,0 +1,86 @@ +--- +sidebar_position: 14 +title: "WeCom (Enterprise WeChat)" +description: "Connect Hermes Agent to WeCom via the AI Bot WebSocket gateway" +--- + +# WeCom (Enterprise WeChat) + +Connect Hermes to [WeCom](https://work.weixin.qq.com/) (企业微信), Tencent's enterprise messaging platform. The adapter uses WeCom's AI Bot WebSocket gateway for real-time bidirectional communication — no public endpoint or webhook needed. + +## Prerequisites + +- A WeCom organization account +- An AI Bot created in the WeCom Admin Console +- The Bot ID and Secret from the bot's credentials page + +## Setup + +### 1. Create an AI Bot + +1. Log in to the [WeCom Admin Console](https://work.weixin.qq.com/wework_admin/frame) +2. Navigate to **Applications** → **Create Application** → **AI Bot** +3. Configure the bot name and description +4. Copy the **Bot ID** and **Secret** from the credentials page + +### 2. Configure Hermes + +Run the interactive setup: + +```bash +hermes gateway setup +``` + +Select **WeCom** and enter your Bot ID and Secret. + +Or set environment variables in `~/.hermes/.env`: + +```bash +WECOM_BOT_ID=your-bot-id +WECOM_SECRET=your-secret + +# Optional: restrict access +WECOM_ALLOWED_USERS=user_id_1,user_id_2 + +# Optional: home channel for cron/notifications +WECOM_HOME_CHANNEL=chat_id +``` + +### 3. Start the gateway + +```bash +hermes gateway start +``` + +## Features + +- **WebSocket transport** — persistent connection, no public endpoint needed +- **DM and group messaging** — configurable access policies +- **Media support** — images, files, voice, video upload and download +- **AES-encrypted media** — automatic decryption for inbound attachments +- **Quote context** — preserves reply threading +- **Markdown rendering** — rich text responses +- **Auto-reconnect** — exponential backoff on connection drops + +## Configuration Options + +Set these in `config.yaml` under `platforms.wecom.extra`: + +| Key | Default | Description | +|-----|---------|-------------| +| `bot_id` | — | WeCom AI Bot ID (required) | +| `secret` | — | WeCom AI Bot Secret (required) | +| `websocket_url` | `wss://openws.work.weixin.qq.com` | WebSocket gateway URL | +| `dm_policy` | `open` | DM access: `open`, `allowlist`, `disabled`, `pairing` | +| `group_policy` | `open` | Group access: `open`, `allowlist`, `disabled` | +| `allow_from` | `[]` | User IDs allowed for DMs (when dm_policy=allowlist) | +| `group_allow_from` | `[]` | Group IDs allowed (when group_policy=allowlist) | + +## Troubleshooting + +| Problem | Fix | +|---------|-----| +| "WECOM_BOT_ID and WECOM_SECRET are required" | Set both env vars or configure in setup wizard | +| "invalid secret (errcode=40013)" | Verify the secret matches your bot's credentials | +| "Timed out waiting for subscribe acknowledgement" | Check network connectivity to `openws.work.weixin.qq.com` | +| Bot doesn't respond in groups | Check `group_policy` setting and group allowlist | diff --git a/website/docs/user-guide/messaging/whatsapp.md b/website/docs/user-guide/messaging/whatsapp.md index 57212df15..1c5226813 100644 --- a/website/docs/user-guide/messaging/whatsapp.md +++ b/website/docs/user-guide/messaging/whatsapp.md @@ -173,6 +173,7 @@ whatsapp: | **Logged out unexpectedly** | WhatsApp unlinks devices after long inactivity. Keep the phone on and connected to the network, then re-pair with `hermes whatsapp` if needed. | | **Bridge crashes or reconnect loops** | Restart the gateway, update Hermes, and re-pair if the session was invalidated by a WhatsApp protocol change. | | **Bot stops working after WhatsApp update** | Update Hermes to get the latest bridge version, then re-pair. | +| **macOS: "Node.js not installed" but node works in terminal** | launchd services don't inherit your shell PATH. Run `hermes gateway install` to re-snapshot your current PATH into the plist, then `hermes gateway start`. See the [Gateway Service docs](./index.md#macos-launchd) for details. | | **Messages not being received** | Verify `WHATSAPP_ALLOWED_USERS` includes the sender's number (with country code, no `+` or spaces). | | **Bot replies to strangers with a pairing code** | Set `whatsapp.unauthorized_dm_behavior: ignore` in `~/.hermes/config.yaml` if you want unauthorized DMs to be silently ignored instead. | diff --git a/website/docs/user-guide/profiles.md b/website/docs/user-guide/profiles.md new file mode 100644 index 000000000..5da6d8ab2 --- /dev/null +++ b/website/docs/user-guide/profiles.md @@ -0,0 +1,202 @@ +--- +sidebar_position: 2 +--- + +# Profiles: Running Multiple Agents + +Run multiple independent Hermes agents on the same machine — each with its own config, API keys, memory, sessions, skills, and gateway. + +## What are profiles? + +A profile is a fully isolated Hermes environment. Each profile gets its own directory containing its own `config.yaml`, `.env`, `SOUL.md`, memories, sessions, skills, cron jobs, and state database. Profiles let you run separate agents for different purposes — a coding assistant, a personal bot, a research agent — without any cross-contamination. + +When you create a profile, it automatically becomes its own command. Create a profile called `coder` and you immediately have `coder chat`, `coder setup`, `coder gateway start`, etc. + +## Quick start + +```bash +hermes profile create coder # creates profile + "coder" command alias +coder setup # configure API keys and model +coder chat # start chatting +``` + +That's it. `coder` is now a fully independent agent. It has its own config, its own memory, its own everything. + +## Creating a profile + +### Blank profile + +```bash +hermes profile create mybot +``` + +Creates a fresh profile with bundled skills seeded. Run `mybot setup` to configure API keys, model, and gateway tokens. + +### Clone config only (`--clone`) + +```bash +hermes profile create work --clone +``` + +Copies your current profile's `config.yaml`, `.env`, and `SOUL.md` into the new profile. Same API keys and model, but fresh sessions and memory. Edit `~/.hermes/profiles/work/.env` for different API keys, or `~/.hermes/profiles/work/SOUL.md` for a different personality. + +### Clone everything (`--clone-all`) + +```bash +hermes profile create backup --clone-all +``` + +Copies **everything** — config, API keys, personality, all memories, full session history, skills, cron jobs, plugins. A complete snapshot. Useful for backups or forking an agent that already has context. + +### Clone from a specific profile + +```bash +hermes profile create work --clone --clone-from coder +``` + +## Using profiles + +### Command aliases + +Every profile automatically gets a command alias at `~/.local/bin/`: + +```bash +coder chat # chat with the coder agent +coder setup # configure coder's settings +coder gateway start # start coder's gateway +coder doctor # check coder's health +coder skills list # list coder's skills +coder config set model.model anthropic/claude-sonnet-4 +``` + +The alias works with every hermes subcommand — it's just `hermes -p ` under the hood. + +### The `-p` flag + +You can also target a profile explicitly with any command: + +```bash +hermes -p coder chat +hermes --profile=coder doctor +hermes chat -p coder -q "hello" # works in any position +``` + +### Sticky default (`hermes profile use`) + +```bash +hermes profile use coder +hermes chat # now targets coder +hermes tools # configures coder's tools +hermes profile use default # switch back +``` + +Sets a default so plain `hermes` commands target that profile. Like `kubectl config use-context`. + +### Knowing where you are + +The CLI always shows which profile is active: + +- **Prompt**: `coder ❯` instead of `❯` +- **Banner**: Shows `Profile: coder` on startup +- **`hermes profile`**: Shows current profile name, path, model, gateway status + +## Running gateways + +Each profile runs its own gateway as a separate process with its own bot token: + +```bash +coder gateway start # starts coder's gateway +assistant gateway start # starts assistant's gateway (separate process) +``` + +### Different bot tokens + +Each profile has its own `.env` file. Configure a different Telegram/Discord/Slack bot token in each: + +```bash +# Edit coder's tokens +nano ~/.hermes/profiles/coder/.env + +# Edit assistant's tokens +nano ~/.hermes/profiles/assistant/.env +``` + +### Safety: token locks + +If two profiles accidentally use the same bot token, the second gateway will be blocked with a clear error naming the conflicting profile. Supported for Telegram, Discord, Slack, WhatsApp, and Signal. + +### Persistent services + +```bash +coder gateway install # creates hermes-gateway-coder systemd/launchd service +assistant gateway install # creates hermes-gateway-assistant service +``` + +Each profile gets its own service name. They run independently. + +## Configuring profiles + +Each profile has its own: + +- **`config.yaml`** — model, provider, toolsets, all settings +- **`.env`** — API keys, bot tokens +- **`SOUL.md`** — personality and instructions + +```bash +coder config set model.model anthropic/claude-sonnet-4 +echo "You are a focused coding assistant." > ~/.hermes/profiles/coder/SOUL.md +``` + +## Updating + +`hermes update` pulls code once (shared) and syncs new bundled skills to **all** profiles automatically: + +```bash +hermes update +# → Code updated (12 commits) +# → Skills synced: default (up to date), coder (+2 new), assistant (+2 new) +``` + +User-modified skills are never overwritten. + +## Managing profiles + +```bash +hermes profile list # show all profiles with status +hermes profile show coder # detailed info for one profile +hermes profile rename coder dev-bot # rename (updates alias + service) +hermes profile export coder # export to coder.tar.gz +hermes profile import coder.tar.gz # import from archive +``` + +## Deleting a profile + +```bash +hermes profile delete coder +``` + +This stops the gateway, removes the systemd/launchd service, removes the command alias, and deletes all profile data. You'll be asked to type the profile name to confirm. + +Use `--yes` to skip confirmation: `hermes profile delete coder --yes` + +:::note +You cannot delete the default profile (`~/.hermes`). To remove everything, use `hermes uninstall`. +::: + +## Tab completion + +```bash +# Bash +eval "$(hermes completion bash)" + +# Zsh +eval "$(hermes completion zsh)" +``` + +Add the line to your `~/.bashrc` or `~/.zshrc` for persistent completion. Completes profile names after `-p`, profile subcommands, and top-level commands. + +## How it works + +Profiles use the `HERMES_HOME` environment variable. When you run `coder chat`, the wrapper script sets `HERMES_HOME=~/.hermes/profiles/coder` before launching hermes. Since 119+ files in the codebase resolve paths via `get_hermes_home()`, everything automatically scopes to the profile's directory — config, sessions, memory, skills, state database, gateway PID, logs, and cron jobs. + +The default profile is simply `~/.hermes` itself. No migration needed — existing installs work identically. diff --git a/website/docs/user-guide/security.md b/website/docs/user-guide/security.md index b38cdcb14..4d51161e1 100644 --- a/website/docs/user-guide/security.md +++ b/website/docs/user-guide/security.md @@ -43,6 +43,8 @@ The following patterns trigger approval prompts (defined in `tools/approval.py`) | `bash -c`, `python -e` | Shell/script execution via flags | | `find -exec rm`, `find -delete` | Find with destructive actions | | Fork bomb patterns | Fork bombs | +| `pkill`/`killall` hermes/gateway | Self-termination prevention | +| `gateway run` with `&`/`disown`/`nohup` | Prevents starting gateway outside service manager | :::info **Container bypass**: When running in `docker`, `singularity`, `modal`, or `daytona` backends, dangerous command checks are **skipped** because the container itself is the security boundary. Destructive commands inside a container can't harm the host. @@ -276,7 +278,11 @@ required_environment_variables: help: Get a key from https://developers.google.com/tenor ``` -After loading this skill, `TENOR_API_KEY` passes through to both `execute_code` and `terminal` subprocesses — no manual configuration needed. +After loading this skill, `TENOR_API_KEY` passes through to `execute_code`, `terminal` (local), **and remote backends (Docker, Modal)** — no manual configuration needed. + +:::info Docker & Modal +Prior to v0.5.1, Docker's `forward_env` was a separate system from the skill passthrough. They are now merged — skill-declared env vars are automatically forwarded into Docker containers and Modal sandboxes without needing to add them to `docker_forward_env` manually. +::: **2. Config-based passthrough (manual)** @@ -289,17 +295,49 @@ terminal: - ANOTHER_TOKEN ``` +### Credential File Passthrough (OAuth tokens, etc.) {#credential-file-passthrough} + +Some skills need **files** (not just env vars) in the sandbox — for example, Google Workspace stores OAuth tokens as `google_token.json` in `~/.hermes/`. Skills declare these in frontmatter: + +```yaml +required_credential_files: + - path: google_token.json + description: Google OAuth2 token (created by setup script) + - path: google_client_secret.json + description: Google OAuth2 client credentials +``` + +When loaded, Hermes checks if these files exist in `~/.hermes/` and registers them for mounting: + +- **Docker**: Read-only bind mounts (`-v host:container:ro`) +- **Modal**: Mounted at sandbox creation + synced before each command (handles mid-session OAuth setup) +- **Local**: No action needed (files already accessible) + +You can also list credential files manually in `config.yaml`: + +```yaml +terminal: + credential_files: + - google_token.json + - my_custom_oauth_token.json +``` + +Paths are relative to `~/.hermes/`. Files are mounted to `/root/.hermes/` inside the container. + ### What Each Sandbox Filters | Sandbox | Default Filter | Passthrough Override | |---------|---------------|---------------------| | **execute_code** | Blocks vars containing `KEY`, `TOKEN`, `SECRET`, `PASSWORD`, `CREDENTIAL`, `PASSWD`, `AUTH` in name; only allows safe-prefix vars through | ✅ Passthrough vars bypass both checks | | **terminal** (local) | Blocks explicit Hermes infrastructure vars (provider keys, gateway tokens, tool API keys) | ✅ Passthrough vars bypass the blocklist | +| **terminal** (Docker) | No host env vars by default | ✅ Passthrough vars + `docker_forward_env` forwarded via `-e` | +| **terminal** (Modal) | No host env/files by default | ✅ Credential files mounted; env passthrough via sync | | **MCP** | Blocks everything except safe system vars + explicitly configured `env` | ❌ Not affected by passthrough (use MCP `env` config instead) | ### Security Considerations - The passthrough only affects vars you or your skills explicitly declare — the default security posture is unchanged for arbitrary LLM-generated code +- Credential files are mounted **read-only** into Docker containers - Skills Guard scans skill content for suspicious env access patterns before installation - Missing/unset vars are never registered (you can't leak what doesn't exist) - Hermes infrastructure secrets (provider API keys, gateway tokens) should never be added to `env_passthrough` — they have dedicated mechanisms @@ -392,7 +430,7 @@ security: When `tirith_fail_open` is `true` (default), commands proceed if tirith is not installed or times out. Set to `false` in high-security environments to block commands when tirith is unavailable. -Tirith's verdict integrates with the approval flow: safe commands pass through, suspicious commands trigger user approval, and dangerous commands are blocked. +Tirith's verdict integrates with the approval flow: safe commands pass through, while both suspicious and blocked commands trigger user approval with the full tirith findings (severity, title, description, safer alternatives). Users can approve or deny — the default choice is deny to keep unattended scenarios secure. ### Context File Injection Protection diff --git a/website/sidebars.ts b/website/sidebars.ts index c7fef2ae1..73c943031 100644 --- a/website/sidebars.ts +++ b/website/sidebars.ts @@ -37,6 +37,8 @@ const sidebars: SidebarsConfig = { 'user-guide/configuration', 'user-guide/sessions', 'user-guide/security', + 'user-guide/docker', + 'user-guide/profiles', { type: 'category', label: 'Messaging Gateway', @@ -52,6 +54,8 @@ const sidebars: SidebarsConfig = { 'user-guide/messaging/mattermost', 'user-guide/messaging/matrix', 'user-guide/messaging/dingtalk', + 'user-guide/messaging/feishu', + 'user-guide/messaging/wecom', 'user-guide/messaging/open-webui', 'user-guide/messaging/webhooks', ], @@ -152,6 +156,7 @@ const sidebars: SidebarsConfig = { 'reference/mcp-config-reference', 'reference/skills-catalog', 'reference/optional-skills-catalog', + 'reference/profile-commands', 'reference/environment-variables', 'reference/faq', ],