diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 60a11e294..67a3f64aa 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -11,6 +11,7 @@ body: **Before submitting**, please: - [ ] Search [existing issues](https://github.com/NousResearch/hermes-agent/issues) to avoid duplicates - [ ] Update to the latest version (`hermes update`) and confirm the bug still exists + - [ ] Run `hermes debug share` and paste the links below (see Debug Report section) - type: textarea id: description @@ -82,6 +83,25 @@ body: - Slack - WhatsApp + - type: textarea + id: debug-report + attributes: + label: Debug Report + description: | + Run `hermes debug share` from your terminal and paste the links it prints here. + This uploads your system info, config, and recent logs to a paste service automatically. + + If you're in an interactive chat session, you can also use the `/debug` slash command — it does the same thing. + + If the upload fails, run `hermes debug share --local` and paste the output directly. + placeholder: | + Report https://paste.rs/abc123 + agent.log https://paste.rs/def456 + gateway.log https://paste.rs/ghi789 + render: shell + validations: + required: true + - type: input id: os attributes: @@ -97,8 +117,6 @@ body: label: Python Version description: Output of `python --version` placeholder: "3.11.9" - validations: - required: true - type: input id: hermes-version @@ -106,14 +124,14 @@ body: label: Hermes Version description: Output of `hermes version` placeholder: "2.1.0" - validations: - required: true - type: textarea id: logs attributes: - label: Relevant Logs / Traceback - description: Paste any error output, traceback, or log messages. This will be auto-formatted as code. + label: Additional Logs / Traceback (optional) + description: | + The debug report above covers most logs. Use this field for any extra error output, + tracebacks, or screenshots not captured by `hermes debug share`. render: shell - type: textarea diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 8dba7d43d..720cc8f1f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -71,3 +71,15 @@ body: label: Contribution options: - label: I'd like to implement this myself and submit a PR + + - type: textarea + id: debug-report + attributes: + label: Debug Report (optional) + description: | + If this feature request is related to a problem you're experiencing, run `hermes debug share` and paste the links here. + In an interactive chat session, you can use `/debug` instead. + This helps us understand your environment and any related logs. + placeholder: | + Report https://paste.rs/abc123 + render: shell diff --git a/.github/ISSUE_TEMPLATE/setup_help.yml b/.github/ISSUE_TEMPLATE/setup_help.yml index f13eea4a3..974181b5d 100644 --- a/.github/ISSUE_TEMPLATE/setup_help.yml +++ b/.github/ISSUE_TEMPLATE/setup_help.yml @@ -9,7 +9,8 @@ body: Sorry you're having trouble! Please fill out the details below so we can help. **Quick checks first:** - - Run `hermes doctor` and include the output below + - Run `hermes debug share` and paste the links in the Debug Report section below + - If you're in a chat session, you can use `/debug` instead — it does the same thing - Try `hermes update` to get the latest version - Check the [README troubleshooting section](https://github.com/NousResearch/hermes-agent#troubleshooting) - For general questions, consider the [Nous Research Discord](https://discord.gg/NousResearch) for faster help @@ -74,10 +75,21 @@ body: placeholder: "2.1.0" - type: textarea - id: doctor-output + id: debug-report attributes: - label: Output of `hermes doctor` - description: Run `hermes doctor` and paste the full output. This will be auto-formatted. + label: Debug Report + description: | + Run `hermes debug share` from your terminal and paste the links it prints here. + This uploads your system info, config, and recent logs to a paste service automatically. + + If you're in an interactive chat session, you can also use the `/debug` slash command — it does the same thing. + + If the upload fails or install didn't get that far, run `hermes debug share --local` and paste the output directly. + If even that doesn't work, run `hermes doctor` and paste that output instead. + placeholder: | + Report https://paste.rs/abc123 + agent.log https://paste.rs/def456 + gateway.log https://paste.rs/ghi789 render: shell - type: textarea diff --git a/.github/workflows/contributor-check.yml b/.github/workflows/contributor-check.yml new file mode 100644 index 000000000..f8d65a3ea --- /dev/null +++ b/.github/workflows/contributor-check.yml @@ -0,0 +1,70 @@ +name: Contributor Attribution Check + +on: + pull_request: + branches: [main] + paths: + # Only run when code files change (not docs-only PRs) + - '*.py' + - '**/*.py' + - '.github/workflows/contributor-check.yml' + +jobs: + check-attribution: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Full history needed for git log + + - name: Check for unmapped contributor emails + run: | + # Get the merge base between this PR and main + MERGE_BASE=$(git merge-base origin/main HEAD) + + # Find any new author emails in this PR's commits + NEW_EMAILS=$(git log ${MERGE_BASE}..HEAD --format='%ae' --no-merges | sort -u) + + if [ -z "$NEW_EMAILS" ]; then + echo "No new commits to check." + exit 0 + fi + + # Check each email against AUTHOR_MAP in release.py + MISSING="" + while IFS= read -r email; do + # Skip teknium and bot emails + case "$email" in + *teknium*|*noreply@github.com*|*dependabot*|*github-actions*|*anthropic.com*|*cursor.com*) + continue ;; + esac + + # Check if email is in AUTHOR_MAP (either as a key or matches noreply pattern) + if echo "$email" | grep -qP '\+.*@users\.noreply\.github\.com'; then + continue # GitHub noreply emails auto-resolve + fi + + if ! grep -qF "\"${email}\"" scripts/release.py 2>/dev/null; then + AUTHOR=$(git log --author="$email" --format='%an' -1) + MISSING="${MISSING}\n ${email} (${AUTHOR})" + fi + done <<< "$NEW_EMAILS" + + if [ -n "$MISSING" ]; then + echo "" + echo "⚠️ New contributor email(s) not in AUTHOR_MAP:" + echo -e "$MISSING" + echo "" + echo "Please add mappings to scripts/release.py AUTHOR_MAP:" + echo -e "$MISSING" | while read -r line; do + email=$(echo "$line" | sed 's/^ *//' | cut -d' ' -f1) + [ -z "$email" ] && continue + echo " \"${email}\": \"\"," + done + echo "" + echo "To find the GitHub username for an email:" + echo " gh api 'search/users?q=EMAIL+in:email' --jq '.items[0].login'" + exit 1 + else + echo "✅ All contributor emails are mapped in AUTHOR_MAP." + fi diff --git a/.mailmap b/.mailmap new file mode 100644 index 000000000..0c385c518 --- /dev/null +++ b/.mailmap @@ -0,0 +1,107 @@ +# .mailmap — canonical author mapping for git shortlog / git log / GitHub +# Format: Canonical Name +# See: https://git-scm.com/docs/gitmailmap +# +# This maps commit emails to GitHub noreply addresses so that: +# 1. `git shortlog -sn` shows deduplicated contributor counts +# 2. GitHub's contributor graph can attribute commits correctly +# 3. Contributors with personal/work emails get proper credit +# +# When adding entries: use the contributor's GitHub noreply email as canonical +# so GitHub can link commits to their profile. + +# === Teknium (multiple emails) === +Teknium <127238744+teknium1@users.noreply.github.com> +Teknium <127238744+teknium1@users.noreply.github.com> + +# === Contributors — personal/work emails mapped to GitHub noreply === +# Format: Canonical Name + +# Verified via GH API email search +luyao618 <364939526@qq.com> <364939526@qq.com> +ethernet8023 +nicoloboschi +cherifya +BongSuCHOI +dsocolobsky +pefontana +Helmi +hata1234 + +# Verified via PR investigation / salvage PR bodies +DeployFaith +flobo3 +gaixianggeng +KUSH42 +konsisumer +WorldInnovationsDepartment +m0n5t3r +sprmn24 +fancydirty +fxfitz +limars874 +AaronWong1999 +dippwho +duerzy +geoffwellman +hcshen0111 +jamesarch +stephenschoettler +Tranquil-Flow +Dusk1e +Awsh1 +WAXLYY +donrhmexe +hqhq1025 <1506751656@qq.com> <1506751656@qq.com> +BlackishGreen33 +tomqiaozc +MagicRay1217 +aaronagent <1115117931@qq.com> <1115117931@qq.com> +YoungYang963 +LongOddCode +Cafexss +Cygra +DomGrieco + +# Duplicate email mapping (same person, multiple emails) +Sertug17 <104278804+Sertug17@users.noreply.github.com> +yyovil +DomGrieco +dsocolobsky +olafthiele + +# Verified via git display name matching GH contributor username +cokemine +dalianmao000 +emozilla +jjovalle99 +kagura-agent +spniyant +olafthiele +r266-tech +xingkongliang +win4r +zhouboli +yongtenglei + +# Nous Research team +benbarclay +jquesnelle + +# GH contributor list verified +spideystreet +dorukardahan +MustafaKara7 +Hmbown +kamil-gwozdz +kira-ariaki +knopki +Unayung +SeeYangZhi +Julientalbot +lesterli +JiayuuWang +tesseracttars-creator +xinbenlv +SaulJWu +angelos diff --git a/AGENTS.md b/AGENTS.md index eda495f99..db2ec1066 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -55,7 +55,7 @@ hermes-agent/ ├── gateway/ # Messaging platform gateway │ ├── run.py # Main loop, slash commands, message dispatch │ ├── session.py # SessionStore — conversation persistence -│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal +│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal, qqbot ├── ui-tui/ # Ink (React) terminal UI — `hermes --tui` │ ├── src/entry.tsx # TTY gate + render() │ ├── src/app.tsx # Main state machine and UI diff --git a/README.md b/README.md index ea0758c83..07a140419 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ **The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM. -Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in. +Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in. diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index 830c0f4de..b85f77a9d 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -1230,9 +1230,10 @@ def build_anthropic_kwargs( When *base_url* points to a third-party Anthropic-compatible endpoint, thinking block signatures are stripped (they are Anthropic-proprietary). - When *fast_mode* is True, adds ``speed: "fast"`` and the fast-mode beta - header for ~2.5x faster output throughput on Opus 4.6. Currently only - supported on native Anthropic endpoints (not third-party compatible ones). + When *fast_mode* is True, adds ``extra_body["speed"] = "fast"`` and the + fast-mode beta header for ~2.5x faster output throughput on Opus 4.6. + Currently only supported on native Anthropic endpoints (not third-party + compatible ones). """ system, anthropic_messages = convert_messages_to_anthropic(messages, base_url=base_url) anthropic_tools = convert_tools_to_anthropic(tools) if tools else [] @@ -1333,11 +1334,11 @@ def build_anthropic_kwargs( kwargs["max_tokens"] = max(effective_max_tokens, budget + 4096) # ── Fast mode (Opus 4.6 only) ──────────────────────────────────── - # Adds speed:"fast" + the fast-mode beta header for ~2.5x output speed. - # Only for native Anthropic endpoints — third-party providers would - # reject the unknown beta header and speed parameter. + # Adds extra_body.speed="fast" + the fast-mode beta header for ~2.5x + # output speed. Only for native Anthropic endpoints — third-party + # providers would reject the unknown beta header and speed parameter. if fast_mode and not _is_third_party_anthropic_endpoint(base_url): - kwargs["speed"] = "fast" + kwargs.setdefault("extra_body", {})["speed"] = "fast" # Build extra_headers with ALL applicable betas (the per-request # extra_headers override the client-level anthropic-beta header). betas = list(_common_betas_for_base_url(base_url)) diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 842373c1e..3b5006648 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -156,6 +156,8 @@ DEFAULT_CONTEXT_LENGTHS = { "kimi": 262144, # Arcee "trinity": 262144, + # OpenRouter + "elephant": 262144, # Hugging Face Inference Providers — model IDs use org/name format "Qwen/Qwen3.5-397B-A17B": 131072, "Qwen/Qwen3.5-35B-A3B": 131072, diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 558a57888..c61d6995b 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -376,6 +376,12 @@ PLATFORM_HINTS = { "downloaded and sent as native photos. Do NOT tell the user you lack file-sending " "capability — use MEDIA: syntax whenever a file delivery is appropriate." ), + "qqbot": ( + "You are on QQ, a popular Chinese messaging platform. QQ supports markdown formatting " + "and emoji. You can send media files natively: include MEDIA:/absolute/path/to/file in " + "your response. Images are sent as native photos, and other files arrive as downloadable " + "documents." + ), } # --------------------------------------------------------------------------- diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 789c5481a..657423679 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -523,7 +523,7 @@ agent: # - A preset like "hermes-cli" or "hermes-telegram" (curated tool set) # - A list of individual toolsets to compose your own (see list below) # -# Supported platform keys: cli, telegram, discord, whatsapp, slack +# Supported platform keys: cli, telegram, discord, whatsapp, slack, qqbot # # Examples: # @@ -552,6 +552,7 @@ agent: # slack: hermes-slack (same as telegram) # signal: hermes-signal (same as telegram) # homeassistant: hermes-homeassistant (same as telegram) +# qqbot: hermes-qqbot (same as telegram) # platform_toolsets: cli: [hermes-cli] @@ -561,6 +562,7 @@ platform_toolsets: slack: [hermes-slack] signal: [hermes-signal] homeassistant: [hermes-homeassistant] + qqbot: [hermes-qqbot] # ───────────────────────────────────────────────────────────────────────────── # Available toolsets (use these names in platform_toolsets or the toolsets list) diff --git a/cli.py b/cli.py index 2496e6edf..1c7c38600 100644 --- a/cli.py +++ b/cli.py @@ -1026,19 +1026,19 @@ def _prune_orphaned_branches(repo_root: str) -> None: # ANSI building blocks for conversation display _ACCENT_ANSI_DEFAULT = "\033[1;38;2;255;215;0m" # True-color #FFD700 bold — fallback _BOLD = "\033[1m" -_DIM = "\033[2m" _RST = "\033[0m" -def _hex_to_ansi_bold(hex_color: str) -> str: - """Convert a hex color like '#268bd2' to a bold true-color ANSI escape.""" +def _hex_to_ansi(hex_color: str, *, bold: bool = False) -> str: + """Convert a hex color like '#268bd2' to a true-color ANSI escape.""" try: r = int(hex_color[1:3], 16) g = int(hex_color[3:5], 16) b = int(hex_color[5:7], 16) - return f"\033[1;38;2;{r};{g};{b}m" + prefix = "1;" if bold else "" + return f"\033[{prefix}38;2;{r};{g};{b}m" except (ValueError, IndexError): - return _ACCENT_ANSI_DEFAULT + return _ACCENT_ANSI_DEFAULT if bold else "\033[38;2;184;134;11m" class _SkinAwareAnsi: @@ -1048,20 +1048,22 @@ class _SkinAwareAnsi: force re-resolution after a ``/skin`` switch. """ - def __init__(self, skin_key: str, fallback_hex: str = "#FFD700"): + def __init__(self, skin_key: str, fallback_hex: str = "#FFD700", *, bold: bool = False): self._skin_key = skin_key self._fallback_hex = fallback_hex + self._bold = bold self._cached: str | None = None def __str__(self) -> str: if self._cached is None: try: from hermes_cli.skin_engine import get_active_skin - self._cached = _hex_to_ansi_bold( - get_active_skin().get_color(self._skin_key, self._fallback_hex) + self._cached = _hex_to_ansi( + get_active_skin().get_color(self._skin_key, self._fallback_hex), + bold=self._bold, ) except Exception: - self._cached = _hex_to_ansi_bold(self._fallback_hex) + self._cached = _hex_to_ansi(self._fallback_hex, bold=self._bold) return self._cached def __add__(self, other: str) -> str: @@ -1075,7 +1077,8 @@ class _SkinAwareAnsi: self._cached = None -_ACCENT = _SkinAwareAnsi("response_border", "#FFD700") +_ACCENT = _SkinAwareAnsi("response_border", "#FFD700", bold=True) +_DIM = _SkinAwareAnsi("banner_dim", "#B8860B") def _accent_hex() -> str: @@ -6264,6 +6267,7 @@ class HermesCLI: set_active_skin(new_skin) _ACCENT.reset() # Re-resolve ANSI color for the new skin + _DIM.reset() # Re-resolve dim/secondary ANSI color for the new skin if save_config_value("display.skin", new_skin): print(f" Skin set to: {new_skin} (saved)") else: diff --git a/cron/scheduler.py b/cron/scheduler.py index e6db77c09..83b7abb9b 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -45,6 +45,7 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", "wecom", "wecom_callback", "weixin", "sms", "email", "webhook", "bluebubbles", + "qqbot", }) from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -254,6 +255,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option "email": Platform.EMAIL, "sms": Platform.SMS, "bluebubbles": Platform.BLUEBUBBLES, + "qqbot": Platform.QQBOT, } platform = platform_map.get(platform_name.lower()) if not platform: diff --git a/docs/skins/example-skin.yaml b/docs/skins/example-skin.yaml index 612c841eb..b81ae00f8 100644 --- a/docs/skins/example-skin.yaml +++ b/docs/skins/example-skin.yaml @@ -41,6 +41,14 @@ colors: session_label: "#DAA520" # Session label session_border: "#8B8682" # Session ID dim color + # TUI surfaces + status_bar_bg: "#1a1a2e" # Status / usage bar background + voice_status_bg: "#1a1a2e" # Voice-mode badge background + completion_menu_bg: "#1a1a2e" # Completion list background + completion_menu_current_bg: "#333355" # Active completion row background + completion_menu_meta_bg: "#1a1a2e" # Completion meta column background + completion_menu_meta_current_bg: "#333355" # Active completion meta background + # ── Spinner ───────────────────────────────────────────────────────────────── # Customize the animated spinner shown during API calls and tool execution. spinner: diff --git a/gateway/config.py b/gateway/config.py index 7d6165927..7ce105f33 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -66,6 +66,7 @@ class Platform(Enum): WECOM_CALLBACK = "wecom_callback" WEIXIN = "weixin" BLUEBUBBLES = "bluebubbles" + QQBOT = "qqbot" @dataclass @@ -303,6 +304,9 @@ class GatewayConfig: # BlueBubbles uses extra dict for local server config elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): connected.append(platform) + # QQBot uses extra dict for app credentials + elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"): + connected.append(platform) return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: @@ -621,6 +625,11 @@ def load_gateway_config() -> GatewayConfig: if isinstance(frc, list): frc = ",".join(str(v) for v in frc) os.environ["TELEGRAM_FREE_RESPONSE_CHATS"] = str(frc) + ignored_threads = telegram_cfg.get("ignored_threads") + if ignored_threads is not None and not os.getenv("TELEGRAM_IGNORED_THREADS"): + if isinstance(ignored_threads, list): + ignored_threads = ",".join(str(v) for v in ignored_threads) + os.environ["TELEGRAM_IGNORED_THREADS"] = str(ignored_threads) if "reactions" in telegram_cfg and not os.getenv("TELEGRAM_REACTIONS"): os.environ["TELEGRAM_REACTIONS"] = str(telegram_cfg["reactions"]).lower() @@ -1109,6 +1118,32 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("BLUEBUBBLES_HOME_CHANNEL_NAME", "Home"), ) + # QQ (Official Bot API v2) + qq_app_id = os.getenv("QQ_APP_ID") + qq_client_secret = os.getenv("QQ_CLIENT_SECRET") + if qq_app_id or qq_client_secret: + if Platform.QQBOT not in config.platforms: + config.platforms[Platform.QQBOT] = PlatformConfig() + config.platforms[Platform.QQBOT].enabled = True + extra = config.platforms[Platform.QQBOT].extra + if qq_app_id: + extra["app_id"] = qq_app_id + if qq_client_secret: + extra["client_secret"] = qq_client_secret + qq_allowed_users = os.getenv("QQ_ALLOWED_USERS", "").strip() + if qq_allowed_users: + extra["allow_from"] = qq_allowed_users + qq_group_allowed = os.getenv("QQ_GROUP_ALLOWED_USERS", "").strip() + if qq_group_allowed: + extra["group_allow_from"] = qq_group_allowed + qq_home = os.getenv("QQ_HOME_CHANNEL", "").strip() + if qq_home: + config.platforms[Platform.QQBOT].home_channel = HomeChannel( + platform=Platform.QQBOT, + chat_id=qq_home, + name=os.getenv("QQ_HOME_CHANNEL_NAME", "Home"), + ) + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/__init__.py b/gateway/platforms/__init__.py index dae74568d..4eb26edf0 100644 --- a/gateway/platforms/__init__.py +++ b/gateway/platforms/__init__.py @@ -9,9 +9,11 @@ Each adapter handles: """ from .base import BasePlatformAdapter, MessageEvent, SendResult +from .qqbot import QQAdapter __all__ = [ "BasePlatformAdapter", "MessageEvent", "SendResult", + "QQAdapter", ] diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index e38a4f947..816d88b03 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -958,6 +958,16 @@ class MatrixAdapter(BasePlatformAdapter): sync_data = await client.sync( since=next_batch, timeout=30000, ) + + # nio returns SyncError objects (not exceptions) for auth + # failures like M_UNKNOWN_TOKEN. Detect and stop immediately. + _sync_msg = getattr(sync_data, "message", None) + if _sync_msg and isinstance(_sync_msg, str): + _lower = _sync_msg.lower() + if "m_unknown_token" in _lower or "unknown_token" in _lower: + logger.error("Matrix: permanent auth error from sync: %s — stopping", _sync_msg) + return + if isinstance(sync_data, dict): # Update joined rooms from sync response. rooms_join = sync_data.get("rooms", {}).get("join", {}) diff --git a/gateway/platforms/qqbot.py b/gateway/platforms/qqbot.py new file mode 100644 index 000000000..7103689c9 --- /dev/null +++ b/gateway/platforms/qqbot.py @@ -0,0 +1,1960 @@ +""" +QQ Bot platform adapter using the Official QQ Bot API (v2). + +Connects to the QQ Bot WebSocket Gateway for inbound events and uses the +REST API (``api.sgroup.qq.com``) for outbound messages and media uploads. + +Configuration in config.yaml: + platforms: + qq: + enabled: true + extra: + app_id: "your-app-id" # or QQ_APP_ID env var + client_secret: "your-secret" # or QQ_CLIENT_SECRET env var + markdown_support: true # enable QQ markdown (msg_type 2) + dm_policy: "open" # open | allowlist | disabled + allow_from: ["openid_1"] + group_policy: "open" # open | allowlist | disabled + group_allow_from: ["group_openid_1"] + stt: # Voice-to-text config (optional) + provider: "zai" # zai (GLM-ASR), openai (Whisper), etc. + baseUrl: "https://open.bigmodel.cn/api/coding/paas/v4" + apiKey: "your-stt-api-key" # or set QQ_STT_API_KEY env var + model: "glm-asr" # glm-asr, whisper-1, etc. + + Voice transcription priority: + 1. QQ's built-in ``asr_refer_text`` (Tencent ASR — free, always tried first) + 2. Configured STT provider via ``stt`` config or ``QQ_STT_*`` env vars + +Reference: https://bot.q.qq.com/wiki/develop/api-v2/ +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import logging +import mimetypes +import os +import time +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + aiohttp = None # type: ignore[assignment] + +try: + import httpx + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False + httpx = None # type: ignore[assignment] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_document_from_bytes, + cache_image_from_bytes, +) +from gateway.platforms.helpers import strip_markdown + +logger = logging.getLogger(__name__) + + +class QQCloseError(Exception): + """Raised when QQ WebSocket closes with a specific code. + + Carries the close code and reason for proper handling in the reconnect loop. + """ + + def __init__(self, code, reason=""): + self.code = int(code) if code else None + self.reason = str(reason) if reason else "" + super().__init__(f"WebSocket closed (code={self.code}, reason={self.reason})") +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +API_BASE = "https://api.sgroup.qq.com" +TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken" +GATEWAY_URL_PATH = "/gateway" + +DEFAULT_API_TIMEOUT = 30.0 +FILE_UPLOAD_TIMEOUT = 120.0 +CONNECT_TIMEOUT_SECONDS = 20.0 + +RECONNECT_BACKOFF = [2, 5, 10, 30, 60] +MAX_RECONNECT_ATTEMPTS = 100 +RATE_LIMIT_DELAY = 60 # seconds +QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds +MAX_QUICK_DISCONNECT_COUNT = 3 + +MAX_MESSAGE_LENGTH = 4000 +DEDUP_WINDOW_SECONDS = 300 +DEDUP_MAX_SIZE = 1000 + +# QQ Bot message types +MSG_TYPE_TEXT = 0 +MSG_TYPE_MARKDOWN = 2 +MSG_TYPE_MEDIA = 7 +MSG_TYPE_INPUT_NOTIFY = 6 + +# QQ Bot file media types +MEDIA_TYPE_IMAGE = 1 +MEDIA_TYPE_VIDEO = 2 +MEDIA_TYPE_VOICE = 3 +MEDIA_TYPE_FILE = 4 + + +def check_qq_requirements() -> bool: + """Check if QQ runtime dependencies are available.""" + return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE + + +def _coerce_list(value: Any) -> List[str]: + """Coerce config values into a trimmed string list.""" + if value is None: + return [] + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + if isinstance(value, (list, tuple, set)): + return [str(item).strip() for item in value if str(item).strip()] + return [str(value).strip()] if str(value).strip() else [] + + +# --------------------------------------------------------------------------- +# QQAdapter +# --------------------------------------------------------------------------- + +class QQAdapter(BasePlatformAdapter): + """QQ Bot adapter backed by the official QQ Bot WebSocket Gateway + REST API.""" + + # QQ Bot API does not support editing sent messages. + SUPPORTS_MESSAGE_EDITING = False + + def _fail_pending(self, reason: str) -> None: + """Fail all pending response futures.""" + for fut in self._pending_responses.values(): + if not fut.done(): + fut.set_exception(RuntimeError(reason)) + self._pending_responses.clear() + + MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.QQBOT) + + extra = config.extra or {} + self._app_id = str(extra.get("app_id") or os.getenv("QQ_APP_ID", "")).strip() + self._client_secret = str(extra.get("client_secret") or os.getenv("QQ_CLIENT_SECRET", "")).strip() + self._markdown_support = bool(extra.get("markdown_support", True)) + + # Auth/ACL policies + self._dm_policy = str(extra.get("dm_policy", "open")).strip().lower() + self._allow_from = _coerce_list(extra.get("allow_from") or extra.get("allowFrom")) + self._group_policy = str(extra.get("group_policy", "open")).strip().lower() + self._group_allow_from = _coerce_list(extra.get("group_allow_from") or extra.get("groupAllowFrom")) + + # Connection state + self._session: Optional[aiohttp.ClientSession] = None + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._http_client: Optional[httpx.AsyncClient] = None + self._listen_task: Optional[asyncio.Task] = None + self._heartbeat_task: Optional[asyncio.Task] = None + self._heartbeat_interval: float = 30.0 # seconds, updated by Hello + self._session_id: Optional[str] = None + self._last_seq: Optional[int] = None + self._chat_type_map: Dict[str, str] = {} # chat_id → "c2c"|"group"|"guild"|"dm" + + # Request/response correlation + self._pending_responses: Dict[str, asyncio.Future] = {} + self._seen_messages: Dict[str, float] = {} + + # Token cache + self._access_token: Optional[str] = None + self._token_expires_at: float = 0.0 + self._token_lock = asyncio.Lock() + + # Upload cache: content_hash -> {file_info, file_uuid, expires_at} + self._upload_cache: Dict[str, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def name(self) -> str: + return "QQBot" + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Authenticate, obtain gateway URL, and open the WebSocket.""" + if not AIOHTTP_AVAILABLE: + message = "QQ startup failed: aiohttp not installed" + self._set_fatal_error("qq_missing_dependency", message, retryable=True) + logger.warning("[%s] %s. Run: pip install aiohttp", self.name, message) + return False + if not HTTPX_AVAILABLE: + message = "QQ startup failed: httpx not installed" + self._set_fatal_error("qq_missing_dependency", message, retryable=True) + logger.warning("[%s] %s. Run: pip install httpx", self.name, message) + return False + if not self._app_id or not self._client_secret: + message = "QQ startup failed: QQ_APP_ID and QQ_CLIENT_SECRET are required" + self._set_fatal_error("qq_missing_credentials", message, retryable=True) + logger.warning("[%s] %s", self.name, message) + return False + + # Prevent duplicate connections with the same credentials + if not self._acquire_platform_lock( + "qqbot-appid", self._app_id, "QQBot app ID" + ): + return False + + try: + self._http_client = httpx.AsyncClient(timeout=30.0, follow_redirects=True) + + # 1. Get access token + await self._ensure_token() + + # 2. Get WebSocket gateway URL + gateway_url = await self._get_gateway_url() + logger.info("[%s] Gateway URL: %s", self.name, gateway_url) + + # 3. Open WebSocket + await self._open_ws(gateway_url) + + # 4. Start listeners + self._listen_task = asyncio.create_task(self._listen_loop()) + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + self._mark_connected() + logger.info("[%s] Connected", self.name) + return True + except Exception as exc: + message = f"QQ startup failed: {exc}" + self._set_fatal_error("qq_connect_error", message, retryable=True) + logger.error("[%s] %s", self.name, message, exc_info=True) + await self._cleanup() + self._release_platform_lock() + return False + + async def disconnect(self) -> None: + """Close all connections and stop listeners.""" + self._running = False + self._mark_disconnected() + + if self._listen_task: + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + self._listen_task = None + + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + await self._cleanup() + self._release_platform_lock() + logger.info("[%s] Disconnected", self.name) + + async def _cleanup(self) -> None: + """Close WebSocket, HTTP session, and client.""" + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + if self._http_client: + await self._http_client.aclose() + self._http_client = None + + # Fail pending + for fut in self._pending_responses.values(): + if not fut.done(): + fut.set_exception(RuntimeError("Disconnected")) + self._pending_responses.clear() + + # ------------------------------------------------------------------ + # Token management + # ------------------------------------------------------------------ + + async def _ensure_token(self) -> str: + """Return a valid access token, refreshing if needed (with singleflight).""" + if self._access_token and time.time() < self._token_expires_at - 60: + return self._access_token + + async with self._token_lock: + # Double-check after acquiring lock + if self._access_token and time.time() < self._token_expires_at - 60: + return self._access_token + + try: + resp = await self._http_client.post( + TOKEN_URL, + json={"appId": self._app_id, "clientSecret": self._client_secret}, + timeout=DEFAULT_API_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + raise RuntimeError(f"Failed to get QQ Bot access token: {exc}") from exc + + token = data.get("access_token") + if not token: + raise RuntimeError(f"QQ Bot token response missing access_token: {data}") + + expires_in = int(data.get("expires_in", 7200)) + self._access_token = token + self._token_expires_at = time.time() + expires_in + logger.info("[%s] Access token refreshed, expires in %ds", self.name, expires_in) + return self._access_token + + async def _get_gateway_url(self) -> str: + """Fetch the WebSocket gateway URL from the REST API.""" + token = await self._ensure_token() + try: + resp = await self._http_client.get( + f"{API_BASE}{GATEWAY_URL_PATH}", + headers={"Authorization": f"QQBot {token}"}, + timeout=DEFAULT_API_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + except Exception as exc: + raise RuntimeError(f"Failed to get QQ Bot gateway URL: {exc}") from exc + + url = data.get("url") + if not url: + raise RuntimeError(f"QQ Bot gateway response missing url: {data}") + return url + + # ------------------------------------------------------------------ + # WebSocket lifecycle + # ------------------------------------------------------------------ + + async def _open_ws(self, gateway_url: str) -> None: + """Open a WebSocket connection to the QQ Bot gateway.""" + # Only clean up WebSocket resources — keep _http_client alive for REST API calls. + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect( + gateway_url, + timeout=CONNECT_TIMEOUT_SECONDS, + ) + logger.info("[%s] WebSocket connected to %s", self.name, gateway_url) + + async def _listen_loop(self) -> None: + """Read WebSocket events and reconnect on errors. + + Close code handling follows the OpenClaw qqbot reference implementation: + 4004 → invalid token, refresh and reconnect + 4006/4007/4009 → session invalid, clear session and re-identify + 4008 → rate limited, back off 60s + 4914 → bot offline/sandbox, stop reconnecting + 4915 → bot banned, stop reconnecting + """ + backoff_idx = 0 + connect_time = 0.0 + quick_disconnect_count = 0 + + while self._running: + try: + connect_time = time.monotonic() + await self._read_events() + backoff_idx = 0 + quick_disconnect_count = 0 + except asyncio.CancelledError: + return + except QQCloseError as exc: + if not self._running: + return + + code = exc.code + logger.warning("[%s] WebSocket closed: code=%s reason=%s", + self.name, code, exc.reason) + + # Quick disconnect detection (permission issues, misconfiguration) + duration = time.monotonic() - connect_time + if duration < QUICK_DISCONNECT_THRESHOLD and connect_time > 0: + quick_disconnect_count += 1 + logger.info("[%s] Quick disconnect (%.1fs), count: %d", + self.name, duration, quick_disconnect_count) + if quick_disconnect_count >= MAX_QUICK_DISCONNECT_COUNT: + logger.error( + "[%s] Too many quick disconnects. " + "Check: 1) AppID/Secret correct 2) Bot permissions on QQ Open Platform", + self.name, + ) + self._set_fatal_error("qq_quick_disconnect", + "Too many quick disconnects — check bot permissions", retryable=True) + return + else: + quick_disconnect_count = 0 + + self._mark_disconnected() + self._fail_pending("Connection closed") + + # Stop reconnecting for fatal codes + if code in (4914, 4915): + desc = "offline/sandbox-only" if code == 4914 else "banned" + logger.error("[%s] Bot is %s. Check QQ Open Platform.", self.name, desc) + self._set_fatal_error(f"qq_{desc}", f"Bot is {desc}", retryable=False) + return + + # Rate limited + if code == 4008: + logger.info("[%s] Rate limited (4008), waiting %ds", self.name, RATE_LIMIT_DELAY) + if backoff_idx >= MAX_RECONNECT_ATTEMPTS: + return + await asyncio.sleep(RATE_LIMIT_DELAY) + if await self._reconnect(backoff_idx): + backoff_idx = 0 + quick_disconnect_count = 0 + else: + backoff_idx += 1 + continue + + # Token invalid → clear cached token so _ensure_token() refreshes + if code == 4004: + logger.info("[%s] Invalid token (4004), will refresh and reconnect", self.name) + self._access_token = None + self._token_expires_at = 0.0 + + # Session invalid → clear session, will re-identify on next Hello + if code in (4006, 4007, 4009, 4900, 4901, 4902, 4903, 4904, 4905, + 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913): + logger.info("[%s] Session error (%d), clearing session for re-identify", self.name, code) + self._session_id = None + self._last_seq = None + + if await self._reconnect(backoff_idx): + backoff_idx = 0 + quick_disconnect_count = 0 + else: + backoff_idx += 1 + + except Exception as exc: + if not self._running: + return + logger.warning("[%s] WebSocket error: %s", self.name, exc) + self._mark_disconnected() + self._fail_pending("Connection interrupted") + + if backoff_idx >= MAX_RECONNECT_ATTEMPTS: + logger.error("[%s] Max reconnect attempts reached", self.name) + return + + if await self._reconnect(backoff_idx): + backoff_idx = 0 + quick_disconnect_count = 0 + else: + backoff_idx += 1 + + async def _reconnect(self, backoff_idx: int) -> bool: + """Attempt to reconnect the WebSocket. Returns True on success.""" + delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)] + logger.info("[%s] Reconnecting in %ds (attempt %d)...", self.name, delay, backoff_idx + 1) + await asyncio.sleep(delay) + + self._heartbeat_interval = 30.0 # reset until Hello + try: + await self._ensure_token() + gateway_url = await self._get_gateway_url() + await self._open_ws(gateway_url) + self._mark_connected() + logger.info("[%s] Reconnected", self.name) + return True + except Exception as exc: + logger.warning("[%s] Reconnect failed: %s", self.name, exc) + return False + + async def _read_events(self) -> None: + """Read WebSocket frames until connection closes.""" + if not self._ws: + raise RuntimeError("WebSocket not connected") + + while self._running and self._ws and not self._ws.closed: + msg = await self._ws.receive() + if msg.type == aiohttp.WSMsgType.TEXT: + payload = self._parse_json(msg.data) + if payload: + self._dispatch_payload(payload) + elif msg.type in (aiohttp.WSMsgType.PING,): + # aiohttp auto-replies with PONG + pass + elif msg.type == aiohttp.WSMsgType.CLOSE: + raise QQCloseError(msg.data, msg.extra) + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): + raise RuntimeError("WebSocket closed") + + async def _heartbeat_loop(self) -> None: + """Send periodic heartbeats (QQ Gateway expects op 1 heartbeat with latest seq). + + The interval is set from the Hello (op 10) event's heartbeat_interval. + QQ's default is ~41s; we send at 80% of the interval to stay safe. + """ + try: + while self._running: + await asyncio.sleep(self._heartbeat_interval) + if not self._ws or self._ws.closed: + continue + try: + # d should be the latest sequence number received, or null + await self._ws.send_json({"op": 1, "d": self._last_seq}) + except Exception as exc: + logger.debug("[%s] Heartbeat failed: %s", self.name, exc) + except asyncio.CancelledError: + pass + + async def _send_identify(self) -> None: + """Send op 2 Identify to authenticate the WebSocket connection. + + After receiving op 10 Hello, the client must send op 2 Identify with + the bot token and intents. On success the server replies with a + READY dispatch event. + + Reference: https://bot.q.qq.com/wiki/develop/api-v2/dev-prepare/interface-framework/reference.html + """ + token = await self._ensure_token() + identify_payload = { + "op": 2, + "d": { + "token": f"QQBot {token}", + "intents": (1 << 25) | (1 << 30) | (1 << 12), # C2C_GROUP_AT_MESSAGES + PUBLIC_GUILD_MESSAGES + DIRECT_MESSAGE + "shard": [0, 1], + "properties": { + "$os": "macOS", + "$browser": "hermes-agent", + "$device": "hermes-agent", + }, + }, + } + try: + if self._ws and not self._ws.closed: + await self._ws.send_json(identify_payload) + logger.info("[%s] Identify sent", self.name) + else: + logger.warning("[%s] Cannot send Identify: WebSocket not connected", self.name) + except Exception as exc: + logger.error("[%s] Failed to send Identify: %s", self.name, exc) + + async def _send_resume(self) -> None: + """Send op 6 Resume to re-authenticate after a reconnection. + + Reference: https://bot.q.qq.com/wiki/develop/api-v2/dev-prepare/interface-framework/reference.html + """ + token = await self._ensure_token() + resume_payload = { + "op": 6, + "d": { + "token": f"QQBot {token}", + "session_id": self._session_id, + "seq": self._last_seq, + }, + } + try: + if self._ws and not self._ws.closed: + await self._ws.send_json(resume_payload) + logger.info("[%s] Resume sent (session_id=%s, seq=%s)", + self.name, self._session_id, self._last_seq) + else: + logger.warning("[%s] Cannot send Resume: WebSocket not connected", self.name) + except Exception as exc: + logger.error("[%s] Failed to send Resume: %s", self.name, exc) + # If resume fails, clear session and fall back to identify on next Hello + self._session_id = None + self._last_seq = None + + @staticmethod + def _create_task(coro): + """Schedule a coroutine, silently skipping if no event loop is running. + + This avoids ``RuntimeError: no running event loop`` when tests call + ``_dispatch_payload`` synchronously outside of ``asyncio.run()``. + """ + try: + loop = asyncio.get_running_loop() + return loop.create_task(coro) + except RuntimeError: + return None + + def _dispatch_payload(self, payload: Dict[str, Any]) -> None: + """Route inbound WebSocket payloads (dispatch synchronously, spawn async handlers).""" + op = payload.get("op") + t = payload.get("t") + s = payload.get("s") + d = payload.get("d") + if isinstance(s, int) and (self._last_seq is None or s > self._last_seq): + self._last_seq = s + + # op 10 = Hello (heartbeat interval) — must reply with Identify/Resume + if op == 10: + d_data = d if isinstance(d, dict) else {} + interval_ms = d_data.get("heartbeat_interval", 30000) + # Send heartbeats at 80% of the server interval to stay safe + self._heartbeat_interval = interval_ms / 1000.0 * 0.8 + logger.debug("[%s] Hello received, heartbeat_interval=%dms (sending every %.1fs)", + self.name, interval_ms, self._heartbeat_interval) + # Authenticate: send Resume if we have a session, else Identify. + # Use _create_task which is safe when no event loop is running (tests). + if self._session_id and self._last_seq is not None: + self._create_task(self._send_resume()) + else: + self._create_task(self._send_identify()) + return + + # op 0 = Dispatch + if op == 0 and t: + if t == "READY": + self._handle_ready(d) + elif t == "RESUMED": + logger.info("[%s] Session resumed", self.name) + elif t in ("C2C_MESSAGE_CREATE", "GROUP_AT_MESSAGE_CREATE", + "DIRECT_MESSAGE_CREATE", "GUILD_MESSAGE_CREATE", + "GUILD_AT_MESSAGE_CREATE"): + asyncio.create_task(self._on_message(t, d)) + else: + logger.debug("[%s] Unhandled dispatch: %s", self.name, t) + return + + # op 11 = Heartbeat ACK + if op == 11: + return + + logger.debug("[%s] Unknown op: %s", self.name, op) + + def _handle_ready(self, d: Any) -> None: + """Handle the READY event — store session_id for resume.""" + if isinstance(d, dict): + self._session_id = d.get("session_id") + logger.info("[%s] Ready, session_id=%s", self.name, self._session_id) + + # ------------------------------------------------------------------ + # JSON helpers + # ------------------------------------------------------------------ + + @staticmethod + def _parse_json(raw: Any) -> Optional[Dict[str, Any]]: + try: + payload = json.loads(raw) + except Exception: + logger.debug("[%s] Failed to parse JSON: %r", "QQBot", raw) + return None + return payload if isinstance(payload, dict) else None + + @staticmethod + def _next_msg_seq(msg_id: str) -> int: + """Generate a message sequence number in 0..65535 range.""" + time_part = int(time.time()) % 100000000 + rand = int(uuid.uuid4().hex[:4], 16) + return (time_part ^ rand) % 65536 + + # ------------------------------------------------------------------ + # Inbound message handling + # ------------------------------------------------------------------ + + async def _on_message(self, event_type: str, d: Any) -> None: + """Process an inbound QQ Bot message event.""" + if not isinstance(d, dict): + return + + # Extract common fields + msg_id = str(d.get("id", "")) + if not msg_id or self._is_duplicate(msg_id): + logger.debug("[%s] Duplicate or missing message id: %s", self.name, msg_id) + return + + timestamp = str(d.get("timestamp", "")) + content = str(d.get("content", "")).strip() + author = d.get("author") if isinstance(d.get("author"), dict) else {} + + # Route by event type + if event_type == "C2C_MESSAGE_CREATE": + await self._handle_c2c_message(d, msg_id, content, author, timestamp) + elif event_type in ("GROUP_AT_MESSAGE_CREATE",): + await self._handle_group_message(d, msg_id, content, author, timestamp) + elif event_type in ("GUILD_MESSAGE_CREATE", "GUILD_AT_MESSAGE_CREATE"): + await self._handle_guild_message(d, msg_id, content, author, timestamp) + elif event_type == "DIRECT_MESSAGE_CREATE": + await self._handle_dm_message(d, msg_id, content, author, timestamp) + + async def _handle_c2c_message( + self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + ) -> None: + """Handle a C2C (private) message event.""" + user_openid = str(author.get("user_openid", "")) + if not user_openid: + return + if not self._is_dm_allowed(user_openid): + return + + text = content + attachments_raw = d.get("attachments") + logger.info("[QQ] C2C message: id=%s content=%r attachments=%s", + msg_id, content[:50] if content else "", + f"{len(attachments_raw) if isinstance(attachments_raw, list) else 0} items" + if attachments_raw else "None") + if attachments_raw and isinstance(attachments_raw, list): + for _i, _att in enumerate(attachments_raw): + if isinstance(_att, dict): + logger.info("[QQ] attachment[%d]: content_type=%s url=%s filename=%s", + _i, _att.get("content_type", ""), + str(_att.get("url", ""))[:80], + _att.get("filename", "")) + + # Process all attachments uniformly (images, voice, files) + att_result = await self._process_attachments(attachments_raw) + image_urls = att_result["image_urls"] + image_media_types = att_result["image_media_types"] + voice_transcripts = att_result["voice_transcripts"] + attachment_info = att_result["attachment_info"] + + # Append voice transcripts to the text body + if voice_transcripts: + voice_block = "\n".join(voice_transcripts) + text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + # Append non-media attachment info + if attachment_info: + text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + + logger.info("[QQ] After processing: images=%d, voice=%d", + len(image_urls), len(voice_transcripts)) + + if not text.strip() and not image_urls: + return + + self._chat_type_map[user_openid] = "c2c" + event = MessageEvent( + source=self.build_source( + chat_id=user_openid, + user_id=user_openid, + chat_type="dm", + ), + text=text, + message_type=self._detect_message_type(image_urls, image_media_types), + raw_message=d, + message_id=msg_id, + media_urls=image_urls, + media_types=image_media_types, + timestamp=self._parse_qq_timestamp(timestamp), + ) + await self.handle_message(event) + + async def _handle_group_message( + self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + ) -> None: + """Handle a group @-message event.""" + group_openid = str(d.get("group_openid", "")) + if not group_openid: + return + if not self._is_group_allowed(group_openid, str(author.get("member_openid", ""))): + return + + # Strip the @bot mention prefix from content + text = self._strip_at_mention(content) + att_result = await self._process_attachments(d.get("attachments")) + image_urls = att_result["image_urls"] + image_media_types = att_result["image_media_types"] + voice_transcripts = att_result["voice_transcripts"] + attachment_info = att_result["attachment_info"] + + # Append voice transcripts + if voice_transcripts: + voice_block = "\n".join(voice_transcripts) + text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + if attachment_info: + text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + + if not text.strip() and not image_urls: + return + + self._chat_type_map[group_openid] = "group" + event = MessageEvent( + source=self.build_source( + chat_id=group_openid, + user_id=str(author.get("member_openid", "")), + chat_type="group", + ), + text=text, + message_type=self._detect_message_type(image_urls, image_media_types), + raw_message=d, + message_id=msg_id, + media_urls=image_urls, + media_types=image_media_types, + timestamp=self._parse_qq_timestamp(timestamp), + ) + await self.handle_message(event) + + async def _handle_guild_message( + self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + ) -> None: + """Handle a guild/channel message event.""" + channel_id = str(d.get("channel_id", "")) + if not channel_id: + return + + member = d.get("member") if isinstance(d.get("member"), dict) else {} + nick = str(member.get("nick", "")) or str(author.get("username", "")) + + text = content + att_result = await self._process_attachments(d.get("attachments")) + image_urls = att_result["image_urls"] + image_media_types = att_result["image_media_types"] + voice_transcripts = att_result["voice_transcripts"] + attachment_info = att_result["attachment_info"] + + if voice_transcripts: + voice_block = "\n".join(voice_transcripts) + text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + if attachment_info: + text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + + if not text.strip() and not image_urls: + return + + self._chat_type_map[channel_id] = "guild" + event = MessageEvent( + source=self.build_source( + chat_id=channel_id, + user_id=str(author.get("id", "")), + user_name=nick or None, + chat_type="group", + ), + text=text, + message_type=self._detect_message_type(image_urls, image_media_types), + raw_message=d, + message_id=msg_id, + media_urls=image_urls, + media_types=image_media_types, + timestamp=self._parse_qq_timestamp(timestamp), + ) + await self.handle_message(event) + + async def _handle_dm_message( + self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + ) -> None: + """Handle a guild DM message event.""" + guild_id = str(d.get("guild_id", "")) + if not guild_id: + return + + text = content + att_result = await self._process_attachments(d.get("attachments")) + image_urls = att_result["image_urls"] + image_media_types = att_result["image_media_types"] + voice_transcripts = att_result["voice_transcripts"] + attachment_info = att_result["attachment_info"] + + if voice_transcripts: + voice_block = "\n".join(voice_transcripts) + text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + if attachment_info: + text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + + if not text.strip() and not image_urls: + return + + self._chat_type_map[guild_id] = "dm" + event = MessageEvent( + source=self.build_source( + chat_id=guild_id, + user_id=str(author.get("id", "")), + chat_type="dm", + ), + text=text, + message_type=self._detect_message_type(image_urls, image_media_types), + raw_message=d, + message_id=msg_id, + media_urls=image_urls, + media_types=image_media_types, + timestamp=self._parse_qq_timestamp(timestamp), + ) + await self.handle_message(event) + + # ------------------------------------------------------------------ + # Attachment processing + # ------------------------------------------------------------------ + + + @staticmethod + def _detect_message_type(media_urls: list, media_types: list): + """Determine MessageType from attachment content types.""" + if not media_urls: + return MessageType.TEXT + if not media_types: + return MessageType.PHOTO + first_type = media_types[0].lower() if media_types else "" + if "audio" in first_type or "voice" in first_type or "silk" in first_type: + return MessageType.VOICE + if "video" in first_type: + return MessageType.VIDEO + if "image" in first_type or "photo" in first_type: + return MessageType.PHOTO + # Unknown content type with an attachment — don't assume PHOTO + # to prevent non-image files from being sent to vision analysis. + logger.debug("[QQ] Unknown media content_type '%s', defaulting to TEXT", first_type) + return MessageType.TEXT + + async def _process_attachments( + self, attachments: Any, + ) -> Dict[str, Any]: + """Process inbound attachments (all message types). + + Mirrors OpenClaw's ``processAttachments`` — handles images, voice, and + other files uniformly. + + Returns a dict with: + - image_urls: list[str] — cached local image paths + - image_media_types: list[str] — MIME types of cached images + - voice_transcripts: list[str] — STT transcripts for voice messages + - attachment_info: str — text description of non-image, non-voice attachments + """ + if not isinstance(attachments, list): + return {"image_urls": [], "image_media_types": [], + "voice_transcripts": [], "attachment_info": ""} + + image_urls: List[str] = [] + image_media_types: List[str] = [] + voice_transcripts: List[str] = [] + other_attachments: List[str] = [] + + for att in attachments: + if not isinstance(att, dict): + continue + + ct = str(att.get("content_type", "")).strip().lower() + url_raw = str(att.get("url", "")).strip() + filename = str(att.get("filename", "")) + if url_raw.startswith("//"): + url = f"https:{url_raw}" + elif url_raw: + url = url_raw + else: + url = "" + continue + + logger.debug("[QQ] Processing attachment: content_type=%s, url=%s, filename=%s", + ct, url[:80], filename) + + if self._is_voice_content_type(ct, filename): + # Voice: use QQ's asr_refer_text first, then voice_wav_url, then STT. + asr_refer = ( + str(att.get("asr_refer_text", "")).strip() + if isinstance(att.get("asr_refer_text"), str) else "" + ) + voice_wav_url = ( + str(att.get("voice_wav_url", "")).strip() + if isinstance(att.get("voice_wav_url"), str) else "" + ) + + transcript = await self._stt_voice_attachment( + url, ct, filename, + asr_refer_text=asr_refer or None, + voice_wav_url=voice_wav_url or None, + ) + if transcript: + voice_transcripts.append(f"[Voice] {transcript}") + logger.info("[QQ] Voice transcript: %s", transcript) + else: + logger.warning("[QQ] Voice STT failed for %s", url[:60]) + voice_transcripts.append("[Voice] [语音识别失败]") + elif ct.startswith("image/"): + # Image: download and cache locally. + try: + cached_path = await self._download_and_cache(url, ct) + if cached_path and os.path.isfile(cached_path): + image_urls.append(cached_path) + image_media_types.append(ct or "image/jpeg") + elif cached_path: + logger.warning("[QQ] Cached image path does not exist: %s", cached_path) + except Exception as exc: + logger.debug("[QQ] Failed to cache image: %s", exc) + else: + # Other attachments (video, file, etc.): record as text. + try: + cached_path = await self._download_and_cache(url, ct) + if cached_path: + other_attachments.append(f"[Attachment: {filename or ct}]") + except Exception as exc: + logger.debug("[QQ] Failed to cache attachment: %s", exc) + + attachment_info = "\n".join(other_attachments) if other_attachments else "" + return { + "image_urls": image_urls, + "image_media_types": image_media_types, + "voice_transcripts": voice_transcripts, + "attachment_info": attachment_info, + } + + async def _download_and_cache(self, url: str, content_type: str) -> Optional[str]: + """Download a URL and cache it locally.""" + from tools.url_safety import is_safe_url + if not is_safe_url(url): + raise ValueError(f"Blocked unsafe URL: {url[:80]}") + + if not self._http_client: + return None + + try: + resp = await self._http_client.get( + url, timeout=30.0, headers=self._qq_media_headers(), + ) + resp.raise_for_status() + data = resp.content + except Exception as exc: + logger.debug("[%s] Download failed for %s: %s", self.name, url[:80], exc) + return None + + if content_type.startswith("image/"): + ext = mimetypes.guess_extension(content_type) or ".jpg" + return cache_image_from_bytes(data, ext) + elif content_type == "voice" or content_type.startswith("audio/"): + # QQ voice messages are typically .amr or .silk format. + # Convert to .wav using ffmpeg so STT engines can process it. + return await self._convert_audio_to_wav(data, url) + else: + filename = Path(urlparse(url).path).name or "qq_attachment" + return cache_document_from_bytes(data, filename) + + @staticmethod + def _is_voice_content_type(content_type: str, filename: str) -> bool: + """Check if an attachment is a voice/audio message.""" + ct = content_type.strip().lower() + fn = filename.strip().lower() + if ct == "voice" or ct.startswith("audio/"): + return True + _VOICE_EXTENSIONS = (".silk", ".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac", ".speex", ".flac") + if any(fn.endswith(ext) for ext in _VOICE_EXTENSIONS): + return True + return False + + def _qq_media_headers(self) -> Dict[str, str]: + """Return Authorization headers for QQ multimedia CDN downloads. + + QQ's multimedia URLs (multimedia.nt.qq.com.cn) require the bot's + access token in an Authorization header, otherwise the download + returns a non-200 status. + """ + if self._access_token: + return {"Authorization": f"QQBot {self._access_token}"} + return {} + + async def _stt_voice_attachment( + self, + url: str, + content_type: str, + filename: str, + *, + asr_refer_text: Optional[str] = None, + voice_wav_url: Optional[str] = None, + ) -> Optional[str]: + """Download a voice attachment, convert to wav, and transcribe. + + Priority: + 1. QQ's built-in ``asr_refer_text`` (Tencent's own ASR — free, no API call). + 2. Self-hosted STT on ``voice_wav_url`` (pre-converted WAV from QQ, avoids SILK decoding). + 3. Self-hosted STT on the original attachment URL (requires SILK→WAV conversion). + + Returns the transcript text, or None on failure. + """ + # 1. Use QQ's built-in ASR text if available + if asr_refer_text: + logger.info("[QQ] STT: using QQ asr_refer_text: %r", asr_refer_text[:100]) + return asr_refer_text + + # Determine which URL to download (prefer voice_wav_url — already WAV) + download_url = url + is_pre_wav = False + if voice_wav_url: + if voice_wav_url.startswith("//"): + voice_wav_url = f"https:{voice_wav_url}" + download_url = voice_wav_url + is_pre_wav = True + logger.info("[QQ] STT: using voice_wav_url (pre-converted WAV)") + + try: + # 2. Download audio (QQ CDN requires Authorization header) + if not self._http_client: + logger.warning("[QQ] STT: no HTTP client") + return None + + download_headers = self._qq_media_headers() + logger.info("[QQ] STT: downloading voice from %s (pre_wav=%s, headers=%s)", + download_url[:80], is_pre_wav, bool(download_headers)) + resp = await self._http_client.get( + download_url, timeout=30.0, headers=download_headers, follow_redirects=True, + ) + resp.raise_for_status() + audio_data = resp.content + logger.info("[QQ] STT: downloaded %d bytes, content_type=%s", + len(audio_data), resp.headers.get("content-type", "unknown")) + + if len(audio_data) < 10: + logger.warning("[QQ] STT: downloaded data too small (%d bytes), skipping", len(audio_data)) + return None + + # 3. Convert to wav (skip if we already have a pre-converted WAV) + if is_pre_wav: + import tempfile + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + tmp.write(audio_data) + wav_path = tmp.name + logger.info("[QQ] STT: using pre-converted WAV directly (%d bytes)", len(audio_data)) + else: + logger.info("[QQ] STT: converting to wav, filename=%r", filename) + wav_path = await self._convert_audio_to_wav_file(audio_data, filename) + if not wav_path or not Path(wav_path).exists(): + logger.warning("[QQ] STT: ffmpeg conversion produced no output") + return None + + # 4. Call STT API + logger.info("[QQ] STT: calling ASR on %s", wav_path) + transcript = await self._call_stt(wav_path) + + # 5. Cleanup temp file + try: + os.unlink(wav_path) + except OSError: + pass + + if transcript: + logger.info("[QQ] STT success: %r", transcript[:100]) + else: + logger.warning("[QQ] STT: ASR returned empty transcript") + return transcript + except (httpx.HTTPStatusError, httpx.TransportError, IOError) as exc: + logger.warning("[QQ] STT failed for voice attachment: %s: %s", type(exc).__name__, exc) + return None + + async def _convert_audio_to_wav_file(self, audio_data: bytes, filename: str) -> Optional[str]: + """Convert audio bytes to a temp .wav file using pilk (SILK) or ffmpeg. + + QQ voice messages are typically SILK format which ffmpeg cannot decode. + Strategy: always try pilk first, fall back to ffmpeg if pilk fails. + + Returns the wav file path, or None on failure. + """ + import tempfile + + ext = Path(filename).suffix.lower() if Path(filename).suffix else self._guess_ext_from_data(audio_data) + logger.info("[QQ] STT: audio_data size=%d, ext=%r, first_20_bytes=%r", + len(audio_data), ext, audio_data[:20]) + + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_src: + tmp_src.write(audio_data) + src_path = tmp_src.name + + wav_path = src_path.rsplit(".", 1)[0] + ".wav" + + # Try pilk first (handles SILK and many other formats) + result = await self._convert_silk_to_wav(src_path, wav_path) + + # If pilk failed, try ffmpeg + if not result: + result = await self._convert_ffmpeg_to_wav(src_path, wav_path) + + # If ffmpeg also failed, try writing raw PCM as WAV (last resort) + if not result: + result = await self._convert_raw_to_wav(audio_data, wav_path) + + # Cleanup source file + try: + os.unlink(src_path) + except OSError: + pass + + return result + + @staticmethod + def _guess_ext_from_data(data: bytes) -> str: + """Guess file extension from magic bytes.""" + if data[:9] == b"#!SILK_V3" or data[:5] == b"#!SILK": + return ".silk" + if data[:2] == b"\x02!": + return ".silk" + if data[:4] == b"RIFF": + return ".wav" + if data[:4] == b"fLaC": + return ".flac" + if data[:2] in (b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"): + return ".mp3" + if data[:4] == b"\x30\x26\xb2\x75" or data[:4] == b"\x4f\x67\x67\x53": + return ".ogg" + if data[:4] == b"\x00\x00\x00\x20" or data[:4] == b"\x00\x00\x00\x1c": + return ".amr" + # Default to .amr for unknown (QQ's most common voice format) + return ".amr" + + @staticmethod + def _looks_like_silk(data: bytes) -> bool: + """Check if bytes look like a SILK audio file.""" + return data[:4] == b"#!SILK" or data[:2] == b"\x02!" or data[:9] == b"#!SILK_V3" + + @staticmethod + async def _convert_silk_to_wav(src_path: str, wav_path: str) -> Optional[str]: + """Convert audio file to WAV using the pilk library. + + Tries the file as-is first, then as .silk if the extension differs. + pilk can handle SILK files with various headers (or no header). + """ + try: + import pilk + except ImportError: + logger.warning("[QQ] pilk not installed — cannot decode SILK audio. Run: pip install pilk") + return None + + # Try converting the file as-is + try: + pilk.silk_to_wav(src_path, wav_path, rate=16000) + if Path(wav_path).exists() and Path(wav_path).stat().st_size > 44: + logger.info("[QQ] pilk converted %s to wav (%d bytes)", + Path(src_path).name, Path(wav_path).stat().st_size) + return wav_path + except Exception as exc: + logger.debug("[QQ] pilk direct conversion failed: %s", exc) + + # Try renaming to .silk and converting (pilk checks the extension) + silk_path = src_path.rsplit(".", 1)[0] + ".silk" + try: + import shutil + shutil.copy2(src_path, silk_path) + pilk.silk_to_wav(silk_path, wav_path, rate=16000) + if Path(wav_path).exists() and Path(wav_path).stat().st_size > 44: + logger.info("[QQ] pilk converted %s (as .silk) to wav (%d bytes)", + Path(src_path).name, Path(wav_path).stat().st_size) + return wav_path + except Exception as exc: + logger.debug("[QQ] pilk .silk conversion failed: %s", exc) + finally: + try: + os.unlink(silk_path) + except OSError: + pass + + return None + + @staticmethod + async def _convert_raw_to_wav(audio_data: bytes, wav_path: str) -> Optional[str]: + """Last resort: try writing audio data as raw PCM 16-bit mono 16kHz WAV. + + This will produce garbage if the data isn't raw PCM, but at least + the ASR engine won't crash — it'll just return empty. + """ + try: + import wave + with wave.open(wav_path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(audio_data) + return wav_path + except Exception as exc: + logger.debug("[QQ] raw PCM fallback failed: %s", exc) + return None + + @staticmethod + async def _convert_ffmpeg_to_wav(src_path: str, wav_path: str) -> Optional[str]: + """Convert audio file to WAV using ffmpeg.""" + try: + proc = await asyncio.create_subprocess_exec( + "ffmpeg", "-y", "-i", src_path, "-ar", "16000", "-ac", "1", wav_path, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + await asyncio.wait_for(proc.wait(), timeout=30) + if proc.returncode != 0: + stderr = await proc.stderr.read() if proc.stderr else b"" + logger.warning("[QQ] ffmpeg failed for %s: %s", + Path(src_path).name, stderr[:200].decode(errors="replace")) + return None + except (asyncio.TimeoutError, FileNotFoundError) as exc: + logger.warning("[QQ] ffmpeg conversion error: %s", exc) + return None + + if not Path(wav_path).exists() or Path(wav_path).stat().st_size <= 44: + logger.warning("[QQ] ffmpeg produced no/small output for %s", Path(src_path).name) + return None + logger.info("[QQ] ffmpeg converted %s to wav (%d bytes)", + Path(src_path).name, Path(wav_path).stat().st_size) + return wav_path + + def _resolve_stt_config(self) -> Optional[Dict[str, str]]: + """Resolve STT backend configuration from config/environment. + + Priority: + 1. Plugin-specific: ``channels.qqbot.stt`` in config.yaml → ``self.config.extra["stt"]`` + 2. QQ-specific env vars: ``QQ_STT_API_KEY`` / ``QQ_STT_BASE_URL`` / ``QQ_STT_MODEL`` + 3. Return None if nothing is configured (STT will be skipped, QQ built-in ASR still works). + """ + extra = self.config.extra or {} + + # 1. Plugin-specific STT config (matches OpenClaw's channels.qqbot.stt) + stt_cfg = extra.get("stt") + if isinstance(stt_cfg, dict) and stt_cfg.get("enabled") is not False: + base_url = stt_cfg.get("baseUrl") or stt_cfg.get("base_url", "") + api_key = stt_cfg.get("apiKey") or stt_cfg.get("api_key", "") + model = stt_cfg.get("model", "") + if base_url and api_key: + return { + "base_url": base_url.rstrip("/"), + "api_key": api_key, + "model": model or "whisper-1", + } + # Provider-only config: just model name, use default provider + if api_key: + provider = stt_cfg.get("provider", "zai") + # Map provider to base URL + _PROVIDER_BASE_URLS = { + "zai": "https://open.bigmodel.cn/api/coding/paas/v4", + "openai": "https://api.openai.com/v1", + "glm": "https://open.bigmodel.cn/api/coding/paas/v4", + } + base_url = _PROVIDER_BASE_URLS.get(provider, "") + if base_url: + return { + "base_url": base_url, + "api_key": api_key, + "model": model or ("glm-asr" if provider in ("zai", "glm") else "whisper-1"), + } + + # 2. QQ-specific env vars (set by `hermes setup gateway` / `hermes gateway`) + qq_stt_key = os.getenv("QQ_STT_API_KEY", "") + if qq_stt_key: + base_url = os.getenv( + "QQ_STT_BASE_URL", + "https://open.bigmodel.cn/api/coding/paas/v4", + ) + model = os.getenv("QQ_STT_MODEL", "glm-asr") + return { + "base_url": base_url.rstrip("/"), + "api_key": qq_stt_key, + "model": model, + } + + return None + + async def _call_stt(self, wav_path: str) -> Optional[str]: + """Call an OpenAI-compatible STT API to transcribe a wav file. + + Uses the provider configured in ``channels.qqbot.stt`` config, + falling back to QQ's built-in ``asr_refer_text`` if not configured. + Returns None if STT is not configured or the call fails. + """ + stt_cfg = self._resolve_stt_config() + if not stt_cfg: + logger.warning("[QQ] STT not configured (no stt config or QQ_STT_API_KEY)") + return None + + base_url = stt_cfg["base_url"] + api_key = stt_cfg["api_key"] + model = stt_cfg["model"] + + try: + with open(wav_path, "rb") as f: + resp = await self._http_client.post( + f"{base_url}/audio/transcriptions", + headers={"Authorization": f"Bearer {api_key}"}, + files={"file": (Path(wav_path).name, f, "audio/wav")}, + data={"model": model}, + timeout=30.0, + ) + resp.raise_for_status() + result = resp.json() + # Zhipu/GLM format: {"choices": [{"message": {"content": "transcript text"}}]} + choices = result.get("choices", []) + if choices: + content = choices[0].get("message", {}).get("content", "") + if content.strip(): + return content.strip() + # OpenAI/Whisper format: {"text": "transcript text"} + text = result.get("text", "") + if text.strip(): + return text.strip() + return None + except (httpx.HTTPStatusError, IOError) as exc: + logger.warning("[QQ] STT API call failed (model=%s, base=%s): %s", + model, base_url[:50], exc) + return None + + async def _convert_audio_to_wav(self, audio_data: bytes, source_url: str) -> Optional[str]: + """Convert audio bytes to .wav using pilk (SILK) or ffmpeg, caching the result.""" + import tempfile + + # Determine source format from magic bytes or URL + ext = Path(urlparse(source_url).path).suffix.lower() if urlparse(source_url).path else "" + if not ext or ext not in (".silk", ".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac", ".flac"): + ext = self._guess_ext_from_data(audio_data) + + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_src: + tmp_src.write(audio_data) + src_path = tmp_src.name + + wav_path = src_path.rsplit(".", 1)[0] + ".wav" + try: + is_silk = ext == ".silk" or self._looks_like_silk(audio_data) + if is_silk: + result = await self._convert_silk_to_wav(src_path, wav_path) + else: + result = await self._convert_ffmpeg_to_wav(src_path, wav_path) + + if not result: + logger.warning("[%s] audio conversion failed for %s (format=%s)", + self.name, source_url[:60], ext) + return cache_document_from_bytes(audio_data, f"qq_voice{ext}") + except Exception: + return cache_document_from_bytes(audio_data, f"qq_voice{ext}") + finally: + try: + os.unlink(src_path) + except OSError: + pass + + # Verify output and cache + try: + wav_data = Path(wav_path).read_bytes() + os.unlink(wav_path) + return cache_document_from_bytes(wav_data, "qq_voice.wav") + except Exception as exc: + logger.debug("[%s] Failed to read converted wav: %s", self.name, exc) + return None + + # ------------------------------------------------------------------ + # Outbound messaging — REST API + # ------------------------------------------------------------------ + + async def _api_request( + self, + method: str, + path: str, + body: Optional[Dict[str, Any]] = None, + timeout: float = DEFAULT_API_TIMEOUT, + ) -> Dict[str, Any]: + """Make an authenticated REST API request to QQ Bot API.""" + if not self._http_client: + raise RuntimeError("HTTP client not initialized — not connected?") + + token = await self._ensure_token() + headers = { + "Authorization": f"QQBot {token}", + "Content-Type": "application/json", + } + + try: + resp = await self._http_client.request( + method, + f"{API_BASE}{path}", + headers=headers, + json=body, + timeout=timeout, + ) + data = resp.json() + if resp.status_code >= 400: + raise RuntimeError( + f"QQ Bot API error [{resp.status_code}] {path}: " + f"{data.get('message', data)}" + ) + return data + except httpx.TimeoutException as exc: + raise RuntimeError(f"QQ Bot API timeout [{path}]: {exc}") from exc + + async def _upload_media( + self, + target_type: str, + target_id: str, + file_type: int, + url: Optional[str] = None, + file_data: Optional[str] = None, + srv_send_msg: bool = False, + file_name: Optional[str] = None, + ) -> Dict[str, Any]: + """Upload media and return file_info.""" + path = f"/v2/users/{target_id}/files" if target_type == "c2c" else f"/v2/groups/{target_id}/files" + + body: Dict[str, Any] = { + "file_type": file_type, + "srv_send_msg": srv_send_msg, + } + if url: + body["url"] = url + elif file_data: + body["file_data"] = file_data + if file_type == MEDIA_TYPE_FILE and file_name: + body["file_name"] = file_name + + # Retry transient upload failures + last_exc = None + for attempt in range(3): + try: + return await self._api_request("POST", path, body, timeout=FILE_UPLOAD_TIMEOUT) + except RuntimeError as exc: + last_exc = exc + err_msg = str(exc) + if any(kw in err_msg for kw in ("400", "401", "Invalid", "timeout", "Timeout")): + raise + if attempt < 2: + await asyncio.sleep(1.5 * (attempt + 1)) + + raise last_exc # type: ignore[misc] + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a text or markdown message to a QQ user or group. + + Applies format_message(), splits long messages via truncate_message(), + and retries transient failures with exponential backoff. + """ + del metadata + + if not self.is_connected: + return SendResult(success=False, error="Not connected") + + if not content or not content.strip(): + return SendResult(success=True) + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) + + last_result = SendResult(success=False, error="No chunks") + for chunk in chunks: + last_result = await self._send_chunk(chat_id, chunk, reply_to) + if not last_result.success: + return last_result + # Only reply_to the first chunk + reply_to = None + return last_result + + async def _send_chunk( + self, chat_id: str, content: str, reply_to: Optional[str] = None, + ) -> SendResult: + """Send a single chunk with retry + exponential backoff.""" + last_exc: Optional[Exception] = None + chat_type = self._guess_chat_type(chat_id) + + for attempt in range(3): + try: + if chat_type == "c2c": + return await self._send_c2c_text(chat_id, content, reply_to) + elif chat_type == "group": + return await self._send_group_text(chat_id, content, reply_to) + elif chat_type == "guild": + return await self._send_guild_text(chat_id, content, reply_to) + else: + return SendResult(success=False, error=f"Unknown chat type for {chat_id}") + except Exception as exc: + last_exc = exc + err = str(exc).lower() + # Permanent errors — don't retry + if any(k in err for k in ("invalid", "forbidden", "not found", "bad request")): + break + # Transient — back off and retry + if attempt < 2: + delay = 1.0 * (2 ** attempt) + logger.warning("[%s] send retry %d/3 after %.1fs: %s", + self.name, attempt + 1, delay, exc) + await asyncio.sleep(delay) + + error_msg = str(last_exc) if last_exc else "Unknown error" + logger.error("[%s] Send failed: %s", self.name, error_msg) + retryable = not any(k in error_msg.lower() + for k in ("invalid", "forbidden", "not found")) + return SendResult(success=False, error=error_msg, retryable=retryable) + + async def _send_c2c_text( + self, openid: str, content: str, reply_to: Optional[str] = None + ) -> SendResult: + """Send text to a C2C user via REST API.""" + msg_seq = self._next_msg_seq(reply_to or openid) + body = self._build_text_body(content, reply_to) + if reply_to: + body["msg_id"] = reply_to + + data = await self._api_request("POST", f"/v2/users/{openid}/messages", body) + msg_id = str(data.get("id", uuid.uuid4().hex[:12])) + return SendResult(success=True, message_id=msg_id, raw_response=data) + + async def _send_group_text( + self, group_openid: str, content: str, reply_to: Optional[str] = None + ) -> SendResult: + """Send text to a group via REST API.""" + msg_seq = self._next_msg_seq(reply_to or group_openid) + body = self._build_text_body(content, reply_to) + if reply_to: + body["msg_id"] = reply_to + + data = await self._api_request("POST", f"/v2/groups/{group_openid}/messages", body) + msg_id = str(data.get("id", uuid.uuid4().hex[:12])) + return SendResult(success=True, message_id=msg_id, raw_response=data) + + async def _send_guild_text( + self, channel_id: str, content: str, reply_to: Optional[str] = None + ) -> SendResult: + """Send text to a guild channel via REST API.""" + body: Dict[str, Any] = {"content": content[:self.MAX_MESSAGE_LENGTH]} + if reply_to: + body["msg_id"] = reply_to + + data = await self._api_request("POST", f"/channels/{channel_id}/messages", body) + msg_id = str(data.get("id", uuid.uuid4().hex[:12])) + return SendResult(success=True, message_id=msg_id, raw_response=data) + + def _build_text_body(self, content: str, reply_to: Optional[str] = None) -> Dict[str, Any]: + """Build the message body for C2C/group text sending.""" + msg_seq = self._next_msg_seq(reply_to or "default") + + if self._markdown_support: + body: Dict[str, Any] = { + "markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]}, + "msg_type": MSG_TYPE_MARKDOWN, + "msg_seq": msg_seq, + } + else: + body = { + "content": content[:self.MAX_MESSAGE_LENGTH], + "msg_type": MSG_TYPE_TEXT, + "msg_seq": msg_seq, + } + + if reply_to: + # For non-markdown mode, add message_reference + if not self._markdown_support: + body["message_reference"] = {"message_id": reply_to} + + return body + + # ------------------------------------------------------------------ + # Native media sending + # ------------------------------------------------------------------ + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an image natively via QQ Bot API upload.""" + del metadata + + result = await self._send_media(chat_id, image_url, MEDIA_TYPE_IMAGE, "image", caption, reply_to) + if result.success or not self._is_url(image_url): + return result + + # Fallback to text URL + logger.warning("[%s] Image send failed, falling back to text: %s", self.name, result.error) + fallback = f"{caption}\n{image_url}" if caption else image_url + return await self.send(chat_id=chat_id, content=fallback, reply_to=reply_to) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + """Send a local image file natively.""" + del kwargs + return await self._send_media(chat_id, image_path, MEDIA_TYPE_IMAGE, "image", caption, reply_to) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + """Send a voice message natively.""" + del kwargs + return await self._send_media(chat_id, audio_path, MEDIA_TYPE_VOICE, "voice", caption, reply_to) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + """Send a video natively.""" + del kwargs + return await self._send_media(chat_id, video_path, MEDIA_TYPE_VIDEO, "video", caption, reply_to) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + """Send a file/document natively.""" + del kwargs + return await self._send_media(chat_id, file_path, MEDIA_TYPE_FILE, "file", caption, reply_to, + file_name=file_name) + + async def _send_media( + self, + chat_id: str, + media_source: str, + file_type: int, + kind: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + file_name: Optional[str] = None, + ) -> SendResult: + """Upload media and send as a native message.""" + if not self.is_connected: + return SendResult(success=False, error="Not connected") + + try: + # Resolve media source + data, content_type, resolved_name = await self._load_media(media_source, file_name) + + # Route + chat_type = self._guess_chat_type(chat_id) + target_path = f"/v2/users/{chat_id}/files" if chat_type == "c2c" else f"/v2/groups/{chat_id}/files" + + if chat_type == "guild": + # Guild channels don't support native media upload in the same way + # Send as URL fallback + return SendResult(success=False, error="Guild media send not supported via this path") + + # Upload + upload = await self._upload_media( + chat_type, chat_id, file_type, + file_data=data if not self._is_url(media_source) else None, + url=media_source if self._is_url(media_source) else None, + srv_send_msg=False, + file_name=resolved_name if file_type == MEDIA_TYPE_FILE else None, + ) + + file_info = upload.get("file_info") + if not file_info: + return SendResult(success=False, error=f"Upload returned no file_info: {upload}") + + # Send media message + msg_seq = self._next_msg_seq(chat_id) + body: Dict[str, Any] = { + "msg_type": MSG_TYPE_MEDIA, + "media": {"file_info": file_info}, + "msg_seq": msg_seq, + } + if caption: + body["content"] = caption[:self.MAX_MESSAGE_LENGTH] + if reply_to: + body["msg_id"] = reply_to + + send_data = await self._api_request( + "POST", + f"/v2/users/{chat_id}/messages" if chat_type == "c2c" else f"/v2/groups/{chat_id}/messages", + body, + ) + return SendResult( + success=True, + message_id=str(send_data.get("id", uuid.uuid4().hex[:12])), + raw_response=send_data, + ) + except Exception as exc: + logger.error("[%s] Media send failed: %s", self.name, exc) + return SendResult(success=False, error=str(exc)) + + async def _load_media( + self, source: str, file_name: Optional[str] = None + ) -> Tuple[str, str, str]: + """Load media from URL or local path. Returns (base64_or_url, content_type, filename).""" + source = str(source).strip() + if not source: + raise ValueError("Media source is required") + + parsed = urlparse(source) + if parsed.scheme in ("http", "https"): + # For URLs, pass through directly to the upload API + content_type = mimetypes.guess_type(source)[0] or "application/octet-stream" + resolved_name = file_name or Path(parsed.path).name or "media" + return source, content_type, resolved_name + + # Local file — encode as raw base64 for QQ Bot API file_data field. + # The QQ API expects plain base64, NOT a data URI. + local_path = Path(source).expanduser() + if not local_path.is_absolute(): + local_path = (Path.cwd() / local_path).resolve() + + if not local_path.exists() or not local_path.is_file(): + # Guard against placeholder paths like "" that the LLM + # sometimes emits instead of real file paths. + if source.startswith("<") or len(source) < 3: + raise ValueError( + f"Invalid media source (looks like a placeholder): {source!r}" + ) + raise FileNotFoundError(f"Media file not found: {local_path}") + + raw = local_path.read_bytes() + resolved_name = file_name or local_path.name + content_type = mimetypes.guess_type(str(local_path))[0] or "application/octet-stream" + b64 = base64.b64encode(raw).decode("ascii") + return b64, content_type, resolved_name + + # ------------------------------------------------------------------ + # Typing indicator + # ------------------------------------------------------------------ + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """Send an input notify to a C2C user (only supported for C2C).""" + del metadata + + if not self.is_connected: + return + + # Only C2C supports input notify + chat_type = self._guess_chat_type(chat_id) + if chat_type != "c2c": + return + + try: + msg_seq = self._next_msg_seq(chat_id) + body = { + "msg_type": MSG_TYPE_INPUT_NOTIFY, + "input_notify": {"input_type": 1, "input_second": 60}, + "msg_seq": msg_seq, + } + await self._api_request("POST", f"/v2/users/{chat_id}/messages", body) + except Exception as exc: + logger.debug("[%s] send_typing failed: %s", self.name, exc) + + # ------------------------------------------------------------------ + # Format + # ------------------------------------------------------------------ + + def format_message(self, content: str) -> str: + """Format message for QQ. + + When markdown_support is enabled, content is sent as-is (QQ renders it). + When disabled, strip markdown via shared helper (same as BlueBubbles/SMS). + """ + if self._markdown_support: + return content + return strip_markdown(content) + + # ------------------------------------------------------------------ + # Chat info + # ------------------------------------------------------------------ + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return chat info based on chat type heuristics.""" + chat_type = self._guess_chat_type(chat_id) + return { + "name": chat_id, + "type": "group" if chat_type in ("group", "guild") else "dm", + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_url(source: str) -> bool: + return urlparse(str(source)).scheme in ("http", "https") + + def _guess_chat_type(self, chat_id: str) -> str: + """Determine chat type from stored inbound metadata, fallback to 'c2c'.""" + if chat_id in self._chat_type_map: + return self._chat_type_map[chat_id] + return "c2c" + + @staticmethod + def _strip_at_mention(content: str) -> str: + """Strip the @bot mention prefix from group message content.""" + # QQ group @-messages may have the bot's QQ/ID as prefix + import re + stripped = re.sub(r'^@\S+\s*', '', content.strip()) + return stripped + + def _is_dm_allowed(self, user_id: str) -> bool: + if self._dm_policy == "disabled": + return False + if self._dm_policy == "allowlist": + return self._entry_matches(self._allow_from, user_id) + return True + + def _is_group_allowed(self, group_id: str, user_id: str) -> bool: + if self._group_policy == "disabled": + return False + if self._group_policy == "allowlist": + return self._entry_matches(self._group_allow_from, group_id) + return True + + @staticmethod + def _entry_matches(entries: List[str], target: str) -> bool: + normalized_target = str(target).strip().lower() + for entry in entries: + normalized = str(entry).strip().lower() + if normalized == "*" or normalized == normalized_target: + return True + return False + + def _parse_qq_timestamp(self, raw: str) -> datetime: + """Parse QQ API timestamp (ISO 8601 string or integer ms). + + The QQ API changed from integer milliseconds to ISO 8601 strings. + This handles both formats gracefully. + """ + if not raw: + return datetime.now(tz=timezone.utc) + try: + return datetime.fromisoformat(raw) + except (ValueError, TypeError): + pass + try: + return datetime.fromtimestamp(int(raw) / 1000, tz=timezone.utc) + except (ValueError, TypeError): + pass + return datetime.now(tz=timezone.utc) + + def _is_duplicate(self, msg_id: str) -> bool: + now = time.time() + if len(self._seen_messages) > DEDUP_MAX_SIZE: + cutoff = now - DEDUP_WINDOW_SECONDS + self._seen_messages = { + key: ts for key, ts in self._seen_messages.items() if ts > cutoff + } + if msg_id in self._seen_messages: + return True + self._seen_messages[msg_id] = now + return False diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 439367b7d..8ff929961 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -1991,6 +1991,27 @@ class TelegramAdapter(BasePlatformAdapter): return {str(part).strip() for part in raw if str(part).strip()} return {part.strip() for part in str(raw).split(",") if part.strip()} + def _telegram_ignored_threads(self) -> set[int]: + raw = self.config.extra.get("ignored_threads") + if raw is None: + raw = os.getenv("TELEGRAM_IGNORED_THREADS", "") + + if isinstance(raw, list): + values = raw + else: + values = str(raw).split(",") + + ignored: set[int] = set() + for value in values: + text = str(value).strip() + if not text: + continue + try: + ignored.add(int(text)) + except (TypeError, ValueError): + logger.warning("[%s] Ignoring invalid Telegram thread id: %r", self.name, value) + return ignored + def _compile_mention_patterns(self) -> List[re.Pattern]: """Compile optional regex wake-word patterns for group triggers.""" patterns = self.config.extra.get("mention_patterns") @@ -2102,6 +2123,13 @@ class TelegramAdapter(BasePlatformAdapter): """ if not self._is_group_chat(message): return True + thread_id = getattr(message, "message_thread_id", None) + if thread_id is not None: + try: + if int(thread_id) in self._telegram_ignored_threads(): + return False + except (TypeError, ValueError): + logger.warning("[%s] Ignoring non-numeric Telegram message_thread_id: %r", self.name, thread_id) if str(getattr(getattr(message, "chat", None), "id", "")) in self._telegram_free_response_chats(): return True if not self._telegram_require_mention(): diff --git a/gateway/platforms/webhook.py b/gateway/platforms/webhook.py index eac7ed80e..c37445b17 100644 --- a/gateway/platforms/webhook.py +++ b/gateway/platforms/webhook.py @@ -203,6 +203,7 @@ class WebhookAdapter(BasePlatformAdapter): "wecom_callback", "weixin", "bluebubbles", + "qqbot", ): return await self._deliver_cross_platform( deliver_type, content, delivery diff --git a/gateway/run.py b/gateway/run.py index c23b499bf..76084dcbd 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1499,6 +1499,7 @@ class GatewayRunner: "WECOM_CALLBACK_ALLOWED_USERS", "WEIXIN_ALLOWED_USERS", "BLUEBUBBLES_ALLOWED_USERS", + "QQ_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( @@ -1512,7 +1513,8 @@ class GatewayRunner: "WECOM_ALLOW_ALL_USERS", "WECOM_CALLBACK_ALLOW_ALL_USERS", "WEIXIN_ALLOW_ALL_USERS", - "BLUEBUBBLES_ALLOW_ALL_USERS") + "BLUEBUBBLES_ALLOW_ALL_USERS", + "QQ_ALLOW_ALL_USERS") ) if not _any_allowlist and not _allow_all: logger.warning( @@ -2255,8 +2257,15 @@ class GatewayRunner: return None return BlueBubblesAdapter(config) + elif platform == Platform.QQBOT: + from gateway.platforms.qqbot import QQAdapter, check_qq_requirements + if not check_qq_requirements(): + logger.warning("QQBot: aiohttp/httpx missing or QQ_APP_ID/QQ_CLIENT_SECRET not configured") + return None + return QQAdapter(config) + return None - + def _is_user_authorized(self, source: SessionSource) -> bool: """ Check if a user is authorized to use the bot. @@ -2296,6 +2305,7 @@ class GatewayRunner: Platform.WECOM_CALLBACK: "WECOM_CALLBACK_ALLOWED_USERS", Platform.WEIXIN: "WEIXIN_ALLOWED_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", + Platform.QQBOT: "QQ_ALLOWED_USERS", } platform_allow_all_map = { Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS", @@ -2313,6 +2323,7 @@ class GatewayRunner: Platform.WECOM_CALLBACK: "WECOM_CALLBACK_ALLOW_ALL_USERS", Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS", + Platform.QQBOT: "QQ_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) @@ -6566,7 +6577,7 @@ class GatewayRunner: Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP, Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX, Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK, - Platform.FEISHU, Platform.WECOM, Platform.WECOM_CALLBACK, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL, + Platform.FEISHU, Platform.WECOM, Platform.WECOM_CALLBACK, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.QQBOT, Platform.LOCAL, }) async def _handle_debug_command(self, event: MessageEvent) -> str: @@ -7906,13 +7917,14 @@ class GatewayRunner: _adapter = self.adapters.get(source.platform) if _adapter: # Platforms that don't support editing sent messages - # (e.g. WeChat) must not show a cursor in intermediate - # sends — the cursor would be permanently visible because - # it can never be edited away. Use an empty cursor for - # such platforms so streaming still delivers the final - # response, just without the typing indicator. + # (e.g. QQ, WeChat) should skip streaming entirely — + # without edit support, the consumer sends a partial + # first message that can never be updated, resulting in + # duplicate messages (partial + final). _adapter_supports_edit = getattr(_adapter, "SUPPORTS_MESSAGE_EDITING", True) - _effective_cursor = _scfg.cursor if _adapter_supports_edit else "" + if not _adapter_supports_edit: + raise RuntimeError("skip streaming for non-editable platform") + _effective_cursor = _scfg.cursor # Some Matrix clients render the streaming cursor # as a visible tofu/white-box artifact. Keep # streaming text on Matrix, but suppress the cursor. diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 240084e9b..e6d96c802 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -64,6 +64,18 @@ class GatewayStreamConsumer: # progressive edits for the remainder of the stream. _MAX_FLOOD_STRIKES = 3 + # Reasoning/thinking tags that models emit inline in content. + # Must stay in sync with cli.py _OPEN_TAGS/_CLOSE_TAGS and + # run_agent.py _strip_think_blocks() tag variants. + _OPEN_THINK_TAGS = ( + "", "", "", + "", "", "", + ) + _CLOSE_THINK_TAGS = ( + "", "", "", + "", "", "", + ) + def __init__( self, adapter: Any, @@ -88,6 +100,10 @@ class GatewayStreamConsumer: self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff self._final_response_sent = False + # Think-block filter state (mirrors CLI's _stream_delta tag suppression) + self._in_think_block = False + self._think_buffer = "" + @property def already_sent(self) -> bool: """True if at least one message was sent or edited during the run.""" @@ -132,6 +148,112 @@ class GatewayStreamConsumer: """Signal that the stream is complete.""" self._queue.put(_DONE) + # ── Think-block filtering ──────────────────────────────────────── + # Models like MiniMax emit inline ... blocks in their + # content. The CLI's _stream_delta suppresses these via a state + # machine; we do the same here so gateway users never see raw + # reasoning tags. The agent also strips them from the final + # response (run_agent.py _strip_think_blocks), but the stream + # consumer sends intermediate edits before that stripping happens. + + def _filter_and_accumulate(self, text: str) -> None: + """Add a text delta to the accumulated buffer, suppressing think blocks. + + Uses a state machine that tracks whether we are inside a + reasoning/thinking block. Text inside such blocks is silently + discarded. Partial tags at buffer boundaries are held back in + ``_think_buffer`` until enough characters arrive to decide. + """ + buf = self._think_buffer + text + self._think_buffer = "" + + while buf: + if self._in_think_block: + # Look for the earliest closing tag + best_idx = -1 + best_len = 0 + for tag in self._CLOSE_THINK_TAGS: + idx = buf.find(tag) + if idx != -1 and (best_idx == -1 or idx < best_idx): + best_idx = idx + best_len = len(tag) + + if best_len: + # Found closing tag — discard block, process remainder + self._in_think_block = False + buf = buf[best_idx + best_len:] + else: + # No closing tag yet — hold tail that could be a + # partial closing tag prefix, discard the rest. + max_tag = max(len(t) for t in self._CLOSE_THINK_TAGS) + self._think_buffer = buf[-max_tag:] if len(buf) > max_tag else buf + return + else: + # Look for earliest opening tag at a block boundary + # (start of text / preceded by newline + optional whitespace). + # This prevents false positives when models *mention* tags + # in prose (e.g. "the tag is used for…"). + best_idx = -1 + best_len = 0 + for tag in self._OPEN_THINK_TAGS: + search_start = 0 + while True: + idx = buf.find(tag, search_start) + if idx == -1: + break + # Block-boundary check (mirrors cli.py logic) + if idx == 0: + is_boundary = ( + not self._accumulated + or self._accumulated.endswith("\n") + ) + else: + preceding = buf[:idx] + last_nl = preceding.rfind("\n") + if last_nl == -1: + is_boundary = ( + (not self._accumulated + or self._accumulated.endswith("\n")) + and preceding.strip() == "" + ) + else: + is_boundary = preceding[last_nl + 1:].strip() == "" + + if is_boundary and (best_idx == -1 or idx < best_idx): + best_idx = idx + best_len = len(tag) + break # first boundary hit for this tag is enough + search_start = idx + 1 + + if best_len: + # Emit text before the tag, enter think block + self._accumulated += buf[:best_idx] + self._in_think_block = True + buf = buf[best_idx + best_len:] + else: + # No opening tag — check for a partial tag at the tail + held_back = 0 + for tag in self._OPEN_THINK_TAGS: + for i in range(1, len(tag)): + if buf.endswith(tag[:i]) and i > held_back: + held_back = i + if held_back: + self._accumulated += buf[:-held_back] + self._think_buffer = buf[-held_back:] + else: + self._accumulated += buf + return + + def _flush_think_buffer(self) -> None: + """Flush any held-back partial-tag buffer into accumulated text. + + Called when the stream ends (got_done) so that partial text that + was held back waiting for a possible opening tag is not lost. + """ + if self._think_buffer and not self._in_think_block: + self._accumulated += self._think_buffer + self._think_buffer = "" + async def run(self) -> None: """Async task that drains the queue and edits the platform message.""" # Platform message length limit — leave room for cursor + formatting @@ -156,10 +278,16 @@ class GatewayStreamConsumer: if isinstance(item, tuple) and len(item) == 2 and item[0] is _COMMENTARY: commentary_text = item[1] break - self._accumulated += item + self._filter_and_accumulate(item) except queue.Empty: break + # Flush any held-back partial-tag buffer on stream end + # so trailing text that was waiting for a potential open + # tag is not lost. + if got_done: + self._flush_think_buffer() + # Decide whether to flush an edit now = time.monotonic() elapsed = now - self._last_edit_time @@ -504,10 +632,26 @@ class GatewayStreamConsumer: visible_without_cursor = text if self.cfg.cursor: visible_without_cursor = visible_without_cursor.replace(self.cfg.cursor, "") - if not visible_without_cursor.strip(): + _visible_stripped = visible_without_cursor.strip() + if not _visible_stripped: return True # cursor-only / whitespace-only update if not text.strip(): return True # nothing to send is "success" + # Guard: do not create a brand-new standalone message when the only + # visible content is a handful of characters alongside the streaming + # cursor. During rapid tool-calling the model often emits 1-2 tokens + # before switching to tool calls; the resulting "X ▉" message risks + # leaving the cursor permanently visible if the follow-up edit (to + # strip the cursor on segment break) is rate-limited by the platform. + # This was reported on Telegram, Matrix, and other clients where the + # ▉ block character renders as a visible white box ("tofu"). + # Existing messages (edits) are unaffected — only first sends gated. + _MIN_NEW_MSG_CHARS = 4 + if (self._message_id is None + and self.cfg.cursor + and self.cfg.cursor in text + and len(_visible_stripped) < _MIN_NEW_MSG_CHARS): + return True # too short for a standalone message — accumulate more try: if self._message_id is not None: if self._edit_supported: diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 9d1d82e8c..e63a1ebb6 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -224,7 +224,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { ), "ai-gateway": ProviderConfig( id="ai-gateway", - name="AI Gateway", + name="Vercel AI Gateway", auth_type="api_key", inference_base_url="https://ai-gateway.vercel.sh/v1", api_key_env_vars=("AI_GATEWAY_API_KEY",), diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 964311fb7..b0ac79104 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -12,6 +12,9 @@ from __future__ import annotations import os import re +import shutil +import subprocess +import time from collections.abc import Callable, Mapping from dataclasses import dataclass from typing import Any @@ -614,6 +617,10 @@ class SlashCommandCompleter(Completer): ) -> None: self._skill_commands_provider = skill_commands_provider self._command_filter = command_filter + # Cached project file list for fuzzy @ completions + self._file_cache: list[str] = [] + self._file_cache_time: float = 0.0 + self._file_cache_cwd: str = "" def _command_allowed(self, slash_command: str) -> bool: if self._command_filter is None: @@ -798,46 +805,138 @@ class SlashCommandCompleter(Completer): count += 1 return - # Bare @ or @partial — show matching files/folders from cwd + # Bare @ or @partial — fuzzy project-wide file search query = word[1:] # strip the @ - if not query: - search_dir, match_prefix = ".", "" - else: - expanded = os.path.expanduser(query) - if expanded.endswith("/"): - search_dir, match_prefix = expanded, "" - else: - search_dir = os.path.dirname(expanded) or "." - match_prefix = os.path.basename(expanded) + yield from self._fuzzy_file_completions(word, query, limit) - try: - entries = os.listdir(search_dir) - except OSError: + def _get_project_files(self) -> list[str]: + """Return cached list of project files (refreshed every 5s).""" + cwd = os.getcwd() + now = time.monotonic() + if ( + self._file_cache + and self._file_cache_cwd == cwd + and now - self._file_cache_time < 5.0 + ): + return self._file_cache + + files: list[str] = [] + # Try rg first (fast, respects .gitignore), then fd, then find. + for cmd in [ + ["rg", "--files", "--sortr=modified", cwd], + ["rg", "--files", cwd], + ["fd", "--type", "f", "--base-directory", cwd], + ]: + tool = cmd[0] + if not shutil.which(tool): + continue + try: + proc = subprocess.run( + cmd, capture_output=True, text=True, timeout=2, + cwd=cwd, + ) + if proc.returncode == 0 and proc.stdout.strip(): + raw = proc.stdout.strip().split("\n") + # Store relative paths + for p in raw[:5000]: + rel = os.path.relpath(p, cwd) if os.path.isabs(p) else p + files.append(rel) + break + except (subprocess.TimeoutExpired, OSError): + continue + + self._file_cache = files + self._file_cache_time = now + self._file_cache_cwd = cwd + return files + + @staticmethod + def _score_path(filepath: str, query: str) -> int: + """Score a file path against a fuzzy query. Higher = better match.""" + if not query: + return 1 # show everything when query is empty + + filename = os.path.basename(filepath) + lower_file = filename.lower() + lower_path = filepath.lower() + lower_q = query.lower() + + # Exact filename match + if lower_file == lower_q: + return 100 + # Filename starts with query + if lower_file.startswith(lower_q): + return 80 + # Filename contains query as substring + if lower_q in lower_file: + return 60 + # Full path contains query + if lower_q in lower_path: + return 40 + # Initials / abbreviation match: e.g. "fo" matches "file_operations" + # Check if query chars appear in order in filename + qi = 0 + for c in lower_file: + if qi < len(lower_q) and c == lower_q[qi]: + qi += 1 + if qi == len(lower_q): + # Bonus if matches land on word boundaries (after _, -, /, .) + boundary_hits = 0 + qi = 0 + prev = "_" # treat start as boundary + for c in lower_file: + if qi < len(lower_q) and c == lower_q[qi]: + if prev in "_-./": + boundary_hits += 1 + qi += 1 + prev = c + if boundary_hits >= len(lower_q) * 0.5: + return 35 + return 25 + return 0 + + def _fuzzy_file_completions(self, word: str, query: str, limit: int = 20): + """Yield fuzzy file completions for bare @query.""" + files = self._get_project_files() + + if not query: + # No query — show recently modified files (already sorted by mtime) + for fp in files[:limit]: + is_dir = fp.endswith("/") + filename = os.path.basename(fp) + kind = "folder" if is_dir else "file" + meta = "dir" if is_dir else _file_size_label( + os.path.join(os.getcwd(), fp) + ) + yield Completion( + f"@{kind}:{fp}", + start_position=-len(word), + display=filename, + display_meta=meta, + ) return - count = 0 - prefix_lower = match_prefix.lower() - for entry in sorted(entries): - if match_prefix and not entry.lower().startswith(prefix_lower): - continue - if entry.startswith("."): - continue # skip hidden files in bare @ mode - if count >= limit: - break - full_path = os.path.join(search_dir, entry) - is_dir = os.path.isdir(full_path) - display_path = os.path.relpath(full_path) - suffix = "/" if is_dir else "" + # Score and rank + scored = [] + for fp in files: + s = self._score_path(fp, query) + if s > 0: + scored.append((s, fp)) + scored.sort(key=lambda x: (-x[0], x[1])) + + for _, fp in scored[:limit]: + is_dir = fp.endswith("/") + filename = os.path.basename(fp) kind = "folder" if is_dir else "file" - meta = "dir" if is_dir else _file_size_label(full_path) - completion = f"@{kind}:{display_path}{suffix}" - yield Completion( - completion, - start_position=-len(word), - display=entry + suffix, - display_meta=meta, + meta = "dir" if is_dir else _file_size_label( + os.path.join(os.getcwd(), fp) + ) + yield Completion( + f"@{kind}:{fp}", + start_position=-len(word), + display=filename, + display_meta=f"{fp} {meta}" if meta else fp, ) - count += 1 @staticmethod def _skin_completions(sub_text: str, sub_lower: str): diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 64a5bd1a9..78cc30157 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -45,6 +45,9 @@ _EXTRA_ENV_KEYS = frozenset({ "WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY", "WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS", "BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD", + "QQ_APP_ID", "QQ_CLIENT_SECRET", "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", + "QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT", + "QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", @@ -1331,6 +1334,53 @@ OPTIONAL_ENV_VARS = { "password": False, "category": "messaging", }, + "BLUEBUBBLES_ALLOW_ALL_USERS": { + "description": "Allow all BlueBubbles users without allowlist", + "prompt": "Allow All BlueBubbles Users", + "category": "messaging", + }, + "QQ_APP_ID": { + "description": "QQ Bot App ID from QQ Open Platform (q.qq.com)", + "prompt": "QQ App ID", + "url": "https://q.qq.com", + "category": "messaging", + }, + "QQ_CLIENT_SECRET": { + "description": "QQ Bot Client Secret from QQ Open Platform", + "prompt": "QQ Client Secret", + "password": True, + "category": "messaging", + }, + "QQ_ALLOWED_USERS": { + "description": "Comma-separated QQ user IDs allowed to use the bot", + "prompt": "QQ Allowed Users", + "category": "messaging", + }, + "QQ_GROUP_ALLOWED_USERS": { + "description": "Comma-separated QQ group IDs allowed to interact with the bot", + "prompt": "QQ Group Allowed Users", + "category": "messaging", + }, + "QQ_ALLOW_ALL_USERS": { + "description": "Allow all QQ users without an allowlist (true/false)", + "prompt": "Allow All QQ Users", + "category": "messaging", + }, + "QQ_HOME_CHANNEL": { + "description": "Default QQ channel/group for cron delivery and notifications", + "prompt": "QQ Home Channel", + "category": "messaging", + }, + "QQ_HOME_CHANNEL_NAME": { + "description": "Display name for the QQ home channel", + "prompt": "QQ Home Channel Name", + "category": "messaging", + }, + "QQ_SANDBOX": { + "description": "Enable QQ sandbox mode for development testing (true/false)", + "prompt": "QQ Sandbox Mode", + "category": "messaging", + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index 19c332b35..34a57aad2 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -729,7 +729,7 @@ def run_doctor(args): # MiniMax: the /anthropic endpoint doesn't support /models, but the /v1 endpoint does. ("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL", True), ("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL", True), - ("AI Gateway", ("AI_GATEWAY_API_KEY",), "https://ai-gateway.vercel.sh/v1/models", "AI_GATEWAY_BASE_URL", True), + ("Vercel AI Gateway", ("AI_GATEWAY_API_KEY",), "https://ai-gateway.vercel.sh/v1/models", "AI_GATEWAY_BASE_URL", True), ("Kilo Code", ("KILOCODE_API_KEY",), "https://api.kilo.ai/api/gateway/models", "KILOCODE_BASE_URL", True), ("OpenCode Zen", ("OPENCODE_ZEN_API_KEY",), "https://opencode.ai/zen/v1/models", "OPENCODE_ZEN_BASE_URL", True), ("OpenCode Go", ("OPENCODE_GO_API_KEY",), "https://opencode.ai/zen/go/v1/models", "OPENCODE_GO_BASE_URL", True), diff --git a/hermes_cli/dump.py b/hermes_cli/dump.py index 491bf6e2c..a52079085 100644 --- a/hermes_cli/dump.py +++ b/hermes_cli/dump.py @@ -131,6 +131,7 @@ def _configured_platforms() -> list[str]: "wecom": "WECOM_BOT_ID", "wecom_callback": "WECOM_CALLBACK_CORP_ID", "weixin": "WEIXIN_ACCOUNT_ID", + "qqbot": "QQ_APP_ID", } return [name for name, env in checks.items() if os.getenv(env)] diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 628319d57..fe7bb9bd8 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -1913,6 +1913,29 @@ _PLATFORMS = [ "help": "Phone number or Apple ID to deliver cron results and notifications to."}, ], }, + { + "key": "qqbot", + "label": "QQ Bot", + "emoji": "🐧", + "token_var": "QQ_APP_ID", + "setup_instructions": [ + "1. Register a QQ Bot application at q.qq.com", + "2. Note your App ID and App Secret from the application page", + "3. Enable the required intents (C2C, Group, Guild messages)", + "4. Configure sandbox or publish the bot", + ], + "vars": [ + {"name": "QQ_APP_ID", "prompt": "QQ Bot App ID", "password": False, + "help": "Your QQ Bot App ID from q.qq.com."}, + {"name": "QQ_CLIENT_SECRET", "prompt": "QQ Bot App Secret", "password": True, + "help": "Your QQ Bot App Secret from q.qq.com."}, + {"name": "QQ_ALLOWED_USERS", "prompt": "Allowed user OpenIDs (comma-separated, leave empty for open access)", "password": False, + "is_allowlist": True, + "help": "Optional — restrict DM access to specific user OpenIDs."}, + {"name": "QQ_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False, + "help": "OpenID to deliver cron results and notifications to."}, + ], + }, ] diff --git a/hermes_cli/main.py b/hermes_cli/main.py index a2c797fd4..2277940b8 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1839,6 +1839,10 @@ def _model_flow_custom(config): model_name = input("Model name (e.g. gpt-4, llama-3-70b): ").strip() context_length_str = input("Context length in tokens [leave blank for auto-detect]: ").strip() + + # Prompt for a display name — shown in the provider menu on future runs + default_name = _auto_provider_name(effective_url) + display_name = input(f"Display name [{default_name}]: ").strip() or default_name except (KeyboardInterrupt, EOFError): print("\nCancelled.") return @@ -1894,15 +1898,37 @@ def _model_flow_custom(config): print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.") # Auto-save to custom_providers so it appears in the menu next time - _save_custom_provider(effective_url, effective_key, model_name or "", context_length=context_length) + _save_custom_provider(effective_url, effective_key, model_name or "", + context_length=context_length, name=display_name) -def _save_custom_provider(base_url, api_key="", model="", context_length=None): +def _auto_provider_name(base_url: str) -> str: + """Generate a display name from a custom endpoint URL. + + Returns a human-friendly label like "Local (localhost:11434)" or + "RunPod (xyz.runpod.io)". Used as the default when prompting the + user for a display name during custom endpoint setup. + """ + import re + clean = base_url.replace("https://", "").replace("http://", "").rstrip("/") + clean = re.sub(r"/v1/?$", "", clean) + name = clean.split("/")[0] + if "localhost" in name or "127.0.0.1" in name: + name = f"Local ({name})" + elif "runpod" in name.lower(): + name = f"RunPod ({name})" + else: + name = name.capitalize() + return name + + +def _save_custom_provider(base_url, api_key="", model="", context_length=None, + name=None): """Save a custom endpoint to custom_providers in config.yaml. Deduplicates by base_url — if the URL already exists, updates the model name and context_length but doesn't add a duplicate entry. - Auto-generates a display name from the URL hostname. + Uses *name* when provided, otherwise auto-generates from the URL. """ from hermes_cli.config import load_config, save_config @@ -1930,20 +1956,9 @@ def _save_custom_provider(base_url, api_key="", model="", context_length=None): save_config(cfg) return # already saved, updated if needed - # Auto-generate a name from the URL - import re - clean = base_url.replace("https://", "").replace("http://", "").rstrip("/") - # Remove /v1 suffix for cleaner names - clean = re.sub(r"/v1/?$", "", clean) - # Use hostname:port as the name - name = clean.split("/")[0] - # Capitalize for readability - if "localhost" in name or "127.0.0.1" in name: - name = f"Local ({name})" - elif "runpod" in name.lower(): - name = f"RunPod ({name})" - else: - name = name.capitalize() + # Use provided name or auto-generate from URL + if not name: + name = _auto_provider_name(base_url) entry = {"name": name, "base_url": base_url} if api_key: diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index c7f422c6d..08dfb7f41 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -705,6 +705,10 @@ def switch_model( error_message=msg, ) + # Apply auto-correction if validation found a closer match + if validation.get("corrected_model"): + new_model = validation["corrected_model"] + # --- OpenCode api_mode override --- if target_provider in {"opencode-zen", "opencode-go", "opencode", "opencode-go"}: api_mode = opencode_model_api_mode(target_provider, new_model) diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 23e0d9e1f..3d23ee557 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -29,6 +29,7 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [ ("qwen/qwen3.6-plus", ""), ("anthropic/claude-sonnet-4.5", ""), ("anthropic/claude-haiku-4.5", ""), + ("openrouter/elephant-alpha", "free"), ("openai/gpt-5.4", ""), ("openai/gpt-5.4-mini", ""), ("xiaomi/mimo-v2-pro", ""), @@ -97,6 +98,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "arcee-ai/trinity-large-thinking", "openai/gpt-5.4-pro", "openai/gpt-5.4-nano", + "openrouter/elephant-alpha", ], "openai-codex": _codex_curated_models(), "copilot-acp": [ @@ -512,6 +514,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [ ProviderEntry("openrouter", "OpenRouter", "OpenRouter (100+ models, pay-per-use)"), ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"), ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"), + ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"), ProviderEntry("qwen-oauth", "Qwen OAuth (Portal)", "Qwen OAuth (reuses local Qwen CLI login)"), ProviderEntry("copilot", "GitHub Copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"), ProviderEntry("copilot-acp", "GitHub Copilot ACP", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"), @@ -525,12 +528,11 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [ ProviderEntry("minimax", "MiniMax", "MiniMax (global direct API)"), ProviderEntry("minimax-cn", "MiniMax (China)", "MiniMax China (domestic direct API)"), ProviderEntry("alibaba", "Alibaba Cloud (DashScope)","Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"), - ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"), ProviderEntry("arcee", "Arcee AI", "Arcee AI (Trinity models — direct API)"), ProviderEntry("kilocode", "Kilo Code", "Kilo Code (Kilo Gateway API)"), ProviderEntry("opencode-zen", "OpenCode Zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"), ProviderEntry("opencode-go", "OpenCode Go", "OpenCode Go (open models, $10/month subscription)"), - ProviderEntry("ai-gateway", "AI Gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"), + ProviderEntry("ai-gateway", "Vercel AI Gateway", "Vercel AI Gateway (200+ models, pay-per-use)"), ] # Derived dicts — used throughout the codebase @@ -1818,6 +1820,17 @@ def validate_requested_model( "message": None, } + # Auto-correct if the top match is very similar (e.g. typo) + auto = get_close_matches(requested_for_lookup, api_models, n=1, cutoff=0.9) + if auto: + return { + "accepted": True, + "persist": True, + "recognized": True, + "corrected_model": auto[0], + "message": f"Auto-corrected `{requested}` → `{auto[0]}`", + } + suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5) suggestion_text = "" if suggestions: @@ -1869,6 +1882,16 @@ def validate_requested_model( "recognized": True, "message": None, } + # Auto-correct if the top match is very similar (e.g. typo) + auto = get_close_matches(requested_for_lookup, codex_models, n=1, cutoff=0.9) + if auto: + return { + "accepted": True, + "persist": True, + "recognized": True, + "corrected_model": auto[0], + "message": f"Auto-corrected `{requested}` → `{auto[0]}`", + } suggestions = get_close_matches(requested_for_lookup, codex_models, n=3, cutoff=0.5) suggestion_text = "" if suggestions: @@ -1895,10 +1918,27 @@ def validate_requested_model( "recognized": True, "message": None, } - suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5) - suggestion_text = "" - if suggestions: - suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions) + else: + # API responded but model is not listed. Accept anyway — + # the user may have access to models not shown in the public + # listing (e.g. Z.AI Pro/Max plans can use glm-5 on coding + # endpoints even though it's not in /models). Warn but allow. + + # Auto-correct if the top match is very similar (e.g. typo) + auto = get_close_matches(requested_for_lookup, api_models, n=1, cutoff=0.9) + if auto: + return { + "accepted": True, + "persist": True, + "recognized": True, + "corrected_model": auto[0], + "message": f"Auto-corrected `{requested}` → `{auto[0]}`", + } + + suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5) + suggestion_text = "" + if suggestions: + suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions) return { "accepted": False, diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py index df47ed095..1fc3a3a85 100644 --- a/hermes_cli/platforms.py +++ b/hermes_cli/platforms.py @@ -35,6 +35,7 @@ PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")), ("wecom_callback", PlatformInfo(label="💬 WeCom Callback", default_toolset="hermes-wecom-callback")), ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), + ("qqbot", PlatformInfo(label="💬 QQBot", default_toolset="hermes-qqbot")), ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), ]) diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 13a31b2a8..a1f8db31f 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -584,6 +584,45 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]: +def get_pre_tool_call_block_message( + tool_name: str, + args: Optional[Dict[str, Any]], + task_id: str = "", + session_id: str = "", + tool_call_id: str = "", +) -> Optional[str]: + """Check ``pre_tool_call`` hooks for a blocking directive. + + Plugins that need to enforce policy (rate limiting, security + restrictions, approval workflows) can return:: + + {"action": "block", "message": "Reason the tool was blocked"} + + from their ``pre_tool_call`` callback. The first valid block + directive wins. Invalid or irrelevant hook return values are + silently ignored so existing observer-only hooks are unaffected. + """ + hook_results = invoke_hook( + "pre_tool_call", + tool_name=tool_name, + args=args if isinstance(args, dict) else {}, + task_id=task_id, + session_id=session_id, + tool_call_id=tool_call_id, + ) + + for result in hook_results: + if not isinstance(result, dict): + continue + if result.get("action") != "block": + continue + message = result.get("message") + if isinstance(message, str) and message: + return message + + return None + + def get_plugin_context_engine(): """Return the plugin-registered context engine, or None.""" return get_plugin_manager()._context_engine @@ -608,7 +647,7 @@ def get_plugin_toolsets() -> List[tuple]: toolset_tools: Dict[str, List[str]] = {} toolset_plugin: Dict[str, LoadedPlugin] = {} for tool_name in manager._plugin_tool_names: - entry = registry._tools.get(tool_name) + entry = registry.get_entry(tool_name) if not entry: continue ts = entry.toolset @@ -617,7 +656,7 @@ def get_plugin_toolsets() -> List[tuple]: # Map toolsets back to the plugin that registered them for _name, loaded in manager._plugins.items(): for tool_name in loaded.tools_registered: - entry = registry._tools.get(tool_name) + entry = registry.get_entry(tool_name) if entry and entry.toolset in toolset_tools: toolset_plugin.setdefault(entry.toolset, loaded) diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 54b9ae65c..b2dec61cd 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -287,6 +287,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An # Resolve the API key from the env var name stored in key_env key_env = str(entry.get("key_env", "") or "").strip() resolved_api_key = os.getenv(key_env, "").strip() if key_env else "" + # Fall back to inline api_key when key_env is absent or unresolvable + if not resolved_api_key: + resolved_api_key = str(entry.get("api_key", "") or "").strip() if requested_norm in {ep_name, name_norm, f"custom:{name_norm}"}: # Found match by provider key diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index aadf369f5..9044871dc 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -776,7 +776,7 @@ def setup_model_provider(config: dict, *, quick: bool = False): "minimax": "MiniMax", "minimax-cn": "MiniMax CN", "anthropic": "Anthropic", - "ai-gateway": "AI Gateway", + "ai-gateway": "Vercel AI Gateway", "custom": "your custom endpoint", } _prov_display = _prov_names.get(selected_provider, selected_provider or "your provider") @@ -1969,6 +1969,54 @@ def _setup_wecom_callback(): _gw_setup() +def _setup_qqbot(): + """Configure QQ Bot gateway.""" + print_header("QQ Bot") + existing = get_env_value("QQ_APP_ID") + if existing: + print_info("QQ Bot: already configured") + if not prompt_yes_no("Reconfigure QQ Bot?", False): + return + + print_info("Connects Hermes to QQ via the Official QQ Bot API (v2).") + print_info(" Requires a QQ Bot application at q.qq.com") + print_info(" Reference: https://bot.q.qq.com/wiki/develop/api-v2/") + print() + + app_id = prompt("QQ Bot App ID") + if not app_id: + print_warning("App ID is required — skipping QQ Bot setup") + return + save_env_value("QQ_APP_ID", app_id.strip()) + + client_secret = prompt("QQ Bot App Secret", password=True) + if not client_secret: + print_warning("App Secret is required — skipping QQ Bot setup") + return + save_env_value("QQ_CLIENT_SECRET", client_secret) + print_success("QQ Bot credentials saved") + + print() + print_info("🔒 Security: Restrict who can DM your bot") + print_info(" Use QQ user OpenIDs (found in event payloads)") + print() + allowed_users = prompt("Allowed user OpenIDs (comma-separated, leave empty for open access)") + if allowed_users: + save_env_value("QQ_ALLOWED_USERS", allowed_users.replace(" ", "")) + print_success("QQ Bot allowlist configured") + else: + print_info("⚠️ No allowlist set — anyone can DM the bot!") + + print() + print_info("📬 Home Channel: OpenID for cron job delivery and notifications.") + home_channel = prompt("Home channel OpenID (leave empty to set later)") + if home_channel: + save_env_value("QQ_HOME_CHANNEL", home_channel) + + print() + print_success("QQ Bot configured!") + + def _setup_bluebubbles(): """Configure BlueBubbles iMessage gateway.""" print_header("BlueBubbles (iMessage)") @@ -2034,6 +2082,15 @@ def _setup_bluebubbles(): print_info(" Install: https://docs.bluebubbles.app/helper-bundle/installation") +def _setup_qqbot(): + """Configure QQ Bot (Official API v2) via standard platform setup.""" + from hermes_cli.gateway import _PLATFORMS + qq_platform = next((p for p in _PLATFORMS if p["key"] == "qqbot"), None) + if qq_platform: + from hermes_cli.gateway import _setup_standard_platform + _setup_standard_platform(qq_platform) + + def _setup_webhooks(): """Configure webhook integration.""" print_header("Webhooks") @@ -2097,6 +2154,7 @@ _GATEWAY_PLATFORMS = [ ("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback), ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), + ("QQ Bot", "QQ_APP_ID", _setup_qqbot), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), ] @@ -2148,6 +2206,7 @@ def setup_gateway(config: dict): or get_env_value("WECOM_BOT_ID") or get_env_value("WEIXIN_ACCOUNT_ID") or get_env_value("BLUEBUBBLES_SERVER_URL") + or get_env_value("QQ_APP_ID") or get_env_value("WEBHOOK_ENABLED") ) if any_messaging: @@ -2169,6 +2228,8 @@ def setup_gateway(config: dict): missing_home.append("Slack") if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"): missing_home.append("BlueBubbles") + if get_env_value("QQ_APP_ID") and not get_env_value("QQ_HOME_CHANNEL"): + missing_home.append("QQBot") if missing_home: print() diff --git a/hermes_cli/skin_engine.py b/hermes_cli/skin_engine.py index 5fad176b0..b992ada06 100644 --- a/hermes_cli/skin_engine.py +++ b/hermes_cli/skin_engine.py @@ -32,6 +32,12 @@ All fields are optional. Missing values inherit from the ``default`` skin. response_border: "#FFD700" # Response box border (ANSI) session_label: "#DAA520" # Session label color session_border: "#8B8682" # Session ID dim color + status_bar_bg: "#1a1a2e" # TUI status/usage bar background + voice_status_bg: "#1a1a2e" # TUI voice status background + completion_menu_bg: "#1a1a2e" # Completion menu background + completion_menu_current_bg: "#333355" # Active completion row background + completion_menu_meta_bg: "#1a1a2e" # Completion meta column background + completion_menu_meta_current_bg: "#333355" # Active completion meta background # Spinner: customize the animated spinner during API calls spinner: @@ -87,6 +93,8 @@ BUILT-IN SKINS - ``ares`` — Crimson/bronze war-god theme with custom spinner wings - ``mono`` — Clean grayscale monochrome - ``slate`` — Cool blue developer-focused theme +- ``daylight`` — Light background theme with dark text and blue accents +- ``warm-lightmode`` — Warm brown/gold text for light terminal backgrounds USER SKINS ========== @@ -304,6 +312,80 @@ _BUILTIN_SKINS: Dict[str, Dict[str, Any]] = { }, "tool_prefix": "┊", }, + "daylight": { + "name": "daylight", + "description": "Light theme for bright terminals with dark text and cool blue accents", + "colors": { + "banner_border": "#2563EB", + "banner_title": "#0F172A", + "banner_accent": "#1D4ED8", + "banner_dim": "#475569", + "banner_text": "#111827", + "ui_accent": "#2563EB", + "ui_label": "#0F766E", + "ui_ok": "#15803D", + "ui_error": "#B91C1C", + "ui_warn": "#B45309", + "prompt": "#111827", + "input_rule": "#93C5FD", + "response_border": "#2563EB", + "session_label": "#1D4ED8", + "session_border": "#64748B", + "status_bar_bg": "#E5EDF8", + "voice_status_bg": "#E5EDF8", + "completion_menu_bg": "#F8FAFC", + "completion_menu_current_bg": "#DBEAFE", + "completion_menu_meta_bg": "#EEF2FF", + "completion_menu_meta_current_bg": "#BFDBFE", + }, + "spinner": {}, + "branding": { + "agent_name": "Hermes Agent", + "welcome": "Welcome to Hermes Agent! Type your message or /help for commands.", + "goodbye": "Goodbye! ⚕", + "response_label": " ⚕ Hermes ", + "prompt_symbol": "❯ ", + "help_header": "[?] Available Commands", + }, + "tool_prefix": "│", + }, + "warm-lightmode": { + "name": "warm-lightmode", + "description": "Warm light mode — dark brown/gold text for light terminal backgrounds", + "colors": { + "banner_border": "#8B6914", + "banner_title": "#5C3D11", + "banner_accent": "#8B4513", + "banner_dim": "#8B7355", + "banner_text": "#2C1810", + "ui_accent": "#8B4513", + "ui_label": "#5C3D11", + "ui_ok": "#2E7D32", + "ui_error": "#C62828", + "ui_warn": "#E65100", + "prompt": "#2C1810", + "input_rule": "#8B6914", + "response_border": "#8B6914", + "session_label": "#5C3D11", + "session_border": "#A0845C", + "status_bar_bg": "#F5F0E8", + "voice_status_bg": "#F5F0E8", + "completion_menu_bg": "#F5EFE0", + "completion_menu_current_bg": "#E8DCC8", + "completion_menu_meta_bg": "#F0E8D8", + "completion_menu_meta_current_bg": "#DFCFB0", + }, + "spinner": {}, + "branding": { + "agent_name": "Hermes Agent", + "welcome": "Welcome to Hermes Agent! Type your message or /help for commands.", + "goodbye": "Goodbye! \u2695", + "response_label": " \u2695 Hermes ", + "prompt_symbol": "\u276f ", + "help_header": "(^_^)? Available Commands", + }, + "tool_prefix": "\u250a", + }, "poseidon": { "name": "poseidon", "description": "Ocean-god theme — deep blue and seafoam", @@ -685,6 +767,12 @@ def get_prompt_toolkit_style_overrides() -> Dict[str, str]: label = skin.get_color("ui_label", title) warn = skin.get_color("ui_warn", "#FF8C00") error = skin.get_color("ui_error", "#FF6B6B") + status_bg = skin.get_color("status_bar_bg", "#1a1a2e") + voice_bg = skin.get_color("voice_status_bg", status_bg) + menu_bg = skin.get_color("completion_menu_bg", "#1a1a2e") + menu_current_bg = skin.get_color("completion_menu_current_bg", "#333355") + menu_meta_bg = skin.get_color("completion_menu_meta_bg", menu_bg) + menu_meta_current_bg = skin.get_color("completion_menu_meta_current_bg", menu_current_bg) return { "input-area": prompt, @@ -692,13 +780,20 @@ def get_prompt_toolkit_style_overrides() -> Dict[str, str]: "prompt": prompt, "prompt-working": f"{dim} italic", "hint": f"{dim} italic", + "status-bar": f"bg:{status_bg} {text}", + "status-bar-strong": f"bg:{status_bg} {title} bold", + "status-bar-dim": f"bg:{status_bg} {dim}", + "status-bar-good": f"bg:{status_bg} {skin.get_color('ui_ok', '#8FBC8F')} bold", + "status-bar-warn": f"bg:{status_bg} {warn} bold", + "status-bar-bad": f"bg:{status_bg} {skin.get_color('banner_accent', warn)} bold", + "status-bar-critical": f"bg:{status_bg} {error} bold", "input-rule": input_rule, "image-badge": f"{label} bold", - "completion-menu": f"bg:#1a1a2e {text}", - "completion-menu.completion": f"bg:#1a1a2e {text}", - "completion-menu.completion.current": f"bg:#333355 {title}", - "completion-menu.meta.completion": f"bg:#1a1a2e {dim}", - "completion-menu.meta.completion.current": f"bg:#333355 {label}", + "completion-menu": f"bg:{menu_bg} {text}", + "completion-menu.completion": f"bg:{menu_bg} {text}", + "completion-menu.completion.current": f"bg:{menu_current_bg} {title}", + "completion-menu.meta.completion": f"bg:{menu_meta_bg} {dim}", + "completion-menu.meta.completion.current": f"bg:{menu_meta_current_bg} {label}", "clarify-border": input_rule, "clarify-title": f"{title} bold", "clarify-question": f"{text} bold", @@ -716,4 +811,6 @@ def get_prompt_toolkit_style_overrides() -> Dict[str, str]: "approval-cmd": f"{dim} italic", "approval-choice": dim, "approval-selected": f"{title} bold", + "voice-status": f"bg:{voice_bg} {label}", + "voice-status-recording": f"bg:{voice_bg} {error} bold", } diff --git a/hermes_cli/status.py b/hermes_cli/status.py index a7745d65f..5ec93f24d 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -305,6 +305,7 @@ def show_status(args): "WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None), "Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"), "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), + "QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 343007cab..d74f7ea72 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -426,6 +426,8 @@ def _get_enabled_platforms() -> List[str]: enabled.append("slack") if get_env_value("WHATSAPP_ENABLED"): enabled.append("whatsapp") + if get_env_value("QQ_APP_ID"): + enabled.append("qqbot") return enabled diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 89d60a299..f73104ce8 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -96,6 +96,11 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = { "description": "Default model (e.g. anthropic/claude-sonnet-4.6)", "category": "general", }, + "model_context_length": { + "type": "number", + "description": "Context window override (0 = auto-detect from model metadata)", + "category": "general", + }, "terminal.backend": { "type": "select", "description": "Terminal execution backend", @@ -246,6 +251,17 @@ def _build_schema_from_config( CONFIG_SCHEMA = _build_schema_from_config(DEFAULT_CONFIG) +# Inject virtual fields that don't live in DEFAULT_CONFIG but are surfaced +# by the normalize/denormalize cycle. Insert model_context_length right after +# the "model" key so it renders adjacent in the frontend. +_mcl_entry = _SCHEMA_OVERRIDES["model_context_length"] +_ordered_schema: Dict[str, Dict[str, Any]] = {} +for _k, _v in CONFIG_SCHEMA.items(): + _ordered_schema[_k] = _v + if _k == "model": + _ordered_schema["model_context_length"] = _mcl_entry +CONFIG_SCHEMA = _ordered_schema + class ConfigUpdate(BaseModel): config: dict @@ -408,11 +424,19 @@ def _normalize_config_for_web(config: Dict[str, Any]) -> Dict[str, Any]: or a dict (``{default: ..., provider: ..., base_url: ...}``). The schema is built from DEFAULT_CONFIG where ``model`` is a string, but user configs often have the dict form. Normalize to the string form so the frontend schema matches. + + Also surfaces ``model_context_length`` as a top-level field so the web UI can + display and edit it. A value of 0 means "auto-detect". """ config = dict(config) # shallow copy model_val = config.get("model") if isinstance(model_val, dict): + # Extract context_length before flattening the dict + ctx_len = model_val.get("context_length", 0) config["model"] = model_val.get("default", model_val.get("name", "")) + config["model_context_length"] = ctx_len if isinstance(ctx_len, int) else 0 + else: + config["model_context_length"] = 0 return config @@ -433,6 +457,93 @@ async def get_schema(): return {"fields": CONFIG_SCHEMA, "category_order": _CATEGORY_ORDER} +_EMPTY_MODEL_INFO: dict = { + "model": "", + "provider": "", + "auto_context_length": 0, + "config_context_length": 0, + "effective_context_length": 0, + "capabilities": {}, +} + + +@app.get("/api/model/info") +def get_model_info(): + """Return resolved model metadata for the currently configured model. + + Calls the same context-length resolution chain the agent uses, so the + frontend can display "Auto-detected: 200K" alongside the override field. + Also returns model capabilities (vision, reasoning, tools) when available. + """ + try: + cfg = load_config() + model_cfg = cfg.get("model", "") + + # Extract model name and provider from the config + if isinstance(model_cfg, dict): + model_name = model_cfg.get("default", model_cfg.get("name", "")) + provider = model_cfg.get("provider", "") + base_url = model_cfg.get("base_url", "") + config_ctx = model_cfg.get("context_length") + else: + model_name = str(model_cfg) if model_cfg else "" + provider = "" + base_url = "" + config_ctx = None + + if not model_name: + return dict(_EMPTY_MODEL_INFO, provider=provider) + + # Resolve auto-detected context length (pass config_ctx=None to get + # purely auto-detected value, then separately report the override) + try: + from agent.model_metadata import get_model_context_length + auto_ctx = get_model_context_length( + model=model_name, + base_url=base_url, + provider=provider, + config_context_length=None, # ignore override — we want auto value + ) + except Exception: + auto_ctx = 0 + + config_ctx_int = 0 + if isinstance(config_ctx, int) and config_ctx > 0: + config_ctx_int = config_ctx + + # Effective is what the agent actually uses + effective_ctx = config_ctx_int if config_ctx_int > 0 else auto_ctx + + # Try to get model capabilities from models.dev + caps = {} + try: + from agent.models_dev import get_model_capabilities + mc = get_model_capabilities(provider=provider, model=model_name) + if mc is not None: + caps = { + "supports_tools": mc.supports_tools, + "supports_vision": mc.supports_vision, + "supports_reasoning": mc.supports_reasoning, + "context_window": mc.context_window, + "max_output_tokens": mc.max_output_tokens, + "model_family": mc.model_family, + } + except Exception: + pass + + return { + "model": model_name, + "provider": provider, + "auto_context_length": auto_ctx, + "config_context_length": config_ctx_int, + "effective_context_length": effective_ctx, + "capabilities": caps, + } + except Exception: + _log.exception("GET /api/model/info failed") + return dict(_EMPTY_MODEL_INFO) + + def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: """Reverse _normalize_config_for_web before saving. @@ -440,12 +551,24 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: to recover model subkeys (provider, base_url, api_mode, etc.) that were stripped from the GET response. The frontend only sees model as a flat string; the rest is preserved transparently. + + Also handles ``model_context_length`` — writes it back into the model dict + as ``context_length``. A value of 0 or absent means "auto-detect" (omitted + from the dict so get_model_context_length() uses its normal resolution). """ config = dict(config) # Remove any _model_meta that might have leaked in (shouldn't happen # with the stripped GET response, but be defensive) config.pop("_model_meta", None) + # Extract and remove model_context_length before processing model + ctx_override = config.pop("model_context_length", 0) + if not isinstance(ctx_override, int): + try: + ctx_override = int(ctx_override) + except (TypeError, ValueError): + ctx_override = 0 + model_val = config.get("model") if isinstance(model_val, str) and model_val: # Read the current disk config to recover model subkeys @@ -455,7 +578,20 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: if isinstance(disk_model, dict): # Preserve all subkeys, update default with the new value disk_model["default"] = model_val + # Write context_length into the model dict (0 = remove/auto) + if ctx_override > 0: + disk_model["context_length"] = ctx_override + else: + disk_model.pop("context_length", None) config["model"] = disk_model + else: + # Model was previously a bare string — upgrade to dict if + # user is setting a context_length override + if ctx_override > 0: + config["model"] = { + "default": model_val, + "context_length": ctx_override, + } except Exception: pass # can't read disk config — just use the string form return config diff --git a/hermes_logging.py b/hermes_logging.py index 6d611ba7c..dbef21328 100644 --- a/hermes_logging.py +++ b/hermes_logging.py @@ -78,6 +78,10 @@ def set_session_context(session_id: str) -> None: _session_context.session_id = session_id +def clear_session_context() -> None: + """Clear the session ID for the current thread.""" + _session_context.session_id = None + # --------------------------------------------------------------------------- # Record factory — injects session_tag into every LogRecord at creation diff --git a/model_tools.py b/model_tools.py index c37007c41..1924b2516 100644 --- a/model_tools.py +++ b/model_tools.py @@ -464,6 +464,7 @@ def handle_function_call( session_id: Optional[str] = None, user_task: Optional[str] = None, enabled_tools: Optional[List[str]] = None, + skip_pre_tool_call_hook: bool = False, ) -> str: """ Main function call dispatcher that routes calls to the tool registry. @@ -484,31 +485,53 @@ def handle_function_call( # Coerce string arguments to their schema-declared types (e.g. "42"→42) function_args = coerce_tool_args(function_name, function_args) - # Notify the read-loop tracker when a non-read/search tool runs, - # so the *consecutive* counter resets (reads after other work are fine). - if function_name not in _READ_SEARCH_TOOLS: - try: - from tools.file_tools import notify_other_tool_call - notify_other_tool_call(task_id or "default") - except Exception: - pass # file_tools may not be loaded yet - try: if function_name in _AGENT_LOOP_TOOLS: return json.dumps({"error": f"{function_name} must be handled by the agent loop"}) - try: - from hermes_cli.plugins import invoke_hook - invoke_hook( - "pre_tool_call", - tool_name=function_name, - args=function_args, - task_id=task_id or "", - session_id=session_id or "", - tool_call_id=tool_call_id or "", - ) - except Exception: - pass + # Check plugin hooks for a block directive (unless caller already + # checked — e.g. run_agent._invoke_tool passes skip=True to + # avoid double-firing the hook). + if not skip_pre_tool_call_hook: + block_message: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, + function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) + except Exception: + pass + + if block_message is not None: + return json.dumps({"error": block_message}, ensure_ascii=False) + else: + # Still fire the hook for observers — just don't check for blocking + # (the caller already did that). + try: + from hermes_cli.plugins import invoke_hook + invoke_hook( + "pre_tool_call", + tool_name=function_name, + args=function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) + except Exception: + pass + + # Notify the read-loop tracker when a non-read/search tool runs, + # so the *consecutive* counter resets (reads after other work are fine). + if function_name not in _READ_SEARCH_TOOLS: + try: + from tools.file_tools import notify_other_tool_call + notify_other_tool_call(task_id or "default") + except Exception: + pass # file_tools may not be loaded yet if function_name == "execute_code": # Prefer the caller-provided list so subagents can't overwrite diff --git a/optional-skills/health/fitness-nutrition/SKILL.md b/optional-skills/health/fitness-nutrition/SKILL.md new file mode 100644 index 000000000..672f0ccd0 --- /dev/null +++ b/optional-skills/health/fitness-nutrition/SKILL.md @@ -0,0 +1,255 @@ +--- +name: fitness-nutrition +description: > + Gym workout planner and nutrition tracker. Search 690+ exercises by muscle, + equipment, or category via wger. Look up macros and calories for 380,000+ + foods via USDA FoodData Central. Compute BMI, TDEE, one-rep max, macro + splits, and body fat — pure Python, no pip installs. Built for anyone + chasing gains, cutting weight, or just trying to eat better. +version: 1.0.0 +authors: + - haileymarshall +license: MIT +metadata: + hermes: + tags: [health, fitness, nutrition, gym, workout, diet, exercise] + category: health + prerequisites: + commands: [curl, python3] +required_environment_variables: + - name: USDA_API_KEY + prompt: "USDA FoodData Central API key (free)" + help: "Get one free at https://fdc.nal.usda.gov/api-key-signup/ — or skip to use DEMO_KEY with lower rate limits" + required_for: "higher rate limits on food/nutrition lookups (DEMO_KEY works without signup)" + optional: true +--- + +# Fitness & Nutrition + +Expert fitness coach and sports nutritionist skill. Two data sources +plus offline calculators — everything a gym-goer needs in one place. + +**Data sources (all free, no pip dependencies):** + +- **wger** (https://wger.de/api/v2/) — open exercise database, 690+ exercises with muscles, equipment, images. Public endpoints need zero authentication. +- **USDA FoodData Central** (https://api.nal.usda.gov/fdc/v1/) — US government nutrition database, 380,000+ foods. `DEMO_KEY` works instantly; free signup for higher limits. + +**Offline calculators (pure stdlib Python):** + +- BMI, TDEE (Mifflin-St Jeor), one-rep max (Epley/Brzycki/Lombardi), macro splits, body fat % (US Navy method) + +--- + +## When to Use + +Trigger this skill when the user asks about: +- Exercises, workouts, gym routines, muscle groups, workout splits +- Food macros, calories, protein content, meal planning, calorie counting +- Body composition: BMI, body fat, TDEE, caloric surplus/deficit +- One-rep max estimates, training percentages, progressive overload +- Macro ratios for cutting, bulking, or maintenance + +--- + +## Procedure + +### Exercise Lookup (wger API) + +All wger public endpoints return JSON and require no auth. Always add +`format=json` and `language=2` (English) to exercise queries. + +**Step 1 — Identify what the user wants:** + +- By muscle → use `/api/v2/exercise/?muscles={id}&language=2&status=2&format=json` +- By category → use `/api/v2/exercise/?category={id}&language=2&status=2&format=json` +- By equipment → use `/api/v2/exercise/?equipment={id}&language=2&status=2&format=json` +- By name → use `/api/v2/exercise/search/?term={query}&language=english&format=json` +- Full details → use `/api/v2/exerciseinfo/{exercise_id}/?format=json` + +**Step 2 — Reference IDs (so you don't need extra API calls):** + +Exercise categories: + +| ID | Category | +|----|-------------| +| 8 | Arms | +| 9 | Legs | +| 10 | Abs | +| 11 | Chest | +| 12 | Back | +| 13 | Shoulders | +| 14 | Calves | +| 15 | Cardio | + +Muscles: + +| ID | Muscle | ID | Muscle | +|----|---------------------------|----|-------------------------| +| 1 | Biceps brachii | 2 | Anterior deltoid | +| 3 | Serratus anterior | 4 | Pectoralis major | +| 5 | Obliquus externus | 6 | Gastrocnemius | +| 7 | Rectus abdominis | 8 | Gluteus maximus | +| 9 | Trapezius | 10 | Quadriceps femoris | +| 11 | Biceps femoris | 12 | Latissimus dorsi | +| 13 | Brachialis | 14 | Triceps brachii | +| 15 | Soleus | | | + +Equipment: + +| ID | Equipment | +|----|----------------| +| 1 | Barbell | +| 3 | Dumbbell | +| 4 | Gym mat | +| 5 | Swiss Ball | +| 6 | Pull-up bar | +| 7 | none (bodyweight) | +| 8 | Bench | +| 9 | Incline bench | +| 10 | Kettlebell | + +**Step 3 — Fetch and present results:** + +```bash +# Search exercises by name +QUERY="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$QUERY") +curl -s "https://wger.de/api/v2/exercise/search/?term=${ENCODED}&language=english&format=json" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +for s in data.get('suggestions',[])[:10]: + d=s.get('data',{}) + print(f\" ID {d.get('id','?'):>4} | {d.get('name','N/A'):<35} | Category: {d.get('category','N/A')}\") +" +``` + +```bash +# Get full details for a specific exercise +EXERCISE_ID="$1" +curl -s "https://wger.de/api/v2/exerciseinfo/${EXERCISE_ID}/?format=json" \ + | python3 -c " +import json,sys,html,re +data=json.load(sys.stdin) +trans=[t for t in data.get('translations',[]) if t.get('language')==2] +t=trans[0] if trans else data.get('translations',[{}])[0] +desc=re.sub('<[^>]+>','',html.unescape(t.get('description','N/A'))) +print(f\"Exercise : {t.get('name','N/A')}\") +print(f\"Category : {data.get('category',{}).get('name','N/A')}\") +print(f\"Primary : {', '.join(m.get('name_en','') for m in data.get('muscles',[])) or 'N/A'}\") +print(f\"Secondary : {', '.join(m.get('name_en','') for m in data.get('muscles_secondary',[])) or 'none'}\") +print(f\"Equipment : {', '.join(e.get('name','') for e in data.get('equipment',[])) or 'bodyweight'}\") +print(f\"How to : {desc[:500]}\") +imgs=data.get('images',[]) +if imgs: print(f\"Image : {imgs[0].get('image','')}\") +" +``` + +```bash +# List exercises filtering by muscle, category, or equipment +# Combine filters as needed: ?muscles=4&equipment=1&language=2&status=2 +FILTER="$1" # e.g. "muscles=4" or "category=11" or "equipment=3" +curl -s "https://wger.de/api/v2/exercise/?${FILTER}&language=2&status=2&limit=20&format=json" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +print(f'Found {data.get(\"count\",0)} exercises.') +for ex in data.get('results',[]): + print(f\" ID {ex['id']:>4} | muscles: {ex.get('muscles',[])} | equipment: {ex.get('equipment',[])}\") +" +``` + +### Nutrition Lookup (USDA FoodData Central) + +Uses `USDA_API_KEY` env var if set, otherwise falls back to `DEMO_KEY`. +DEMO_KEY = 30 requests/hour. Free signup key = 1,000 requests/hour. + +```bash +# Search foods by name +FOOD="$1" +API_KEY="${USDA_API_KEY:-DEMO_KEY}" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$FOOD") +curl -s "https://api.nal.usda.gov/fdc/v1/foods/search?api_key=${API_KEY}&query=${ENCODED}&pageSize=5&dataType=Foundation,SR%20Legacy" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +foods=data.get('foods',[]) +if not foods: print('No foods found.'); sys.exit() +for f in foods: + n={x['nutrientName']:x.get('value','?') for x in f.get('foodNutrients',[])} + cal=n.get('Energy','?'); prot=n.get('Protein','?') + fat=n.get('Total lipid (fat)','?'); carb=n.get('Carbohydrate, by difference','?') + print(f\"{f.get('description','N/A')}\") + print(f\" Per 100g: {cal} kcal | {prot}g protein | {fat}g fat | {carb}g carbs\") + print(f\" FDC ID: {f.get('fdcId','N/A')}\") + print() +" +``` + +```bash +# Detailed nutrient profile by FDC ID +FDC_ID="$1" +API_KEY="${USDA_API_KEY:-DEMO_KEY}" +curl -s "https://api.nal.usda.gov/fdc/v1/food/${FDC_ID}?api_key=${API_KEY}" \ + | python3 -c " +import json,sys +d=json.load(sys.stdin) +print(f\"Food: {d.get('description','N/A')}\") +print(f\"{'Nutrient':<40} {'Amount':>8} {'Unit'}\") +print('-'*56) +for x in sorted(d.get('foodNutrients',[]),key=lambda x:x.get('nutrient',{}).get('rank',9999)): + nut=x.get('nutrient',{}); amt=x.get('amount',0) + if amt and float(amt)>0: + print(f\" {nut.get('name',''):<38} {amt:>8} {nut.get('unitName','')}\") +" +``` + +### Offline Calculators + +Use the helper scripts in `scripts/` for batch operations, +or run inline for single calculations: + +- `python3 scripts/body_calc.py bmi ` +- `python3 scripts/body_calc.py tdee ` +- `python3 scripts/body_calc.py 1rm ` +- `python3 scripts/body_calc.py macros ` +- `python3 scripts/body_calc.py bodyfat [hip_cm] ` + +See `references/FORMULAS.md` for the science behind each formula. + +--- + +## Pitfalls + +- wger exercise endpoint returns **all languages by default** — always add `language=2` for English +- wger includes **unverified user submissions** — add `status=2` to only get approved exercises +- USDA `DEMO_KEY` has **30 req/hour** — add `sleep 2` between batch requests or get a free key +- USDA data is **per 100g** — remind users to scale to their actual portion size +- BMI does not distinguish muscle from fat — high BMI in muscular people is not necessarily unhealthy +- Body fat formulas are **estimates** (±3-5%) — recommend DEXA scans for precision +- 1RM formulas lose accuracy above 10 reps — use sets of 3-5 for best estimates +- wger's `exercise/search` endpoint uses `term` not `query` as the parameter name + +--- + +## Verification + +After running exercise search: confirm results include exercise names, muscle groups, and equipment. +After nutrition lookup: confirm per-100g macros are returned with kcal, protein, fat, carbs. +After calculators: sanity-check outputs (e.g. TDEE should be 1500-3500 for most adults). + +--- + +## Quick Reference + +| Task | Source | Endpoint | +|------|--------|----------| +| Search exercises by name | wger | `GET /api/v2/exercise/search/?term=&language=english` | +| Exercise details | wger | `GET /api/v2/exerciseinfo/{id}/` | +| Filter by muscle | wger | `GET /api/v2/exercise/?muscles={id}&language=2&status=2` | +| Filter by equipment | wger | `GET /api/v2/exercise/?equipment={id}&language=2&status=2` | +| List categories | wger | `GET /api/v2/exercisecategory/` | +| List muscles | wger | `GET /api/v2/muscle/` | +| Search foods | USDA | `GET /fdc/v1/foods/search?query=&dataType=Foundation,SR Legacy` | +| Food details | USDA | `GET /fdc/v1/food/{fdcId}` | +| BMI / TDEE / 1RM / macros | offline | `python3 scripts/body_calc.py` | \ No newline at end of file diff --git a/optional-skills/health/fitness-nutrition/references/FORMULAS.md b/optional-skills/health/fitness-nutrition/references/FORMULAS.md new file mode 100644 index 000000000..763c0b3a1 --- /dev/null +++ b/optional-skills/health/fitness-nutrition/references/FORMULAS.md @@ -0,0 +1,100 @@ +# Formulas Reference + +Scientific references for all calculators used in the fitness-nutrition skill. + +## BMI (Body Mass Index) + +**Formula:** BMI = weight (kg) / height (m)² + +| Category | BMI Range | +|-------------|------------| +| Underweight | < 18.5 | +| Normal | 18.5 – 24.9 | +| Overweight | 25.0 – 29.9 | +| Obese | 30.0+ | + +**Limitation:** BMI does not distinguish muscle from fat. A muscular person +can have a high BMI while being lean. Use body fat % for a better picture. + +Reference: Quetelet, A. (1832). Keys et al., Int J Obes (1972). + +## TDEE (Total Daily Energy Expenditure) + +Uses the **Mifflin-St Jeor equation** — the most accurate BMR predictor for +the general population according to the ADA (2005). + +**BMR formulas:** + +- Male: BMR = 10 × weight(kg) + 6.25 × height(cm) − 5 × age + 5 +- Female: BMR = 10 × weight(kg) + 6.25 × height(cm) − 5 × age − 161 + +**Activity multipliers:** + +| Level | Description | Multiplier | +|-------|--------------------------------|------------| +| 1 | Sedentary (desk job) | 1.200 | +| 2 | Lightly active (1-3 days/wk) | 1.375 | +| 3 | Moderately active (3-5 days) | 1.550 | +| 4 | Very active (6-7 days) | 1.725 | +| 5 | Extremely active (2x/day) | 1.900 | + +Reference: Mifflin et al., Am J Clin Nutr 51, 241-247 (1990). + +## One-Rep Max (1RM) + +Three validated formulas. Average of all three is most reliable. + +- **Epley:** 1RM = w × (1 + r/30) +- **Brzycki:** 1RM = w × 36 / (37 − r) +- **Lombardi:** 1RM = w × r^0.1 + +All formulas are most accurate for r ≤ 10. Above 10 reps, error increases. + +Reference: LeSuer et al., J Strength Cond Res 11(4), 211-213 (1997). + +## Macro Splits + +Recommended splits based on goal: + +| Goal | Protein | Fat | Carbs | Calorie Offset | +|-------------|---------|------|-------|----------------| +| Fat loss | 40% | 30% | 30% | −500 kcal | +| Maintenance | 30% | 30% | 40% | 0 | +| Lean bulk | 30% | 25% | 45% | +400 kcal | + +Protein targets for muscle growth: 1.6–2.2 g/kg body weight per day. +Minimum fat intake: 0.5 g/kg to support hormone production. + +Conversion: Protein = 4 kcal/g, Fat = 9 kcal/g, Carbs = 4 kcal/g. + +Reference: Morton et al., Br J Sports Med 52, 376–384 (2018). + +## Body Fat % (US Navy Method) + +**Male:** + +BF% = 86.010 × log₁₀(waist − neck) − 70.041 × log₁₀(height) + 36.76 + +**Female:** + +BF% = 163.205 × log₁₀(waist + hip − neck) − 97.684 × log₁₀(height) − 78.387 + +All measurements in centimeters. + +| Category | Male | Female | +|--------------|--------|--------| +| Essential | 2-5% | 10-13% | +| Athletic | 6-13% | 14-20% | +| Fitness | 14-17% | 21-24% | +| Average | 18-24% | 25-31% | +| Obese | 25%+ | 32%+ | + +Accuracy: ±3-5% compared to DEXA. Measure at the navel (waist), +at the Adam's apple (neck), and widest point (hip, females only). + +Reference: Hodgdon & Beckett, Naval Health Research Center (1984). + +## APIs + +- wger: https://wger.de/api/v2/ — AGPL-3.0, exercise data is CC-BY-SA 3.0 +- USDA FoodData Central: https://api.nal.usda.gov/fdc/v1/ — public domain (CC0 1.0) \ No newline at end of file diff --git a/optional-skills/health/fitness-nutrition/scripts/body_calc.py b/optional-skills/health/fitness-nutrition/scripts/body_calc.py new file mode 100644 index 000000000..2d07129ce --- /dev/null +++ b/optional-skills/health/fitness-nutrition/scripts/body_calc.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +body_calc.py — All-in-one fitness calculator. + +Subcommands: + bmi + tdee + 1rm + macros + bodyfat [hip_cm] + +No external dependencies — stdlib only. +""" +import sys +import math + + +def bmi(weight_kg, height_cm): + h = height_cm / 100 + val = weight_kg / (h * h) + if val < 18.5: + cat = "Underweight" + elif val < 25: + cat = "Normal weight" + elif val < 30: + cat = "Overweight" + else: + cat = "Obese" + print(f"BMI: {val:.1f} — {cat}") + print() + print("Ranges:") + print(f" Underweight : < 18.5") + print(f" Normal : 18.5 – 24.9") + print(f" Overweight : 25.0 – 29.9") + print(f" Obese : 30.0+") + + +def tdee(weight_kg, height_cm, age, sex, activity): + if sex.upper() == "M": + bmr = 10 * weight_kg + 6.25 * height_cm - 5 * age + 5 + else: + bmr = 10 * weight_kg + 6.25 * height_cm - 5 * age - 161 + + multipliers = { + 1: ("Sedentary (desk job, no exercise)", 1.2), + 2: ("Lightly active (1-3 days/week)", 1.375), + 3: ("Moderately active (3-5 days/week)", 1.55), + 4: ("Very active (6-7 days/week)", 1.725), + 5: ("Extremely active (athlete + physical job)", 1.9), + } + + label, mult = multipliers.get(activity, ("Moderate", 1.55)) + total = bmr * mult + + print(f"BMR (Mifflin-St Jeor): {bmr:.0f} kcal/day") + print(f"Activity: {label} (x{mult})") + print(f"TDEE: {total:.0f} kcal/day") + print() + print("Calorie targets:") + print(f" Aggressive cut (-750): {total - 750:.0f} kcal/day") + print(f" Fat loss (-500): {total - 500:.0f} kcal/day") + print(f" Mild cut (-250): {total - 250:.0f} kcal/day") + print(f" Maintenance : {total:.0f} kcal/day") + print(f" Lean bulk (+250): {total + 250:.0f} kcal/day") + print(f" Bulk (+500): {total + 500:.0f} kcal/day") + + +def one_rep_max(weight, reps): + if reps < 1: + print("Error: reps must be at least 1.") + sys.exit(1) + if reps == 1: + print(f"1RM = {weight:.1f} (actual single)") + return + + epley = weight * (1 + reps / 30) + brzycki = weight * (36 / (37 - reps)) if reps < 37 else 0 + lombardi = weight * (reps ** 0.1) + avg = (epley + brzycki + lombardi) / 3 + + print(f"Estimated 1RM ({weight} x {reps} reps):") + print(f" Epley : {epley:.1f}") + print(f" Brzycki : {brzycki:.1f}") + print(f" Lombardi : {lombardi:.1f}") + print(f" Average : {avg:.1f}") + print() + print("Training percentages off average 1RM:") + for pct, rep_range in [ + (100, "1"), (95, "1-2"), (90, "3-4"), (85, "4-6"), + (80, "6-8"), (75, "8-10"), (70, "10-12"), + (65, "12-15"), (60, "15-20"), + ]: + print(f" {pct:>3}% = {avg * pct / 100:>7.1f} (~{rep_range} reps)") + + +def macros(tdee_kcal, goal): + goal = goal.lower() + if goal in ("cut", "lose", "deficit"): + cals = tdee_kcal - 500 + p, f, c = 0.40, 0.30, 0.30 + label = "Fat Loss (-500 kcal)" + elif goal in ("bulk", "gain", "surplus"): + cals = tdee_kcal + 400 + p, f, c = 0.30, 0.25, 0.45 + label = "Lean Bulk (+400 kcal)" + else: + cals = tdee_kcal + p, f, c = 0.30, 0.30, 0.40 + label = "Maintenance" + + prot_g = cals * p / 4 + fat_g = cals * f / 9 + carb_g = cals * c / 4 + + print(f"Goal: {label}") + print(f"Daily calories: {cals:.0f} kcal") + print() + print(f" Protein : {prot_g:>6.0f}g ({p * 100:.0f}%) = {prot_g * 4:.0f} kcal") + print(f" Fat : {fat_g:>6.0f}g ({f * 100:.0f}%) = {fat_g * 9:.0f} kcal") + print(f" Carbs : {carb_g:>6.0f}g ({c * 100:.0f}%) = {carb_g * 4:.0f} kcal") + print() + print(f"Per meal (3 meals): P {prot_g / 3:.0f}g | F {fat_g / 3:.0f}g | C {carb_g / 3:.0f}g") + print(f"Per meal (4 meals): P {prot_g / 4:.0f}g | F {fat_g / 4:.0f}g | C {carb_g / 4:.0f}g") + + +def bodyfat(sex, neck_cm, waist_cm, hip_cm, height_cm): + sex = sex.upper() + if sex == "M": + if waist_cm <= neck_cm: + print("Error: waist must be larger than neck."); sys.exit(1) + bf = 86.010 * math.log10(waist_cm - neck_cm) - 70.041 * math.log10(height_cm) + 36.76 + else: + if (waist_cm + hip_cm) <= neck_cm: + print("Error: waist + hip must be larger than neck."); sys.exit(1) + bf = 163.205 * math.log10(waist_cm + hip_cm - neck_cm) - 97.684 * math.log10(height_cm) - 78.387 + + print(f"Estimated body fat: {bf:.1f}%") + + if sex == "M": + ranges = [ + (6, "Essential fat (2-5%)"), + (14, "Athletic (6-13%)"), + (18, "Fitness (14-17%)"), + (25, "Average (18-24%)"), + ] + default = "Obese (25%+)" + else: + ranges = [ + (14, "Essential fat (10-13%)"), + (21, "Athletic (14-20%)"), + (25, "Fitness (21-24%)"), + (32, "Average (25-31%)"), + ] + default = "Obese (32%+)" + + cat = default + for threshold, label in ranges: + if bf < threshold: + cat = label + break + + print(f"Category: {cat}") + print(f"Method: US Navy circumference formula") + + +def usage(): + print(__doc__) + sys.exit(1) + + +def main(): + if len(sys.argv) < 2: + usage() + + cmd = sys.argv[1].lower() + + try: + if cmd == "bmi": + bmi(float(sys.argv[2]), float(sys.argv[3])) + + elif cmd == "tdee": + tdee( + float(sys.argv[2]), float(sys.argv[3]), + int(sys.argv[4]), sys.argv[5], int(sys.argv[6]), + ) + + elif cmd in ("1rm", "orm"): + one_rep_max(float(sys.argv[2]), int(sys.argv[3])) + + elif cmd == "macros": + macros(float(sys.argv[2]), sys.argv[3]) + + elif cmd == "bodyfat": + sex = sys.argv[2] + if sex.upper() == "M": + bodyfat(sex, float(sys.argv[3]), float(sys.argv[4]), 0, float(sys.argv[5])) + else: + bodyfat(sex, float(sys.argv[3]), float(sys.argv[4]), float(sys.argv[5]), float(sys.argv[6])) + + else: + print(f"Unknown command: {cmd}") + usage() + + except (IndexError, ValueError) as e: + print(f"Error: {e}") + usage() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/optional-skills/health/fitness-nutrition/scripts/nutrition_search.py b/optional-skills/health/fitness-nutrition/scripts/nutrition_search.py new file mode 100644 index 000000000..7494f6c38 --- /dev/null +++ b/optional-skills/health/fitness-nutrition/scripts/nutrition_search.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +nutrition_search.py — Search USDA FoodData Central for nutrition info. + +Usage: + python3 nutrition_search.py "chicken breast" + python3 nutrition_search.py "rice" "eggs" "broccoli" + echo -e "oats\\nbanana\\nwhey protein" | python3 nutrition_search.py - + +Reads USDA_API_KEY from environment, falls back to DEMO_KEY. +No external dependencies. +""" +import sys +import os +import json +import time +import urllib.request +import urllib.parse +import urllib.error + +API_KEY = os.environ.get("USDA_API_KEY", "DEMO_KEY") +BASE = "https://api.nal.usda.gov/fdc/v1" + + +def search(query, max_results=3): + encoded = urllib.parse.quote(query) + url = ( + f"{BASE}/foods/search?api_key={API_KEY}" + f"&query={encoded}&pageSize={max_results}" + f"&dataType=Foundation,SR%20Legacy" + ) + try: + req = urllib.request.Request(url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(req, timeout=15) as r: + return json.loads(r.read()) + except Exception as e: + print(f" API error: {e}", file=sys.stderr) + return None + + +def display(food): + nutrients = {n["nutrientName"]: n.get("value", "?") for n in food.get("foodNutrients", [])} + cal = nutrients.get("Energy", "?") + prot = nutrients.get("Protein", "?") + fat = nutrients.get("Total lipid (fat)", "?") + carb = nutrients.get("Carbohydrate, by difference", "?") + fib = nutrients.get("Fiber, total dietary", "?") + sug = nutrients.get("Sugars, total including NLEA", "?") + + print(f" {food.get('description', 'N/A')}") + print(f" Calories : {cal} kcal") + print(f" Protein : {prot}g") + print(f" Fat : {fat}g") + print(f" Carbs : {carb}g (fiber: {fib}g, sugar: {sug}g)") + print(f" FDC ID : {food.get('fdcId', 'N/A')}") + + +def main(): + if len(sys.argv) < 2: + print(__doc__) + sys.exit(1) + + if sys.argv[1] == "-": + queries = [line.strip() for line in sys.stdin if line.strip()] + else: + queries = sys.argv[1:] + + for query in queries: + print(f"\n--- {query.upper()} (per 100g) ---") + data = search(query, max_results=2) + if not data or not data.get("foods"): + print(" No results found.") + else: + for food in data["foods"]: + display(food) + print() + if len(queries) > 1: + time.sleep(1) # respect rate limits + + if API_KEY == "DEMO_KEY": + print("\nTip: using DEMO_KEY (30 req/hr). Set USDA_API_KEY for 1000 req/hr.") + print("Free signup: https://fdc.nal.usda.gov/api-key-signup/") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/optional-skills/research/drug-discovery/SKILL.md b/optional-skills/research/drug-discovery/SKILL.md new file mode 100644 index 000000000..dc3bd3e7b --- /dev/null +++ b/optional-skills/research/drug-discovery/SKILL.md @@ -0,0 +1,226 @@ +--- +name: drug-discovery +description: > + Pharmaceutical research assistant for drug discovery workflows. Search + bioactive compounds on ChEMBL, calculate drug-likeness (Lipinski Ro5, QED, + TPSA, synthetic accessibility), look up drug-drug interactions via + OpenFDA, interpret ADMET profiles, and assist with lead optimization. + Use for medicinal chemistry questions, molecule property analysis, clinical + pharmacology, and open-science drug research. +version: 1.0.0 +author: bennytimz +license: MIT +metadata: + hermes: + tags: [science, chemistry, pharmacology, research, health] +prerequisites: + commands: [curl, python3] +--- + +# Drug Discovery & Pharmaceutical Research + +You are an expert pharmaceutical scientist and medicinal chemist with deep +knowledge of drug discovery, cheminformatics, and clinical pharmacology. +Use this skill for all pharma/chemistry research tasks. + +## Core Workflows + +### 1 — Bioactive Compound Search (ChEMBL) + +Search ChEMBL (the world's largest open bioactivity database) for compounds +by target, activity, or molecule name. No API key required. + +```bash +# Search compounds by target name (e.g. "EGFR", "COX-2", "ACE") +TARGET="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$TARGET") +curl -s "https://www.ebi.ac.uk/chembl/api/data/target/search?q=${ENCODED}&format=json" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +targets=data.get('targets',[])[:5] +for t in targets: + print(f\"ChEMBL ID : {t.get('target_chembl_id')}\") + print(f\"Name : {t.get('pref_name')}\") + print(f\"Type : {t.get('target_type')}\") + print() +" +``` + +```bash +# Get bioactivity data for a ChEMBL target ID +TARGET_ID="$1" # e.g. CHEMBL203 +curl -s "https://www.ebi.ac.uk/chembl/api/data/activity?target_chembl_id=${TARGET_ID}&pchembl_value__gte=6&limit=10&format=json" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +acts=data.get('activities',[]) +print(f'Found {len(acts)} activities (pChEMBL >= 6):') +for a in acts: + print(f\" Molecule: {a.get('molecule_chembl_id')} | {a.get('standard_type')}: {a.get('standard_value')} {a.get('standard_units')} | pChEMBL: {a.get('pchembl_value')}\") +" +``` + +```bash +# Look up a specific molecule by ChEMBL ID +MOL_ID="$1" # e.g. CHEMBL25 (aspirin) +curl -s "https://www.ebi.ac.uk/chembl/api/data/molecule/${MOL_ID}?format=json" \ + | python3 -c " +import json,sys +m=json.load(sys.stdin) +props=m.get('molecule_properties',{}) or {} +print(f\"Name : {m.get('pref_name','N/A')}\") +print(f\"SMILES : {m.get('molecule_structures',{}).get('canonical_smiles','N/A') if m.get('molecule_structures') else 'N/A'}\") +print(f\"MW : {props.get('full_mwt','N/A')} Da\") +print(f\"LogP : {props.get('alogp','N/A')}\") +print(f\"HBD : {props.get('hbd','N/A')}\") +print(f\"HBA : {props.get('hba','N/A')}\") +print(f\"TPSA : {props.get('psa','N/A')} Ų\") +print(f\"Ro5 violations: {props.get('num_ro5_violations','N/A')}\") +print(f\"QED : {props.get('qed_weighted','N/A')}\") +" +``` + +### 2 — Drug-Likeness Calculation (Lipinski Ro5 + Veber) + +Assess any molecule against established oral bioavailability rules using +PubChem's free property API — no RDKit install needed. + +```bash +COMPOUND="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$COMPOUND") +curl -s "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/${ENCODED}/property/MolecularWeight,XLogP,HBondDonorCount,HBondAcceptorCount,RotatableBondCount,TPSA,InChIKey/JSON" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +props=data['PropertyTable']['Properties'][0] +mw = float(props.get('MolecularWeight', 0)) +logp = float(props.get('XLogP', 0)) +hbd = int(props.get('HBondDonorCount', 0)) +hba = int(props.get('HBondAcceptorCount', 0)) +rot = int(props.get('RotatableBondCount', 0)) +tpsa = float(props.get('TPSA', 0)) +print('=== Lipinski Rule of Five (Ro5) ===') +print(f' MW {mw:.1f} Da {\"✓\" if mw<=500 else \"✗ VIOLATION (>500)\"}') +print(f' LogP {logp:.2f} {\"✓\" if logp<=5 else \"✗ VIOLATION (>5)\"}') +print(f' HBD {hbd} {\"✓\" if hbd<=5 else \"✗ VIOLATION (>5)\"}') +print(f' HBA {hba} {\"✓\" if hba<=10 else \"✗ VIOLATION (>10)\"}') +viol = sum([mw>500, logp>5, hbd>5, hba>10]) +print(f' Violations: {viol}/4 {\"→ Likely orally bioavailable\" if viol<=1 else \"→ Poor oral bioavailability predicted\"}') +print() +print('=== Veber Oral Bioavailability Rules ===') +print(f' TPSA {tpsa:.1f} Ų {\"✓\" if tpsa<=140 else \"✗ VIOLATION (>140)\"}') +print(f' Rot. bonds {rot} {\"✓\" if rot<=10 else \"✗ VIOLATION (>10)\"}') +print(f' Both rules met: {\"Yes → good oral absorption predicted\" if tpsa<=140 and rot<=10 else \"No → reduced oral absorption\"}') +" +``` + +### 3 — Drug Interaction & Safety Lookup (OpenFDA) + +```bash +DRUG="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$DRUG") +curl -s "https://api.fda.gov/drug/label.json?search=drug_interactions:\"${ENCODED}\"&limit=3" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +results=data.get('results',[]) +if not results: + print('No interaction data found in FDA labels.') + sys.exit() +for r in results[:2]: + brand=r.get('openfda',{}).get('brand_name',['Unknown'])[0] + generic=r.get('openfda',{}).get('generic_name',['Unknown'])[0] + interactions=r.get('drug_interactions',['N/A'])[0] + print(f'--- {brand} ({generic}) ---') + print(interactions[:800]) + print() +" +``` + +```bash +DRUG="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$DRUG") +curl -s "https://api.fda.gov/drug/event.json?search=patient.drug.medicinalproduct:\"${ENCODED}\"&count=patient.reaction.reactionmeddrapt.exact&limit=10" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +results=data.get('results',[]) +if not results: + print('No adverse event data found.') + sys.exit() +print(f'Top adverse events reported:') +for r in results[:10]: + print(f\" {r['count']:>5}x {r['term']}\") +" +``` + +### 4 — PubChem Compound Search + +```bash +COMPOUND="$1" +ENCODED=$(python3 -c "import urllib.parse,sys; print(urllib.parse.quote(sys.argv[1]))" "$COMPOUND") +CID=$(curl -s "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/${ENCODED}/cids/TXT" | head -1 | tr -d '[:space:]') +echo "PubChem CID: $CID" +curl -s "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/${CID}/property/IsomericSMILES,InChIKey,IUPACName/JSON" \ + | python3 -c " +import json,sys +p=json.load(sys.stdin)['PropertyTable']['Properties'][0] +print(f\"IUPAC Name : {p.get('IUPACName','N/A')}\") +print(f\"SMILES : {p.get('IsomericSMILES','N/A')}\") +print(f\"InChIKey : {p.get('InChIKey','N/A')}\") +" +``` + +### 5 — Target & Disease Literature (OpenTargets) + +```bash +GENE="$1" +curl -s -X POST "https://api.platform.opentargets.org/api/v4/graphql" \ + -H "Content-Type: application/json" \ + -d "{\"query\":\"{ search(queryString: \\\"${GENE}\\\", entityNames: [\\\"target\\\"], page: {index: 0, size: 1}) { hits { id score object { ... on Target { id approvedSymbol approvedName associatedDiseases(page: {index: 0, size: 5}) { count rows { score disease { id name } } } } } } } }\"}" \ + | python3 -c " +import json,sys +data=json.load(sys.stdin) +hits=data.get('data',{}).get('search',{}).get('hits',[]) +if not hits: + print('Target not found.') + sys.exit() +obj=hits[0]['object'] +print(f\"Target: {obj.get('approvedSymbol')} — {obj.get('approvedName')}\") +assoc=obj.get('associatedDiseases',{}) +print(f\"Associated with {assoc.get('count',0)} diseases. Top associations:\") +for row in assoc.get('rows',[]): + print(f\" Score {row['score']:.3f} | {row['disease']['name']}\") +" +``` + +## Reasoning Guidelines + +When analysing drug-likeness or molecular properties, always: + +1. **State raw values first** — MW, LogP, HBD, HBA, TPSA, RotBonds +2. **Apply rule sets** — Ro5 (Lipinski), Veber, Ghose filter where relevant +3. **Flag liabilities** — metabolic hotspots, hERG risk, high TPSA for CNS penetration +4. **Suggest optimizations** — bioisosteric replacements, prodrug strategies, ring truncation +5. **Cite the source API** — ChEMBL, PubChem, OpenFDA, or OpenTargets + +For ADMET questions, reason through Absorption, Distribution, Metabolism, Excretion, Toxicity systematically. See references/ADMET_REFERENCE.md for detailed guidance. + +## Important Notes + +- All APIs are free, public, require no authentication +- ChEMBL rate limits: add sleep 1 between batch requests +- FDA data reflects reported adverse events, not necessarily causation +- Always recommend consulting a licensed pharmacist or physician for clinical decisions + +## Quick Reference + +| Task | API | Endpoint | +|------|-----|----------| +| Find target | ChEMBL | `/api/data/target/search?q=` | +| Get bioactivity | ChEMBL | `/api/data/activity?target_chembl_id=` | +| Molecule properties | PubChem | `/rest/pug/compound/name/{name}/property/` | +| Drug interactions | OpenFDA | `/drug/label.json?search=drug_interactions:` | +| Adverse events | OpenFDA | `/drug/event.json?search=...&count=reaction` | +| Gene-disease | OpenTargets | GraphQL POST `/api/v4/graphql` | diff --git a/optional-skills/research/drug-discovery/references/ADMET_REFERENCE.md b/optional-skills/research/drug-discovery/references/ADMET_REFERENCE.md new file mode 100644 index 000000000..92a5e9503 --- /dev/null +++ b/optional-skills/research/drug-discovery/references/ADMET_REFERENCE.md @@ -0,0 +1,66 @@ +# ADMET Reference Guide + +Comprehensive reference for Absorption, Distribution, Metabolism, Excretion, and Toxicity (ADMET) analysis in drug discovery. + +## Drug-Likeness Rule Sets + +### Lipinski's Rule of Five (Ro5) + +| Property | Threshold | +|----------|-----------| +| Molecular Weight (MW) | ≤ 500 Da | +| Lipophilicity (LogP) | ≤ 5 | +| H-Bond Donors (HBD) | ≤ 5 | +| H-Bond Acceptors (HBA) | ≤ 10 | + +Reference: Lipinski et al., Adv. Drug Deliv. Rev. 23, 3–25 (1997). + +### Veber's Oral Bioavailability Rules + +| Property | Threshold | +|----------|-----------| +| TPSA | ≤ 140 Ų | +| Rotatable Bonds | ≤ 10 | + +Reference: Veber et al., J. Med. Chem. 45, 2615–2623 (2002). + +### CNS Penetration (BBB) + +| Property | CNS-Optimal | +|----------|-------------| +| MW | ≤ 400 Da | +| LogP | 1–3 | +| TPSA | < 90 Ų | +| HBD | ≤ 3 | + +## CYP450 Metabolism + +| Isoform | % Drugs | Notable inhibitors | +|---------|---------|-------------------| +| CYP3A4 | ~50% | Grapefruit, ketoconazole | +| CYP2D6 | ~25% | Fluoxetine, paroxetine | +| CYP2C9 | ~15% | Fluconazole, amiodarone | +| CYP2C19 | ~10% | Omeprazole, fluoxetine | +| CYP1A2 | ~5% | Fluvoxamine, ciprofloxacin | + +## hERG Cardiac Toxicity Risk + +Structural alerts: basic nitrogen (pKa 7–9) + aromatic ring + hydrophobic moiety, LogP > 3.5 + basic amine. + +Mitigation: reduce basicity, introduce polar groups, break planarity. + +## Common Bioisosteric Replacements + +| Original | Bioisostere | Purpose | +|----------|-------------|---------| +| -COOH | -tetrazole, -SO₂NH₂ | Improve permeability | +| -OH (phenol) | -F, -CN | Reduce glucuronidation | +| Phenyl | Pyridine, thiophene | Reduce LogP | +| Ester | -CONHR | Reduce hydrolysis | + +## Key APIs + +- ChEMBL: https://www.ebi.ac.uk/chembl/api/data/ +- PubChem: https://pubchem.ncbi.nlm.nih.gov/rest/pug/ +- OpenFDA: https://api.fda.gov/drug/ +- OpenTargets GraphQL: https://api.platform.opentargets.org/api/v4/graphql diff --git a/optional-skills/research/drug-discovery/scripts/chembl_target.py b/optional-skills/research/drug-discovery/scripts/chembl_target.py new file mode 100644 index 000000000..1346b999a --- /dev/null +++ b/optional-skills/research/drug-discovery/scripts/chembl_target.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +""" +chembl_target.py — Search ChEMBL for a target and retrieve top active compounds. +Usage: python3 chembl_target.py "EGFR" --min-pchembl 7 --limit 20 +No external dependencies. +""" +import sys, json, time, argparse +import urllib.request, urllib.parse, urllib.error + +BASE = "https://www.ebi.ac.uk/chembl/api/data" + +def get(endpoint): + try: + req = urllib.request.Request(f"{BASE}{endpoint}", headers={"Accept":"application/json"}) + with urllib.request.urlopen(req, timeout=15) as r: + return json.loads(r.read()) + except Exception as e: + print(f"API error: {e}", file=sys.stderr); return None + +def main(): + parser = argparse.ArgumentParser(description="ChEMBL target → active compounds") + parser.add_argument("target") + parser.add_argument("--min-pchembl", type=float, default=6.0) + parser.add_argument("--limit", type=int, default=10) + args = parser.parse_args() + + enc = urllib.parse.quote(args.target) + data = get(f"/target/search?q={enc}&limit=5&format=json") + if not data or not data.get("targets"): + print("No targets found."); sys.exit(1) + + t = data["targets"][0] + tid = t.get("target_chembl_id","") + print(f"\nTarget: {t.get('pref_name')} ({tid})") + print(f"Type: {t.get('target_type')} | Organism: {t.get('organism','N/A')}") + print(f"\nFetching compounds with pChEMBL ≥ {args.min_pchembl}...\n") + + acts = get(f"/activity?target_chembl_id={tid}&pchembl_value__gte={args.min_pchembl}&assay_type=B&limit={args.limit}&order_by=-pchembl_value&format=json") + if not acts or not acts.get("activities"): + print("No activities found."); sys.exit(0) + + print(f"{'Molecule':<18} {'pChEMBL':>8} {'Type':<12} {'Value':<10} {'Units'}") + print("-"*65) + seen = set() + for a in acts["activities"]: + mid = a.get("molecule_chembl_id","N/A") + if mid in seen: continue + seen.add(mid) + print(f"{mid:<18} {str(a.get('pchembl_value','N/A')):>8} {str(a.get('standard_type','N/A')):<12} {str(a.get('standard_value','N/A')):<10} {a.get('standard_units','N/A')}") + time.sleep(0.1) + print(f"\nTotal: {len(seen)} unique molecules") + +if __name__ == "__main__": main() diff --git a/optional-skills/research/drug-discovery/scripts/ro5_screen.py b/optional-skills/research/drug-discovery/scripts/ro5_screen.py new file mode 100644 index 000000000..84e438fa1 --- /dev/null +++ b/optional-skills/research/drug-discovery/scripts/ro5_screen.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" +ro5_screen.py — Batch Lipinski Ro5 + Veber screening via PubChem API. +Usage: python3 ro5_screen.py aspirin ibuprofen paracetamol +No external dependencies beyond stdlib. +""" +import sys, json, time, argparse +import urllib.request, urllib.parse, urllib.error + +BASE = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name" +PROPS = "MolecularWeight,XLogP,HBondDonorCount,HBondAcceptorCount,RotatableBondCount,TPSA" + +def fetch(name): + url = f"{BASE}/{urllib.parse.quote(name)}/property/{PROPS}/JSON" + try: + with urllib.request.urlopen(url, timeout=10) as r: + return json.loads(r.read())["PropertyTable"]["Properties"][0] + except Exception: + return None + +def check(p): + mw,logp,hbd,hba,rot,tpsa = float(p.get("MolecularWeight",0)),float(p.get("XLogP",0)),int(p.get("HBondDonorCount",0)),int(p.get("HBondAcceptorCount",0)),int(p.get("RotatableBondCount",0)),float(p.get("TPSA",0)) + v = sum([mw>500,logp>5,hbd>5,hba>10]) + return dict(mw=mw,logp=logp,hbd=hbd,hba=hba,rot=rot,tpsa=tpsa,violations=v,ro5=v<=1,veber=tpsa<=140 and rot<=10,ok=v<=1 and tpsa<=140 and rot<=10) + +def report(name, r): + if not r: print(f"✗ {name:30s} — not found"); return + s = "✓ PASS" if r["ok"] else "✗ FAIL" + flags = (f" [Ro5 violations:{r['violations']}]" if not r["ro5"] else "") + (" [Veber fail]" if not r["veber"] else "") + print(f"{s} {name:28s} MW={r['mw']:.0f} LogP={r['logp']:.2f} HBD={r['hbd']} HBA={r['hba']} TPSA={r['tpsa']:.0f} RotB={r['rot']}{flags}") + +def main(): + compounds = sys.stdin.read().splitlines() if len(sys.argv)<2 or sys.argv[1]=="-" else sys.argv[1:] + print(f"\n{'Status':<8} {'Compound':<30} Properties\n" + "-"*85) + passed = 0 + for name in compounds: + props = fetch(name.strip()) + result = check(props) if props else None + report(name.strip(), result) + if result and result["ok"]: passed += 1 + time.sleep(0.3) + print(f"\nSummary: {passed}/{len(compounds)} passed Ro5 + Veber.\n") + +if __name__ == "__main__": main() diff --git a/run_agent.py b/run_agent.py index fdfdca85a..676e0ffc7 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6144,6 +6144,12 @@ class AIAgent: elif self.reasoning_config.get("effort"): reasoning_effort = self.reasoning_config["effort"] + # Clamp effort levels not supported by the Responses API model. + # GPT-5.4 supports none/low/medium/high/xhigh but not "minimal". + # "minimal" is valid on OpenRouter and GPT-5 but fails on 5.2/5.4. + _effort_clamp = {"minimal": "low"} + reasoning_effort = _effort_clamp.get(reasoning_effort, reasoning_effort) + kwargs = { "model": self.model, "instructions": instructions, @@ -6891,6 +6897,18 @@ class AIAgent: tools. Used by the concurrent execution path; the sequential path retains its own inline invocation for backward-compatible display handling. """ + # Check plugin hooks for a block directive before executing anything. + block_message: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + pass + if block_message is not None: + return json.dumps({"error": block_message}, ensure_ascii=False) + if function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool return _todo_tool( @@ -6955,6 +6973,7 @@ class AIAgent: tool_call_id=tool_call_id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: @@ -7185,12 +7204,6 @@ class AIAgent: function_name = tool_call.function.name - # Reset nudge counters when the relevant tool is actually used - if function_name == "memory": - self._turns_since_memory = 0 - elif function_name == "skill_manage": - self._iters_since_skill = 0 - try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: @@ -7199,6 +7212,27 @@ class AIAgent: if not isinstance(function_args, dict): function_args = {} + # Check plugin hooks for a block directive before executing. + _block_msg: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + _block_msg = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + pass + + if _block_msg is not None: + # Tool blocked by plugin policy — skip counter resets. + # Execution is handled below in the tool dispatch chain. + pass + else: + # Reset nudge counters when the relevant tool is actually used + if function_name == "memory": + self._turns_since_memory = 0 + elif function_name == "skill_manage": + self._iters_since_skill = 0 + if not self.quiet_mode: args_str = json.dumps(function_args, ensure_ascii=False) if self.verbose_logging: @@ -7208,33 +7242,35 @@ class AIAgent: args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") - self._current_tool = function_name - self._touch_activity(f"executing tool: {function_name}") + if _block_msg is None: + self._current_tool = function_name + self._touch_activity(f"executing tool: {function_name}") # Set activity callback for long-running tool execution (terminal # commands, etc.) so the gateway's inactivity monitor doesn't kill # the agent while a command is running. - try: - from tools.environments.base import set_activity_callback - set_activity_callback(self._touch_activity) - except Exception: - pass + if _block_msg is None: + try: + from tools.environments.base import set_activity_callback + set_activity_callback(self._touch_activity) + except Exception: + pass - if self.tool_progress_callback: + if _block_msg is None and self.tool_progress_callback: try: preview = _build_tool_preview(function_name, function_args) self.tool_progress_callback("tool.started", function_name, preview, function_args) except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") - if self.tool_start_callback: + if _block_msg is None and self.tool_start_callback: try: self.tool_start_callback(tool_call.id, function_name, function_args) except Exception as cb_err: logging.debug(f"Tool start callback error: {cb_err}") # Checkpoint: snapshot working dir before file-mutating tools - if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: + if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: try: file_path = function_args.get("path", "") if file_path: @@ -7246,7 +7282,7 @@ class AIAgent: pass # never block tool execution # Checkpoint before destructive terminal commands - if function_name == "terminal" and self._checkpoint_mgr.enabled: + if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled: try: cmd = function_args.get("command", "") if _is_destructive_command(cmd): @@ -7259,7 +7295,11 @@ class AIAgent: tool_start_time = time.time() - if function_name == "todo": + if _block_msg is not None: + # Tool blocked by plugin policy — return error without executing. + function_result = json.dumps({"error": _block_msg}, ensure_ascii=False) + tool_duration = 0.0 + elif function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool function_result = _todo_tool( todos=function_args.get("todos"), @@ -7402,6 +7442,7 @@ class AIAgent: tool_call_id=tool_call.id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) _spinner_result = function_result except Exception as tool_error: @@ -7421,6 +7462,7 @@ class AIAgent: tool_call_id=tool_call.id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) except Exception as tool_error: function_result = f"Error executing tool '{function_name}': {tool_error}" diff --git a/scripts/contributor_audit.py b/scripts/contributor_audit.py index 5d39f8316..474b0d52b 100644 --- a/scripts/contributor_audit.py +++ b/scripts/contributor_audit.py @@ -333,6 +333,16 @@ def main(): default=None, help="Path to a release notes file to check for missing contributors", ) + parser.add_argument( + "--strict", + action="store_true", + help="Exit with code 1 if new unmapped emails are found (for CI)", + ) + parser.add_argument( + "--diff-base", + default=None, + help="Git ref to diff against (only flag emails from commits after this ref)", + ) args = parser.parse_args() print(f"=== Contributor Audit: {args.since_tag}..{args.until} ===") @@ -398,6 +408,42 @@ def main(): for email, name in sorted(all_unknowns.items()): print(f' "{email}": "{name}",') + # ---- Strict mode: fail CI if new unmapped emails are introduced ---- + if args.strict and all_unknowns: + # In strict mode, check if ANY unknown emails come from commits in this + # PR's diff range (new unmapped emails that weren't there before). + # This is the CI gate: existing unknowns are grandfathered, but new + # commits must have their author email in AUTHOR_MAP. + new_unknowns = {} + if args.diff_base: + # Only flag emails from commits after diff_base + new_commits_output = git( + "log", f"{args.diff_base}..HEAD", + "--format=%ae", "--no-merges", + ) + new_emails = set(new_commits_output.splitlines()) if new_commits_output else set() + for email, name in all_unknowns.items(): + if email in new_emails: + new_unknowns[email] = name + else: + new_unknowns = all_unknowns + + if new_unknowns: + print() + print(f"=== STRICT MODE FAILURE: {len(new_unknowns)} new unmapped email(s) ===") + print("Add these to AUTHOR_MAP in scripts/release.py before merging:") + print() + for email, name in sorted(new_unknowns.items()): + print(f' "{email}": "",') + print() + print("To find the GitHub username:") + print(" gh api 'search/users?q=EMAIL+in:email' --jq '.items[0].login'") + strict_failed = True + else: + strict_failed = False + else: + strict_failed = False + # ---- Release file comparison ---- if args.release_file: print() @@ -419,6 +465,9 @@ def main(): print() print("Done.") + if strict_failed: + sys.exit(1) + if __name__ == "__main__": main() diff --git a/scripts/release.py b/scripts/release.py index 84d057ea0..5cc938ca3 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -112,6 +112,85 @@ AUTHOR_MAP = { "dalvidjr2022@gmail.com": "Jr-kenny", "m@statecraft.systems": "mbierling", "balyan.sid@gmail.com": "balyansid", + "oluwadareab12@gmail.com": "bennytimz", + # ── bulk addition: 75 emails resolved via API, PR salvage bodies, noreply + # crossref, and GH contributor list matching (April 2026 audit) ── + "1115117931@qq.com": "aaronagent", + "1506751656@qq.com": "hqhq1025", + "364939526@qq.com": "luyao618", + "aaronwong1999@icloud.com": "AaronWong1999", + "agents@kylefrench.dev": "DeployFaith", + "angelos@oikos.lan.home.malaiwah.com": "angelos", + "aptx4561@gmail.com": "cokemine", + "arilotter@gmail.com": "ethernet8023", + "ben@nousresearch.com": "benbarclay", + "birdiegyal@gmail.com": "yyovil", + "boschi1997@gmail.com": "nicoloboschi", + "chef.ya@gmail.com": "cherifya", + "chlqhdtn98@gmail.com": "BongSuCHOI", + "coffeemjj@gmail.com": "Cafexss", + "dalianmao0107@gmail.com": "dalianmao000", + "der@konsi.org": "konsisumer", + "dgrieco@redhat.com": "DomGrieco", + "dhicham.pro@gmail.com": "spideystreet", + "dipp.who@gmail.com": "dippwho", + "don.rhm@gmail.com": "donrhmexe", + "dorukardahan@hotmail.com": "dorukardahan", + "dsocolobsky@gmail.com": "dsocolobsky", + "duerzy@gmail.com": "duerzy", + "emozilla@nousresearch.com": "emozilla", + "fancydirty@gmail.com": "fancydirty", + "floptopbot33@gmail.com": "flobo3", + "fontana.pedro93@gmail.com": "pefontana", + "francis.x.fitzpatrick@gmail.com": "fxfitz", + "frank@helmschrott.de": "Helmi", + "gaixg94@gmail.com": "gaixianggeng", + "geoff.wellman@gmail.com": "geoffwellman", + "han.shan@live.cn": "jamesarch", + "haolong@microsoft.com": "LongOddCode", + "hata1234@gmail.com": "hata1234", + "hmbown@gmail.com": "Hmbown", + "iacobs@m0n5t3r.info": "m0n5t3r", + "jiayuw794@gmail.com": "JiayuuWang", + "jonny@nousresearch.com": "jquesnelle", + "juan.ovalle@mistral.ai": "jjovalle99", + "julien.talbot@ergonomia.re": "Julientalbot", + "kagura.chen28@gmail.com": "kagura-agent", + "kamil@gwozdz.me": "kamil-gwozdz", + "karamusti912@gmail.com": "MustafaKara7", + "kira@ariaki.me": "kira-ariaki", + "knopki@duck.com": "knopki", + "limars874@gmail.com": "limars874", + "lisicheng168@gmail.com": "lesterli", + "mingjwan@microsoft.com": "MagicRay1217", + "niyant@spicefi.xyz": "spniyant", + "olafthiele@gmail.com": "olafthiele", + "oncuevtv@gmail.com": "sprmn24", + "programming@olafthiele.com": "olafthiele", + "r2668940489@gmail.com": "r266-tech", + "s5460703@gmail.com": "BlackishGreen33", + "saul.jj.wu@gmail.com": "SaulJWu", + "shenhaocheng19990111@gmail.com": "hcshen0111", + "sjtuwbh@gmail.com": "Cygra", + "srhtsrht17@gmail.com": "Sertug17", + "stephenschoettler@gmail.com": "stephenschoettler", + "tanishq231003@gmail.com": "yyovil", + "tesseracttars@gmail.com": "tesseracttars-creator", + "tianliangjay@gmail.com": "xingkongliang", + "tranquil_flow@protonmail.com": "Tranquil-Flow", + "unayung@gmail.com": "Unayung", + "vorvul.danylo@gmail.com": "WorldInnovationsDepartment", + "win4r@outlook.com": "win4r", + "xush@xush.org": "KUSH42", + "yangzhi.see@gmail.com": "SeeYangZhi", + "yongtenglei@gmail.com": "yongtenglei", + "young@YoungdeMacBook-Pro.local": "YoungYang963", + "ysfalweshcan@gmail.com": "Awsh1", + "ysfwaxlycan@gmail.com": "WAXLYY", + "yusufalweshdemir@gmail.com": "Dusk1e", + "zhouboli@gmail.com": "zhouboli", + "zqiao@microsoft.com": "tomqiaozc", + "zzn+pa@zzn.im": "xinbenlv", } diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index e6a9d1919..3b44cba4d 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -365,7 +365,7 @@ class TestExpiredCodexFallback: def test_hermes_oauth_file_sets_oauth_flag(self, monkeypatch): """OAuth-style tokens should get is_oauth=*** (token is not sk-ant-api-*).""" # Mock resolve_anthropic_token to return an OAuth-style token - with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="hermes-oauth-jwt-token"), \ + with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-oat-hermes-token"), \ patch("agent.anthropic_adapter.build_anthropic_client") as mock_build, \ patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)): mock_build.return_value = MagicMock() @@ -420,7 +420,7 @@ class TestExpiredCodexFallback: def test_claude_code_oauth_env_sets_flag(self, monkeypatch): """CLAUDE_CODE_OAUTH_TOKEN env var should get is_oauth=True.""" - monkeypatch.setenv("CLAUDE_CODE_OAUTH_TOKEN", "cc-oauth-token-test") + monkeypatch.setenv("CLAUDE_CODE_OAUTH_TOKEN", "sk-ant-oat-cc-test-token") monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False) with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build: mock_build.return_value = MagicMock() @@ -786,7 +786,7 @@ class TestAuxiliaryPoolAwareness: patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()), patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"), ): - client, model = get_vision_auxiliary_client() + provider, client, model = resolve_vision_provider_client() assert client is not None assert client.__class__.__name__ == "AnthropicAuxiliaryClient" diff --git a/tests/agent/test_compress_focus.py b/tests/agent/test_compress_focus.py index a569eb9e3..8b5b1d35d 100644 --- a/tests/agent/test_compress_focus.py +++ b/tests/agent/test_compress_focus.py @@ -25,6 +25,11 @@ def _make_compressor(): compressor._previous_summary = None compressor._summary_failure_cooldown_until = 0.0 compressor.summary_model = None + compressor.model = "test-model" + compressor.provider = "test" + compressor.base_url = "http://localhost" + compressor.api_key = "test-key" + compressor.api_mode = "chat_completions" return compressor diff --git a/tests/agent/test_memory_user_id.py b/tests/agent/test_memory_user_id.py index 04f90c74c..c1b82208d 100644 --- a/tests/agent/test_memory_user_id.py +++ b/tests/agent/test_memory_user_id.py @@ -109,14 +109,12 @@ class TestMemoryManagerUserIdThreading: assert "user_id" not in p._init_kwargs def test_multiple_providers_all_receive_user_id(self): - from agent.builtin_memory_provider import BuiltinMemoryProvider - mgr = MemoryManager() - # Use builtin + one external (MemoryManager only allows one external) - builtin = BuiltinMemoryProvider() - ext = RecordingProvider("external") - mgr.add_provider(builtin) - mgr.add_provider(ext) + # Use one provider named "builtin" (always accepted) and one external + p1 = RecordingProvider("builtin") + p2 = RecordingProvider("external") + mgr.add_provider(p1) + mgr.add_provider(p2) mgr.initialize_all( session_id="sess-multi", @@ -124,8 +122,10 @@ class TestMemoryManagerUserIdThreading: user_id="slack_U12345", ) - assert ext._init_kwargs.get("user_id") == "slack_U12345" - assert ext._init_kwargs.get("platform") == "slack" + assert p1._init_kwargs.get("user_id") == "slack_U12345" + assert p1._init_kwargs.get("platform") == "slack" + assert p2._init_kwargs.get("user_id") == "slack_U12345" + assert p2._init_kwargs.get("platform") == "slack" # --------------------------------------------------------------------------- @@ -211,17 +211,17 @@ class TestHonchoUserIdScoping: """Verify Honcho plugin uses gateway user_id for peer_name when provided.""" def test_gateway_user_id_overrides_peer_name(self): - """When user_id is in kwargs, cfg.peer_name should be overridden.""" + """When user_id is in kwargs and no explicit peer_name, user_id should be used.""" from plugins.memory.honcho import HonchoMemoryProvider provider = HonchoMemoryProvider() - # Create a mock config with a static peer_name + # Create a mock config with NO explicit peer_name mock_cfg = MagicMock() mock_cfg.enabled = True mock_cfg.api_key = "test-key" mock_cfg.base_url = None - mock_cfg.peer_name = "static-user" + mock_cfg.peer_name = "" # No explicit peer_name — user_id should fill it mock_cfg.recall_mode = "tools" # Use tools mode to defer session init with patch( diff --git a/tests/cli/test_cli_interrupt_subagent.py b/tests/cli/test_cli_interrupt_subagent.py index f4322ea6b..6821a6725 100644 --- a/tests/cli/test_cli_interrupt_subagent.py +++ b/tests/cli/test_cli_interrupt_subagent.py @@ -63,6 +63,7 @@ class TestCLISubagentInterrupt(unittest.TestCase): parent._delegate_depth = 0 parent._delegate_spinner = None parent.tool_progress_callback = None + parent._execution_thread_id = None # We'll track what happens with _active_children original_children = parent._active_children diff --git a/tests/cli/test_cli_provider_resolution.py b/tests/cli/test_cli_provider_resolution.py index 353b3234e..9c5bf0cca 100644 --- a/tests/cli/test_cli_provider_resolution.py +++ b/tests/cli/test_cli_provider_resolution.py @@ -576,8 +576,9 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys): monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None) # After the probe detects a single model ("llm"), the flow asks - # "Use this model? [Y/n]:" — confirm with Enter, then context length. - answers = iter(["http://localhost:8000", "local-key", "", ""]) + # "Use this model? [Y/n]:" — confirm with Enter, then context length, + # then display name. + answers = iter(["http://localhost:8000", "local-key", "", "", ""]) monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) monkeypatch.setattr("getpass.getpass", lambda _prompt="": next(answers)) @@ -641,3 +642,46 @@ def test_cmd_model_forwards_nous_login_tls_options(monkeypatch): "ca_bundle": "/tmp/local-ca.pem", "insecure": True, } + + +# --------------------------------------------------------------------------- +# _auto_provider_name — unit tests +# --------------------------------------------------------------------------- + +def test_auto_provider_name_localhost(): + from hermes_cli.main import _auto_provider_name + assert _auto_provider_name("http://localhost:11434/v1") == "Local (localhost:11434)" + assert _auto_provider_name("http://127.0.0.1:1234/v1") == "Local (127.0.0.1:1234)" + + +def test_auto_provider_name_runpod(): + from hermes_cli.main import _auto_provider_name + assert "RunPod" in _auto_provider_name("https://xyz.runpod.io/v1") + + +def test_auto_provider_name_remote(): + from hermes_cli.main import _auto_provider_name + result = _auto_provider_name("https://api.together.xyz/v1") + assert result == "Api.together.xyz" + + +def test_save_custom_provider_uses_provided_name(monkeypatch, tmp_path): + """When a display name is passed, it should appear in the saved entry.""" + import yaml + from hermes_cli.main import _save_custom_provider + + cfg_path = tmp_path / "config.yaml" + cfg_path.write_text(yaml.dump({})) + + monkeypatch.setattr( + "hermes_cli.config.load_config", lambda: yaml.safe_load(cfg_path.read_text()) or {}, + ) + saved = {} + def _save(cfg): + saved.update(cfg) + monkeypatch.setattr("hermes_cli.config.save_config", _save) + + _save_custom_provider("http://localhost:11434/v1", name="Ollama") + entries = saved.get("custom_providers", []) + assert len(entries) == 1 + assert entries[0]["name"] == "Ollama" diff --git a/tests/cli/test_fast_command.py b/tests/cli/test_fast_command.py index d39453c10..bc6c8e5fb 100644 --- a/tests/cli/test_fast_command.py +++ b/tests/cli/test_fast_command.py @@ -369,7 +369,8 @@ class TestAnthropicFastModeAdapter(unittest.TestCase): reasoning_config=None, fast_mode=True, ) - assert kwargs.get("speed") == "fast" + assert kwargs.get("extra_body", {}).get("speed") == "fast" + assert "speed" not in kwargs assert "extra_headers" in kwargs assert _FAST_MODE_BETA in kwargs["extra_headers"].get("anthropic-beta", "") @@ -384,6 +385,7 @@ class TestAnthropicFastModeAdapter(unittest.TestCase): reasoning_config=None, fast_mode=False, ) + assert kwargs.get("extra_body", {}).get("speed") is None assert "speed" not in kwargs assert "extra_headers" not in kwargs @@ -400,9 +402,24 @@ class TestAnthropicFastModeAdapter(unittest.TestCase): base_url="https://api.minimax.io/anthropic/v1", ) # Third-party endpoints should NOT get speed or fast-mode beta + assert kwargs.get("extra_body", {}).get("speed") is None assert "speed" not in kwargs assert "extra_headers" not in kwargs + def test_fast_mode_kwargs_are_safe_for_sdk_unpacking(self): + from agent.anthropic_adapter import build_anthropic_kwargs + + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", + messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}], + tools=None, + max_tokens=None, + reasoning_config=None, + fast_mode=True, + ) + assert "speed" not in kwargs + assert kwargs.get("extra_body", {}).get("speed") == "fast" + class TestConfigDefault(unittest.TestCase): def test_default_config_has_service_tier(self): diff --git a/tests/gateway/test_display_config.py b/tests/gateway/test_display_config.py index c9ad51280..ae2eac66e 100644 --- a/tests/gateway/test_display_config.py +++ b/tests/gateway/test_display_config.py @@ -220,41 +220,6 @@ class TestPlatformDefaults: assert resolve_display_setting({}, "telegram", "streaming") is None -# --------------------------------------------------------------------------- -# get_effective_display / get_platform_defaults -# --------------------------------------------------------------------------- - -class TestHelpers: - """Helper functions return correct composite results.""" - - def test_get_effective_display_merges_correctly(self): - from gateway.display_config import get_effective_display - - config = { - "display": { - "tool_progress": "new", - "show_reasoning": True, - "platforms": { - "telegram": {"tool_progress": "verbose"}, - }, - } - } - eff = get_effective_display(config, "telegram") - assert eff["tool_progress"] == "verbose" # platform override - assert eff["show_reasoning"] is True # global - assert "tool_preview_length" in eff # default filled in - - def test_get_platform_defaults_returns_dict(self): - from gateway.display_config import get_platform_defaults - - defaults = get_platform_defaults("telegram") - assert "tool_progress" in defaults - assert "show_reasoning" in defaults - # Returns a new dict (not the shared tier dict) - defaults["tool_progress"] = "changed" - assert get_platform_defaults("telegram")["tool_progress"] != "changed" - - # --------------------------------------------------------------------------- # Config migration: tool_progress_overrides → display.platforms # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_email.py b/tests/gateway/test_email.py index b6da07921..44e38aff4 100644 --- a/tests/gateway/test_email.py +++ b/tests/gateway/test_email.py @@ -334,10 +334,12 @@ class TestChannelDirectory(unittest.TestCase): """Verify email in channel directory session-based discovery.""" def test_email_in_session_discovery(self): - import gateway.channel_directory - import inspect - source = inspect.getsource(gateway.channel_directory.build_channel_directory) - self.assertIn('"email"', source) + from gateway.config import Platform + # Verify email is a Platform enum member — the dynamic loop in + # build_channel_directory iterates all Platform members, so email + # is included automatically as long as it's in the enum. + email_values = [p.value for p in Platform] + self.assertIn("email", email_values) class TestGatewaySetup(unittest.TestCase): diff --git a/tests/gateway/test_feishu.py b/tests/gateway/test_feishu.py index 2ef84f744..7b23a6985 100644 --- a/tests/gateway/test_feishu.py +++ b/tests/gateway/test_feishu.py @@ -631,6 +631,14 @@ class TestAdapterBehavior(unittest.TestCase): calls.append("card_action") return self + def register_p2_im_chat_member_bot_added_v1(self, _handler): + calls.append("bot_added") + return self + + def register_p2_im_chat_member_bot_deleted_v1(self, _handler): + calls.append("bot_deleted") + return self + def build(self): calls.append("build") return "handler" @@ -654,6 +662,8 @@ class TestAdapterBehavior(unittest.TestCase): "reaction_created", "reaction_deleted", "card_action", + "bot_added", + "bot_deleted", "build", ], ) diff --git a/tests/gateway/test_qqbot.py b/tests/gateway/test_qqbot.py new file mode 100644 index 000000000..d3ca5320d --- /dev/null +++ b/tests/gateway/test_qqbot.py @@ -0,0 +1,460 @@ +"""Tests for the QQ Bot platform adapter.""" + +import json +import os +import sys +from unittest import mock + +import pytest + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_config(**extra): + """Build a PlatformConfig(enabled=True, extra=extra) for testing.""" + return PlatformConfig(enabled=True, extra=extra) + + +# --------------------------------------------------------------------------- +# check_qq_requirements +# --------------------------------------------------------------------------- + +class TestQQRequirements: + def test_returns_bool(self): + from gateway.platforms.qqbot import check_qq_requirements + result = check_qq_requirements() + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# QQAdapter.__init__ +# --------------------------------------------------------------------------- + +class TestQQAdapterInit: + def _make(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_basic_attributes(self): + adapter = self._make(app_id="123", client_secret="sec") + assert adapter._app_id == "123" + assert adapter._client_secret == "sec" + + def test_env_fallback(self): + with mock.patch.dict(os.environ, {"QQ_APP_ID": "env_id", "QQ_CLIENT_SECRET": "env_sec"}, clear=False): + adapter = self._make() + assert adapter._app_id == "env_id" + assert adapter._client_secret == "env_sec" + + def test_env_fallback_extra_wins(self): + with mock.patch.dict(os.environ, {"QQ_APP_ID": "env_id"}, clear=False): + adapter = self._make(app_id="extra_id", client_secret="sec") + assert adapter._app_id == "extra_id" + + def test_dm_policy_default(self): + adapter = self._make(app_id="a", client_secret="b") + assert adapter._dm_policy == "open" + + def test_dm_policy_explicit(self): + adapter = self._make(app_id="a", client_secret="b", dm_policy="allowlist") + assert adapter._dm_policy == "allowlist" + + def test_group_policy_default(self): + adapter = self._make(app_id="a", client_secret="b") + assert adapter._group_policy == "open" + + def test_allow_from_parsing_string(self): + adapter = self._make(app_id="a", client_secret="b", allow_from="x, y , z") + assert adapter._allow_from == ["x", "y", "z"] + + def test_allow_from_parsing_list(self): + adapter = self._make(app_id="a", client_secret="b", allow_from=["a", "b"]) + assert adapter._allow_from == ["a", "b"] + + def test_allow_from_default_empty(self): + adapter = self._make(app_id="a", client_secret="b") + assert adapter._allow_from == [] + + def test_group_allow_from(self): + adapter = self._make(app_id="a", client_secret="b", group_allow_from="g1,g2") + assert adapter._group_allow_from == ["g1", "g2"] + + def test_markdown_support_default(self): + adapter = self._make(app_id="a", client_secret="b") + assert adapter._markdown_support is True + + def test_markdown_support_false(self): + adapter = self._make(app_id="a", client_secret="b", markdown_support=False) + assert adapter._markdown_support is False + + def test_name_property(self): + adapter = self._make(app_id="a", client_secret="b") + assert adapter.name == "QQBot" + + +# --------------------------------------------------------------------------- +# _coerce_list +# --------------------------------------------------------------------------- + +class TestCoerceList: + def _fn(self, value): + from gateway.platforms.qqbot import _coerce_list + return _coerce_list(value) + + def test_none(self): + assert self._fn(None) == [] + + def test_string(self): + assert self._fn("a, b ,c") == ["a", "b", "c"] + + def test_list(self): + assert self._fn(["x", "y"]) == ["x", "y"] + + def test_empty_string(self): + assert self._fn("") == [] + + def test_tuple(self): + assert self._fn(("a", "b")) == ["a", "b"] + + def test_single_item_string(self): + assert self._fn("hello") == ["hello"] + + +# --------------------------------------------------------------------------- +# _is_voice_content_type +# --------------------------------------------------------------------------- + +class TestIsVoiceContentType: + def _fn(self, content_type, filename): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter._is_voice_content_type(content_type, filename) + + def test_voice_content_type(self): + assert self._fn("voice", "msg.silk") is True + + def test_audio_content_type(self): + assert self._fn("audio/mp3", "file.mp3") is True + + def test_voice_extension(self): + assert self._fn("", "file.silk") is True + + def test_non_voice(self): + assert self._fn("image/jpeg", "photo.jpg") is False + + def test_audio_extension_amr(self): + assert self._fn("", "recording.amr") is True + + +# --------------------------------------------------------------------------- +# _strip_at_mention +# --------------------------------------------------------------------------- + +class TestStripAtMention: + def _fn(self, content): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter._strip_at_mention(content) + + def test_removes_mention(self): + result = self._fn("@BotUser hello there") + assert result == "hello there" + + def test_no_mention(self): + result = self._fn("just text") + assert result == "just text" + + def test_empty_string(self): + assert self._fn("") == "" + + def test_only_mention(self): + assert self._fn("@Someone ") == "" + + +# --------------------------------------------------------------------------- +# _is_dm_allowed +# --------------------------------------------------------------------------- + +class TestDmAllowed: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_open_policy(self): + adapter = self._make_adapter(app_id="a", client_secret="b", dm_policy="open") + assert adapter._is_dm_allowed("any_user") is True + + def test_disabled_policy(self): + adapter = self._make_adapter(app_id="a", client_secret="b", dm_policy="disabled") + assert adapter._is_dm_allowed("any_user") is False + + def test_allowlist_match(self): + adapter = self._make_adapter(app_id="a", client_secret="b", dm_policy="allowlist", allow_from="user1,user2") + assert adapter._is_dm_allowed("user1") is True + + def test_allowlist_no_match(self): + adapter = self._make_adapter(app_id="a", client_secret="b", dm_policy="allowlist", allow_from="user1,user2") + assert adapter._is_dm_allowed("user3") is False + + def test_allowlist_wildcard(self): + adapter = self._make_adapter(app_id="a", client_secret="b", dm_policy="allowlist", allow_from="*") + assert adapter._is_dm_allowed("anyone") is True + + +# --------------------------------------------------------------------------- +# _is_group_allowed +# --------------------------------------------------------------------------- + +class TestGroupAllowed: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_open_policy(self): + adapter = self._make_adapter(app_id="a", client_secret="b", group_policy="open") + assert adapter._is_group_allowed("grp1", "user1") is True + + def test_allowlist_match(self): + adapter = self._make_adapter(app_id="a", client_secret="b", group_policy="allowlist", group_allow_from="grp1") + assert adapter._is_group_allowed("grp1", "user1") is True + + def test_allowlist_no_match(self): + adapter = self._make_adapter(app_id="a", client_secret="b", group_policy="allowlist", group_allow_from="grp1") + assert adapter._is_group_allowed("grp2", "user1") is False + + +# --------------------------------------------------------------------------- +# _resolve_stt_config +# --------------------------------------------------------------------------- + +class TestResolveSTTConfig: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_no_config(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + with mock.patch.dict(os.environ, {}, clear=True): + assert adapter._resolve_stt_config() is None + + def test_env_config(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + with mock.patch.dict(os.environ, { + "QQ_STT_API_KEY": "key123", + "QQ_STT_BASE_URL": "https://example.com/v1", + "QQ_STT_MODEL": "my-model", + }, clear=True): + cfg = adapter._resolve_stt_config() + assert cfg is not None + assert cfg["api_key"] == "key123" + assert cfg["base_url"] == "https://example.com/v1" + assert cfg["model"] == "my-model" + + def test_extra_config(self): + stt_cfg = { + "baseUrl": "https://custom.api/v4", + "apiKey": "sk_extra", + "model": "glm-asr", + } + adapter = self._make_adapter(app_id="a", client_secret="b", stt=stt_cfg) + with mock.patch.dict(os.environ, {}, clear=True): + cfg = adapter._resolve_stt_config() + assert cfg is not None + assert cfg["base_url"] == "https://custom.api/v4" + assert cfg["api_key"] == "sk_extra" + assert cfg["model"] == "glm-asr" + + +# --------------------------------------------------------------------------- +# _detect_message_type +# --------------------------------------------------------------------------- + +class TestDetectMessageType: + def _fn(self, media_urls, media_types): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter._detect_message_type(media_urls, media_types) + + def test_no_media(self): + from gateway.platforms.base import MessageType + assert self._fn([], []) == MessageType.TEXT + + def test_image(self): + from gateway.platforms.base import MessageType + assert self._fn(["file.jpg"], ["image/jpeg"]) == MessageType.PHOTO + + def test_voice(self): + from gateway.platforms.base import MessageType + assert self._fn(["voice.silk"], ["audio/silk"]) == MessageType.VOICE + + def test_video(self): + from gateway.platforms.base import MessageType + assert self._fn(["vid.mp4"], ["video/mp4"]) == MessageType.VIDEO + + +# --------------------------------------------------------------------------- +# QQCloseError +# --------------------------------------------------------------------------- + +class TestQQCloseError: + def test_attributes(self): + from gateway.platforms.qqbot import QQCloseError + err = QQCloseError(4004, "bad token") + assert err.code == 4004 + assert err.reason == "bad token" + + def test_code_none(self): + from gateway.platforms.qqbot import QQCloseError + err = QQCloseError(None, "") + assert err.code is None + + def test_string_to_int(self): + from gateway.platforms.qqbot import QQCloseError + err = QQCloseError("4914", "banned") + assert err.code == 4914 + assert err.reason == "banned" + + def test_message_format(self): + from gateway.platforms.qqbot import QQCloseError + err = QQCloseError(4008, "rate limit") + assert "4008" in str(err) + assert "rate limit" in str(err) + + +# --------------------------------------------------------------------------- +# _dispatch_payload +# --------------------------------------------------------------------------- + +class TestDispatchPayload: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + adapter = QQAdapter(_make_config(**extra)) + return adapter + + def test_unknown_op(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + # Should not raise + adapter._dispatch_payload({"op": 99, "d": {}}) + # last_seq should remain None + assert adapter._last_seq is None + + def test_op10_updates_heartbeat_interval(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + adapter._dispatch_payload({"op": 10, "d": {"heartbeat_interval": 50000}}) + # Should be 50000 / 1000 * 0.8 = 40.0 + assert adapter._heartbeat_interval == 40.0 + + def test_op11_heartbeat_ack(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + # Should not raise + adapter._dispatch_payload({"op": 11, "t": "HEARTBEAT_ACK", "s": 42}) + + def test_seq_tracking(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + adapter._dispatch_payload({"op": 0, "t": "READY", "s": 100, "d": {}}) + assert adapter._last_seq == 100 + + def test_seq_increments(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + adapter._dispatch_payload({"op": 0, "t": "READY", "s": 5, "d": {}}) + adapter._dispatch_payload({"op": 0, "t": "SOME_EVENT", "s": 10, "d": {}}) + assert adapter._last_seq == 10 + + +# --------------------------------------------------------------------------- +# READY / RESUMED handling +# --------------------------------------------------------------------------- + +class TestReadyHandling: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_ready_stores_session(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + adapter._dispatch_payload({ + "op": 0, "t": "READY", + "s": 1, + "d": {"session_id": "sess_abc123"}, + }) + assert adapter._session_id == "sess_abc123" + + def test_resumed_preserves_session(self): + adapter = self._make_adapter(app_id="a", client_secret="b") + adapter._session_id = "old_sess" + adapter._last_seq = 50 + adapter._dispatch_payload({ + "op": 0, "t": "RESUMED", "s": 60, "d": {}, + }) + # Session should remain unchanged on RESUMED + assert adapter._session_id == "old_sess" + assert adapter._last_seq == 60 + + +# --------------------------------------------------------------------------- +# _parse_json +# --------------------------------------------------------------------------- + +class TestParseJson: + def _fn(self, raw): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter._parse_json(raw) + + def test_valid_json(self): + result = self._fn('{"op": 10, "d": {}}') + assert result == {"op": 10, "d": {}} + + def test_invalid_json(self): + result = self._fn("not json") + assert result is None + + def test_none_input(self): + result = self._fn(None) + assert result is None + + def test_non_dict_json(self): + result = self._fn('"just a string"') + assert result is None + + def test_empty_dict(self): + result = self._fn('{}') + assert result == {} + + +# --------------------------------------------------------------------------- +# _build_text_body +# --------------------------------------------------------------------------- + +class TestBuildTextBody: + def _make_adapter(self, **extra): + from gateway.platforms.qqbot import QQAdapter + return QQAdapter(_make_config(**extra)) + + def test_plain_text(self): + adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=False) + body = adapter._build_text_body("hello world") + assert body["msg_type"] == 0 # MSG_TYPE_TEXT + assert body["content"] == "hello world" + + def test_markdown_text(self): + adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=True) + body = adapter._build_text_body("**bold** text") + assert body["msg_type"] == 2 # MSG_TYPE_MARKDOWN + assert body["markdown"]["content"] == "**bold** text" + + def test_truncation(self): + adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=False) + long_text = "x" * 10000 + body = adapter._build_text_body(long_text) + assert len(body["content"]) == adapter.MAX_MESSAGE_LENGTH + + def test_empty_string(self): + adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=False) + body = adapter._build_text_body("") + assert body["content"] == "" + + def test_reply_to(self): + adapter = self._make_adapter(app_id="a", client_secret="b", markdown_support=False) + body = adapter._build_text_body("reply text", reply_to="msg_123") + assert body.get("message_reference", {}).get("message_id") == "msg_123" diff --git a/tests/gateway/test_restart_drain.py b/tests/gateway/test_restart_drain.py index 0c1324664..cfc2c364c 100644 --- a/tests/gateway/test_restart_drain.py +++ b/tests/gateway/test_restart_drain.py @@ -13,7 +13,10 @@ from tests.gateway.restart_test_helpers import make_restart_runner, make_restart @pytest.mark.asyncio -async def test_restart_command_while_busy_requests_drain_without_interrupt(): +async def test_restart_command_while_busy_requests_drain_without_interrupt(monkeypatch): + # Ensure INVOCATION_ID is NOT set — systemd sets this in service mode, + # which changes the restart call signature. + monkeypatch.delenv("INVOCATION_ID", raising=False) runner, _adapter = make_restart_runner() runner.request_restart = MagicMock(return_value=True) event = MessageEvent( diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index 9f556f884..5a643a1ef 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -186,10 +186,13 @@ def test_set_session_env_includes_session_key(): session_key="tg:-1001:17585", ) + # Capture baseline value before setting (may be non-empty from another + # test in the same pytest-xdist worker sharing the context). + baseline = get_session_env("HERMES_SESSION_KEY") tokens = runner._set_session_env(context) assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585" runner._clear_session_env(tokens) - assert get_session_env("HERMES_SESSION_KEY") == "" + assert get_session_env("HERMES_SESSION_KEY") == baseline def test_session_key_no_race_condition_with_contextvars(monkeypatch): diff --git a/tests/gateway/test_session_hygiene.py b/tests/gateway/test_session_hygiene.py index 5488296f6..325c24fac 100644 --- a/tests/gateway/test_session_hygiene.py +++ b/tests/gateway/test_session_hygiene.py @@ -374,6 +374,7 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t chat_id="-1001", chat_type="group", thread_id="17585", + user_id="12345", ), message_id="1", ) diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index d8a1be2d2..38532e66b 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -155,6 +155,90 @@ class TestSendOrEditMediaStripping: adapter.send.assert_not_called() + @pytest.mark.asyncio + async def test_short_text_with_cursor_skips_new_message(self): + """Short text + cursor should not create a standalone new message. + + During rapid tool-calling the model often emits 1-2 tokens before + switching to tool calls. Sending 'I ▉' as a new message risks + leaving the cursor permanently visible if the follow-up edit is + rate-limited. The guard should skip the first send and let the + text accumulate into the next segment. + """ + adapter = MagicMock() + adapter.send = AsyncMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(cursor=" ▉"), + ) + # No message_id yet (first send) — short text + cursor should be skipped + assert consumer._message_id is None + result = await consumer._send_or_edit("I ▉") + assert result is True + adapter.send.assert_not_called() + + # 3 chars is still under the threshold + result = await consumer._send_or_edit("Hi! ▉") + assert result is True + adapter.send.assert_not_called() + + @pytest.mark.asyncio + async def test_longer_text_with_cursor_sends_new_message(self): + """Text >= 4 visible chars + cursor should create a new message normally.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(cursor=" ▉"), + ) + result = await consumer._send_or_edit("Hello ▉") + assert result is True + adapter.send.assert_called_once() + + @pytest.mark.asyncio + async def test_short_text_without_cursor_sends_normally(self): + """Short text without cursor (e.g. final edit) should send normally.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + adapter.send = AsyncMock(return_value=send_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(cursor=" ▉"), + ) + # No cursor in text — even short text should be sent + result = await consumer._send_or_edit("OK") + assert result is True + adapter.send.assert_called_once() + + @pytest.mark.asyncio + async def test_short_text_cursor_edit_existing_message_allowed(self): + """Short text + cursor editing an existing message should proceed.""" + adapter = MagicMock() + edit_result = SimpleNamespace(success=True) + adapter.edit_message = AsyncMock(return_value=edit_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(cursor=" ▉"), + ) + consumer._message_id = "msg_1" # Existing message — guard should not fire + consumer._last_sent_text = "" + result = await consumer._send_or_edit("I ▉") + assert result is True + adapter.edit_message.assert_called_once() + # ── Integration: full stream run ───────────────────────────────────────── @@ -507,7 +591,7 @@ class TestSegmentBreakOnToolBoundary: config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉") consumer = GatewayStreamConsumer(adapter, "chat_123", config) - prefix = "abc" + prefix = "Hello world" tail = "x" * 620 consumer.on_delta(prefix) task = asyncio.create_task(consumer.run()) @@ -680,3 +764,202 @@ class TestCancelledConsumerSetsFlags: # Without a successful send, final_response_sent should stay False # so the normal gateway send path can deliver the response. assert consumer.final_response_sent is False + + +# ── Think-block filtering unit tests ───────────────────────────────────── + + +def _make_consumer() -> GatewayStreamConsumer: + """Create a bare consumer for unit-testing the filter (no adapter needed).""" + adapter = MagicMock() + return GatewayStreamConsumer(adapter, "chat_test") + + +class TestFilterAndAccumulate: + """Unit tests for _filter_and_accumulate think-block suppression.""" + + def test_plain_text_passes_through(self): + c = _make_consumer() + c._filter_and_accumulate("Hello world") + assert c._accumulated == "Hello world" + + def test_complete_think_block_stripped(self): + c = _make_consumer() + c._filter_and_accumulate("internal reasoningAnswer here") + assert c._accumulated == "Answer here" + + def test_think_block_in_middle(self): + c = _make_consumer() + c._filter_and_accumulate("Prefix\nreasoning\nSuffix") + assert c._accumulated == "Prefix\n\nSuffix" + + def test_think_block_split_across_deltas(self): + c = _make_consumer() + c._filter_and_accumulate("start of") + c._filter_and_accumulate(" reasoningvisible text") + assert c._accumulated == "visible text" + + def test_opening_tag_split_across_deltas(self): + c = _make_consumer() + c._filter_and_accumulate("hiddenshown") + assert c._accumulated == "shown" + + def test_closing_tag_split_across_deltas(self): + c = _make_consumer() + c._filter_and_accumulate("hiddenshown") + assert c._accumulated == "shown" + + def test_multiple_think_blocks(self): + c = _make_consumer() + # Consecutive blocks with no text between them — both stripped + c._filter_and_accumulate( + "block1block2visible" + ) + assert c._accumulated == "visible" + + def test_multiple_think_blocks_with_text_between(self): + """Think tag after non-whitespace is NOT a boundary (prose safety).""" + c = _make_consumer() + c._filter_and_accumulate( + "block1Ablock2B" + ) + # Second follows 'A' (not a block boundary) — treated as prose + assert "A" in c._accumulated + assert "B" in c._accumulated + + def test_thinking_tag_variant(self): + c = _make_consumer() + c._filter_and_accumulate("deep thoughtResult") + assert c._accumulated == "Result" + + def test_thought_tag_variant(self): + c = _make_consumer() + c._filter_and_accumulate("Gemma styleOutput") + assert c._accumulated == "Output" + + def test_reasoning_scratchpad_variant(self): + c = _make_consumer() + c._filter_and_accumulate( + "long planDone" + ) + assert c._accumulated == "Done" + + def test_case_insensitive_THINKING(self): + c = _make_consumer() + c._filter_and_accumulate("capsanswer") + assert c._accumulated == "answer" + + def test_prose_mention_not_stripped(self): + """ mentioned mid-line in prose should NOT trigger filtering.""" + c = _make_consumer() + c._filter_and_accumulate("The tag is used for reasoning") + assert "" in c._accumulated + assert "used for reasoning" in c._accumulated + + def test_prose_mention_after_text(self): + """ after non-whitespace on same line is not a block boundary.""" + c = _make_consumer() + c._filter_and_accumulate("Try using some content tags") + assert "" in c._accumulated + + def test_think_at_line_start_is_stripped(self): + """ at start of a new line IS a block boundary.""" + c = _make_consumer() + c._filter_and_accumulate("Previous line\nreasoningNext") + assert "Previous line\nNext" == c._accumulated + + def test_think_with_only_whitespace_before(self): + """ preceded by only whitespace on its line is a boundary.""" + c = _make_consumer() + c._filter_and_accumulate(" hiddenvisible") + # Leading whitespace before the tag is emitted, then block is stripped + assert c._accumulated == " visible" + + def test_flush_think_buffer_on_non_tag(self): + """Partial tag that turns out not to be a tag is flushed.""" + c = _make_consumer() + c._filter_and_accumulate("still thinking") + c._flush_think_buffer() + assert c._accumulated == "" + + def test_unclosed_think_block_suppresses(self): + """An unclosed suppresses all subsequent content.""" + c = _make_consumer() + c._filter_and_accumulate("Before\nreasoning that never ends...") + assert c._accumulated == "Before\n" + + def test_multiline_think_block(self): + c = _make_consumer() + c._filter_and_accumulate( + "\nLine 1\nLine 2\nLine 3\nFinal answer" + ) + assert c._accumulated == "Final answer" + + def test_segment_reset_preserves_think_state(self): + """_reset_segment_state should NOT clear think-block filter state.""" + c = _make_consumer() + c._filter_and_accumulate("start") + c._reset_segment_state() + # Still inside think block — subsequent text should be suppressed + c._filter_and_accumulate("still hiddenvisible") + assert c._accumulated == "visible" + + +class TestFilterAndAccumulateIntegration: + """Integration: verify think blocks don't leak through the full run() path.""" + + @pytest.mark.asyncio + async def test_think_block_not_sent_to_platform(self): + """Think blocks should be filtered before platform edit.""" + adapter = MagicMock() + adapter.send = AsyncMock( + return_value=SimpleNamespace(success=True, message_id="msg_1") + ) + adapter.edit_message = AsyncMock( + return_value=SimpleNamespace(success=True) + ) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_test", + StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5), + ) + + # Simulate streaming: think block then visible text + consumer.on_delta("deep reasoning here") + consumer.on_delta("The answer is 42.") + consumer.finish() + + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.15) + + # The final text sent to the platform should NOT contain + all_calls = list(adapter.send.call_args_list) + list( + adapter.edit_message.call_args_list + ) + for call in all_calls: + args, kwargs = call + content = kwargs.get("content") or (args[0] if args else "") + assert "" not in content, f"Think tag leaked: {content}" + assert "deep reasoning" not in content + + try: + task.cancel() + await task + except asyncio.CancelledError: + pass diff --git a/tests/gateway/test_telegram_group_gating.py b/tests/gateway/test_telegram_group_gating.py index 99675605d..15ffca9ec 100644 --- a/tests/gateway/test_telegram_group_gating.py +++ b/tests/gateway/test_telegram_group_gating.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock from gateway.config import Platform, PlatformConfig, load_gateway_config -def _make_adapter(require_mention=None, free_response_chats=None, mention_patterns=None): +def _make_adapter(require_mention=None, free_response_chats=None, mention_patterns=None, ignored_threads=None): from gateway.platforms.telegram import TelegramAdapter extra = {} @@ -15,6 +15,8 @@ def _make_adapter(require_mention=None, free_response_chats=None, mention_patter extra["free_response_chats"] = free_response_chats if mention_patterns is not None: extra["mention_patterns"] = mention_patterns + if ignored_threads is not None: + extra["ignored_threads"] = ignored_threads adapter = object.__new__(TelegramAdapter) adapter.platform = Platform.TELEGRAM @@ -28,7 +30,16 @@ def _make_adapter(require_mention=None, free_response_chats=None, mention_patter return adapter -def _group_message(text="hello", *, chat_id=-100, reply_to_bot=False, entities=None, caption=None, caption_entities=None): +def _group_message( + text="hello", + *, + chat_id=-100, + thread_id=None, + reply_to_bot=False, + entities=None, + caption=None, + caption_entities=None, +): reply_to_message = None if reply_to_bot: reply_to_message = SimpleNamespace(from_user=SimpleNamespace(id=999)) @@ -37,6 +48,7 @@ def _group_message(text="hello", *, chat_id=-100, reply_to_bot=False, entities=N caption=caption, entities=entities or [], caption_entities=caption_entities or [], + message_thread_id=thread_id, chat=SimpleNamespace(id=chat_id, type="group"), reply_to_message=reply_to_message, ) @@ -69,6 +81,14 @@ def test_free_response_chats_bypass_mention_requirement(): assert adapter._should_process_message(_group_message("hello everyone", chat_id=-201)) is False +def test_ignored_threads_drop_group_messages_before_other_gates(): + adapter = _make_adapter(require_mention=False, free_response_chats=["-200"], ignored_threads=[31, "42"]) + + assert adapter._should_process_message(_group_message("hello everyone", chat_id=-200, thread_id=31)) is False + assert adapter._should_process_message(_group_message("hello everyone", chat_id=-200, thread_id=42)) is False + assert adapter._should_process_message(_group_message("hello everyone", chat_id=-200, thread_id=99)) is True + + def test_regex_mention_patterns_allow_custom_wake_words(): adapter = _make_adapter(require_mention=True, mention_patterns=[r"^\s*chompy\b"]) @@ -108,3 +128,23 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path): assert __import__("os").environ["TELEGRAM_REQUIRE_MENTION"] == "true" assert json.loads(__import__("os").environ["TELEGRAM_MENTION_PATTERNS"]) == [r"^\s*chompy\b"] assert __import__("os").environ["TELEGRAM_FREE_RESPONSE_CHATS"] == "-123" + + +def test_config_bridges_telegram_ignored_threads(monkeypatch, tmp_path): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + "telegram:\n" + " ignored_threads:\n" + " - 31\n" + " - \"42\"\n", + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("TELEGRAM_IGNORED_THREADS", raising=False) + + config = load_gateway_config() + + assert config is not None + assert __import__("os").environ["TELEGRAM_IGNORED_THREADS"] == "31,42" diff --git a/tests/gateway/test_ws_auth_retry.py b/tests/gateway/test_ws_auth_retry.py index beef6722e..0da397933 100644 --- a/tests/gateway/test_ws_auth_retry.py +++ b/tests/gateway/test_ws_auth_retry.py @@ -130,13 +130,17 @@ class TestMatrixSyncAuthRetry: sync_count = 0 - async def fake_sync(timeout=30000): + async def fake_sync(timeout=30000, since=None): nonlocal sync_count sync_count += 1 return SyncError("M_UNKNOWN_TOKEN: Invalid access token") adapter._client = MagicMock() adapter._client.sync = fake_sync + adapter._client.sync_store = MagicMock() + adapter._client.sync_store.get_next_batch = AsyncMock(return_value=None) + adapter._pending_megolm = [] + adapter._joined_rooms = set() async def run(): import sys @@ -157,13 +161,17 @@ class TestMatrixSyncAuthRetry: call_count = 0 - async def fake_sync(timeout=30000): + async def fake_sync(timeout=30000, since=None): nonlocal call_count call_count += 1 raise RuntimeError("HTTP 401 Unauthorized") adapter._client = MagicMock() adapter._client.sync = fake_sync + adapter._client.sync_store = MagicMock() + adapter._client.sync_store.get_next_batch = AsyncMock(return_value=None) + adapter._pending_megolm = [] + adapter._joined_rooms = set() async def run(): import types @@ -188,7 +196,7 @@ class TestMatrixSyncAuthRetry: call_count = 0 - async def fake_sync(timeout=30000): + async def fake_sync(timeout=30000, since=None): nonlocal call_count call_count += 1 if call_count >= 2: @@ -198,6 +206,10 @@ class TestMatrixSyncAuthRetry: adapter._client = MagicMock() adapter._client.sync = fake_sync + adapter._client.sync_store = MagicMock() + adapter._client.sync_store.get_next_batch = AsyncMock(return_value=None) + adapter._pending_megolm = [] + adapter._joined_rooms = set() async def run(): import types diff --git a/tests/hermes_cli/test_api_key_providers.py b/tests/hermes_cli/test_api_key_providers.py index 0e1183471..0e8badc6e 100644 --- a/tests/hermes_cli/test_api_key_providers.py +++ b/tests/hermes_cli/test_api_key_providers.py @@ -44,7 +44,7 @@ class TestProviderRegistry: ("kimi-coding", "Kimi / Moonshot", "api_key"), ("minimax", "MiniMax", "api_key"), ("minimax-cn", "MiniMax (China)", "api_key"), - ("ai-gateway", "AI Gateway", "api_key"), + ("ai-gateway", "Vercel AI Gateway", "api_key"), ("kilocode", "Kilo Code", "api_key"), ]) def test_provider_registered(self, provider_id, name, auth_type): diff --git a/tests/hermes_cli/test_auth_commands.py b/tests/hermes_cli/test_auth_commands.py index 2ebdb1cc7..b26757a22 100644 --- a/tests/hermes_cli/test_auth_commands.py +++ b/tests/hermes_cli/test_auth_commands.py @@ -238,6 +238,10 @@ def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch): def test_auth_remove_accepts_label_target(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + monkeypatch.setattr( + "agent.credential_pool._seed_from_singletons", + lambda provider, entries: (False, set()), + ) _write_auth_store( tmp_path, { @@ -281,6 +285,10 @@ def test_auth_remove_accepts_label_target(tmp_path, monkeypatch): def test_auth_remove_prefers_exact_numeric_label_over_index(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + monkeypatch.setattr( + "agent.credential_pool._seed_from_singletons", + lambda provider, entries: (False, set()), + ) _write_auth_store( tmp_path, { diff --git a/tests/hermes_cli/test_auth_provider_gate.py b/tests/hermes_cli/test_auth_provider_gate.py index 2eacb71be..f65ae71b8 100644 --- a/tests/hermes_cli/test_auth_provider_gate.py +++ b/tests/hermes_cli/test_auth_provider_gate.py @@ -18,6 +18,13 @@ def _write_auth_store(tmp_path, payload: dict) -> None: (hermes_home / "auth.json").write_text(json.dumps(payload, indent=2)) +@pytest.fixture(autouse=True) +def _clean_anthropic_env(monkeypatch): + """Strip Anthropic env vars so CI secrets don't leak into tests.""" + for key in ("ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"): + monkeypatch.delenv(key, raising=False) + + def test_returns_false_when_no_config(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) (tmp_path / "hermes").mkdir(parents=True, exist_ok=True) diff --git a/tests/hermes_cli/test_model_validation.py b/tests/hermes_cli/test_model_validation.py index 3b83b81da..4ce318395 100644 --- a/tests/hermes_cli/test_model_validation.py +++ b/tests/hermes_cli/test_model_validation.py @@ -436,9 +436,23 @@ class TestValidateApiNotFound: def test_warning_includes_suggestions(self): result = _validate("anthropic/claude-opus-4.5") - assert result["accepted"] is False - assert result["persist"] is False - assert "Similar models" in result["message"] + assert result["accepted"] is True + # Close match auto-corrects; less similar inputs show suggestions + assert "Auto-corrected" in result["message"] or "Similar models" in result["message"] + + def test_auto_correction_returns_corrected_model(self): + """When a very close match exists, validate returns corrected_model.""" + result = _validate("anthropic/claude-opus-4.5") + assert result["accepted"] is True + assert result.get("corrected_model") == "anthropic/claude-opus-4.6" + assert result["recognized"] is True + + def test_dissimilar_model_shows_suggestions_not_autocorrect(self): + """Models too different for auto-correction still get suggestions.""" + result = _validate("anthropic/claude-nonexistent") + assert result["accepted"] is True + assert result.get("corrected_model") is None + assert "not found" in result["message"] # -- validate — API unreachable — reject with guidance ---------------- @@ -487,3 +501,40 @@ class TestValidateApiFallback: assert result["persist"] is False assert "http://localhost:8000/v1/models" in result["message"] assert "http://localhost:8000/v1" in result["message"] + + +# -- validate — Codex auto-correction ------------------------------------------ + +class TestValidateCodexAutoCorrection: + """Auto-correction for typos on openai-codex provider.""" + + def test_missing_dash_auto_corrects(self): + """gpt5.3-codex (missing dash) auto-corrects to gpt-5.3-codex.""" + codex_models = ["gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex", + "gpt-5.2-codex", "gpt-5.1-codex-max"] + with patch("hermes_cli.models.provider_model_ids", return_value=codex_models): + result = validate_requested_model("gpt5.3-codex", "openai-codex") + assert result["accepted"] is True + assert result["recognized"] is True + assert result["corrected_model"] == "gpt-5.3-codex" + assert "Auto-corrected" in result["message"] + + def test_exact_match_no_correction(self): + """Exact model name does not trigger auto-correction.""" + codex_models = ["gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex"] + with patch("hermes_cli.models.provider_model_ids", return_value=codex_models): + result = validate_requested_model("gpt-5.3-codex", "openai-codex") + assert result["accepted"] is True + assert result["recognized"] is True + assert result.get("corrected_model") is None + assert result["message"] is None + + def test_very_different_name_falls_to_suggestions(self): + """Names too different for auto-correction get the suggestion list.""" + codex_models = ["gpt-5.4-mini", "gpt-5.4", "gpt-5.3-codex"] + with patch("hermes_cli.models.provider_model_ids", return_value=codex_models): + result = validate_requested_model("totally-wrong", "openai-codex") + assert result["accepted"] is True + assert result["recognized"] is False + assert result.get("corrected_model") is None + assert "not found" in result["message"] diff --git a/tests/hermes_cli/test_opencode_go_in_model_list.py b/tests/hermes_cli/test_opencode_go_in_model_list.py index 493d41b99..7f0815233 100644 --- a/tests/hermes_cli/test_opencode_go_in_model_list.py +++ b/tests/hermes_cli/test_opencode_go_in_model_list.py @@ -16,8 +16,10 @@ def test_opencode_go_appears_when_api_key_set(): assert opencode_go is not None, "opencode-go should appear when OPENCODE_GO_API_KEY is set" assert opencode_go["models"] == ["glm-5", "kimi-k2.5", "mimo-v2-pro", "mimo-v2-omni", "minimax-m2.7", "minimax-m2.5"] - # opencode-go is in PROVIDER_TO_MODELS_DEV, so it appears as "built-in" (Part 1) - assert opencode_go["source"] == "built-in" + # opencode-go can appear as "built-in" (from PROVIDER_TO_MODELS_DEV when + # models.dev is reachable) or "hermes" (from HERMES_OVERLAYS fallback when + # the API is unavailable, e.g. in CI). + assert opencode_go["source"] in ("built-in", "hermes") def test_opencode_go_not_appears_when_no_creds(): diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index ec29a4e90..7be1be617 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -18,6 +18,7 @@ from hermes_cli.plugins import ( PluginManager, PluginManifest, get_plugin_manager, + get_pre_tool_call_block_message, discover_plugins, invoke_hook, ) @@ -310,6 +311,50 @@ class TestPluginHooks: assert any("on_banana" in record.message for record in caplog.records) +class TestPreToolCallBlocking: + """Tests for the pre_tool_call block directive helper.""" + + def test_block_message_returned_for_valid_directive(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [{"action": "block", "message": "blocked by plugin"}], + ) + assert get_pre_tool_call_block_message("todo", {}, task_id="t1") == "blocked by plugin" + + def test_invalid_returns_are_ignored(self, monkeypatch): + """Various malformed hook returns should not trigger a block.""" + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [ + "block", # not a dict + 123, # not a dict + {"action": "block"}, # missing message + {"action": "deny", "message": "nope"}, # wrong action + {"message": "missing action"}, # no action key + {"action": "block", "message": 123}, # message not str + ], + ) + assert get_pre_tool_call_block_message("todo", {}, task_id="t1") is None + + def test_none_when_no_hooks(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [], + ) + assert get_pre_tool_call_block_message("web_search", {"q": "test"}) is None + + def test_first_valid_block_wins(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [ + {"action": "allow"}, + {"action": "block", "message": "first blocker"}, + {"action": "block", "message": "second blocker"}, + ], + ) + assert get_pre_tool_call_block_message("terminal", {}) == "first blocker" + + # ── TestPluginContext ────────────────────────────────────────────────────── diff --git a/tests/hermes_cli/test_skin_engine.py b/tests/hermes_cli/test_skin_engine.py index b11d168c7..aadcde3a6 100644 --- a/tests/hermes_cli/test_skin_engine.py +++ b/tests/hermes_cli/test_skin_engine.py @@ -78,6 +78,28 @@ class TestBuiltinSkins: assert skin.name == "slate" assert skin.get_color("banner_title") == "#7eb8f6" + def test_daylight_skin_loads(self): + from hermes_cli.skin_engine import load_skin + + skin = load_skin("daylight") + assert skin.name == "daylight" + assert skin.tool_prefix == "│" + assert skin.get_color("banner_title") == "#0F172A" + assert skin.get_color("status_bar_bg") == "#E5EDF8" + assert skin.get_color("voice_status_bg") == "#E5EDF8" + assert skin.get_color("completion_menu_bg") == "#F8FAFC" + assert skin.get_color("completion_menu_current_bg") == "#DBEAFE" + assert skin.get_color("completion_menu_meta_bg") == "#EEF2FF" + assert skin.get_color("completion_menu_meta_current_bg") == "#BFDBFE" + + def test_warm_lightmode_skin_loads(self): + from hermes_cli.skin_engine import load_skin + + skin = load_skin("warm-lightmode") + assert skin.name == "warm-lightmode" + assert skin.get_color("banner_text") == "#2C1810" + assert skin.get_color("completion_menu_bg") == "#F5EFE0" + def test_unknown_skin_falls_back_to_default(self): from hermes_cli.skin_engine import load_skin skin = load_skin("nonexistent_skin_xyz") @@ -114,6 +136,8 @@ class TestSkinManagement: assert "ares" in names assert "mono" in names assert "slate" in names + assert "daylight" in names + assert "warm-lightmode" in names for s in skins: assert "source" in s assert s["source"] == "builtin" @@ -242,6 +266,15 @@ class TestCliBrandingHelpers: "completion-menu.completion.current", "completion-menu.meta.completion", "completion-menu.meta.completion.current", + "status-bar", + "status-bar-strong", + "status-bar-dim", + "status-bar-good", + "status-bar-warn", + "status-bar-bad", + "status-bar-critical", + "voice-status", + "voice-status-recording", "clarify-border", "clarify-title", "clarify-question", @@ -277,3 +310,9 @@ class TestCliBrandingHelpers: assert overrides["clarify-title"] == f"{skin.get_color('banner_title')} bold" assert overrides["sudo-prompt"] == f"{skin.get_color('ui_error')} bold" assert overrides["approval-title"] == f"{skin.get_color('ui_warn')} bold" + + set_active_skin("daylight") + skin = get_active_skin() + overrides = get_prompt_toolkit_style_overrides() + assert overrides["status-bar"] == f"bg:{skin.get_color('status_bar_bg')} {skin.get_color('banner_text')}" + assert overrides["voice-status"] == f"bg:{skin.get_color('voice_status_bg')} {skin.get_color('ui_label')}" diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index ffa614cd9..1bbbdba1c 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -673,3 +673,282 @@ class TestNewEndpoints: resp = self.client.get("/api/auth/session-token") assert resp.status_code == 200 assert resp.json()["token"] == _SESSION_TOKEN + + +# --------------------------------------------------------------------------- +# Model context length: normalize/denormalize + /api/model/info +# --------------------------------------------------------------------------- + + +class TestModelContextLength: + """Tests for model_context_length in normalize/denormalize and /api/model/info.""" + + def test_normalize_extracts_context_length_from_dict(self): + """normalize should surface context_length from model dict.""" + from hermes_cli.web_server import _normalize_config_for_web + + cfg = { + "model": { + "default": "anthropic/claude-opus-4.6", + "provider": "openrouter", + "context_length": 200000, + } + } + result = _normalize_config_for_web(cfg) + assert result["model"] == "anthropic/claude-opus-4.6" + assert result["model_context_length"] == 200000 + + def test_normalize_bare_string_model_yields_zero(self): + """normalize should set model_context_length=0 for bare string model.""" + from hermes_cli.web_server import _normalize_config_for_web + + result = _normalize_config_for_web({"model": "anthropic/claude-sonnet-4"}) + assert result["model"] == "anthropic/claude-sonnet-4" + assert result["model_context_length"] == 0 + + def test_normalize_dict_without_context_length_yields_zero(self): + """normalize should default to 0 when model dict has no context_length.""" + from hermes_cli.web_server import _normalize_config_for_web + + cfg = {"model": {"default": "test/model", "provider": "openrouter"}} + result = _normalize_config_for_web(cfg) + assert result["model_context_length"] == 0 + + def test_normalize_non_int_context_length_yields_zero(self): + """normalize should coerce non-int context_length to 0.""" + from hermes_cli.web_server import _normalize_config_for_web + + cfg = {"model": {"default": "test/model", "context_length": "invalid"}} + result = _normalize_config_for_web(cfg) + assert result["model_context_length"] == 0 + + def test_denormalize_writes_context_length_into_model_dict(self): + """denormalize should write model_context_length back into model dict.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + # Set up disk config with model as a dict + save_config({ + "model": {"default": "anthropic/claude-opus-4.6", "provider": "openrouter"} + }) + + result = _denormalize_config_from_web({ + "model": "anthropic/claude-opus-4.6", + "model_context_length": 100000, + }) + assert isinstance(result["model"], dict) + assert result["model"]["context_length"] == 100000 + assert "model_context_length" not in result # virtual field removed + + def test_denormalize_zero_removes_context_length(self): + """denormalize with model_context_length=0 should remove context_length key.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "anthropic/claude-opus-4.6", + "provider": "openrouter", + "context_length": 50000, + } + }) + + result = _denormalize_config_from_web({ + "model": "anthropic/claude-opus-4.6", + "model_context_length": 0, + }) + assert isinstance(result["model"], dict) + assert "context_length" not in result["model"] + + def test_denormalize_upgrades_bare_string_to_dict(self): + """denormalize should upgrade bare string model to dict when context_length set.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + # Disk has model as bare string + save_config({"model": "anthropic/claude-sonnet-4"}) + + result = _denormalize_config_from_web({ + "model": "anthropic/claude-sonnet-4", + "model_context_length": 65000, + }) + assert isinstance(result["model"], dict) + assert result["model"]["default"] == "anthropic/claude-sonnet-4" + assert result["model"]["context_length"] == 65000 + + def test_denormalize_bare_string_stays_string_when_zero(self): + """denormalize should keep bare string model as string when context_length=0.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({"model": "anthropic/claude-sonnet-4"}) + + result = _denormalize_config_from_web({ + "model": "anthropic/claude-sonnet-4", + "model_context_length": 0, + }) + assert result["model"] == "anthropic/claude-sonnet-4" + + def test_denormalize_coerces_string_context_length(self): + """denormalize should handle string model_context_length from frontend.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": {"default": "test/model", "provider": "openrouter"} + }) + + result = _denormalize_config_from_web({ + "model": "test/model", + "model_context_length": "32000", + }) + assert isinstance(result["model"], dict) + assert result["model"]["context_length"] == 32000 + + +class TestModelContextLengthSchema: + """Tests for model_context_length placement in CONFIG_SCHEMA.""" + + def test_schema_has_model_context_length(self): + from hermes_cli.web_server import CONFIG_SCHEMA + assert "model_context_length" in CONFIG_SCHEMA + + def test_schema_model_context_length_after_model(self): + """model_context_length should appear immediately after model in schema.""" + from hermes_cli.web_server import CONFIG_SCHEMA + keys = list(CONFIG_SCHEMA.keys()) + model_idx = keys.index("model") + assert keys[model_idx + 1] == "model_context_length" + + def test_schema_model_context_length_is_number(self): + from hermes_cli.web_server import CONFIG_SCHEMA + entry = CONFIG_SCHEMA["model_context_length"] + assert entry["type"] == "number" + assert "category" in entry + + +class TestModelInfoEndpoint: + """Tests for GET /api/model/info endpoint.""" + + @pytest.fixture(autouse=True) + def _setup(self): + try: + from starlette.testclient import TestClient + except ImportError: + pytest.skip("fastapi/starlette not installed") + from hermes_cli.web_server import app + self.client = TestClient(app) + + def test_model_info_returns_200(self): + resp = self.client.get("/api/model/info") + assert resp.status_code == 200 + data = resp.json() + assert "model" in data + assert "provider" in data + assert "auto_context_length" in data + assert "config_context_length" in data + assert "effective_context_length" in data + assert "capabilities" in data + + def test_model_info_with_dict_config(self, monkeypatch): + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: { + "model": { + "default": "anthropic/claude-opus-4.6", + "provider": "openrouter", + "context_length": 100000, + } + }) + + with patch("agent.model_metadata.get_model_context_length", return_value=200000): + resp = self.client.get("/api/model/info") + + data = resp.json() + assert data["model"] == "anthropic/claude-opus-4.6" + assert data["provider"] == "openrouter" + assert data["auto_context_length"] == 200000 + assert data["config_context_length"] == 100000 + assert data["effective_context_length"] == 100000 # override wins + + def test_model_info_auto_detect_when_no_override(self, monkeypatch): + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: { + "model": {"default": "anthropic/claude-opus-4.6", "provider": "openrouter"} + }) + + with patch("agent.model_metadata.get_model_context_length", return_value=200000): + resp = self.client.get("/api/model/info") + + data = resp.json() + assert data["auto_context_length"] == 200000 + assert data["config_context_length"] == 0 + assert data["effective_context_length"] == 200000 # auto wins + + def test_model_info_empty_model(self, monkeypatch): + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: {"model": ""}) + + resp = self.client.get("/api/model/info") + data = resp.json() + assert data["model"] == "" + assert data["effective_context_length"] == 0 + + def test_model_info_bare_string_model(self, monkeypatch): + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: { + "model": "anthropic/claude-sonnet-4" + }) + + with patch("agent.model_metadata.get_model_context_length", return_value=200000): + resp = self.client.get("/api/model/info") + + data = resp.json() + assert data["model"] == "anthropic/claude-sonnet-4" + assert data["provider"] == "" + assert data["config_context_length"] == 0 + assert data["effective_context_length"] == 200000 + + def test_model_info_capabilities(self, monkeypatch): + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: { + "model": {"default": "anthropic/claude-opus-4.6", "provider": "openrouter"} + }) + + mock_caps = MagicMock() + mock_caps.supports_tools = True + mock_caps.supports_vision = True + mock_caps.supports_reasoning = True + mock_caps.context_window = 200000 + mock_caps.max_output_tokens = 32000 + mock_caps.model_family = "claude-opus" + + with patch("agent.model_metadata.get_model_context_length", return_value=200000), \ + patch("agent.models_dev.get_model_capabilities", return_value=mock_caps): + resp = self.client.get("/api/model/info") + + caps = resp.json()["capabilities"] + assert caps["supports_tools"] is True + assert caps["supports_vision"] is True + assert caps["supports_reasoning"] is True + assert caps["max_output_tokens"] == 32000 + assert caps["model_family"] == "claude-opus" + + def test_model_info_graceful_on_metadata_error(self, monkeypatch): + """Endpoint should return zeros on import/resolution errors, not 500.""" + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws, "load_config", lambda: { + "model": "some/obscure-model" + }) + + with patch("agent.model_metadata.get_model_context_length", side_effect=Exception("boom")): + resp = self.client.get("/api/model/info") + + assert resp.status_code == 200 + data = resp.json() + assert data["auto_context_length"] == 0 diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 568077fd7..d71e6a625 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1442,7 +1442,7 @@ class TestConcurrentToolExecution: tool_call_id=None, session_id=agent.session_id, enabled_tools=list(agent.valid_tool_names), - + skip_pre_tool_call_hook=True, ) assert result == "result" @@ -1489,6 +1489,73 @@ class TestConcurrentToolExecution: mock_todo.assert_called_once() assert "ok" in result + def test_invoke_tool_blocked_returns_error_and_skips_execution(self, agent, monkeypatch): + """_invoke_tool should return error JSON when a plugin blocks the tool.""" + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked by test policy", + ) + with patch("tools.todo_tool.todo_tool", side_effect=AssertionError("should not run")) as mock_todo: + result = agent._invoke_tool("todo", {"todos": []}, "task-1") + + assert json.loads(result) == {"error": "Blocked by test policy"} + mock_todo.assert_not_called() + + def test_invoke_tool_blocked_skips_handle_function_call(self, agent, monkeypatch): + """Blocked registry tools should not reach handle_function_call.""" + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked", + ) + with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")): + result = agent._invoke_tool("web_search", {"q": "test"}, "task-1") + + assert json.loads(result) == {"error": "Blocked"} + + def test_sequential_blocked_tool_skips_checkpoints_and_callbacks(self, agent, monkeypatch): + """Sequential path: blocked tool should not trigger checkpoints or start callbacks.""" + tool_call = _mock_tool_call(name="write_file", + arguments='{"path":"test.txt","content":"hello"}', + call_id="c1") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call]) + messages = [] + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked by policy", + ) + agent._checkpoint_mgr.enabled = True + agent._checkpoint_mgr.ensure_checkpoint = MagicMock( + side_effect=AssertionError("checkpoint should not run") + ) + + starts = [] + agent.tool_start_callback = lambda *a: starts.append(a) + + with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")): + agent._execute_tool_calls_sequential(mock_msg, messages, "task-1") + + agent._checkpoint_mgr.ensure_checkpoint.assert_not_called() + assert starts == [] + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert json.loads(messages[0]["content"]) == {"error": "Blocked by policy"} + + def test_blocked_memory_tool_does_not_reset_counter(self, agent, monkeypatch): + """Blocked memory tool should not reset the nudge counter.""" + agent._turns_since_memory = 5 + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked", + ) + with patch("tools.memory_tool.memory_tool", side_effect=AssertionError("should not run")): + result = agent._invoke_tool( + "memory", {"action": "add", "target": "memory", "content": "x"}, "task-1", + ) + + assert json.loads(result) == {"error": "Blocked"} + assert agent._turns_since_memory == 5 + class TestPathsOverlap: """Unit tests for the _paths_overlap helper.""" diff --git a/tests/run_agent/test_run_agent_codex_responses.py b/tests/run_agent/test_run_agent_codex_responses.py index 0fca9e4df..785d85886 100644 --- a/tests/run_agent/test_run_agent_codex_responses.py +++ b/tests/run_agent/test_run_agent_codex_responses.py @@ -287,6 +287,69 @@ def test_build_api_kwargs_codex(monkeypatch): assert "extra_body" not in kwargs +def test_build_api_kwargs_codex_clamps_minimal_effort(monkeypatch): + """'minimal' reasoning effort is clamped to 'low' on the Responses API. + + GPT-5.4 supports none/low/medium/high/xhigh but NOT 'minimal'. + Users may configure 'minimal' via OpenRouter conventions, so the Codex + Responses path must clamp it to the nearest supported level. + """ + _patch_agent_bootstrap(monkeypatch) + + agent = run_agent.AIAgent( + model="gpt-5-codex", + base_url="https://chatgpt.com/backend-api/codex", + api_key="codex-token", + quiet_mode=True, + max_iterations=4, + skip_context_files=True, + skip_memory=True, + reasoning_config={"enabled": True, "effort": "minimal"}, + ) + agent._cleanup_task_resources = lambda task_id: None + agent._persist_session = lambda messages, history=None: None + agent._save_trajectory = lambda messages, user_message, completed: None + agent._save_session_log = lambda messages: None + + kwargs = agent._build_api_kwargs( + [ + {"role": "system", "content": "You are Hermes."}, + {"role": "user", "content": "Ping"}, + ] + ) + + assert kwargs["reasoning"]["effort"] == "low" + + +def test_build_api_kwargs_codex_preserves_supported_efforts(monkeypatch): + """Effort levels natively supported by the Responses API pass through unchanged.""" + _patch_agent_bootstrap(monkeypatch) + + for effort in ("low", "medium", "high", "xhigh"): + agent = run_agent.AIAgent( + model="gpt-5-codex", + base_url="https://chatgpt.com/backend-api/codex", + api_key="codex-token", + quiet_mode=True, + max_iterations=4, + skip_context_files=True, + skip_memory=True, + reasoning_config={"enabled": True, "effort": effort}, + ) + agent._cleanup_task_resources = lambda task_id: None + agent._persist_session = lambda messages, history=None: None + agent._save_trajectory = lambda messages, user_message, completed: None + agent._save_session_log = lambda messages: None + + kwargs = agent._build_api_kwargs( + [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + ) + assert kwargs["reasoning"]["effort"] == effort, f"{effort} should pass through unchanged" + + def test_build_api_kwargs_copilot_responses_omits_openai_only_fields(monkeypatch): agent = _build_copilot_agent(monkeypatch) kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}]) diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index 5e3b1d6ce..bb8a79ab0 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -91,6 +91,91 @@ class TestAgentLoopTools: assert "terminal" not in _AGENT_LOOP_TOOLS +# ========================================================================= +# Pre-tool-call blocking via plugin hooks +# ========================================================================= + +class TestPreToolCallBlocking: + """Verify that pre_tool_call hooks can block tool execution.""" + + def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch): + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [{"action": "block", "message": "Blocked by policy"}] + return [] + + dispatch_called = False + _orig_dispatch = None + + def fake_dispatch(*args, **kwargs): + nonlocal dispatch_called + dispatch_called = True + raise AssertionError("dispatch should not run when blocked") + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch) + + result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) + assert result == {"error": "Blocked by policy"} + assert not dispatch_called + + def test_blocked_tool_skips_read_loop_notification(self, monkeypatch): + notifications = [] + + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [{"action": "block", "message": "Blocked"}] + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run"))) + monkeypatch.setattr("tools.file_tools.notify_other_tool_call", + lambda task_id: notifications.append(task_id)) + + result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1")) + assert result == {"error": "Blocked"} + assert notifications == [] + + def test_invalid_hook_returns_do_not_block(self, monkeypatch): + """Malformed hook returns should be ignored — tool executes normally.""" + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [ + "block", + {"action": "block"}, # missing message + {"action": "deny", "message": "nope"}, + ] + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: json.dumps({"ok": True})) + + result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) + assert result == {"ok": True} + + def test_skip_flag_prevents_double_block_check(self, monkeypatch): + """When skip_pre_tool_call_hook=True, blocking is not checked (caller did it).""" + hook_calls = [] + + def fake_invoke_hook(hook_name, **kwargs): + hook_calls.append(hook_name) + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: json.dumps({"ok": True})) + + handle_function_call("web_search", {"q": "test"}, task_id="t1", + skip_pre_tool_call_hook=True) + + # Hook still fires for observer notification, but get_pre_tool_call_block_message + # is not called — invoke_hook fires directly in the skip=True branch. + assert "pre_tool_call" in hook_calls + assert "post_tool_call" in hook_calls + + # ========================================================================= # Legacy toolset map # ========================================================================= diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 13c345070..774bf9893 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -1,7 +1,6 @@ """Tests for toolsets.py — toolset resolution, validation, and composition.""" -import pytest - +from tools.registry import ToolRegistry from toolsets import ( TOOLSETS, get_toolset, @@ -15,6 +14,18 @@ from toolsets import ( ) +def _dummy_handler(args, **kwargs): + return "{}" + + +def _make_schema(name: str, description: str = "test tool"): + return { + "name": name, + "description": description, + "parameters": {"type": "object", "properties": {}}, + } + + class TestGetToolset: def test_known_toolset(self): ts = get_toolset("web") @@ -52,6 +63,25 @@ class TestResolveToolset: def test_unknown_toolset_returns_empty(self): assert resolve_toolset("nonexistent") == [] + def test_plugin_toolset_uses_registry_snapshot(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="plugin_b", + toolset="plugin_example", + schema=_make_schema("plugin_b", "B"), + handler=_dummy_handler, + ) + reg.register( + name="plugin_a", + toolset="plugin_example", + schema=_make_schema("plugin_a", "A"), + handler=_dummy_handler, + ) + + monkeypatch.setattr("tools.registry.registry", reg) + + assert resolve_toolset("plugin_example") == ["plugin_a", "plugin_b"] + def test_all_alias(self): tools = resolve_toolset("all") assert len(tools) > 10 # Should resolve all tools from all toolsets @@ -141,3 +171,20 @@ class TestToolsetConsistency: # All platform toolsets should be identical for ts in tool_sets[1:]: assert ts == tool_sets[0] + + +class TestPluginToolsets: + def test_get_all_toolsets_includes_plugin_toolset(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="plugin_tool", + toolset="plugin_bundle", + schema=_make_schema("plugin_tool", "Plugin tool"), + handler=_dummy_handler, + ) + + monkeypatch.setattr("tools.registry.registry", reg) + + all_toolsets = get_all_toolsets() + assert "plugin_bundle" in all_toolsets + assert all_toolsets["plugin_bundle"]["tools"] == ["plugin_tool"] diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index a269218c2..d2fbc7c10 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -380,7 +380,7 @@ class TestStubSchemaDrift(unittest.TestCase): # Parameters that are internal (injected by the handler, not user-facing) _INTERNAL_PARAMS = {"task_id", "user_task"} # Parameters intentionally blocked in the sandbox - _BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete"} + _BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete", "watch_patterns"} def test_stubs_cover_all_schema_params(self): """Every user-facing parameter in the real schema must appear in the diff --git a/tests/tools/test_interrupt.py b/tests/tools/test_interrupt.py index 13b5041d6..61a898ac3 100644 --- a/tests/tools/test_interrupt.py +++ b/tests/tools/test_interrupt.py @@ -29,8 +29,11 @@ class TestInterruptModule: def test_thread_safety(self): """Set from one thread targeting another thread's ident.""" - from tools.interrupt import set_interrupt, is_interrupted + from tools.interrupt import set_interrupt, is_interrupted, _interrupted_threads, _lock set_interrupt(False) + # Clear any stale thread idents left by prior tests in this worker. + with _lock: + _interrupted_threads.clear() seen = {"value": False} diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 663895c0b..883bbe318 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -6,6 +6,8 @@ All tests use mocks -- no real MCP servers or subprocesses are started. import asyncio import json import os +import threading +import time from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -255,6 +257,77 @@ class TestToolHandler: finally: _servers.pop("test_srv", None) + def test_interrupted_call_returns_interrupted_error(self): + from tools.mcp_tool import _make_tool_handler, _servers + + mock_session = MagicMock() + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server + + try: + handler = _make_tool_handler("test_srv", "greet", 120) + def _interrupting_run(coro, timeout=30): + coro.close() + raise InterruptedError("User sent a new message") + with patch( + "tools.mcp_tool._run_on_mcp_loop", + side_effect=_interrupting_run, + ): + result = json.loads(handler({})) + assert result == {"error": "MCP call interrupted: user sent a new message"} + finally: + _servers.pop("test_srv", None) + + +class TestRunOnMCPLoopInterrupts: + def test_interrupt_cancels_waiting_mcp_call(self): + import tools.mcp_tool as mcp_mod + from tools.interrupt import set_interrupt + + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + + cancelled = threading.Event() + + async def _slow_call(): + try: + await asyncio.sleep(5) + return "done" + except asyncio.CancelledError: + cancelled.set() + raise + + old_loop = mcp_mod._mcp_loop + old_thread = mcp_mod._mcp_thread + mcp_mod._mcp_loop = loop + mcp_mod._mcp_thread = thread + + waiter_tid = threading.current_thread().ident + + def _interrupt_soon(): + time.sleep(0.2) + set_interrupt(True, waiter_tid) + + interrupter = threading.Thread(target=_interrupt_soon, daemon=True) + interrupter.start() + + try: + with pytest.raises(InterruptedError, match="User sent a new message"): + mcp_mod._run_on_mcp_loop(_slow_call(), timeout=2) + + deadline = time.time() + 2 + while time.time() < deadline and not cancelled.is_set(): + time.sleep(0.05) + assert cancelled.is_set() + finally: + set_interrupt(False, waiter_tid) + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=2) + loop.close() + mcp_mod._mcp_loop = old_loop + mcp_mod._mcp_thread = old_thread + # --------------------------------------------------------------------------- # Tool registration (discovery + register) diff --git a/tests/tools/test_registry.py b/tests/tools/test_registry.py index 455e9f48a..6b2756886 100644 --- a/tests/tools/test_registry.py +++ b/tests/tools/test_registry.py @@ -1,6 +1,7 @@ """Tests for the central tool registry.""" import json +import threading from tools.registry import ToolRegistry @@ -167,6 +168,32 @@ class TestToolsetAvailability: ) assert reg.get_all_tool_names() == ["a_tool", "z_tool"] + def test_get_registered_toolset_names(self): + reg = ToolRegistry() + reg.register( + name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler + ) + assert reg.get_registered_toolset_names() == ["alpha", "zeta"] + + def test_get_tool_names_for_toolset(self): + reg = ToolRegistry() + reg.register( + name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler + ) + reg.register( + name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler + ) + assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"] + def test_handler_exception_returns_error(self): reg = ToolRegistry() @@ -301,6 +328,22 @@ class TestEmojiMetadata: assert reg.get_emoji("t") == "⚡" +class TestEntryLookup: + def test_get_entry_returns_registered_entry(self): + reg = ToolRegistry() + reg.register( + name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler + ) + entry = reg.get_entry("alpha") + assert entry is not None + assert entry.name == "alpha" + assert entry.toolset == "core" + + def test_get_entry_returns_none_for_unknown_tool(self): + reg = ToolRegistry() + assert reg.get_entry("missing") is None + + class TestSecretCaptureResultContract: def test_secret_request_result_does_not_include_secret_value(self): result = { @@ -309,3 +352,141 @@ class TestSecretCaptureResultContract: "validated": False, } assert "secret" not in json.dumps(result).lower() + + +class TestThreadSafety: + def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch): + reg = ToolRegistry() + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=lambda: False, + ) + + entries, toolset_checks = reg._snapshot_state() + + def snapshot_then_mutate(): + reg.deregister("alpha") + return entries, toolset_checks + + monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate) + + toolsets = reg.get_available_toolsets() + assert toolsets["gated"]["available"] is False + assert toolsets["gated"]["tools"] == ["alpha"] + + def test_check_tool_availability_tolerates_concurrent_register(self): + reg = ToolRegistry() + check_started = threading.Event() + writer_done = threading.Event() + errors = [] + result_holder = {} + writer_completed_during_check = {} + + def blocking_check(): + check_started.set() + writer_completed_during_check["value"] = writer_done.wait(timeout=1) + return True + + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=blocking_check, + ) + reg.register( + name="beta", + toolset="plain", + schema=_make_schema("beta"), + handler=_dummy_handler, + ) + + def reader(): + try: + result_holder["value"] = reg.check_tool_availability() + except Exception as exc: # pragma: no cover - exercised on failure only + errors.append(exc) + + def writer(): + assert check_started.wait(timeout=1) + reg.register( + name="gamma", + toolset="new", + schema=_make_schema("gamma"), + handler=_dummy_handler, + ) + writer_done.set() + + reader_thread = threading.Thread(target=reader) + writer_thread = threading.Thread(target=writer) + reader_thread.start() + writer_thread.start() + reader_thread.join(timeout=2) + writer_thread.join(timeout=2) + + assert not reader_thread.is_alive() + assert not writer_thread.is_alive() + assert writer_completed_during_check["value"] is True + assert errors == [] + + available, unavailable = result_holder["value"] + assert "gated" in available + assert "plain" in available + assert unavailable == [] + + def test_get_available_toolsets_tolerates_concurrent_deregister(self): + reg = ToolRegistry() + check_started = threading.Event() + writer_done = threading.Event() + errors = [] + result_holder = {} + writer_completed_during_check = {} + + def blocking_check(): + check_started.set() + writer_completed_during_check["value"] = writer_done.wait(timeout=1) + return True + + reg.register( + name="alpha", + toolset="gated", + schema=_make_schema("alpha"), + handler=_dummy_handler, + check_fn=blocking_check, + ) + reg.register( + name="beta", + toolset="plain", + schema=_make_schema("beta"), + handler=_dummy_handler, + ) + + def reader(): + try: + result_holder["value"] = reg.get_available_toolsets() + except Exception as exc: # pragma: no cover - exercised on failure only + errors.append(exc) + + def writer(): + assert check_started.wait(timeout=1) + reg.deregister("beta") + writer_done.set() + + reader_thread = threading.Thread(target=reader) + writer_thread = threading.Thread(target=writer) + reader_thread.start() + writer_thread.start() + reader_thread.join(timeout=2) + writer_thread.join(timeout=2) + + assert not reader_thread.is_alive() + assert not writer_thread.is_alive() + assert writer_completed_during_check["value"] is True + assert errors == [] + + toolsets = result_holder["value"] + assert "gated" in toolsets + assert toolsets["gated"]["available"] is True diff --git a/tools/approval.py b/tools/approval.py index 70420976b..3e9ccdf75 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -313,6 +313,17 @@ def disable_session_yolo(session_key: str) -> None: _session_yolo.discard(session_key) +def clear_session(session_key: str) -> None: + """Remove all approval and yolo state for a given session.""" + if not session_key: + return + with _lock: + _session_approved.pop(session_key, None) + _session_yolo.discard(session_key) + _pending.pop(session_key, None) + _gateway_queues.pop(session_key, None) + + def is_session_yolo_enabled(session_key: str) -> bool: """Return True when YOLO bypass is enabled for a specific session.""" if not session_key: diff --git a/tools/file_operations.py b/tools/file_operations.py index 29180931d..b6ab271cd 100644 --- a/tools/file_operations.py +++ b/tools/file_operations.py @@ -556,27 +556,54 @@ class ShellFileOperations(FileOperations): def _suggest_similar_files(self, path: str) -> ReadResult: """Suggest similar files when the requested file is not found.""" - # Get directory and filename dir_path = os.path.dirname(path) or "." filename = os.path.basename(path) - - # List files in directory - ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -20" + basename_no_ext = os.path.splitext(filename)[0] + ext = os.path.splitext(filename)[1].lower() + lower_name = filename.lower() + + # List files in the target directory + ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -50" ls_result = self._exec(ls_cmd) - - similar = [] + + scored: list = [] # (score, filepath) — higher is better if ls_result.exit_code == 0 and ls_result.stdout.strip(): - files = ls_result.stdout.strip().split('\n') - # Simple similarity: files that share some characters with the target - for f in files: - # Check if filenames share significant overlap - common = set(filename.lower()) & set(f.lower()) - if len(common) >= len(filename) * 0.5: # 50% character overlap - similar.append(os.path.join(dir_path, f)) - + for f in ls_result.stdout.strip().split('\n'): + if not f: + continue + lf = f.lower() + score = 0 + + # Exact match (shouldn't happen, but guard) + if lf == lower_name: + score = 100 + # Same base name, different extension (e.g. config.yml vs config.yaml) + elif os.path.splitext(f)[0].lower() == basename_no_ext.lower(): + score = 90 + # Target is prefix of candidate or vice-versa + elif lf.startswith(lower_name) or lower_name.startswith(lf): + score = 70 + # Substring match (candidate contains query) + elif lower_name in lf: + score = 60 + # Reverse substring (query contains candidate name) + elif lf in lower_name and len(lf) > 2: + score = 40 + # Same extension with some overlap + elif ext and os.path.splitext(f)[1].lower() == ext: + common = set(lower_name) & set(lf) + if len(common) >= max(len(lower_name), len(lf)) * 0.4: + score = 30 + + if score > 0: + scored.append((score, os.path.join(dir_path, f))) + + scored.sort(key=lambda x: -x[0]) + similar = [fp for _, fp in scored[:5]] + return ReadResult( error=f"File not found: {path}", - similar_files=similar[:5] # Limit to 5 suggestions + similar_files=similar ) def read_file_raw(self, path: str) -> ReadResult: @@ -845,8 +872,33 @@ class ShellFileOperations(FileOperations): # Validate that the path exists before searching check = self._exec(f"test -e {self._escape_shell_arg(path)} && echo exists || echo not_found") if "not_found" in check.stdout: + # Try to suggest nearby paths + parent = os.path.dirname(path) or "." + basename_query = os.path.basename(path) + hint_parts = [f"Path not found: {path}"] + # Check if parent directory exists and list similar entries + parent_check = self._exec( + f"test -d {self._escape_shell_arg(parent)} && echo yes || echo no" + ) + if "yes" in parent_check.stdout and basename_query: + ls_result = self._exec( + f"ls -1 {self._escape_shell_arg(parent)} 2>/dev/null | head -20" + ) + if ls_result.exit_code == 0 and ls_result.stdout.strip(): + lower_q = basename_query.lower() + candidates = [] + for entry in ls_result.stdout.strip().split('\n'): + if not entry: + continue + le = entry.lower() + if lower_q in le or le in lower_q or le.startswith(lower_q[:3]): + candidates.append(os.path.join(parent, entry)) + if candidates: + hint_parts.append( + "Similar paths: " + ", ".join(candidates[:5]) + ) return SearchResult( - error=f"Path not found: {path}. Verify the path exists (use 'terminal' to check).", + error=". ".join(hint_parts), total_count=0 ) @@ -912,7 +964,8 @@ class ShellFileOperations(FileOperations): rg --files respects .gitignore and excludes hidden directories by default, and uses parallel directory traversal for ~200x speedup - over find on wide trees. + over find on wide trees. Results are sorted by modification time + (most recently edited first) when rg >= 13.0 supports --sortr. """ # rg --files -g uses glob patterns; wrap bare names so they match # at any depth (equivalent to find -name). @@ -922,14 +975,25 @@ class ShellFileOperations(FileOperations): glob_pattern = pattern fetch_limit = limit + offset - cmd = ( - f"rg --files -g {self._escape_shell_arg(glob_pattern)} " + # Try mtime-sorted first (rg 13+); fall back to unsorted if not supported. + cmd_sorted = ( + f"rg --files --sortr=modified -g {self._escape_shell_arg(glob_pattern)} " f"{self._escape_shell_arg(path)} 2>/dev/null " f"| head -n {fetch_limit}" ) - result = self._exec(cmd, timeout=60) - + result = self._exec(cmd_sorted, timeout=60) all_files = [f for f in result.stdout.strip().split('\n') if f] + + if not all_files: + # --sortr may have failed on older rg; retry without it. + cmd_plain = ( + f"rg --files -g {self._escape_shell_arg(glob_pattern)} " + f"{self._escape_shell_arg(path)} 2>/dev/null " + f"| head -n {fetch_limit}" + ) + result = self._exec(cmd_plain, timeout=60) + all_files = [f for f in result.stdout.strip().split('\n') if f] + page = all_files[offset:offset + limit] return SearchResult( diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e953998cc..2356830c4 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -70,6 +70,7 @@ Thread safety: """ import asyncio +import concurrent.futures import inspect import json import logging @@ -1167,13 +1168,43 @@ def _ensure_mcp_loop(): def _run_on_mcp_loop(coro, timeout: float = 30): - """Schedule a coroutine on the MCP event loop and block until done.""" + """Schedule a coroutine on the MCP event loop and block until done. + + Poll in short intervals so the calling agent thread can honor user + interrupts while the MCP work is still running on the background loop. + """ + from tools.interrupt import is_interrupted + with _lock: loop = _mcp_loop if loop is None or not loop.is_running(): raise RuntimeError("MCP event loop is not running") future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result(timeout=timeout) + deadline = None if timeout is None else time.monotonic() + timeout + + while True: + if is_interrupted(): + future.cancel() + raise InterruptedError("User sent a new message") + + wait_timeout = 0.1 + if deadline is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + return future.result(timeout=0) + wait_timeout = min(wait_timeout, remaining) + + try: + return future.result(timeout=wait_timeout) + except concurrent.futures.TimeoutError: + continue + + +def _interrupted_call_result() -> str: + """Standardized JSON error for a user-interrupted MCP tool call.""" + return json.dumps({ + "error": "MCP call interrupted: user sent a new message" + }) # --------------------------------------------------------------------------- @@ -1299,6 +1330,8 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): try: return _run_on_mcp_loop(_call(), timeout=tool_timeout) + except InterruptedError: + return _interrupted_call_result() except Exception as exc: logger.error( "MCP tool %s/%s call failed: %s", @@ -1342,6 +1375,8 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float): try: return _run_on_mcp_loop(_call(), timeout=tool_timeout) + except InterruptedError: + return _interrupted_call_result() except Exception as exc: logger.error( "MCP %s/list_resources failed: %s", server_name, exc, @@ -1386,6 +1421,8 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): try: return _run_on_mcp_loop(_call(), timeout=tool_timeout) + except InterruptedError: + return _interrupted_call_result() except Exception as exc: logger.error( "MCP %s/read_resource failed: %s", server_name, exc, @@ -1433,6 +1470,8 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float): try: return _run_on_mcp_loop(_call(), timeout=tool_timeout) + except InterruptedError: + return _interrupted_call_result() except Exception as exc: logger.error( "MCP %s/list_prompts failed: %s", server_name, exc, @@ -1488,6 +1527,8 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): try: return _run_on_mcp_loop(_call(), timeout=tool_timeout) + except InterruptedError: + return _interrupted_call_result() except Exception as exc: logger.error( "MCP %s/get_prompt failed: %s", server_name, exc, diff --git a/tools/registry.py b/tools/registry.py index d3590a42c..d6aff8348 100644 --- a/tools/registry.py +++ b/tools/registry.py @@ -16,6 +16,7 @@ Import chain (circular-import safe): import json import logging +import threading from typing import Callable, Dict, List, Optional, Set logger = logging.getLogger(__name__) @@ -51,6 +52,49 @@ class ToolRegistry: def __init__(self): self._tools: Dict[str, ToolEntry] = {} self._toolset_checks: Dict[str, Callable] = {} + # MCP dynamic refresh can mutate the registry while other threads are + # reading tool metadata, so keep mutations serialized and readers on + # stable snapshots. + self._lock = threading.RLock() + + def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]: + """Return a coherent snapshot of registry entries and toolset checks.""" + with self._lock: + return list(self._tools.values()), dict(self._toolset_checks) + + def _snapshot_entries(self) -> List[ToolEntry]: + """Return a stable snapshot of registered tool entries.""" + return self._snapshot_state()[0] + + def _snapshot_toolset_checks(self) -> Dict[str, Callable]: + """Return a stable snapshot of toolset availability checks.""" + return self._snapshot_state()[1] + + def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool: + """Run a toolset check, treating missing or failing checks as unavailable/available.""" + if not check: + return True + try: + return bool(check()) + except Exception: + logger.debug("Toolset %s check raised; marking unavailable", toolset) + return False + + def get_entry(self, name: str) -> Optional[ToolEntry]: + """Return a registered tool entry by name, or None.""" + with self._lock: + return self._tools.get(name) + + def get_registered_toolset_names(self) -> List[str]: + """Return sorted unique toolset names present in the registry.""" + return sorted({entry.toolset for entry in self._snapshot_entries()}) + + def get_tool_names_for_toolset(self, toolset: str) -> List[str]: + """Return sorted tool names registered under a given toolset.""" + return sorted( + entry.name for entry in self._snapshot_entries() + if entry.toolset == toolset + ) # ------------------------------------------------------------------ # Registration @@ -70,27 +114,28 @@ class ToolRegistry: max_result_size_chars: int | float | None = None, ): """Register a tool. Called at module-import time by each tool file.""" - existing = self._tools.get(name) - if existing and existing.toolset != toolset: - logger.warning( - "Tool name collision: '%s' (toolset '%s') is being " - "overwritten by toolset '%s'", - name, existing.toolset, toolset, + with self._lock: + existing = self._tools.get(name) + if existing and existing.toolset != toolset: + logger.warning( + "Tool name collision: '%s' (toolset '%s') is being " + "overwritten by toolset '%s'", + name, existing.toolset, toolset, + ) + self._tools[name] = ToolEntry( + name=name, + toolset=toolset, + schema=schema, + handler=handler, + check_fn=check_fn, + requires_env=requires_env or [], + is_async=is_async, + description=description or schema.get("description", ""), + emoji=emoji, + max_result_size_chars=max_result_size_chars, ) - self._tools[name] = ToolEntry( - name=name, - toolset=toolset, - schema=schema, - handler=handler, - check_fn=check_fn, - requires_env=requires_env or [], - is_async=is_async, - description=description or schema.get("description", ""), - emoji=emoji, - max_result_size_chars=max_result_size_chars, - ) - if check_fn and toolset not in self._toolset_checks: - self._toolset_checks[toolset] = check_fn + if check_fn and toolset not in self._toolset_checks: + self._toolset_checks[toolset] = check_fn def deregister(self, name: str) -> None: """Remove a tool from the registry. @@ -99,14 +144,15 @@ class ToolRegistry: same toolset. Used by MCP dynamic tool discovery to nuke-and-repave when a server sends ``notifications/tools/list_changed``. """ - entry = self._tools.pop(name, None) - if entry is None: - return - # Drop the toolset check if this was the last tool in that toolset - if entry.toolset in self._toolset_checks and not any( - e.toolset == entry.toolset for e in self._tools.values() - ): - self._toolset_checks.pop(entry.toolset, None) + with self._lock: + entry = self._tools.pop(name, None) + if entry is None: + return + # Drop the toolset check if this was the last tool in that toolset + if entry.toolset in self._toolset_checks and not any( + e.toolset == entry.toolset for e in self._tools.values() + ): + self._toolset_checks.pop(entry.toolset, None) logger.debug("Deregistered tool: %s", name) # ------------------------------------------------------------------ @@ -121,8 +167,9 @@ class ToolRegistry: """ result = [] check_results: Dict[Callable, bool] = {} + entries_by_name = {entry.name: entry for entry in self._snapshot_entries()} for name in sorted(tool_names): - entry = self._tools.get(name) + entry = entries_by_name.get(name) if not entry: continue if entry.check_fn: @@ -153,7 +200,7 @@ class ToolRegistry: * All exceptions are caught and returned as ``{"error": "..."}`` for consistent error format. """ - entry = self._tools.get(name) + entry = self.get_entry(name) if not entry: return json.dumps({"error": f"Unknown tool: {name}"}) try: @@ -171,7 +218,7 @@ class ToolRegistry: def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float: """Return per-tool max result size, or *default* (or global default).""" - entry = self._tools.get(name) + entry = self.get_entry(name) if entry and entry.max_result_size_chars is not None: return entry.max_result_size_chars if default is not None: @@ -181,7 +228,7 @@ class ToolRegistry: def get_all_tool_names(self) -> List[str]: """Return sorted list of all registered tool names.""" - return sorted(self._tools.keys()) + return sorted(entry.name for entry in self._snapshot_entries()) def get_schema(self, name: str) -> Optional[dict]: """Return a tool's raw schema dict, bypassing check_fn filtering. @@ -189,22 +236,22 @@ class ToolRegistry: Useful for token estimation and introspection where availability doesn't matter — only the schema content does. """ - entry = self._tools.get(name) + entry = self.get_entry(name) return entry.schema if entry else None def get_toolset_for_tool(self, name: str) -> Optional[str]: """Return the toolset a tool belongs to, or None.""" - entry = self._tools.get(name) + entry = self.get_entry(name) return entry.toolset if entry else None def get_emoji(self, name: str, default: str = "⚡") -> str: """Return the emoji for a tool, or *default* if unset.""" - entry = self._tools.get(name) + entry = self.get_entry(name) return (entry.emoji if entry and entry.emoji else default) def get_tool_to_toolset_map(self) -> Dict[str, str]: """Return ``{tool_name: toolset_name}`` for every registered tool.""" - return {name: e.toolset for name, e in self._tools.items()} + return {entry.name: entry.toolset for entry in self._snapshot_entries()} def is_toolset_available(self, toolset: str) -> bool: """Check if a toolset's requirements are met. @@ -212,28 +259,30 @@ class ToolRegistry: Returns False (rather than crashing) when the check function raises an unexpected exception (e.g. network error, missing import, bad config). """ - check = self._toolset_checks.get(toolset) - if not check: - return True - try: - return bool(check()) - except Exception: - logger.debug("Toolset %s check raised; marking unavailable", toolset) - return False + with self._lock: + check = self._toolset_checks.get(toolset) + return self._evaluate_toolset_check(toolset, check) def check_toolset_requirements(self) -> Dict[str, bool]: """Return ``{toolset: available_bool}`` for every toolset.""" - toolsets = set(e.toolset for e in self._tools.values()) - return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)} + entries, toolset_checks = self._snapshot_state() + toolsets = sorted({entry.toolset for entry in entries}) + return { + toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset)) + for toolset in toolsets + } def get_available_toolsets(self) -> Dict[str, dict]: """Return toolset metadata for UI display.""" toolsets: Dict[str, dict] = {} - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts not in toolsets: toolsets[ts] = { - "available": self.is_toolset_available(ts), + "available": self._evaluate_toolset_check( + ts, toolset_checks.get(ts) + ), "tools": [], "description": "", "requirements": [], @@ -248,13 +297,14 @@ class ToolRegistry: def get_toolset_requirements(self) -> Dict[str, dict]: """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat.""" result: Dict[str, dict] = {} - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts not in result: result[ts] = { "name": ts, "env_vars": [], - "check_fn": self._toolset_checks.get(ts), + "check_fn": toolset_checks.get(ts), "setup_url": None, "tools": [], } @@ -270,18 +320,19 @@ class ToolRegistry: available = [] unavailable = [] seen = set() - for entry in self._tools.values(): + entries, toolset_checks = self._snapshot_state() + for entry in entries: ts = entry.toolset if ts in seen: continue seen.add(ts) - if self.is_toolset_available(ts): + if self._evaluate_toolset_check(ts, toolset_checks.get(ts)): available.append(ts) else: unavailable.append({ "name": ts, "env_vars": entry.requires_env, - "tools": [e.name for e in self._tools.values() if e.toolset == ts], + "tools": [e.name for e in entries if e.toolset == ts], }) return available, unavailable diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index a2b3e984c..391e03baa 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -152,6 +152,7 @@ def _handle_send(args): "whatsapp": Platform.WHATSAPP, "signal": Platform.SIGNAL, "bluebubbles": Platform.BLUEBUBBLES, + "qqbot": Platform.QQBOT, "matrix": Platform.MATRIX, "mattermost": Platform.MATTERMOST, "homeassistant": Platform.HOMEASSISTANT, @@ -426,6 +427,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, result = await _send_wecom(pconfig.extra, chat_id, chunk) elif platform == Platform.BLUEBUBBLES: result = await _send_bluebubbles(pconfig.extra, chat_id, chunk) + elif platform == Platform.QQBOT: + result = await _send_qqbot(pconfig, chat_id, chunk) else: result = {"error": f"Direct sending not yet implemented for {platform.value}"} @@ -1038,6 +1041,58 @@ def _check_send_message(): return False +async def _send_qqbot(pconfig, chat_id, message): + """Send via QQBot using the REST API directly (no WebSocket needed). + + Uses the QQ Bot Open Platform REST endpoints to get an access token + and post a message. Works for guild channels without requiring + a running gateway adapter. + """ + try: + import httpx + except ImportError: + return _error("QQBot direct send requires httpx. Run: pip install httpx") + + extra = pconfig.extra or {} + appid = extra.get("app_id") or os.getenv("QQ_APP_ID", "") + secret = (pconfig.token or extra.get("client_secret") + or os.getenv("QQ_CLIENT_SECRET", "")) + if not appid or not secret: + return _error("QQBot: QQ_APP_ID / QQ_CLIENT_SECRET not configured.") + + try: + async with httpx.AsyncClient(timeout=15) as client: + # Step 1: Get access token + token_resp = await client.post( + "https://bots.qq.com/app/getAppAccessToken", + json={"appId": str(appid), "clientSecret": str(secret)}, + ) + if token_resp.status_code != 200: + return _error(f"QQBot token request failed: {token_resp.status_code}") + token_data = token_resp.json() + access_token = token_data.get("access_token") + if not access_token: + return _error(f"QQBot: no access_token in response") + + # Step 2: Send message via REST + headers = { + "Authorization": f"QQBotAccessToken {access_token}", + "Content-Type": "application/json", + } + url = f"https://api.sgroup.qq.com/channels/{chat_id}/messages" + payload = {"content": message[:4000], "msg_type": 0} + + resp = await client.post(url, json=payload, headers=headers) + if resp.status_code in (200, 201): + data = resp.json() + return {"success": True, "platform": "qqbot", "chat_id": chat_id, + "message_id": data.get("id")} + else: + return _error(f"QQBot send failed: {resp.status_code} {resp.text}") + except Exception as e: + return _error(f"QQBot send failed: {e}") + + # --- Registry --- from tools.registry import registry, tool_error diff --git a/tools/skills_tool.py b/tools/skills_tool.py index 5a9e80f34..90839b9a7 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -245,6 +245,9 @@ def _get_required_environment_variables( if isinstance(required_for, str) and required_for.strip(): normalized["required_for"] = required_for.strip() + if entry.get("optional"): + normalized["optional"] = True + seen.add(env_name) required.append(normalized) @@ -378,6 +381,8 @@ def _remaining_required_environment_names( remaining = [] for entry in required_env_vars: name = entry["name"] + if entry.get("optional"): + continue if name in missing_names or not _is_env_var_persisted(name, env_snapshot): remaining.append(name) return remaining @@ -1042,7 +1047,8 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str: missing_required_env_vars = [ e for e in required_env_vars - if not _is_env_var_persisted(e["name"], env_snapshot) + if not e.get("optional") + and not _is_env_var_persisted(e["name"], env_snapshot) ] capture_result = _capture_required_environment_variables( skill_name, diff --git a/toolsets.py b/toolsets.py index 57e03d250..2e7a0a92a 100644 --- a/toolsets.py +++ b/toolsets.py @@ -359,6 +359,12 @@ TOOLSETS = { "includes": [] }, + "hermes-qqbot": { + "description": "QQBot toolset - QQ messaging via Official Bot API v2 (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-wecom": { "description": "WeCom bot toolset - enterprise WeChat messaging (full access)", "tools": _HERMES_CORE_TOOLS, @@ -386,7 +392,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-wecom-callback", "hermes-weixin", "hermes-webhook"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-wecom-callback", "hermes-weixin", "hermes-qqbot", "hermes-webhook"] } } @@ -449,7 +455,7 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: if name in _get_plugin_toolset_names(): try: from tools.registry import registry - return [e.name for e in registry._tools.values() if e.toolset == name] + return registry.get_tool_names_for_toolset(name) except Exception: pass return [] @@ -495,9 +501,9 @@ def _get_plugin_toolset_names() -> Set[str]: try: from tools.registry import registry return { - entry.toolset - for entry in registry._tools.values() - if entry.toolset not in TOOLSETS + toolset_name + for toolset_name in registry.get_registered_toolset_names() + if toolset_name not in TOOLSETS } except Exception: return set() @@ -518,7 +524,7 @@ def get_all_toolsets() -> Dict[str, Dict[str, Any]]: if ts_name not in result: try: from tools.registry import registry - tools = [e.name for e in registry._tools.values() if e.toolset == ts_name] + tools = registry.get_tool_names_for_toolset(ts_name) result[ts_name] = { "description": f"Plugin toolset: {ts_name}", "tools": tools, diff --git a/ui-tui/src/app.tsx b/ui-tui/src/app.tsx index c329057c7..03ae3b57e 100644 --- a/ui-tui/src/app.tsx +++ b/ui-tui/src/app.tsx @@ -8,6 +8,7 @@ import { Box, NoSelect, ScrollBox, + type ScrollBoxHandle, Text, useApp, useHasSelection, @@ -15,7 +16,7 @@ import { useSelection, useStdout } from '@hermes/ink' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { type RefObject, useCallback, useEffect, useMemo, useRef, useState, useSyncExternalStore } from 'react' import { Banner, Panel, SessionPanel } from './components/branding.js' import { MaskedPrompt } from './components/maskedPrompt.js' @@ -44,7 +45,8 @@ import { pick, sameToolTrailGroup, stripTrailingPasteNewlines, - toolTrailLabel + toolTrailLabel, + userDisplay } from './lib/text.js' import { DEFAULT_THEME, fromSkin, type Theme } from './theme.js' import type { @@ -70,6 +72,7 @@ const STARTUP_RESUME_ID = (process.env.HERMES_TUI_RESUME ?? '').trim() const LARGE_PASTE = { chars: 8000, lines: 80 } const MAX_HISTORY = 800 const REASONING_PULSE_MS = 700 +const STREAM_BATCH_MS = 16 const WHEEL_SCROLL_STEP = 3 const MOUSE_TRACKING = !/^(1|true|yes|on)$/.test((process.env.HERMES_TUI_DISABLE_MOUSE ?? '').trim().toLowerCase()) const PASTE_SNIPPET_RE = /\[\[[^\n]*?\]\]/g @@ -78,14 +81,18 @@ const DETAILS_MODES: DetailsMode[] = ['hidden', 'collapsed', 'expanded'] const parseDetailsMode = (v: unknown): DetailsMode | null => { const s = typeof v === 'string' ? v.trim().toLowerCase() : '' + return DETAILS_MODES.includes(s as DetailsMode) ? (s as DetailsMode) : null } const resolveDetailsMode = (d: any): DetailsMode => - parseDetailsMode(d?.details_mode) - ?? ({ full: 'expanded' as const, collapsed: 'collapsed' as const, truncated: 'collapsed' as const }[ - String(d?.thinking_mode ?? '').trim().toLowerCase() - ] ?? 'collapsed') + parseDetailsMode(d?.details_mode) ?? + { full: 'expanded' as const, collapsed: 'collapsed' as const, truncated: 'collapsed' as const }[ + String(d?.thinking_mode ?? '') + .trim() + .toLowerCase() + ] ?? + 'collapsed' const nextDetailsMode = (m: DetailsMode): DetailsMode => DETAILS_MODES[(DETAILS_MODES.indexOf(m) + 1) % DETAILS_MODES.length]! @@ -98,6 +105,7 @@ type PasteSnippet = { label: string; text: string } const shortCwd = (cwd: string, max = 28) => { const home = process.env.HOME const path = home && cwd.startsWith(home) ? `~${cwd.slice(home.length)}` : cwd + return path.length <= max ? path : `…${path.slice(-(max - 1))}` } @@ -288,6 +296,78 @@ function PromptBox({ children, color }: { children: React.ReactNode; color: stri ) } +const upperBound = (arr: ArrayLike, target: number) => { + let lo = 0 + let hi = arr.length + + while (lo < hi) { + const mid = (lo + hi) >> 1 + + if (arr[mid]! <= target) { + lo = mid + 1 + } else { + hi = mid + } + } + + return lo +} + +function StickyPromptTracker({ + messages, + offsets, + scrollRef, + onChange +}: { + messages: readonly Msg[] + offsets: ArrayLike + scrollRef: RefObject + onChange: (text: string) => void +}) { + useSyncExternalStore( + useCallback((cb: () => void) => scrollRef.current?.subscribe(cb) ?? (() => {}), [scrollRef]), + () => { + const s = scrollRef.current + + if (!s) { + return NaN + } + const top = Math.max(0, s.getScrollTop() + s.getPendingDelta()) + + return s.isSticky() ? -1 - top : top + }, + () => NaN + ) + + const s = scrollRef.current + const top = Math.max(0, (s?.getScrollTop() ?? 0) + (s?.getPendingDelta() ?? 0)) + + let text = '' + + if (!(s?.isSticky() ?? true) && messages.length) { + const first = Math.max(0, Math.min(messages.length - 1, upperBound(offsets, top) - 1)) + + if (!(messages[first]?.role === 'user' && (offsets[first] ?? 0) + 1 >= top)) { + for (let i = first - 1; i >= 0; i--) { + if (messages[i]?.role !== 'user') { + continue + } + + if ((offsets[i] ?? 0) + 1 >= top) { + continue + } + text = userDisplay(messages[i]!.text.trim()).replace(/\s+/g, ' ').trim() + + break + } + } + } + + useEffect(() => onChange(text), [onChange, text]) + + return null +} + // ── App ────────────────────────────────────────────────────────────── export function App({ gw }: { gw: GatewayClient }) { @@ -343,6 +423,7 @@ export function App({ gw }: { gw: GatewayClient }) { const [reasoningStreaming, setReasoningStreaming] = useState(false) const [statusBar, setStatusBar] = useState(true) const [lastUserMsg, setLastUserMsg] = useState('') + const [stickyPrompt, setStickyPrompt] = useState('') const [pasteSnips, setPasteSnips] = useState([]) const [streaming, setStreaming] = useState('') const [turnTrail, setTurnTrail] = useState([]) @@ -371,11 +452,13 @@ export function App({ gw }: { gw: GatewayClient }) { const colsRef = useRef(cols) const turnToolsRef = useRef([]) const persistedToolLabelsRef = useRef>(new Set()) + const streamTimerRef = useRef | null>(null) + const reasoningTimerRef = useRef | null>(null) const reasoningStreamingTimerRef = useRef | null>(null) const statusTimerRef = useRef | null>(null) const busyRef = useRef(busy) const sidRef = useRef(sid) - const scrollRef = useRef(null) + const scrollRef = useRef(null) const onEventRef = useRef<(ev: GatewayEvent) => void>(() => {}) const configMtimeRef = useRef(0) colsRef.current = cols @@ -407,6 +490,28 @@ export function App({ gw }: { gw: GatewayClient }) { }, REASONING_PULSE_MS) }, []) + const scheduleStreaming = useCallback(() => { + if (streamTimerRef.current) { + return + } + + streamTimerRef.current = setTimeout(() => { + streamTimerRef.current = null + setStreaming(buf.current.trimStart()) + }, STREAM_BATCH_MS) + }, []) + + const scheduleReasoning = useCallback(() => { + if (reasoningTimerRef.current) { + return + } + + reasoningTimerRef.current = setTimeout(() => { + reasoningTimerRef.current = null + setReasoning(reasoningRef.current) + }, STREAM_BATCH_MS) + }, []) + const endReasoningPhase = useCallback(() => { if (reasoningStreamingTimerRef.current) { clearTimeout(reasoningStreamingTimerRef.current) @@ -419,6 +524,14 @@ export function App({ gw }: { gw: GatewayClient }) { useEffect( () => () => { + if (streamTimerRef.current) { + clearTimeout(streamTimerRef.current) + } + + if (reasoningTimerRef.current) { + clearTimeout(reasoningTimerRef.current) + } + if (reasoningStreamingTimerRef.current) { clearTimeout(reasoningStreamingTimerRef.current) } @@ -432,6 +545,7 @@ export function App({ gw }: { gw: GatewayClient }) { const empty = !messages.length const isBlocked = blocked() + const virtualRows = useMemo( () => historyItems.map((msg, index) => ({ @@ -441,6 +555,7 @@ export function App({ gw }: { gw: GatewayClient }) { })), [historyItems] ) + const virtualHistory = useVirtualHistory(scrollRef, virtualRows) // ── Resize RPC ─────────────────────────────────────────────────── @@ -651,10 +766,26 @@ export function App({ gw }: { gw: GatewayClient }) { setApproval(null) setSudo(null) setSecret(null) + + if (streamTimerRef.current) { + clearTimeout(streamTimerRef.current) + streamTimerRef.current = null + } + setStreaming('') buf.current = '' } + const clearReasoning = () => { + if (reasoningTimerRef.current) { + clearTimeout(reasoningTimerRef.current) + reasoningTimerRef.current = null + } + + reasoningRef.current = '' + setReasoning('') + } + const die = () => { gw.kill() exit() @@ -670,13 +801,14 @@ export function App({ gw }: { gw: GatewayClient }) { const resetSession = () => { idle() - setReasoning('') + clearReasoning() setVoiceRecording(false) setVoiceProcessing(false) setSid(null as any) // will be set by caller setInfo(null) setHistoryItems([]) setMessages([]) + setStickyPrompt('') setPasteSnips([]) setActivity([]) setBgTasks(new Set()) @@ -688,11 +820,12 @@ export function App({ gw }: { gw: GatewayClient }) { const resetVisibleHistory = (info: SessionInfo | null = null) => { idle() - setReasoning('') + clearReasoning() setMessages([]) setHistoryItems(info ? [introMsg(info)] : []) setInfo(info) setUsage(info?.usage ? { ...ZERO, ...info.usage } : ZERO) + setStickyPrompt('') setPasteSnips([]) setActivity([]) setLastUserMsg('') @@ -888,10 +1021,12 @@ export function App({ gw }: { gw: GatewayClient }) { const send = (text: string) => { const expandPasteSnips = (value: string) => { const byLabel = new Map() + for (const item of pasteSnips) { const list = byLabel.get(item.label) list ? list.push(item.text) : byLabel.set(item.label, [item.text]) } + return value.replace(PASTE_SNIPPET_RE, token => byLabel.get(token)?.shift() ?? token) } @@ -1192,11 +1327,13 @@ export function App({ gw }: { gw: GatewayClient }) { if (key.wheelUp) { scrollRef.current?.scrollBy(-WHEEL_SCROLL_STEP) + return } if (key.wheelDown) { scrollRef.current?.scrollBy(WHEEL_SCROLL_STEP) + return } @@ -1204,6 +1341,7 @@ export function App({ gw }: { gw: GatewayClient }) { const viewport = scrollRef.current?.getViewportHeight() ?? Math.max(6, (stdout?.rows ?? 24) - 8) const step = Math.max(4, viewport - 2) scrollRef.current?.scrollBy(key.pageUp ? -step : step) + return } @@ -1272,7 +1410,7 @@ export function App({ gw }: { gw: GatewayClient }) { partial ? appendMessage({ role: 'assistant', text: partial + '\n\n*[interrupted]*' }) : sys('interrupted') idle() - setReasoning('') + clearReasoning() setActivity([]) turnToolsRef.current = [] setStatus('interrupted') @@ -1287,6 +1425,7 @@ export function App({ gw }: { gw: GatewayClient }) { }, 1500) } else if (hasSelection) { const copied = selection.copySelection() + if (copied) { sys('copied selection') } @@ -1457,7 +1596,7 @@ export function App({ gw }: { gw: GatewayClient }) { case 'message.start': setBusy(true) endReasoningPhase() - setReasoning('') + clearReasoning() setActivity([]) setTurnTrail([]) turnToolsRef.current = [] @@ -1534,7 +1673,8 @@ export function App({ gw }: { gw: GatewayClient }) { case 'reasoning.delta': if (p?.text) { - setReasoning(prev => prev + p.text) + reasoningRef.current += p.text + scheduleReasoning() pulseReasoningStreaming() } @@ -1657,7 +1797,7 @@ export function App({ gw }: { gw: GatewayClient }) { if (p?.text && !interruptedRef.current) { buf.current = p.rendered ?? buf.current + p.text - setStreaming(buf.current.trimStart()) + scheduleStreaming() } break @@ -1673,7 +1813,7 @@ export function App({ gw }: { gw: GatewayClient }) { const finalText = (p?.rendered ?? p?.text ?? buf.current).trimStart() idle() - setReasoning('') + clearReasoning() setStreaming('') if (!wasInterrupted) { @@ -1715,7 +1855,7 @@ export function App({ gw }: { gw: GatewayClient }) { case 'error': idle() - setReasoning('') + clearReasoning() turnToolsRef.current = [] persistedToolLabelsRef.current.clear() @@ -1734,6 +1874,7 @@ export function App({ gw }: { gw: GatewayClient }) { [ appendMessage, bellOnComplete, + clearReasoning, dequeue, endReasoningPhase, gw, @@ -1743,6 +1884,8 @@ export function App({ gw }: { gw: GatewayClient }) { pushActivity, pushTrail, rpc, + scheduleReasoning, + scheduleStreaming, sendQueued, sys, stdout @@ -1883,8 +2026,10 @@ export function App({ gw }: { gw: GatewayClient }) { { const mode = arg.trim().toLowerCase() + if (!['hidden', 'collapsed', 'expanded', 'cycle', 'toggle'].includes(mode)) { sys('usage: /details [hidden|collapsed|expanded|cycle]') + return true } @@ -1895,12 +2040,13 @@ export function App({ gw }: { gw: GatewayClient }) { } return true - case 'copy': { if (!arg && hasSelection) { const copied = selection.copySelection() + if (copied) { sys('copied selection') + return true } } @@ -1949,6 +2095,7 @@ export function App({ gw }: { gw: GatewayClient }) { case 'sb': if (arg && !['on', 'off', 'toggle'].includes(arg.trim().toLowerCase())) { sys('usage: /statusbar [on|off|toggle]') + return true } @@ -2798,7 +2945,7 @@ export function App({ gw }: { gw: GatewayClient }) { } idle() - setReasoning('') + clearReasoning() setActivity([]) turnToolsRef.current = [] setStatus('interrupted') @@ -2855,16 +3002,17 @@ export function App({ gw }: { gw: GatewayClient }) { const durationLabel = sid ? fmtDuration(clockNow - sessionStartedAt) : '' const voiceLabel = voiceRecording ? 'REC' : voiceProcessing ? 'STT' : `voice ${voiceEnabled ? 'on' : 'off'}` const cwdLabel = shortCwd(info?.cwd || process.env.HERMES_CWD || process.cwd()) + const showStreamingArea = Boolean(streaming) + const visibleHistory = virtualRows.slice(virtualHistory.start, virtualHistory.end) + const showStickyPrompt = !!stickyPrompt const hasReasoning = Boolean(reasoning.trim()) + const showProgressArea = detailsMode === 'hidden' ? activity.some(i => i.tone !== 'info') : Boolean(busy || tools.length || turnTrail.length || hasReasoning || activity.length) - const showStreamingArea = Boolean(streaming) - const visibleHistory = virtualRows.slice(virtualHistory.start, virtualHistory.end) - // ── Render ─────────────────────────────────────────────────────── return ( @@ -2891,23 +3039,7 @@ export function App({ gw }: { gw: GatewayClient }) { {virtualHistory.bottomSpacer > 0 ? : null} - {showStreamingArea && ( - - - - )} - - - - - {showProgressArea && ( - + {showProgressArea && ( - - )} + )} + {showStreamingArea && ( + + )} + + + + + + {clarify && ( )} - + {showStickyPrompt ? ( + + + {stickyPrompt} + + ) : ( + + )} {statusBar && ( - {compactPreview(hasAnsi(msg.text) ? stripAnsi(msg.text) : msg.text, Math.max(24, cols - 14)) || '(empty tool result)'} + {compactPreview(hasAnsi(msg.text) ? stripAnsi(msg.text) : msg.text, Math.max(24, cols - 14)) || + '(empty tool result)'} ) @@ -44,16 +48,27 @@ export const MessageLine = memo(function MessageLine({ const showDetails = detailsMode !== 'hidden' && (Boolean(msg.tools?.length) || Boolean(thinking)) const content = (() => { - if (msg.kind === 'slash') return {msg.text} - if (msg.role !== 'user' && hasAnsi(msg.text)) return {msg.text} - if (msg.role === 'assistant') return + if (msg.kind === 'slash') { + return {msg.text} + } + + if (msg.role !== 'user' && hasAnsi(msg.text)) { + return {msg.text} + } + + if (msg.role === 'assistant') { + return isStreaming ? {msg.text} : + } if (msg.role === 'user' && msg.text.length > LONG_MSG && isPasteBackedText(msg.text)) { const [head, ...rest] = userDisplay(msg.text).split('[long message]') + return ( {head} - [long message] + + [long message] + {rest.join('')} ) @@ -76,7 +91,9 @@ export const MessageLine = memo(function MessageLine({ - {glyph}{' '} + + {glyph}{' '} + {content} diff --git a/ui-tui/src/components/textInput.tsx b/ui-tui/src/components/textInput.tsx index 8df818811..e6e23cc45 100644 --- a/ui-tui/src/components/textInput.tsx +++ b/ui-tui/src/components/textInput.tsx @@ -176,8 +176,10 @@ function offsetFromPosition(value: string, row: number, col: number, cols: numbe if (line === targetRow) { return index } + line++ column = 0 + continue } @@ -187,6 +189,7 @@ function offsetFromPosition(value: string, row: number, col: number, cols: numbe if (line === targetRow) { return index } + line++ column = 0 } @@ -333,7 +336,9 @@ export function TextInput({ }, [cur, display, focus, placeholder]) const clickCursor = (e: { localRow?: number; localCol?: number }) => { - if (!focus) return + if (!focus) { + return + } const next = offsetFromPosition(display, e.localRow ?? 0, e.localCol ?? 0, columns) setCur(next) curRef.current = next @@ -442,7 +447,6 @@ export function TextInput({ k.upArrow || k.downArrow || (k.ctrl && inp === 'c') || - (k.ctrl && inp === 't') || k.tab || (k.shift && k.tab) || k.pageUp || @@ -568,7 +572,7 @@ export function TextInput({ // ── Render ─────────────────────────────────────────────────────── return ( - + {rendered} ) diff --git a/ui-tui/src/components/thinking.tsx b/ui-tui/src/components/thinking.tsx index b24766ab3..418ee0c54 100644 --- a/ui-tui/src/components/thinking.tsx +++ b/ui-tui/src/components/thinking.tsx @@ -4,7 +4,6 @@ import spinners, { type BrailleSpinnerName } from 'unicode-animations' import { FACES, VERBS } from '../constants.js' import { - compactPreview, formatToolCall, parseToolTrailResultLine, pick, @@ -20,6 +19,7 @@ const TOOL: BrailleSpinnerName[] = ['cascade', 'scan', 'diagswipe', 'fillsweep', const fmtElapsed = (ms: number) => { const sec = Math.max(0, ms) / 1000 + return sec < 10 ? `${sec.toFixed(1)}s` : `${Math.round(sec)}s` } @@ -28,6 +28,7 @@ const fmtElapsed = (ms: number) => { export function Spinner({ color, variant = 'think' }: { color: string; variant?: 'think' | 'tool' }) { const [spin] = useState(() => { const raw = spinners[pick(variant === 'tool' ? TOOL : THINK)] + return { ...raw, frames: raw.frames.map(f => [...f][0] ?? '⠀') } }) @@ -35,6 +36,7 @@ export function Spinner({ color, variant = 'think' }: { color: string; variant?: useEffect(() => { const id = setInterval(() => setFrame(f => (f + 1) % spin.frames.length), spin.interval) + return () => clearInterval(id) }, [spin]) @@ -52,22 +54,46 @@ function Detail({ color, content, dimColor }: DetailRow) { ) } -function StreamCursor({ color, dimColor, streaming = false, visible = false }: { - color: string; dimColor?: boolean; streaming?: boolean; visible?: boolean +function StreamCursor({ + color, + dimColor, + streaming = false, + visible = false +}: { + color: string + dimColor?: boolean + streaming?: boolean + visible?: boolean }) { const [on, setOn] = useState(true) useEffect(() => { const id = setInterval(() => setOn(v => !v), 420) + return () => clearInterval(id) }, []) - return visible ? {streaming && on ? '▍' : ' '} : null + return visible ? ( + + {streaming && on ? '▍' : ' '} + + ) : null } -function Chevron({ count, onClick, open, summary, t, title, tone = 'dim' }: { - count?: number; onClick: () => void; open: boolean; summary?: string - t: Theme; title: string; tone?: 'dim' | 'error' | 'warn' +function Chevron({ + count, + onClick, + open, + t, + title, + tone = 'dim' +}: { + count?: number + onClick: () => void + open: boolean + t: Theme + title: string + tone?: 'dim' | 'error' | 'warn' }) { const color = tone === 'error' ? t.color.error : tone === 'warn' ? t.color.warn : t.color.dim @@ -75,8 +101,8 @@ function Chevron({ count, onClick, open, summary, t, title, tone = 'dim' }: { {open ? '▾ ' : '▸ '} - {title}{typeof count === 'number' ? ` (${count})` : ''} - {summary ? · {summary} : null} + {title} + {typeof count === 'number' ? ` (${count})` : ''} ) @@ -85,14 +111,23 @@ function Chevron({ count, onClick, open, summary, t, title, tone = 'dim' }: { // ── Thinking ───────────────────────────────────────────────────────── export const Thinking = memo(function Thinking({ - active = false, mode = 'truncated', reasoning, streaming = false, t + active = false, + mode = 'truncated', + reasoning, + streaming = false, + t }: { - active?: boolean; mode?: ThinkingMode; reasoning: string; streaming?: boolean; t: Theme + active?: boolean + mode?: ThinkingMode + reasoning: string + streaming?: boolean + t: Theme }) { const [tick, setTick] = useState(0) useEffect(() => { const id = setInterval(() => setTick(v => v + 1), 1100) + return () => clearInterval(id) }, []) @@ -126,13 +161,25 @@ export const Thinking = memo(function Thinking({ type Group = { color: string; content: ReactNode; details: DetailRow[]; key: string } export const ToolTrail = memo(function ToolTrail({ - busy = false, detailsMode = 'collapsed', reasoningActive = false, - reasoning = '', reasoningStreaming = false, t, - tools = [], trail = [], activity = [] + busy = false, + detailsMode = 'collapsed', + reasoningActive = false, + reasoning = '', + reasoningStreaming = false, + t, + tools = [], + trail = [], + activity = [] }: { - busy?: boolean; detailsMode?: DetailsMode; reasoningActive?: boolean - reasoning?: string; reasoningStreaming?: boolean; t: Theme - tools?: ActiveTool[]; trail?: string[]; activity?: ActivityItem[] + busy?: boolean + detailsMode?: DetailsMode + reasoningActive?: boolean + reasoning?: string + reasoningStreaming?: boolean + t: Theme + tools?: ActiveTool[] + trail?: string[] + activity?: ActivityItem[] }) { const [now, setNow] = useState(() => Date.now()) const [openThinking, setOpenThinking] = useState(false) @@ -140,19 +187,33 @@ export const ToolTrail = memo(function ToolTrail({ const [openMeta, setOpenMeta] = useState(false) useEffect(() => { - if (!tools.length) return - const id = setInterval(() => setNow(Date.now()), 200) + if (!tools.length || (detailsMode === 'collapsed' && !openTools)) { + return + } + const id = setInterval(() => setNow(Date.now()), 500) + return () => clearInterval(id) - }, [tools.length]) + }, [detailsMode, openTools, tools.length]) useEffect(() => { - if (detailsMode === 'expanded') { setOpenThinking(true); setOpenTools(true); setOpenMeta(true) } - if (detailsMode === 'hidden') { setOpenThinking(false); setOpenTools(false); setOpenMeta(false) } + if (detailsMode === 'expanded') { + setOpenThinking(true) + setOpenTools(true) + setOpenMeta(true) + } + + if (detailsMode === 'hidden') { + setOpenThinking(false) + setOpenTools(false) + setOpenMeta(false) + } }, [detailsMode]) const cot = thinkingPreview(reasoning, 'full', THINKING_COT_MAX) - if (!busy && !trail.length && !tools.length && !activity.length && !cot && !reasoningActive) return null + if (!busy && !trail.length && !tools.length && !activity.length && !cot && !reasoningActive) { + return null + } // ── Build groups + meta ──────────────────────────────────────── @@ -167,12 +228,19 @@ export const ToolTrail = memo(function ToolTrail({ groups.push({ color: parsed.mark === '✗' ? t.color.error : t.color.cornsilk, content: parsed.detail ? parsed.call : `${parsed.call} ${parsed.mark}`, - details: [], key: `tr-${i}` - }) - if (parsed.detail) pushDetail({ - color: parsed.mark === '✗' ? t.color.error : t.color.dim, - content: parsed.detail, dimColor: parsed.mark !== '✗', key: `tr-${i}-d` + details: [], + key: `tr-${i}` }) + + if (parsed.detail) { + pushDetail({ + color: parsed.mark === '✗' ? t.color.error : t.color.dim, + content: parsed.detail, + dimColor: parsed.mark !== '✗', + key: `tr-${i}-d` + }) + } + continue } @@ -183,16 +251,24 @@ export const ToolTrail = memo(function ToolTrail({ details: [{ color: t.color.dim, content: 'drafting...', dimColor: true, key: `tr-${i}-d` }], key: `tr-${i}` }) + continue } if (line === 'analyzing tool output…') { pushDetail({ - color: t.color.dim, dimColor: true, key: `tr-${i}`, - content: groups.length - ? <> {line} - : line + color: t.color.dim, + dimColor: true, + key: `tr-${i}`, + content: groups.length ? ( + <> + {line} + + ) : ( + line + ) }) + continue } @@ -201,7 +277,9 @@ export const ToolTrail = memo(function ToolTrail({ for (const tool of tools) { groups.push({ - color: t.color.cornsilk, key: tool.id, details: [], + color: t.color.cornsilk, + key: tool.id, + details: [], content: ( <> {formatToolCall(tool.name, tool.context || '')} @@ -211,18 +289,6 @@ export const ToolTrail = memo(function ToolTrail({ }) } - if (cot && groups.length) { - pushDetail({ - color: t.color.dim, dimColor: true, key: 'cot', - content: <>{cot} - }) - } else if (reasoningActive && groups.length) { - pushDetail({ - color: t.color.dim, dimColor: true, key: 'cot', - content: - }) - } - for (const item of activity.slice(-4)) { const glyph = item.tone === 'error' ? '✗' : item.tone === 'warn' ? '!' : '·' const color = item.tone === 'error' ? t.color.error : item.tone === 'warn' ? t.color.warn : t.color.dim @@ -233,12 +299,13 @@ export const ToolTrail = memo(function ToolTrail({ const hasTools = groups.length > 0 const hasMeta = meta.length > 0 - const hasThinking = !hasTools && (busy || !!cot || reasoningActive) + const hasThinking = !!cot || reasoningActive || (busy && !hasTools) // ── Hidden: errors/warnings only ────────────────────────────── if (detailsMode === 'hidden') { const alerts = activity.filter(i => i.tone !== 'info').slice(-2) + return alerts.length ? ( {alerts.map(i => ( @@ -253,61 +320,95 @@ export const ToolTrail = memo(function ToolTrail({ // ── Shared render fragments ──────────────────────────────────── const thinkingBlock = hasThinking ? ( - busy - ? - : cot - ? - : } dimColor key="cot" /> + busy ? ( + + ) : cot ? ( + + ) : ( + } + dimColor + key="cot" + /> + ) ) : null - const toolBlock = hasTools ? groups.map(g => ( - - - - {g.content} - - {g.details.map(d => )} - - )) : null + const toolBlock = hasTools + ? groups.map(g => ( + + + + {g.content} + + {g.details.map(d => ( + + ))} + + )) + : null - const metaBlock = hasMeta ? meta.map((row, i) => ( - - {i === meta.length - 1 ? '└ ' : '├ '} - {row.content} - - )) : null + const metaBlock = hasMeta + ? meta.map((row, i) => ( + + {i === meta.length - 1 ? '└ ' : '├ '} + {row.content} + + )) + : null // ── Expanded: flat, no accordions ────────────────────────────── if (detailsMode === 'expanded') { - return {thinkingBlock}{toolBlock}{metaBlock} + return ( + + {thinkingBlock} + {toolBlock} + {metaBlock} + + ) } // ── Collapsed: clickable accordions ──────────────────────────── - const metaTone: 'dim' | 'error' | 'warn' = - activity.some(i => i.tone === 'error') ? 'error' - : activity.some(i => i.tone === 'warn') ? 'warn' : 'dim' + const metaTone: 'dim' | 'error' | 'warn' = activity.some(i => i.tone === 'error') + ? 'error' + : activity.some(i => i.tone === 'warn') + ? 'warn' + : 'dim' return ( {hasThinking && ( <> - setOpenThinking(v => !v)} open={openThinking} summary={cot ? compactPreview(cot, 56) : busy ? 'running…' : ''} t={t} title="Thinking" /> + setOpenThinking(v => !v)} open={openThinking} t={t} title="Thinking" /> {openThinking && thinkingBlock} )} {hasTools && ( <> - setOpenTools(v => !v)} open={openTools} t={t} title="Tool calls" /> + setOpenTools(v => !v)} + open={openTools} + t={t} + title="Tool calls" + /> {openTools && toolBlock} )} {hasMeta && ( <> - setOpenMeta(v => !v)} open={openMeta} t={t} title="Activity" tone={metaTone} /> + setOpenMeta(v => !v)} + open={openMeta} + t={t} + title="Activity" + tone={metaTone} + /> {openMeta && metaBlock} )} diff --git a/ui-tui/src/hooks/useVirtualHistory.ts b/ui-tui/src/hooks/useVirtualHistory.ts index 877fde5d7..868a35c4a 100644 --- a/ui-tui/src/hooks/useVirtualHistory.ts +++ b/ui-tui/src/hooks/useVirtualHistory.ts @@ -1,19 +1,30 @@ -import { useCallback, useEffect, useLayoutEffect, useMemo, useRef, useState, useSyncExternalStore, type RefObject } from 'react' - import type { ScrollBoxHandle } from '@hermes/ink' +import { + type RefObject, + useCallback, + useEffect, + useLayoutEffect, + useMemo, + useRef, + useState, + useSyncExternalStore +} from 'react' const ESTIMATE = 4 const OVERSCAN = 40 const MAX_MOUNTED = 260 const COLD_START = 40 -const QUANTUM = 8 +const QUANTUM = OVERSCAN >> 1 const upperBound = (arr: number[], target: number) => { - let lo = 0, hi = arr.length + let lo = 0, + hi = arr.length + while (lo < hi) { const mid = (lo + hi) >> 1 - arr[mid]! <= target ? lo = mid + 1 : hi = mid + arr[mid]! <= target ? (lo = mid + 1) : (hi = mid) } + return lo } @@ -28,14 +39,15 @@ export function useVirtualHistory( const [ver, setVer] = useState(0) useSyncExternalStore( - useCallback( - (cb: () => void) => scrollRef.current?.subscribe(cb) ?? (() => () => {}), - [scrollRef] - ), + useCallback((cb: () => void) => scrollRef.current?.subscribe(cb) ?? (() => () => {}), [scrollRef]), () => { const s = scrollRef.current - if (!s) return NaN + + if (!s) { + return NaN + } const b = Math.floor(s.getScrollTop() / QUANTUM) + return s.isSticky() ? -b - 1 : b }, () => NaN @@ -44,6 +56,7 @@ export function useVirtualHistory( useEffect(() => { const keep = new Set(items.map(i => i.key)) let dirty = false + for (const k of heights.current.keys()) { if (!keep.has(k)) { heights.current.delete(k) @@ -52,13 +65,19 @@ export function useVirtualHistory( dirty = true } } - if (dirty) setVer(v => v + 1) + + if (dirty) { + setVer(v => v + 1) + } }, [items]) const offsets = useMemo(() => { const out = new Array(items.length + 1).fill(0) - for (let i = 0; i < items.length; i++) + + for (let i = 0; i < items.length; i++) { out[i + 1] = out[i]! + Math.max(1, Math.floor(heights.current.get(items[i]!.key) ?? estimate)) + } + return out }, [estimate, items, ver]) @@ -67,7 +86,8 @@ export function useVirtualHistory( const vp = Math.max(0, scrollRef.current?.getViewportHeight() ?? 0) const sticky = scrollRef.current?.isSticky() ?? true - let start = 0, end = items.length + let start = 0, + end = items.length if (items.length > 0) { if (vp <= 0) { @@ -79,37 +99,46 @@ export function useVirtualHistory( } if (end - start > maxMounted) { - sticky - ? (start = Math.max(0, end - maxMounted)) - : (end = Math.min(items.length, start + maxMounted)) + sticky ? (start = Math.max(0, end - maxMounted)) : (end = Math.min(items.length, start + maxMounted)) } const measureRef = useCallback((key: string) => { let fn = refs.current.get(key) + if (!fn) { - fn = (el: any) => el ? nodes.current.set(key, el) : nodes.current.delete(key) + fn = (el: any) => (el ? nodes.current.set(key, el) : nodes.current.delete(key)) refs.current.set(key, fn) } + return fn }, []) useLayoutEffect(() => { let dirty = false + for (let i = start; i < end; i++) { const k = items[i]?.key - if (!k) continue + + if (!k) { + continue + } const h = Math.ceil(nodes.current.get(k)?.yogaNode?.getComputedHeight?.() ?? 0) + if (h > 0 && heights.current.get(k) !== h) { heights.current.set(k, h) dirty = true } } - if (dirty) setVer(v => v + 1) + + if (dirty) { + setVer(v => v + 1) + } }, [end, items, start]) return { start, end, + offsets, topSpacer: offsets[start] ?? 0, bottomSpacer: Math.max(0, total - (offsets[end] ?? total)), measureRef diff --git a/ui-tui/src/types/hermes-ink.d.ts b/ui-tui/src/types/hermes-ink.d.ts index d6ecb7f61..6b3001a85 100644 --- a/ui-tui/src/types/hermes-ink.d.ts +++ b/ui-tui/src/types/hermes-ink.d.ts @@ -49,8 +49,10 @@ declare module '@hermes/ink' { export type ScrollBoxHandle = { readonly scrollTo: (y: number) => void readonly scrollBy: (dy: number) => void + readonly scrollToElement: (el: unknown, offset?: number) => void readonly scrollToBottom: () => void readonly getScrollTop: () => number + readonly getPendingDelta: () => number readonly getViewportHeight: () => number readonly isSticky: () => boolean readonly subscribe: (listener: () => void) => () => void diff --git a/web/package-lock.json b/web/package-lock.json index d9aa7a951..71ca2c7a7 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -14,6 +14,7 @@ "lucide-react": "^0.577.0", "react": "^19.2.4", "react-dom": "^19.2.4", + "react-router-dom": "^7.14.1", "tailwind-merge": "^3.5.0", "tailwindcss": "^4.2.1" }, @@ -2208,6 +2209,19 @@ "dev": true, "license": "MIT" }, + "node_modules/cookie": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-1.1.1.tgz", + "integrity": "sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/cross-spawn": { "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", @@ -3403,6 +3417,44 @@ "node": ">=0.10.0" } }, + "node_modules/react-router": { + "version": "7.14.1", + "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.14.1.tgz", + "integrity": "sha512-5BCvFskyAAVumqhEKh/iPhLOIkfxcEUz8WqFIARCkMg8hZZzDYX9CtwxXA0e+qT8zAxmMC0x3Ckb9iMONwc5jg==", + "license": "MIT", + "dependencies": { + "cookie": "^1.0.1", + "set-cookie-parser": "^2.6.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "react": ">=18", + "react-dom": ">=18" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + } + } + }, + "node_modules/react-router-dom": { + "version": "7.14.1", + "resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-7.14.1.tgz", + "integrity": "sha512-ZkrQuwwhGibjQLqH1eCdyiZyLWglPxzxdl5tgwgKEyCSGC76vmAjleGocRe3J/MLfzMUIKwaFJWpFVJhK3d2xA==", + "license": "MIT", + "dependencies": { + "react-router": "7.14.1" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "react": ">=18", + "react-dom": ">=18" + } + }, "node_modules/resolve-from": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", @@ -3473,6 +3525,12 @@ "semver": "bin/semver.js" } }, + "node_modules/set-cookie-parser": { + "version": "2.7.2", + "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.2.tgz", + "integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==", + "license": "MIT" + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", diff --git a/web/package.json b/web/package.json index 87dbfdb79..09675d283 100644 --- a/web/package.json +++ b/web/package.json @@ -16,6 +16,7 @@ "lucide-react": "^0.577.0", "react": "^19.2.4", "react-dom": "^19.2.4", + "react-router-dom": "^7.14.1", "tailwind-merge": "^3.5.0", "tailwindcss": "^4.2.1" }, diff --git a/web/src/App.tsx b/web/src/App.tsx index d52757c20..3d2832ccb 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -8,16 +8,18 @@ import LogsPage from "@/pages/LogsPage"; import AnalyticsPage from "@/pages/AnalyticsPage"; import CronPage from "@/pages/CronPage"; import SkillsPage from "@/pages/SkillsPage"; +import { LanguageSwitcher } from "@/components/LanguageSwitcher"; +import { useI18n } from "@/i18n"; const NAV_ITEMS = [ - { id: "status", label: "Status", icon: Activity }, - { id: "sessions", label: "Sessions", icon: MessageSquare }, - { id: "analytics", label: "Analytics", icon: BarChart3 }, - { id: "logs", label: "Logs", icon: FileText }, - { id: "cron", label: "Cron", icon: Clock }, - { id: "skills", label: "Skills", icon: Package }, - { id: "config", label: "Config", icon: Settings }, - { id: "env", label: "Keys", icon: KeyRound }, + { id: "status", labelKey: "status" as const, icon: Activity }, + { id: "sessions", labelKey: "sessions" as const, icon: MessageSquare }, + { id: "analytics", labelKey: "analytics" as const, icon: BarChart3 }, + { id: "logs", labelKey: "logs" as const, icon: FileText }, + { id: "cron", labelKey: "cron" as const, icon: Clock }, + { id: "skills", labelKey: "skills" as const, icon: Package }, + { id: "config", labelKey: "config" as const, icon: Settings }, + { id: "env", labelKey: "keys" as const, icon: KeyRound }, ] as const; type PageId = (typeof NAV_ITEMS)[number]["id"]; @@ -37,6 +39,7 @@ export default function App() { const [page, setPage] = useState("status"); const [animKey, setAnimKey] = useState(0); const initialRef = useRef(true); + const { t } = useI18n(); useEffect(() => { // Skip the animation key bump on initial mount to avoid re-mounting @@ -68,7 +71,7 @@ export default function App() { {/* Nav — icons only on mobile, icon+label on sm+ */} - {/* Version badge — hidden on mobile */} -
- - Web UI + {/* Right side: language switcher + version badge */} +
+ + + {t.app.webUi}
@@ -112,10 +116,10 @@ export default function App() {
- Hermes Agent + {t.app.footer.name} - Nous Research + {t.app.footer.org}
diff --git a/web/src/components/AutoField.tsx b/web/src/components/AutoField.tsx index 67f6739e9..44128cf9f 100644 --- a/web/src/components/AutoField.tsx +++ b/web/src/components/AutoField.tsx @@ -1,6 +1,6 @@ import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; -import { Select } from "@/components/ui/select"; +import { Select, SelectOption } from "@/components/ui/select"; import { Switch } from "@/components/ui/switch"; function FieldHint({ schema, schemaKey }: { schema: Record; schemaKey: string }) { @@ -44,11 +44,11 @@ export function AutoField({
- onChange(v)}> {options.map((opt) => ( - + ))}
@@ -85,7 +85,7 @@ export function AutoField({
A real terminal interfaceFull TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.