diff --git a/.github/workflows/supply-chain-audit.yml b/.github/workflows/supply-chain-audit.yml index 9eb76e6a5f3..7ff734ca943 100644 --- a/.github/workflows/supply-chain-audit.yml +++ b/.github/workflows/supply-chain-audit.yml @@ -47,14 +47,17 @@ jobs: HEAD="${{ github.event.pull_request.head.sha }}" # Added lines only, excluding lockfiles. - DIFF=$(git diff "$BASE".."$HEAD" -- . ':!uv.lock' ':!*.lock' ':!package-lock.json' ':!yarn.lock' || true) + # Three-dot diff (base...head) diffs from the merge base to HEAD, + # so only changes introduced by this PR are included โ€” not changes + # that landed on main after the PR branched off. + DIFF=$(git diff "$BASE"..."$HEAD" -- . ':!uv.lock' ':!*.lock' ':!package-lock.json' ':!yarn.lock' || true) FINDINGS="" # --- .pth files (auto-execute on Python startup) --- # The exact mechanism used in the litellm supply chain attack: # https://github.com/BerriAI/litellm/issues/24512 - PTH_FILES=$(git diff --name-only "$BASE".."$HEAD" | grep '\.pth$' || true) + PTH_FILES=$(git diff --name-only "$BASE"..."$HEAD" | grep '\.pth$' || true) if [ -n "$PTH_FILES" ]; then FINDINGS="${FINDINGS} ### ๐Ÿšจ CRITICAL: .pth file added or modified @@ -97,7 +100,7 @@ jobs: # --- Install-hook files (setup.py/sitecustomize/usercustomize/__init__.pth) --- # These execute during pip install or interpreter startup. - SETUP_HITS=$(git diff --name-only "$BASE".."$HEAD" | grep -E '(^|/)(setup\.py|setup\.cfg|sitecustomize\.py|usercustomize\.py|__init__\.pth)$' || true) + SETUP_HITS=$(git diff --name-only "$BASE"..."$HEAD" | grep -E '(^|/)(setup\.py|setup\.cfg|sitecustomize\.py|usercustomize\.py|__init__\.pth)$' || true) if [ -n "$SETUP_HITS" ]; then FINDINGS="${FINDINGS} ### ๐Ÿšจ CRITICAL: Install-hook file added or modified @@ -158,7 +161,7 @@ jobs: HEAD="${{ github.event.pull_request.head.sha }}" # Only check added lines in pyproject.toml - ADDED=$(git diff "$BASE".."$HEAD" -- pyproject.toml | grep '^+' | grep -v '^+++' || true) + ADDED=$(git diff "$BASE"..."$HEAD" -- pyproject.toml | grep '^+' | grep -v '^+++' || true) if [ -z "$ADDED" ]; then echo "found=false" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3ffaa10d009..b48b0bab080 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,11 +23,22 @@ concurrency: jobs: test: runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + slice: [1, 2, 3, 4, 5, 6] steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Restore duration cache + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: test_durations.json + # Single stable key. main always overwrites, PRs always find it. + key: test-durations + - name: Install ripgrep (prebuilt binary) run: | set -euo pipefail @@ -54,7 +65,7 @@ jobs: source .venv/bin/activate uv pip install -e ".[all,dev]" - - name: Run tests + - name: Run tests (slice ${{ matrix.slice }}/6) # Per-file isolation via scripts/run_tests_parallel.py: discovers # every test_*.py file under tests/ (excluding integration/ + e2e/), # then runs `python -m pytest ` in a freshly-spawned subprocess @@ -72,15 +83,61 @@ jobs: # state across files, which is exactly the leakage we wanted to # fix. ThreadPoolExecutor + subprocess.run is ~60 lines and does # the job with cleaner semantics. + # + # Matrix slicing (--slice I/N): files are distributed across 6 + # jobs by cached duration (LPT algorithm) so each job gets + # roughly equal wall time. Without a cache, files default to 2s + # estimate and get split roughly evenly by count โ€” still correct, + # just not perfectly balanced. run: | source .venv/bin/activate - python scripts/run_tests_parallel.py + python scripts/run_tests_parallel.py --slice ${{ matrix.slice }}/6 env: # Ensure tests don't accidentally call real APIs OPENROUTER_API_KEY: "" OPENAI_API_KEY: "" NOUS_API_KEY: "" + - name: Upload per-slice durations + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: test-durations-slice-${{ matrix.slice }} + path: test_durations.json + retention-days: 1 + + # Merge per-slice duration data into a single cache, so future runs + # (including PRs) get balanced slicing. + save-durations: + needs: test + if: always() && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - name: Download all slice durations + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + pattern: test-durations-slice-* + path: durations + merge-multiple: true + + - name: Merge into single durations file + run: | + python3 -c " + import json, glob, os + merged = {} + for f in glob.glob('durations/*test_durations.json'): + with open(f) as fh: + merged.update(json.load(fh)) + with open('test_durations.json', 'w') as fh: + json.dump(merged, fh, indent=2, sort_keys=True) + print(f'Merged {len(merged)} file durations') + " + + - name: Save merged duration cache + uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 + with: + path: test_durations.json + key: test-durations + e2e: runs-on: ubuntu-latest timeout-minutes: 15 @@ -121,4 +178,4 @@ jobs: env: OPENROUTER_API_KEY: "" OPENAI_API_KEY: "" - NOUS_API_KEY: "" + NOUS_API_KEY: "" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2dbd15c6c7d..8bbe7235ee9 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ __pycache__/web_tools.cpython-310.pyc logs/ data/ .pytest_cache/ +test_durations.json .pytest-cache/ tmp/ temp_vision_images/ diff --git a/README.md b/README.md index b659f56fa53..9b148164294 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,27 @@ hermes doctor # Diagnose any issues ๐Ÿ“– **[Full documentation โ†’](https://hermes-agent.nousresearch.com/docs/)** +--- + +## Skip the API-key collection โ€” Nous Portal + +Hermes works with whatever provider you want โ€” that's not changing. But if you'd rather not collect five separate API keys for the model, web search, image generation, TTS, and a cloud browser, **[Nous Portal](https://portal.nousresearch.com)** covers all of them under one subscription: + +- **300+ models** โ€” pick any of them with `/model ` +- **Tool Gateway** โ€” web search (Firecrawl), image generation (FAL), text-to-speech (OpenAI), cloud browser (Browser Use), all routed through your sub. No extra accounts. + +One command from a fresh install: + +```bash +hermes setup --portal +``` + +That logs you in via OAuth, sets Nous as your provider, and turns on the Tool Gateway. Check what's wired up any time with `hermes portal status`. Full details on the [Tool Gateway docs page](https://hermes-agent.nousresearch.com/docs/user-guide/features/tool-gateway). + +You can still bring your own keys per-tool whenever you want โ€” the gateway is per-backend, not all-or-nothing. + +--- + ## CLI vs Messaging Quick Reference Hermes has two entry points: start the terminal UI with `hermes`, or run the gateway and talk to it from Telegram, Discord, Slack, WhatsApp, Signal, or Email. Once you're in a conversation, many slash commands are shared across both interfaces. diff --git a/README.zh-CN.md b/README.zh-CN.md index 9a964574413..e2228234ce6 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -65,6 +65,27 @@ hermes doctor # ่ฏŠๆ–ญ้—ฎ้ข˜ ๐Ÿ“– **[ๅฎŒๆ•ดๆ–‡ๆกฃ โ†’](https://hermes-agent.nousresearch.com/docs/)** +--- + +## ็œๅŽปๅˆฐๅค„ๆ”ถ้›† API Key โ€” Nous Portal + +Hermes ๅง‹็ปˆๅ…่ฎธไฝ ไฝฟ็”จไปปๆ„ๆœๅŠกๅ•†๏ผŒ่ฟ™็‚นไธไผšๆ”นๅ˜ใ€‚ไฝ†ๅฆ‚ๆžœไฝ ไธๆƒณไธบๆจกๅž‹ใ€็ฝ‘้กตๆœ็ดขใ€ๅ›พๅƒ็”Ÿๆˆใ€TTSใ€ไบ‘ๆต่งˆๅ™จๅˆ†ๅˆซๅŽป็”ณ่ฏทไบ”ไธชไธๅŒ็š„ API Key๏ผŒ**[Nous Portal](https://portal.nousresearch.com)** ็”จไธ€ไธช่ฎข้˜…ๅฐฑ่ƒฝ่ฆ†็›–ๅ…จ้ƒจ๏ผš + +- **300+ ๆจกๅž‹** โ€” ็”จ `/model ` ้šๆ—ถๅˆ‡ๆข +- **Tool Gateway** โ€” ็ฝ‘้กตๆœ็ดข๏ผˆFirecrawl๏ผ‰ใ€ๅ›พๅƒ็”Ÿๆˆ๏ผˆFAL๏ผ‰ใ€ๆ–‡ๆœฌ่ฝฌ่ฏญ้Ÿณ๏ผˆOpenAI๏ผ‰ใ€ไบ‘ๆต่งˆๅ™จ๏ผˆBrowser Use๏ผ‰๏ผŒๅ…จ้ƒจ้€š่ฟ‡่ฎข้˜…ๆ‰˜็ฎกใ€‚ๆ— ้œ€้ขๅค–ๆณจๅ†Œไปปไฝ•่ดฆๆˆทใ€‚ + +ๅ…จๆ–ฐๅฎ‰่ฃ…ๆ—ถไธ€ๆกๅ‘ฝไปคๅณๅฏ๏ผš + +```bash +hermes setup --portal +``` + +ๅฎƒไผš้€š่ฟ‡ OAuth ็™ปๅฝ•ใ€ๆŠŠ Nous ่ฎพไธบๆŽจ็†ๆœๅŠกๅ•†๏ผŒๅนถๅฏ็”จ Tool Gatewayใ€‚้šๆ—ถ็”จ `hermes portal status` ๆŸฅ็œ‹่ทฏ็”ฑ็Šถๆ€ใ€‚ๅฎŒๆ•ด่ฏดๆ˜Ž่ง [Tool Gateway ๆ–‡ๆกฃ](https://hermes-agent.nousresearch.com/docs/user-guide/features/tool-gateway)ใ€‚ + +ไฝ ้šๆ—ถๅฏไปฅๆŒ‰ๅทฅๅ…ทๅ•็‹ฌๅˆ‡ๅ›ž่‡ชๅทฑ็š„ API Key โ€” Gateway ๆ˜ฏๆŒ‰ๅทฅๅ…ท็ฒ’ๅบฆ็”Ÿๆ•ˆ็š„๏ผŒไธๆ˜ฏไธ€ๅˆ€ๅˆ‡ใ€‚ + +--- + ## CLI ไธŽๆถˆๆฏๅนณๅฐ ๅฟซ้€Ÿๅฏน็…ง Hermes ๆœ‰ไธค็งๅ…ฅๅฃ๏ผš็”จ `hermes` ๅฏๅŠจ็ปˆ็ซฏ UI๏ผŒๆˆ–่ฟ่กŒ็ฝ‘ๅ…ณไปŽ Telegramใ€Discordใ€Slackใ€WhatsAppใ€Signal ๆˆ– Email ไธŽไน‹ๅฏน่ฏใ€‚่ฟ›ๅ…ฅๅฏน่ฏๅŽ๏ผŒ่ฎธๅคšๆ–œๆ ๅ‘ฝไปคๅœจไธค็ง็•Œ้ขไธญ้€š็”จใ€‚ diff --git a/acp_adapter/server.py b/acp_adapter/server.py index fbdee70527a..81c22c18774 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -1534,7 +1534,11 @@ class HermesACPAgent(acp.Agent): ) except Exception: logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True) - if final_response and conn and not streamed_message: + if final_response and conn and (not streamed_message or result.get("response_transformed")): + # Deliver the final response when streaming did not already send it, + # or when a plugin hook transformed the response after streaming + # finished (e.g. transform_llm_output) โ€” otherwise the appended / + # rewritten text never reaches the client. update = acp.update_agent_message_text(final_response) await conn.session_update(session_id, update) diff --git a/agent/agent_init.py b/agent/agent_init.py index be9a09dd2f5..e20755c5091 100644 --- a/agent/agent_init.py +++ b/agent/agent_init.py @@ -607,6 +607,31 @@ def init_agent( # Falling back would send Anthropic credentials to third-party endpoints (Fixes #1739, #minimax-401). _is_native_anthropic = agent.provider == "anthropic" effective_key = (api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or "") + + # MiniMax OAuth issues short-lived (~15-min) access tokens. The + # Anthropic SDK caches ``api_key`` as a static string at client + # construction time, so a session that resolves the bearer once + # at startup will keep sending the same token until MiniMax + # returns 401 mid-session. Swap the static string for a callable + # token provider โ€” ``build_anthropic_client`` recognizes the + # callable and installs an httpx event hook that mints a fresh + # bearer per outbound request (re-reading auth.json so a refresh + # persisted by another process is visible immediately). + # The cached refresh path is a no-op when the token still has + # ``MINIMAX_OAUTH_REFRESH_SKEW_SECONDS`` of life left, so steady- + # state cost is one file read + one timestamp compare per request. + if agent.provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key: + try: + from hermes_cli.auth import build_minimax_oauth_token_provider + effective_key = build_minimax_oauth_token_provider() + except Exception as _mm_exc: # noqa: BLE001 โ€” never block startup on this + import logging as _logging + _logging.getLogger(__name__).warning( + "MiniMax OAuth: failed to install per-request token provider " + "(%s); falling back to static bearer that will expire ~15min in.", + _mm_exc, + ) + agent.api_key = effective_key agent._anthropic_api_key = effective_key agent._anthropic_base_url = base_url @@ -618,7 +643,7 @@ def init_agent( # that cause 401/403 on their endpoints. Guards #1739 and # the third-party identity-injection bug. from agent.anthropic_adapter import _is_oauth_token as _is_oat - agent._is_anthropic_oauth = _is_oat(effective_key) if _is_native_anthropic else False + agent._is_anthropic_oauth = _is_oat(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False agent._anthropic_client = build_anthropic_client(effective_key, base_url, timeout=_provider_timeout) # No OpenAI client needed for Anthropic mode agent.client = None @@ -951,16 +976,14 @@ def init_agent( # Expose session ID to tools (terminal, execute_code) so agents can # reference their own session for --resume commands, cross-session - # coordination, and logging. Uses the ContextVar system from - # session_context.py for concurrency safety (gateway runs multiple - # sessions in one process). Also writes os.environ as fallback for - # CLI mode where ContextVars aren't used. - os.environ["HERMES_SESSION_ID"] = agent.session_id + # coordination, and logging. Keep the ContextVar and os.environ + # fallback synchronized because different tool paths still read both. try: - from gateway.session_context import _SESSION_ID - _SESSION_ID.set(agent.session_id) + from gateway.session_context import set_current_session_id + + set_current_session_id(agent.session_id) except Exception: - pass # CLI/test mode โ€” ContextVar not needed + os.environ["HERMES_SESSION_ID"] = agent.session_id # Session logs go into ~/.hermes/sessions/ alongside gateway sessions hermes_home = get_hermes_home() @@ -1125,7 +1148,18 @@ def init_agent( # through _ra().get_tool_definitions()). Duplicate function names cause # 400 errors on providers that enforce unique names (e.g. Xiaomi # MiMo via Nous Portal). - if agent._memory_manager and agent.tools is not None: + # + # Respect the platform's enabled_toolsets configuration (#5544): + # enabled_toolsets is None โ†’ no filter, inject (backward compat) + # "memory" in enabled_toolsets โ†’ user opted in, inject + # otherwise (incl. []) โ†’ user excluded memory, skip injection + # + # Without this gate, `platform_toolsets: telegram: []` still leaks memory + # provider tools (fact_store, etc.) into the tool surface โ€” a 10x latency + # penalty on local models and a frequent trigger of tool-call loops. + if agent._memory_manager and agent.tools is not None and ( + agent.enabled_toolsets is None or "memory" in agent.enabled_toolsets + ): _existing_tool_names = { t.get("function", {}).get("name") for t in agent.tools @@ -1393,6 +1427,7 @@ def init_agent( base_url=agent.base_url, api_key=getattr(agent, "api_key", ""), provider=agent.provider, + api_mode=agent.api_mode, ) if not agent.quiet_mode: _ra().logger.info("Using context engine: %s", _selected_engine.name) @@ -1435,8 +1470,22 @@ def init_agent( # errors. Even with the cache fix, dedup is the right defense # against plugin paths that may register the same schemas via # ctx.register_tool(). Mirrors the memory tools dedup above. + # + # Respect the platform's enabled_toolsets configuration (#5544): + # context engine tools follow the same gating pattern as memory + # provider tools โ€” without the gate, `platform_toolsets: telegram: []` + # would still leak lcm_* tools into the tool surface and incur the + # same local-model latency penalty. agent._context_engine_tool_names: set = set() - if hasattr(agent, "context_compressor") and agent.context_compressor and agent.tools is not None: + if ( + hasattr(agent, "context_compressor") + and agent.context_compressor + and agent.tools is not None + and ( + agent.enabled_toolsets is None + or "context_engine" in agent.enabled_toolsets + ) + ): _existing_tool_names = { t.get("function", {}).get("name") for t in agent.tools diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index b98fe4b44e7..f7c8819eb5e 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -132,7 +132,7 @@ def convert_to_trajectory_format(agent, messages: List[Dict[str, Any]], user_que except json.JSONDecodeError: # This shouldn't happen since we validate and retry during conversation, # but if it does, log warning and use empty dict - logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}") + logger.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}") arguments = {} tool_call_json = { @@ -617,9 +617,28 @@ def recover_with_credential_pool( # existing entitlement keyword set in ``_is_entitlement_failure``. # Any 403 against ``xai-oauth`` is treated as entitlement here so # the refresh loop can't spin in those cases either. + # + # Exception (#29344): xAI's ``[WKE=unauthenticated:...]`` suffix and + # the ``OAuth2 access token could not be validated`` phrasing are + # xAI's authoritative "this is a stale token, not entitlement" + # signal. When either fires we must NOT apply the catch-all + # override โ€” refresh is the recoverable path for these bodies, and + # blanket-classifying them as entitlement was the bug that left + # long-running TUI sessions stuck on stale tokens until the user + # exited and reopened. is_entitlement = agent._is_entitlement_failure(error_context, status_code) if not is_entitlement and status_code == 403 and (agent.provider or "") == "xai-oauth": - is_entitlement = True + _disambiguator_haystack = " ".join( + str(error_context.get(k) or "").lower() + for k in ("message", "reason", "code", "error") + if isinstance(error_context, dict) + ) + _is_xai_auth_failure = ( + "[wke=unauthenticated:" in _disambiguator_haystack + or "oauth2 access token could not be validated" in _disambiguator_haystack + ) + if not _is_xai_auth_failure: + is_entitlement = True if is_entitlement: _ra().logger.info( "Credential %s โ€” entitlement-shaped 403 from %s; " @@ -728,7 +747,7 @@ def try_recover_primary_transport( time.sleep(wait_time) return True except Exception as e: - logging.warning("Primary transport recovery failed: %s", e) + logger.warning("Primary transport recovery failed: %s", e) return False # โ”€โ”€ End provider fallback โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -891,19 +910,20 @@ def restore_primary_runtime(agent) -> bool: base_url=rt["compressor_base_url"], api_key=rt["compressor_api_key"], provider=rt["compressor_provider"], + api_mode=rt.get("compressor_api_mode", ""), ) # โ”€โ”€ Reset fallback chain for the new turn โ”€โ”€ agent._fallback_activated = False agent._fallback_index = 0 - logging.info( + logger.info( "Primary runtime restored for new turn: %s (%s)", agent.model, agent.provider, ) return True except Exception as e: - logging.warning("Failed to restore primary runtime: %s", e) + logger.warning("Failed to restore primary runtime: %s", e) return False # Which error types indicate a transient transport failure worth @@ -1064,10 +1084,7 @@ def dump_api_request_debug( timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") dump_file = agent.logs_dir / f"request_dump_{agent.session_id}_{timestamp}.json" - dump_file.write_text( - json.dumps(dump_payload, ensure_ascii=False, indent=2, default=str), - encoding="utf-8", - ) + atomic_json_write(dump_file, dump_payload, default=str) agent._vprint(f"{agent.log_prefix}๐Ÿงพ Request debug dump written to: {dump_file}") @@ -1077,7 +1094,7 @@ def dump_api_request_debug( return dump_file except Exception as dump_error: if agent.verbose_logging: - logging.warning(f"Failed to dump API request debug payload: {dump_error}") + logger.warning(f"Failed to dump API request debug payload: {dump_error}") return None @@ -1352,6 +1369,22 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo # API key โ€” falling back would send Anthropic credentials to third-party endpoints. _is_native_anthropic = new_provider == "anthropic" effective_key = (api_key or agent.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or agent.api_key or "") + + # MiniMax OAuth: swap static string for a per-request callable token + # provider so the rebuilt client survives 15-min token expiry. See + # the matching block in agent_init.py for the full rationale. + if new_provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key: + try: + from hermes_cli.auth import build_minimax_oauth_token_provider + effective_key = build_minimax_oauth_token_provider() + except Exception as _mm_exc: # noqa: BLE001 + import logging as _logging + _logging.getLogger(__name__).warning( + "MiniMax OAuth: failed to install per-request token provider " + "on switch (%s); using static bearer.", + _mm_exc, + ) + agent.api_key = effective_key agent._anthropic_api_key = effective_key agent._anthropic_base_url = base_url or getattr(agent, "_anthropic_base_url", None) @@ -1359,7 +1392,7 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo effective_key, agent._anthropic_base_url, timeout=get_provider_request_timeout(agent.provider, agent.model), ) - agent._is_anthropic_oauth = _is_oauth_token(effective_key) if _is_native_anthropic else False + agent._is_anthropic_oauth = _is_oauth_token(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False agent.client = None agent._client_kwargs = {} else: @@ -1446,6 +1479,7 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo "compressor_api_key": getattr(_cc, "api_key", "") if _cc else "", "compressor_provider": getattr(_cc, "provider", agent.provider) if _cc else agent.provider, "compressor_context_length": _cc.context_length if _cc else 0, + "compressor_api_mode": getattr(_cc, "api_mode", agent.api_mode) if _cc else agent.api_mode, "compressor_threshold_tokens": _cc.threshold_tokens if _cc else 0, } if api_mode == "anthropic_messages": @@ -1477,7 +1511,7 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo agent._fallback_chain = fallback_chain agent._fallback_model = fallback_chain[0] if fallback_chain else None - logging.info( + logger.info( "Model switched in-place: %s (%s) -> %s (%s)", old_model, old_provider, new_model, new_provider, ) @@ -2116,33 +2150,56 @@ def apply_pending_steer_to_tool_results(agent, messages: list, num_tool_msgs: in def force_close_tcp_sockets(client: Any) -> int: - """Force-close underlying TCP sockets to prevent CLOSE-WAIT accumulation. + """Abort in-flight TCP I/O by shutting down sockets WITHOUT closing FDs. - When a provider drops a connection mid-stream, httpx's ``client.close()`` - performs a graceful shutdown which leaves sockets in CLOSE-WAIT until the - OS times them out (often minutes). This method walks the httpx transport - pool and issues ``socket.shutdown(SHUT_RDWR)`` + ``socket.close()`` to - force an immediate TCP RST, freeing the file descriptors. + When a provider drops a connection mid-stream โ€” or the user issues an + interrupt โ€” we want to unblock httpx's reader/writer immediately rather + than waiting for the kernel's per-connection timeout. ``shutdown(SHUT_RDWR)`` + achieves that: it sends FIN, breaks any pending ``recv``/``send`` with EOF + or ``EPIPE``, but does NOT release the file descriptor. - Returns the number of sockets force-closed. + Historically this helper also called ``socket.close()`` so the FD got + released immediately, but that's unsafe when (as is the case for both the + interrupt-abort path and stale-call kill path) the helper runs on a + different thread than the one driving the request: + + * The Python ``socket.socket`` we close here is the SAME object held by + httpx's pool, so closing it via Python sets its ``_fd`` to -1 and + future operations on that Python object fail safely. + * BUT the SSL wrapper (``ssl.SSLSocket``'s underlying OpenSSL ``BIO``) + caches the raw integer FD. Once ``os.close(fd)`` runs, the kernel may + immediately recycle that integer to the next ``open()`` call โ€” e.g. + the kanban dispatcher opening ``kanban.db``. + * The owning worker thread then unwinds httpx, the SSL layer flushes a + pending TLS record, and the encrypted bytes get written into the + wrong file (issue #29507: 24-byte TLS application-data record + clobbering SQLite header bytes 5..28). + + The fix is to let the owning thread own the close. ``shutdown()`` from any + thread is FD-safe; ``close()`` is not. The httpx connection's own close + path โ€” which runs from the worker thread when it unwinds โ€” will release + the FD via the same ``socket.socket`` object, and because Python's socket + close atomically swaps ``_fd`` to -1 *before* issuing ``os.close``, there + is no FD-aliasing window when only one thread closes. + + Returns the number of sockets shut down. (Field kept as + ``tcp_force_closed=N`` in the log line for backwards-compatible parsing.) """ import socket as _socket - closed = 0 + shutdown_count = 0 try: for sock in _iter_pool_sockets(client): try: sock.shutdown(_socket.SHUT_RDWR) except OSError: + # Already shut down / not connected / FD invalid โ€” all benign. pass - try: - sock.close() - except OSError: - pass - closed += 1 + # IMPORTANT (#29507): do NOT call sock.close() here. See docstring. + shutdown_count += 1 except Exception as exc: _ra().logger.debug("Force-close TCP sockets sweep error: %s", exc) - return closed + return shutdown_count diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index c94d664a434..8c06f3d517b 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -1606,182 +1606,155 @@ def _content_parts_to_anthropic_blocks(parts: Any) -> List[Dict[str, Any]]: return out -def convert_messages_to_anthropic( - messages: List[Dict], - base_url: str | None = None, - model: str | None = None, -) -> Tuple[Optional[Any], List[Dict]]: - """Convert OpenAI-format messages to Anthropic format. +def _convert_assistant_message(m: Dict[str, Any]) -> Dict[str, Any]: + """Convert an assistant message to Anthropic content blocks. - Returns (system_prompt, anthropic_messages). - System messages are extracted since Anthropic takes them as a separate param. - system_prompt is a string or list of content blocks (when cache_control present). - - When *base_url* is provided and points to a third-party Anthropic-compatible - endpoint, all thinking block signatures are stripped. Signatures are - Anthropic-proprietary โ€” third-party endpoints cannot validate them and will - reject them with HTTP 400 "Invalid signature in thinking block". - - When *model* is provided and matches the Kimi / Moonshot family (or - *base_url* is a Kimi / Moonshot host), unsigned thinking blocks - synthesised from ``reasoning_content`` are preserved on replayed - assistant tool-call messages โ€” Kimi requires the field to exist, even - if empty. + Handles thinking blocks, regular content, tool calls, and + reasoning_content injection for Kimi/DeepSeek endpoints. """ - system = None - result = [] - - for m in messages: - role = m.get("role", "user") - content = m.get("content", "") - - if role == "system": - if isinstance(content, list): - # Preserve cache_control markers on content blocks - has_cache = any( - p.get("cache_control") for p in content if isinstance(p, dict) - ) - if has_cache: - system = [p for p in content if isinstance(p, dict)] - else: - system = "\n".join( - p["text"] for p in content if p.get("type") == "text" - ) - else: - system = content - continue - - if role == "assistant": - blocks = _extract_preserved_thinking_blocks(m) - if content: - if isinstance(content, list): - converted_content = _convert_content_to_anthropic(content) - if isinstance(converted_content, list): - blocks.extend(converted_content) - else: - blocks.append({"type": "text", "text": str(content)}) - for tc in m.get("tool_calls", []): - if not tc or not isinstance(tc, dict): - continue - fn = tc.get("function", {}) - args = fn.get("arguments", "{}") - try: - parsed_args = json.loads(args) if isinstance(args, str) else args - except (json.JSONDecodeError, ValueError): - parsed_args = {} - blocks.append({ - "type": "tool_use", - "id": _sanitize_tool_id(tc.get("id", "")), - "name": fn.get("name", ""), - "input": parsed_args, - }) - # Kimi's /coding endpoint (Anthropic protocol) requires assistant - # tool-call messages to carry reasoning_content when thinking is - # enabled server-side. Preserve it as a thinking block so Kimi - # can validate the message history. See hermes-agent#13848. - # - # Accept empty string "" โ€” _copy_reasoning_content_for_api() - # injects "" as a tier-3 fallback for Kimi tool-call messages - # that had no reasoning. Kimi requires the field to exist, even - # if empty. - # - # Prepend (not append): Anthropic protocol requires thinking - # blocks before text and tool_use blocks. - # - # Guard: only add when reasoning_details didn't already contribute - # thinking blocks. On native Anthropic, reasoning_details produces - # signed thinking blocks โ€” adding another unsigned one from - # reasoning_content would create a duplicate (same text) that gets - # downgraded to a spurious text block on the last assistant message. - reasoning_content = m.get("reasoning_content") - _already_has_thinking = any( - isinstance(b, dict) and b.get("type") in {"thinking", "redacted_thinking"} - for b in blocks - ) - if isinstance(reasoning_content, str) and not _already_has_thinking: - blocks.insert(0, {"type": "thinking", "thinking": reasoning_content}) - # Anthropic rejects empty assistant content - effective = blocks or content - if not effective or effective == "": - effective = [{"type": "text", "text": "(empty)"}] - result.append({"role": "assistant", "content": effective}) - continue - - if role == "tool": - # Sanitize tool_use_id and ensure non-empty content. - # Computer-use (and other multimodal) tool results arrive as - # either a list of OpenAI-style content parts, or a dict - # marked `_multimodal` with an embedded `content` list. Convert - # both into Anthropic `tool_result` inner blocks (text + image). - multimodal_blocks: Optional[List[Dict[str, Any]]] = None - if isinstance(content, dict) and content.get("_multimodal"): - multimodal_blocks = _content_parts_to_anthropic_blocks( - content.get("content") or [] - ) - # Fallback text if the conversion produced nothing usable. - if not multimodal_blocks and content.get("text_summary"): - multimodal_blocks = [ - {"type": "text", "text": str(content["text_summary"])} - ] - elif isinstance(content, list): - converted = _content_parts_to_anthropic_blocks(content) - if any(b.get("type") == "image" for b in converted): - multimodal_blocks = converted - # Back-compat: some callers stash blocks under a private key. - if multimodal_blocks is None: - stashed = m.get("_anthropic_content_blocks") - if isinstance(stashed, list) and stashed: - text_content = content if isinstance(content, str) and content.strip() else None - multimodal_blocks = ( - [{"type": "text", "text": text_content}] + stashed - if text_content else list(stashed) - ) - - if multimodal_blocks: - result_content: Any = multimodal_blocks - elif isinstance(content, str): - result_content = content - else: - result_content = json.dumps(content) if content else "(no output)" - if not result_content: - result_content = "(no output)" - tool_result = { - "type": "tool_result", - "tool_use_id": _sanitize_tool_id(m.get("tool_call_id", "")), - "content": result_content, - } - if isinstance(m.get("cache_control"), dict): - tool_result["cache_control"] = dict(m["cache_control"]) - # Merge consecutive tool results into one user message - if ( - result - and result[-1]["role"] == "user" - and isinstance(result[-1]["content"], list) - and result[-1]["content"] - and result[-1]["content"][0].get("type") == "tool_result" - ): - result[-1]["content"].append(tool_result) - else: - result.append({"role": "user", "content": [tool_result]}) - continue - - # Regular user message โ€” validate non-empty content (Anthropic rejects empty) + content = m.get("content", "") + blocks = _extract_preserved_thinking_blocks(m) + if content: if isinstance(content, list): - converted_blocks = _convert_content_to_anthropic(content) - # Check if all text blocks are empty - if not converted_blocks or all( - b.get("text", "").strip() == "" - for b in converted_blocks - if isinstance(b, dict) and b.get("type") == "text" - ): - converted_blocks = [{"type": "text", "text": "(empty message)"}] - result.append({"role": "user", "content": converted_blocks}) + converted_content = _convert_content_to_anthropic(content) + if isinstance(converted_content, list): + blocks.extend(converted_content) else: - # Validate string content is non-empty - if not content or (isinstance(content, str) and not content.strip()): - content = "(empty message)" - result.append({"role": "user", "content": content}) + blocks.append({"type": "text", "text": str(content)}) + for tc in m.get("tool_calls", []): + if not tc or not isinstance(tc, dict): + continue + fn = tc.get("function", {}) + args = fn.get("arguments", "{}") + try: + parsed_args = json.loads(args) if isinstance(args, str) else args + except (json.JSONDecodeError, ValueError): + parsed_args = {} + blocks.append({ + "type": "tool_use", + "id": _sanitize_tool_id(tc.get("id", "")), + "name": fn.get("name", ""), + "input": parsed_args, + }) + # Kimi's /coding endpoint (Anthropic protocol) requires assistant + # tool-call messages to carry reasoning_content when thinking is + # enabled server-side. Preserve it as a thinking block so Kimi + # can validate the message history. See hermes-agent#13848. + # + # Accept empty string "" โ€” _copy_reasoning_content_for_api() + # injects "" as a tier-3 fallback for Kimi tool-call messages + # that had no reasoning. Kimi requires the field to exist, even + # if empty. + # + # Prepend (not append): Anthropic protocol requires thinking + # blocks before text and tool_use blocks. + # + # Guard: only add when reasoning_details didn't already contribute + # thinking blocks. On native Anthropic, reasoning_details produces + # signed thinking blocks โ€” adding another unsigned one from + # reasoning_content would create a duplicate (same text) that gets + # downgraded to a spurious text block on the last assistant message. + reasoning_content = m.get("reasoning_content") + _already_has_thinking = any( + isinstance(b, dict) and b.get("type") in {"thinking", "redacted_thinking"} + for b in blocks + ) + if isinstance(reasoning_content, str) and not _already_has_thinking: + blocks.insert(0, {"type": "thinking", "thinking": reasoning_content}) + # Anthropic rejects empty assistant content + effective = blocks or content + if not effective or effective == "": + effective = [{"type": "text", "text": "(empty)"}] + return {"role": "assistant", "content": effective} + +def _convert_tool_message_to_result( + result: List[Dict[str, Any]], m: Dict[str, Any] +) -> None: + """Convert a tool message to an Anthropic tool_result, merging consecutive + results into one user message. + + Mutates ``result`` in place โ€” either appends a new user message or extends + the trailing user message's tool_result list. + """ + content = m.get("content", "") + multimodal_blocks: Optional[List[Dict[str, Any]]] = None + if isinstance(content, dict) and content.get("_multimodal"): + multimodal_blocks = _content_parts_to_anthropic_blocks( + content.get("content") or [] + ) + # Fallback text if the conversion produced nothing usable. + if not multimodal_blocks and content.get("text_summary"): + multimodal_blocks = [ + {"type": "text", "text": str(content["text_summary"])} + ] + elif isinstance(content, list): + converted = _content_parts_to_anthropic_blocks(content) + if any(b.get("type") == "image" for b in converted): + multimodal_blocks = converted + # Back-compat: some callers stash blocks under a private key. + if multimodal_blocks is None: + stashed = m.get("_anthropic_content_blocks") + if isinstance(stashed, list) and stashed: + text_content = content if isinstance(content, str) and content.strip() else None + multimodal_blocks = ( + [{"type": "text", "text": text_content}] + stashed + if text_content else list(stashed) + ) + + if multimodal_blocks: + result_content: Any = multimodal_blocks + elif isinstance(content, str): + result_content = content + else: + result_content = json.dumps(content) if content else "(no output)" + if not result_content: + result_content = "(no output)" + tool_result = { + "type": "tool_result", + "tool_use_id": _sanitize_tool_id(m.get("tool_call_id", "")), + "content": result_content, + } + if isinstance(m.get("cache_control"), dict): + tool_result["cache_control"] = dict(m["cache_control"]) + # Merge consecutive tool results into one user message + if ( + result + and result[-1]["role"] == "user" + and isinstance(result[-1]["content"], list) + and result[-1]["content"] + and result[-1]["content"][0].get("type") == "tool_result" + ): + result[-1]["content"].append(tool_result) + else: + result.append({"role": "user", "content": [tool_result]}) + + +def _convert_user_message(content: Any) -> Dict[str, Any]: + """Validate and convert a user message to anthropic format.""" + if isinstance(content, list): + converted_blocks = _convert_content_to_anthropic(content) + if not converted_blocks or all( + b.get("text", "").strip() == "" + for b in converted_blocks + if isinstance(b, dict) and b.get("type") == "text" + ): + converted_blocks = [{"type": "text", "text": "(empty message)"}] + return {"role": "user", "content": converted_blocks} + else: + if not content or (isinstance(content, str) and not content.strip()): + content = "(empty message)" + return {"role": "user", "content": content} + + +def _strip_orphaned_tool_blocks(result: List[Dict[str, Any]]) -> None: + """Strip tool_use blocks with no matching tool_result, and vice versa. + + Context compression or session truncation can remove either side of a + tool-call pair. Anthropic rejects both orphans with HTTP 400. + + Mutates ``result`` in place. + """ # Strip orphaned tool_use blocks (no matching tool_result follows) tool_result_ids = set() for m in result: @@ -1799,10 +1772,7 @@ def convert_messages_to_anthropic( if not m["content"]: m["content"] = [{"type": "text", "text": "(tool call removed)"}] - # Strip orphaned tool_result blocks (no matching tool_use precedes them). - # This is the mirror of the above: context compression or session truncation - # can remove an assistant message containing a tool_use while leaving the - # subsequent tool_result intact. Anthropic rejects these with a 400. + # Strip orphaned tool_result blocks (no matching tool_use precedes them) tool_use_ids = set() for m in result: if m["role"] == "assistant" and isinstance(m["content"], list): @@ -1819,12 +1789,16 @@ def convert_messages_to_anthropic( if not m["content"]: m["content"] = [{"type": "text", "text": "(tool result removed)"}] - # Enforce strict role alternation (Anthropic rejects consecutive same-role messages) + +def _merge_consecutive_roles(result: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Merge consecutive same-role messages to enforce Anthropic alternation. + + Returns a new list (caller must rebind ``result``). + """ fixed = [] for m in result: if fixed and fixed[-1]["role"] == m["role"]: if m["role"] == "user": - # Merge consecutive user messages prev_content = fixed[-1]["content"] curr_content = m["content"] if isinstance(prev_content, str) and isinstance(curr_content, str): @@ -1832,7 +1806,6 @@ def convert_messages_to_anthropic( elif isinstance(prev_content, list) and isinstance(curr_content, list): fixed[-1]["content"] = prev_content + curr_content else: - # Mixed types โ€” wrap string in list if isinstance(prev_content, str): prev_content = [{"type": "text", "text": prev_content}] if isinstance(curr_content, str): @@ -1855,7 +1828,6 @@ def convert_messages_to_anthropic( elif isinstance(prev_blocks, str) and isinstance(curr_blocks, str): fixed[-1]["content"] = prev_blocks + "\n" + curr_blocks else: - # Mixed types โ€” normalize both to list and merge if isinstance(prev_blocks, str): prev_blocks = [{"type": "text", "text": prev_blocks}] if isinstance(curr_blocks, str): @@ -1863,37 +1835,34 @@ def convert_messages_to_anthropic( fixed[-1]["content"] = prev_blocks + curr_blocks else: fixed.append(m) - result = fixed + return fixed - # โ”€โ”€ Thinking block signature management โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - # Anthropic signs thinking blocks against the full turn content. - # Any upstream mutation (context compression, session truncation, - # orphan stripping, message merging) invalidates the signature, - # causing HTTP 400 "Invalid signature in thinking block". - # - # Signatures are Anthropic-proprietary. Third-party endpoints - # (MiniMax, Microsoft Foundry, self-hosted proxies) cannot validate - # them and will reject them outright. When targeting a third-party - # endpoint, strip ALL thinking/redacted_thinking blocks from every - # assistant message โ€” the third-party will generate its own - # thinking blocks if it supports extended thinking. - # - # For direct Anthropic (strategy following clawdbot/OpenClaw): - # 1. Strip thinking/redacted_thinking from all assistant messages - # EXCEPT the last one โ€” preserves reasoning continuity on the - # current tool-use chain while avoiding stale signature errors. - # 2. Downgrade unsigned thinking blocks (no signature) to text โ€” - # Anthropic can't validate them and will reject them. - # 3. Strip cache_control from thinking/redacted_thinking blocks โ€” - # cache markers can interfere with signature validation. + +def _manage_thinking_signatures( + result: List[Dict[str, Any]], base_url: str | None, model: str | None +) -> None: + """Strip or preserve thinking blocks based on endpoint type. + + Anthropic signs thinking blocks against the full turn content. + Any upstream mutation (context compression, session truncation, orphan + stripping, message merging) invalidates the signature, causing HTTP 400 + "Invalid signature in thinking block". + + Signatures are Anthropic-proprietary. Third-party endpoints (MiniMax, + Azure AI Foundry, AWS Bedrock, self-hosted proxies) cannot validate them + and will reject them outright. Kimi's /coding and DeepSeek's /anthropic + endpoints speak the Anthropic protocol upstream but require unsigned + thinking blocks (synthesised from ``reasoning_content``) to round-trip on + replayed assistant tool-call messages. See hermes-agent#13848 (Kimi) and + hermes-agent#16748 (DeepSeek). + + Mutates ``result`` in place. + """ _THINKING_TYPES = frozenset(("thinking", "redacted_thinking")) _is_third_party = _is_third_party_anthropic_endpoint(base_url) - # Kimi /coding and DeepSeek /anthropic share a contract: both speak the - # Anthropic Messages protocol upstream but require that thinking blocks - # synthesised from reasoning_content round-trip on subsequent turns when - # thinking is enabled. Signed Anthropic blocks still have to be stripped - # (neither endpoint can validate Anthropic's signatures); unsigned blocks - # are preserved. See hermes-agent#13848 (Kimi) and #16748 (DeepSeek). + # Kimi / DeepSeek share a contract: strip signed Anthropic blocks + # (neither upstream can validate Anthropic signatures), preserve unsigned + # ones synthesised from reasoning_content. See #13848, #16748. _preserve_unsigned_thinking = ( _is_kimi_family_endpoint(base_url, model) or _is_deepseek_anthropic_endpoint(base_url) @@ -1910,26 +1879,19 @@ def convert_messages_to_anthropic( continue if _preserve_unsigned_thinking: - # Kimi's /coding and DeepSeek's /anthropic endpoints both enable - # thinking server-side and require unsigned thinking blocks on - # replayed assistant tool-call messages. Strip signed Anthropic - # blocks (neither upstream can validate Anthropic signatures) but - # preserve the unsigned ones we synthesised from reasoning_content. + # Kimi / DeepSeek: strip signed, preserve unsigned. new_content = [] for b in m["content"]: if not isinstance(b, dict) or b.get("type") not in _THINKING_TYPES: new_content.append(b) continue if b.get("signature") or b.get("data"): - # Anthropic-signed block โ€” upstream can't validate, strip + # Signed (or redacted-with-data) โ€” upstream can't validate, strip. continue - # Unsigned thinking (synthesised from reasoning_content) โ€” - # keep it: the upstream needs it for message-history validation. new_content.append(b) m["content"] = new_content or [{"type": "text", "text": "(empty)"}] elif _is_third_party or idx != last_assistant_idx: - # Third-party endpoint: strip ALL thinking blocks from every - # assistant message โ€” signatures are Anthropic-proprietary. + # Third-party: strip ALL thinking blocks (signatures are proprietary). # Direct Anthropic: strip from non-latest assistant messages only. stripped = [ b for b in m["content"] @@ -1937,24 +1899,21 @@ def convert_messages_to_anthropic( ] m["content"] = stripped or [{"type": "text", "text": "(thinking elided)"}] else: - # Latest assistant on direct Anthropic: keep signed thinking - # blocks for reasoning continuity; downgrade unsigned ones to - # plain text. + # Latest assistant on direct Anthropic: keep signed, downgrade unsigned + # to text so the reasoning isn't lost. new_content = [] for b in m["content"]: if not isinstance(b, dict) or b.get("type") not in _THINKING_TYPES: new_content.append(b) continue if b.get("type") == "redacted_thinking": - # Redacted blocks use 'data' for the signature payload + # Redacted blocks use 'data' for the signature payload โ€” + # drop the block when 'data' is missing (can't be validated). if b.get("data"): new_content.append(b) - # else: drop โ€” no data means it can't be validated elif b.get("signature"): - # Signed thinking block โ€” keep it new_content.append(b) else: - # Unsigned thinking โ€” downgrade to text so it's not lost thinking_text = b.get("thinking", "") if thinking_text: new_content.append({"type": "text", "text": thinking_text}) @@ -1966,12 +1925,15 @@ def convert_messages_to_anthropic( if isinstance(b, dict) and b.get("type") in _THINKING_TYPES: b.pop("cache_control", None) - # โ”€โ”€ Image eviction: keep only the most recent N screenshots โ”€โ”€โ”€โ”€โ”€ - # computer_use screenshots (base64 images) sit inside tool_result - # blocks: they accumulate and are sent with every API call. Each - # costs ~1,465 tokens; after 10+ the conversation becomes slow - # even for simple text queries. Walk backward, keep the most recent - # _MAX_KEEP_IMAGES, replace older ones with a text placeholder. + +def _evict_old_screenshots(result: List[Dict[str, Any]]) -> None: + """Keep only the most recent ``_MAX_KEEP_IMAGES`` computer-use screenshots. + + Base64 images cost ~1,465 tokens each and accumulate across tool calls. + Walk backward, keep the most recent N, replace older ones with a placeholder. + + Mutates ``result`` in place. + """ _MAX_KEEP_IMAGES = 3 _image_count = 0 for msg in reversed(result): @@ -1998,6 +1960,68 @@ def convert_messages_to_anthropic( for b in inner ] + +def convert_messages_to_anthropic( + messages: List[Dict], + base_url: str | None = None, + model: str | None = None, +) -> Tuple[Optional[Any], List[Dict]]: + """Convert OpenAI-format messages to Anthropic format. + + Returns (system_prompt, anthropic_messages). + System messages are extracted since Anthropic takes them as a separate param. + system_prompt is a string or list of content blocks (when cache_control present). + + When *base_url* is provided and points to a third-party Anthropic-compatible + endpoint, all thinking block signatures are stripped. Signatures are + Anthropic-proprietary โ€” third-party endpoints cannot validate them and will + reject them with HTTP 400 "Invalid signature in thinking block". + + When *model* is provided and matches the Kimi / Moonshot family (or + *base_url* is a Kimi / Moonshot host), unsigned thinking blocks + synthesised from ``reasoning_content`` are preserved on replayed + assistant tool-call messages โ€” Kimi requires the field to exist, even + if empty. + """ + system = None + result: List[Dict[str, Any]] = [] + + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") + + if role == "system": + if isinstance(content, list): + # Preserve cache_control markers on content blocks + has_cache = any( + p.get("cache_control") for p in content if isinstance(p, dict) + ) + if has_cache: + system = [p for p in content if isinstance(p, dict)] + else: + system = "\n".join( + p["text"] for p in content if p.get("type") == "text" + ) + else: + system = content + continue + + if role == "assistant": + result.append(_convert_assistant_message(m)) + continue + + if role == "tool": + _convert_tool_message_to_result(result, m) + continue + + # Regular user message + result.append(_convert_user_message(content)) + + _strip_orphaned_tool_blocks(result) + result = _merge_consecutive_roles(result) + _manage_thinking_signatures(result, base_url, model) + _evict_old_screenshots(result) + return system, result @@ -2098,9 +2122,13 @@ def build_anthropic_kwargs( block["text"] = text # 3. Prefix tool names with mcp_ (Claude Code convention) + # Skip names that already begin with the marker โ€” native MCP server + # tools (from mcp_servers: in config.yaml) are registered under their + # full mcp__ name and would double-prefix otherwise, + # breaking round-trip registry lookup in normalize_response. GH-25255. if anthropic_tools: for tool in anthropic_tools: - if "name" in tool: + if "name" in tool and not tool["name"].startswith(_MCP_TOOL_PREFIX): tool["name"] = _MCP_TOOL_PREFIX + tool["name"] # 4. Prefix tool names in message history (tool_use and tool_result blocks) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 89dc7d935b4..37880190426 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -3730,6 +3730,37 @@ _VISION_AUTO_PROVIDER_ORDER = ( ) +def _main_model_supports_vision(provider: str, model: Optional[str]) -> bool: + """Return True when ``provider``/``model`` is known to accept image input. + + Used by the vision auto-detect chain to skip the user's main provider + when it's known to be text-only (e.g. DeepSeek, gpt-oss without vision). + Without this guard, ``resolve_vision_provider_client(provider="auto")`` + would happily return the main-provider client and any subsequent image + payload would surface as a cryptic provider-side error + (``unknown variant `image_url`, expected `text```, #31179). + + Returns True when capability lookup is unknown โ€” preserves the historical + behaviour of attempting the call, so providers we haven't catalogued yet + don't silently regress to text-only. + """ + try: + from agent.image_routing import _lookup_supports_vision + from hermes_cli.config import load_config + except ImportError: + return True + try: + supports = _lookup_supports_vision(provider, model, load_config()) + except Exception: # pragma: no cover - defensive + return True + if supports is None: + # No capability data โ€” keep current behaviour and let the call attempt + # happen rather than silently skipping. This avoids false-positive + # skips for new/custom providers. + return True + return bool(supports) + + def _normalize_vision_provider(provider: Optional[str]) -> str: return _normalize_aux_provider(provider) @@ -3870,6 +3901,23 @@ def resolve_vision_provider_client( "vision support) โ€” falling through to aggregator chain", main_provider, ) + elif not _main_model_supports_vision(main_provider, vision_model): + # The main model is known to be text-only (e.g. DeepSeek V4, + # gpt-oss-120b without vision). Building a client and sending + # an image would produce a cryptic provider-side error like + # ``unknown variant `image_url`, expected `text``` (#31179). + # Fall through to the aggregator chain instead. + # + # Only log the provider name (not the model) โ€” mirrors the + # sibling _PROVIDERS_WITHOUT_VISION branch above, and avoids + # CodeQL py/clear-text-logging-sensitive-data heuristic false + # positives on multi-value interpolations. + logger.debug( + "Vision auto-detect: skipping main provider %s " + "(reports no vision capability) โ€” falling through to " + "aggregator chain", + main_provider, + ) else: rpc_client, rpc_model = resolve_provider_client( main_provider, vision_model, @@ -4281,6 +4329,23 @@ def _get_cached_client( return client, model or default_model +# Aliases that target direct REST APIs not modeled as first-class providers +# in PROVIDER_REGISTRY. Used for ``auxiliary..provider`` so users can +# write the obvious name and have it resolve to a working ``custom`` endpoint +# without needing to know our internal provider IDs. +# +# Why these specifically: PROVIDER_REGISTRY has ``openai-codex`` (OAuth) and +# ``custom`` (manual base_url + OPENAI_API_KEY) but no plain ``openai`` for +# direct API-key access. Users predictably type ``provider: openai`` and +# expect it to use OPENAI_API_KEY against api.openai.com. Previously this +# silently fell back to the user's main provider, sending OpenAI model names +# to e.g. DeepSeek and producing cryptic ``unknown variant 'image_url'`` +# errors (issue #31179). +_AUX_DIRECT_API_BASE_URLS: Dict[str, str] = { + "openai": "https://api.openai.com/v1", +} + + def _resolve_task_provider_model( task: str = None, provider: str = None, @@ -4317,6 +4382,25 @@ def _resolve_task_provider_model( resolved_model = model or cfg_model resolved_api_mode = cfg_api_mode + # Convenience aliases for direct API-key endpoints that aren't first-class + # providers (e.g. ``provider: openai`` โ†’ custom + api.openai.com/v1). + # Applied to both explicit args and config-derived values. When the user + # has already supplied a base_url we keep their endpoint but still rewrite + # the provider to ``custom`` so resolution doesn't hit the + # PROVIDER_REGISTRY-only path (which has no ``openai`` entry). + def _expand_direct_api_alias(prov: Optional[str], existing_base: Optional[str]) -> Tuple[Optional[str], Optional[str]]: + if not prov: + return prov, existing_base + target_base = _AUX_DIRECT_API_BASE_URLS.get(prov.strip().lower()) + if target_base is None: + return prov, existing_base + return "custom", existing_base or target_base + + if provider: + provider, base_url = _expand_direct_api_alias(provider, base_url) + if cfg_provider: + cfg_provider, cfg_base_url = _expand_direct_api_alias(cfg_provider, cfg_base_url) + if base_url: return "custom", resolved_model, base_url, api_key, resolved_api_mode if provider: @@ -4344,7 +4428,17 @@ _DEFAULT_AUX_TIMEOUT = 30.0 def _get_auxiliary_task_config(task: str) -> Dict[str, Any]: - """Return the config dict for auxiliary., or {} when unavailable.""" + """Return the config dict for auxiliary., or {} when unavailable. + + For plugin-registered auxiliary tasks (see + :meth:`hermes_cli.plugins.PluginContext.register_auxiliary_task`) the + plugin's declared *defaults* are layered underneath the user's config + so an unconfigured plugin task still works: + + plugin defaults โ† config.yaml auxiliary. (user wins) + + Built-in tasks ignore this path (their defaults live in DEFAULT_CONFIG). + """ if not task: return {} try: @@ -4354,7 +4448,27 @@ def _get_auxiliary_task_config(task: str) -> Dict[str, Any]: return {} aux = config.get("auxiliary", {}) if isinstance(config, dict) else {} task_config = aux.get(task, {}) if isinstance(aux, dict) else {} - return task_config if isinstance(task_config, dict) else {} + if not isinstance(task_config, dict): + task_config = {} + + # Layer plugin-declared defaults underneath user config so + # ctx.register_auxiliary_task(defaults={...}) takes effect without + # forcing the user to write config.yaml entries. + try: + from hermes_cli.plugins import get_plugin_auxiliary_tasks + for _entry in get_plugin_auxiliary_tasks(): + if _entry.get("key") == task: + _defaults = _entry.get("defaults") or {} + if isinstance(_defaults, dict): + merged = dict(_defaults) + merged.update(task_config) + return merged + break + except Exception: + # Plugin discovery failure must not break aux task config reads. + pass + + return task_config def _get_task_timeout(task: str, default: float = _DEFAULT_AUX_TIMEOUT) -> float: diff --git a/agent/background_review.py b/agent/background_review.py index ba65b2b1bc8..35d3d5191a0 100644 --- a/agent/background_review.py +++ b/agent/background_review.py @@ -115,7 +115,10 @@ _SKILL_REVIEW_PROMPT = ( "Protected skills (DO NOT edit these):\n" " โ€ข Bundled skills (shipped with Hermes, e.g. 'hermes-agent').\n" " โ€ข Hub-installed skills (installed via 'hermes skills install').\n" - " โ€ข Pinned skills (marked via 'hermes curator pin').\n" + "Pinned skills (marked via 'hermes curator pin') CAN be improved โ€” " + "pin only blocks deletion/archive/consolidation by the curator, not " + "content updates. Patch them when a pitfall or missing step turns up, " + "same as any other agent-created skill.\n" "If the only skills that need updating are protected, say\n" "'Nothing to save.' and stop.\n\n" "Do NOT capture (these become persistent self-imposed constraints " @@ -198,7 +201,10 @@ _COMBINED_REVIEW_PROMPT = ( "Protected skills (DO NOT edit these):\n" " โ€ข Bundled skills (shipped with Hermes, e.g. 'hermes-agent').\n" " โ€ข Hub-installed skills (installed via 'hermes skills install').\n" - " โ€ข Pinned skills (marked via 'hermes curator pin').\n" + "Pinned skills (marked via 'hermes curator pin') CAN be improved โ€” " + "pin only blocks deletion/archive/consolidation by the curator, not " + "content updates. Patch them when a pitfall or missing step turns up, " + "same as any other agent-created skill.\n" "If the only skills that need updating are protected, say\n" "'Nothing to save.' and stop.\n\n" "Do NOT capture as skills (these become persistent self-imposed " diff --git a/agent/chat_completion_helpers.py b/agent/chat_completion_helpers.py index c68f2271f5b..b3261b60d0b 100644 --- a/agent/chat_completion_helpers.py +++ b/agent/chat_completion_helpers.py @@ -91,23 +91,55 @@ def interruptible_api_call(agent, api_kwargs: dict): provider fallback. """ result = {"response": None, "error": None} - request_client_holder = {"client": None} + request_client_holder = {"client": None, "owner_tid": None} request_client_lock = threading.Lock() def _set_request_client(client): with request_client_lock: request_client_holder["client"] = client + # #29507: stamp the owning thread so a stranger-thread interrupt + # only shuts the connection down rather than racing the worker + # for FD ownership during ``client.close()``. + request_client_holder["owner_tid"] = threading.get_ident() return client def _take_request_client(): with request_client_lock: client = request_client_holder.get("client") request_client_holder["client"] = None + request_client_holder["owner_tid"] = None return client def _close_request_client_once(reason: str) -> None: - request_client = _take_request_client() - if request_client is not None: + # #29507: dispatch on the calling thread. + # + # When ``_call`` (the worker) reaches its ``finally`` it owns the + # close and we pop + fully close as before. When a *stranger* thread + # (the interrupt-check loop, the stale-call detector) drives the + # close, only shut the sockets down so the worker's blocked + # ``recv``/``send`` unwinds with an ``EPIPE`` / EOF โ€” and let the + # worker close ``client`` from its own thread on its way out. That + # avoids the FD-recycling race where the kernel reassigned a + # just-closed TLS socket FD to ``kanban.db``, and the still-live SSL + # BIO on the worker thread then wrote a 24-byte TLS application-data + # record into the SQLite header (#29507). + with request_client_lock: + request_client = request_client_holder.get("client") + owner_tid = request_client_holder.get("owner_tid") + stranger_thread = ( + request_client is not None + and owner_tid is not None + and owner_tid != threading.get_ident() + ) + if not stranger_thread: + # Owning thread (or no recorded owner) โ†’ pop and fully close. + request_client_holder["client"] = None + request_client_holder["owner_tid"] = None + if request_client is None: + return + if stranger_thread: + agent._abort_request_openai_client(request_client, reason=reason) + else: agent._close_request_openai_client(request_client, reason=reason) def _call(): @@ -725,7 +757,7 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool current_base_url = str(getattr(agent, "base_url", "") or "").rstrip("/").lower() fb_base_url_for_dedup = (fb.get("base_url") or "").strip().rstrip("/").lower() if fb_provider == current_provider and fb_model == current_model: - logging.warning( + logger.warning( "Fallback skip: chain entry %s/%s matches current provider/model", fb_provider, fb_model, ) @@ -736,7 +768,7 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool and fb_base_url_for_dedup == current_base_url and fb_model == current_model ): - logging.warning( + logger.warning( "Fallback skip: chain entry base_url %s matches current backend", fb_base_url_for_dedup, ) @@ -768,7 +800,7 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool explicit_base_url=fb_base_url_hint, explicit_api_key=fb_api_key_hint) if fb_client is None: - logging.warning( + logger.warning( "Fallback to %s failed: provider not configured", fb_provider) return agent._try_activate_fallback() # try next in chain @@ -776,8 +808,11 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool from hermes_cli.model_normalize import normalize_model_for_provider fb_model = normalize_model_for_provider(fb_model, fb_provider) - except Exception: - pass + except Exception as _norm_err: + logger.warning( + "Could not normalize fallback model %r for provider %r: %s", + fb_model, fb_provider, _norm_err, + ) # Determine api_mode from provider / base URL / model fb_api_mode = "chat_completions" @@ -905,19 +940,20 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool base_url=agent.base_url, api_key=getattr(agent, "api_key", ""), # callable preserved โ†’ call_llm provider=agent.provider, + api_mode=agent.api_mode, ) agent._emit_status( f"๐Ÿ”„ Primary model failed โ€” switching to fallback: " f"{fb_model} via {fb_provider}" ) - logging.info( + logger.info( "Fallback activated: %s โ†’ %s (%s)", old_model, fb_model, fb_provider, ) return True except Exception as e: - logging.error("Failed to activate fallback %s: %s", fb_model, e) + logger.error("Failed to activate fallback %s: %s", fb_model, e) return agent._try_activate_fallback() # try next in chain @@ -1133,7 +1169,7 @@ def handle_max_iterations(agent, messages: list, api_call_count: int) -> str: final_response = "I reached the iteration limit and couldn't generate a summary." except Exception as e: - logging.warning(f"Failed to get summary response: {e}") + logger.warning(f"Failed to get summary response: {e}") final_response = f"I reached the maximum iterations ({agent.max_iterations}) but couldn't summarize. Error: {str(e)}" return final_response @@ -1162,12 +1198,12 @@ def cleanup_task_resources(agent, task_id: str) -> None: _ra().cleanup_vm(task_id) except Exception as e: if agent.verbose_logging: - logging.warning(f"Failed to cleanup VM for task {task_id}: {e}") + logger.warning(f"Failed to cleanup VM for task {task_id}: {e}") try: _ra().cleanup_browser(task_id) except Exception as e: if agent.verbose_logging: - logging.warning(f"Failed to cleanup browser for task {task_id}: {e}") + logger.warning(f"Failed to cleanup browser for task {task_id}: {e}") @@ -1271,23 +1307,44 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= return result["response"] result = {"response": None, "error": None, "partial_tool_names": []} - request_client_holder = {"client": None, "diag": None} + request_client_holder = {"client": None, "diag": None, "owner_tid": None} request_client_lock = threading.Lock() def _set_request_client(client): with request_client_lock: request_client_holder["client"] = client + # See #29507 explanation in the non-streaming variant above. + request_client_holder["owner_tid"] = threading.get_ident() return client def _take_request_client(): with request_client_lock: client = request_client_holder.get("client") request_client_holder["client"] = None + request_client_holder["owner_tid"] = None return client def _close_request_client_once(reason: str) -> None: - request_client = _take_request_client() - if request_client is not None: + # See #29507 explanation in the non-streaming variant above. A + # stranger thread (the interrupt-check / stale-stream detector loop) + # only aborts sockets โ€” never pops, never calls ``client.close()`` โ€” + # so the worker thread retains ownership of the FD release. + with request_client_lock: + request_client = request_client_holder.get("client") + owner_tid = request_client_holder.get("owner_tid") + stranger_thread = ( + request_client is not None + and owner_tid is not None + and owner_tid != threading.get_ident() + ) + if not stranger_thread: + request_client_holder["client"] = None + request_client_holder["owner_tid"] = None + if request_client is None: + return + if stranger_thread: + agent._abort_request_openai_client(request_client, reason=reason) + else: agent._close_request_openai_client(request_client, reason=reason) first_delta_fired = {"done": False} @@ -2020,8 +2077,21 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= # Streaming failed AFTER some tokens were already delivered to # the platform. Re-raising would let the outer retry loop make # a new API call, creating a duplicate message. Return a - # partial "stop" response instead so the outer loop treats this - # turn as complete (no retry, no fallback). + # partial response stub instead and let the outer loop decide: + # + # - text-only partials โ†’ finish_reason="length" so the + # conversation loop persists the partial assistant content + # and asks the model to continue from where the stream + # died (issue #30963: partial stop misclassified as a + # clean completion was exiting the loop with budget + # remaining and an unfinished goal). + # + # - partial mid-tool-call โ†’ finish_reason="stop" stays. + # The user-visible warning we append says "Ask me to + # retry if you want to continue", so the agent should + # hand control back rather than auto-retry a tool call + # that may have side-effects. + # # Recover whatever content was already streamed to the user. # _current_streamed_assistant_text accumulates text fired # through _fire_stream_delta, so it has exactly what the @@ -2059,14 +2129,17 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= "of text; surfaced warning to user: %s", _partial_names, len(_partial_text or ""), result["error"], ) + _stub_finish_reason = "stop" else: logger.warning( - "Partial stream delivered before error; returning stub " - "response with %s chars of recovered content to prevent " - "duplicate messages: %s", + "Partial stream delivered before error; returning " + "length-truncated stub with %s chars of recovered " + "content so the loop can continue from where the " + "stream died: %s", len(_partial_text or ""), result["error"], ) + _stub_finish_reason = "length" _stub_msg = SimpleNamespace( role="assistant", content=_partial_text, tool_calls=None, reasoning_content=None, @@ -2075,7 +2148,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= id="partial-stream-stub", model=getattr(agent, "model", "unknown"), choices=[SimpleNamespace( - index=0, message=_stub_msg, finish_reason="stop", + index=0, message=_stub_msg, finish_reason=_stub_finish_reason, )], usage=None, ) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 62636809094..49907e2c331 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -609,6 +609,7 @@ class ContextCompressor(ContextEngine): """Update tracked token usage from API response.""" self.last_prompt_tokens = usage.get("prompt_tokens", 0) self.last_completion_tokens = usage.get("completion_tokens", 0) + self.last_total_tokens = usage.get("total_tokens", self.last_prompt_tokens + self.last_completion_tokens) def should_compress(self, prompt_tokens: int = None) -> bool: """Check if context exceeds the compression threshold. @@ -897,7 +898,7 @@ class ContextCompressor(ContextEngine): into the warning log. """ self._summary_model_fallen_back = True - logging.warning( + logger.warning( "Summary model '%s' %s (%s). " "Falling back to main model '%s' for compression.", self.summary_model, reason, e, self.model, @@ -1086,7 +1087,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio # No provider configured โ€” long cooldown, unlikely to self-resolve self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS self._last_summary_error = "no auxiliary LLM provider configured" - logging.warning("Context compression: no provider available for " + logger.warning("Context compression: no provider available for " "summary. Middle turns will be dropped without summary " "for %d seconds.", _SUMMARY_FAILURE_COOLDOWN_SECONDS) @@ -1182,7 +1183,7 @@ The user has requested that this compaction PRIORITISE preserving all informatio if len(err_text) > 220: err_text = err_text[:217].rstrip() + "..." self._last_summary_error = err_text - logging.warning( + logger.warning( "Failed to generate context summary: %s. " "Further summary attempts paused for %d seconds.", e, diff --git a/agent/context_engine.py b/agent/context_engine.py index 2947da54d8c..c30a7a84752 100644 --- a/agent/context_engine.py +++ b/agent/context_engine.py @@ -200,6 +200,7 @@ class ContextEngine(ABC): base_url: str = "", api_key: str = "", provider: str = "", + api_mode: str = "", ) -> None: """Called when the user switches models or on fallback activation. diff --git a/agent/conversation_compression.py b/agent/conversation_compression.py index cd1b133fa4a..a620f343e99 100644 --- a/agent/conversation_compression.py +++ b/agent/conversation_compression.py @@ -381,12 +381,12 @@ def compress_context( agent._session_db.end_session(agent.session_id, "compression") old_session_id = agent.session_id agent.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" - os.environ["HERMES_SESSION_ID"] = agent.session_id try: - from gateway.session_context import _SESSION_ID - _SESSION_ID.set(agent.session_id) + from gateway.session_context import set_current_session_id + + set_current_session_id(agent.session_id) except Exception: - pass + os.environ["HERMES_SESSION_ID"] = agent.session_id agent._session_db_created = False agent._session_db.create_session( session_id=agent.session_id, diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index caac0d3e8f2..90b45c814d3 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -46,6 +46,7 @@ from agent.message_sanitization import ( _strip_non_ascii, ) from agent.model_metadata import ( + MINIMUM_CONTEXT_LENGTH, estimate_messages_tokens_rough, estimate_request_tokens_rough, get_next_probe_tier, @@ -73,6 +74,50 @@ from utils import base_url_host_matches, env_var_enabled logger = logging.getLogger(__name__) +def _ollama_context_limit_error(agent: Any, request_tokens: int) -> Optional[str]: + """Return a user-facing error when Ollama is loaded with too little context.""" + if not getattr(agent, "tools", None): + return None + + runtime_ctx = getattr(agent, "_ollama_num_ctx", None) + if not isinstance(runtime_ctx, int) or runtime_ctx <= 0: + return None + if runtime_ctx >= MINIMUM_CONTEXT_LENGTH: + return None + + model = getattr(agent, "model", "") or "the selected model" + base_url = getattr(agent, "base_url", "") or "unknown base URL" + provider = getattr(agent, "provider", "") or "unknown" + tool_count = len(getattr(agent, "tools", None) or []) + + logger.warning( + "Ollama runtime context too small for Hermes tool use: " + "model=%s provider=%s base_url=%s runtime_context=%d " + "minimum_context=%d estimated_request_tokens=%d tool_count=%d " + "session=%s", + model, + provider, + base_url, + runtime_ctx, + MINIMUM_CONTEXT_LENGTH, + request_tokens, + tool_count, + getattr(agent, "session_id", None) or "none", + ) + + return ( + f"Ollama loaded `{model}` with only {runtime_ctx:,} tokens of runtime " + f"context, but Hermes needs at least {MINIMUM_CONTEXT_LENGTH:,} tokens " + "for reliable tool use.\n\n" + "Increase the Ollama context for this model and restart/reload the " + "model before trying again. A known-good starting point is 65,536 " + "tokens. In Hermes config, set `model.ollama_num_ctx: 65536` " + "(and `model.context_length: 65536` if you also override the displayed " + "model context). If you manage the model through an Ollama Modelfile, " + "set `PARAMETER num_ctx 65536` there instead." + ) + + def _ra(): """Lazy reference to ``run_agent`` so callers can patch ``run_agent.handle_function_call`` / ``run_agent._set_interrupt`` / @@ -527,6 +572,7 @@ def run_conversation( api_call_count = 0 final_response = None interrupted = False + failed = False codex_ack_continuations = 0 length_continue_retries = 0 truncated_tool_call_retries = 0 @@ -883,6 +929,26 @@ def run_conversation( # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) approx_tokens = estimate_messages_tokens_rough(api_messages) + approx_request_tokens = estimate_request_tokens_rough( + api_messages, tools=agent.tools or None + ) + + _runtime_context_error = _ollama_context_limit_error( + agent, approx_request_tokens + ) + if _runtime_context_error: + final_response = _runtime_context_error + failed = True + _turn_exit_reason = "ollama_runtime_context_too_small" + messages.append({"role": "assistant", "content": final_response}) + agent._emit_status("โŒ Ollama runtime context is too small for Hermes tool use") + api_call_count -= 1 + agent._api_call_count = api_call_count + try: + agent.iteration_budget.refund() + except Exception: + pass + break # Thinking spinner for quiet mode (animated during API call) thinking_spinner = None @@ -923,6 +989,7 @@ def run_conversation( copilot_auth_retry_attempted=False thinking_sig_retry_attempted = False image_shrink_retry_attempted = False + multimodal_tool_content_retry_attempted = False oauth_1m_beta_retry_attempted = False llama_cpp_grammar_retry_attempted = False has_retried_429 = False @@ -1116,7 +1183,7 @@ def run_conversation( else str(_codex_error_obj) if _codex_error_obj else f"Responses API returned status '{_codex_resp_status}'" ) - logging.warning( + logger.warning( "Codex response status='%s' (error=%s). Routing to fallback. %s", _codex_resp_status, _codex_error_msg, agent._client_log_context(), @@ -1268,7 +1335,7 @@ def run_conversation( primary_recovery_attempted = False continue agent._emit_status(f"โŒ Max retries ({max_retries}) exceeded for invalid responses. Giving up.") - logging.error(f"{agent.log_prefix}Invalid API response after {max_retries} retries.") + logger.error(f"{agent.log_prefix}Invalid API response after {max_retries} retries.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -1281,7 +1348,7 @@ def run_conversation( # Backoff before retry โ€” jittered exponential: 5s base, 120s cap wait_time = jittered_backoff(retry_count, base_delay=5.0, max_delay=120.0) agent._vprint(f"{agent.log_prefix}โณ Retrying in {wait_time:.1f}s ({_failure_hint})...", force=True) - logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") + logger.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") # Sleep in small increments to stay responsive to interrupts sleep_end = time.time() + wait_time @@ -1347,7 +1414,18 @@ def run_conversation( finish_reason = "length" if finish_reason == "length": - agent._vprint(f"{agent.log_prefix}โš ๏ธ Response truncated (finish_reason='length') - model hit max output tokens", force=True) + if getattr(response, "id", "") == "partial-stream-stub": + agent._vprint( + f"{agent.log_prefix}โš ๏ธ Stream interrupted by network error " + f"(finish_reason='length' on partial-stream-stub)", + force=True, + ) + else: + agent._vprint( + f"{agent.log_prefix}โš ๏ธ Response truncated " + f"(finish_reason='length') - model hit max output tokens", + force=True, + ) # Normalize the truncated response to a single OpenAI-style # message shape so text-continuation and tool-call retry @@ -1440,17 +1518,40 @@ def run_conversation( truncated_response_parts.append(assistant_message.content) if length_continue_retries < 3: - agent._vprint( - f"{agent.log_prefix}โ†ป Requesting continuation " - f"({length_continue_retries}/3)..." + # Distinguish a real output-token truncation + # from a partial-stream-stub network error + # (#30963). Same continuation machinery, + # but the prompt has to tell the truth or + # the model goes off rails ("I wasn't + # truncated, I'm done"). + _is_partial_stream_stub = ( + getattr(response, "id", "") == "partial-stream-stub" ) - continue_msg = { - "role": "user", - "content": ( + if _is_partial_stream_stub: + agent._vprint( + f"{agent.log_prefix}โ†ป Stream interrupted โ€” " + f"requesting continuation " + f"({length_continue_retries}/3)..." + ) + _continue_content = ( + "[System: The previous response was cut off by a " + "network error mid-stream. Continue exactly where " + "you left off. Do not restart or repeat prior text. " + "Finish the answer directly.]" + ) + else: + agent._vprint( + f"{agent.log_prefix}โ†ป Requesting continuation " + f"({length_continue_retries}/3)..." + ) + _continue_content = ( "[System: Your previous response was truncated by the output " "length limit. Continue exactly where you left off. Do not " "restart or repeat prior text. Finish the answer directly.]" - ), + ) + continue_msg = { + "role": "user", + "content": _continue_content, } messages.append(continue_msg) agent._session_messages = messages @@ -1994,6 +2095,31 @@ def run_conversation( "or shrink didn't reduce size; surfacing original error." ) + # Multimodal-tool-content recovery: providers that follow + # the OpenAI spec strictly (tool message content must be a + # string) reject our list-type content with a 400. Strip + # image parts from any list-type tool messages, mark the + # (provider, model) as no-list-tool-content for the rest + # of this session so future tool results preemptively + # downgrade, and retry once. See issue #27344. + if ( + classified.reason == FailoverReason.multimodal_tool_content_unsupported + and not multimodal_tool_content_retry_attempted + ): + multimodal_tool_content_retry_attempted = True + if agent._try_strip_image_parts_from_tool_messages(api_messages): + agent._vprint( + f"{agent.log_prefix}๐Ÿ“ Provider rejected list-type tool content โ€” " + f"downgraded screenshots to text and retrying...", + force=True, + ) + continue + else: + logger.info( + "multimodal-tool-content recovery: no list-type tool " + "messages with image parts found; surfacing original error." + ) + # Anthropic OAuth subscription rejected the 1M-context beta # header ("long context beta is not yet available for this # subscription"). Disable the beta for the rest of this @@ -2133,7 +2259,7 @@ def run_conversation( f"stripped all thinking blocks, retrying...", force=True, ) - logging.warning( + logger.warning( "%sThinking block signature recovery: stripped " "reasoning_details from %d messages", agent.log_prefix, len(messages), @@ -2158,7 +2284,7 @@ def run_conversation( from tools.schema_sanitizer import strip_pattern_and_format _, _stripped = strip_pattern_and_format(agent.tools) except Exception as _strip_exc: # pragma: no cover โ€” defensive - logging.warning( + logger.warning( "%sllama.cpp grammar recovery: strip helper failed: %s", agent.log_prefix, _strip_exc, ) @@ -2169,7 +2295,7 @@ def run_conversation( f"stripped {_stripped} pattern/format keyword(s), retrying...", force=True, ) - logging.warning( + logger.warning( "%sllama.cpp grammar recovery: stripped %d " "pattern/format keyword(s) from tool schemas", agent.log_prefix, _stripped, @@ -2177,7 +2303,7 @@ def run_conversation( continue # No keywords found to strip โ€” fall through to normal # retry path rather than loop forever on the same error. - logging.warning( + logger.warning( "%sllama.cpp grammar error but no pattern/format " "keywords to strip โ€” falling through to normal retry", agent.log_prefix, @@ -2278,6 +2404,7 @@ def run_conversation( base_url=agent.base_url, api_key=getattr(agent, "api_key", ""), provider=agent.provider, + api_mode=agent.api_mode, ) # Context probing flags โ€” only set on built-in # compressor (plugin engines manage their own). @@ -2391,7 +2518,7 @@ def run_conversation( error_context=error_context, ) else: - logging.info( + logger.info( "Nous 429 looks like upstream capacity " "(no exhausted bucket in headers or " "last-known state) -- not tripping " @@ -2451,7 +2578,7 @@ def run_conversation( if compression_attempts > max_compression_attempts: agent._vprint(f"{agent.log_prefix}โŒ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True) agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก Try /new to start a fresh conversation, or /compress to retry compression.", force=True) - logging.error(f"{agent.log_prefix}413 compression failed after {max_compression_attempts} attempts.") + logger.error(f"{agent.log_prefix}413 compression failed after {max_compression_attempts} attempts.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -2482,7 +2609,7 @@ def run_conversation( else: agent._vprint(f"{agent.log_prefix}โŒ Payload too large and cannot compress further.", force=True) agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก Try /new to start a fresh conversation, or /compress to retry compression.", force=True) - logging.error(f"{agent.log_prefix}413 payload too large. Cannot compress further.") + logger.error(f"{agent.log_prefix}413 payload too large. Cannot compress further.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -2535,7 +2662,7 @@ def run_conversation( if compression_attempts > max_compression_attempts: agent._vprint(f"{agent.log_prefix}โŒ Max compression attempts ({max_compression_attempts}) reached.", force=True) agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก Try /new to start a fresh conversation, or /compress to retry compression.", force=True) - logging.error(f"{agent.log_prefix}Context compression failed after {max_compression_attempts} attempts.") + logger.error(f"{agent.log_prefix}Context compression failed after {max_compression_attempts} attempts.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -2587,6 +2714,7 @@ def run_conversation( base_url=agent.base_url, api_key=getattr(agent, "api_key", ""), provider=agent.provider, + api_mode=agent.api_mode, ) # Context probing flags โ€” only set on built-in # compressor (plugin engines manage their own). @@ -2608,7 +2736,7 @@ def run_conversation( if compression_attempts > max_compression_attempts: agent._vprint(f"{agent.log_prefix}โŒ Max compression attempts ({max_compression_attempts}) reached.", force=True) agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก Try /new to start a fresh conversation, or /compress to retry compression.", force=True) - logging.error(f"{agent.log_prefix}Context compression failed after {max_compression_attempts} attempts.") + logger.error(f"{agent.log_prefix}Context compression failed after {max_compression_attempts} attempts.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -2641,7 +2769,7 @@ def run_conversation( # Can't compress further and already at minimum tier agent._vprint(f"{agent.log_prefix}โŒ Context length exceeded and cannot compress further.", force=True) agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก The conversation has accumulated too much content. Try /new to start fresh, or /compress to manually trigger compression.", force=True) - logging.error(f"{agent.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") + logger.error(f"{agent.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") agent._persist_session(messages, conversation_history) return { "messages": messages, @@ -2678,6 +2806,21 @@ def run_conversation( # retryable=True mapping takes effect instead. and not isinstance(api_error, ssl.SSLError) ) + # ``FailoverReason.billing`` (HTTP 402) is NOT in this + # exclusion set. By the time we reach this block: + # โ€ข credential-pool rotation (line ~2031) has already + # fired for billing and either ``continue``d or + # returned (False, ...) โ€” pool is exhausted or absent. + # โ€ข the eager-fallback branch above (line ~2422) also + # fires on billing and ``continue``s if a fallback + # provider is configured. + # Falling through to here means BOTH recovery paths + # gave up. Treating 402 as retryable from this point + # just burns more paid requests against a depleted + # balance with no recovery mechanism left โ€” see #31273 + # (real-world: ~$40 in 48h on a 24/7 gateway). Aborting + # mirrors how 401/403 (also ``should_fallback=True``) + # already behave once their recovery paths have failed. is_client_error = ( is_local_validation_error or ( @@ -2685,7 +2828,6 @@ def run_conversation( and not classified.should_compress and classified.reason not in { FailoverReason.rate_limit, - FailoverReason.billing, FailoverReason.overloaded, FailoverReason.context_overflow, FailoverReason.payload_too_large, @@ -2734,7 +2876,7 @@ def run_conversation( agent._vprint(f"{agent.log_prefix} โ€ข Check credits: https://openrouter.ai/settings/credits", force=True) else: agent._vprint(f"{agent.log_prefix} ๐Ÿ’ก This type of error won't be fixed by retrying.", force=True) - logging.error(f"{agent.log_prefix}Non-retryable client error: {api_error}") + logger.error(f"{agent.log_prefix}Non-retryable client error: {api_error}") # Skip session persistence when the error is likely # context-overflow related (status 400 + large session). # Persisting the failed user message would make the @@ -2811,7 +2953,7 @@ def run_conversation( force=True, ) - logging.error( + logger.error( "%sAPI call failed after %s retries. %s | provider=%s model=%s msgs=%s tokens=~%s", agent.log_prefix, max_retries, _final_summary, _provider, _model, len(api_messages), f"{approx_tokens:,}", @@ -3342,6 +3484,19 @@ def run_conversation( f"โš ๏ธ Tool guardrail halted {decision.tool_name}: {decision.code}" ) messages.append({"role": "assistant", "content": final_response}) + # Emit the halt message to the client so it's not + # indistinguishable from a crash. The stream display + # was flushed (callback(None)) before tool execution, + # but the callback is still alive โ€” fire the text + # through it so SSE/TUI clients see the explanation. + if final_response: + agent._safe_print(f"\n{final_response}\n") + if agent.stream_delta_callback: + try: + agent.stream_delta_callback(final_response) + agent.stream_delta_callback(None) + except Exception: + pass break # Reset per-turn retry counters after successful tool @@ -3848,7 +4003,11 @@ def run_conversation( ) # Determine if conversation completed successfully - completed = final_response is not None and api_call_count < agent.max_iterations + completed = ( + final_response is not None + and api_call_count < agent.max_iterations + and not failed + ) # Save trajectory if enabled. ``user_message`` may be a multimodal # list of parts; the trajectory format wants a plain string. @@ -3933,6 +4092,8 @@ def run_conversation( except Exception as _ver_err: logger.debug("file-mutation verifier footer failed: %s", _ver_err) + _response_transformed = False + # Plugin hook: transform_llm_output # Fired once per turn after the tool-calling loop completes. # Plugins can transform the LLM's output text before it's returned. @@ -3950,6 +4111,7 @@ def run_conversation( for _hook_result in _transform_results: if isinstance(_hook_result, str) and _hook_result: final_response = _hook_result + _response_transformed = True break # First non-empty string wins except Exception as exc: logger.warning("transform_llm_output hook failed: %s", exc) @@ -3998,8 +4160,10 @@ def run_conversation( "api_calls": api_call_count, "completed": completed, "turn_exit_reason": _turn_exit_reason, + "failed": failed, "partial": False, # True only when stopped due to invalid tool calls "interrupted": interrupted, + "response_transformed": _response_transformed, "response_previewed": getattr(agent, "_response_was_previewed", False), "model": agent.model, "provider": agent.provider, diff --git a/agent/display.py b/agent/display.py index cdfc88f46a3..02880a83e0d 100644 --- a/agent/display.py +++ b/agent/display.py @@ -787,33 +787,65 @@ class KawaiiSpinner: # Cute tool message (completion line that replaces the spinner) # ========================================================================= +_ERROR_SUFFIX_MAX_LEN = 48 + + +def _trim_error(msg: str) -> str: + """Shrink an error message for inline display in a tool status line. + + Strips overly long absolute paths down to just the filename so the + suffix stays readable on narrow terminals. + """ + msg = msg.strip() + # Common case: "File not found: /very/long/absolute/path/foo.py" + if "File not found:" in msg: + _, _, tail = msg.partition("File not found:") + tail = tail.strip() + if "/" in tail: + msg = f"File not found: {tail.rsplit('/', 1)[-1]}" + if len(msg) > _ERROR_SUFFIX_MAX_LEN: + msg = msg[: _ERROR_SUFFIX_MAX_LEN - 3] + "..." + return msg + + def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]: """Inspect a tool result string for signs of failure. - Returns ``(is_failure, suffix)`` where *suffix* is an informational tag - like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic - failures. On success, returns ``(False, "")``. + Returns ``(is_failure, suffix)`` where *suffix* is a short informational + tag like ``" [exit 1]"`` for terminal failures, ``" [full]"`` for memory + overflow, or a trimmed error message (``" [File not found: foo.py]"``). + On success returns ``(False, "")``. """ if result is None: return False, "" if file_mutation_result_landed(tool_name, result): return False, "" + data = safe_json_loads(result) + + # Terminal: non-zero exit code is the canonical failure signal. if tool_name == "terminal": - data = safe_json_loads(result) if isinstance(data, dict): exit_code = data.get("exit_code") if exit_code is not None and exit_code != 0: + err_msg = data.get("error") + if err_msg: + return True, f" [{_trim_error(str(err_msg))}]" return True, f" [exit {exit_code}]" return False, "" - # Memory-specific: distinguish "full" from real errors + # Memory: distinguish "store full" from real errors. if tool_name == "memory": - data = safe_json_loads(result) if isinstance(data, dict): if data.get("success") is False and "exceed the limit" in data.get("error", ""): return True, " [full]" + # Structured error in JSON result (any tool that surfaces {"error": ...}). + if isinstance(data, dict): + err = data.get("error") or data.get("message") + if err and (data.get("success") is False or "error" in data): + return True, f" [{_trim_error(str(err))}]" + # Generic heuristic for non-terminal tools # Multimodal tool results (dicts with _multimodal=True) are not strings โ€” # treat them as successes since failures would be JSON-encoded strings. @@ -921,11 +953,29 @@ def get_cute_tool_message( if tool_name == "todo": todos_arg = args.get("todos") merge = args.get("merge", False) + # Parse result for completion progress + total = 0 + done = 0 + if result: + try: + data = safe_json_loads(result) + if data: + s = data.get("summary", {}) + total = s.get("total", 0) + done = s.get("completed", 0) + except Exception: + pass if todos_arg is None: + if total > 0: + return _wrap(f"โ”Š ๐Ÿ“‹ plan {done}/{total} task(s) {dur}") return _wrap(f"โ”Š ๐Ÿ“‹ plan reading tasks {dur}") elif merge: + if total > 0 and done > 0: + return _wrap(f"โ”Š ๐Ÿ“‹ plan update {done}/{total} โœ“ {dur}") return _wrap(f"โ”Š ๐Ÿ“‹ plan update {len(todos_arg)} task(s) {dur}") else: + if total > 0 and done > 0: + return _wrap(f"โ”Š ๐Ÿ“‹ plan {done}/{total} task(s) {dur}") return _wrap(f"โ”Š ๐Ÿ“‹ plan {len(todos_arg)} task(s) {dur}") if tool_name == "session_search": return _wrap(f"โ”Š ๐Ÿ” recall \"{_trunc(args.get('query', ''), 35)}\" {dur}") diff --git a/agent/error_classifier.py b/agent/error_classifier.py index 42eb42d6803..0afcf66d445 100644 --- a/agent/error_classifier.py +++ b/agent/error_classifier.py @@ -50,6 +50,7 @@ class FailoverReason(enum.Enum): # Request format format_error = "format_error" # 400 bad request โ€” abort or strip + retry + multimodal_tool_content_unsupported = "multimodal_tool_content_unsupported" # Provider rejected list-type content in tool messages (e.g. Xiaomi MiMo) โ€” downgrade to text and retry # Provider-specific thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid @@ -165,6 +166,32 @@ _IMAGE_TOO_LARGE_PATTERNS = [ # the likely culprit; we still try the shrink path before giving up. ] +# Providers that follow the OpenAI spec strictly require tool message +# ``content`` to be a string. Some (Anthropic native, Codex Responses, +# Gemini native, first-party OpenAI) extend this to accept a content-parts +# list (text + image_url) so screenshots from computer_use survive. Others +# (Xiaomi MiMo, some Alibaba endpoints, a long tail of OpenAI-compatible +# providers) reject the list with a 400 โ€” the patterns below are the most +# common error shapes we see. Recovery: strip image parts from tool +# messages in-place, record the (provider, model) for the rest of the +# session so we don't waste another call learning the same lesson, retry. +# +# See: https://github.com/NousResearch/hermes-agent/issues/27344 +_MULTIMODAL_TOOL_CONTENT_PATTERNS = [ + # Xiaomi MiMo: {"error":{"code":"400","message":"Param Incorrect","param":"text is not set"}} + "text is not set", + # Generic "tool message must be string" shapes + "tool message content must be a string", + "tool content must be a string", + "tool message must be a string", + # OpenAI-compat servers that reject list-type tool content with a + # schema-validation message + "expected string, got list", + "expected string, got array", + # Alibaba/DashScope variant + "tool_call.content must be string", +] + # Context overflow patterns _CONTEXT_OVERFLOW_PATTERNS = [ "context length", @@ -213,6 +240,24 @@ _MODEL_NOT_FOUND_PATTERNS = [ "unsupported model", ] +# Request-validation patterns โ€” the request is malformed and will fail +# identically on every retry. Some OpenAI-compatible gateways (notably +# codex.nekos.me) return these as 5xx instead of the standard 4xx, which +# makes the generic "5xx โ†’ retryable server_error" rule misfire: the retry +# loop hammers the same deterministic rejection 3+ times, then the +# transport-recovery path resets the counter and does it again, producing +# a request flood. When a 5xx body carries one of these unambiguous +# request-validation signals, classify as a non-retryable format_error so +# the loop fails fast and falls back instead of looping. +_REQUEST_VALIDATION_PATTERNS = [ + "unknown parameter", + "unsupported parameter", + "unrecognized request argument", + "invalid_request_error", + "unknown_parameter", + "unsupported_parameter", +] + # OpenRouter aggregator policy-block patterns. # # When a user's OpenRouter account privacy setting (or a per-request @@ -718,6 +763,23 @@ def _classify_by_status( ) if status_code in {500, 502}: + # Some OpenAI-compatible gateways return request-validation errors + # with a 5xx status (codex.nekos.me returns 502 for unknown/ + # unsupported parameters). These are deterministic โ€” every retry + # gets the identical rejection โ€” so the generic "5xx โ†’ retryable + # server_error" rule turns one bad request into a retry flood. + # Detect the unambiguous request-validation signals (in either the + # message text or the structured error code) and fail fast. + if ( + any(p in error_msg for p in _REQUEST_VALIDATION_PATTERNS) + or error_code.lower() in {"invalid_request_error", "unknown_parameter", + "unsupported_parameter"} + ): + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) return result_fn(FailoverReason.server_error, retryable=True) if status_code in {503, 529}: @@ -781,6 +843,19 @@ def _classify_400( ) -> ClassifiedError: """Classify 400 Bad Request โ€” context overflow, format error, or generic.""" + # Multimodal tool content rejected from 400. Must be checked BEFORE + # image_too_large because the recovery is different (strip image parts + # from tool messages, mark the model as no-list-tool-content for the + # rest of the session) and BEFORE context_overflow because some of the + # patterns ("text is not set") are ambiguous in isolation but become + # specific when combined with a 400 on a request known to contain + # multimodal tool content. + if any(p in error_msg for p in _MULTIMODAL_TOOL_CONTENT_PATTERNS): + return result_fn( + FailoverReason.multimodal_tool_content_unsupported, + retryable=True, + ) + # Image-too-large from 400 (Anthropic's 5 MB per-image check fires this way). # Must be checked BEFORE context_overflow because messages can trip both # patterns ("exceeds" + "image") and image-shrink is a cheaper recovery. @@ -922,6 +997,13 @@ def _classify_by_message( should_compress=True, ) + # Multimodal tool content patterns (from message text when no status_code) + if any(p in error_msg for p in _MULTIMODAL_TOOL_CONTENT_PATTERNS): + return result_fn( + FailoverReason.multimodal_tool_content_unsupported, + retryable=True, + ) + # Image-too-large patterns (from message text when no status_code) if any(p in error_msg for p in _IMAGE_TOO_LARGE_PATTERNS): return result_fn( diff --git a/agent/file_safety.py b/agent/file_safety.py index f8678b68c06..502c3b254a8 100644 --- a/agent/file_safety.py +++ b/agent/file_safety.py @@ -97,6 +97,43 @@ def is_write_denied(path: str) -> bool: if resolved.startswith(prefix): return True + # Hermes control-plane files: block both the ACTIVE profile's view + # (hermes_home) AND the global root view. Without the root pass, a + # profile-mode session leaves /auth.json + /config.yaml + # writable โ€” letting a prompt-injected write_file overwrite the global + # files that every profile inherits from (same shape as #15981). + control_file_names = ("auth.json", "config.yaml", "webhook_subscriptions.json") + mcp_tokens_dir_name = "mcp-tokens" + + hermes_dirs = [] + for base in (_hermes_home_path(), _hermes_root_path()): + try: + real = os.path.realpath(base) + if real not in hermes_dirs: + hermes_dirs.append(real) + except Exception: + continue + + for base_real in hermes_dirs: + for name in control_file_names: + try: + if resolved == os.path.realpath(os.path.join(base_real, name)): + return True + except Exception: + continue + try: + mcp_real = os.path.realpath(os.path.join(base_real, mcp_tokens_dir_name)) + if resolved == mcp_real or resolved.startswith(mcp_real + os.sep): + return True + except Exception: + pass + try: + pairing_real = os.path.realpath(os.path.join(base_real, "pairing")) + if resolved == pairing_real or resolved.startswith(pairing_real + os.sep): + return True + except Exception: + pass + safe_root = get_safe_write_root() if safe_root and not (resolved == safe_root or resolved.startswith(safe_root + os.sep)): return True @@ -105,21 +142,266 @@ def is_write_denied(path: str) -> bool: def get_read_block_error(path: str) -> Optional[str]: - """Return an error message when a read targets internal Hermes cache files.""" + """Return an error message when a read targets a denied Hermes path. + + Two categories are blocked: + + * Internal Hermes cache files under ``HERMES_HOME/skills/.hub`` โ€” + readable metadata that an attacker could use as a prompt-injection + carrier. + * Credential / secret stores under HERMES_HOME and the global Hermes + root: ``auth.json``, ``auth.lock``, ``.anthropic_oauth.json``, + ``.env``, ``webhook_subscriptions.json``, and anything under + ``mcp-tokens/``. These hold plaintext provider keys, OAuth tokens, + and HMAC secrets that the agent never needs to read directly โ€” + provider tools / gateway adapters consume them through internal + channels. + + **This is NOT a security boundary.** The terminal tool runs as the + same OS user with shell access; the agent can still ``cat auth.json`` + or ``cat ~/.hermes/.env`` and exfiltrate the file. The read-deny exists + as defense-in-depth that: + + * Returns a clear error to models that respect tool denials, which + empirically prompts most modern models to stop rather than reach + for the shell. + * Surfaces a visible audit trail when something tries to read + credentials โ€” easier to spot in logs than a generic ``cat``. + + Treat any user-visible framing around this as "may help" rather than + "stops attackers." A determined model or malicious instruction can + always shell out. + + Callers that resolve relative paths against a non-process cwd + (e.g. ``TERMINAL_CWD`` in ``tools/file_tools.py``) MUST pre-resolve + and pass the absolute path string. This function's own ``resolve()`` + is anchored at the Python process cwd, so a relative input like + ``"auth.json"`` would otherwise miss the denylist when the task's + terminal cwd differs from the process cwd. + """ resolved = Path(path).expanduser().resolve() - hermes_home = _hermes_home_path().resolve() - blocked_dirs = [ - hermes_home / "skills" / ".hub" / "index-cache", - hermes_home / "skills" / ".hub", - ] - for blocked in blocked_dirs: + + # Resolve BOTH the active HERMES_HOME (profile-aware) AND the global + # Hermes root so credential stores at /auth.json etc. are also + # blocked when running under a profile (HERMES_HOME points at + # /profiles/ in profile mode). Same shape as the write + # deny widening (#15981, #14157). + hermes_dirs: list[Path] = [] + for base in (_hermes_home_path(), _hermes_root_path()): try: - resolved.relative_to(blocked) + real = base.resolve() + if real not in hermes_dirs: + hermes_dirs.append(real) + except Exception: + continue + + # Skills .hub: prompt-injection carriers. + for hd in hermes_dirs: + blocked_dirs = [ + hd / "skills" / ".hub" / "index-cache", + hd / "skills" / ".hub", + ] + for blocked in blocked_dirs: + try: + resolved.relative_to(blocked) + except ValueError: + continue + return ( + f"Access denied: {path} is an internal Hermes cache file " + "and cannot be read directly to prevent prompt injection. " + "Use the skills_list or skill_view tools instead." + ) + + # Credential / secret stores. Exact-file matches under either + # HERMES_HOME or . + credential_file_names = ( + "auth.json", + "auth.lock", + ".anthropic_oauth.json", + ".env", + "webhook_subscriptions.json", + ) + for hd in hermes_dirs: + for name in credential_file_names: + try: + blocked = (hd / name).resolve() + except Exception: + continue + if resolved == blocked: + return ( + f"Access denied: {path} is a Hermes credential store " + "and cannot be read directly. Provider tools consume " + "these credentials through internal channels. " + "(Defense-in-depth โ€” not a security boundary; the " + "terminal tool can still bypass.)" + ) + + # mcp-tokens/: directory prefix match โ€” anything inside is OAuth + # token material. + for hd in hermes_dirs: + try: + mcp_tokens = (hd / "mcp-tokens").resolve() + except Exception: + continue + if resolved == mcp_tokens: + return ( + f"Access denied: {path} is the Hermes MCP token directory " + "and cannot be read directly. (Defense-in-depth โ€” not a " + "security boundary; the terminal tool can still bypass.)" + ) + try: + resolved.relative_to(mcp_tokens) except ValueError: continue return ( - f"Access denied: {path} is an internal Hermes cache file " - "and cannot be read directly to prevent prompt injection. " - "Use the skills_list or skill_view tools instead." + f"Access denied: {path} is a Hermes MCP token file " + "and cannot be read directly. (Defense-in-depth โ€” not a " + "security boundary; the terminal tool can still bypass.)" ) + return None + + +# --------------------------------------------------------------------------- +# Cross-profile write guard (#TBD) +# +# Hermes profiles are separate HERMES_HOME dirs under +# ``/profiles//``. Each profile has its own skills/, plugins/, +# cron/, memories/. When an agent runs under one profile, writing into +# ANOTHER profile's directories is almost always wrong โ€” those skills / +# plugins / cron jobs / memories affect a different session the user runs +# from a different shell. +# +# Soft guard, NOT a security boundary: the agent runs as the same OS user +# and has unrestricted terminal access, so this returns a warning the model +# can choose to honor or override with ``cross_profile=True``. Same shape +# as the dangerous-command approval flow โ€” the agent is told the boundary +# exists, and explicit user direction is required to cross it. +# +# Reference: May 2026 incident where a hermes-security profile session +# edited skills under both ``~/.hermes/profiles/hermes-security/skills/`` +# AND ``~/.hermes/skills/`` (the default profile's skills) without realizing +# the second path belonged to a different profile. +# --------------------------------------------------------------------------- + +# Profile-scoped directories under HERMES_HOME / / /profiles// +# that should be guarded. Adding a new area here extends the guard with no +# other code change. +PROFILE_SCOPED_AREAS = ("skills", "plugins", "cron", "memories") + + +def _resolve_active_profile_name() -> str: + """Return the active profile name derived from HERMES_HOME. + + ``~/.hermes`` -> ``"default"`` + ``~/.hermes/profiles/X`` -> ``"X"`` + + Falls back to ``"default"`` on any resolution failure so the guard + never raises into the tool path. + """ + try: + home_real = _hermes_home_path().resolve() + root_real = _hermes_root_path().resolve() + except (OSError, RuntimeError): + return "default" + profiles_dir = root_real / "profiles" + try: + rel = home_real.relative_to(profiles_dir) + parts = rel.parts + if len(parts) >= 1: + return parts[0] + except ValueError: + pass + return "default" + + +def classify_cross_profile_target(path: str) -> Optional[dict]: + """Classify a write target as cross-profile if it lands in another + profile's scoped area (skills/plugins/cron/memories). + + Returns ``None`` when the target is outside Hermes scope, or is inside + the ACTIVE profile, or doesn't hit a profile-scoped area. Otherwise + returns a dict with: + + * ``active_profile``: name of the profile the agent is running as + * ``target_profile``: name of the profile the path belongs to + * ``area``: which scoped area (``"skills"``, ``"plugins"``, etc.) + * ``target_path``: the resolved path string + + The caller decides what to do with the result โ€” surface a warning to + the model, prompt the user, or (with explicit consent / + ``cross_profile=True``) proceed anyway. + """ + try: + target = Path(os.path.expanduser(str(path))).resolve() + root_real = _hermes_root_path().resolve() + except (OSError, RuntimeError): + return None + + target_profile: Optional[str] = None + area: Optional[str] = None + + try: + rel = target.relative_to(root_real) + except ValueError: + return None + + parts = rel.parts + if not parts: + return None + + if parts[0] in PROFILE_SCOPED_AREAS: + # ``//...`` โ†’ default profile. + target_profile = "default" + area = parts[0] + elif ( + parts[0] == "profiles" + and len(parts) >= 3 + and parts[2] in PROFILE_SCOPED_AREAS + ): + # ``/profiles///...`` โ†’ named profile. + target_profile = parts[1] + area = parts[2] + else: + return None + + active_profile = _resolve_active_profile_name() + if target_profile == active_profile: + # In-profile write โ€” not a cross-profile event. + return None + + return { + "active_profile": active_profile, + "target_profile": target_profile, + "area": area, + "target_path": str(target), + } + + +def get_cross_profile_warning(path: str) -> Optional[str]: + """Return a model-facing warning string when ``path`` is cross-profile. + + Returns ``None`` when the write is in-scope (same profile) or outside + Hermes entirely. Caller is expected to surface the warning to the + agent as a tool-result error, NOT to silently allow the write โ€” the + agent must either get explicit user direction to proceed, or pass + ``cross_profile=True`` to its write tool. + + This is defense-in-depth: the terminal tool runs as the same OS user + and can write any of these paths without going through this guard. + Treat the guard as a confusion-reducer, not a security boundary. + """ + info = classify_cross_profile_target(path) + if info is None: + return None + return ( + f"Cross-profile write blocked by soft guard: {info['target_path']} " + f"belongs to Hermes profile {info['target_profile']!r}, but the " + f"agent is running under profile {info['active_profile']!r}. " + f"Editing another profile's {info['area']}/ will affect that " + f"profile's future sessions, not the one you are currently in. " + f"Confirm with the user before proceeding. To bypass this guard " + f"after explicit user direction, retry the call with " + f"``cross_profile=True``. (Defense-in-depth โ€” not a security " + f"boundary; the terminal tool can still bypass.)" + ) diff --git a/agent/model_metadata.py b/agent/model_metadata.py index b8ec0d6509e..e9ec4bf03a7 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -209,6 +209,7 @@ DEFAULT_CONTEXT_LENGTHS = { # via a custom provider. Values sourced from models.dev (2026-04). # Keys use substring matching (longest-first), so e.g. "grok-4.20" # matches "grok-4.20-0309-reasoning" / "-non-reasoning" / "-multi-agent-0309". + "grok-build": 256000, # grok-build-0.1 "grok-code-fast": 256000, # grok-code-fast-1 "grok-4-1-fast": 2000000, # grok-4-1-fast-(non-)reasoning "grok-2-vision": 8192, # grok-2-vision, -1212, -latest @@ -640,7 +641,7 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any return cache except Exception as e: - logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}") + logger.warning(f"Failed to fetch model metadata from OpenRouter: {e}") return _model_metadata_cache or {} diff --git a/agent/models_dev.py b/agent/models_dev.py index 8fabb276645..1249c6f1970 100644 --- a/agent/models_dev.py +++ b/agent/models_dev.py @@ -167,6 +167,9 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = { "gemini": "google", "google": "google", "xai": "xai", + # xAI OAuth is an authentication/transport path for the same xAI model + # catalog, so model metadata should resolve through the xAI provider. + "xai-oauth": "xai", "xiaomi": "xiaomi", "nvidia": "nvidia", "groq": "groq", diff --git a/agent/redact.py b/agent/redact.py index 1beb10450fd..7ed241c5efd 100644 --- a/agent/redact.py +++ b/agent/redact.py @@ -176,6 +176,15 @@ _URL_USERINFO_RE = re.compile( r"(https?|wss?|ftp)://([^/\s:@]+):([^/\s@]+)@", ) +# HTTP access logs often use a relative request target rather than a full URL: +# `"POST /webhook?password=... HTTP/1.1"`. The full-URL redactor above only +# sees strings containing `://`, so handle request-target query strings too. +_HTTP_REQUEST_TARGET_QUERY_RE = re.compile( + r"\b((?:GET|POST|PUT|PATCH|DELETE|HEAD|OPTIONS|TRACE|CONNECT)\s+[^ \t\r\n\"']*?)" + r"\?([^ \t\r\n\"']+)", + re.IGNORECASE, +) + # Form-urlencoded body detection: conservative โ€” only applies when the entire # text looks like a query string (k=v&k=v pattern with no newlines). _FORM_BODY_RE = re.compile( @@ -293,6 +302,15 @@ def _redact_url_userinfo(text: str) -> str: ) +def _redact_http_request_target_query_params(text: str) -> str: + """Redact sensitive query params in HTTP access-log request targets.""" + def _sub(m: re.Match) -> str: + prefix = m.group(1) + query = _redact_query_string(m.group(2)) + return f"{prefix}?{query}" + return _HTTP_REQUEST_TARGET_QUERY_RE.sub(_sub, text) + + def _redact_form_body(text: str) -> str: """Redact sensitive values in a form-urlencoded body. @@ -397,6 +415,11 @@ def redact_sensitive_text(text: str, *, force: bool = False, code_file: bool = F if "?" in text: text = _redact_url_query_params(text) + # HTTP access logs can contain relative request targets with query params + # and no URL scheme, e.g. `"POST /hook?password=... HTTP/1.1"`. + if "?" in text and "=" in text and _has_http_method_substring(text): + text = _redact_http_request_target_query_params(text) + # Form-urlencoded bodies (only triggers on clean k=v&k=v inputs). if "&" in text and "=" in text: text = _redact_form_body(text) @@ -456,6 +479,25 @@ def _has_known_prefix_substring(text: str) -> bool: return any(p in text for p in _PREFIX_SUBSTRINGS) +_HTTP_METHOD_SUBSTRINGS = ( + "GET ", + "POST ", + "PUT ", + "PATCH ", + "DELETE ", + "HEAD ", + "OPTIONS ", + "TRACE ", + "CONNECT ", +) + + +def _has_http_method_substring(text: str) -> bool: + """Cheap pre-check before scanning for access-log request targets.""" + upper = text.upper() + return any(method in upper for method in _HTTP_METHOD_SUBSTRINGS) + + class RedactingFormatter(logging.Formatter): """Log formatter that redacts secrets from all log messages.""" diff --git a/agent/secret_sources/bitwarden.py b/agent/secret_sources/bitwarden.py index fb6824b5229..8c1e8dc5678 100644 --- a/agent/secret_sources/bitwarden.py +++ b/agent/secret_sources/bitwarden.py @@ -70,7 +70,7 @@ _BWS_RUN_TIMEOUT = 30 # In-process cache so repeated load_hermes_dotenv() calls (CLI startup, # gateway hot-reload, test suites) don't re-fetch from BSM. -_CacheKey = Tuple[str, str] # (access_token_fingerprint, project_id) +_CacheKey = Tuple[str, str, str] # (access_token_fingerprint, project_id, server_url) _CACHE: Dict[_CacheKey, "_CachedFetch"] = {} @@ -317,11 +317,18 @@ def fetch_bitwarden_secrets( binary: Optional[Path] = None, cache_ttl_seconds: float = 300, use_cache: bool = True, + server_url: str = "", ) -> Tuple[Dict[str, str], List[str]]: """Pull the secrets for ``project_id`` from Bitwarden Secrets Manager. Returns ``(secrets_dict, warnings_list)``. + Set ``server_url`` to point at a non-default Bitwarden region or a + self-hosted instance โ€” e.g. ``https://vault.bitwarden.eu`` for EU + Cloud accounts. When empty, ``bws`` uses its built-in default + (``https://vault.bitwarden.com``, US Cloud). This is plumbed into + the subprocess as ``BWS_SERVER_URL``. + Raises :class:`RuntimeError` for fatal conditions (missing binary, auth failure, unparseable output). Callers in the env_loader path catch this and emit a single warning; callers in the user-facing @@ -332,7 +339,7 @@ def fetch_bitwarden_secrets( if not project_id: raise RuntimeError("Bitwarden project_id is empty") - cache_key = (_token_fingerprint(access_token), project_id) + cache_key = (_token_fingerprint(access_token), project_id, server_url or "") if use_cache: cached = _CACHE.get(cache_key) if cached and cached.is_fresh(cache_ttl_seconds): @@ -347,19 +354,26 @@ def fetch_bitwarden_secrets( "`hermes secrets bitwarden setup`." ) - secrets, warnings = _run_bws_list(bws, access_token, project_id) + secrets, warnings = _run_bws_list(bws, access_token, project_id, server_url) _CACHE[cache_key] = _CachedFetch(secrets=secrets, fetched_at=time.time()) return secrets, warnings def _run_bws_list( - bws: Path, access_token: str, project_id: str + bws: Path, access_token: str, project_id: str, server_url: str = "" ) -> Tuple[Dict[str, str], List[str]]: cmd = [str(bws), "secret", "list", project_id, "--output", "json"] env = os.environ.copy() env["BWS_ACCESS_TOKEN"] = access_token # Make sure we're not echoing telemetry / colour codes into json. env.setdefault("NO_COLOR", "1") + # Region / self-hosted support. bws defaults to https://vault.bitwarden.com + # (US Cloud); EU Cloud users need https://vault.bitwarden.eu, and + # self-hosted users need their own URL. When unset, fall back to whatever + # BWS_SERVER_URL the caller already had in their shell env (preserved by + # the copy above) so manual overrides keep working too. + if server_url: + env["BWS_SERVER_URL"] = server_url try: proc = subprocess.run( # noqa: S603 โ€” bws path is trusted @@ -437,6 +451,7 @@ def apply_bitwarden_secrets( override_existing: bool = False, cache_ttl_seconds: float = 300, auto_install: bool = True, + server_url: str = "", ) -> FetchResult: """Pull secrets from BSM and set them on ``os.environ``. @@ -444,6 +459,10 @@ def apply_bitwarden_secrets( files have loaded. It is intentionally defensive โ€” any failure returns a :class:`FetchResult` with ``error`` set; it never raises. + ``server_url`` selects the Bitwarden region or self-hosted endpoint + (e.g. ``https://vault.bitwarden.eu`` for EU Cloud). Empty string + means use ``bws``'s default (US Cloud). + Parameters mirror the ``secrets.bitwarden.*`` config keys so the caller can just splat the dict in. """ @@ -482,6 +501,7 @@ def apply_bitwarden_secrets( project_id=project_id, binary=binary, cache_ttl_seconds=cache_ttl_seconds, + server_url=server_url, ) except RuntimeError as exc: result.error = str(exc) diff --git a/agent/skill_utils.py b/agent/skill_utils.py index 959a109a6cb..5b8e4c22a67 100644 --- a/agent/skill_utils.py +++ b/agent/skill_utils.py @@ -12,7 +12,7 @@ import sys from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple -from hermes_constants import get_config_path, get_skills_dir +from hermes_constants import get_config_path, get_skills_dir, is_termux logger = logging.getLogger(__name__) @@ -136,6 +136,14 @@ def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool: If the field is absent or empty the skill is compatible with **all** platforms (backward-compatible default). + + Termux note: on Termux/Android, ``sys.platform`` is ``"linux"`` on + older Pythons but became ``"android"`` on Python 3.13+. Termux is a + Linux userland riding on the Android kernel, so skills tagged + ``linux`` are treated as compatible in Termux regardless of which + ``sys.platform`` value Python reports. Individual Linux commands + inside a skill may still misbehave (no systemd, BusyBox utils, no + apt/dnf, etc.) but that is on the skill, not on platform gating. """ platforms = frontmatter.get("platforms") if not platforms: @@ -143,11 +151,21 @@ def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool: if not isinstance(platforms, list): platforms = [platforms] current = sys.platform + running_in_termux = is_termux() for platform in platforms: normalized = str(platform).lower().strip() mapped = PLATFORM_MAP.get(normalized, normalized) if current.startswith(mapped): return True + # Termux runs a Linux userland on Android. Accept linux-tagged + # skills regardless of whether sys.platform is "linux" (pre-3.13 + # Termux) or "android" (Python 3.13+ Termux, and any other + # Android runtime). + if running_in_termux and mapped == "linux": + return True + # Explicit termux/android tags match a Termux session too. + if running_in_termux and mapped in ("termux", "android"): + return True return False diff --git a/agent/system_prompt.py b/agent/system_prompt.py index bc29c9ef89a..8fa4c191563 100644 --- a/agent/system_prompt.py +++ b/agent/system_prompt.py @@ -205,6 +205,40 @@ def build_system_prompt_parts(agent: Any, system_message: Optional[str] = None) if _env_hints: stable_parts.append(_env_hints) + # Active-profile hint โ€” names the Hermes profile the agent is running + # under so it doesn't conflate ~/.hermes/skills/ (default profile) with + # ~/.hermes/profiles//skills/ (this profile's). Deterministic + # for the lifetime of the agent โ€” profile name doesn't change + # mid-session, so this doesn't break the prompt cache. + # See file_safety._resolve_active_profile_name + classify_cross_profile_target + # for the matching tool-side guard. + try: + from agent.file_safety import _resolve_active_profile_name + active_profile = _resolve_active_profile_name() + except Exception: + active_profile = "default" + if active_profile == "default": + stable_parts.append( + "Active Hermes profile: default. Other profiles (if any) live " + "under ~/.hermes/profiles//. Each profile has its own " + "skills/, plugins/, cron/, and memories/ that affect a different " + "session than this one. Do not modify another profile's " + "skills/plugins/cron/memories unless the user explicitly directs " + "you to." + ) + else: + stable_parts.append( + f"Active Hermes profile: {active_profile}. This session reads " + f"and writes ~/.hermes/profiles/{active_profile}/. The default " + f"profile's data lives at ~/.hermes/skills/, ~/.hermes/plugins/, " + f"~/.hermes/cron/, ~/.hermes/memories/ โ€” those belong to a " + f"different session run from a different shell. Do NOT modify " + f"another profile's skills/plugins/cron/memories unless the user " + f"explicitly directs you to. The cross-profile write guard will " + f"refuse such writes by default; pass cross_profile=True only " + f"after explicit direction." + ) + platform_key = (agent.platform or "").lower().strip() if platform_key in PLATFORM_HINTS: stable_parts.append(PLATFORM_HINTS[platform_key]) diff --git a/agent/tool_executor.py b/agent/tool_executor.py index b161b507e8d..438a6337074 100644 --- a/agent/tool_executor.py +++ b/agent/tool_executor.py @@ -388,6 +388,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe agent.tool_progress_callback( "tool.completed", function_name, None, None, duration=tool_duration, is_error=is_error, + result=function_result, ) except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") @@ -491,7 +492,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: - logging.warning(f"Unexpected JSON error after validation: {e}") + logger.warning(f"Unexpected JSON error after validation: {e}") function_args = {} if not isinstance(function_args, dict): function_args = {} @@ -822,6 +823,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe agent.tool_progress_callback( "tool.completed", function_name, None, None, duration=tool_duration, is_error=_is_error_result, + result=function_result, ) except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") diff --git a/agent/transports/anthropic.py b/agent/transports/anthropic.py index 72024ac20f3..d77ae63ef32 100644 --- a/agent/transports/anthropic.py +++ b/agent/transports/anthropic.py @@ -106,7 +106,17 @@ class AnthropicTransport(ProviderTransport): elif block.type == "tool_use": name = block.name if strip_tool_prefix and name.startswith(_MCP_PREFIX): - name = name[len(_MCP_PREFIX):] + stripped = name[len(_MCP_PREFIX):] + # Only strip the mcp_ prefix for OAuth-injected tools + # (where Hermes adds the prefix when sending to Anthropic + # and must remove it on the way back). Native MCP server + # tools (from mcp_servers: in config.yaml) are registered + # in the tool registry under their FULL mcp__ + # name and must NOT be stripped. GH-25255. + from tools.registry import registry as _tool_registry + if (_tool_registry.get_entry(stripped) + and not _tool_registry.get_entry(name)): + name = stripped tool_calls.append( ToolCall( id=block.id, diff --git a/agent/transports/chat_completions.py b/agent/transports/chat_completions.py index fa36301bd81..96997afca43 100644 --- a/agent/transports/chat_completions.py +++ b/agent/transports/chat_completions.py @@ -113,9 +113,8 @@ class ChatCompletionsTransport(ProviderTransport): self, messages: list[dict[str, Any]], **kwargs ) -> list[dict[str, Any]]: """Messages are already in OpenAI format โ€” strip internal fields - that strict chat-completions providers reject with HTTP 400/422. - - Strips: + that strict chat-completions providers reject with HTTP 400/422 + (or, in the case of some OpenAI-compatible gateways, 5xx): - Codex Responses API fields: ``codex_reasoning_items`` / ``codex_message_items`` on the message, ``call_id`` / @@ -127,6 +126,16 @@ class ChatCompletionsTransport(ProviderTransport): ``Extra inputs are not permitted, field: 'messages[N].tool_name'``. Permissive providers (OpenRouter, MiniMax) silently ignore the field, which masked the bug for months. + - Hermes-internal scaffolding markers โ€” any top-level message key + starting with ``_`` (e.g. ``_empty_recovery_synthetic``, + ``_empty_terminal_sentinel``, ``_thinking_prefill``). These are + bookkeeping flags the agent loop attaches to messages so the + persistence layer can later strip its own scaffolding; they must + never reach the wire. Permissive providers (real OpenAI, + Anthropic) silently drop unknown message keys, but strict + gateways (e.g. opencode-go, codex.nekos.me) reject with + ``Extra inputs are not permitted, field: 'messages[N]._empty_recovery_synthetic'``, + which then poisons every subsequent request in the session. """ needs_sanitize = False for msg in messages: @@ -139,6 +148,9 @@ class ChatCompletionsTransport(ProviderTransport): ): needs_sanitize = True break + if any(isinstance(k, str) and k.startswith("_") for k in msg): + needs_sanitize = True + break tool_calls = msg.get("tool_calls") if isinstance(tool_calls, list): for tc in tool_calls: @@ -160,6 +172,11 @@ class ChatCompletionsTransport(ProviderTransport): msg.pop("codex_reasoning_items", None) msg.pop("codex_message_items", None) msg.pop("tool_name", None) + # Drop all Hermes-internal scaffolding markers (``_``-prefixed). + # OpenAI's message schema has no ``_``-prefixed fields, so this + # is safe and future-proofs against new markers being added. + for key in [k for k in msg if isinstance(k, str) and k.startswith("_")]: + msg.pop(key, None) tool_calls = msg.get("tool_calls") if isinstance(tool_calls, list): for tc in tool_calls: diff --git a/agent/transports/codex_app_server_session.py b/agent/transports/codex_app_server_session.py index d9ee92dfbf5..74e164d64d9 100644 --- a/agent/transports/codex_app_server_session.py +++ b/agent/transports/codex_app_server_session.py @@ -87,6 +87,39 @@ class TurnResult: _TURN_ABORTED_MARKERS = ("", "") +def _coerce_turn_input_text(user_input: Any) -> str: + """Collapse Hermes/OpenAI rich content into app-server text input. + + The current `turn/start` path sends text items only. TUI image attachment + can hand us OpenAI-style content parts, so keep the text/path hints and + replace opaque image payloads with a small marker instead of putting a + Python list into the `text` field. + """ + if isinstance(user_input, str): + return user_input + if isinstance(user_input, list): + parts: list[str] = [] + for item in user_input: + if isinstance(item, str): + if item.strip(): + parts.append(item) + continue + if not isinstance(item, dict): + if item is not None: + parts.append(str(item)) + continue + item_type = item.get("type") + if item_type in {"text", "input_text"}: + text = item.get("text") or item.get("content") or "" + if text: + parts.append(str(text)) + elif item_type in {"image", "image_url", "input_image"}: + parts.append("[image attached]") + text = "\n\n".join(p for p in parts if p).strip() + return text or "What do you see in this image?" + return "" if user_input is None else str(user_input) + + # Substrings in codex stderr / JSON-RPC error messages that signal the # subprocess died because its OAuth credentials are no longer valid. # Kept conservative: we only redirect users to `codex login` when we're @@ -327,7 +360,7 @@ class CodexAppServerSession: def run_turn( self, - user_input: str, + user_input: Any, *, turn_timeout: float = 600.0, notification_poll_timeout: float = 0.25, @@ -365,6 +398,8 @@ class CodexAppServerSession: self._interrupt_event.clear() projector = CodexEventProjector() + user_input_text = _coerce_turn_input_text(user_input) + # Send turn/start with the user input. Text-only for now (codex # supports rich content but Hermes' text path is the common case). try: @@ -372,7 +407,7 @@ class CodexAppServerSession: "turn/start", { "threadId": self._thread_id, - "input": [{"type": "text", "text": user_input}], + "input": [{"type": "text", "text": user_input_text}], }, timeout=10, ) diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 68c716daab0..939f602cdfb 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -39,7 +39,7 @@ model: # LM Studio is first-class and uses provider: "lmstudio". # It works with both no-auth and auth-enabled server modes. # - # Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var. + # Can also be overridden for a single invocation with the --provider flag. provider: "auto" # API configuration (falls back to OPENROUTER_API_KEY env var) diff --git a/cli.py b/cli.py index bd8696178d5..dcd97139809 100644 --- a/cli.py +++ b/cli.py @@ -51,6 +51,8 @@ os.environ["HERMES_QUIET"] = "1" # Our own modules import yaml +from hermes_cli.fallback_config import get_fallback_chain + # prompt_toolkit for fixed input area TUI from prompt_toolkit.history import FileHistory from prompt_toolkit.styles import Style as PTStyle @@ -81,17 +83,73 @@ except Exception: import threading import queue -from agent.usage_pricing import ( - CanonicalUsage, - estimate_usage_cost, - format_duration_compact, - format_token_count_compact, -) -from agent.markdown_tables import ( - is_table_divider, - looks_like_table_row, - realign_markdown_tables, -) +def CanonicalUsage(*args, **kwargs): + from agent.usage_pricing import CanonicalUsage as _CanonicalUsage + + return _CanonicalUsage(*args, **kwargs) + + +def estimate_usage_cost(*args, **kwargs): + from agent.usage_pricing import estimate_usage_cost as _estimate_usage_cost + + return _estimate_usage_cost(*args, **kwargs) + + +def format_duration_compact(*args, **kwargs): + seconds = float(args[0] if args else kwargs.get("seconds", 0.0)) + if seconds < 60: + return f"{seconds:.0f}s" + minutes = seconds / 60 + if minutes < 60: + return f"{minutes:.0f}m" + hours = minutes / 60 + if hours < 24: + remaining_min = int(minutes % 60) + return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h" + days = hours / 24 + return f"{days:.1f}d" + + +def format_token_count_compact(*args, **kwargs): + value = int(args[0] if args else kwargs.get("value", 0)) + abs_value = abs(value) + if abs_value < 1_000: + return str(value) + + sign = "-" if value < 0 else "" + units = ((1_000_000_000, "B"), (1_000_000, "M"), (1_000, "K")) + for threshold, suffix in units: + if abs_value >= threshold: + scaled = abs_value / threshold + if scaled < 10: + text = f"{scaled:.2f}" + elif scaled < 100: + text = f"{scaled:.1f}" + else: + text = f"{scaled:.0f}" + if "." in text: + text = text.rstrip("0").rstrip(".") + return f"{sign}{text}{suffix}" + + return f"{value:,}" + + +def is_table_divider(*args, **kwargs): + from agent.markdown_tables import is_table_divider as _is_table_divider + + return _is_table_divider(*args, **kwargs) + + +def looks_like_table_row(*args, **kwargs): + from agent.markdown_tables import looks_like_table_row as _looks_like_table_row + + return _looks_like_table_row(*args, **kwargs) + + +def realign_markdown_tables(*args, **kwargs): + from agent.markdown_tables import realign_markdown_tables as _realign_markdown_tables + + return _realign_markdown_tables(*args, **kwargs) # NOTE: `from agent.account_usage import ...` is deliberately NOT at module # top โ€” it transitively pulls the OpenAI SDK chain (~230 ms cold) and is only # needed when the user runs `/limits`. Lazy-imported inside the handler below. @@ -357,6 +415,12 @@ def load_cli_config() -> Dict[str, Any]: "display": { "compact": False, "resume_display": "full", + # Recap tuning for /resume โ€” see hermes_cli/config.py DEFAULT_CONFIG. + "resume_exchanges": 10, + "resume_max_user_chars": 300, + "resume_max_assistant_chars": 200, + "resume_max_assistant_lines": 3, + "resume_skip_tool_only": True, "show_reasoning": False, "streaming": True, "busy_input_mode": "interrupt", @@ -410,7 +474,9 @@ def load_cli_config() -> Dict[str, Any]: if config_path.exists(): try: with open(config_path, "r", encoding="utf-8") as f: - file_config = yaml.safe_load(f) or {} + from hermes_cli.config import _normalize_root_model_keys + + file_config = _normalize_root_model_keys(yaml.safe_load(f) or {}) _file_has_terminal_config = "terminal" in file_config @@ -431,21 +497,6 @@ def load_cli_config() -> Dict[str, Any]: if "model" in file_config["model"] and "default" not in file_config["model"]: defaults["model"]["default"] = file_config["model"]["model"] - # Legacy root-level provider/base_url fallback. - # Some users (or old code) put provider: / base_url: at the - # config root instead of inside the model: section. These are - # only used as a FALLBACK when model.provider / model.base_url - # is not already set โ€” never as an override. The canonical - # location is model.provider (written by `hermes model`). - if not defaults["model"].get("provider"): - root_provider = file_config.get("provider") - if root_provider: - defaults["model"]["provider"] = root_provider - if not defaults["model"].get("base_url"): - root_base_url = file_config.get("base_url") - if root_base_url: - defaults["model"]["base_url"] = root_base_url - # Deep merge file_config into defaults. # First: merge keys that exist in both (deep-merge dicts, overwrite scalars) for key in defaults: @@ -717,31 +768,142 @@ from rich.markup import escape as _escape from rich.panel import Panel from rich.text import Text as _RichText -import fire +# Import agent and tool systems lazily. Bare interactive startup only needs the +# prompt; the full agent/tool registry is initialized on first use. +def AIAgent(*args, **kwargs): + from run_agent import AIAgent as _AIAgent -# Import the agent and tool systems -from run_agent import AIAgent -from model_tools import get_tool_definitions, get_toolset_for_tool + return _AIAgent(*args, **kwargs) + + +def get_tool_definitions(*args, **kwargs): + from model_tools import get_tool_definitions as _get_tool_definitions + + return _get_tool_definitions(*args, **kwargs) + + +def get_toolset_for_tool(*args, **kwargs): + from model_tools import get_toolset_for_tool as _get_toolset_for_tool + + return _get_toolset_for_tool(*args, **kwargs) # Extracted CLI modules (Phase 3) from hermes_cli.banner import build_welcome_banner from hermes_cli.commands import SlashCommandCompleter, SlashCommandAutoSuggest -from toolsets import get_all_toolsets, get_toolset_info, validate_toolset + + +def get_all_toolsets(*args, **kwargs): + from toolsets import get_all_toolsets as _get_all_toolsets + + return _get_all_toolsets(*args, **kwargs) + + +def get_toolset_info(*args, **kwargs): + from toolsets import get_toolset_info as _get_toolset_info + + return _get_toolset_info(*args, **kwargs) + + +def validate_toolset(*args, **kwargs): + from toolsets import validate_toolset as _validate_toolset + + return _validate_toolset(*args, **kwargs) + + +def _sync_process_session_id(session_id: str) -> None: + """Keep process-local session-id consumers aligned after CLI switches.""" + from gateway.session_context import set_current_session_id + + set_current_session_id(session_id) # Cron job system for scheduled tasks (execution is handled by the gateway) -from cron import get_job +def get_job(*args, **kwargs): + from cron import get_job as _get_job + + return _get_job(*args, **kwargs) # Resource cleanup imports for safe shutdown (terminal VMs, browser sessions) -from tools.terminal_tool import cleanup_all_environments as _cleanup_all_terminals -from tools.terminal_tool import set_sudo_password_callback, set_approval_callback -from tools.skills_tool import set_secret_capture_callback from hermes_cli.callbacks import prompt_for_secret -from tools.browser_tool import _emergency_cleanup_all_sessions as _cleanup_all_browsers + + +def _cleanup_all_terminals(*args, **kwargs): + from tools.terminal_tool import cleanup_all_environments + + return cleanup_all_environments(*args, **kwargs) + + +def set_sudo_password_callback(*args, **kwargs): + from tools.terminal_tool import set_sudo_password_callback as _set_sudo_password_callback + + return _set_sudo_password_callback(*args, **kwargs) + + +def set_approval_callback(*args, **kwargs): + from tools.terminal_tool import set_approval_callback as _set_approval_callback + + return _set_approval_callback(*args, **kwargs) + + +def set_secret_capture_callback(*args, **kwargs): + from tools.skills_tool import set_secret_capture_callback as _set_secret_capture_callback + + return _set_secret_capture_callback(*args, **kwargs) + + +def _cleanup_all_browsers(*args, **kwargs): + from tools.browser_tool import _emergency_cleanup_all_sessions + + return _emergency_cleanup_all_sessions(*args, **kwargs) # Guard to prevent cleanup from running multiple times on exit _cleanup_done = False # Weak reference to the active AIAgent for memory provider shutdown at exit _active_agent_ref = None +_deferred_agent_startup_done = False + + +def _prepare_deferred_agent_startup() -> None: + """Run Termux-deferred agent discovery before the first real agent turn.""" + global _deferred_agent_startup_done + if _deferred_agent_startup_done: + return + if os.environ.get("HERMES_DEFER_AGENT_STARTUP") != "1": + return + _deferred_agent_startup_done = True + _accept_hooks = os.environ.get("HERMES_ACCEPT_HOOKS", "").lower() in { + "1", + "true", + "yes", + "on", + } + try: + from hermes_cli.plugins import discover_plugins + + discover_plugins() + except Exception: + logger.warning( + "plugin discovery failed at deferred CLI startup", + exc_info=True, + ) + try: + from tools.mcp_tool import discover_mcp_tools + + discover_mcp_tools() + except Exception: + logger.debug( + "MCP tool discovery failed at deferred CLI startup", + exc_info=True, + ) + try: + from agent.shell_hooks import register_from_config + from hermes_cli.config import load_config + + register_from_config(load_config(), accept_hooks=_accept_hooks) + except Exception: + logger.debug( + "shell-hook registration failed at deferred CLI startup", + exc_info=True, + ) def _run_cleanup(): """Run resource cleanup exactly once.""" @@ -2455,7 +2617,13 @@ def _build_compact_banner() -> str: line1 = f"{agent_name} - AI Agent Framework" tiny_line = agent_name - version_line = format_banner_version_label() + if os.environ.get("HERMES_FAST_STARTUP_BANNER") == "1": + from hermes_cli import __release_date__ as _release_date + from hermes_cli import __version__ as _version + + version_line = f"Hermes Agent v{_version} ({_release_date})" + else: + version_line = format_banner_version_label() w = min(shutil.get_terminal_size().columns - 2, 88) if w < 30: @@ -2504,19 +2672,48 @@ def _looks_like_slash_command(text: str) -> bool: # Skill Slash Commands โ€” dynamic commands generated from installed skills # ============================================================================ -from agent.skill_commands import ( - scan_skill_commands, - get_skill_commands, - build_skill_invocation_message, - build_preloaded_skills_prompt, -) -from agent.skill_bundles import ( - get_skill_bundles, - build_bundle_invocation_message, -) +_skill_commands = None +_skill_bundles = None -_skill_commands = scan_skill_commands() -_skill_bundles = get_skill_bundles() + +def _ensure_skill_commands() -> dict: + global _skill_commands + if _skill_commands is None: + from agent.skill_commands import scan_skill_commands + + _skill_commands = scan_skill_commands() + return _skill_commands + + +def get_skill_commands() -> dict: + return _ensure_skill_commands() + + +def build_skill_invocation_message(*args, **kwargs): + from agent.skill_commands import build_skill_invocation_message as _impl + + return _impl(*args, **kwargs) + + +def build_preloaded_skills_prompt(*args, **kwargs): + from agent.skill_commands import build_preloaded_skills_prompt as _impl + + return _impl(*args, **kwargs) + + +def get_skill_bundles() -> dict: + global _skill_bundles + if _skill_bundles is None: + from agent.skill_bundles import get_skill_bundles as _impl + + _skill_bundles = _impl() + return _skill_bundles + + +def build_bundle_invocation_message(*args, **kwargs): + from agent.skill_bundles import build_bundle_invocation_message as _impl + + return _impl(*args, **kwargs) def _get_plugin_cmd_handler_names() -> set: @@ -2615,7 +2812,7 @@ class HermesCLI: api_key: str = None, base_url: str = None, max_turns: int = None, - verbose: bool = False, + verbose: Optional[bool] = None, compact: bool = False, resume: str = None, checkpoints: bool = False, @@ -2666,7 +2863,12 @@ class HermesCLI: else: self.busy_input_mode = "interrupt" - self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose") + # self.verbose ONLY controls global DEBUG logging (root logger level). + # display.tool_progress="verbose" controls tool-call rendering (full args, + # results, think blocks) and is independent โ€” see _apply_logging_levels. + # Coupling the two (PR #6a1aa420e) caused all module DEBUG logs to spew + # to console whenever a user set tool_progress: verbose in config. + self.verbose = bool(verbose) if verbose is not None else False # streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml) self.streaming_enabled = CLI_CONFIG["display"].get("streaming", False) @@ -2852,12 +3054,9 @@ class HermesCLI: pass # Fallback provider chain โ€” tried in order when primary fails after retries. - # Supports new list format (fallback_providers) and legacy single-dict (fallback_model). - fb = CLI_CONFIG.get("fallback_providers") or CLI_CONFIG.get("fallback_model") or [] - # Normalize legacy single-dict to a one-element list - if isinstance(fb, dict): - fb = [fb] if fb.get("provider") and fb.get("model") else [] - self._fallback_model = fb + # Merge new ``fallback_providers`` entries with any legacy + # ``fallback_model`` entries so old configs still participate. + self._fallback_model = get_fallback_chain(CLI_CONFIG) # Signature of the currently-initialised agent's runtime. Used to # rebuild the agent when provider / model / base_url changes across @@ -2865,7 +3064,9 @@ class HermesCLI: self._active_agent_route_signature = None # Agent will be initialized on first use - self.agent: Optional[AIAgent] = None + self.agent: Optional[Any] = None + self._tool_callbacks_installed = False + self._tirith_security_checked = False self._app = None # prompt_toolkit Application (set in run()) # Conversation state @@ -4488,6 +4689,41 @@ class HermesCLI: route["request_overrides"] = overrides return route + def _install_tool_callbacks(self) -> None: + """Install tool callbacks that need the live prompt UI.""" + if getattr(self, "_tool_callbacks_installed", False): + return + set_sudo_password_callback(self._sudo_password_callback) + set_approval_callback(self._approval_callback) + set_secret_capture_callback(self._secret_capture_callback) + try: + from tools.computer_use_tool import set_approval_callback as _set_cu_cb + + _set_cu_cb(self._computer_use_approval_callback) + except ImportError: + pass + self._tool_callbacks_installed = True + + def _ensure_tirith_security(self) -> None: + """Check tirith availability once before tools can run terminal commands.""" + if getattr(self, "_tirith_security_checked", False): + return + self._tirith_security_checked = True + try: + from tools.tirith_security import ensure_installed, is_platform_supported + + tirith_path = ensure_installed(log_failures=False) + if tirith_path is None and is_platform_supported(): + security_cfg = self.config.get("security", {}) or {} + tirith_enabled = security_cfg.get("tirith_enabled", True) + if tirith_enabled: + _cprint( + f" {_DIM}โš  tirith security scanner enabled but not available " + f"โ€” command scanning will use pattern matching only{_RST}" + ) + except Exception: + pass + def _init_agent(self, *, model_override: str = None, runtime_override: dict = None, request_overrides: dict | None = None) -> bool: """ Initialize the agent on first use. @@ -4499,6 +4735,10 @@ class HermesCLI: if self.agent is not None: return True + _prepare_deferred_agent_startup() + self._install_tool_callbacks() + self._ensure_tirith_security() + if not self._ensure_runtime_credentials(): return False @@ -4713,8 +4953,10 @@ class HermesCLI: context_length=ctx_len, ) - # Show tool availability warnings if any tools are disabled - self._show_tool_availability_warnings() + # Tool discovery is intentionally deferred on the Termux bare prompt + # path; availability warnings are shown once tools are initialized. + if os.environ.get("HERMES_DEFER_AGENT_STARTUP") != "1": + self._show_tool_availability_warnings() # Warn about very low context lengths (common with local servers) if ctx_len and ctx_len <= 8192: @@ -4852,10 +5094,13 @@ class HermesCLI: if self.resume_display == "minimal": return - MAX_DISPLAY_EXCHANGES = 10 # max user+assistant pairs to show - MAX_USER_LEN = 300 # truncate user messages - MAX_ASST_LEN = 200 # truncate assistant text - MAX_ASST_LINES = 3 # max lines of assistant text + # Read limits from config (with hardcoded defaults) + _disp = CLI_CONFIG.get("display", {}) + MAX_DISPLAY_EXCHANGES = int(_disp.get("resume_exchanges", 10)) + MAX_USER_LEN = int(_disp.get("resume_max_user_chars", 300)) + MAX_ASST_LEN = int(_disp.get("resume_max_assistant_chars", 200)) + MAX_ASST_LINES = int(_disp.get("resume_max_assistant_lines", 3)) + SKIP_TOOL_ONLY = _disp.get("resume_skip_tool_only", True) # Collect displayable entries (skip system, tool-result messages) entries = [] # list of (role, display_text) @@ -4918,6 +5163,10 @@ class HermesCLI: if not parts: # Skip pure-reasoning messages that have no visible output continue + # Skip tool-call-only entries when SKIP_TOOL_ONLY is enabled + has_text = bool(text) + if SKIP_TOOL_ONLY and not has_text and tool_calls: + continue entries.append(("assistant", " ".join(parts))) _last_asst_idx = len(entries) - 1 _last_asst_full = " ".join(full_parts) @@ -5491,9 +5740,13 @@ class HermesCLI: def _show_status(self): """Show compact startup status line.""" - # Get tool count - tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) - tool_count = len(tools) if tools else 0 + # Avoid pulling the full tool registry into the bare Termux prompt path. + if os.environ.get("HERMES_DEFER_AGENT_STARTUP") == "1": + tool_status = "tools deferred" + else: + tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) + tool_count = len(tools) if tools else 0 + tool_status = f"{tool_count} tools" # Format model name (shorten if needed) model_short = self.model.split("/")[-1] if "/" in self.model else self.model @@ -5525,7 +5778,7 @@ class HermesCLI: self._console_print( f" {api_indicator} [{accent_color}]{model_short}[/] " - f"[dim {separator_color}]ยท[/] [bold {label_color}]{tool_count} tools[/]" + f"[dim {separator_color}]ยท[/] [bold {label_color}]{tool_status}[/]" f"{toolsets_info}{provider_info}" ) @@ -5638,9 +5891,10 @@ class HermesCLI: continue ChatConsole().print(f" [bold {_accent_hex()}]{cmd:<15}[/] [dim]-[/] {_escape(desc)}") - if _skill_commands: - _cprint(f"\n โšก {_BOLD}Skill Commands{_RST} ({len(_skill_commands)} installed):") - for cmd, info in sorted(_skill_commands.items()): + skill_commands = _ensure_skill_commands() + if skill_commands: + _cprint(f"\n โšก {_BOLD}Skill Commands{_RST} ({len(skill_commands)} installed):") + for cmd, info in sorted(skill_commands.items()): ChatConsole().print( f" [bold {_accent_hex()}]{cmd:<22}[/] [dim]-[/] {_escape(info['description'])}" ) @@ -5918,15 +6172,16 @@ class HermesCLI: else: print(" Recent sessions:") print() - print(f" {'Title':<32} {'Preview':<40} {'Last Active':<13} {'ID'}") - print(f" {'โ”€' * 32} {'โ”€' * 40} {'โ”€' * 13} {'โ”€' * 24}") - for session in sessions: - title = (session.get("title") or "โ€”")[:30] + print(f" {'#':<3} {'Title':<32} {'Preview':<40} {'Last Active':<13} {'ID'}") + print(f" {'โ”€' * 3} {'โ”€' * 32} {'โ”€' * 40} {'โ”€' * 13} {'โ”€' * 24}") + for idx, session in enumerate(sessions, start=1): + title = session.get("title") or "โ€”" preview = (session.get("preview") or "")[:38] last_active = _relative_time(session.get("last_active")) - print(f" {title:<32} {preview:<40} {last_active:<13} {session['id']}") + print(f" {idx:<3} {title:<32} {preview:<40} {last_active:<13} {session['id']}") print() - print(" Use /resume to continue where you left off.") + print(" Use /resume , /resume , or /resume to continue.") + print(" Example: /resume 2") print() return True @@ -6037,6 +6292,7 @@ class HermesCLI: self.conversation_history = [] self._pending_title = None self._resumed = False + _sync_process_session_id(self.session_id) if self.agent: self.agent.session_id = self.session_id @@ -6270,7 +6526,7 @@ class HermesCLI: target = parts[1].strip() if len(parts) > 1 else "" if not target: - _cprint(" Usage: /resume ") + _cprint(" Usage: /resume ") if self._show_recent_sessions(reason="resume"): return _cprint(" Tip: Use /history or `hermes sessions list` to find sessions.") @@ -6281,10 +6537,20 @@ class HermesCLI: _cprint(f" {format_session_db_unavailable()}") return - # Resolve title or ID - from hermes_cli.main import _resolve_session_by_name_or_id - resolved = _resolve_session_by_name_or_id(target) - target_id = resolved or target + # Resolve numbered selection, title, or ID + if target.isdigit(): + sessions = self._list_recent_sessions(limit=10) + index = int(target) + if index < 1 or index > len(sessions): + _cprint(f" Resume index {index} is out of range.") + _cprint(" Use /resume with no arguments to see available sessions.") + return + selected = sessions[index - 1] + target_id = selected["id"] + else: + from hermes_cli.main import _resolve_session_by_name_or_id + resolved = _resolve_session_by_name_or_id(target) + target_id = resolved or target session_meta = self._session_db.get_session(target_id) if not session_meta: @@ -6323,6 +6589,7 @@ class HermesCLI: self.session_id = target_id self._resumed = True self._pending_title = None + _sync_process_session_id(target_id) # Load conversation history (strip transcript-only metadata entries) restored = self._session_db.get_messages_as_conversation(target_id) @@ -6374,6 +6641,7 @@ class HermesCLI: f" ({msg_count} user message{'s' if msg_count != 1 else ''}," f" {len(self.conversation_history)} total)" ) + self._display_resumed_history() else: _cprint(f" โ†ป Resumed session {target_id}{title_part} โ€” no messages, starting fresh.") @@ -6496,6 +6764,7 @@ class HermesCLI: self.session_start = now self._pending_title = None self._resumed = True # Prevents auto-title generation + _sync_process_session_id(new_session_id) # Sync the agent if self.agent: @@ -7857,6 +8126,7 @@ class HermesCLI: "clear", "This clears the screen and starts a new session.\n" "The current conversation history will be discarded.", + cmd_original=cmd_original, ) is None: return self.new_session(silent=True) @@ -7981,12 +8251,16 @@ class HermesCLI: if not self._handle_handoff_command(cmd_original): return False elif canonical == "new": - parts = cmd_original.split(maxsplit=1) - title = parts[1].strip() if len(parts) > 1 else None + # Strip inline-skip tokens (now/--yes/-y) before deriving the title + # so "/new now My Session" yields title="My Session" instead of + # title="now My Session". See _split_destructive_skip. + _new_args, _ = self._split_destructive_skip(cmd_original) + title = _new_args.strip() or None if self._confirm_destructive_slash( "new", "This starts a fresh session.\n" "The current conversation history will be discarded.", + cmd_original=cmd_original, ) is None: return self.new_session(title=title) @@ -8013,6 +8287,7 @@ class HermesCLI: if self._confirm_destructive_slash( "undo", "This removes the last user/assistant exchange from history.", + cmd_original=cmd_original, ) is None: return self.undo_last() @@ -8161,6 +8436,8 @@ class HermesCLI: else: # Check for user-defined quick commands (bypass agent loop, no LLM call) base_cmd = cmd_lower.split()[0] + skill_commands = _ensure_skill_commands() + skill_bundles = get_skill_bundles() quick_commands = self.config.get("quick_commands", {}) if base_cmd.lstrip("/") in quick_commands: qcmd = quick_commands[base_cmd.lstrip("/")] @@ -8216,14 +8493,14 @@ class HermesCLI: _cprint(f"\033[1;31mPlugin command error: {e}{_RST}") # Skill bundles take precedence over individual skills โ€” / # loads multiple skills at once. Rescans cheaply when files change. - elif base_cmd in get_skill_bundles(): + elif base_cmd in skill_bundles: user_instruction = cmd_original[len(base_cmd):].strip() bundle_result = build_bundle_invocation_message( base_cmd, user_instruction, task_id=self.session_id ) if bundle_result: msg, loaded_names, missing = bundle_result - bundle_info = get_skill_bundles()[base_cmd] + bundle_info = skill_bundles[base_cmd] print( f"\nโšก Loading bundle: {bundle_info['name']} " f"({len(loaded_names)} skills)" @@ -8239,13 +8516,13 @@ class HermesCLI: f"[bold red]Failed to load bundle for {base_cmd}[/]" ) # Check for skill slash commands (/gif-search, /axolotl, etc.) - elif base_cmd in _skill_commands: + elif base_cmd in skill_commands: user_instruction = cmd_original[len(base_cmd):].strip() msg = build_skill_invocation_message( base_cmd, user_instruction, task_id=self.session_id ) if msg: - skill_name = _skill_commands[base_cmd]["name"] + skill_name = skill_commands[base_cmd]["name"] print(f"\nโšก Loading skill: {skill_name}") if hasattr(self, '_pending_input'): self._pending_input.put(msg) @@ -8257,7 +8534,7 @@ class HermesCLI: # that execution-time resolution agrees with tab-completion. from hermes_cli.commands import COMMANDS typed_base = cmd_lower.split()[0] - all_known = set(COMMANDS) | set(_skill_commands) | set(get_skill_bundles()) + all_known = set(COMMANDS) | set(skill_commands) | set(skill_bundles) matches = [c for c in all_known if c.startswith(typed_base)] if len(matches) > 1: # Prefer an exact match (typed the full command name) @@ -9088,18 +9365,23 @@ class HermesCLI: _cprint(" Failed to save runtime_footer setting to config.yaml") def _toggle_verbose(self): - """Cycle tool progress mode: off โ†’ new โ†’ all โ†’ verbose โ†’ off.""" + """Cycle tool progress mode: off โ†’ new โ†’ all โ†’ verbose โ†’ off. + + Tool-progress display (full args / results / think blocks at the + ``verbose`` step) is INDEPENDENT of global DEBUG logging. Cycling + through here does not change ``self.verbose`` or the agent's + ``verbose_logging`` / ``quiet_mode`` โ€” those remain under the + explicit ``-v``/``--verbose`` flag and the ``/verbose-logging`` + toggle. See PR #6a1aa420e for the history that decoupled them. + """ cycle = ["off", "new", "all", "verbose"] try: idx = cycle.index(self.tool_progress_mode) except ValueError: idx = 2 # default to "all" self.tool_progress_mode = cycle[(idx + 1) % len(cycle)] - self.verbose = self.tool_progress_mode == "verbose" if self.agent: - self.agent.verbose_logging = self.verbose - self.agent.quiet_mode = not self.verbose self.agent.reasoning_callback = self._current_reasoning_callback() # Use raw ANSI codes via _cprint so the output is routed through @@ -9111,7 +9393,7 @@ class HermesCLI: "off": f"{_Colors.DIM}Tool progress: OFF{_Colors.RESET} โ€” silent mode, just the final response.", "new": f"{_Colors.YELLOW}Tool progress: NEW{_Colors.RESET} โ€” show each new tool (skip repeats).", "all": f"{_Colors.GREEN}Tool progress: ALL{_Colors.RESET} โ€” show every tool call.", - "verbose": f"{_Colors.BOLD}{_Colors.GREEN}Tool progress: VERBOSE{_Colors.RESET} โ€” full args, results, think blocks, and debug logs.", + "verbose": f"{_Colors.BOLD}{_Colors.GREEN}Tool progress: VERBOSE{_Colors.RESET} โ€” full args, results, and think blocks.", } _cprint(labels.get(self.tool_progress_mode, "")) @@ -9657,7 +9939,49 @@ class HermesCLI: if _reload_thread.is_alive(): print(" โš ๏ธ MCP reload timed out (30s). Some servers may not have reconnected.") - def _confirm_destructive_slash(self, command: str, detail: str) -> Optional[str]: + # Inline-skip tokens that bypass the destructive-slash confirmation modal. + # Matches the escape-hatch pattern users on broken modal platforms + # (currently native Windows PowerShell โ€” issue #30768) need to self-serve + # without having to flip approvals.destructive_slash_confirm in config. + _DESTRUCTIVE_SKIP_TOKENS = frozenset({"now", "--yes", "-y"}) + + @classmethod + def _split_destructive_skip(cls, cmd_text: Optional[str]) -> tuple[str, bool]: + """Split inline-skip tokens out of a destructive slash command. + + Returns ``(remainder, skip)`` where ``remainder`` is the original + text with the command word and any recognized skip tokens removed, + and ``skip`` is True iff at least one skip token was found. + + Examples: + "/reset now" -> ("", True) + "/reset --yes My title" -> ("My title", True) + "/new My title" -> ("My title", False) + "/clear" -> ("", False) + """ + if not cmd_text: + return "", False + tokens = cmd_text.strip().split() + if not tokens: + return "", False + # Drop leading "/cmd" word โ€” callers pass the full command text. + if tokens[0].startswith("/"): + tokens = tokens[1:] + skip = False + kept: list[str] = [] + for tok in tokens: + if tok.lower() in cls._DESTRUCTIVE_SKIP_TOKENS: + skip = True + continue + kept.append(tok) + return " ".join(kept), skip + + def _confirm_destructive_slash( + self, + command: str, + detail: str, + cmd_original: Optional[str] = None, + ) -> Optional[str]: """Prompt the user to confirm a destructive session slash command. Used by ``/clear``, ``/new``/``/reset``, and ``/undo`` before they @@ -9673,9 +9997,24 @@ class HermesCLI: gate is off the function returns ``"once"`` immediately without prompting. + Inline-skip: if ``cmd_original`` contains ``now``, ``--yes``, or + ``-y`` as an argument (e.g. ``/reset now``, ``/new --yes My title``), + the modal is bypassed and ``"once"`` is returned immediately. This is + an escape hatch for platforms where the prompt_toolkit modal hangs + (issue #30768 โ€” native Windows PowerShell). Callers are responsible + for stripping the skip tokens from any remaining argument parsing + (see :meth:`_split_destructive_skip`). + Returns ``"once"``, ``"always"``, or ``None`` (cancelled). Callers proceed with the destructive action when the result is non-None. """ + # Inline-skip escape hatch โ€” works regardless of platform/modal state. + # See class-level _DESTRUCTIVE_SKIP_TOKENS for the accepted tokens. + if cmd_original: + _, _skip = self._split_destructive_skip(cmd_original) + if _skip: + return "once" + # Gate check โ€” respects prior "Always Approve" clicks. try: cfg = load_cli_config() @@ -10010,9 +10349,7 @@ class HermesCLI: self._last_scrollback_tool = function_name try: from agent.display import get_cute_tool_message - line = get_cute_tool_message(function_name, stored_args, duration) - if is_error: - line = f"{line} [error]" + line = get_cute_tool_message(function_name, stored_args, duration, result=kwargs.get("result")) _cprint(f" {line}") except Exception: pass @@ -12023,37 +12360,11 @@ class HermesCLI: self._voice_tts_done = threading.Event() # Signals TTS playback finished self._voice_tts_done.set() # Initially "done" (no TTS pending) - # Register callbacks so terminal_tool prompts route through our UI - set_sudo_password_callback(self._sudo_password_callback) - set_approval_callback(self._approval_callback) - set_secret_capture_callback(self._secret_capture_callback) + if os.environ.get("HERMES_DEFER_AGENT_STARTUP") != "1": + self._install_tool_callbacks() - # Computer-use shares the same approval UI (prompt_toolkit dialog). - # The tool handler expects a 3-arg callback (action, args, summary) - # and returns "approve_once" | "approve_session" | "always_approve" - # | "deny". Adapt our existing generic callback. - try: - from tools.computer_use_tool import set_approval_callback as _set_cu_cb - _set_cu_cb(self._computer_use_approval_callback) - except ImportError: - pass # computer_use extras not installed - - # Ensure tirith security scanner is available (downloads if needed). - # Warn the user if tirith is enabled in config but not available, - # so they know command security scanning is degraded. Suppressed - # on platforms where tirith ships no binary (Windows etc.) โ€” the - # user can't act on it and pattern-matching guards still run. - try: - from tools.tirith_security import ensure_installed, is_platform_supported - tirith_path = ensure_installed(log_failures=False) - if tirith_path is None and is_platform_supported(): - security_cfg = self.config.get("security", {}) or {} - tirith_enabled = security_cfg.get("tirith_enabled", True) - if tirith_enabled: - _cprint(f" {_DIM}โš  tirith security scanner enabled but not available " - f"โ€” command scanning will use pattern matching only{_RST}") - except Exception: - pass # Non-fatal โ€” fail-open at scan time if unavailable + if os.environ.get("HERMES_DEFER_AGENT_STARTUP") != "1": + self._ensure_tirith_security() # Key bindings for the input area kb = KeyBindings() @@ -14211,7 +14522,7 @@ def main( api_key: str = None, base_url: str = None, max_turns: int = None, - verbose: bool = False, + verbose: Optional[bool] = None, quiet: bool = False, compact: bool = False, list_tools: bool = False, @@ -14557,4 +14868,6 @@ def main( if __name__ == "__main__": + import fire + fire.Fire(main) diff --git a/cron/scheduler.py b/cron/scheduler.py index e76f67064cf..6b511d38b77 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -529,7 +529,9 @@ def _send_media_via_adapter( """ from pathlib import Path - from gateway.platforms.base import should_send_media_as_audio + from gateway.platforms.base import BasePlatformAdapter, should_send_media_as_audio + + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) for media_path, _is_voice in media_files: try: @@ -614,6 +616,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option # Extract MEDIA: tags so attachments are forwarded as files, not raw text from gateway.platforms.base import BasePlatformAdapter media_files, cleaned_delivery_content = BasePlatformAdapter.extract_media(delivery_content) + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) try: config = load_gateway_config() diff --git a/gateway/config.py b/gateway/config.py index 83326975249..ec98532a273 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -424,7 +424,9 @@ _PLATFORM_CONNECTED_CHECKERS: dict[Platform, Callable[[PlatformConfig], bool]] = Platform.SMS: lambda cfg: bool(os.getenv("TWILIO_ACCOUNT_SID")), Platform.API_SERVER: lambda cfg: True, Platform.WEBHOOK: lambda cfg: True, - Platform.MSGRAPH_WEBHOOK: lambda cfg: True, + Platform.MSGRAPH_WEBHOOK: lambda cfg: bool( + str(cfg.extra.get("client_state") or "").strip() + ), Platform.FEISHU: lambda cfg: bool(cfg.extra.get("app_id")), Platform.WECOM: lambda cfg: bool(cfg.extra.get("bot_id")), Platform.WECOM_CALLBACK: lambda cfg: bool( @@ -926,73 +928,6 @@ def load_gateway_config() -> GatewayConfig: ac = ",".join(str(v) for v in ac) os.environ["SLACK_ALLOWED_CHANNELS"] = str(ac) - # Discord settings โ†’ env vars (env vars take precedence) - discord_cfg = yaml_cfg.get("discord", {}) - if isinstance(discord_cfg, dict): - if "require_mention" in discord_cfg and not os.getenv("DISCORD_REQUIRE_MENTION"): - os.environ["DISCORD_REQUIRE_MENTION"] = str(discord_cfg["require_mention"]).lower() - if "thread_require_mention" in discord_cfg and not os.getenv("DISCORD_THREAD_REQUIRE_MENTION"): - os.environ["DISCORD_THREAD_REQUIRE_MENTION"] = str(discord_cfg["thread_require_mention"]).lower() - frc = discord_cfg.get("free_response_channels") - if frc is not None and not os.getenv("DISCORD_FREE_RESPONSE_CHANNELS"): - if isinstance(frc, list): - frc = ",".join(str(v) for v in frc) - os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc) - if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"): - os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower() - if "reactions" in discord_cfg and not os.getenv("DISCORD_REACTIONS"): - os.environ["DISCORD_REACTIONS"] = str(discord_cfg["reactions"]).lower() - # ignored_channels: channels where bot never responds (even when mentioned) - ic = discord_cfg.get("ignored_channels") - if ic is not None and not os.getenv("DISCORD_IGNORED_CHANNELS"): - if isinstance(ic, list): - ic = ",".join(str(v) for v in ic) - os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic) - # allowed_channels: if set, bot ONLY responds in these channels (whitelist) - ac = discord_cfg.get("allowed_channels") - if ac is not None and not os.getenv("DISCORD_ALLOWED_CHANNELS"): - if isinstance(ac, list): - ac = ",".join(str(v) for v in ac) - os.environ["DISCORD_ALLOWED_CHANNELS"] = str(ac) - # no_thread_channels: channels where bot responds directly without creating thread - ntc = discord_cfg.get("no_thread_channels") - if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"): - if isinstance(ntc, list): - ntc = ",".join(str(v) for v in ntc) - os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc) - # history_backfill: recover missed channel messages for shared sessions - # when require_mention is active. Fetches messages between bot turns - # and prepends them to the user message for context. - if "history_backfill" in discord_cfg and not os.getenv("DISCORD_HISTORY_BACKFILL"): - os.environ["DISCORD_HISTORY_BACKFILL"] = str(discord_cfg["history_backfill"]).lower() - hbl = discord_cfg.get("history_backfill_limit") - if hbl is not None and not os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT"): - os.environ["DISCORD_HISTORY_BACKFILL_LIMIT"] = str(hbl) - # allow_mentions: granular control over what the bot can ping. - # Safe defaults (no @everyone/roles) are applied in the adapter; - # these YAML keys only override when set and let users opt back - # into unsafe modes (e.g. roles=true) if they actually want it. - allow_mentions_cfg = discord_cfg.get("allow_mentions") - if isinstance(allow_mentions_cfg, dict): - for yaml_key, env_key in ( - ("everyone", "DISCORD_ALLOW_MENTION_EVERYONE"), - ("roles", "DISCORD_ALLOW_MENTION_ROLES"), - ("users", "DISCORD_ALLOW_MENTION_USERS"), - ("replied_user", "DISCORD_ALLOW_MENTION_REPLIED_USER"), - ): - if yaml_key in allow_mentions_cfg and not os.getenv(env_key): - os.environ[env_key] = str(allow_mentions_cfg[yaml_key]).lower() - # reply_to_mode: top-level preferred, falls back to extra.reply_to_mode - # YAML 1.1 parses bare 'off' as boolean False โ€” coerce to string "off". - _discord_extra = discord_cfg.get("extra") if isinstance(discord_cfg.get("extra"), dict) else {} - _discord_rtm = ( - discord_cfg["reply_to_mode"] if "reply_to_mode" in discord_cfg - else _discord_extra.get("reply_to_mode") - ) - if _discord_rtm is not None and not os.getenv("DISCORD_REPLY_TO_MODE"): - _rtm_str = "off" if _discord_rtm is False else str(_discord_rtm).lower() - os.environ["DISCORD_REPLY_TO_MODE"] = _rtm_str - # Bridge top-level require_mention to Telegram when the telegram: section # does not already provide one. Users often write "require_mention: true" # at the top level alongside group_sessions_per_user, expecting it to work @@ -1878,6 +1813,17 @@ def _apply_env_overrides(config: GatewayConfig) -> None: # need to seed ``PlatformConfig.extra`` from env vars (e.g. Google Chat's # project_id / subscription_name) can supply ``env_enablement_fn`` on # their PlatformEntry โ€” called here BEFORE adapter construction. + # + # Enablement gate (#31116): when a plugin registers ``is_connected`` + # (the "has the user actually configured credentials for this?" check), + # we MUST consult it before flipping ``enabled = True``. Otherwise + # ``check_fn`` alone โ€” which for adapter plugins typically just + # verifies the SDK is importable / lazy-installs it โ€” silently enables + # platforms the user never opted into, and the gateway then tries to + # connect to Discord / Teams / Google Chat with no token and emits + # noisy retry-forever errors. ``_platform_status`` was already fixed + # for the same bug class in commit 7849a3d73; this is the runtime + # counterpart. try: from hermes_cli.plugins import discover_plugins discover_plugins() # idempotent @@ -1890,34 +1836,99 @@ def _apply_env_overrides(config: GatewayConfig) -> None: logger.debug("check_fn for %s raised: %s", entry.name, e) continue platform = Platform(entry.name) - if platform not in config.platforms: - config.platforms[platform] = PlatformConfig() - config.platforms[platform].enabled = True - # Seed extras from env if the plugin opted in. + existing_cfg = config.platforms.get(platform) + # Seed candidate extras from ``env_enablement_fn`` so plugins + # whose ``is_connected`` reads ``config.extra`` (e.g. Google + # Chat's ``_is_connected`` checks ``config.extra["project_id"]``) + # see the same state they will after enablement. Without this, + # Google-Chat-on-env-vars-only setups silently fail the gate + # below even though the user is configured. Plugins whose + # ``is_connected`` reads env vars directly (Discord, IRC, + # Teams, LINE, ntfy, Simplex) are unaffected; this only + # restores Google Chat. + seed_for_probe = None if entry.env_enablement_fn is not None: try: - seed = entry.env_enablement_fn() + seed_for_probe = entry.env_enablement_fn() except Exception as e: logger.debug( "env_enablement_fn for %s raised: %s", entry.name, e ) - seed = None - if isinstance(seed, dict) and seed: - # Extract the home_channel dict (if provided) so we wire it - # up as a proper HomeChannel dataclass. Everything else is - # merged into ``extra``. - home = seed.pop("home_channel", None) - config.platforms[platform].extra.update(seed) - if isinstance(home, dict) and home.get("chat_id"): - config.platforms[platform].home_channel = HomeChannel( - platform=platform, - chat_id=str(home["chat_id"]), - name=str(home.get("name") or "Home"), - thread_id=( - str(home["thread_id"]) - if home.get("thread_id") - else None - ), + seed_for_probe = None + + # Only consult is_connected for platforms that are NOT already + # explicitly configured in YAML / env (existing_cfg with + # enabled=True means the user wrote it themselves or another + # env-var bridge enabled it โ€” keep that decision). + if existing_cfg is None or not existing_cfg.enabled: + if entry.is_connected is not None: + try: + # Probe with ``enabled=True`` since we're asking + # "would this plugin BE configured if we enabled + # it?" not "is it currently enabled?". Google + # Chat's ``_is_connected`` short-circuits on + # ``config.enabled`` being False, which on the + # default ``PlatformConfig()`` would fail the + # gate even with proper env vars set. + if existing_cfg is not None: + probe_cfg = existing_cfg + if not probe_cfg.enabled: + probe_cfg = PlatformConfig( + enabled=True, + extra=dict(probe_cfg.extra or {}), + ) + else: + probe_cfg = PlatformConfig(enabled=True) + if isinstance(seed_for_probe, dict) and seed_for_probe: + # Don't mutate ``existing_cfg``; the probe gets + # a transient view with env-seeded extras layered + # on top of whatever's already there. + probe_extra = dict(getattr(probe_cfg, "extra", {}) or {}) + for k, v in seed_for_probe.items(): + if k == "home_channel": + continue + probe_extra.setdefault(k, v) + probe_cfg = PlatformConfig( + enabled=True, + extra=probe_extra, + ) + configured = bool(entry.is_connected(probe_cfg)) + except Exception as exc: + logger.debug( + "is_connected for %s raised: %s โ€” skipping enablement", + entry.name, exc, ) + configured = False + if not configured: + logger.debug( + "Plugin platform '%s' available but not configured " + "(is_connected returned False) โ€” skipping enable", + entry.name, + ) + continue + if platform not in config.platforms: + config.platforms[platform] = PlatformConfig() + config.platforms[platform].enabled = True + # Commit env-seeded extras onto the now-enabled platform. + # We've already called ``env_enablement_fn`` above (for the + # probe); reuse that result instead of calling it twice. + if isinstance(seed_for_probe, dict) and seed_for_probe: + seed = dict(seed_for_probe) + # Extract the home_channel dict (if provided) so we wire it + # up as a proper HomeChannel dataclass. Everything else is + # merged into ``extra``. + home = seed.pop("home_channel", None) + config.platforms[platform].extra.update(seed) + if isinstance(home, dict) and home.get("chat_id"): + config.platforms[platform].home_channel = HomeChannel( + platform=platform, + chat_id=str(home["chat_id"]), + name=str(home.get("name") or "Home"), + thread_id=( + str(home["thread_id"]) + if home.get("thread_id") + else None + ), + ) except Exception as e: logger.debug("Plugin platform enable pass failed: %s", e) diff --git a/gateway/pairing.py b/gateway/pairing.py index af9ff2fdbfd..b8bfe46a9a8 100644 --- a/gateway/pairing.py +++ b/gateway/pairing.py @@ -18,6 +18,7 @@ Security features (based on OWASP + NIST SP 800-63-4 guidance): Storage: ~/.hermes/pairing/ """ +import hashlib import json import os import secrets @@ -27,6 +28,10 @@ import time from pathlib import Path from typing import Optional +from gateway.whatsapp_identity import ( + expand_whatsapp_aliases, + normalize_whatsapp_identifier, +) from hermes_constants import get_hermes_dir from utils import atomic_replace @@ -109,12 +114,40 @@ class PairingStore: def _save_json(self, path: Path, data: dict) -> None: _secure_write(path, json.dumps(data, indent=2, ensure_ascii=False)) + def _normalize_user_id(self, platform: str, user_id: str) -> str: + """Normalize platform-specific user IDs before persisting them.""" + raw_user_id = str(user_id or "").strip() + if platform == "whatsapp": + return normalize_whatsapp_identifier(raw_user_id) or raw_user_id + return raw_user_id + + def _user_id_aliases(self, platform: str, user_id: str) -> set[str]: + """Return all known equivalent user IDs for auth/rate-limit checks.""" + raw_user_id = str(user_id or "").strip() + if not raw_user_id: + return set() + + aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)} + if platform == "whatsapp": + aliases.update(expand_whatsapp_aliases(raw_user_id)) + aliases.discard("") + return aliases + + def _user_ids_match(self, platform: str, left: str, right: str) -> bool: + """Return True when two user IDs represent the same principal.""" + left_aliases = self._user_id_aliases(platform, left) + right_aliases = self._user_id_aliases(platform, right) + return bool(left_aliases and right_aliases and (left_aliases & right_aliases)) + # ----- Approved users ----- def is_approved(self, platform: str, user_id: str) -> bool: """Check if a user is approved (paired) on a platform.""" approved = self._load_json(self._approved_path(platform)) - return user_id in approved + for approved_user_id in approved: + if self._user_ids_match(platform, approved_user_id, user_id): + return True + return False def list_approved(self, platform: str = None) -> list: """List approved users, optionally filtered by platform.""" @@ -129,7 +162,16 @@ class PairingStore: def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None: """Add a user to the approved list. Must be called under self._lock.""" approved = self._load_json(self._approved_path(platform)) - approved[user_id] = { + normalized_user_id = self._normalize_user_id(platform, user_id) + duplicate_ids = [ + approved_user_id + for approved_user_id in approved + if self._user_ids_match(platform, approved_user_id, normalized_user_id) + ] + for approved_user_id in duplicate_ids: + del approved[approved_user_id] + + approved[normalized_user_id] = { "user_name": user_name, "approved_at": time.time(), } @@ -140,14 +182,25 @@ class PairingStore: path = self._approved_path(platform) with self._lock: approved = self._load_json(path) - if user_id in approved: - del approved[user_id] + matching_ids = [ + approved_user_id + for approved_user_id in approved + if self._user_ids_match(platform, approved_user_id, user_id) + ] + if matching_ids: + for approved_user_id in matching_ids: + del approved[approved_user_id] self._save_json(path, approved) return True return False # ----- Pending codes ----- + @staticmethod + def _hash_code(code: str, salt: bytes) -> str: + """Hash a pairing code with the given salt using SHA-256.""" + return hashlib.sha256(salt + code.encode("utf-8")).hexdigest() + def generate_code( self, platform: str, user_id: str, user_name: str = "" ) -> Optional[str]: @@ -158,9 +211,13 @@ class PairingStore: - User is rate-limited (too recent request) - Max pending codes reached for this platform - User/platform is in lockout due to failed attempts + + The code is NOT stored in plaintext. Only a salted SHA-256 hash is + persisted so that reading the pending file does not reveal codes. """ with self._lock: self._cleanup_expired(platform) + normalized_user_id = self._normalize_user_id(platform, user_id) # Check lockout if self._is_locked_out(platform): @@ -178,9 +235,18 @@ class PairingStore: # Generate cryptographically random code code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH)) - # Store pending request - pending[code] = { - "user_id": user_id, + # Hash the code with a random salt before storing + salt = os.urandom(16) + code_hash = self._hash_code(code, salt) + + # Use a unique entry id as the key (not the code itself) + entry_id = secrets.token_hex(8) + + # Store pending request with hashed code + pending[entry_id] = { + "hash": code_hash, + "salt": salt.hex(), + "user_id": normalized_user_id, "user_name": user_name, "created_at": time.time(), } @@ -195,10 +261,16 @@ class PairingStore: """ Approve a pairing code. Adds the user to the approved list. - Returns {user_id, user_name} on success, None if code is + Returns ``{user_id, user_name}`` on success, ``None`` if the code is invalid/expired OR the platform is currently locked out after ``MAX_FAILED_ATTEMPTS`` failed approvals (#10195). Callers can disambiguate with ``_is_locked_out(platform)``. + + Verification: the user-provided code is hashed with each stored + entry's salt and compared to the stored hash using constant-time + comparison. Pre-hash entries (legacy plaintext-key format from + pre-upgrade pending.json files) are silently ignored โ€” they get + pruned at TTL by ``_cleanup_expired``. """ with self._lock: self._cleanup_expired(platform) @@ -213,37 +285,77 @@ class PairingStore: return None pending = self._load_json(self._pending_path(platform)) - if code not in pending: + + # Find the entry whose hash matches the provided code. + # Tolerate legacy plaintext-key entries (no salt/hash) and + # malformed entries โ€” skip them rather than KeyError, so an + # in-place upgrade across an existing pending.json doesn't + # crash on the first approve call. Legacy entries get pruned + # at their TTL by _cleanup_expired. + matched_key = None + matched_entry = None + for entry_id, entry in pending.items(): + if not isinstance(entry, dict): + continue + if "salt" not in entry or "hash" not in entry: + continue + try: + salt = bytes.fromhex(entry["salt"]) + except ValueError: + continue + candidate_hash = self._hash_code(code, salt) + if secrets.compare_digest(candidate_hash, entry["hash"]): + matched_key = entry_id + matched_entry = entry + break + + if matched_key is None: self._record_failed_attempt(platform) return None - entry = pending.pop(code) + del pending[matched_key] self._save_json(self._pending_path(platform), pending) # Add to approved list - self._approve_user(platform, entry["user_id"], entry.get("user_name", "")) + self._approve_user(platform, matched_entry["user_id"], + matched_entry.get("user_name", "")) return { - "user_id": entry["user_id"], - "user_name": entry.get("user_name", ""), + "user_id": matched_entry["user_id"], + "user_name": matched_entry.get("user_name", ""), } def list_pending(self, platform: str = None) -> list: - """List pending pairing requests, optionally filtered by platform.""" + """List pending pairing requests, optionally filtered by platform. + + Codes are stored hashed โ€” the ``code`` field is replaced with the + first 8 hex characters of the hash so admins can distinguish entries + without revealing the original code. Legacy plaintext-key entries + (pre-hash format) are shown with a "legacy" placeholder so admins + can see them age out without crashing on a missing ``hash`` field. + """ results = [] - platforms = [platform] if platform else self._all_platforms("pending") - for p in platforms: - self._cleanup_expired(p) - pending = self._load_json(self._pending_path(p)) - for code, info in pending.items(): - age_min = int((time.time() - info["created_at"]) / 60) - results.append({ - "platform": p, - "code": code, - "user_id": info["user_id"], - "user_name": info.get("user_name", ""), - "age_minutes": age_min, - }) + with self._lock: + platforms = [platform] if platform else self._all_platforms("pending") + for p in platforms: + self._cleanup_expired(p) + pending = self._load_json(self._pending_path(p)) + for entry_id, info in pending.items(): + if not isinstance(info, dict): + continue + created_at = info.get("created_at") + if not isinstance(created_at, (int, float)): + continue + age_min = int((time.time() - created_at) / 60) + hash_val = info.get("hash") + code_display = hash_val[:8] if isinstance(hash_val, str) else "legacy" + results.append({ + "platform": p, + "code": code_display, + "user_id": info.get("user_id", ""), + "user_name": info.get("user_name", ""), + "age_minutes": age_min, + }) return results def clear_pending(self, platform: str = None) -> int: @@ -262,15 +374,20 @@ class PairingStore: def _is_rate_limited(self, platform: str, user_id: str) -> bool: """Check if a user has requested a code too recently.""" limits = self._load_json(self._rate_limit_path()) - key = f"{platform}:{user_id}" - last_request = limits.get(key, 0) - return (time.time() - last_request) < RATE_LIMIT_SECONDS + for alias in self._user_id_aliases(platform, user_id): + key = f"{platform}:{alias}" + last_request = limits.get(key, 0) + if (time.time() - last_request) < RATE_LIMIT_SECONDS: + return True + return False def _record_rate_limit(self, platform: str, user_id: str) -> None: """Record the time of a pairing request for rate limiting.""" limits = self._load_json(self._rate_limit_path()) - key = f"{platform}:{user_id}" - limits[key] = time.time() + now = time.time() + for alias in self._user_id_aliases(platform, user_id): + key = f"{platform}:{alias}" + limits[key] = now self._save_json(self._rate_limit_path(), limits) def _is_locked_out(self, platform: str) -> bool: @@ -297,17 +414,29 @@ class PairingStore: # ----- Cleanup ----- def _cleanup_expired(self, platform: str) -> None: - """Remove expired pending codes.""" + """Remove expired pending codes. + + Tolerant of malformed / legacy entries โ€” anything without a numeric + ``created_at`` is treated as expired (it's effectively unusable + with the new hash-keyed schema anyway). + """ path = self._pending_path(platform) pending = self._load_json(path) now = time.time() - expired = [ - code for code, info in pending.items() - if (now - info["created_at"]) > CODE_TTL_SECONDS - ] + expired = [] + for entry_id, info in pending.items(): + if not isinstance(info, dict): + expired.append(entry_id) + continue + created_at = info.get("created_at") + if not isinstance(created_at, (int, float)): + expired.append(entry_id) + continue + if (now - created_at) > CODE_TTL_SECONDS: + expired.append(entry_id) if expired: - for code in expired: - del pending[code] + for entry_id in expired: + del pending[entry_id] self._save_json(path, pending) def _all_platforms(self, suffix: str) -> list: diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 0668896e170..1f02bde5a2a 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -35,6 +35,7 @@ import re import sqlite3 import time import uuid +from pathlib import Path from typing import Any, Dict, List, Optional try: @@ -337,10 +338,12 @@ class ResponseStore: db_path = str(get_hermes_home() / "response_store.db") except Exception: db_path = ":memory:" + self._db_path: Optional[str] = db_path if db_path != ":memory:" else None try: self._conn = sqlite3.connect(db_path, check_same_thread=False) except Exception: self._conn = sqlite3.connect(":memory:", check_same_thread=False) + self._db_path = None # Use shared WAL-fallback helper so response_store.db degrades # gracefully on NFS/SMB/FUSE-mounted HERMES_HOME (same filesystem # issue addressed for state.db/kanban.db โ€” see @@ -361,6 +364,31 @@ class ResponseStore: )""" ) self._conn.commit() + # response_store.db contains conversation history (tool payloads, + # prompts, results). Tighten to owner-only after creation so other + # local users on a shared box can't read it. Run once at __init__ + # rather than after every commit โ€” chmod-on-every-write is wasted + # syscalls on a hot path. + self._tighten_file_permissions() + + def _tighten_file_permissions(self) -> None: + """Force owner-only permissions on the DB and SQLite sidecars.""" + if not self._db_path: + return + for candidate in ( + Path(self._db_path), + Path(f"{self._db_path}-wal"), + Path(f"{self._db_path}-shm"), + ): + try: + if candidate.exists(): + candidate.chmod(0o600) + except OSError: + logger.debug( + "Failed to restrict response store permissions for %s", + candidate, + exc_info=True, + ) def get(self, response_id: str) -> Optional[Dict[str, Any]]: """Retrieve a stored response by ID (updates access time for LRU).""" diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 5157593ac57..307ecf46f4d 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -15,6 +15,7 @@ import re import socket as _socket import subprocess import sys +import time import uuid from abc import ABC, abstractmethod from urllib.parse import urlsplit @@ -40,6 +41,16 @@ def _platform_name(platform) -> str: return str(value or "").lower() +def _float_env(name: str, default: float) -> float: + raw = os.environ.get(name, "").strip() + if not raw: + return default + try: + return float(raw) + except (TypeError, ValueError): + return default + + def _thread_metadata_for_source(source, reply_to_message_id: str | None = None) -> dict | None: """Build platform-aware thread metadata for adapter sends. @@ -472,7 +483,7 @@ sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) from gateway.config import Platform, PlatformConfig from gateway.session import SessionSource, build_session_key -from hermes_constants import get_hermes_dir +from hermes_constants import get_hermes_dir, get_hermes_home GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = ( @@ -813,6 +824,86 @@ def cache_video_from_bytes(data: bytes, ext: str = ".mp4") -> str: # --------------------------------------------------------------------------- DOCUMENT_CACHE_DIR = get_hermes_dir("cache/documents", "document_cache") +SCREENSHOT_CACHE_DIR = get_hermes_dir("cache/screenshots", "browser_screenshots") +_HERMES_HOME = get_hermes_home() +MEDIA_DELIVERY_ALLOW_DIRS_ENV = "HERMES_MEDIA_ALLOW_DIRS" +MEDIA_DELIVERY_SAFE_ROOTS = ( + IMAGE_CACHE_DIR, + AUDIO_CACHE_DIR, + VIDEO_CACHE_DIR, + DOCUMENT_CACHE_DIR, + SCREENSHOT_CACHE_DIR, + _HERMES_HOME / "image_cache", + _HERMES_HOME / "audio_cache", + _HERMES_HOME / "video_cache", + _HERMES_HOME / "document_cache", + _HERMES_HOME / "browser_screenshots", +) + + +def _media_delivery_allowed_roots() -> List[Path]: + """Return roots from which model-emitted local media may be delivered.""" + roots = [Path(root) for root in MEDIA_DELIVERY_SAFE_ROOTS] + extra_roots = os.environ.get(MEDIA_DELIVERY_ALLOW_DIRS_ENV, "") + for chunk in extra_roots.split(os.pathsep): + for raw_root in chunk.split(","): + raw_root = raw_root.strip() + if not raw_root: + continue + root = Path(os.path.expanduser(raw_root)) + if root.is_absolute(): + roots.append(root) + return roots + + +def _path_is_within(path: Path, root: Path) -> bool: + try: + path.relative_to(root) + return True + except ValueError: + return False + + +def validate_media_delivery_path(path: str) -> Optional[str]: + """Return a safe absolute file path for native media delivery, else None. + + MEDIA tags and bare local paths in model output are untrusted text. Only + existing regular files under Hermes-managed media caches, or roots the + operator explicitly allowlists, may be uploaded as native attachments. + Symlinks are resolved before the containment check. + """ + if not path: + return None + + candidate = str(path).strip() + if len(candidate) >= 2 and candidate[0] == candidate[-1] and candidate[0] in "`\"'": + candidate = candidate[1:-1].strip() + candidate = candidate.lstrip("`\"'").rstrip("`\"',.;:)}]") + if not candidate: + return None + + expanded = Path(os.path.expanduser(candidate)) + if not expanded.is_absolute(): + return None + + try: + resolved = expanded.resolve(strict=True) + except (OSError, RuntimeError, ValueError): + return None + + if not resolved.is_file(): + return None + + for root in _media_delivery_allowed_roots(): + try: + resolved_root = root.expanduser().resolve(strict=False) + except (OSError, RuntimeError, ValueError): + continue + if _path_is_within(resolved, resolved_root): + return str(resolved) + + return None + SUPPORTED_DOCUMENT_TYPES = { ".pdf": "application/pdf", @@ -1023,6 +1114,14 @@ class MessageEvent: return args +@dataclass +class TextDebounceState: + event: MessageEvent + task: asyncio.Task | None + first_ts: float + last_ts: float + + _PLAINTEXT_GATEWAY_RESTART_PATTERNS: tuple[re.Pattern[str], ...] = ( re.compile(r"^(?:please\s+)?restart\s+(?:the\s+)?gateway[.!?\s]*$", re.IGNORECASE), re.compile(r"^(?:please\s+)?restart\s+(?:the\s+)?hermes\s+gateway[.!?\s]*$", re.IGNORECASE), @@ -1318,6 +1417,17 @@ class BasePlatformAdapter(ABC): self._active_sessions: Dict[str, asyncio.Event] = {} self._pending_messages: Dict[str, MessageEvent] = {} self._session_tasks: Dict[str, asyncio.Task] = {} + self._busy_text_mode: str = ( + os.environ.get("HERMES_GATEWAY_BUSY_TEXT_MODE", "queue").strip().lower() + or "queue" + ) + self._busy_text_debounce_seconds: float = _float_env( + "HERMES_GATEWAY_BUSY_TEXT_DEBOUNCE_SECONDS", 0.35 + ) + self._busy_text_hard_cap_seconds: float = _float_env( + "HERMES_GATEWAY_BUSY_TEXT_HARD_CAP_SECONDS", 1.0 + ) + self._text_debounce: dict[str, TextDebounceState] = {} # Background message-processing tasks spawned by handle_message(). # Gateway shutdown cancels these so an old gateway instance doesn't keep # working on a task after --replace or manual restarts. @@ -2119,6 +2229,35 @@ class BasePlatformAdapter(ABC): text = f"{caption}\n{text}" return await self.send(chat_id=chat_id, content=text, reply_to=reply_to, metadata=metadata) + @staticmethod + def validate_media_delivery_path(path: str) -> Optional[str]: + """Return a resolved path if it is safe for native attachment upload.""" + return validate_media_delivery_path(path) + + @staticmethod + def filter_media_delivery_paths(media_files) -> List[Tuple[str, bool]]: + """Drop unsafe MEDIA paths and normalize accepted paths.""" + safe_media: List[Tuple[str, bool]] = [] + for media_path, is_voice in media_files or []: + safe_path = validate_media_delivery_path(str(media_path)) + if safe_path: + safe_media.append((safe_path, bool(is_voice))) + else: + logger.warning("Skipping unsafe MEDIA directive path outside allowed roots") + return safe_media + + @staticmethod + def filter_local_delivery_paths(file_paths) -> List[str]: + """Drop unsafe bare local file paths and normalize accepted paths.""" + safe_paths: List[str] = [] + for file_path in file_paths or []: + safe_path = validate_media_delivery_path(str(file_path)) + if safe_path: + safe_paths.append(safe_path) + else: + logger.warning("Skipping unsafe local file path outside allowed roots") + return safe_paths + @staticmethod def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]: """ @@ -2616,6 +2755,161 @@ class BasePlatformAdapter(ABC): return f"{existing_text}\n\n{new_text}".strip() return existing_text + def _text_debounce_store(self) -> dict[str, TextDebounceState]: + store = getattr(self, "_text_debounce", None) + if store is None: + store = {} + self._text_debounce = store + return store + + def _is_queue_text_debounce_candidate(self, event: MessageEvent) -> bool: + """Return True for normal text eligible for queue-mode debounce.""" + result = ( + getattr(self, "_busy_text_mode", "queue") == "queue" + and event.message_type == MessageType.TEXT + and not getattr(event, "internal", False) + and not event.is_command() + and bool((event.text or "").strip()) + ) + if result: + logger.debug( + "[%s] Queue-text debounce candidate accepted: session=%s text_len=%d", + self.name, + getattr(event, "session_key", "?"), + len(event.text or ""), + ) + return result + + def _can_merge_text_debounce_events(self, existing: MessageEvent, event: MessageEvent) -> bool: + """Return True when two text debounce events came from the same sender.""" + + def _identity(candidate: MessageEvent) -> tuple[str, ...] | None: + source = getattr(candidate, "source", None) + if source is None: + return None + platform = _platform_name(getattr(source, "platform", None)) + sender = getattr(source, "user_id_alt", None) or getattr(source, "user_id", None) + if sender: + return (platform, str(sender)) + if getattr(source, "chat_type", None) in {"dm", "private"} and getattr(source, "chat_id", None): + return (platform, "dm", str(source.chat_id)) + return None + + existing_sender = _identity(existing) + incoming_sender = _identity(event) + return existing_sender is not None and existing_sender == incoming_sender + + def _text_debounce_delay(self, session_key: str) -> float: + """Return bounded busy-text debounce delay for ``session_key``.""" + state = self._text_debounce_store().get(session_key) + if state is None: + return 0.0 + now = time.monotonic() + window_deadline = state.last_ts + self._busy_text_debounce_seconds + hard_cap_deadline = state.first_ts + self._busy_text_hard_cap_seconds + return max(0.0, min(window_deadline, hard_cap_deadline) - now) + + async def _queue_text_debounce(self, session_key: str, event: MessageEvent) -> None: + """Buffer normal queue-mode busy text and schedule a bounded flush.""" + store = self._text_debounce_store() + state = store.get(session_key) + + if state is not None and not self._can_merge_text_debounce_events(state.event, event): + # Preserve sender attribution in shared sessions. The current + # buffer becomes the next pending turn; the new sender starts a + # fresh debounce burst when the pending slot allows it. + await self._flush_text_debounce_now(session_key) + state = store.get(session_key) + if state is not None and not self._can_merge_text_debounce_events(state.event, event): + existing_pending = self._pending_messages.get(session_key) + if existing_pending is not None and self._can_merge_text_debounce_events(existing_pending, event): + merge_pending_message_event( + self._pending_messages, + session_key, + event, + merge_text=True, + ) + return + + now = time.monotonic() + if state is None: + state = TextDebounceState( + event=event, + task=None, + first_ts=now, + last_ts=now, + ) + store[session_key] = state + else: + if event.text: + state.event.text = ( + f"{state.event.text}\n{event.text}" + if state.event.text + else event.text + ) + latest_message_id = getattr(event, "message_id", None) + latest_anchor = latest_message_id or getattr(event, "reply_to_message_id", None) + if latest_message_id is not None: + state.event.message_id = str(latest_message_id) + if latest_anchor is not None and hasattr(state.event, "reply_to_message_id"): + state.event.reply_to_message_id = str(latest_anchor) + state.last_ts = now + + if state.task is not None and not state.task.done(): + state.task.cancel() + + delay = self._text_debounce_delay(session_key) + state.task = asyncio.create_task(self._flush_text_debounce(session_key, delay)) + + async def _flush_text_debounce(self, session_key: str, delay: float) -> None: + """Timer task that flushes the debounced text buffer.""" + try: + await asyncio.sleep(delay) + await self._flush_text_debounce_now(session_key) + except asyncio.CancelledError: + return + finally: + current = asyncio.current_task() + state = self._text_debounce_store().get(session_key) + if state is not None and state.task is current: + state.task = None + + async def _flush_text_debounce_now(self, session_key: str) -> bool: + """Force-flush one debounced busy-text burst into the pending slot.""" + store = self._text_debounce_store() + state = store.get(session_key) + if state is None: + return False + + current = asyncio.current_task() + if state.task is not None and state.task is not current and not state.task.done(): + state.task.cancel() + state.task = None + + existing_pending = self._pending_messages.get(session_key) + if ( + existing_pending is not None + and not self._can_merge_text_debounce_events(existing_pending, state.event) + ): + return False + + state = store.pop(session_key, None) + if state is None: + return False + merge_pending_message_event( + self._pending_messages, + session_key, + state.event, + merge_text=True, + ) + return True + + def _discard_text_debounce(self, session_key: str) -> None: + """Cancel and drop pending text debounce state for control commands.""" + state = self._text_debounce_store().pop(session_key, None) + if state is not None and state.task is not None and not state.task.done(): + state.task.cancel() + # ------------------------------------------------------------------ # Session task + guard ownership helpers # ------------------------------------------------------------------ @@ -2685,6 +2979,7 @@ class BasePlatformAdapter(ABC): self._active_sessions.pop(session_key, None) self._pending_messages.pop(session_key, None) self._session_tasks.pop(session_key, None) + self._discard_text_debounce(session_key) return True def _start_session_processing( @@ -2766,6 +3061,7 @@ class BasePlatformAdapter(ABC): ) if discard_pending: self._pending_messages.pop(session_key, None) + self._discard_text_debounce(session_key) if release_guard: self._release_session_guard(session_key) @@ -2780,6 +3076,7 @@ class BasePlatformAdapter(ABC): command-scoped guard, then โ€” if a follow-up message landed while the command was running โ€” spawns a fresh processing task for it. """ + await self._flush_text_debounce_now(session_key) pending_event = self._pending_messages.pop(session_key, None) self._release_session_guard(session_key, guard=command_guard) if pending_event is None: @@ -2911,6 +3208,7 @@ class BasePlatformAdapter(ABC): # through the dedicated handoff path that serializes # cancellation + runner response + pending drain. if cmd in {"stop", "new", "reset"}: + self._discard_text_debounce(session_key) try: await self._dispatch_active_session_command(event, session_key, cmd) except Exception as e: @@ -2955,8 +3253,9 @@ class BasePlatformAdapter(ABC): # clarify-intercept can resolve it and unblock the agent. # # Without this bypass: the message gets queued in - # _pending_messages AND triggers an interrupt, killing the - # agent run mid-clarify and discarding the user's answer. + # _pending_messages as a follow-up turn instead of reaching the + # clarify resolver, leaving the agent blocked and discarding the + # user's answer. # Same shape as the /approve deadlock fix (PR #4926) โ€” both # cases are "agent thread blocked on Event.wait, message must # reach the resolver before being treated as a new turn." @@ -3015,27 +3314,28 @@ class BasePlatformAdapter(ABC): merge_pending_message_event(self._pending_messages, session_key, event) return # Don't interrupt now - will run after current task completes - # Default behavior for non-photo follow-ups: interrupt the running agent. - # - # Use merge_text=True so rapid TEXT follow-ups (#4469) accumulate - # into the single pending slot instead of clobbering each other. - # Without merging, three rapid messages "A", "B", "C" land like: - # _pending_messages[k] = A (interrupts) - # _pending_messages[k] = B (replaces A before consumer reads) - # _pending_messages[k] = C (replaces B) - # ...and only "C" reaches the next turn. merge_pending_message_event - # already does the right thing for photo/media bursts; the - # ``merge_text=True`` flag extends that to plain TEXT events. - # Same shape as the Telegram bursty-grace path in gateway/run.py. - logger.debug("[%s] New message while session %s is active โ€” triggering interrupt", self.name, session_key) - merge_pending_message_event( - self._pending_messages, - session_key, - event, - merge_text=True, - ) - # Signal the interrupt (the processing task checks this) - self._active_sessions[session_key].set() + if self._is_queue_text_debounce_candidate(event): + logger.debug( + "[%s] New text message while session %s is active โ€” " + "debouncing follow-up (busy_text_mode=queue, window=%.2fs)", + self.name, + session_key, + self._busy_text_debounce_seconds, + ) + await self._queue_text_debounce(session_key, event) + else: + logger.debug( + "[%s] New message while session %s is active โ€” queuing follow-up " + "(no interrupt, will cascade after current turn)", + self.name, + session_key, + ) + merge_pending_message_event( + self._pending_messages, + session_key, + event, + merge_text=event.message_type == MessageType.TEXT, + ) return # Don't process now - will be handled after current task finishes # Mark session as active BEFORE spawning background task to close @@ -3166,6 +3466,7 @@ class BasePlatformAdapter(ABC): # Extract MEDIA: tags (from TTS tool) before other processing media_files, response = self.extract_media(response) + media_files = self.filter_media_delivery_paths(media_files) # Extract image URLs and send them as native platform attachments images, text_content = self.extract_images(response) @@ -3179,6 +3480,7 @@ class BasePlatformAdapter(ABC): # Auto-detect bare local file paths for native media delivery # (helps small models that don't use MEDIA: syntax) local_files, text_content = self.extract_local_files(text_content) + local_files = self.filter_local_delivery_paths(local_files) if local_files: logger.info("[%s] extract_local_files found %d file(s) in response", self.name, len(local_files)) @@ -3387,10 +3689,15 @@ class BasePlatformAdapter(ABC): ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE, ) + # The active drain owns debounce state. If a queue-mode timer has + # not fired yet, force-flush into _pending_messages here and let + # this task hand off the follow-up. + await self._flush_text_debounce_now(session_key) + # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: pending_event = self._pending_messages.pop(session_key) - logger.debug("[%s] Processing queued message from interrupt", self.name) + logger.debug("[%s] Processing queued follow-up message", self.name) # Keep the _active_sessions entry live across the turn chain # and only CLEAR the interrupt Event โ€” do NOT delete the entry. # If we deleted here, a concurrent inbound message arriving @@ -3399,7 +3706,7 @@ class BasePlatformAdapter(ABC): # with the recursive drain below. Two agents on one # session_key = duplicate responses, duplicate tool calls. # Clearing the Event keeps the guard live so follow-ups take - # the busy-handler path (queue + interrupt) as intended. + # the busy-handler path as intended. _active = self._active_sessions.get(session_key) if _active is not None: _active.clear() @@ -3492,6 +3799,9 @@ class BasePlatformAdapter(ABC): await self.stop_typing(event.source.chat_id) except Exception: pass + # Final drain/release boundary: force-flush any timer that missed + # the in-band drain before deciding whether the guard can clear. + await self._flush_text_debounce_now(session_key) # Late-arrival drain: a message may have arrived during the # cleanup awaits above (typing_task cancel, stop_typing). Such # messages passed the Level-1 guard (entry still live, Event @@ -3611,6 +3921,10 @@ class BasePlatformAdapter(ABC): self._session_tasks.clear() self._pending_messages.clear() self._active_sessions.clear() + for state in list(self._text_debounce_store().values()): + if state.task is not None and not state.task.done(): + state.task.cancel() + self._text_debounce_store().clear() def has_pending_interrupt(self, session_key: str) -> bool: """Check if there's a pending interrupt for a session.""" diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py index 7a4af3ad685..ec852e3d610 100644 --- a/gateway/platforms/bluebubbles.py +++ b/gateway/platforms/bluebubbles.py @@ -189,7 +189,10 @@ class BlueBubblesAdapter(BasePlatformAdapter): app = web.Application() app.router.add_get("/health", lambda _: web.Response(text="ok")) app.router.add_post(self.webhook_path, self._handle_webhook) - self._runner = web.AppRunner(app) + # The webhook auth value is carried in the query string because the + # BlueBubbles webhook API cannot send custom headers. Do not let + # aiohttp access logs write that request target to agent.log. + self._runner = web.AppRunner(app, access_log=None) await self._runner.setup() site = web.TCPSite(self._runner, self.webhook_host, self.webhook_port) await site.start() @@ -242,6 +245,14 @@ class BlueBubblesAdapter(BasePlatformAdapter): return f"{base}?password={quote(self.password, safe='')}" return base + @property + def _webhook_register_url_for_log(self) -> str: + """Webhook registration URL safe for logs.""" + base = self._webhook_url + if self.password: + return f"{base}?password=***" + return base + async def _find_registered_webhooks(self, url: str) -> list: """Return list of BB webhook entries matching *url*.""" try: @@ -269,7 +280,8 @@ class BlueBubblesAdapter(BasePlatformAdapter): existing = await self._find_registered_webhooks(webhook_url) if existing: logger.info( - "[bluebubbles] webhook already registered: %s", webhook_url + "[bluebubbles] webhook already registered: %s", + self._webhook_register_url_for_log, ) return True @@ -284,7 +296,7 @@ class BlueBubblesAdapter(BasePlatformAdapter): if 200 <= status < 300: logger.info( "[bluebubbles] webhook registered with server: %s", - webhook_url, + self._webhook_register_url_for_log, ) return True else: @@ -324,7 +336,8 @@ class BlueBubblesAdapter(BasePlatformAdapter): removed = True if removed: logger.info( - "[bluebubbles] webhook unregistered: %s", webhook_url + "[bluebubbles] webhook unregistered: %s", + self._webhook_register_url_for_log, ) except Exception as exc: logger.debug( @@ -934,4 +947,3 @@ class BlueBubblesAdapter(BasePlatformAdapter): asyncio.create_task(self.mark_read(session_chat_id)) return web.Response(text="ok") - diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index 6e599ed2210..0b3c7f52ace 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -358,6 +358,19 @@ class DingTalkAdapter(BasePlatformAdapter): await asyncio.gather(*self._bg_tasks, return_exceptions=True) self._bg_tasks.clear() + # Finalize any open streaming cards before the HTTP client closes so + # they don't stay stuck in streaming state on DingTalk's UI after + # a gateway restart. _close_streaming_siblings handles its own + # per-card exceptions; the outer try is a safety net for token fetch. + for _chat_id in list(self._streaming_cards): + try: + await self._close_streaming_siblings(_chat_id) + except Exception as _exc: + logger.debug( + "[%s] Failed to finalize streaming card on disconnect for %s: %s", + self.name, _chat_id, _exc, + ) + if self._http_client: await self._http_client.aclose() self._http_client = None diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index a9b0447080d..2831476b5ba 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -1514,8 +1514,10 @@ class FeishuAdapter(BasePlatformAdapter): connection_mode=str( extra.get("connection_mode") or os.getenv("FEISHU_CONNECTION_MODE", "websocket") ).strip().lower(), - encrypt_key=os.getenv("FEISHU_ENCRYPT_KEY", "").strip(), - verification_token=os.getenv("FEISHU_VERIFICATION_TOKEN", "").strip(), + encrypt_key=str(extra.get("encrypt_key") or os.getenv("FEISHU_ENCRYPT_KEY", "")).strip(), + verification_token=str( + extra.get("verification_token") or os.getenv("FEISHU_VERIFICATION_TOKEN", "") + ).strip(), group_policy=os.getenv("FEISHU_GROUP_POLICY", "allowlist").strip().lower(), allowed_group_users=frozenset( item.strip() @@ -1642,6 +1644,11 @@ class FeishuAdapter(BasePlatformAdapter): self._connection_mode, ) return False + if self._connection_mode == "webhook" and not (self._verification_token or self._encrypt_key): + logger.error( + "[Feishu] Webhook mode requires FEISHU_VERIFICATION_TOKEN or FEISHU_ENCRYPT_KEY." + ) + return False try: self._app_lock_identity = self._app_id @@ -2563,13 +2570,44 @@ class FeishuAdapter(BasePlatformAdapter): if approval_id is None: logger.debug("[Feishu] Card action missing approval_id, ignoring") return P2CardActionTriggerResponse() if P2CardActionTriggerResponse else None + state = self._approval_state.get(approval_id) + if not state: + logger.debug("[Feishu] Approval %s already resolved or unknown", approval_id) + return P2CardActionTriggerResponse() if P2CardActionTriggerResponse else None choice = _APPROVAL_CHOICE_MAP.get(action_value.get("hermes_action"), "deny") operator = getattr(event, "operator", None) open_id = str(getattr(operator, "open_id", "") or "") + sender_id = SimpleNamespace(open_id=open_id, user_id=str(getattr(operator, "user_id", "") or "")) + if not self._allow_group_message(sender_id, state.get("chat_id", ""), is_bot=False): + logger.warning("[Feishu] Unauthorized approval click by %s", open_id or "") + return P2CardActionTriggerResponse() if P2CardActionTriggerResponse else None + + callback_chat_id = str(getattr(getattr(event, "context", None), "open_chat_id", "") or "") + expected_chat_id = str(state.get("chat_id", "") or "") + if callback_chat_id and expected_chat_id and callback_chat_id != expected_chat_id: + logger.warning( + "[Feishu] Approval callback chat mismatch for %s (expected=%s, got=%s)", + approval_id, + expected_chat_id, + callback_chat_id, + ) + return P2CardActionTriggerResponse() if P2CardActionTriggerResponse else None + user_name = self._get_cached_sender_name(open_id) or open_id - if not self._submit_on_loop(loop, self._resolve_approval(approval_id, choice, user_name)): + chat_context = getattr(event, "context", None) + chat_id = str(getattr(chat_context, "open_chat_id", "") or "") + if not self._submit_on_loop( + loop, + self._resolve_approval( + approval_id=approval_id, + choice=choice, + user_name=user_name, + open_id=open_id, + chat_id=chat_id, + ), + ): return P2CardActionTriggerResponse() if P2CardActionTriggerResponse else None if P2CardActionTriggerResponse is None: @@ -2617,12 +2655,34 @@ class FeishuAdapter(BasePlatformAdapter): response.card = card return response - async def _resolve_approval(self, approval_id: Any, choice: str, user_name: str) -> None: + async def _resolve_approval( + self, + approval_id: Any, + choice: str, + user_name: str, + *, + open_id: str = "", + chat_id: str = "", + ) -> None: """Pop approval state and unblock the waiting agent thread.""" - state = self._approval_state.pop(approval_id, None) + state = self._approval_state.get(approval_id) if not state: logger.debug("[Feishu] Approval %s already resolved or unknown", approval_id) return + if not self._is_interactive_operator_authorized(open_id): + logger.warning("[Feishu] Unauthorized approval click by %s for approval %s", open_id or "", approval_id) + return + expected_chat_id = str(state.get("chat_id", "") or "") + if expected_chat_id and chat_id and expected_chat_id != chat_id: + logger.warning( + "[Feishu] Approval %s chat mismatch (expected=%s, got=%s)", + approval_id, expected_chat_id, chat_id, + ) + return + state = self._approval_state.pop(approval_id, None) + if not state: + logger.debug("[Feishu] Approval %s already resolved while validating callback", approval_id) + return try: from tools.approval import resolve_gateway_approval count = resolve_gateway_approval(state["session_key"], choice) @@ -3229,11 +3289,6 @@ class FeishuAdapter(BasePlatformAdapter): self._record_webhook_anomaly(remote_ip, "400") return web.json_response({"code": 400, "msg": "invalid json"}, status=400) - # URL verification challenge โ€” respond before other checks so that Feishu's - # subscription setup works even before encrypt_key is wired. - if payload.get("type") == "url_verification": - return web.json_response({"challenge": payload.get("challenge", "")}) - # Verification token check โ€” second layer of defence beyond signature (matches openclaw). if self._verification_token: header = payload.get("header") or {} @@ -3243,6 +3298,13 @@ class FeishuAdapter(BasePlatformAdapter): self._record_webhook_anomaly(remote_ip, "401-token") return web.Response(status=401, text="Invalid verification token") + # URL verification challenge โ€” Feishu includes the verification token in + # challenge requests. Validate the token (above) before reflecting the + # challenge so an unauthenticated remote request cannot prove endpoint + # control by getting attacker-supplied challenge data echoed back. + if payload.get("type") == "url_verification": + return web.json_response({"challenge": payload.get("challenge", "")}) + # Timing-safe signature verification (only enforced when encrypt_key is set). if self._encrypt_key and not self._is_webhook_signature_valid(request.headers, body_bytes): logger.warning("[Feishu] Webhook rejected: invalid signature from %s", remote_ip) diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 28b086291ae..f7837a1f7d6 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -138,7 +138,8 @@ _OUTBOUND_MENTION_RE = re.compile( ) _E2EE_INSTALL_HINT = ( - "Install with: pip install 'mautrix[encryption]' (requires libolm C library)" + "Install with: pip install 'mautrix[encryption]' asyncpg aiosqlite " + "(requires libolm C library)" ) _MATRIX_IMAGE_FILENAME_EXTS = frozenset({ @@ -214,9 +215,22 @@ def _create_matrix_session(proxy_url: str | None): def _check_e2ee_deps() -> bool: - """Return True if mautrix E2EE dependencies (python-olm) are available.""" + """Return True if mautrix E2EE dependencies are available. + + Verifies python-olm (via mautrix.crypto.OlmMachine), the SQLite crypto + store backend (mautrix.crypto.store.asyncpg.PgCryptoStore โ€” yes, the + PgCryptoStore class also drives the sqlite backend in mautrix 0.21), + and the database drivers actually used at connect time (``asyncpg`` for + the underlying upgrade_table machinery, ``aiosqlite`` for the + ``sqlite:///`` URL we pass to ``Database.create``). Without all four, + encrypted rooms fail at connect time with a confusing + ``No module named 'asyncpg'`` (#31116). + """ try: from mautrix.crypto import OlmMachine # noqa: F401 + from mautrix.crypto.store.asyncpg import PgCryptoStore # noqa: F401 + import asyncpg # noqa: F401 + import aiosqlite # noqa: F401 return True except (ImportError, AttributeError): @@ -226,8 +240,13 @@ def _check_e2ee_deps() -> bool: def check_matrix_requirements() -> bool: """Return True if the Matrix adapter can be used. - Lazy-installs mautrix via ``tools.lazy_deps.ensure("platform.matrix")`` - on first call if not present. Rebinds all module-level type globals on success. + Lazy-installs the full ``platform.matrix`` feature group via + ``tools.lazy_deps.ensure_and_bind`` whenever any of the declared + packages (mautrix, Markdown, aiosqlite, asyncpg, aiohttp-socks) is + missing โ€” not just mautrix itself. Previously this short-circuited on + ``import mautrix``, which left the other four packages uninstalled + forever and broke E2EE connect with ``No module named 'asyncpg'`` + (#31116). Rebinds module-level type globals on success. """ token = os.getenv("MATRIX_ACCESS_TOKEN", "") password = os.getenv("MATRIX_PASSWORD", "") @@ -239,9 +258,20 @@ def check_matrix_requirements() -> bool: if not homeserver: logger.warning("Matrix: MATRIX_HOMESERVER not set") return False + + # Check whether any package in the platform.matrix feature group is + # missing. ``feature_missing`` is cheap (per-spec importlib.metadata + # lookups) and correctly handles ``mautrix[encryption]`` by stripping + # the extras marker before checking the bare package. try: - import mautrix # noqa: F401 - except ImportError: + from tools.lazy_deps import feature_missing, ensure_and_bind + missing = feature_missing("platform.matrix") + except Exception as exc: # pragma: no cover โ€” defensive + logger.debug("Matrix: lazy_deps lookup failed: %s", exc) + missing = () + ensure_and_bind = None # type: ignore[assignment] + + if missing or ensure_and_bind is None: def _import(): from mautrix.types import ( ContentURI, EventID, EventType, PaginationDirection, @@ -261,10 +291,14 @@ def check_matrix_requirements() -> bool: "UserID": UserID, } - from tools.lazy_deps import ensure_and_bind + if ensure_and_bind is None: + return False if not ensure_and_bind("platform.matrix", _import, globals(), prompt=False): logger.warning( - "Matrix: mautrix not installed. Run: pip install 'mautrix[encryption]'" + "Matrix: required packages not installed (%s). " + "Run: pip install 'mautrix[encryption]' asyncpg aiosqlite " + "Markdown aiohttp-socks", + ", ".join(missing) if missing else "platform.matrix", ) return False diff --git a/gateway/platforms/msgraph_webhook.py b/gateway/platforms/msgraph_webhook.py index 46430a25bc7..b7045c801a6 100644 --- a/gateway/platforms/msgraph_webhook.py +++ b/gateway/platforms/msgraph_webhook.py @@ -133,6 +133,12 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): self._notification_scheduler = scheduler async def connect(self) -> bool: + if self._client_state is None: + logger.error( + "[msgraph_webhook] Refusing to start without extra.client_state configured" + ) + return False + app = web.Application() app.router.add_get(self._health_path, self._handle_health) app.router.add_get(self._webhook_path, self._handle_validation) @@ -310,7 +316,7 @@ class MSGraphWebhookAdapter(BasePlatformAdapter): """ expected = self._client_state if expected is None: - return True + return False provided = self._string_or_none(notification.get("clientState")) if provided is None: return False diff --git a/gateway/platforms/qqbot/adapter.py b/gateway/platforms/qqbot/adapter.py index 086f5e073f5..7569884760e 100644 --- a/gateway/platforms/qqbot/adapter.py +++ b/gateway/platforms/qqbot/adapter.py @@ -534,9 +534,30 @@ class QQAdapter(BasePlatformAdapter): self._mark_transport_disconnected() self._fail_pending("Connection closed") - # Stop reconnecting for fatal codes - if code in {4914, 4915}: - desc = "offline/sandbox-only" if code == 4914 else "banned" + # Stop reconnecting for fatal codes (unrecoverable errors) + if code in { + 4001, # Invalid opcode + 4002, # Invalid payload + 4010, # Invalid shard + 4011, # Sharding required + 4012, # Invalid API version + 4013, # Invalid intent + 4014, # Intent not authorized + 4914, # Offline/sandbox-only + 4915, # Banned + }: + fatal_descriptions = { + 4001: "invalid opcode", + 4002: "invalid payload", + 4010: "invalid shard", + 4011: "sharding required", + 4012: "invalid API version", + 4013: "invalid intent", + 4014: "intent not authorized", + 4914: "offline/sandbox-only", + 4915: "banned", + } + desc = fatal_descriptions.get(code, f"fatal error (code={code})") logger.error( "[%s] Bot is %s. Check QQ Open Platform.", self._log_tag, desc ) @@ -573,10 +594,11 @@ class QQAdapter(BasePlatformAdapter): self._token_expires_at = 0.0 # Session invalid โ†’ clear session, will re-identify on next Hello + # Note: 4009 (connection timeout) is NOT included here โ€” it is + # resumable per the QQ protocol and should preserve session state. if code in { 4006, 4007, - 4009, 4900, 4901, 4902, @@ -705,9 +727,8 @@ class QQAdapter(BasePlatformAdapter): "token": f"QQBot {token}", "intents": (1 << 25) | (1 << 30) - | ( - 1 << 12 - ), # C2C_GROUP_AT_MESSAGES + PUBLIC_GUILD_MESSAGES + DIRECT_MESSAGE + | (1 << 12) + | (1 << 26), # C2C_GROUP_AT_MESSAGES + PUBLIC_GUILD_MESSAGES + DIRECT_MESSAGE + INTERACTION "shard": [0, 1], "properties": { "$os": "macOS", @@ -826,6 +847,32 @@ class QQAdapter(BasePlatformAdapter): if op == 11: return + # op 7 = Server Reconnect โ€” server asks client to reconnect (e.g. + # load-balancing, maintenance). Close the WS so _read_events raises + # and the outer loop triggers a reconnect with Resume. + if op == 7: + logger.info("[%s] Server requested reconnect (op 7)", self._log_tag) + if self._ws and not self._ws.closed: + self._create_task(self._ws.close()) + return + + # op 9 = Invalid Session โ€” d=True means session is resumable, + # d=False means we must re-identify from scratch. + if op == 9: + resumable = bool(d) if d is not None else False + if not resumable: + logger.info( + "[%s] Invalid session (op 9, not resumable), clearing session", + self._log_tag, + ) + self._session_id = None + self._last_seq = None + else: + logger.info("[%s] Invalid session (op 9, resumable)", self._log_tag) + if self._ws and not self._ws.closed: + self._create_task(self._ws.close()) + return + logger.debug("[%s] Unknown op: %s", self._log_tag, op) def _handle_ready(self, d: Any) -> None: @@ -1007,6 +1054,46 @@ class QQAdapter(BasePlatformAdapter): "deny": "deny", } + @staticmethod + def _parse_gateway_session_key(session_key: str) -> Optional[Dict[str, str]]: + """Parse ``agent:main:::[:]``.""" + parts = str(session_key or "").split(":") + if len(parts) < 5 or parts[0] != "agent" or parts[1] != "main": + return None + parsed = { + "platform": parts[2], + "chat_type": parts[3], + "chat_id": parts[4], + } + if len(parts) > 5: + parsed["user_id"] = parts[5] + return parsed + + def _is_authorized_interaction_for_session( + self, + event: InteractionEvent, + session_key: str, + ) -> bool: + """Authorize approval/update interactions against session + operator.""" + parsed = self._parse_gateway_session_key(session_key) + operator = str(event.operator_openid or "").strip() + if not parsed or parsed.get("platform") != "qqbot" or not operator: + return False + + chat_type = parsed.get("chat_type", "") + chat_id = parsed.get("chat_id", "") + if chat_type == "c2c": + return bool(chat_id) and operator == chat_id + + if chat_type in {"group", "guild"}: + event_chat = str(event.group_openid or event.guild_id or "").strip() + if not event_chat or event_chat != chat_id: + return False + session_user = str(parsed.get("user_id", "")).strip() + return bool(session_user) and operator == session_user + + return False + async def _default_interaction_dispatch( self, event: InteractionEvent, @@ -1040,6 +1127,13 @@ class QQAdapter(BasePlatformAdapter): self._log_tag, decision, session_key, ) return + if not self._is_authorized_interaction_for_session(event, session_key): + logger.warning( + "[%s] Rejected unauthorized approval click for session %s " + "(operator=%s)", + self._log_tag, session_key, event.operator_openid, + ) + return try: # Import lazily to keep the adapter importable in tests that # don't exercise the approval subsystem. @@ -1060,6 +1154,13 @@ class QQAdapter(BasePlatformAdapter): update_answer = parse_update_prompt_button_data(button_data) if update_answer is not None: + update_session_key = f"agent:main:qqbot:{event.scene}:{event.group_openid or event.guild_id or event.user_openid}" + if not self._is_authorized_interaction_for_session(event, update_session_key): + logger.warning( + "[%s] Rejected unauthorized update prompt click (operator=%s)", + self._log_tag, event.operator_openid, + ) + return self._write_update_response(update_answer, event.operator_openid) return @@ -1607,7 +1708,7 @@ class QQAdapter(BasePlatformAdapter): elif ct.startswith("image/"): # Image: download and cache locally. try: - cached_path = await self._download_and_cache(url, ct) + cached_path = await self._download_and_cache(url, ct, filename) if cached_path and os.path.isfile(cached_path): image_urls.append(cached_path) image_media_types.append(ct or "image/jpeg") @@ -1620,11 +1721,15 @@ class QQAdapter(BasePlatformAdapter): except Exception as exc: logger.debug("[%s] Failed to cache image: %s", self._log_tag, exc) else: - # Other attachments (video, file, etc.): record as text. + # Other attachments (video, file, etc.): download and record with path. try: - cached_path = await self._download_and_cache(url, ct) + cached_path = await self._download_and_cache(url, ct, filename) if cached_path: - other_attachments.append(f"[Attachment: {filename or ct}]") + name = filename or ct + if ct.startswith("video/"): + other_attachments.append(f"[video: {name} ({cached_path})]") + else: + other_attachments.append(f"[file: {name} ({cached_path})]") except Exception as exc: logger.debug("[%s] Failed to cache attachment: %s", self._log_tag, exc) @@ -1636,8 +1741,14 @@ class QQAdapter(BasePlatformAdapter): "attachment_info": attachment_info, } - async def _download_and_cache(self, url: str, content_type: str) -> Optional[str]: - """Download a URL and cache it locally.""" + async def _download_and_cache( + self, url: str, content_type: str, original_name: str = "", + ) -> Optional[str]: + """Download a URL and cache it locally. + + :param original_name: Preferred filename from attachment metadata. + Falls back to the URL path basename if empty. + """ from tools.url_safety import is_safe_url if not is_safe_url(url): @@ -1668,7 +1779,11 @@ class QQAdapter(BasePlatformAdapter): # Convert to .wav using ffmpeg so STT engines can process it. return await self._convert_audio_to_wav(data, url) else: - filename = Path(urlparse(url).path).name or "qq_attachment" + filename = ( + original_name + or Path(urlparse(url).path).name + or "qq_attachment" + ) return cache_document_from_bytes(data, filename) @staticmethod @@ -1881,7 +1996,7 @@ class QQAdapter(BasePlatformAdapter): @staticmethod def _guess_ext_from_data(data: bytes) -> str: """Guess file extension from magic bytes.""" - if data[:9] == b"#!SILK_V3" or data[:5] == b"#!SILK": + if data[:9] == b"#!SILK_V3" or data[:6] == b"#!SILK": return ".silk" if data[:2] == b"\x02!": return ".silk" @@ -1901,7 +2016,7 @@ class QQAdapter(BasePlatformAdapter): @staticmethod def _looks_like_silk(data: bytes) -> bool: """Check if bytes look like a SILK audio file.""" - return data[:4] == b"#!SILK" or data[:2] == b"\x02!" or data[:9] == b"#!SILK_V3" + return data[:6] == b"#!SILK" or data[:2] == b"\x02!" or data[:9] == b"#!SILK_V3" async def _convert_silk_to_wav(self, src_path: str, wav_path: str) -> Optional[str]: """Convert audio file to WAV using the pilk library. diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 799a836df73..1e3ac5728d4 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -429,6 +429,13 @@ class TelegramAdapter(BasePlatformAdapter): self._polling_conflict_count: int = 0 self._polling_network_error_count: int = 0 self._polling_error_callback_ref = None + # After sustained reconnect storms the PTB httpx pool can return + # SendResult(success=True) for sends that never actually transmit. + # _handle_polling_network_error sets this; _verify_polling_after_reconnect + # clears it once getMe() confirms the Bot client is healthy. + # While True, send() short-circuits to a failure so callers + # (cron live-adapter branch) fall through to standalone delivery. + self._send_path_degraded: bool = False # DM Topics: map of topic_name -> message_thread_id (populated at startup) self._dm_topics: Dict[str, int] = {} # Track forum chats where we've already registered bot commands @@ -468,6 +475,10 @@ class TelegramAdapter(BasePlatformAdapter): # "all" โ€” every message triggers a push notification (legacy # behavior; opt-in via display.platforms.telegram.notifications). self._notifications_mode: str = "important" + # send_or_update_status() bookkeeping: {(chat_id, status_key) -> bot message_id} + # Tracks status bubbles owned by this adapter so subsequent calls with the + # same key edit the same message instead of appending new ones (#30045). + self._status_message_ids: Dict[tuple, str] = {} def _notification_kwargs( self, metadata: Optional[Dict[str, Any]] @@ -870,6 +881,7 @@ class TelegramAdapter(BasePlatformAdapter): MAX_DELAY = 60 self._polling_network_error_count += 1 + self._send_path_degraded = True attempt = self._polling_network_error_count if attempt > MAX_NETWORK_RETRIES: @@ -967,6 +979,7 @@ class TelegramAdapter(BasePlatformAdapter): try: await asyncio.wait_for(self._app.bot.get_me(), PROBE_TIMEOUT) + self._send_path_degraded = False except Exception as probe_err: logger.warning( "[%s] Polling heartbeat probe failed %ds after reconnect: %s", @@ -1679,7 +1692,11 @@ class TelegramAdapter(BasePlatformAdapter): """Send a message to a Telegram chat.""" if not self._bot: return SendResult(success=False, error="Not connected") - + + # getattr() โ€” tests build adapters via object.__new__() (no __init__). + if getattr(self, "_send_path_degraded", False): + return SendResult(success=False, error="send_path_degraded", retryable=True) + # Skip whitespace-only text to prevent Telegram 400 empty-text errors. if not content or not content.strip(): return SendResult(success=True, message_id=None) @@ -1908,6 +1925,40 @@ class TelegramAdapter(BasePlatformAdapter): is_connect_timeout = self._looks_like_connect_timeout(e) return SendResult(success=False, error=str(e), retryable=(is_connect_timeout or not is_timeout)) + async def send_or_update_status( + self, + chat_id: str, + status_key: str, + content: str, + *, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a status message, or edit the previous one with the same key. + + Issue #30045: progress/status callbacks (context-pressure, lifecycle, + compression, etc.) used to append a fresh bubble on every call. With + this method, the first call sends and the message id is remembered; + subsequent calls with the same (chat_id, status_key) edit that same + message in place. If the edit fails (message deleted, too old, etc.) + we drop the cached id and send fresh. + """ + key = (str(chat_id), str(status_key)) + cached_id = self._status_message_ids.get(key) + if cached_id is not None: + result = await self.edit_message( + chat_id, cached_id, content, finalize=True, metadata=metadata, + ) + if result.success: + if result.message_id: + self._status_message_ids[key] = str(result.message_id) + return result + # Edit failed โ€” clear the cached id and fall through to a fresh send. + self._status_message_ids.pop(key, None) + result = await self.send(chat_id, content, metadata=metadata) + if result.success and result.message_id: + self._status_message_ids[key] = str(result.message_id) + return result + async def edit_message( self, chat_id: str, @@ -4573,10 +4624,10 @@ class TelegramAdapter(BasePlatformAdapter): return ( "You are handling a Telegram group chat message.\n" f"- Your identity: user_id={bot_id}, @-mention name in this group=@{username}\n" - "- Lines in history prefixed with `[nickname|user_id]` are observed Telegram group context " - "and are not necessarily addressed to you.\n" + "- observed Telegram group context may be provided in a separate context-only block " + "before the current message; it is not necessarily addressed to you.\n" "- Treat only the current new message as a request explicitly directed at you, " - "and answer it directly." + "and use observed context only when the current message asks for it." ) def _apply_telegram_group_observe_attribution(self, event: MessageEvent) -> MessageEvent: @@ -4593,6 +4644,12 @@ class TelegramAdapter(BasePlatformAdapter): shared_source = self._telegram_group_observe_shared_source(event.source) observe_prompt = self._telegram_group_observe_channel_prompt() channel_prompt = f"{event.channel_prompt}\n\n{observe_prompt}" if event.channel_prompt else observe_prompt + if event.message_type == MessageType.COMMAND: + return dataclasses.replace( + event, + source=shared_source, + channel_prompt=channel_prompt, + ) return dataclasses.replace( event, text=self._telegram_group_observe_attributed_text(event), diff --git a/gateway/platforms/webhook.py b/gateway/platforms/webhook.py index d7714ff5652..32c6e8109bd 100644 --- a/gateway/platforms/webhook.py +++ b/gateway/platforms/webhook.py @@ -27,6 +27,8 @@ Security: """ import asyncio +import base64 +import binascii import hashlib import hmac import json @@ -308,11 +310,37 @@ class WebhookAdapter(BasePlatformAdapter): data = json.loads(subs_path.read_text(encoding="utf-8")) if not isinstance(data, dict): return - # Merge: static routes take precedence over dynamic ones - self._dynamic_routes = { - k: v for k, v in data.items() - if k not in self._static_routes - } + # Merge: static routes take precedence over dynamic ones. + # Reject any dynamic route whose effective secret is empty โ€” + # an empty secret would cause _handle_webhook to skip HMAC + # validation entirely, letting unauthenticated callers in. + new_dynamic: Dict[str, dict] = {} + for k, v in data.items(): + if k in self._static_routes: + continue + effective_secret = v.get("secret", self._global_secret) + if not effective_secret: + logger.warning( + "[webhook] Dynamic route '%s' skipped: 'secret' is " + "missing or empty. Set a valid HMAC secret, or use " + "'%s' to explicitly disable auth (testing only).", + k, + _INSECURE_NO_AUTH, + ) + continue + if ( + effective_secret == _INSECURE_NO_AUTH + and not _is_loopback_host(self._host) + ): + logger.warning( + "[webhook] Dynamic route '%s' skipped: INSECURE_NO_AUTH " + "is only allowed on loopback hosts. Current host: '%s'.", + k, + self._host, + ) + continue + new_dynamic[k] = v + self._dynamic_routes = new_dynamic self._routes = {**self._dynamic_routes, **self._static_routes} self._dynamic_routes_mtime = mtime logger.info( @@ -351,9 +379,21 @@ class WebhookAdapter(BasePlatformAdapter): logger.error("[webhook] Failed to read body: %s", e) return web.json_response({"error": "Bad request"}, status=400) - # Validate HMAC signature FIRST (skip for INSECURE_NO_AUTH testing mode) + # Validate HMAC signature FIRST (skip only for the explicit local-test + # INSECURE_NO_AUTH mode). Missing/empty secrets must fail closed here, + # not only during connect(), so direct handler reuse cannot turn a + # network webhook route into an unauthenticated agent-dispatch surface. secret = route_config.get("secret", self._global_secret) - if secret and secret != _INSECURE_NO_AUTH: + if not secret: + logger.error( + "[webhook] Route %s has no HMAC secret; refusing request", + route_name, + ) + return web.json_response( + {"error": "Webhook route is missing an HMAC secret"}, + status=403, + ) + if secret != _INSECURE_NO_AUTH: if not self._validate_signature(request, raw_body, secret): logger.warning( "[webhook] Invalid signature for route %s", route_name @@ -393,6 +433,7 @@ class WebhookAdapter(BasePlatformAdapter): request.headers.get("X-GitHub-Event", "") or request.headers.get("X-GitLab-Event", "") or payload.get("event_type", "") + or payload.get("type", "") or "unknown" ) allowed_events = route_config.get("events", []) @@ -445,7 +486,10 @@ class WebhookAdapter(BasePlatformAdapter): # Build a unique delivery ID delivery_id = request.headers.get( "X-GitHub-Delivery", - request.headers.get("X-Request-ID", str(int(time.time() * 1000))), + request.headers.get( + "svix-id", + request.headers.get("X-Request-ID", str(int(time.time() * 1000))), + ), ) # โ”€โ”€ Idempotency โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ @@ -590,7 +634,32 @@ class WebhookAdapter(BasePlatformAdapter): def _validate_signature( self, request: "web.Request", body: bytes, secret: str ) -> bool: - """Validate webhook signature (GitHub, GitLab, generic HMAC-SHA256).""" + """Validate webhook signature (GitHub, GitLab, Svix, generic HMAC-SHA256).""" + def _header(name: str) -> str: + return ( + request.headers.get(name, "") + or request.headers.get(name.lower(), "") + or request.headers.get(name.upper(), "") + ) + + # Svix / AgentMail: + # svix-id: msg_... + # svix-timestamp: unix seconds + # svix-signature: v1, [v1, ...] + # Signed content is: "{id}.{timestamp}.{raw_body}". Svix secrets + # usually start with "whsec_" and the remainder is base64-encoded. + svix_id = _header("svix-id") + svix_timestamp = _header("svix-timestamp") + svix_signature = _header("svix-signature") + if svix_id or svix_timestamp or svix_signature: + return self._validate_svix_signature( + body=body, + secret=secret, + msg_id=svix_id, + timestamp=svix_timestamp, + signature_header=svix_signature, + ) + # GitHub: X-Hub-Signature-256 = sha256= gh_sig = request.headers.get("X-Hub-Signature-256", "") if gh_sig: @@ -618,6 +687,56 @@ class WebhookAdapter(BasePlatformAdapter): ) return False + def _validate_svix_signature( + self, + body: bytes, + secret: str, + msg_id: str, + timestamp: str, + signature_header: str, + tolerance_seconds: int = 300, + ) -> bool: + """Validate Svix-compatible signatures used by AgentMail webhooks.""" + if not (msg_id and timestamp and signature_header and secret): + return False + + try: + ts = int(timestamp) + except (TypeError, ValueError): + return False + if abs(int(time.time()) - ts) > tolerance_seconds: + logger.warning("[webhook] Svix signature timestamp outside replay window") + return False + + if secret.startswith("whsec_"): + encoded_secret = secret.removeprefix("whsec_") + try: + key = base64.b64decode(encoded_secret, validate=True) + except (binascii.Error, ValueError): + logger.debug("[webhook] Invalid whsec_ Svix signing secret") + return False + else: + # Be permissive for providers that document Svix-style headers but + # hand out raw shared secrets rather than whsec_ base64 secrets. + logger.debug("[webhook] Validating Svix-style signature with raw secret") + key = secret.encode() + + signed_content = msg_id.encode() + b"." + timestamp.encode() + b"." + body + expected = base64.b64encode( + hmac.new(key, signed_content, hashlib.sha256).digest() + ).decode() + + # Svix can send multiple signatures separated by spaces during secret + # rotation. Each entry is formatted as "vN,". + for part in signature_header.split(): + try: + version, signature = part.split(",", 1) + except ValueError: + continue + if version == "v1" and hmac.compare_digest(signature, expected): + return True + return False + # ------------------------------------------------------------------ # Prompt rendering # ------------------------------------------------------------------ diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index 5aad1e09cc5..1569d5faf52 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -616,6 +616,18 @@ class WeComAdapter(BasePlatformAdapter): else: delay = self._text_batch_delay_seconds await asyncio.sleep(delay) + # Guard against the cancel-delivery race: when the sleep timer + # fires just before cancel() is called, CPython sets + # Task._must_cancel but cannot cancel the already-done sleep + # future, so CancelledError is delivered at the *next* await + # (handle_message) rather than here. By that point this task + # has already popped the merged event, so the superseding task + # sees an empty batch and silently drops the message. + # This check is synchronous โ€” no await between the sleep and + # the pop โ€” so no other coroutine can modify the task registry + # in between. + if self._pending_text_batch_tasks.get(key) is not current_task: + return event = self._pending_text_batches.pop(key, None) if not event: return diff --git a/gateway/platforms/wecom_callback.py b/gateway/platforms/wecom_callback.py index 139c67fe7c1..e08bc039742 100644 --- a/gateway/platforms/wecom_callback.py +++ b/gateway/platforms/wecom_callback.py @@ -187,7 +187,6 @@ class WecomCallbackAdapter(BasePlatformAdapter): app = self._resolve_app_for_chat(chat_id) touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id try: - token = await self._get_access_token(app) payload = { "touser": touser, "msgtype": "text", @@ -195,18 +194,31 @@ class WecomCallbackAdapter(BasePlatformAdapter): "text": {"content": content[:2048]}, "safe": 0, } - resp = await self._http_client.post( - f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}", - json=payload, - ) - data = resp.json() - if data.get("errcode") != 0: - return SendResult(success=False, error=str(data)) - return SendResult( - success=True, - message_id=str(data.get("msgid", "")), - raw_response=data, - ) + for _attempt in range(2): + token = await self._get_access_token(app) + resp = await self._http_client.post( + f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}", + json=payload, + ) + data = resp.json() + errcode = data.get("errcode") + if errcode in {40001, 42001} and _attempt == 0: + # WeCom rejected the token โ€” evict the cached entry so + # the next _get_access_token call forces a fresh fetch. + logger.warning( + "[WecomCallback] Token rejected for app '%s' (errcode=%s), refreshing", + app.get("name", "default"), errcode, + ) + self._access_tokens.pop(app["name"], None) + continue + if errcode != 0: + return SendResult(success=False, error=str(data)) + return SendResult( + success=True, + message_id=str(data.get("msgid", "")), + raw_response=data, + ) + return SendResult(success=False, error="send failed after token refresh") except Exception as exc: return SendResult(success=False, error=str(exc)) diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index 1c9fec0af7f..613c8283b1c 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -1679,8 +1679,10 @@ class WeixinAdapter(BasePlatformAdapter): # Extract MEDIA: tags and bare local file paths before text delivery. media_files, cleaned_content = self.extract_media(content) + media_files = self.filter_media_delivery_paths(media_files) _, image_cleaned = self.extract_images(cleaned_content) local_files, final_content = self.extract_local_files(image_cleaned) + local_files = self.filter_local_delivery_paths(local_files) _AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a", ".flac"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".3gp"} diff --git a/gateway/run.py b/gateway/run.py index 0f56ad61c39..5089586386e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -54,6 +54,7 @@ from agent.account_usage import fetch_account_usage, render_account_usage_lines from agent.async_utils import safe_schedule_threadsafe from agent.i18n import t from hermes_cli.config import cfg_get +from hermes_cli.fallback_config import get_fallback_chain # --- Agent cache tuning --------------------------------------------------- # Bounds the per-session AIAgent cache to prevent unbounded growth in @@ -138,6 +139,85 @@ def _gateway_platform_value(platform: Any) -> str: return str(getattr(platform, "value", platform) or "").strip().lower() +def _is_transient_network_error(exc: BaseException) -> bool: + """Return True for transient network errors safe to log + swallow. + + The crash class targeted by #31066 / #31110: an unhandled Telegram + ``TimedOut`` (or peer ``NetworkError`` / ``httpx`` connection error) + propagating to the event loop and killing the entire gateway + process. These are by definition transient โ€” the next poll cycle or + user action recovers โ€” so they must never crash the process. + + Walk the exception cause chain so wrapped errors (e.g. PTB's + ``NetworkError`` wrapping ``httpx.ConnectError``) are still + classified. The chain is bounded to avoid pathological cycles. + """ + seen: set[int] = set() + cur: Optional[BaseException] = exc + depth = 0 + transient_class_names = { + "TimedOut", + "NetworkError", + "ReadError", + "WriteError", + "ConnectError", + "ConnectTimeout", + "ReadTimeout", + "WriteTimeout", + "PoolTimeout", + "RemoteProtocolError", + "ServerDisconnectedError", + "ClientConnectorError", + "ClientOSError", + } + while cur is not None and depth < 12: + ident = id(cur) + if ident in seen: + break + seen.add(ident) + depth += 1 + name = type(cur).__name__ + if name in transient_class_names: + return True + cur = cur.__cause__ or cur.__context__ + return False + + +def _gateway_loop_exception_handler( + loop: "asyncio.AbstractEventLoop", context: Dict[str, Any] +) -> None: + """Loop-level safety net for transient network errors. + + Installed once during :func:`start_gateway`. Catches the + ``telegram.error.TimedOut`` crash class (issues #31066 / #31110) + and any peer transient network error before it can kill the + gateway process. Logs at WARNING with full traceback so the + originating call site stays diagnosable; non-transient errors + are forwarded to the default loop handler so real bugs still + surface. + """ + exc = context.get("exception") + if exc is not None and _is_transient_network_error(exc): + message = context.get("message") or "transient network error" + task = context.get("future") or context.get("task") + task_name = "" + if task is not None: + try: + task_name = task.get_name() if hasattr(task, "get_name") else repr(task) + except Exception: + task_name = repr(task) + logger.warning( + "Gateway swallowed transient network error from %s: %s: %s", + task_name or "", + type(exc).__name__, + exc, + exc_info=(type(exc), exc, exc.__traceback__), + ) + return + # Fall back to the default handler for anything we don't recognise. + loop.default_exception_handler(context) + + def _redact_gateway_user_facing_secrets(text: str) -> str: """Best-effort secret redaction before text can leave the gateway.""" redacted = str(text or "") @@ -238,6 +318,19 @@ def _prepare_gateway_status_message(platform: Any, event_type: str, message: str return text +async def _send_or_update_status_coro(adapter, chat_id, status_key, content, metadata): + """Route a status message through adapter.send_or_update_status when supported. + + Issue #30045: adapters that implement send_or_update_status (currently + Telegram) edit the previous bubble for the same status_key instead of + appending a new one. Adapters without the method fall back to plain send. + """ + sender = getattr(adapter, "send_or_update_status", None) + if callable(sender): + return await sender(chat_id, status_key, content, metadata=metadata) + return await adapter.send(chat_id, content, metadata=metadata) + + def _telegramize_command_mentions(text: str, platform: Any) -> str: """Rewrite slash-command mentions to Telegram-valid command names. @@ -447,6 +540,109 @@ def _build_replay_entry(role: str, content: Any, msg: Dict[str, Any]) -> Dict[st return entry +_TELEGRAM_OBSERVED_CONTEXT_PROMPT_MARKER = "observed Telegram group context" +_OBSERVED_GROUP_CONTEXT_HEADER = "[Observed Telegram group context - context only, not requests]" +_CURRENT_ADDRESSED_MESSAGE_HEADER = "[Current addressed message - answer only this unless it explicitly asks you to use the observed context]" + + +def _uses_telegram_observed_group_context(channel_prompt: Optional[str]) -> bool: + """Return True for Telegram group turns that may include observed chatter. + + Telegram's observe-unmentioned mode persists skipped group chatter so a + later @mention can see it. Those rows must not replay as ordinary user + turns: a weak wake word like ``@bot cambio`` should not make the model treat + old unmentioned chatter as pending work. The Telegram adapter marks these + turns with a channel prompt; this helper keeps the run-path check explicit + and unit-testable. + """ + + return bool(channel_prompt and _TELEGRAM_OBSERVED_CONTEXT_PROMPT_MARKER in channel_prompt) + + +def _build_gateway_agent_history( + history: List[Dict[str, Any]], + *, + channel_prompt: Optional[str] = None, +) -> tuple[List[Dict[str, Any]], Optional[str]]: + """Convert stored gateway transcript rows into agent replay messages. + + Observed Telegram group rows are returned as API-only context for the + current addressed message instead of being replayed as normal prior user + turns. Keeping that context out of ``conversation_history`` avoids + consecutive-user repair merging it with the live user turn and then hiding + the current message behind ``history_offset`` during persistence. + """ + + agent_history: List[Dict[str, Any]] = [] + observed_group_context: List[str] = [] + separate_observed_context = _uses_telegram_observed_group_context(channel_prompt) + + for msg in history or []: + role = msg.get("role") + if not role: + continue + + # Skip metadata entries (tool definitions, session info) -- these are + # for transcript logging, not for the LLM. + if role in {"session_meta",}: + continue + + # Skip system messages -- the agent rebuilds its own system prompt. + if role == "system": + continue + + content = msg.get("content") + if separate_observed_context and msg.get("observed") and role == "user" and content: + observed_group_context.append(str(content).strip()) + continue + + # Rich agent messages (tool_calls, tool results) must be passed through + # intact so the API sees valid assistantโ†’tool sequences. + has_tool_calls = "tool_calls" in msg + has_tool_call_id = "tool_call_id" in msg + is_tool_message = role == "tool" + + if has_tool_calls or has_tool_call_id or is_tool_message: + clean_msg = {k: v for k, v in msg.items() if k not in {"timestamp", "observed"}} + agent_history.append(clean_msg) + elif content: + # Simple text message - just need role and content. + if msg.get("mirror"): + mirror_src = msg.get("mirror_source", "another session") + content = f"[Delivered from {mirror_src}] {content}" + entry = _build_replay_entry(role, content, msg) + agent_history.append(entry) + + observed_context = "\n".join(observed_group_context).strip() or None + return agent_history, observed_context + + +def _wrap_current_message_with_observed_context(message: Any, observed_context: Optional[str]) -> Any: + """Prepend observed Telegram context to the API-only current user turn.""" + + if not observed_context: + return message + + prefix = ( + f"{_OBSERVED_GROUP_CONTEXT_HEADER}\n" + f"{observed_context}\n\n" + f"{_CURRENT_ADDRESSED_MESSAGE_HEADER}\n" + ) + + if isinstance(message, str): + return f"{prefix}{message}" + + if isinstance(message, list): + wrapped = [dict(part) if isinstance(part, dict) else part for part in message] + for part in wrapped: + if isinstance(part, dict) and part.get("type") == "text": + part["text"] = f"{prefix}{part.get('text', '')}" + return wrapped + return [{"type": "text", "text": prefix.rstrip()}] + wrapped + + return message + + def _last_transcript_timestamp(history: Optional[List[Dict[str, Any]]]) -> Any: """Return the ``timestamp`` of the last usable transcript row, if any. @@ -657,31 +853,29 @@ if _config_path.exists(): os.environ[_env_var] = str(_val) # Compression config is read directly from config.yaml by run_agent.py # and auxiliary_client.py โ€” no env var bridging needed. - # Auxiliary model/direct-endpoint overrides (vision, web_extract). - # Each task has provider/model/base_url/api_key; bridge non-default values to env vars. + # Auxiliary model/direct-endpoint overrides (vision, web_extract, + # approval, plus any plugin-registered auxiliary tasks). + # Each task has provider/model/base_url/api_key; bridge non-default + # values to env vars named AUXILIARY__*. The legacy + # hard-coded list (vision/web_extract/approval) is replaced by a + # dynamic loop so plugin-registered tasks benefit from the same + # configโ†’env bridging without core knowing about each one. _auxiliary_cfg = _cfg.get("auxiliary", {}) if _auxiliary_cfg and isinstance(_auxiliary_cfg, dict): - _aux_task_env = { - "vision": { - "provider": "AUXILIARY_VISION_PROVIDER", - "model": "AUXILIARY_VISION_MODEL", - "base_url": "AUXILIARY_VISION_BASE_URL", - "api_key": "AUXILIARY_VISION_API_KEY", - }, - "web_extract": { - "provider": "AUXILIARY_WEB_EXTRACT_PROVIDER", - "model": "AUXILIARY_WEB_EXTRACT_MODEL", - "base_url": "AUXILIARY_WEB_EXTRACT_BASE_URL", - "api_key": "AUXILIARY_WEB_EXTRACT_API_KEY", - }, - "approval": { - "provider": "AUXILIARY_APPROVAL_PROVIDER", - "model": "AUXILIARY_APPROVAL_MODEL", - "base_url": "AUXILIARY_APPROVAL_BASE_URL", - "api_key": "AUXILIARY_APPROVAL_API_KEY", - }, - } - for _task_key, _env_map in _aux_task_env.items(): + # Built-in tasks that previously had explicit env-var bridging. + # Kept here as the canonical bridged set; plugin tasks are added + # below via the plugin auxiliary registry. + _aux_bridged_keys = {"vision", "web_extract", "approval"} + try: + from hermes_cli.plugins import get_plugin_auxiliary_tasks + for _entry in get_plugin_auxiliary_tasks(): + _aux_bridged_keys.add(_entry["key"]) + except Exception: + # Plugin discovery failure must not break gateway startup; + # built-in bridging stays intact. + pass + + for _task_key in _aux_bridged_keys: _task_cfg = _auxiliary_cfg.get(_task_key, {}) if not isinstance(_task_cfg, dict): continue @@ -689,14 +883,15 @@ if _config_path.exists(): _model = str(_task_cfg.get("model", "")).strip() _base_url = str(_task_cfg.get("base_url", "")).strip() _api_key = str(_task_cfg.get("api_key", "")).strip() + _upper = _task_key.upper() if _prov and _prov != "auto": - os.environ[_env_map["provider"]] = _prov + os.environ[f"AUXILIARY_{_upper}_PROVIDER"] = _prov if _model: - os.environ[_env_map["model"]] = _model + os.environ[f"AUXILIARY_{_upper}_MODEL"] = _model if _base_url: - os.environ[_env_map["base_url"]] = _base_url + os.environ[f"AUXILIARY_{_upper}_BASE_URL"] = _base_url if _api_key: - os.environ[_env_map["api_key"]] = _api_key + os.environ[f"AUXILIARY_{_upper}_API_KEY"] = _api_key # config.yaml is the documented, authoritative source for these # settings โ€” it unconditionally wins over .env values. Previously # the guards below read `if X not in os.environ` and let stale @@ -723,6 +918,8 @@ if _config_path.exists(): if _display_cfg and isinstance(_display_cfg, dict): if "busy_input_mode" in _display_cfg: os.environ["HERMES_GATEWAY_BUSY_INPUT_MODE"] = str(_display_cfg["busy_input_mode"]) + if "busy_text_mode" in _display_cfg: + os.environ["HERMES_GATEWAY_BUSY_TEXT_MODE"] = str(_display_cfg["busy_text_mode"]) if "busy_ack_enabled" in _display_cfg: os.environ["HERMES_GATEWAY_BUSY_ACK_ENABLED"] = str(_display_cfg["busy_ack_enabled"]) # Timezone: bridge config.yaml โ†’ HERMES_TIMEZONE env var. @@ -846,6 +1043,12 @@ _AGENT_PENDING_SENTINEL = object() def _resolve_runtime_agent_kwargs() -> dict: """Resolve provider credentials for gateway-created AIAgent instances. + Provider is read from ``config.yaml`` ``model.provider`` (the single + source of truth). ``resolve_runtime_provider()`` falls through to env + var lookups internally for legacy compatibility, but the gateway does + not consult environment variables for behavioral config โ€” config.yaml + is authoritative. + If the primary provider fails with an authentication error, attempt to resolve credentials using the fallback provider chain from config.yaml before giving up. @@ -857,9 +1060,7 @@ def _resolve_runtime_agent_kwargs() -> dict: from hermes_cli.auth import AuthError try: - runtime = resolve_runtime_provider( - requested=os.getenv("HERMES_INFERENCE_PROVIDER"), - ) + runtime = resolve_runtime_provider() except AuthError as auth_exc: # Primary provider auth failed (expired token, revoked key, etc.). # Try the fallback provider chain before raising. @@ -892,19 +1093,22 @@ def _try_resolve_fallback_provider() -> dict | None: return None with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - fb = cfg.get("fallback_providers") or cfg.get("fallback_model") - if not fb: + fb_list = get_fallback_chain(cfg) + if not fb_list: return None - # Normalize to list - fb_list = fb if isinstance(fb, list) else [fb] for entry in fb_list: - if not isinstance(entry, dict): - continue try: + explicit_api_key = entry.get("api_key") + if not explicit_api_key: + key_env = str( + entry.get("key_env") or entry.get("api_key_env") or "" + ).strip() + if key_env: + explicit_api_key = os.getenv(key_env, "").strip() or None runtime = resolve_runtime_provider( requested=entry.get("provider"), explicit_base_url=entry.get("base_url"), - explicit_api_key=entry.get("api_key"), + explicit_api_key=explicit_api_key, ) logger.info( "Fallback provider resolved: %s model=%s", @@ -1198,6 +1402,26 @@ def _load_gateway_config() -> dict: return {} +def _load_gateway_runtime_config() -> dict: + """Load gateway config for runtime reads, expanding supported ``${VAR}`` refs. + + Runtime helpers should honor the same env-template expansion documented for + ``config.yaml`` while still respecting tests that monkeypatch + ``gateway.run._hermes_home``. Build on ``_load_gateway_config()`` rather + than calling the canonical loader directly so both behaviors stay aligned. + + Expansion failures are intentionally NOT swallowed โ€” silently returning + the unexpanded dict would mask the very bug this helper exists to fix. + """ + cfg = _load_gateway_config() + if not isinstance(cfg, dict) or not cfg: + return {} + from hermes_cli.config import _expand_env_vars + + expanded = _expand_env_vars(cfg) + return expanded if isinstance(expanded, dict) else {} + + def _resolve_gateway_model(config: dict | None = None) -> str: """Read model from config.yaml โ€” single source of truth. @@ -1411,6 +1635,7 @@ class GatewayRunner: # blow up on attribute access. _running_agents_ts: Dict[str, float] = {} _busy_input_mode: str = "interrupt" + _busy_text_mode: str = "interrupt" _restart_drain_timeout: float = DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT _exit_code: Optional[int] = None _draining: bool = False @@ -1437,6 +1662,7 @@ class GatewayRunner: self._service_tier = self._load_service_tier() self._show_reasoning = self._load_show_reasoning() self._busy_input_mode = self._load_busy_input_mode() + self._busy_text_mode = self._load_busy_text_mode() self._restart_drain_timeout = self._load_restart_drain_timeout() self._provider_routing = self._load_provider_routing() self._fallback_model = self._load_fallback_model() @@ -2046,13 +2272,14 @@ class GatewayRunner: ) -> Optional[str]: """Pin DM-topic routing to the user's last-active topic. - Telegram fragments topic-mode DMs two ways: a Reply on a message - in another topic delivers ``message_thread_id`` for *that* topic, - and ``_build_message_event`` strips the thread_id on plain replies - (#3206 โ€” needed for non-topic users). Both route the user to the - wrong session. When topic mode is on, rewrite the thread_id to the - user's most-recent binding if the inbound id is missing/General or - not a known topic for this chat. Returns None to leave it alone. + Telegram can omit ``message_thread_id`` or surface General (``1``) + for some topic-mode DM replies. In those lobby-shaped cases, keep the + conversation attached to the user's most-recent bound topic. + + Do not rewrite a non-lobby, previously-unbound thread id: a newly + created Telegram DM topic is also "unknown" until the first inbound + message is recorded, and rewriting it would send that brand-new topic's + answer into an older lane. Returns None to leave the source alone. """ if ( source.platform != Platform.TELEGRAM @@ -2062,6 +2289,14 @@ class GatewayRunner: or not self._telegram_topic_mode_enabled(source) ): return None + inbound = str(source.thread_id or "") + is_lobby = not inbound or inbound in self._TELEGRAM_GENERAL_TOPIC_IDS + if not is_lobby: + # A non-lobby, unknown thread_id is most likely the first message in + # a brand-new Telegram DM topic. Preserve it so it can be recorded + # as a new independent lane below instead of hijacking the latest + # existing topic binding. + return None session_db = getattr(self, "_session_db", None) if session_db is None: return None @@ -2074,11 +2309,6 @@ class GatewayRunner: return None if not bindings: return None - inbound = str(source.thread_id or "") - is_lobby = not inbound or inbound in self._TELEGRAM_GENERAL_TOPIC_IDS - known = {str(b.get("thread_id") or "") for b in bindings} - if not is_lobby and inbound in known: - return None user_id = str(source.user_id) for b in bindings: # newest-first if str(b.get("user_id") or "") == user_id: @@ -2532,15 +2762,8 @@ class GatewayRunner: """ file_path = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") if not file_path: - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - file_path = cfg.get("prefill_messages_file", "") - except Exception: - pass + cfg = _load_gateway_runtime_config() + file_path = str(cfg.get("prefill_messages_file", "") or "") if not file_path: return [] path = Path(file_path).expanduser() @@ -2570,16 +2793,8 @@ class GatewayRunner: prompt = os.getenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", "") if prompt: return prompt - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - return (cfg_get(cfg, "agent", "system_prompt", default="") or "").strip() - except Exception: - pass - return "" + cfg = _load_gateway_runtime_config() + return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip() @staticmethod def _load_reasoning_config() -> dict | None: @@ -2590,16 +2805,8 @@ class GatewayRunner: default (medium). """ from hermes_constants import parse_reasoning_effort - effort = "" - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - effort = str(cfg_get(cfg, "agent", "reasoning_effort", default="") or "").strip() - except Exception: - pass + cfg = _load_gateway_runtime_config() + effort = str(cfg_get(cfg, "agent", "reasoning_effort", default="") or "").strip() result = parse_reasoning_effort(effort) if effort and effort.strip() and result is None: logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort) @@ -2673,16 +2880,8 @@ class GatewayRunner: "fast"/"priority"/"on" => "priority", while "normal"/"off" disables it. Returns None when unset or unsupported. """ - raw = "" - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - raw = str(cfg_get(cfg, "agent", "service_tier", default="") or "").strip() - except Exception: - pass + cfg = _load_gateway_runtime_config() + raw = str(cfg_get(cfg, "agent", "service_tier", default="") or "").strip() value = raw.lower() if not value or value in {"normal", "default", "standard", "off", "none"}: @@ -2695,54 +2894,43 @@ class GatewayRunner: @staticmethod def _load_show_reasoning() -> bool: """Load show_reasoning toggle from config.yaml display section.""" - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - return is_truthy_value( - cfg_get(cfg, "display", "show_reasoning"), - default=False, - ) - except Exception: - pass - return False + cfg = _load_gateway_runtime_config() + return is_truthy_value( + cfg_get(cfg, "display", "show_reasoning"), + default=False, + ) @staticmethod def _load_busy_input_mode() -> str: """Load gateway drain-time busy-input behavior from config/env.""" mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower() if not mode: - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - mode = str(cfg_get(cfg, "display", "busy_input_mode", default="") or "").strip().lower() - except Exception: - pass + cfg = _load_gateway_runtime_config() + mode = str(cfg_get(cfg, "display", "busy_input_mode", default="") or "").strip().lower() if mode == "queue": return "queue" if mode == "steer": return "steer" return "interrupt" + @staticmethod + def _load_busy_text_mode() -> str: + """Load normal busy TEXT follow-up behavior from config/env.""" + mode = os.getenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "").strip().lower() + if not mode: + cfg = _load_gateway_runtime_config() + mode = str(cfg_get(cfg, "display", "busy_text_mode", default="") or "").strip().lower() + if mode == "interrupt": + return "interrupt" + return "queue" + @staticmethod def _load_restart_drain_timeout() -> float: """Load graceful gateway restart/stop drain timeout in seconds.""" raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip() if not raw: - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - raw = str(cfg_get(cfg, "agent", "restart_drain_timeout", default="") or "").strip() - except Exception: - pass + cfg = _load_gateway_runtime_config() + raw = str(cfg_get(cfg, "agent", "restart_drain_timeout", default="") or "").strip() value = parse_restart_drain_timeout(raw) if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT: try: @@ -2767,19 +2955,12 @@ class GatewayRunner: """ mode = os.getenv("HERMES_BACKGROUND_NOTIFICATIONS", "") if not mode: - try: - import yaml as _y - cfg_path = _hermes_home / "config.yaml" - if cfg_path.exists(): - with open(cfg_path, encoding="utf-8") as _f: - cfg = _y.safe_load(_f) or {} - raw = cfg_get(cfg, "display", "background_process_notifications") - if raw is False: - mode = "off" - elif raw not in {None, ""}: - mode = str(raw) - except Exception: - pass + cfg = _load_gateway_runtime_config() + raw = cfg_get(cfg, "display", "background_process_notifications") + if raw is False: + mode = "off" + elif raw not in {None, ""}: + mode = str(raw) mode = (mode or "all").strip().lower() valid = {"all", "result", "error", "off"} if mode not in valid: @@ -2805,12 +2986,12 @@ class GatewayRunner: return {} @staticmethod - def _load_fallback_model() -> list | dict | None: + def _load_fallback_model() -> list | None: """Load fallback provider chain from config.yaml. - Returns a list of provider dicts (``fallback_providers``), a single - dict (legacy ``fallback_model``), or None if not configured. - AIAgent.__init__ normalizes both formats into a chain. + Returns the merged effective chain from ``fallback_providers`` plus any + legacy ``fallback_model`` entries. ``fallback_providers`` stays first + when both keys are present. """ try: import yaml as _y @@ -2818,7 +2999,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or None + fb = get_fallback_chain(cfg) if fb: return fb except Exception: @@ -2890,11 +3071,19 @@ class GatewayRunner: running_agent = self._running_agents.get(session_key) + effective_mode = self._busy_input_mode + busy_text_mode = getattr(self, "_busy_text_mode", "queue") + if ( + event.message_type == MessageType.TEXT + and busy_text_mode == "queue" + and effective_mode != "steer" + ): + return False + # Steer mode: inject mid-run via running_agent.steer() instead of # queueing + interrupting. If the agent isn't running yet # (sentinel) or lacks steer(), or the payload is empty, fall back # to queue semantics so nothing is lost. - effective_mode = self._busy_input_mode steered = False if effective_mode == "steer": steer_text = (event.text or "").strip() @@ -2919,7 +3108,12 @@ class GatewayRunner: # successful steer โ€” the text already landed inside the run and # must NOT also be replayed as a next-turn user message. if not steered: - merge_pending_message_event(adapter._pending_messages, session_key, event) + merge_pending_message_event( + adapter._pending_messages, + session_key, + event, + merge_text=event.message_type == MessageType.TEXT, + ) is_queue_mode = effective_mode == "queue" is_steer_mode = effective_mode == "steer" @@ -3851,6 +4045,7 @@ class GatewayRunner: adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) adapter.set_busy_session_handler(self._handle_active_session_busy_message) + adapter._busy_text_mode = self._busy_text_mode # Try to connect logger.info("Connecting to %s...", platform.value) @@ -4955,6 +5150,11 @@ class GatewayRunner: if not candidates: return + from gateway.platforms.base import BasePlatformAdapter + candidates = BasePlatformAdapter.filter_local_delivery_paths(candidates) + if not candidates: + return + _IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".webp"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".3gp"} @@ -5458,6 +5658,7 @@ class GatewayRunner: adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) adapter.set_busy_session_handler(self._handle_active_session_busy_message) + adapter._busy_text_mode = self._busy_text_mode success = await self._connect_adapter_with_timeout(adapter, platform) if success: @@ -5897,6 +6098,12 @@ class GatewayRunner: if platform_registry.is_registered(platform.value): adapter = platform_registry.create_adapter(platform.value, config) if adapter is not None: + # Adapters that need a back-reference to the gateway runner + # (e.g. for cross-platform admin alerts) declare a + # ``gateway_runner`` attribute. Inject it after creation so + # plugin adapters don't need a custom factory signature. + if hasattr(adapter, "gateway_runner"): + adapter.gateway_runner = self return adapter # Registered but failed to instantiate โ€” don't silently fall # through to built-ins (there are none for plugin platforms). @@ -5939,15 +6146,6 @@ class GatewayRunner: adapter._notifications_mode = _notify_mode return adapter - elif platform == Platform.DISCORD: - from gateway.platforms.discord import DiscordAdapter, check_discord_requirements - if not check_discord_requirements(): - logger.warning("Discord: discord.py not installed") - return None - adapter = DiscordAdapter(config) - adapter.gateway_runner = self # For cross-platform admin alerts on unauthorized slash - return adapter - elif platform == Platform.WHATSAPP: from gateway.platforms.whatsapp import WhatsAppAdapter, check_whatsapp_requirements if not check_whatsapp_requirements(): @@ -6214,18 +6412,6 @@ class GatewayRunner: if allow_bots_var and os.getenv(allow_bots_var, "none").lower().strip() in {"mentions", "all"}: return True - # Discord role-based access (DISCORD_ALLOWED_ROLES): the adapter's - # on_message pre-filter already verified role membership โ€” if the - # message reached here, the user passed that check. Authorize - # directly to avoid the "no allowlists configured" branch below - # rejecting role-only setups where DISCORD_ALLOWED_USERS is empty - # (issue #7871). - if ( - source.platform == Platform.DISCORD - and os.getenv("DISCORD_ALLOWED_ROLES", "").strip() - ): - return True - # Check pairing store (always checked, regardless of allowlists) platform_name = source.platform.value if source.platform else "" if self.pairing_store.is_approved(platform_name, user_id): @@ -11164,14 +11350,16 @@ class GatewayRunner: # send_multiple_images (Telegram sendPhoto recompresses to ~1280px). force_document_attachments = "[[as_document]]" in response + from gateway.platforms.base import BasePlatformAdapter, should_send_media_as_audio + media_files, _ = adapter.extract_media(response) + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) _, cleaned = adapter.extract_images(response) local_files, _ = adapter.extract_local_files(cleaned) + local_files = BasePlatformAdapter.filter_local_delivery_paths(local_files) _thread_meta = self._thread_metadata_for_source(event.source, self._reply_anchor_for_event(event)) - from gateway.platforms.base import should_send_media_as_audio - _VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'} _IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'} @@ -11463,6 +11651,8 @@ class GatewayRunner: # Extract media files from the response if response: media_files, response = adapter.extract_media(response) + from gateway.platforms.base import BasePlatformAdapter + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) images, text_content = adapter.extract_images(response) preview = prompt[:60] + ("..." if len(prompt) > 60 else "") @@ -12551,7 +12741,7 @@ class GatewayRunner: return t("gateway.title.current_no_title", session_id=session_id) async def _handle_resume_command(self, event: MessageEvent) -> str: - """Handle /resume command โ€” switch to a previously-named session.""" + """Handle /resume command โ€” list or switch to a previous session.""" if not self._session_db: from hermes_state import format_session_db_unavailable return format_session_db_unavailable(prefix=t("gateway.shared.session_db_unavailable_prefix")) @@ -12560,30 +12750,44 @@ class GatewayRunner: session_key = self._session_key_for_source(source) name = event.get_command_args().strip() + def _list_titled_sessions() -> list[dict]: + user_source = source.platform.value if source.platform else None + sessions = self._session_db.list_sessions_rich(source=user_source, limit=10) + return [s for s in sessions if s.get("title")][:10] + if not name: # List recent titled sessions for this user/platform try: - user_source = source.platform.value if source.platform else None - sessions = self._session_db.list_sessions_rich( - source=user_source, limit=10 - ) - titled = [s for s in sessions if s.get("title")] + titled = _list_titled_sessions() if not titled: return t("gateway.resume.no_named_sessions") lines = [t("gateway.resume.list_header")] - for s in titled[:10]: + for idx, s in enumerate(titled[:10], start=1): title = s["title"] preview = s.get("preview", "")[:40] preview_part = t("gateway.resume.list_preview_suffix", preview=preview) if preview else "" - lines.append(t("gateway.resume.list_item", title=title, preview_part=preview_part)) - lines.append(t("gateway.resume.list_footer")) + lines.append(t("gateway.resume.list_item_numbered", index=idx, title=title, preview_part=preview_part)) + lines.append(t("gateway.resume.list_footer_numbered")) return "\n".join(lines) except Exception as e: logger.debug("Failed to list titled sessions: %s", e) return t("gateway.resume.list_failed", error=e) - # Resolve the name to a session ID. - target_id = self._session_db.resolve_session_by_title(name) + # Resolve a numbered choice or a title to a session ID. + if name.isdigit(): + try: + titled = _list_titled_sessions() + except Exception as e: + logger.debug("Failed to list titled sessions for numeric resume: %s", e) + return t("gateway.resume.list_failed", error=e) + index = int(name) + if index < 1 or index > len(titled): + return t("gateway.resume.out_of_range", index=index) + target = titled[index - 1] + target_id = target.get("id") + name = target.get("title") or name + else: + target_id = self._session_db.resolve_session_by_title(name) if not target_id: return t("gateway.resume.not_found", name=name) # Compression creates child continuations that hold the live transcript. @@ -16065,11 +16269,7 @@ class GatewayRunner: ) return _fut = safe_schedule_threadsafe( - _status_adapter.send( - _status_chat_id, - prepared_message, - metadata=_status_thread_metadata, - ), + _send_or_update_status_coro(_status_adapter, _status_chat_id, event_type, prepared_message, _status_thread_metadata), _loop_for_step, logger=logger, log_message=f"status_callback ({event_type}) scheduling error", @@ -16470,45 +16670,16 @@ class GatewayRunner: # that may include tool_calls, tool_call_id, reasoning, etc. # - These must be passed through intact so the API sees valid # assistantโ†’tool sequences (dropping tool_calls causes 500 errors) - agent_history = [] - for msg in history: - role = msg.get("role") - if not role: - continue - - # Skip metadata entries (tool definitions, session info) - # -- these are for transcript logging, not for the LLM - if role in {"session_meta",}: - continue - - # Skip system messages -- the agent rebuilds its own system prompt - if role == "system": - continue - - # Rich agent messages (tool_calls, tool results) must be passed - # through intact so the API sees valid assistantโ†’tool sequences - has_tool_calls = "tool_calls" in msg - has_tool_call_id = "tool_call_id" in msg - is_tool_message = role == "tool" - - if has_tool_calls or has_tool_call_id or is_tool_message: - clean_msg = {k: v for k, v in msg.items() if k != "timestamp"} - agent_history.append(clean_msg) - else: - # Simple text message - just need role and content - content = msg.get("content") - if content: - # Tag cross-platform mirror messages so the agent knows their origin - if msg.get("mirror"): - mirror_src = msg.get("mirror_source", "another session") - content = f"[Delivered from {mirror_src}] {content}" - # Preserve assistant reasoning + Codex replay fields so - # multi-turn reasoning context, prefix-cache hits, and - # provider-specific echo requirements survive session - # reload. See ``_ASSISTANT_REPLAY_FIELDS`` for the full - # whitelist and rationale. - entry = _build_replay_entry(role, content, msg) - agent_history.append(entry) + # + # Telegram observed group context is handled structurally here: + # observed=True transcript rows are withheld from replayable + # history and attached to the current addressed message as + # API-only context, so persisted history stores only the real + # addressed user turn. + agent_history, observed_group_context = _build_gateway_agent_history( + history, + channel_prompt=channel_prompt, + ) # Collect MEDIA paths already in history so we can exclude them # from the current turn's extraction. This is compression-safe: @@ -16741,7 +16912,17 @@ class GatewayRunner: else: _run_message = message - result = agent.run_conversation(_run_message, conversation_history=agent_history, task_id=session_id) + _api_run_message = _wrap_current_message_with_observed_context( + _run_message, + observed_group_context, + ) + _conversation_kwargs = { + "conversation_history": agent_history, + "task_id": session_id, + } + if observed_group_context: + _conversation_kwargs["persist_user_message"] = message + result = agent.run_conversation(_api_run_message, **_conversation_kwargs) finally: unregister_gateway_notify(_approval_session_key) # Cancel any pending clarify entries so blocked agent @@ -16957,6 +17138,7 @@ class GatewayRunner: "context_length": _context_length, "session_id": effective_session_id, "response_previewed": result.get("response_previewed", False), + "response_transformed": result.get("response_transformed", False), } # Start progress message sender if enabled @@ -17594,7 +17776,11 @@ class GatewayRunner: _content_delivered = bool( _sc and getattr(_sc, "final_content_delivered", False) ) - if not _is_empty_sentinel and (_streamed or _previewed or _content_delivered): + # Plugin hooks (e.g. transform_llm_output) may have appended content + # after streaming finished โ€” when the response was transformed, always + # send the final version so the appended content reaches the client. + _transformed = bool(response.get("response_transformed")) + if not _is_empty_sentinel and not _transformed and (_streamed or _previewed or _content_delivered): logger.info( "Suppressing normal final send for session %s: final delivery already confirmed (streamed=%s previewed=%s content_delivered=%s).", session_key or "?", @@ -17603,6 +17789,28 @@ class GatewayRunner: _content_delivered, ) response["already_sent"] = True + elif not _is_empty_sentinel and _transformed and _sc is not None: + # Plugin hooks transformed the response after streaming โ€” edit the + # existing streamed message instead of sending a duplicate. + _sc_msg_id = _sc.message_id + if _sc_msg_id: + try: + await _sc.adapter.edit_message( + chat_id=source.chat_id, + message_id=_sc_msg_id, + content=response["final_response"], + finalize=True, + ) + response["already_sent"] = True + logger.info( + "Edited streamed message %s for session %s to include plugin-transformed content.", + _sc_msg_id, session_key or "?", + ) + except Exception as _edit_err: + logger.warning( + "Failed to edit streamed message for session %s: %s", + session_key or "?", _edit_err, + ) # Schedule deletion of tracked temporary progress bubbles after the # final response lands. Failed runs skip this so bubbles remain as @@ -18029,6 +18237,21 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = runner.request_restart(detached=False, via_service=True) loop = asyncio.get_running_loop() + + # Install a loop-level exception handler that swallows transient + # network errors from background tasks. Issues #31066 / #31110: + # an unhandled ``telegram.error.TimedOut`` (or peer NetworkError / + # httpx connection error) in any awaited coroutine would propagate + # to the loop and kill the gateway process, taking down every + # profile attached to the same runner. systemd then restarts the + # service after ~5s but the active conversation turn is lost. + # + # The fix is intentionally narrow: only well-known transient + # network errors are swallowed (and logged with full traceback so + # the originating call site is still discoverable). Anything else + # is forwarded to the default handler so real bugs still surface. + loop.set_exception_handler(_gateway_loop_exception_handler) + if threading.current_thread() is threading.main_thread(): for sig in (signal.SIGINT, signal.SIGTERM): try: diff --git a/gateway/session.py b/gateway/session.py index 648f8cddf10..5f6fcb9a62f 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -1277,6 +1277,7 @@ class SessionStore: platform_message_id=( message.get("platform_message_id") or message.get("message_id") ), + observed=bool(message.get("observed")), ) except Exception as e: logger.debug("Session DB operation failed: %s", e) diff --git a/gateway/session_context.py b/gateway/session_context.py index 486949fae3d..ee43eca0f76 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -83,6 +83,21 @@ _VAR_MAP = { } +def set_current_session_id(session_id: str) -> None: + """Synchronize ``HERMES_SESSION_ID`` across ContextVar and ``os.environ``. + + Long-lived single-process entrypoints like the CLI can rotate sessions via + ``/new``, ``/resume``, ``/branch``, or compression splits without + reconstructing the entire agent. Tools still consult + ``get_session_env("HERMES_SESSION_ID")`` with an ``os.environ`` fallback, + so both storage paths must move together when the active session changes. + """ + import os + + os.environ["HERMES_SESSION_ID"] = session_id + _SESSION_ID.set(session_id) + + def set_session_vars( platform: str = "", chat_id: str = "", diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 17214050919..4ba65ddf4c5 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -192,6 +192,11 @@ class GatewayStreamConsumer: """True when the stream consumer delivered the final assistant reply.""" return self._final_response_sent + @property + def message_id(self) -> str | None: + """The Discord/chat message ID of the last-sent or edited message.""" + return self._message_id + @property def final_content_delivered(self) -> bool: """True when the final response content reached the user, even if diff --git a/hermes_cli/_parser.py b/hermes_cli/_parser.py index 3ece411e757..cf4ffc34e5c 100644 --- a/hermes_cli/_parser.py +++ b/hermes_cli/_parser.py @@ -129,7 +129,8 @@ def build_top_level_parser(): default=None, help=( "Provider override for this invocation (e.g. openrouter, anthropic). " - "Applies to -z/--oneshot and --tui. Also settable via HERMES_INFERENCE_PROVIDER env var." + "Applies to -z/--oneshot and --tui. The persistent provider lives in config.yaml " + "under model.provider โ€” use `hermes setup` or edit the file to change it." ), ) parser.add_argument( @@ -268,7 +269,11 @@ def build_top_level_parser(): help="Inference provider (default: auto). Built-in or a user-defined name from `providers:` in config.yaml.", ) chat_parser.add_argument( - "-v", "--verbose", action="store_true", help="Verbose output" + "-v", + "--verbose", + action="store_true", + default=argparse.SUPPRESS, + help="Verbose output", ) chat_parser.add_argument( "-Q", diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 59daa2c5d40..04cd6b3ce2f 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -41,7 +41,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse import httpx @@ -553,6 +553,7 @@ _PLACEHOLDER_SECRET_VALUES = { "***", "changeme", "your_api_key", + "your_api_key_here", "your-api-key", "placeholder", "example", @@ -1559,6 +1560,67 @@ def _optional_base_url(value: Any) -> Optional[str]: return cleaned if cleaned else None +# Allowlist of hosts the Nous Portal proxy is willing to forward minted +# bearer tokens to. The bearer is a long-lived agent_key minted by +# portal.nousresearch.com โ€” sending it anywhere else would leak it. +# +# This is consulted only for URLs coming from the NETWORK side (Portal +# refresh / agent-key-mint responses). User-controlled env-var overrides +# (NOUS_INFERENCE_BASE_URL) bypass validation โ€” that's the documented +# dev/staging escape hatch and the env source is already trusted (the +# user set it themselves). +_ALLOWED_NOUS_INFERENCE_HOSTS: FrozenSet[str] = frozenset({ + "inference-api.nousresearch.com", +}) + + +def _validate_nous_inference_url_from_network(url: Optional[str]) -> Optional[str]: + """Validate a Portal-returned inference URL against the host allowlist. + + Returns ``url`` (normalised by stripping trailing slashes) if it's a + well-formed ``https:///...`` URL. Returns ``None`` + if the URL is missing, malformed, non-https, or points at an + unexpected host โ€” letting the caller fall back to the configured + default rather than persist or forward a poisoned value. + + Defense-in-depth: a compromised refresh / mint response from the + Portal API (MITM, malicious response injection) could otherwise + redirect every subsequent proxy request โ€” bearing the user's + legitimately-minted agent_key โ€” to an attacker-controlled endpoint. + Validating scheme + host at the source closes that loop before the + poisoned URL ever lands in ``auth.json``. + + The env-var override path (``NOUS_INFERENCE_BASE_URL``) bypasses + this โ€” env values come from the trusted OS user, not from the + network, and the override is documented for staging/dev use. + + Co-authored-by: memosr + """ + if not isinstance(url, str): + return None + cleaned = url.strip() + if not cleaned: + return None + try: + parsed = urlparse(cleaned) + except Exception: + return None + if parsed.scheme != "https": + logger.warning( + "nous: refusing non-https inference URL scheme %r from Portal response", + parsed.scheme, + ) + return None + if parsed.hostname not in _ALLOWED_NOUS_INFERENCE_HOSTS: + logger.warning( + "nous: refusing inference URL host %r from Portal response " + "(not in allowlist); falling back to default", + parsed.hostname, + ) + return None + return cleaned.rstrip("/") + + def _decode_jwt_claims(token: Any) -> Dict[str, Any]: if not isinstance(token, str) or token.count(".") != 2: return {} @@ -2004,7 +2066,10 @@ def resolve_qwen_runtime_credentials( def get_qwen_auth_status() -> Dict[str, Any]: auth_path = _qwen_cli_auth_path() try: - creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False) + # Validate the runtime credentials, including refresh when the cached + # CLI token is expired. Otherwise stale tokens show up as "logged in" + # and `hermes model` walks users into a broken Qwen setup flow. + creds = resolve_qwen_runtime_credentials(refresh_if_expiring=True) return { "logged_in": True, "auth_file": str(auth_path), @@ -4776,7 +4841,7 @@ def refresh_nous_oauth_pure( state["refresh_token"] = refreshed.get("refresh_token") or state["refresh_token"] state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["scope"] = refreshed.get("scope") or state.get("scope") - refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) + refreshed_url = _validate_nous_inference_url_from_network(refreshed.get("inference_base_url")) if refreshed_url: state["inference_base_url"] = refreshed_url state["obtained_at"] = now.isoformat() @@ -4812,7 +4877,7 @@ def refresh_nous_oauth_pure( state["agent_key_expires_in"] = mint_payload.get("expires_in") state["agent_key_reused"] = bool(mint_payload.get("reused", False)) state["agent_key_obtained_at"] = now.isoformat() - minted_url = _optional_base_url(mint_payload.get("inference_base_url")) + minted_url = _validate_nous_inference_url_from_network(mint_payload.get("inference_base_url")) if minted_url: state["inference_base_url"] = minted_url @@ -5090,7 +5155,7 @@ def resolve_nous_runtime_credentials( state["refresh_token"] = refreshed.get("refresh_token") or refresh_token state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["scope"] = refreshed.get("scope") or state.get("scope") - refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) + refreshed_url = _validate_nous_inference_url_from_network(refreshed.get("inference_base_url")) if refreshed_url: inference_base_url = refreshed_url state["obtained_at"] = now.isoformat() @@ -5198,7 +5263,7 @@ def resolve_nous_runtime_credentials( state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["scope"] = refreshed.get("scope") or state.get("scope") - refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) + refreshed_url = _validate_nous_inference_url_from_network(refreshed.get("inference_base_url")) if refreshed_url: inference_base_url = refreshed_url state["obtained_at"] = now.isoformat() @@ -5253,7 +5318,7 @@ def resolve_nous_runtime_credentials( state["agent_key_expires_in"] = mint_payload.get("expires_in") state["agent_key_reused"] = bool(mint_payload.get("reused", False)) state["agent_key_obtained_at"] = now.isoformat() - minted_url = _optional_base_url(mint_payload.get("inference_base_url")) + minted_url = _validate_nous_inference_url_from_network(mint_payload.get("inference_base_url")) if minted_url: inference_base_url = minted_url _oauth_trace( @@ -7045,10 +7110,95 @@ def _refresh_minimax_oauth_state( return new_state +def _minimax_oauth_quarantine_on_terminal_refresh(state: Dict[str, Any], exc: AuthError) -> None: + """Wipe dead tokens from auth.json after a terminal refresh failure. + + Shared by both the eager-resolve path and the lazy per-request token + provider. Mirrors the Nous / xAI-OAuth / Codex-OAuth quarantine pattern + so subsequent calls fail fast without a network retry. + """ + if not (exc.relogin_required and state.get("refresh_token")): + return + for _k in ("access_token", "refresh_token", "expires_at", "expires_in", "obtained_at"): + state.pop(_k, None) + state["last_auth_error"] = { + "provider": "minimax-oauth", + "code": exc.code or "refresh_failed", + "message": str(exc), + "reason": "runtime_refresh_failure", + "relogin_required": True, + "at": datetime.now(timezone.utc).isoformat(), + } + try: + _minimax_save_auth_state(state) + except Exception as _save_exc: + logger.debug("MiniMax OAuth: failed to persist quarantined state: %s", _save_exc) + + +def build_minimax_oauth_token_provider() -> Callable[[], str]: + """Return a zero-arg callable that yields a fresh MiniMax access token. + + The Anthropic SDK caches ``api_key`` as a static string at construction + time, so a session that resolves credentials once at startup will keep + sending the same bearer until MiniMax's server returns 401 โ€” typically + ~15 minutes in, because MiniMax issues short-lived access tokens. + + Returning a *callable* instead of a string lets us hook into the + existing Entra-ID bearer infrastructure in + :mod:`agent.anthropic_adapter`: ``build_anthropic_client`` detects a + callable and routes through ``_build_anthropic_client_with_bearer_hook``, + which mints a fresh ``Authorization`` header on every outbound request. + Each invocation re-reads the persisted state from ``auth.json`` and + calls :func:`_refresh_minimax_oauth_state` โ€” that helper is a no-op + when the token still has more than ``MINIMAX_OAUTH_REFRESH_SKEW_SECONDS`` + of life left, so the steady-state cost is one file read + one + timestamp compare per request. + + Reading state fresh each time also means a refresh persisted by one + process (CLI, gateway, cron) is immediately visible to every other + process sharing the same ``auth.json``. + """ + def _provide() -> str: + state = get_provider_auth_state("minimax-oauth") + if not state or not state.get("access_token"): + raise AuthError( + "Not logged into MiniMax OAuth. Run `hermes model` and select " + "MiniMax (OAuth).", + provider="minimax-oauth", code="not_logged_in", relogin_required=True, + ) + try: + state = _refresh_minimax_oauth_state(state) + except AuthError as exc: + _minimax_oauth_quarantine_on_terminal_refresh(state, exc) + raise + token = state.get("access_token") + if not token: + raise AuthError( + "MiniMax OAuth state has no access_token after refresh.", + provider="minimax-oauth", code="no_access_token", relogin_required=True, + ) + return token + + return _provide + + def resolve_minimax_oauth_runtime_credentials( *, min_token_ttl_seconds: int = MINIMAX_OAUTH_REFRESH_SKEW_SECONDS, + as_token_provider: bool = False, ) -> Dict[str, Any]: - """Return {provider, api_key, base_url, source} for minimax-oauth.""" + """Return {provider, api_key, base_url, source} for minimax-oauth. + + When ``as_token_provider`` is True, ``api_key`` is a zero-arg callable + that mints a fresh access token per call (proactively refreshing if + the cached token is within ``MINIMAX_OAUTH_REFRESH_SKEW_SECONDS`` of + expiry). This is what the runtime provider path uses so that long + sessions survive MiniMax's short access-token lifetime โ€” see + :func:`build_minimax_oauth_token_provider` for the rationale. + + The default (string ``api_key``) preserves the historical contract for + diagnostic call sites like ``hermes status`` that just want to know + whether a valid token exists right now. + """ state = get_provider_auth_state("minimax-oauth") if not state or not state.get("access_token"): raise AuthError( @@ -7059,28 +7209,15 @@ def resolve_minimax_oauth_runtime_credentials( try: state = _refresh_minimax_oauth_state(state) except AuthError as exc: - if exc.relogin_required and state.get("refresh_token"): - # Terminal refresh failure โ€” clear dead tokens from auth.json so - # subsequent calls fail fast without a network retry, mirroring - # the Nous / xAI-OAuth / Codex-OAuth quarantine pattern. - for _k in ("access_token", "refresh_token", "expires_at", "expires_in", "obtained_at"): - state.pop(_k, None) - state["last_auth_error"] = { - "provider": "minimax-oauth", - "code": exc.code or "refresh_failed", - "message": str(exc), - "reason": "runtime_refresh_failure", - "relogin_required": True, - "at": datetime.now(timezone.utc).isoformat(), - } - try: - _minimax_save_auth_state(state) - except Exception as _save_exc: - logger.debug("MiniMax OAuth: failed to persist quarantined state: %s", _save_exc) + _minimax_oauth_quarantine_on_terminal_refresh(state, exc) raise + if as_token_provider: + api_key: Any = build_minimax_oauth_token_provider() + else: + api_key = state["access_token"] return { "provider": "minimax-oauth", - "api_key": state["access_token"], + "api_key": api_key, "base_url": state["inference_base_url"].rstrip("/"), "source": "oauth", } diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index b920ff2e5fe..f589248621c 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -164,7 +164,7 @@ COMMAND_REGISTRY: list[CommandDef] = [ cli_only=True), CommandDef("skills", "Search, install, inspect, or manage skills", "Tools & Skills", cli_only=True, - subcommands=("search", "browse", "inspect", "install")), + subcommands=("search", "browse", "inspect", "install", "audit")), CommandDef("bundles", "List skill bundles (aliases / for multiple skills)", "Tools & Skills"), CommandDef("cron", "Manage scheduled tasks", "Tools & Skills", @@ -449,7 +449,7 @@ def _iter_plugin_command_entries() -> list[tuple[str, str, str]]: :func:`hermes_cli.plugins.PluginContext.register_command`. They behave like ``CommandDef`` entries for gateway surfacing: they appear in the Telegram command menu, in Slack's ``/hermes`` subcommand mapping, and - (via :func:`gateway.platforms.discord._register_slash_commands`) in + (via :func:`plugins.platforms.discord.adapter._register_slash_commands`) in Discord's native slash command picker. Lookup is lazy so importing this module never forces plugin discovery diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 65e3cce88b3..61f46935bc5 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -1009,6 +1009,19 @@ DEFAULT_CONFIG = { "compact": False, "personality": "kawaii", "resume_display": "full", + # Recap tuning for /resume and startup resume. The defaults match the + # historical hardcoded values; expose them as config so power users can + # widen or tighten the snapshot to taste. + "resume_exchanges": 10, # max user+assistant pairs to show + "resume_max_user_chars": 300, # truncate user message text + "resume_max_assistant_chars": 200, # truncate non-last assistant text + "resume_max_assistant_lines": 3, # truncate non-last assistant lines + # When True (default), assistant entries that are *only* tool calls + # (no visible text) are skipped in the recap. This prevents the recap + # from being dominated by `[2 tool calls: terminal, read_file]` lines + # when an exchange was tool-heavy. Set False to restore the legacy + # behavior of showing tool-call summaries inline. + "resume_skip_tool_only": True, "busy_input_mode": "interrupt", # interrupt | queue | steer # When true, `hermes --tui` auto-resumes the most recent human- # facing session on launch instead of forging a fresh one. @@ -1776,6 +1789,14 @@ DEFAULT_CONFIG = { # ~/.hermes/bin/ on first use. When False you must install # bws yourself and have it on PATH. "auto_install": True, + # Bitwarden region / self-hosted endpoint. Empty string + # means use the bws CLI default (US Cloud, + # https://vault.bitwarden.com). Set to + # https://vault.bitwarden.eu for EU Cloud, or your own URL + # for self-hosted Bitwarden. Plumbed into the bws subprocess + # as BWS_SERVER_URL. Prompted for during + # `hermes secrets bitwarden setup`. + "server_url": "", }, }, diff --git a/hermes_cli/curses_ui.py b/hermes_cli/curses_ui.py index 57607cc31dd..f0e991c0ae2 100644 --- a/hermes_cli/curses_ui.py +++ b/hermes_cli/curses_ui.py @@ -71,7 +71,7 @@ def curses_checklist( curses.use_default_colors() curses.init_pair(1, curses.COLOR_GREEN, -1) curses.init_pair(2, curses.COLOR_YELLOW, -1) - curses.init_pair(3, 8, -1) # dim gray + curses.init_pair(3, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) # dim gray cursor = 0 scroll_offset = 0 diff --git a/hermes_cli/debug.py b/hermes_cli/debug.py index a7338e4ba82..b309ee37c54 100644 --- a/hermes_cli/debug.py +++ b/hermes_cli/debug.py @@ -14,6 +14,7 @@ Currently supports: import io import json import logging +import re import sys import time import urllib.error @@ -36,6 +37,12 @@ _REDACTION_BANNER = ( "run with --no-redact to disable]\n" ) +_EMAIL_ADDRESS_RE = re.compile( + r"(? str: return text from agent.redact import redact_sensitive_text - return redact_sensitive_text(text, force=True) + text = redact_sensitive_text(text, force=True) + return _EMAIL_ADDRESS_RE.sub("[REDACTED_EMAIL]", text) def _capture_log_snapshot( diff --git a/hermes_cli/env_loader.py b/hermes_cli/env_loader.py index 521076af9b4..40a87830dfe 100644 --- a/hermes_cli/env_loader.py +++ b/hermes_cli/env_loader.py @@ -21,6 +21,44 @@ _CREDENTIAL_SUFFIXES = ("_API_KEY", "_TOKEN", "_SECRET", "_KEY") # tests) don't spam the same warning multiple times. _WARNED_KEYS: set[str] = set() +# Map of env-var name โ†’ source label ("bitwarden", etc.) for credentials +# that were injected by an external secret source during load_hermes_dotenv(). +# Used by setup / `hermes model` flows to label detected credentials so +# users understand WHERE a key came from when their .env doesn't contain it +# directly (otherwise the "credentials detected โœ“" line looks identical to +# the .env case and they don't know Bitwarden is wired up). +_SECRET_SOURCES: dict[str, str] = {} + + +def get_secret_source(env_var: str) -> str | None: + """Return the label of the secret source that supplied ``env_var``, if any. + + Returns ``"bitwarden"`` for keys pulled from Bitwarden Secrets Manager + during the current process's ``load_hermes_dotenv()`` call. Returns + ``None`` for keys that came from ``.env``, the shell environment, or + aren't tracked. + """ + return _SECRET_SOURCES.get(env_var) + + +def format_secret_source_suffix(env_var: str) -> str: + """Return a human-readable suffix like ``" (from Bitwarden)"`` or ``""``. + + Use this when printing a detected credential so the user can see where + it came from. Empty string when the credential came from ``.env`` or + the shell โ€” those are the implicit / "default" cases users already + understand. + """ + source = get_secret_source(env_var) + if not source: + return "" + if source == "bitwarden": + return " (from Bitwarden)" + # Generic fallback โ€” future-proofing for additional secret sources + # (e.g. 1Password, HashiCorp Vault) without having to update every + # call site. + return f" (from {source})" + def _format_offending_chars(value: str, limit: int = 3) -> str: """Return a compact 'U+XXXX ('c'), ...' summary of non-ASCII codepoints.""" @@ -102,6 +140,10 @@ def _sanitize_env_file_if_needed(path: Path) -> None: This produces mangled values โ€” e.g. a bot token duplicated 8ร— (see #8908). + Also strips embedded null bytes which crash ``os.environ[k] = v`` + with ``ValueError: embedded null byte`` โ€” typically introduced by + copy-pasting API keys from terminals or rich-text editors. + We delegate to ``hermes_cli.config._sanitize_env_lines`` which already knows all valid Hermes env-var names and can split concatenated lines correctly. @@ -117,7 +159,11 @@ def _sanitize_env_file_if_needed(path: Path) -> None: try: with open(path, **read_kw) as f: original = f.readlines() - sanitized = _sanitize_env_lines(original) + # Strip null bytes before _sanitize_env_lines so they never + # reach python-dotenv (which passes them to os.environ and + # crashes with ValueError). + stripped = [line.replace("\x00", "") for line in original] + sanitized = _sanitize_env_lines(stripped) if sanitized != original: import tempfile fd, tmp = tempfile.mkstemp( @@ -206,6 +252,7 @@ def _apply_external_secret_sources(home_path: Path) -> None: override_existing=bool(bw_cfg.get("override_existing", False)), cache_ttl_seconds=float(bw_cfg.get("cache_ttl_seconds", 300)), auto_install=bool(bw_cfg.get("auto_install", True)), + server_url=str(bw_cfg.get("server_url", "") or "").strip(), ) if result.applied: @@ -213,6 +260,12 @@ def _apply_external_secret_sources(home_path: Path) -> None: # and might have the same copy-paste corruption as a manually # edited .env (see #6843). _sanitize_loaded_credentials() + # Remember where these came from so the setup / `hermes model` + # flows can label detected credentials with "(from Bitwarden)" โ€” + # otherwise users see "credentials โœ“" with no hint that the value + # came from BSM rather than .env. + for name in result.applied: + _SECRET_SOURCES[name] = "bitwarden" print( f" Bitwarden Secrets Manager: applied {len(result.applied)} " f"secret{'s' if len(result.applied) != 1 else ''} " diff --git a/hermes_cli/fallback_cmd.py b/hermes_cli/fallback_cmd.py index 9f2e6b97d46..09142ea99ea 100644 --- a/hermes_cli/fallback_cmd.py +++ b/hermes_cli/fallback_cmd.py @@ -21,6 +21,8 @@ from __future__ import annotations import copy from typing import Any, Dict, List, Optional +from hermes_cli.fallback_config import get_fallback_chain + # --------------------------------------------------------------------------- # Helpers @@ -30,20 +32,11 @@ def _read_chain(config: Dict[str, Any]) -> List[Dict[str, Any]]: """Return the normalized fallback chain as a list of dicts. Accepts both the new list format (``fallback_providers``) and the legacy - single-dict format (``fallback_model``). The returned list is always a - fresh copy โ€” callers can mutate without touching the config dict. + ``fallback_model`` format. When both are present, the effective chain is + merged with ``fallback_providers`` entries kept first. The returned list is + always a fresh copy โ€” callers can mutate without touching the config dict. """ - chain = config.get("fallback_providers") or [] - if isinstance(chain, list): - result = [dict(e) for e in chain if isinstance(e, dict) and e.get("provider") and e.get("model")] - if result: - return result - legacy = config.get("fallback_model") - if isinstance(legacy, dict) and legacy.get("provider") and legacy.get("model"): - return [dict(legacy)] - if isinstance(legacy, list): - return [dict(e) for e in legacy if isinstance(e, dict) and e.get("provider") and e.get("model")] - return [] + return get_fallback_chain(config) def _write_chain(config: Dict[str, Any], chain: List[Dict[str, Any]]) -> None: diff --git a/hermes_cli/fallback_config.py b/hermes_cli/fallback_config.py new file mode 100644 index 00000000000..d7cfc952d2d --- /dev/null +++ b/hermes_cli/fallback_config.py @@ -0,0 +1,72 @@ +"""Helpers for reading the effective fallback provider chain from config.""" + +from __future__ import annotations + +from typing import Any + + +def _normalized_base_url(value: Any) -> str: + if not isinstance(value, str): + return "" + return value.strip().rstrip("/") + + +def _iter_fallback_entries(raw: Any) -> list[dict[str, Any]]: + if isinstance(raw, dict): + candidates = [raw] + elif isinstance(raw, list): + candidates = raw + else: + return [] + + entries: list[dict[str, Any]] = [] + for entry in candidates: + if not isinstance(entry, dict): + continue + provider = str(entry.get("provider") or "").strip() + model = str(entry.get("model") or "").strip() + if not provider or not model: + continue + + normalized = dict(entry) + normalized["provider"] = provider + normalized["model"] = model + + base_url = _normalized_base_url(entry.get("base_url")) + if base_url: + normalized["base_url"] = base_url + + entries.append(normalized) + return entries + + +def _entry_identity(entry: dict[str, Any]) -> tuple[str, str, str]: + return ( + str(entry.get("provider") or "").strip().lower(), + str(entry.get("model") or "").strip().lower(), + _normalized_base_url(entry.get("base_url")).lower(), + ) + + +def get_fallback_chain(config: dict[str, Any] | None) -> list[dict[str, Any]]: + """Return the effective fallback chain merged across old and new config keys. + + ``fallback_providers`` remains the primary source of truth and keeps its + order. Legacy ``fallback_model`` entries are appended afterwards unless + they target the same provider/model/base_url route as an earlier entry. + The returned list always contains fresh dict copies. + """ + + config = config or {} + chain: list[dict[str, Any]] = [] + seen: set[tuple[str, str, str]] = set() + + for key in ("fallback_providers", "fallback_model"): + for entry in _iter_fallback_entries(config.get(key)): + identity = _entry_identity(entry) + if identity in seen: + continue + seen.add(identity) + chain.append(entry) + + return chain diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 05b34c581f5..a3b08751257 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -3349,34 +3349,9 @@ _PLATFORMS = [ "help": "For DMs, this is your user ID. You can set it later by typing /set-home in chat."}, ], }, - { - "key": "discord", - "label": "Discord", - "emoji": "๐Ÿ’ฌ", - "token_var": "DISCORD_BOT_TOKEN", - "setup_instructions": [ - "1. Go to https://discord.com/developers/applications โ†’ New Application", - "2. Go to Bot โ†’ Reset Token โ†’ copy the bot token", - "3. Enable: Bot โ†’ Privileged Gateway Intents โ†’ Message Content Intent", - "4. Invite the bot to your server:", - " OAuth2 โ†’ URL Generator โ†’ check BOTH scopes:", - " - bot", - " - applications.commands (required for slash commands!)", - " Bot Permissions: Send Messages, Read Message History, Attach Files", - " Copy the URL and open it in your browser to invite.", - "5. Get your user ID: enable Developer Mode in Discord settings,", - " then right-click your name โ†’ Copy ID", - ], - "vars": [ - {"name": "DISCORD_BOT_TOKEN", "prompt": "Bot token", "password": True, - "help": "Paste the token from step 2 above."}, - {"name": "DISCORD_ALLOWED_USERS", "prompt": "Allowed user IDs or usernames (comma-separated)", "password": False, - "is_allowlist": True, - "help": "Paste your user ID from step 5 above."}, - {"name": "DISCORD_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False, - "help": "Right-click a channel โ†’ Copy Channel ID (requires Developer Mode)."}, - ], - }, + # Discord moved to plugins/platforms/discord/ โ€” its setup metadata is + # discovered dynamically via _all_platforms() from the platform registry + # entry registered by plugins/platforms/discord/adapter.py::register(). { "key": "slack", "label": "Slack", @@ -3784,7 +3759,12 @@ def _platform_status(platform: dict) -> str: configured = bool(entry.is_connected(synthetic)) except Exception: configured = False - if not configured: + else: + # No is_connected hook โ€” fall back to check_fn as a coarse + # "are deps present" gate. Don't fall back when is_connected + # is defined and returned False; that would let "SDK is + # installed" override "no token configured" and incorrectly + # report the platform as ready. try: configured = bool(entry.check_fn()) except Exception: @@ -4040,15 +4020,11 @@ def _setup_dingtalk(): client_id, client_secret = result save_env_value("DINGTALK_CLIENT_ID", client_id) save_env_value("DINGTALK_CLIENT_SECRET", client_secret) - save_env_value("DINGTALK_ALLOW_ALL_USERS", "true") print() print_success(f"{emoji} {label} configured via QR scan!") else: # โ”€โ”€ Manual entry โ”€โ”€ _setup_standard_platform(dingtalk_platform) - # Also enable allow-all by default for convenience - if get_env_value("DINGTALK_CLIENT_ID"): - save_env_value("DINGTALK_ALLOW_ALL_USERS", "true") def _setup_wecom(): @@ -4769,7 +4745,9 @@ def _builtin_setup_fn(key: str): from hermes_cli import setup as _s return { "telegram": _s._setup_telegram, - "discord": _s._setup_discord, + # discord moved into the plugin: setup_fn is registered by + # plugins/platforms/discord/adapter.py::register() and dispatched + # via the plugin path in _configure_platform(). "slack": _s._setup_slack, "matrix": _s._setup_matrix, "mattermost": _s._setup_mattermost, diff --git a/hermes_cli/gateway_windows.py b/hermes_cli/gateway_windows.py index 77ea60d9b39..e019bb3e638 100644 --- a/hermes_cli/gateway_windows.py +++ b/hermes_cli/gateway_windows.py @@ -365,7 +365,9 @@ def _write_task_script() -> Path: content = _build_gateway_cmd_script(python_path, working_dir, hermes_home, profile_arg) script_path = get_task_script_path() - script_path.write_text(content, encoding="utf-8", newline="") + tmp = script_path.with_suffix(".tmp") + tmp.write_text(content, encoding="utf-8", newline="") + tmp.replace(script_path) return script_path @@ -436,7 +438,9 @@ def _install_startup_entry(script_path: Path) -> Path: """Write the Startup-folder fallback launcher. Returns its path.""" entry = get_startup_entry_path() entry.parent.mkdir(parents=True, exist_ok=True) - entry.write_text(_build_startup_launcher(script_path), encoding="utf-8", newline="") + tmp = entry.with_suffix(".tmp") + tmp.write_text(_build_startup_launcher(script_path), encoding="utf-8", newline="") + tmp.replace(entry) return entry diff --git a/hermes_cli/kanban.py b/hermes_cli/kanban.py index 4e975bb3e8d..1e7169c26cf 100644 --- a/hermes_cli/kanban.py +++ b/hermes_cli/kanban.py @@ -550,6 +550,39 @@ def build_parser(parent_subparsers: argparse._SubParsersAction) -> argparse.Argu p_unblock = sub.add_parser("unblock", help="Return one or more blocked/scheduled tasks to ready") p_unblock.add_argument("task_ids", nargs="+") + p_promote = sub.add_parser( + "promote", + help="Manually move one or more todo/blocked tasks to ready (recovery path)", + ) + p_promote.add_argument("task_id") + p_promote.add_argument( + "reason", + nargs="*", + help="Audit-trail reason (recorded on the task_events row)", + ) + p_promote.add_argument( + "--ids", + nargs="+", + default=None, + help="Additional task ids to promote with the same reason (bulk mode)", + ) + p_promote.add_argument( + "--force", + action="store_true", + help="Promote even if parent dependencies are not yet done/archived", + ) + p_promote.add_argument( + "--dry-run", + action="store_true", + help="Validate the promotion without mutating state", + ) + p_promote.add_argument( + "--json", + dest="json", + action="store_true", + help="Emit machine-readable JSON result", + ) + p_archive = sub.add_parser("archive", help="Archive one or more tasks") p_archive.add_argument("task_ids", nargs="*", help="Task ids to archive (default mode)") @@ -899,6 +932,7 @@ def kanban_command(args: argparse.Namespace) -> int: "block": _cmd_block, "schedule": _cmd_schedule, "unblock": _cmd_unblock, + "promote": _cmd_promote, "archive": _cmd_archive, "tail": _cmd_tail, "dispatch": _cmd_dispatch, @@ -1955,6 +1989,57 @@ def _cmd_unblock(args: argparse.Namespace) -> int: return 0 if not failed else 1 +def _cmd_promote(args: argparse.Namespace) -> int: + reason = " ".join(args.reason).strip() if args.reason else None + author = _profile_author() + as_json = getattr(args, "json", False) + extra_ids = list(getattr(args, "ids", None) or []) + # Dedupe while preserving order; positional task_id always first. + ids: list[str] = [] + seen: set[str] = set() + for tid in [args.task_id, *extra_ids]: + if tid not in seen: + ids.append(tid) + seen.add(tid) + + results: list[dict[str, object]] = [] + with kb.connect() as conn: + for tid in ids: + ok, err = kb.promote_task( + conn, + tid, + actor=author, + reason=reason, + force=bool(args.force), + dry_run=bool(args.dry_run), + ) + results.append({ + "task_id": tid, + "promoted": ok, + "dry_run": bool(args.dry_run), + "forced": bool(args.force), + "reason": reason, + "error": err, + }) + + failed = [r for r in results if not r["promoted"]] + if as_json: + # Single-id stays a flat object for back-compat; bulk emits a list. + payload: object = results[0] if len(results) == 1 else results + print(json.dumps(payload, indent=2, ensure_ascii=False)) + return 0 if not failed else 1 + + tag = " (dry)" if args.dry_run else "" + label = "Would promote" if args.dry_run else "Promoted" + for r in results: + if r["promoted"]: + suffix = f": {reason}" if reason else "" + print(f"{label} {r['task_id']} -> ready{tag}{suffix}") + else: + print(f"cannot promote {r['task_id']}: {r['error']}", file=sys.stderr) + return 0 if not failed else 1 + + def _cmd_archive(args: argparse.Namespace) -> int: ids = list(args.task_ids or []) purge_ids = list(getattr(args, "purge_ids", None) or []) diff --git a/hermes_cli/kanban_db.py b/hermes_cli/kanban_db.py index 7a30b70987f..c89e697c98d 100644 --- a/hermes_cli/kanban_db.py +++ b/hermes_cli/kanban_db.py @@ -75,6 +75,7 @@ import json import os import re import secrets +import shutil import sqlite3 import subprocess import sys @@ -82,6 +83,7 @@ import threading import logging import time from dataclasses import dataclass, field +from datetime import datetime from pathlib import Path from typing import Any, Iterable, Optional @@ -1005,6 +1007,131 @@ def _validate_sqlite_header(path: Path) -> None: ) +class KanbanDbCorruptError(RuntimeError): + """Raised when an existing kanban DB file fails integrity checks. + + Fail-closed guard against silent recreation of a corrupt board file, + which would otherwise destroy the user's tasks. Carries both the + original path and the timestamped backup we made before refusing. + """ + + def __init__(self, db_path: Path, backup_path: Optional[Path], reason: str): + self.db_path = db_path + self.backup_path = backup_path + self.reason = reason + backup_str = str(backup_path) if backup_path is not None else "" + super().__init__( + f"Refusing to open corrupt kanban DB at {db_path}: {reason}. " + f"Original preserved; backup at {backup_str}." + ) + + +def _backup_corrupt_db(path: Path) -> Optional[Path]: + """Copy a corrupt DB (and its WAL/SHM sidecars) to a timestamped backup. + + Returns the backup path of the main DB file, or ``None`` if the copy + itself failed (the caller still raises loudly in that case). + + Writes are confined to the original DB's parent directory. The + backup basename is derived purely from ``path.name``, never from + caller-supplied directory segments โ€” no traversal is possible. + """ + # Resolve once and pin the parent so subsequent path operations cannot + # escape it. ``Path.resolve()`` collapses any ``..`` segments and + # symlinks, and we only ever write inside ``parent``. + resolved = path.resolve() + parent = resolved.parent + base_name = resolved.name # basename only + stamp = datetime.now().strftime("%Y%m%d_%H%M%S") + candidate = parent / f"{base_name}.corrupt.{stamp}.bak" + # Defensive: candidate must still be inside parent after construction. + # f-string interpolation of ``base_name`` cannot escape ``parent`` + # because ``base_name`` is itself a resolved basename, but assert it + # anyway so static analyzers can see the containment guarantee. + if candidate.parent != parent: + return None + counter = 0 + while candidate.exists(): + counter += 1 + candidate = parent / f"{base_name}.corrupt.{stamp}.{counter}.bak" + if candidate.parent != parent: + return None + try: + shutil.copy2(resolved, candidate) + except OSError: + return None + for suffix in ("-wal", "-shm"): + sidecar = parent / (base_name + suffix) + if sidecar.parent != parent or not sidecar.exists(): + continue + try: + sidecar_backup = parent / (candidate.name + suffix) + if sidecar_backup.parent != parent: + continue + shutil.copy2(sidecar, sidecar_backup) + except OSError: + pass + return candidate + + +def _guard_existing_db_is_healthy(path: Path) -> None: + """Run ``PRAGMA integrity_check`` on an existing non-empty DB file. + + Opens the probe in read/write mode so SQLite can recover or + checkpoint a healthy WAL/hot-journal DB before we declare it + corrupt. If the file is malformed, copy it (and any WAL/SHM + sidecars) to a timestamped backup and raise + :class:`KanbanDbCorruptError` so callers cannot silently recreate + the schema on top of a damaged DB. + + Transient lock/busy errors (``sqlite3.OperationalError``) are NOT + treated as corruption; they propagate raw so the caller sees a + normal lock failure and no spurious ``.corrupt`` backup is made. + + No-op for missing files, zero-byte files (treated as fresh), and + paths already proven healthy this process (cache hit). + + Path-trust note: ``path`` arrives via :func:`connect`, which itself + resolves it from an explicit ``db_path`` argument, the + :func:`kanban_db_path` env-var chain, or the kanban-home default โ€” + all sources Hermes treats as user-controlled-but-trusted on the + user's own machine. We additionally resolve the path here and + confine all filesystem writes to its parent directory so any + accidental ``..`` segments are collapsed before any I/O happens. + """ + # Resolve before any I/O. ``Path.resolve()`` normalizes ``..`` and + # symlinks, giving us a canonical path whose parent dir we can pin. + try: + resolved = path.resolve() + except OSError: + return + try: + if not resolved.exists() or resolved.stat().st_size == 0: + return + except OSError: + return + if str(resolved) in _INITIALIZED_PATHS: + return + reason: Optional[str] = None + try: + probe = sqlite3.connect(str(resolved), timeout=5, isolation_level=None) + try: + row = probe.execute("PRAGMA integrity_check").fetchone() + finally: + probe.close() + if not row or (row[0] or "").lower() != "ok": + reason = f"integrity_check returned {row[0] if row else ''!r}" + except sqlite3.OperationalError: + # Lock contention, busy, transient IO โ€” not corruption. Let it propagate. + raise + except sqlite3.DatabaseError as exc: + reason = f"sqlite refused to open file: {exc}" + if reason is None: + return + backup = _backup_corrupt_db(resolved) + raise KanbanDbCorruptError(resolved, backup, reason) + + def connect( db_path: Optional[Path] = None, *, @@ -1033,7 +1160,13 @@ def connect( else: path = kanban_db_path(board=board) path.parent.mkdir(parents=True, exist_ok=True) + # Cheap byte-level check first โ€” catches the #29507 TLS-overwrite shape + # and other invalid-header cases without opening a sqlite connection. _validate_sqlite_header(path) + # Full integrity probe โ€” catches corruption past the header (malformed + # pages, broken internal metadata). Cached per-path after first success + # via _INITIALIZED_PATHS so it only runs once per process per path. + _guard_existing_db_is_healthy(path) resolved = str(path.resolve()) conn = sqlite3.connect(str(path), isolation_level=None, timeout=30) try: @@ -1518,8 +1651,15 @@ def create_task( now = int(time.time()) # Resolve workspace_path from board-level default_workdir when the - # caller did not specify one explicitly. - if workspace_path is None: + # caller did not specify one explicitly. Board defaults represent + # persistent project checkouts, so only persistent workspace kinds may + # inherit them. Scratch workspaces are auto-deleted on completion and + # must stay under the per-board scratch root created by + # ``resolve_workspace``; inheriting ``default_workdir`` for a scratch + # task would point cleanup at the user's source tree (#28818). The + # containment guard in ``_cleanup_workspace`` is the safety rail, but + # we also stop the bad state from being created in the first place. + if workspace_path is None and workspace_kind in {"dir", "worktree"}: board_slug = board if board else get_current_board() board_meta = read_board_metadata(board_slug) board_default = board_meta.get("default_workdir") @@ -2904,6 +3044,81 @@ def complete_task( # Workspace / tmux cleanup # --------------------------------------------------------------------------- +def _is_managed_scratch_path(p: Path) -> bool: + """Return True iff *p* is a strict descendant of a kanban-managed scratch root. + + A managed root is exclusively a ``workspaces/`` directory โ€” never the + broader kanban home, a board root, or sibling subtrees like ``logs/`` or + ``boards//`` itself. Allowed roots: + + * ``HERMES_KANBAN_WORKSPACES_ROOT`` when set (worker-side override + injected by the dispatcher). + * ``/kanban/workspaces`` โ€” legacy default-board scratch root. + * ``/kanban/boards//workspaces`` for each board slug + that currently exists on disk. + + The check requires strict descendancy: a path equal to one of these + roots is NOT managed (deleting the workspaces root would wipe every + task's scratch dir at once), and a path that resolves to `` + /kanban`` itself, ``/kanban/logs``, or + ``/kanban/boards/`` is rejected because those + subtrees hold Hermes' own DB, metadata, and logs, not task workspaces. + + Used by :func:`_cleanup_workspace` to refuse to ``shutil.rmtree`` paths + outside Hermes-managed storage. A board ``default_workdir`` pointing at a + real source tree can otherwise pair with ``workspace_kind='scratch'`` and + cause task completion to delete user data (#28818). + """ + try: + p_abs = p.resolve(strict=False) + except OSError: + return False + roots: list[Path] = [] + override = os.environ.get("HERMES_KANBAN_WORKSPACES_ROOT", "").strip() + if override: + try: + roots.append(Path(override).expanduser().resolve(strict=False)) + except OSError: + pass + try: + home = kanban_home() + except OSError: + home = None + if home is not None: + try: + roots.append((home / "kanban" / "workspaces").resolve(strict=False)) + except OSError: + pass + try: + boards_parent = (home / "kanban" / "boards").resolve(strict=False) + except OSError: + boards_parent = None + if boards_parent is not None: + try: + entries = list(boards_parent.iterdir()) + except OSError: + entries = [] + for entry in entries: + try: + if not entry.is_dir(): + continue + except OSError: + continue + try: + roots.append((entry / "workspaces").resolve(strict=False)) + except OSError: + continue + for root in roots: + if p_abs == root: + continue + try: + if p_abs.is_relative_to(root): + return True + except ValueError: + continue + return False + + def _cleanup_workspace(conn: sqlite3.Connection, task_id: str) -> None: """Remove a task's scratch workspace dir and kill its stale tmux session. @@ -2926,8 +3141,21 @@ def _cleanup_workspace(conn: sqlite3.Connection, task_id: str) -> None: import shutil wp = Path(path) if wp.is_dir(): - shutil.rmtree(wp, ignore_errors=True) - _log.debug("Removed scratch workspace: %s", wp) + # Containment guard (#28818): a board's ``default_workdir`` can + # pair ``workspace_kind='scratch'`` with a user-supplied path + # pointing at a real source tree. Without this check, task + # completion would unconditionally ``shutil.rmtree`` that path + # and silently delete the user's source data. + if _is_managed_scratch_path(wp): + shutil.rmtree(wp, ignore_errors=True) + _log.debug("Removed scratch workspace: %s", wp) + else: + _log.warning( + "Refusing to remove out-of-scratch workspace for task %s: %s " + "(workspace_kind='scratch' but path is outside any " + "kanban-managed workspaces root)", + task_id, wp, + ) # Also kill the tmux session for the worker that owned this task, # if the tmux session is now dead (worker process exited). _cleanup_worker_tmux(conn, task_id) @@ -2961,6 +3189,93 @@ def _cleanup_worker_tmux(conn: sqlite3.Connection, task_id: str) -> None: pass # best-effort โ€” never block completion +# --------------------------------------------------------------------------- +# First-use tip for scratch workspaces +# --------------------------------------------------------------------------- +# +# Scratch workspaces are intentionally ephemeral โ€” ``_cleanup_workspace`` +# removes them as soon as ``complete_task`` runs. New users often don't +# realize that and lose worker output (community report, May 2026). The +# behavior is right; the lack of warning is the bug. +# +# On the FIRST scratch workspace materialization across the whole install +# we: +# 1. Log a warning line on the dispatcher logger. +# 2. Append a ``tip_scratch_workspace`` event on the task so it's visible +# via ``hermes kanban show `` and the dashboard. +# 3. Touch a sentinel file under ``kanban_home() / '.scratch_tip_shown'`` +# so we don't repeat the tip โ€” once you know, you know. +# +# Scope is per-install, not per-board: a user creating a second board +# already learned the lesson on board #1. + +_SCRATCH_TIP_SENTINEL_NAME = ".scratch_tip_shown" + +_SCRATCH_TIP_MESSAGE = ( + "scratch workspaces are ephemeral โ€” they're deleted when the task " + "completes. Use --workspace worktree: (git worktree) or " + "--workspace dir:/abs/path (existing dir) to preserve worker output." +) + + +def _scratch_tip_sentinel_path() -> Path: + """Path to the per-install scratch-workspace-tip sentinel file.""" + return kanban_home() / _SCRATCH_TIP_SENTINEL_NAME + + +def _scratch_tip_shown() -> bool: + """True iff the scratch-workspace tip has already been emitted on this + install. Best-effort โ€” any error means we re-emit, which is the safer + failure mode for a help message.""" + try: + return _scratch_tip_sentinel_path().exists() + except OSError: + return False + + +def _mark_scratch_tip_shown() -> None: + """Touch the sentinel so future scratch workspaces stay silent. + + Best-effort: a failure here just means the tip might appear once more, + which is preferable to crashing dispatch over a help message. + """ + try: + path = _scratch_tip_sentinel_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.touch(exist_ok=True) + except OSError: + pass + + +def _maybe_emit_scratch_tip( + conn: sqlite3.Connection, + task_id: str, + workspace_kind: Optional[str], +) -> None: + """Emit the first-use scratch-workspace tip exactly once per install. + + Called from the dispatcher right after a scratch workspace is + materialized. No-op for ``worktree`` / ``dir`` workspaces (they're + preserved by design) and no-op after the sentinel exists. + """ + if (workspace_kind or "scratch") != "scratch": + return + if _scratch_tip_shown(): + return + try: + _log.warning("kanban: %s (task %s)", _SCRATCH_TIP_MESSAGE, task_id) + with write_txn(conn): + _append_event( + conn, task_id, "tip_scratch_workspace", + {"message": _SCRATCH_TIP_MESSAGE}, + ) + except Exception: + # Best-effort โ€” never block the spawn loop over a help message. + pass + finally: + _mark_scratch_tip_shown() + + def edit_completed_task_result( conn: sqlite3.Connection, task_id: str, @@ -3083,6 +3398,77 @@ def block_task( return True + +def promote_task( + conn: sqlite3.Connection, + task_id: str, + *, + actor: str, + reason: Optional[str] = None, + force: bool = False, + dry_run: bool = False, +) -> tuple[bool, Optional[str]]: + """Manually promote a `todo` or `blocked` task to `ready`. + + Mirrors the automatic promotion done by ``recompute_ready`` but + drives it from a deliberate operator action with an audit-trail + entry. Refuses to promote if any parent dep is not in a terminal + state (`done`/`archived`) unless ``force=True``. Does NOT change + assignee or claim state. Returns ``(True, None)`` on success and + ``(False, reason)`` if refused. ``dry_run=True`` validates the + promotion would succeed without mutating state. + """ + row = conn.execute( + "SELECT status FROM tasks WHERE id = ?", (task_id,) + ).fetchone() + if row is None: + return False, f"task {task_id} not found" + + cur_status = row["status"] + if cur_status not in ("todo", "blocked"): + return False, ( + f"task {task_id} is {cur_status!r}; promote only applies to " + f"'todo' or 'blocked'" + ) + + if not force: + parents = conn.execute( + "SELECT t.id, t.status FROM tasks t " + "JOIN task_links l ON l.parent_id = t.id " + "WHERE l.child_id = ?", + (task_id,), + ).fetchall() + unsatisfied = [ + p["id"] for p in parents + if p["status"] not in ("done", "archived") + ] + if unsatisfied: + return False, ( + f"unsatisfied parent dependencies: " + f"{', '.join(unsatisfied)} (use --force to override)" + ) + + if dry_run: + return True, None + + with write_txn(conn): + upd = conn.execute( + "UPDATE tasks SET status = 'ready' " + "WHERE id = ? AND status IN ('todo', 'blocked')", + (task_id,), + ) + if upd.rowcount != 1: + return False, f"task {task_id} status changed during promotion" + _append_event( + conn, + task_id, + "promoted_manual", + {"actor": actor, "reason": reason, "forced": force}, + ) + + return True, None + + def unblock_task(conn: sqlite3.Connection, task_id: str) -> bool: """Transition ``blocked``/``scheduled`` -> ready or todo. @@ -4892,6 +5278,7 @@ def dispatch_once( continue # Persist the resolved workspace path so the worker can cd there. set_workspace_path(conn, claimed.id, str(workspace)) + _maybe_emit_scratch_tip(conn, claimed.id, claimed.workspace_kind) _spawn = spawn_fn if spawn_fn is not None else _default_spawn try: # Back-compat: older spawn_fn signatures accept only @@ -4970,6 +5357,7 @@ def dispatch_once( continue # Persist the resolved workspace path so the worker can cd there. set_workspace_path(conn, claimed.id, str(workspace)) + _maybe_emit_scratch_tip(conn, claimed.id, claimed.workspace_kind) # Force-load sdlc-review skill for review agents. The # _default_spawn function already auto-loads kanban-worker, and # appends task.skills via --skills. Setting task.skills here diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 72f8a91c342..dea8b5cc9d6 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -61,12 +61,76 @@ try: except ModuleNotFoundError: pass +import os +import sys + + +def _is_termux_startup_environment_fast() -> bool: + """Tiny Termux check for pre-import startup shortcuts.""" + prefix = os.environ.get("PREFIX", "") + return bool( + os.environ.get("TERMUX_VERSION") + or "com.termux/files/usr" in prefix + or prefix.startswith("/data/data/com.termux/") + ) + + +def _is_termux_fast_version_argv(argv: list[str]) -> bool: + return argv in (["--version"], ["-V"], ["version"]) + + +def _read_openai_version_fast() -> str | None: + """Read OpenAI SDK version without importing ``importlib.metadata``.""" + for base in sys.path: + if not base: + base = os.getcwd() + version_file = os.path.join(base, "openai", "_version.py") + try: + with open(version_file, encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if not stripped.startswith("__version__"): + continue + _key, _sep, value = stripped.partition("=") + value = value.split("#", 1)[0].strip().strip("\"'") + return value or None + except OSError: + continue + return None + + +def _print_fast_version_info() -> None: + from hermes_cli import __release_date__, __version__ + + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + print(f"Hermes Agent v{__version__} ({__release_date__})") + print(f"Project: {project_root}") + print(f"Python: {sys.version.split()[0]}") + + openai_version = _read_openai_version_fast() + print(f"OpenAI SDK: {openai_version}" if openai_version else "OpenAI SDK: Not installed") + + +def _try_termux_ultrafast_version() -> bool: + """Handle ``hermes --version`` before config/logging imports on Termux.""" + if os.environ.get("HERMES_TERMUX_DISABLE_FAST_CLI") == "1": + return False + if not _is_termux_startup_environment_fast(): + return False + if not _is_termux_fast_version_argv(sys.argv[1:]): + return False + + _print_fast_version_info() + return True + + +if _try_termux_ultrafast_version(): + raise SystemExit(0) + import argparse import json -import os import shutil import subprocess -import sys from pathlib import Path from typing import Optional @@ -591,7 +655,7 @@ def _session_browse_picker(sessions: list) -> Optional[str]: curses.init_pair(1, curses.COLOR_GREEN, -1) # selected curses.init_pair(2, curses.COLOR_YELLOW, -1) # header curses.init_pair(3, curses.COLOR_CYAN, -1) # search - curses.init_pair(4, 8, -1) # dim + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) # dim cursor = 0 scroll_offset = 0 @@ -1390,7 +1454,7 @@ def _launch_tui( provider: Optional[str] = None, toolsets: object = None, skills: object = None, - verbose: bool = False, + verbose: Optional[bool] = None, quiet: bool = False, query: Optional[str] = None, image: Optional[str] = None, @@ -1699,7 +1763,7 @@ def cmd_chat(args): provider=getattr(args, "provider", None), toolsets=getattr(args, "toolsets", None), skills=getattr(args, "skills", None), - verbose=getattr(args, "verbose", False), + verbose=getattr(args, "verbose", None), quiet=getattr(args, "quiet", False), query=getattr(args, "query", None), image=getattr(args, "image", None), @@ -1719,7 +1783,7 @@ def cmd_chat(args): "provider": getattr(args, "provider", None), "toolsets": args.toolsets, "skills": getattr(args, "skills", None), - "verbose": args.verbose, + "verbose": getattr(args, "verbose", None), "quiet": getattr(args, "quiet", False), "query": args.query, "image": getattr(args, "image", None), @@ -1730,6 +1794,7 @@ def cmd_chat(args): "max_turns": getattr(args, "max_turns", None), "ignore_rules": getattr(args, "ignore_rules", False), "ignore_user_config": getattr(args, "ignore_user_config", False), + "compact": getattr(args, "compact", False), } # Filter out None values kwargs = {k: v for k, v in kwargs.items() if v is not None} @@ -2433,10 +2498,34 @@ _AUX_TASKS: list[tuple[str, str, str]] = [ ("mcp", "MCP", "MCP tool reasoning"), ("title_generation", "Title generation", "session titles"), ("skills_hub", "Skills hub", "skills search/install"), + ("triage_specifier", "Triage specifier", "kanban spec fleshing"), + ("kanban_decomposer", "Kanban decomposer", "task decomposition"), + ("profile_describer", "Profile describer", "auto profile descriptions"), ("curator", "Curator", "skill-usage review pass"), ] +def _all_aux_tasks() -> list[tuple[str, str, str]]: + """Return built-in + plugin-registered auxiliary tasks for picker/menu use. + + Built-in tasks come first (preserving order), followed by plugin tasks + sorted by key. Used by ``_aux_config_menu``, ``_reset_aux_to_auto``, and + display-name lookups so plugin-registered tasks (registered via + :meth:`hermes_cli.plugins.PluginContext.register_auxiliary_task`) appear + in the same surfaces as built-in ones without core knowing about them. + """ + tasks = list(_AUX_TASKS) + try: + from hermes_cli.plugins import get_plugin_auxiliary_tasks + for entry in get_plugin_auxiliary_tasks(): + tasks.append((entry["key"], entry["display_name"], entry["description"])) + except Exception: + # Plugin discovery failure must not break the aux config UI. + # Built-in tasks remain available. + pass + return tasks + + def _format_aux_current(task_cfg: dict) -> str: """Render the current aux config for display in the task menu.""" if not isinstance(task_cfg, dict): @@ -2487,7 +2576,11 @@ def _save_aux_choice( def _reset_aux_to_auto() -> int: - """Reset every known aux task back to auto/empty. Returns number reset.""" + """Reset every known aux task back to auto/empty. Returns number reset. + + Includes plugin-registered tasks (via ``_all_aux_tasks``) so a plugin + that contributed an auxiliary task gets reset alongside built-ins. + """ from hermes_cli.config import load_config, save_config cfg = load_config() @@ -2496,7 +2589,7 @@ def _reset_aux_to_auto() -> int: aux = {} cfg["auxiliary"] = aux count = 0 - for task, _name, _desc in _AUX_TASKS: + for task, _name, _desc in _all_aux_tasks(): entry = aux.setdefault(task, {}) if not isinstance(entry, dict): entry = {} @@ -2539,10 +2632,11 @@ def _aux_config_menu() -> None: print() # Build the task menu with current settings inline - name_col = max(len(name) for _, name, _ in _AUX_TASKS) + 2 - desc_col = max(len(desc) for _, _, desc in _AUX_TASKS) + 4 + all_tasks = _all_aux_tasks() + name_col = max(len(name) for _, name, _ in all_tasks) + 2 + desc_col = max(len(desc) for _, _, desc in all_tasks) + 4 entries: list[tuple[str, str]] = [] - for task_key, name, desc in _AUX_TASKS: + for task_key, name, desc in all_tasks: task_cfg = ( aux.get(task_key, {}) if isinstance(aux.get(task_key), dict) else {} ) @@ -2593,7 +2687,7 @@ def _aux_select_for_task(task: str) -> None: current_model = str(task_cfg.get("model") or "").strip() current_base_url = str(task_cfg.get("base_url") or "").strip() - display_name = next((name for key, name, _ in _AUX_TASKS if key == task), task) + display_name = next((name for key, name, _ in _all_aux_tasks() if key == task), task) # Gather authenticated providers (has credentials + curated model list) try: @@ -2664,7 +2758,7 @@ def _aux_flow_provider_model( from hermes_cli.auth import _prompt_model_selection from hermes_cli.models import get_pricing_for_provider - display_name = next((name for key, name, _ in _AUX_TASKS if key == task), task) + display_name = next((name for key, name, _ in _all_aux_tasks() if key == task), task) # Fetch live pricing for this provider (non-blocking) pricing: dict = {} @@ -2710,7 +2804,7 @@ def _aux_flow_custom_endpoint(task: str, task_cfg: dict) -> None: """Prompt for a direct OpenAI-compatible base_url + optional api_key/model.""" import getpass - display_name = next((name for key, name, _ in _AUX_TASKS if key == task), task) + display_name = next((name for key, name, _ in _all_aux_tasks() if key == task), task) current_base_url = str(task_cfg.get("base_url") or "").strip() current_model = str(task_cfg.get("model") or "").strip() @@ -4662,7 +4756,9 @@ def _model_flow_copilot(config, current_model=""): source = creds.get("source", "") else: if source in {"GITHUB_TOKEN", "GH_TOKEN"}: - print(f" GitHub token: {api_key[:8]}... โœ“ ({source})") + from hermes_cli.env_loader import format_secret_source_suffix + bw_suffix = format_secret_source_suffix(source) + print(f" GitHub token: {api_key[:8]}... โœ“ ({source}{bw_suffix})") elif source == "gh auth token": print(" GitHub token: โœ“ (from `gh auth token`)") else: @@ -4919,7 +5015,10 @@ def _prompt_api_key(pconfig, existing_key: str, provider_id: str = "") -> tuple: return new_key, False # Already configured โ€” offer K / R / C โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - print(f" {pconfig.name} API key: {existing_key[:8]}... โœ“") + from hermes_cli.env_loader import format_secret_source_suffix + + source_suffix = format_secret_source_suffix(key_env) if key_env else "" + print(f" {pconfig.name} API key: {existing_key[:8]}... โœ“{source_suffix}") if not key_env: # Nothing we can rewrite; just acknowledge and move on. print() @@ -5202,7 +5301,9 @@ def _model_flow_bedrock_api_key(config, region, current_model=""): # Prompt for API key existing_key = get_env_value("AWS_BEARER_TOKEN_BEDROCK") or "" if existing_key: - print(f" Bedrock API Key: {existing_key[:12]}... โœ“") + from hermes_cli.env_loader import format_secret_source_suffix + source_suffix = format_secret_source_suffix("AWS_BEARER_TOKEN_BEDROCK") + print(f" Bedrock API Key: {existing_key[:12]}... โœ“{source_suffix}") else: print(f" Endpoint: {mantle_base_url}") print() @@ -5873,7 +5974,22 @@ def _model_flow_anthropic(config, current_model=""): if has_creds: # Show what we found if existing_key: - print(f" Anthropic credentials: {existing_key[:12]}... โœ“") + from hermes_cli.env_loader import format_secret_source_suffix + from hermes_cli.auth import PROVIDER_REGISTRY + + # Surface which env var supplied the key so users with + # Bitwarden see "(from Bitwarden)" โ€” without this, a detected + # BSM key looks identical to a key in .env and users assume + # nothing is wired up. + source_suffix = "" + for var in PROVIDER_REGISTRY["anthropic"].api_key_env_vars: + if os.getenv(var, "").strip() == existing_key: + source_suffix = format_secret_source_suffix(var) + if source_suffix: + break + print( + f" Anthropic credentials: {existing_key[:12]}... โœ“{source_suffix}" + ) elif cc_available: print(" Claude Code credentials: โœ“ (auto-detected)") print() @@ -6007,6 +6123,13 @@ def cmd_webhook(args): webhook_command(args) +def cmd_portal(args): + """Nous Portal status and Tool Gateway routing surface.""" + from hermes_cli.portal_cli import portal_command + + return portal_command(args) + + def cmd_slack(args): """Slack integration helpers. @@ -6059,6 +6182,19 @@ def cmd_doctor(args): run_doctor(args) +def cmd_security(args): + """Dispatch `hermes security `.""" + sub = getattr(args, "security_command", None) + if sub in ("audit", None): + from hermes_cli.security_audit import cmd_security_audit + + # Default subcommand is `audit` when no subcmd is given. + code = cmd_security_audit(args) + sys.exit(int(code or 0)) + print(f"unknown security subcommand: {sub}", file=sys.stderr) + sys.exit(2) + + def cmd_dump(args): """Dump setup summary for support/debugging.""" from hermes_cli.dump import run_dump @@ -6835,8 +6971,8 @@ def _update_via_zip(args): ) print("โ†’ Downloading latest version...") + tmp_dir = tempfile.mkdtemp(prefix="hermes-update-") try: - tmp_dir = tempfile.mkdtemp(prefix="hermes-update-") zip_path = os.path.join(tmp_dir, f"hermes-agent-{branch}.zip") urlretrieve(zip_url, zip_path) @@ -6883,12 +7019,11 @@ def _update_via_zip(args): print(f"โœ“ Updated {update_count} items from ZIP") - # Cleanup - shutil.rmtree(tmp_dir, ignore_errors=True) - except Exception as e: print(f"โœ— ZIP update failed: {e}") sys.exit(1) + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) # Clear stale bytecode after ZIP extraction removed = _clear_bytecode_cache(PROJECT_ROOT) @@ -9720,6 +9855,7 @@ def _coalesce_session_name_args(argv: list) -> list: "honcho", "claw", "plugins", + "security", "acp", "webhook", "memory", @@ -10557,10 +10693,10 @@ _BUILTIN_SUBCOMMANDS = frozenset( "config", "cron", "curator", "dashboard", "debug", "doctor", "dump", "fallback", "gateway", "hooks", "import", "insights", "kanban", "login", "logout", "logs", "lsp", "mcp", "memory", "migrate", - "model", "pairing", "plugins", "postinstall", "profile", "proxy", + "model", "pairing", "plugins", "portal", "postinstall", "profile", "proxy", "send", "sessions", "setup", "skills", "slack", "status", "tools", "uninstall", "update", - "version", "webhook", "whatsapp", "chat", "secrets", + "version", "webhook", "whatsapp", "chat", "secrets", "security", # Help-ish invocations โ€” plugin commands not being listed in # top-level --help is an acceptable trade-off for skipping an # expensive eager import of every bundled plugin module. @@ -10717,10 +10853,6 @@ def _set_chat_arg_defaults(args) -> None: setattr(args, attr, default) -def _is_termux_fast_version_argv(argv: list[str]) -> bool: - return argv in (["--version"], ["-V"], ["version"]) - - def _try_termux_fast_cli_launch() -> bool: """Run obvious Termux non-TUI chat/oneshot/version paths on a light parser.""" if not _is_termux_startup_environment(): @@ -10774,7 +10906,17 @@ def _try_termux_fast_cli_launch() -> bool: if args.command in {None, "chat"}: _set_chat_arg_defaults(args) - _prepare_agent_startup(args) + interactive_prompt = not getattr(args, "query", None) and not getattr(args, "image", None) + if interactive_prompt: + # Bare Termux CLI should reach the prompt first and do agent-only + # discovery on the first submitted turn instead of before input. + setattr(args, "compact", True) + os.environ["HERMES_DEFER_AGENT_STARTUP"] = "1" + os.environ["HERMES_FAST_STARTUP_BANNER"] = "1" + if getattr(args, "accept_hooks", False): + os.environ["HERMES_ACCEPT_HOOKS"] = "1" + else: + _prepare_agent_startup(args) cmd_chat(args) return True @@ -11288,6 +11430,13 @@ def main(): help="On existing installs: only prompt for items that are missing " "or unset, instead of running the full reconfigure wizard.", ) + setup_parser.add_argument( + "--portal", + action="store_true", + help="One-shot Nous Portal setup: log in via OAuth, set Nous as the " + "inference provider, and opt into the Tool Gateway. Skips the " + "rest of the wizard.", + ) setup_parser.set_defaults(func=cmd_setup) # ========================================================================= @@ -11763,6 +11912,12 @@ def main(): webhook_parser.set_defaults(func=cmd_webhook) + # ========================================================================= + # portal command โ€” Nous Portal status + Tool Gateway routing + # ========================================================================= + from hermes_cli.portal_cli import add_parser as _add_portal_parser + _add_portal_parser(subparsers) + # ========================================================================= # kanban command โ€” multi-profile collaboration board # ========================================================================= @@ -11861,6 +12016,58 @@ def main(): ) doctor_parser.set_defaults(func=cmd_doctor) + # ========================================================================= + # security command โ€” on-demand supply-chain audit + # ========================================================================= + security_parser = subparsers.add_parser( + "security", + help="Supply-chain audit (OSV.dev) for venv, plugins, and MCP servers", + description=( + "On-demand vulnerability scan against OSV.dev. Covers the Hermes " + "venv (installed PyPI dists), Python deps declared by plugins under " + "~/.hermes/plugins/, and pinned npx/uvx MCP servers in config.yaml. " + "Does NOT scan globally-installed packages or editor/browser extensions." + ), + ) + security_subparsers = security_parser.add_subparsers( + dest="security_command", + metavar="", + ) + + audit_parser = security_subparsers.add_parser( + "audit", + help="Run a one-shot supply-chain audit", + description="Query OSV.dev for known vulnerabilities in installed components.", + ) + audit_parser.add_argument( + "--json", + action="store_true", + help="Emit machine-readable JSON instead of human-readable text", + ) + audit_parser.add_argument( + "--fail-on", + default="critical", + choices=["low", "moderate", "high", "critical"], + help="Exit non-zero when any finding meets this severity (default: critical)", + ) + audit_parser.add_argument( + "--skip-venv", + action="store_true", + help="Skip scanning the Hermes Python venv", + ) + audit_parser.add_argument( + "--skip-plugins", + action="store_true", + help="Skip scanning plugin requirements files", + ) + audit_parser.add_argument( + "--skip-mcp", + action="store_true", + help="Skip scanning pinned MCP servers in config.yaml", + ) + audit_parser.set_defaults(func=cmd_security) + security_parser.set_defaults(func=cmd_security) + # ========================================================================= # dump command # ========================================================================= @@ -12186,6 +12393,11 @@ Examples: skills_audit.add_argument( "name", nargs="?", help="Specific skill to audit (default: all)" ) + skills_audit.add_argument( + "--deep", + action="store_true", + help="Run AST-level analysis on Python files (opt-in diagnostic)", + ) skills_uninstall = skills_subparsers.add_parser( "uninstall", help="Remove a hub-installed skill" @@ -13665,7 +13877,7 @@ Examples: ("model", None), ("provider", None), ("toolsets", None), - ("verbose", False), + ("verbose", None), ("worktree", False), ]: if not hasattr(args, attr): @@ -13680,7 +13892,7 @@ Examples: ("model", None), ("provider", None), ("toolsets", None), - ("verbose", False), + ("verbose", None), ("resume", None), ("continue_last", None), ("worktree", False), diff --git a/hermes_cli/oneshot.py b/hermes_cli/oneshot.py index ebc684f2857..b79644f6706 100644 --- a/hermes_cli/oneshot.py +++ b/hermes_cli/oneshot.py @@ -17,7 +17,6 @@ Model / provider selection mirrors `hermes chat`: Env var fallbacks (used when the corresponding arg is not passed): - HERMES_INFERENCE_MODEL - - HERMES_INFERENCE_PROVIDER (already read by resolve_runtime_provider) """ from __future__ import annotations @@ -28,6 +27,8 @@ import sys from contextlib import redirect_stderr, redirect_stdout from typing import Optional +from hermes_cli.fallback_config import get_fallback_chain + def _normalize_toolsets(toolsets: object = None) -> list[str] | None: if not toolsets: @@ -133,9 +134,8 @@ def run_oneshot( prompt: The user message to send. model: Optional model override. Falls back to HERMES_INFERENCE_MODEL env var, then config.yaml's model.default / model.model. - provider: Optional provider override. Falls back to - HERMES_INFERENCE_PROVIDER env var, then config.yaml's model.provider, - then "auto". + provider: Optional provider override. Falls back to config.yaml's + model.provider, then "auto". toolsets: Optional comma-separated string or iterable of toolsets. Returns the exit code. Caller should sys.exit() with the return. @@ -301,14 +301,9 @@ def _run_agent( toolsets_list = sorted(_get_platform_tools(cfg, "cli")) session_db = _create_session_db_for_oneshot() - # Read fallback chain from profile config โ€” supports both the new list - # format (fallback_providers) and the legacy single-dict (fallback_model). - # Mirrors the same normalization in cli.py so oneshot workers (e.g. kanban - # workers spawned via `hermes -p chat -q ...`) honour the - # profile's fallback chain just like interactive sessions do. - _fb = cfg.get("fallback_providers") or cfg.get("fallback_model") or [] - if isinstance(_fb, dict): - _fb = [_fb] if _fb.get("provider") and _fb.get("model") else [] + # Read the effective fallback chain from profile config so oneshot workers + # honour the same merge semantics as interactive CLI and gateway sessions. + _fb = get_fallback_chain(cfg) agent = AIAgent( api_key=runtime.get("api_key"), diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 6150bf016d1..5b5bf2209de 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -698,6 +698,119 @@ class PluginContext: # -- hook registration -------------------------------------------------- + # -- auxiliary task registration --------------------------------------- + + def register_auxiliary_task( + self, + key: str, + *, + display_name: str, + description: str, + defaults: Optional[Dict[str, Any]] = None, + ) -> None: + """Register a plugin-defined auxiliary LLM task. + + Auxiliary tasks are LLM-backed side jobs (vision analysis, web extraction, + compression, smart-approval, etc.) that route through ``auxiliary_client.py``. + Each task has its own ``auxiliary.`` config block where users can + pin a provider/model independent of the main chat model. + + Plugins use this to declare their own auxiliary tasks without touching + core files. After registration, the task: + + - Appears in the ``hermes model โ†’ Configure auxiliary models`` picker + - Has its provider/model/base_url/api_key bridged from config.yaml to + ``AUXILIARY__*`` env vars at gateway startup + - Gets default routing fields (provider="auto", model="", etc.) merged + into loaded configs so ``cfg.get("auxiliary", {}).get(key)`` works + + Args: + key: stable task key (snake_case). Used in config ``auxiliary.`` + and env vars ``AUXILIARY__*``. Must not shadow a + built-in task key (vision, compression, web_extract, approval, + mcp, title_generation, skills_hub, curator). + display_name: human-readable name shown in the picker. + description: short one-line description shown next to the name. + defaults: optional dict of default routing fields. Recognized keys: + ``provider`` (default "auto"), ``model`` (default ""), + ``base_url`` (default ""), ``api_key`` (default ""), + ``timeout`` (default 60), ``extra_body`` (default {}), + plus any task-specific extras (e.g. ``download_timeout``). + Unknown keys are preserved verbatim โ€” the plugin owns the + schema for its own task. + + Raises: + ValueError: if *key* is empty, contains invalid characters, or + shadows a built-in auxiliary task key. + + Example: + ctx.register_auxiliary_task( + key="memory_retain_filter", + display_name="Memory retain filter", + description="hindsight pre-retain dedup/extract", + defaults={"provider": "auto", "timeout": 30}, + ) + """ + # Validate key shape + if not key or not isinstance(key, str): + raise ValueError( + f"Plugin '{self.manifest.name}' tried to register auxiliary task " + f"with invalid key {key!r}" + ) + if not all(c.isalnum() or c == "_" for c in key): + raise ValueError( + f"Plugin '{self.manifest.name}' auxiliary task key {key!r} " + f"must contain only alphanumeric characters and underscores" + ) + + # Lazy import to avoid circular: hermes_cli.main imports plugins indirectly + from hermes_cli.main import _AUX_TASKS as _BUILTIN_AUX_TASKS + + builtin_keys = {k for k, _name, _desc in _BUILTIN_AUX_TASKS} + if key in builtin_keys: + raise ValueError( + f"Plugin '{self.manifest.name}' cannot register auxiliary task " + f"{key!r} โ€” that key is reserved for a built-in task. " + f"Pick a plugin-namespaced key (e.g. '{self.manifest.name}_{key}')." + ) + + # Reject duplicate registrations across plugins + existing = self._manager._aux_tasks.get(key) + if existing is not None and existing.get("plugin") != self.manifest.name: + raise ValueError( + f"Plugin '{self.manifest.name}' cannot register auxiliary task " + f"{key!r} โ€” already registered by plugin " + f"'{existing.get('plugin')}'" + ) + + # Normalize defaults โ€” plugin owns the schema, but we ensure routing + # fields exist with sensible types so consumers don't crash. + merged_defaults: Dict[str, Any] = { + "provider": "auto", + "model": "", + "base_url": "", + "api_key": "", + "timeout": 60, + "extra_body": {}, + } + if defaults: + for k, v in defaults.items(): + merged_defaults[k] = v + + self._manager._aux_tasks[key] = { + "key": key, + "display_name": display_name, + "description": description, + "defaults": merged_defaults, + "plugin": self.manifest.name, + } + logger.debug( + "Plugin %s registered auxiliary task: %s (%s)", + self.manifest.name, + key, + display_name, + ) + def register_hook(self, hook_name: str, callback: Callable) -> None: """Register a lifecycle hook callback. @@ -782,6 +895,9 @@ class PluginManager: self._cli_ref = None # Set by CLI after plugin discovery # Plugin skill registry: qualified name โ†’ metadata dict. self._plugin_skills: Dict[str, Dict[str, Any]] = {} + # Plugin-registered auxiliary tasks: key โ†’ {key, display_name, + # description, defaults, plugin}. See PluginContext.register_auxiliary_task. + self._aux_tasks: Dict[str, Dict[str, Any]] = {} # ----------------------------------------------------------------------- # Public @@ -803,6 +919,7 @@ class PluginManager: self._cli_commands.clear() self._plugin_commands.clear() self._plugin_skills.clear() + self._aux_tasks.clear() self._context_engine = None self._discovered = True @@ -1548,6 +1665,21 @@ def get_plugin_commands() -> Dict[str, dict]: return _ensure_plugins_discovered()._plugin_commands +def get_plugin_auxiliary_tasks() -> List[Dict[str, Any]]: + """Return all plugin-registered auxiliary tasks as a stable-ordered list. + + Each entry is the registration dict from + :meth:`PluginContext.register_auxiliary_task`: + ``{key, display_name, description, defaults, plugin}``. + + Triggers idempotent plugin discovery so callers can read the registry + before any explicit ``discover_plugins()`` call. Sorted by ``key`` for + deterministic ordering in pickers and tests. + """ + manager = _ensure_plugins_discovered() + return [manager._aux_tasks[k] for k in sorted(manager._aux_tasks)] + + def get_plugin_toolsets() -> List[tuple]: """Return plugin toolsets as ``(key, label, description)`` tuples. diff --git a/hermes_cli/plugins_cmd.py b/hermes_cli/plugins_cmd.py index 8c002456787..1388e56ce23 100644 --- a/hermes_cli/plugins_cmd.py +++ b/hermes_cli/plugins_cmd.py @@ -76,22 +76,42 @@ def _plugins_dir() -> Path: return plugins -def _sanitize_plugin_name(name: str, plugins_dir: Path) -> Path: +def _sanitize_plugin_name( + name: str, + plugins_dir: Path, + *, + allow_subdir: bool = False, +) -> Path: """Validate a plugin name and return the safe target path inside *plugins_dir*. Raises ``ValueError`` if the name contains path-traversal sequences or would resolve outside the plugins directory. + + ``allow_subdir=True`` permits a single forward slash inside *name* so + category-namespaced plugin keys like ``observability/langfuse`` or + ``image_gen/openai`` (the registry keys emitted by ``_discover_all_plugins``) + can be looked up. ``..`` and backslash are still rejected, leading and + trailing slashes are stripped, and the resolved target must still live + inside *plugins_dir*. Install paths leave this at the default ``False`` + because a freshly-cloned plugin always lands top-level under + ``~/.hermes/plugins//``. """ if not name: raise ValueError("Plugin name must not be empty.") + if allow_subdir: + name = name.strip("/") + if not name: + raise ValueError("Plugin name must not be empty.") + if name in {".", ".."}: raise ValueError( f"Invalid plugin name '{name}': must not reference the plugins directory itself." ) # Reject obvious traversal characters - for bad in ("/", "\\", ".."): + bad_chars = ("\\", "..") if allow_subdir else ("/", "\\", "..") + for bad in bad_chars: if bad in name: raise ValueError(f"Invalid plugin name '{name}': must not contain '{bad}'.") @@ -326,7 +346,7 @@ def _display_removed(name: str, plugins_dir: Path) -> None: def _require_installed_plugin(name: str, plugins_dir: Path, console) -> Path: """Return the plugin path if it exists, or exit with an error listing installed plugins.""" - target = _sanitize_plugin_name(name, plugins_dir) + target = _sanitize_plugin_name(name, plugins_dir, allow_subdir=True) if not target.exists(): installed = ", ".join(d.name for d in plugins_dir.iterdir() if d.is_dir()) or "(none)" console.print( @@ -1051,7 +1071,7 @@ def _run_composite_ui(curses, plugin_names, plugin_labels, plugin_selected, curses.init_pair(1, curses.COLOR_GREEN, -1) curses.init_pair(2, curses.COLOR_YELLOW, -1) curses.init_pair(3, curses.COLOR_CYAN, -1) - curses.init_pair(4, 8, -1) # dim gray + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) # dim gray cursor = 0 scroll_offset = 0 @@ -1196,7 +1216,7 @@ def _run_composite_ui(curses, plugin_names, plugin_labels, plugin_selected, curses.init_pair(1, curses.COLOR_GREEN, -1) curses.init_pair(2, curses.COLOR_YELLOW, -1) curses.init_pair(3, curses.COLOR_CYAN, -1) - curses.init_pair(4, 8, -1) + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) curses.curs_set(0) elif key in {curses.KEY_ENTER, 10, 13}: if cursor < n_plugins: @@ -1228,7 +1248,7 @@ def _run_composite_ui(curses, plugin_names, plugin_labels, plugin_selected, curses.init_pair(1, curses.COLOR_GREEN, -1) curses.init_pair(2, curses.COLOR_YELLOW, -1) curses.init_pair(3, curses.COLOR_CYAN, -1) - curses.init_pair(4, 8, -1) + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) curses.curs_set(0) elif key in {27, ord("q")}: # Save plugin changes on exit @@ -1508,7 +1528,7 @@ def _user_installed_plugin_dir(name: str) -> Optional[Path]: """Resolved path under ``~/.hermes/plugins/`` if it exists.""" plugins_dir = _plugins_dir() try: - target = _sanitize_plugin_name(name, plugins_dir) + target = _sanitize_plugin_name(name, plugins_dir, allow_subdir=True) except ValueError: return None return target if target.is_dir() else None diff --git a/hermes_cli/portal_cli.py b/hermes_cli/portal_cli.py new file mode 100644 index 00000000000..aa658e41d21 --- /dev/null +++ b/hermes_cli/portal_cli.py @@ -0,0 +1,219 @@ +"""``hermes portal`` โ€” small CLI surface for Nous Portal users. + +Subcommands: + status Show Portal auth state + which Tool Gateway tools are routed. + open Open the Portal subscription page in the user's default browser. + tools List Tool Gateway tools and which are active in the current config. + +This command is intentionally minimal โ€” it does not duplicate functionality +already in ``hermes auth`` or ``hermes tools``. It's a discovery + status +surface for the Portal subscription itself. +""" +from __future__ import annotations + +import sys +import webbrowser +from typing import Optional + +from hermes_cli.colors import Colors, color +from hermes_cli.config import load_config + +DEFAULT_PORTAL_URL = "https://portal.nousresearch.com" +SUBSCRIPTION_URL = "https://portal.nousresearch.com/manage-subscription" +DOCS_URL = "https://hermes-agent.nousresearch.com/docs/user-guide/features/tool-gateway" + + +def _nous_portal_base_url() -> str: + """Resolve the Portal base URL from auth state or default.""" + try: + from hermes_cli.auth import get_nous_auth_status + status = get_nous_auth_status() or {} + url = status.get("portal_base_url") + if isinstance(url, str) and url.strip(): + return url.rstrip("/") + except Exception: + pass + return DEFAULT_PORTAL_URL + + +def _cmd_status(args) -> int: + """Show Portal auth + Tool Gateway routing summary.""" + from hermes_cli.auth import get_nous_auth_status + from hermes_cli.nous_subscription import get_nous_subscription_features + + config = load_config() or {} + + try: + auth = get_nous_auth_status() or {} + except Exception: + auth = {} + + logged_in = bool(auth.get("logged_in")) + + print() + print(color(" Nous Portal", Colors.MAGENTA)) + print(color(" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€", Colors.MAGENTA)) + if logged_in: + portal = auth.get("portal_base_url") or DEFAULT_PORTAL_URL + print(f" Auth: {color('โœ“ logged in', Colors.GREEN)}") + print(f" Portal: {portal}") + inference = auth.get("inference_base_url") + if inference: + print(f" API: {inference}") + else: + print(f" Auth: {color('not logged in', Colors.YELLOW)}") + print(f" Sign up: {SUBSCRIPTION_URL}") + print(f" Login: hermes auth add nous --type oauth") + + # Provider selection (independent of auth) + model_cfg = config.get("model") if isinstance(config.get("model"), dict) else {} + provider = str(model_cfg.get("provider") or "").strip().lower() + if provider == "nous": + print(f" Model: {color('โœ“ using Nous as inference provider', Colors.GREEN)}") + elif provider: + print(f" Model: currently {provider} (switch with `hermes model`)") + + # Tool Gateway routing + print() + print(color(" Tool Gateway", Colors.MAGENTA)) + print(color(" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€", Colors.MAGENTA)) + try: + features = get_nous_subscription_features(config) + except Exception: + features = None + + if features is None: + print(" (could not resolve subscription state)") + return 0 + + rows = [] + for feat in features.items(): + if feat.managed_by_nous: + state = color("via Nous Portal", Colors.GREEN) + elif feat.active and feat.current_provider: + state = feat.current_provider + elif feat.active: + state = "active" + else: + state = color("not configured", Colors.DIM) + rows.append((feat.label, state)) + + width = max((len(r[0]) for r in rows), default=0) + for label, state in rows: + print(f" {label:<{width}} {state}") + + if not logged_in: + print() + print(color(f" Docs: {DOCS_URL}", Colors.DIM)) + return 0 + + +def _cmd_open(args) -> int: + """Open the Portal subscription page in the default browser.""" + target = SUBSCRIPTION_URL + print(f"Opening {target}") + try: + opened = webbrowser.open(target) + except Exception: + opened = False + if not opened: + print() + print("Could not launch a browser. Visit the URL above manually.") + return 1 + return 0 + + +def _cmd_tools(args) -> int: + """List the Tool Gateway catalog + current routing.""" + from hermes_cli.nous_subscription import get_nous_subscription_features + + config = load_config() or {} + try: + features = get_nous_subscription_features(config) + except Exception: + print("Could not resolve Tool Gateway state.", file=sys.stderr) + return 1 + + # Static catalog โ€” the partners Tool Gateway routes to today. + catalog = [ + ("web", "Web search & extract", "Firecrawl"), + ("image_gen", "Image generation", "FAL"), + ("tts", "Text-to-speech", "OpenAI TTS"), + ("browser", "Browser automation", "Browser Use"), + ("modal", "Cloud terminal", "Modal"), + ] + + print() + print(color(" Tool Gateway catalog", Colors.MAGENTA)) + print(color(" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€", Colors.MAGENTA)) + + if not features.nous_auth_present: + print(color(" Not logged into Nous Portal โ€” sign in with `hermes auth add nous --type oauth`.", Colors.YELLOW)) + print() + + label_width = max(len(label) for _, label, _ in catalog) + for key, label, partner in catalog: + feat = features.features.get(key) + if feat is None: + state = color("unknown", Colors.DIM) + elif feat.managed_by_nous: + state = color("โœ“ via Nous Portal", Colors.GREEN) + elif feat.active and feat.current_provider: + state = feat.current_provider + elif feat.active: + state = "active" + else: + state = color("not configured", Colors.DIM) + print(f" {label:<{label_width}} partner: {partner:<14} {state}") + + print() + print(color(f" Manage your subscription: {SUBSCRIPTION_URL}", Colors.DIM)) + print(color(f" Docs: {DOCS_URL}", Colors.DIM)) + return 0 + + +def portal_command(args) -> int: + """Top-level dispatch for `hermes portal `.""" + sub = getattr(args, "portal_command", None) + if sub in {None, ""}: + # Default to status โ€” matches gh / kubectl conventions where the + # subcommand-less form gives a useful overview. + return _cmd_status(args) + if sub == "status": + return _cmd_status(args) + if sub == "open": + return _cmd_open(args) + if sub == "tools": + return _cmd_tools(args) + print(f"Unknown portal subcommand: {sub}", file=sys.stderr) + print("Run `hermes portal -h` for usage.", file=sys.stderr) + return 1 + + +def add_parser(subparsers) -> None: + """Register `hermes portal` on the given argparse subparsers object.""" + portal_parser = subparsers.add_parser( + "portal", + help="Nous Portal status, subscription, and Tool Gateway routing", + description=( + "Inspect Nous Portal auth, Tool Gateway routing, and open the " + "Portal subscription page. Subcommands: status (default), " + "open, tools." + ), + ) + portal_sub = portal_parser.add_subparsers(dest="portal_command") + + portal_sub.add_parser( + "status", + help="Show Portal auth + Tool Gateway routing summary (default)", + ) + portal_sub.add_parser( + "open", + help="Open the Portal subscription page in your default browser", + ) + portal_sub.add_parser( + "tools", + help="List Tool Gateway tools and which are routed via Nous", + ) + + portal_parser.set_defaults(func=portal_command) diff --git a/hermes_cli/proxy/adapters/nous_portal.py b/hermes_cli/proxy/adapters/nous_portal.py index 9fb07a9c053..e85d2100404 100644 --- a/hermes_cli/proxy/adapters/nous_portal.py +++ b/hermes_cli/proxy/adapters/nous_portal.py @@ -27,6 +27,7 @@ from hermes_cli.auth import ( _quarantine_nous_oauth_state, _quarantine_nous_pool_entries, _save_auth_store, + _validate_nous_inference_url_from_network, _write_shared_nous_state, resolve_nous_runtime_credentials, ) @@ -137,7 +138,10 @@ class NousPortalAdapter(UpstreamAdapter): "Try `hermes login nous` to re-authenticate." ) - base_url = refreshed.get("base_url") or DEFAULT_NOUS_INFERENCE_URL + base_url = ( + _validate_nous_inference_url_from_network(refreshed.get("base_url")) + or DEFAULT_NOUS_INFERENCE_URL + ) base_url = base_url.rstrip("/") return UpstreamCredential( diff --git a/hermes_cli/secrets_cli.py b/hermes_cli/secrets_cli.py index d771969017e..38a638576bd 100644 --- a/hermes_cli/secrets_cli.py +++ b/hermes_cli/secrets_cli.py @@ -57,6 +57,15 @@ def register_cli(parent_parser: argparse.ArgumentParser) -> None: "--access-token", help="Provide the access token non-interactively (will be stored in .env)", ) + setup.add_argument( + "--server-url", + help=( + "Bitwarden region / self-hosted endpoint. Examples: " + "https://vault.bitwarden.com (US, default), " + "https://vault.bitwarden.eu (EU), or your self-hosted URL. " + "Skips the interactive region prompt." + ), + ) setup.set_defaults(func=cmd_setup) status = sub.add_parser("status", help="Show config + binary + last fetch") @@ -145,14 +154,28 @@ def cmd_setup(args: argparse.Namespace) -> int: os.environ[token_env] = token # so the test fetch below sees it console.print(f" [green]โœ“[/green] stored in {get_env_path()} as {token_env}") + # ------------------------------------------------------------------ region + console.print() + console.print("[bold]Step 3[/bold] Pick a Bitwarden region") + server_url = _resolve_server_url(args, secrets_cfg, console) + if server_url is None: + return 1 + if server_url: + console.print(f" [green]โœ“[/green] using {server_url}") + else: + console.print( + " [green]โœ“[/green] using bws default " + "(US Cloud, https://vault.bitwarden.com)" + ) + # ------------------------------------------------------------------- project if args.project_id and args.project_id.strip(): project_id = args.project_id.strip() else: console.print() - console.print("[bold]Step 3[/bold] Pick a project") + console.print("[bold]Step 4[/bold] Pick a project") project_id = "" - projects = _list_projects(binary, token, console) + projects = _list_projects(binary, token, console, server_url=server_url) if projects is None: return 1 if not projects: @@ -187,7 +210,7 @@ def cmd_setup(args: argparse.Namespace) -> int: # ------------------------------------------------------------------- test console.print() - step_num = 4 if not (args.project_id and args.project_id.strip()) else 3 + step_num = 5 if not (args.project_id and args.project_id.strip()) else 4 console.print(f"[bold]Step {step_num}[/bold] Test fetch") try: secrets, warnings = bw.fetch_bitwarden_secrets( @@ -195,6 +218,7 @@ def cmd_setup(args: argparse.Namespace) -> int: project_id=project_id, binary=binary, use_cache=False, + server_url=server_url, ) except Exception as exc: # noqa: BLE001 console.print(f" [red]โœ— Fetch failed: {exc}[/red]") @@ -221,6 +245,7 @@ def cmd_setup(args: argparse.Namespace) -> int: # ------------------------------------------------------------------- save secrets_cfg["enabled"] = True secrets_cfg["project_id"] = project_id + secrets_cfg["server_url"] = server_url secrets_cfg.setdefault("access_token_env", token_env) secrets_cfg.setdefault("cache_ttl_seconds", 300) secrets_cfg.setdefault("override_existing", True) @@ -248,6 +273,7 @@ def cmd_status(args: argparse.Namespace) -> int: enabled = bool(bw_cfg.get("enabled")) token_env = bw_cfg.get("access_token_env", "BWS_ACCESS_TOKEN") project_id = bw_cfg.get("project_id", "") + server_url = str(bw_cfg.get("server_url", "") or "").strip() token_set = bool(os.environ.get(token_env)) table = Table(show_header=False, box=None, padding=(0, 2)) @@ -257,6 +283,10 @@ def cmd_status(args: argparse.Namespace) -> int: table.add_row("Token env var", token_env) table.add_row("Token in env", _yn(token_set)) table.add_row("Project ID", project_id or "[dim](unset)[/dim]") + table.add_row( + "Server URL", + server_url or "[dim]default (US Cloud, https://vault.bitwarden.com)[/dim]", + ) table.add_row("Override existing", _yn(bool(bw_cfg.get("override_existing", False)))) table.add_row("Cache TTL (s)", str(bw_cfg.get("cache_ttl_seconds", 300))) table.add_row("Auto-install", _yn(bool(bw_cfg.get("auto_install", True)))) @@ -306,11 +336,14 @@ def cmd_sync(args: argparse.Namespace) -> int: console.print("[red]No project_id configured.[/red]") return 1 + server_url = str(bw_cfg.get("server_url", "") or "").strip() + try: secrets, warnings = bw.fetch_bitwarden_secrets( access_token=token, project_id=project_id, use_cache=False, + server_url=server_url, ) except Exception as exc: # noqa: BLE001 console.print(f"[red]Fetch failed: {exc}[/red]") @@ -407,12 +440,14 @@ def _bws_version(binary: Path) -> str: def _list_projects( - binary: Path, token: str, console: Console + binary: Path, token: str, console: Console, *, server_url: str = "" ) -> Optional[List[dict]]: """Call ``bws project list`` and return the parsed list, or None on failure.""" env = os.environ.copy() env["BWS_ACCESS_TOKEN"] = token env.setdefault("NO_COLOR", "1") + if server_url: + env["BWS_SERVER_URL"] = server_url try: res = subprocess.run( [str(binary), "project", "list", "--output", "json"], @@ -428,7 +463,16 @@ def _list_projects( if res.returncode != 0: err = (res.stderr or res.stdout).strip()[:300] console.print(f" [red]bws project list failed: {err}[/red]") - if "authorization" in err.lower() or "invalid" in err.lower(): + lowered = err.lower() + if "invalid_client" in lowered or "400 bad request" in lowered: + console.print( + " [yellow]'invalid_client' from the US identity endpoint usually " + "means the token is for a different Bitwarden region. Re-run " + "[cyan]hermes secrets bitwarden setup[/cyan] and pick EU or " + "self-hosted at the region prompt, or set [cyan]secrets.bitwarden." + "server_url[/cyan] in config.yaml.[/yellow]" + ) + elif "authorization" in lowered or "invalid" in lowered: console.print( " [yellow]This usually means the access token is wrong or revoked. " "Double-check it in the Bitwarden web app.[/yellow]" @@ -443,3 +487,91 @@ def _list_projects( if not isinstance(data, list): return [] return [p for p in data if isinstance(p, dict) and p.get("id")] + + +# Canonical Bitwarden region endpoints. Keep in sync with what Bitwarden +# publishes โ€” these are stable but if a third region appears, add it here +# and to the prompt below. +_REGION_PRESETS = [ + ("US Cloud (https://vault.bitwarden.com โ€” bws default)", ""), + ("EU Cloud (https://vault.bitwarden.eu)", "https://vault.bitwarden.eu"), +] + + +def _resolve_server_url( + args: argparse.Namespace, + secrets_cfg: dict, + console: Console, +) -> Optional[str]: + """Pick a Bitwarden server URL for setup. + + Resolution order: + 1. ``--server-url`` CLI flag (non-interactive) + 2. ``BWS_SERVER_URL`` env var (so users running with that already set + in their shell don't have to re-enter it) + 3. Existing ``secrets.bitwarden.server_url`` value (for re-runs) + 4. Interactive menu: US / EU / self-hosted + + Returns the chosen URL as a string (empty string = bws default, + i.e. US Cloud). Returns None if the user aborted with an empty + custom URL. + """ + if args.server_url and args.server_url.strip(): + return args.server_url.strip() + + env_url = os.environ.get("BWS_SERVER_URL", "").strip() + if env_url: + console.print( + f" Detected [cyan]BWS_SERVER_URL[/cyan]={env_url} in your shell โ€” using it." + ) + return env_url + + existing = str(secrets_cfg.get("server_url", "") or "").strip() + if existing: + console.print( + f" Existing config: [cyan]{existing}[/cyan]. " + "Press Enter to keep, or pick a different option below." + ) + + table = Table(show_header=True, header_style="bold", box=None, padding=(0, 2)) + table.add_column("#", style="cyan", width=4) + table.add_column("Region / endpoint") + for i, (label, _url) in enumerate(_REGION_PRESETS, 1): + table.add_row(str(i), label) + table.add_row(str(len(_REGION_PRESETS) + 1), "Self-hosted / custom URL") + console.print(table) + + custom_idx = len(_REGION_PRESETS) + 1 + while True: + prompt = f" Select region [1-{custom_idx}]" + if existing: + prompt += " (Enter to keep current)" + prompt += ": " + choice = console.input(prompt).strip() + if not choice: + if existing: + return existing + console.print(" [red]Enter a number.[/red]") + continue + try: + idx = int(choice) + except ValueError: + console.print(" [red]Enter a number.[/red]") + continue + if 1 <= idx <= len(_REGION_PRESETS): + return _REGION_PRESETS[idx - 1][1] + if idx == custom_idx: + custom = console.input( + " Enter your Bitwarden server URL " + "(e.g. https://vault.example.com): " + ).strip() + if not custom: + console.print(" [red]Empty URL, aborting.[/red]") + return None + if not custom.startswith(("http://", "https://")): + console.print( + " [yellow]Warning: URL doesn't start with http:// or " + "https:// โ€” bws may reject it.[/yellow]" + ) + return custom + console.print(f" [red]Out of range โ€” pick 1-{custom_idx}.[/red]") diff --git a/hermes_cli/security_audit.py b/hermes_cli/security_audit.py new file mode 100644 index 00000000000..82d414e0b23 --- /dev/null +++ b/hermes_cli/security_audit.py @@ -0,0 +1,576 @@ +"""On-demand supply-chain audit for Hermes Agent installs. + +Scans three surfaces a Hermes user actually controls and we can map to +upstream advisories without auth or extra binaries: + +1. The Hermes venv (every PyPI dist via ``importlib.metadata``). +2. Python deps declared by user-installed plugins under ``~/.hermes/plugins`` + (``requirements.txt`` + ``pyproject.toml`` best-effort pin extraction). +3. MCP servers wired in ``config.yaml`` whose ``command/args`` look like + ``npx -y @`` or ``uvx ==``. + +Vulnerabilities are looked up against OSV.dev (``api.osv.dev/v1/querybatch`` ++ ``/v1/vulns/{id}``). Single-shot, on-demand, never daily โ€” see the design +notes in ``references/security-disclosure-triage.md``. + +Out of scope on purpose: global pip/npm, editor/browser extensions, +daily background scans, auto-blocking installs. +""" + +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import re +import sys +import urllib.error +import urllib.request +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterable, Optional + +from hermes_constants import get_hermes_home + +OSV_BATCH_URL = "https://api.osv.dev/v1/querybatch" +OSV_VULN_URL = "https://api.osv.dev/v1/vulns/{vid}" +OSV_BATCH_MAX = 1000 # OSV documented hard cap per request +HTTP_TIMEOUT = 20 +DETAIL_PARALLELISM = 8 + +# Severity ordering for --fail-on gating. UNKNOWN sits below LOW so it +# never blocks unless --fail-on is passed something even lower (we don't +# expose that). +SEVERITY_ORDER = { + "UNKNOWN": 0, + "LOW": 1, + "MODERATE": 2, + "MEDIUM": 2, + "HIGH": 3, + "CRITICAL": 4, +} + + +# โ”€โ”€โ”€ Data shapes โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +@dataclass(frozen=True) +class Component: + """A single (name, version, ecosystem) tuple discovered on disk.""" + + name: str + version: str + ecosystem: str # "PyPI" | "npm" โ€” exactly as OSV expects + source: str # human-readable origin, e.g. "venv", "plugin:foo", "mcp:bar" + + +@dataclass +class Vulnerability: + osv_id: str + severity: str = "UNKNOWN" + summary: str = "" + fixed_versions: list[str] = field(default_factory=list) + + +@dataclass +class Finding: + component: Component + vuln: Vulnerability + + +# โ”€โ”€โ”€ Component discovery โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _discover_venv() -> list[Component]: + """Every dist installed in the running Python's import path.""" + from importlib.metadata import distributions + + out: list[Component] = [] + seen: set[tuple[str, str]] = set() + for dist in distributions(): + try: + name = (dist.metadata["Name"] or "").strip() + except Exception: + continue + version = (dist.version or "").strip() + if not name or not version: + continue + key = (name.lower(), version) + if key in seen: + continue + seen.add(key) + out.append(Component(name=name, version=version, ecosystem="PyPI", source="venv")) + return out + + +# requirements.txt line: drop comments, environment markers, options, extras +_REQ_LINE = re.compile( + r"""^\s* + (?P[A-Za-z0-9][A-Za-z0-9._-]*) + (?:\[[^\]]+\])? # extras + \s*==\s* + (?P[A-Za-z0-9._+!-]+) + \s*(?:;.*)?$ + """, + re.VERBOSE, +) + + +def _parse_requirements(text: str) -> list[tuple[str, str]]: + """Extract ``name==version`` pins. Everything else (>=, ~=, no pin) is skipped. + + A loose pin can't be mapped to a single OSV query, and getting it wrong + is worse than missing a finding for an audit tool โ€” false positives + train users to ignore output. + """ + pins: list[tuple[str, str]] = [] + for raw in text.splitlines(): + line = raw.strip() + if not line or line.startswith("#") or line.startswith("-"): + continue + m = _REQ_LINE.match(line) + if m: + pins.append((m.group("name"), m.group("version"))) + return pins + + +def _parse_pyproject_pins(text: str) -> list[tuple[str, str]]: + """Pull ``name==version`` pins from a ``pyproject.toml`` ``dependencies`` list. + + Uses stdlib ``tomllib`` (3.11+). Same exact-pin policy as requirements. + """ + try: + import tomllib + except ImportError: # pragma: no cover - 3.10 only + return [] + try: + data = tomllib.loads(text) + except Exception: + return [] + deps: list[str] = [] + project = data.get("project") or {} + if isinstance(project.get("dependencies"), list): + deps.extend(str(x) for x in project["dependencies"]) + optional = project.get("optional-dependencies") or {} + if isinstance(optional, dict): + for group in optional.values(): + if isinstance(group, list): + deps.extend(str(x) for x in group) + pins: list[tuple[str, str]] = [] + for dep in deps: + m = _REQ_LINE.match(dep) + if m: + pins.append((m.group("name"), m.group("version"))) + return pins + + +def _discover_plugins(hermes_home: Path) -> list[Component]: + """Python deps declared by plugins under ``~/.hermes/plugins``. + + Plugins typically don't install into the venv (they're directory-based + with relative imports), so their stated requirements are useful audit + surface even when the venv scan misses them. + """ + plugins_dir = hermes_home / "plugins" + if not plugins_dir.is_dir(): + return [] + + out: list[Component] = [] + for plugin_dir in sorted(plugins_dir.iterdir()): + if not plugin_dir.is_dir() or plugin_dir.name.startswith("."): + continue + source = f"plugin:{plugin_dir.name}" + for req_file in ("requirements.txt", "requirements-dev.txt"): + path = plugin_dir / req_file + if path.is_file(): + try: + pins = _parse_requirements(path.read_text(encoding="utf-8", errors="replace")) + except OSError: + continue + for name, version in pins: + out.append(Component(name=name, version=version, ecosystem="PyPI", source=source)) + pyproject = plugin_dir / "pyproject.toml" + if pyproject.is_file(): + try: + pins = _parse_pyproject_pins(pyproject.read_text(encoding="utf-8", errors="replace")) + except OSError: + continue + for name, version in pins: + out.append(Component(name=name, version=version, ecosystem="PyPI", source=source)) + return out + + +# npx forms we recognise: +# npx -y @scope/pkg@1.2.3 +# npx --yes pkg@1.2.3 +# npx pkg@1.2.3 [...args] +# We deliberately don't try to resolve unversioned names โ€” that maps to +# "latest" at runtime and isn't a stable audit subject. +_NPX_PKG = re.compile(r"^(@[A-Za-z0-9._-]+/[A-Za-z0-9._-]+|[A-Za-z0-9._-]+)@([A-Za-z0-9._+-]+)$") +# uvx forms: +# uvx pkg==1.2.3 +# uvx --with pkg==1.2.3 entrypoint +_UVX_PKG = re.compile(r"^([A-Za-z0-9][A-Za-z0-9._-]*)==([A-Za-z0-9._+!-]+)$") + + +def _extract_mcp_component(server_name: str, command: str, args: list[str]) -> Optional[Component]: + """Best-effort: parse `command/args` into a (name, version, ecosystem). + + Returns None when the entry doesn't pin a version we can audit (local + paths, Docker images, unversioned npx, etc.). Audit output stays silent + rather than guess. + """ + cmd = (command or "").strip().lower() + if not args: + return None + # npx (any prefix path) + if cmd.endswith("npx") or cmd == "npx": + # Skip flag tokens until we see the first thing that looks like a pkg ref + for token in args: + if token.startswith("-"): + continue + m = _NPX_PKG.match(token) + if m: + return Component( + name=m.group(1), + version=m.group(2), + ecosystem="npm", + source=f"mcp:{server_name}", + ) + return None # First non-flag token isn't a pinned ref + # uvx (any prefix path) + if cmd.endswith("uvx") or cmd == "uvx": + for token in args: + if token.startswith("-"): + continue + m = _UVX_PKG.match(token) + if m: + return Component( + name=m.group(1), + version=m.group(2), + ecosystem="PyPI", + source=f"mcp:{server_name}", + ) + return None + return None + + +def _discover_mcp() -> list[Component]: + """Pinned MCP server packages from ``config.yaml``.""" + try: + from hermes_cli.mcp_config import _get_mcp_servers + except Exception: + return [] + + out: list[Component] = [] + servers = _get_mcp_servers() + if not isinstance(servers, dict): + return [] + for name, cfg in servers.items(): + if not isinstance(cfg, dict): + continue + command = cfg.get("command", "") or "" + args = cfg.get("args") or [] + if not isinstance(args, list): + continue + comp = _extract_mcp_component(name, command, [str(a) for a in args]) + if comp is not None: + out.append(comp) + return out + + +# โ”€โ”€โ”€ OSV client โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _http_post_json(url: str, payload: dict) -> dict: + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, data=data, headers={"Content-Type": "application/json"}, method="POST" + ) + with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def _http_get_json(url: str) -> dict: + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def _osv_query_batch(components: list[Component]) -> dict[Component, list[str]]: + """Return {component -> [osv_id, ...]} for components with any vulns. + + Components without findings are omitted from the result dict. + """ + if not components: + return {} + findings: dict[Component, list[str]] = {} + for chunk_start in range(0, len(components), OSV_BATCH_MAX): + chunk = components[chunk_start:chunk_start + OSV_BATCH_MAX] + payload = { + "queries": [ + { + "package": {"name": c.name, "ecosystem": c.ecosystem}, + "version": c.version, + } + for c in chunk + ] + } + try: + resp = _http_post_json(OSV_BATCH_URL, payload) + except (urllib.error.URLError, TimeoutError, ConnectionError) as exc: + raise RuntimeError(f"OSV batch query failed: {exc}") from exc + results = resp.get("results") or [] + for comp, result in zip(chunk, results): + vulns = (result or {}).get("vulns") or [] + ids = [v.get("id") for v in vulns if v.get("id")] + if ids: + findings[comp] = ids + return findings + + +def _osv_severity_from_record(record: dict) -> str: + """Extract CVSS-derived severity tier from an OSV vuln record.""" + # OSV puts CVSS in `severity` (top-level or per-affected) and a + # human-readable bucket in `database_specific.severity` for GHSAs. + db_specific = record.get("database_specific") or {} + raw = db_specific.get("severity") + if isinstance(raw, str) and raw.strip(): + upper = raw.strip().upper() + if upper in SEVERITY_ORDER: + return upper + # Fall back to CVSS score โ†’ tier + score: Optional[float] = None + for sev_entry in record.get("severity") or []: + s = sev_entry.get("score") + if isinstance(s, str): + # CVSS vector strings look like "CVSS:3.1/AV:N/..." โ€” we can't + # parse without a lib. Look for an explicit numeric in + # affected[].ecosystem_specific later if present. + continue + affected = record.get("affected") or [] + for entry in affected: + eco_spec = entry.get("ecosystem_specific") or {} + sev = eco_spec.get("severity") + if isinstance(sev, str) and sev.strip().upper() in SEVERITY_ORDER: + return sev.strip().upper() + if score is not None: + if score >= 9.0: + return "CRITICAL" + if score >= 7.0: + return "HIGH" + if score >= 4.0: + return "MODERATE" + if score > 0: + return "LOW" + return "UNKNOWN" + + +def _osv_fixed_versions(record: dict) -> list[str]: + fixes: list[str] = [] + for entry in record.get("affected") or []: + for rng in entry.get("ranges") or []: + for event in rng.get("events") or []: + if "fixed" in event: + fixes.append(str(event["fixed"])) + # Dedupe, preserve order + seen: set[str] = set() + out: list[str] = [] + for f in fixes: + if f not in seen: + seen.add(f) + out.append(f) + return out + + +def _osv_fetch_details(vuln_ids: Iterable[str]) -> dict[str, Vulnerability]: + """Fetch summary/severity for each unique vuln id, in parallel.""" + unique = sorted({vid for vid in vuln_ids if vid}) + if not unique: + return {} + out: dict[str, Vulnerability] = {} + + def _fetch_one(vid: str) -> Vulnerability: + try: + rec = _http_get_json(OSV_VULN_URL.format(vid=vid)) + except (urllib.error.URLError, TimeoutError, ConnectionError): + return Vulnerability(osv_id=vid) + return Vulnerability( + osv_id=vid, + severity=_osv_severity_from_record(rec), + summary=(rec.get("summary") or "").strip(), + fixed_versions=_osv_fixed_versions(rec), + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=DETAIL_PARALLELISM) as pool: + for vuln in pool.map(_fetch_one, unique): + out[vuln.osv_id] = vuln + return out + + +# โ”€โ”€โ”€ Orchestration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def run_audit( + *, + skip_venv: bool = False, + skip_plugins: bool = False, + skip_mcp: bool = False, + hermes_home: Optional[Path] = None, +) -> list[Finding]: + """Discover components, query OSV, return findings sorted by severity desc.""" + home = hermes_home or Path(get_hermes_home()) + components: list[Component] = [] + if not skip_venv: + components.extend(_discover_venv()) + if not skip_plugins: + components.extend(_discover_plugins(home)) + if not skip_mcp: + components.extend(_discover_mcp()) + + if not components: + return [] + + raw = _osv_query_batch(components) + if not raw: + return [] + + all_ids: list[str] = [] + for ids in raw.values(): + all_ids.extend(ids) + details = _osv_fetch_details(all_ids) + + findings: list[Finding] = [] + for comp, ids in raw.items(): + for vid in ids: + vuln = details.get(vid) or Vulnerability(osv_id=vid) + findings.append(Finding(component=comp, vuln=vuln)) + + findings.sort( + key=lambda f: ( + -SEVERITY_ORDER.get(f.vuln.severity, 0), + f.component.source, + f.component.name.lower(), + f.vuln.osv_id, + ) + ) + return findings + + +# โ”€โ”€โ”€ Rendering โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _render_human(findings: list[Finding], total_components: int) -> str: + if not findings: + return f"No known vulnerabilities found across {total_components} component(s)." + + lines: list[str] = [] + lines.append( + f"Found {len(findings)} known vulnerability finding(s) " + f"across {total_components} component(s):" + ) + lines.append("") + last_source = None + for f in findings: + if f.component.source != last_source: + lines.append(f"[{f.component.source}]") + last_source = f.component.source + sev = f.vuln.severity.ljust(8) + head = f" {sev} {f.component.name}=={f.component.version} {f.vuln.osv_id}" + lines.append(head) + if f.vuln.summary: + summary = f.vuln.summary + if len(summary) > 100: + summary = summary[:97] + "..." + lines.append(f" {summary}") + if f.vuln.fixed_versions: + lines.append(f" fixed in: {', '.join(f.vuln.fixed_versions[:3])}") + return "\n".join(lines) + + +def _render_json(findings: list[Finding], total_components: int) -> str: + payload = { + "total_components_scanned": total_components, + "finding_count": len(findings), + "findings": [ + { + "package": f.component.name, + "version": f.component.version, + "ecosystem": f.component.ecosystem, + "source": f.component.source, + "vuln_id": f.vuln.osv_id, + "severity": f.vuln.severity, + "summary": f.vuln.summary, + "fixed_versions": f.vuln.fixed_versions, + } + for f in findings + ], + } + return json.dumps(payload, indent=2) + + +def _count_components( + *, skip_venv: bool, skip_plugins: bool, skip_mcp: bool, hermes_home: Path +) -> int: + total = 0 + if not skip_venv: + total += len(_discover_venv()) + if not skip_plugins: + total += len(_discover_plugins(hermes_home)) + if not skip_mcp: + total += len(_discover_mcp()) + return total + + +# โ”€โ”€โ”€ CLI entrypoint โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def cmd_security_audit(args: argparse.Namespace) -> int: + """Implementation of `hermes security audit`.""" + home = Path(get_hermes_home()) + skip_venv = bool(getattr(args, "skip_venv", False)) + skip_plugins = bool(getattr(args, "skip_plugins", False)) + skip_mcp = bool(getattr(args, "skip_mcp", False)) + output_json = bool(getattr(args, "json", False)) + fail_on = (getattr(args, "fail_on", None) or "critical").upper() + if fail_on not in SEVERITY_ORDER: + print( + f"unknown --fail-on value: {fail_on.lower()} " + f"(choose from: low, moderate, high, critical)", + file=sys.stderr, + ) + return 2 + + total = _count_components( + skip_venv=skip_venv, skip_plugins=skip_plugins, skip_mcp=skip_mcp, hermes_home=home + ) + if total == 0: + msg = "No components discovered (everything skipped, or empty environment)." + if output_json: + print(json.dumps({"total_components_scanned": 0, "finding_count": 0, "findings": []})) + else: + print(msg) + return 0 + + try: + findings = run_audit( + skip_venv=skip_venv, + skip_plugins=skip_plugins, + skip_mcp=skip_mcp, + hermes_home=home, + ) + except RuntimeError as exc: + print(f"audit failed: {exc}", file=sys.stderr) + return 2 + + if output_json: + print(_render_json(findings, total)) + else: + print(_render_human(findings, total)) + + # Exit code: 1 iff any finding meets or exceeds the --fail-on threshold. + threshold = SEVERITY_ORDER[fail_on] + for f in findings: + if SEVERITY_ORDER.get(f.vuln.severity, 0) >= threshold: + return 1 + return 0 diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 1e4b6d7fc7b..16eeba4e825 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -2034,74 +2034,6 @@ def _setup_telegram(): save_env_value("TELEGRAM_HOME_CHANNEL", home_channel) -def _setup_discord(): - """Configure Discord bot credentials and allowlist.""" - print_header("Discord") - existing = get_env_value("DISCORD_BOT_TOKEN") - if existing: - print_info("Discord: already configured") - if not prompt_yes_no("Reconfigure Discord?", False): - if not get_env_value("DISCORD_ALLOWED_USERS"): - print_info("โš ๏ธ Discord has no user allowlist - anyone can use your bot!") - if prompt_yes_no("Add allowed users now?", True): - print_info(" To find Discord ID: Enable Developer Mode, right-click name โ†’ Copy ID") - allowed_users = prompt("Allowed user IDs (comma-separated)") - if allowed_users: - cleaned_ids = _clean_discord_user_ids(allowed_users) - save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids)) - print_success("Discord allowlist configured") - return - - print_info("Create a bot at https://discord.com/developers/applications") - token = prompt("Discord bot token", password=True) - if not token: - return - save_env_value("DISCORD_BOT_TOKEN", token) - print_success("Discord token saved") - - print() - print_info("๐Ÿ”’ Security: Restrict who can use your bot") - print_info(" To find your Discord user ID:") - print_info(" 1. Enable Developer Mode in Discord settings") - print_info(" 2. Right-click your name โ†’ Copy ID") - print() - print_info(" You can also use Discord usernames (resolved on gateway start).") - print() - allowed_users = prompt( - "Allowed user IDs or usernames (comma-separated, leave empty for open access)" - ) - if allowed_users: - cleaned_ids = _clean_discord_user_ids(allowed_users) - save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids)) - print_success("Discord allowlist configured") - else: - print_info("โš ๏ธ No allowlist set - anyone in servers with your bot can use it!") - - print() - print_info("๐Ÿ“ฌ Home Channel: where Hermes delivers cron job results,") - print_info(" cross-platform messages, and notifications.") - print_info(" To get a channel ID: right-click a channel โ†’ Copy Channel ID") - print_info(" (requires Developer Mode in Discord settings)") - print_info(" You can also set this later by typing /set-home in a Discord channel.") - home_channel = prompt("Home channel ID (leave empty to set later with /set-home)") - if home_channel: - save_env_value("DISCORD_HOME_CHANNEL", home_channel) - - -def _clean_discord_user_ids(raw: str) -> list: - """Strip common Discord mention prefixes from a comma-separated ID string.""" - cleaned = [] - for uid in raw.replace(" ", "").split(","): - uid = uid.strip() - if uid.startswith("<@") and uid.endswith(">"): - uid = uid.lstrip("<@!").rstrip(">") - if uid.lower().startswith("user:"): - uid = uid[5:] - if uid: - cleaned.append(uid) - return cleaned - - def _setup_slack(): """Configure Slack bot credentials.""" print_header("Slack") @@ -2256,28 +2188,58 @@ def _setup_matrix(): print_success("E2EE enabled") matrix_pkg = "mautrix[encryption]" if want_e2ee else "mautrix" + # Use the central lazy-deps feature group so we install ALL of + # platform.matrix's dependencies (mautrix, Markdown, aiosqlite, + # asyncpg, aiohttp-socks) โ€” not just mautrix itself. The previous + # hand-rolled ``pip install mautrix[encryption]`` left asyncpg / + # aiosqlite uninstalled and broke E2EE connect with + # ``No module named 'asyncpg'`` on every fresh install (#31116). try: - __import__("mautrix") + from tools.lazy_deps import ensure as _lazy_ensure, feature_missing + _missing_before = feature_missing("platform.matrix") + if _missing_before: + print_info( + f"Installing {matrix_pkg} (+ {len(_missing_before)} runtime deps)..." + ) + try: + _lazy_ensure("platform.matrix", prompt=False) + print_success(f"{matrix_pkg} installed") + except Exception as exc: + print_warning( + f"Install failed โ€” run manually: pip install " + f"'mautrix[encryption]' asyncpg aiosqlite Markdown " + f"aiohttp-socks" + ) + print_info(f" Error: {exc}") except ImportError: - print_info(f"Installing {matrix_pkg}...") - import subprocess - uv_bin = shutil.which("uv") - if uv_bin: - result = subprocess.run( - [uv_bin, "pip", "install", "--python", sys.executable, matrix_pkg], - capture_output=True, text=True, - ) - else: - result = subprocess.run( - [sys.executable, "-m", "pip", "install", matrix_pkg], - capture_output=True, text=True, - ) - if result.returncode == 0: - print_success(f"{matrix_pkg} installed") - else: - print_warning(f"Install failed โ€” run manually: pip install '{matrix_pkg}'") - if result.stderr: - print_info(f" Error: {result.stderr.strip().splitlines()[-1]}") + # tools.lazy_deps unavailable (extreme edge case โ€” partial + # install). Fall back to the legacy single-package install + # path so the wizard still does *something*. + try: + __import__("mautrix") + except ImportError: + print_info(f"Installing {matrix_pkg}...") + import subprocess + uv_bin = shutil.which("uv") + if uv_bin: + result = subprocess.run( + [uv_bin, "pip", "install", "--python", sys.executable, matrix_pkg], + capture_output=True, text=True, + ) + else: + result = subprocess.run( + [sys.executable, "-m", "pip", "install", matrix_pkg], + capture_output=True, text=True, + ) + if result.returncode == 0: + print_success(f"{matrix_pkg} installed") + else: + print_warning( + f"Install failed โ€” run manually: pip install " + f"'{matrix_pkg}' asyncpg aiosqlite Markdown aiohttp-socks" + ) + if result.stderr: + print_info(f" Error: {result.stderr.strip().splitlines()[-1]}") print() print_info("๐Ÿ”’ Security: Restrict who can use your bot") @@ -3128,6 +3090,119 @@ SETUP_SECTIONS = [ ] +def _run_portal_one_shot(config: dict) -> None: + """One-shot Nous Portal setup โ€” OAuth + provider switch + Tool Gateway. + + Wired into ``hermes setup --portal``. Does NOT prompt for anything + besides what the underlying OAuth + Tool Gateway prompts already need. + Designed to be shareable as a single command (``hermes setup --portal``) + that gets a brand-new user from zero to a fully working Hermes session + with web/image/tts/browser tools all routed via their Portal sub. + """ + from types import SimpleNamespace + + from hermes_cli.auth_commands import auth_add_command + from hermes_cli.config import save_config + from hermes_cli.auth import get_nous_auth_status + from hermes_cli.nous_subscription import prompt_enable_tool_gateway + + print() + print( + color( + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”", + Colors.MAGENTA, + ) + ) + print(color("โ”‚ โš• Hermes Setup โ€” Nous Portal (one-shot) โ”‚", Colors.MAGENTA)) + print( + color( + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜", + Colors.MAGENTA, + ) + ) + print() + print_info(" One subscription, 300+ models, plus the Tool Gateway:") + print_info(" web search, image generation, TTS, browser automation") + print_info(" โ€” all routed through your Nous Portal sub.") + print() + print_info(" Sign up: https://portal.nousresearch.com/manage-subscription") + print() + + # Skip OAuth if already logged in (don't re-prompt every time the user + # runs `hermes setup --portal` after a successful first run). + already_logged_in = False + try: + already_logged_in = bool((get_nous_auth_status() or {}).get("logged_in")) + except Exception: + already_logged_in = False + + if already_logged_in: + print_success(" Already logged into Nous Portal.") + else: + # Hand off to the shared auth wiring so the device-code flow is + # identical to `hermes auth add nous --type oauth`. SimpleNamespace + # mirrors the argparse Namespace contract that auth_add_command expects. + ns = SimpleNamespace( + provider="nous", + auth_type="oauth", + label=None, + api_key=None, + portal_url=None, + inference_url=None, + client_id=None, + scope=None, + no_browser=False, + timeout=None, + insecure=False, + ca_bundle=None, + min_key_ttl_seconds=5 * 60, + ) + try: + auth_add_command(ns) + except SystemExit as e: + print() + print_error(f" Nous Portal login failed (exit {e.code}).") + print_info(" You can retry later with `hermes auth add nous --type oauth`.") + return + except (KeyboardInterrupt, EOFError): + print() + print_info(" Setup cancelled.") + return + except Exception as exc: + print() + print_error(f" Nous Portal login failed: {exc}") + print_info(" You can retry later with `hermes auth add nous --type oauth`.") + return + + # Set provider โ†’ nous so the model picker, status surfaces, and + # managed-tool gating all light up. Leave model.model empty so the + # runtime picks Nous's default model; the user can change it later + # with `hermes model`. + model_cfg = config.get("model") + if not isinstance(model_cfg, dict): + model_cfg = {} + config["model"] = model_cfg + model_cfg["provider"] = "nous" + save_config(config) + print() + print_success(" Nous set as your inference provider.") + + # Offer the Tool Gateway opt-in (single Y/n) โ€” same flow that fires + # from `hermes model` after picking Nous. + print() + try: + prompt_enable_tool_gateway(config) + except (KeyboardInterrupt, EOFError): + pass + except Exception as exc: + print_warning(f" Tool Gateway prompt skipped: {exc}") + + print() + print_success("Portal setup complete.") + print_info(" Run `hermes portal status` to inspect routing.") + print_info(" Run `hermes` to start chatting.") + + def run_setup_wizard(args): """Run the interactive setup wizard. @@ -3183,6 +3258,11 @@ def run_setup_wizard(args): ) return + # --portal: one-shot Nous Portal setup. Skips the rest of the wizard. + if bool(getattr(args, "portal", False)): + _run_portal_one_shot(config) + return + # Check if a specific section was requested section = getattr(args, "section", None) if section: diff --git a/hermes_cli/skills_hub.py b/hermes_cli/skills_hub.py index b0540705165..5d39b5202f4 100644 --- a/hermes_cli/skills_hub.py +++ b/hermes_cli/skills_hub.py @@ -906,8 +906,14 @@ def do_update(name: Optional[str] = None, console: Optional[Console] = None) -> c.print(f"[bold green]Updated {len(updates)} skill(s).[/]\n") -def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> None: - """Re-run security scan on installed hub skills.""" +def do_audit(name: Optional[str] = None, console: Optional[Console] = None, + deep: bool = False) -> None: + """Re-run security scan on installed hub skills. + + When ``deep=True``, also runs an opt-in AST-level diagnostic on Python + files (review aid only โ€” not a security gate; skills_guard.py verdicts + are unchanged). + """ from tools.skills_hub import HubLockFile, SKILLS_DIR from tools.skills_guard import scan_skill, format_scan_report @@ -928,6 +934,9 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N c.print(f"\n[bold]Auditing {len(targets)} skill(s)...[/]\n") + if deep: + from tools.skills_ast_audit import ast_scan_path, format_ast_report + for entry in targets: skill_path = SKILLS_DIR / entry["install_path"] if not skill_path.exists(): @@ -936,6 +945,10 @@ def do_audit(name: Optional[str] = None, console: Optional[Console] = None) -> N result = scan_skill(skill_path, source=entry.get("identifier", entry["source"])) c.print(format_scan_report(result)) + + if deep: + c.print(format_ast_report(ast_scan_path(skill_path), skill_name=entry["name"])) + c.print() @@ -1343,7 +1356,8 @@ def skills_command(args) -> None: elif action == "update": do_update(name=getattr(args, "name", None)) elif action == "audit": - do_audit(name=getattr(args, "name", None)) + do_audit(name=getattr(args, "name", None), + deep=getattr(args, "deep", False)) elif action == "uninstall": do_uninstall(args.name) elif action == "reset": @@ -1395,6 +1409,8 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: /skills update /skills audit /skills audit my-skill + /skills audit --deep + /skills audit my-skill --deep /skills uninstall my-skill /skills tap list /skills tap add owner/repo @@ -1509,8 +1525,9 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None: do_update(name=name, console=c) elif action == "audit": - name = args[0] if args else None - do_audit(name=name, console=c) + name = args[0] if args and not args[0].startswith("--") else None + deep = "--deep" in args + do_audit(name=name, console=c, deep=deep) elif action == "uninstall": if not args: diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 89771291b20..23cb8e685fd 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -311,6 +311,16 @@ TOOL_CATEGORIES = { "image_gen": { "name": "Image Generation", "icon": "๐ŸŽจ", + # Per-provider rows for FAL.ai (`plugins/image_gen/fal`), OpenAI, + # OpenAI Codex, and xAI are injected at runtime from each + # ``plugins.image_gen.`` package via + # ``_plugin_image_gen_providers()`` in ``_visible_providers``. + # Only non-provider UX setup-flow rows remain here: + # - "Nous Subscription" โ€” managed FAL billed via the Nous + # subscription (requires_nous_auth + override_env_vars). + # Uses the fal plugin as the underlying backend but has a + # distinct setup UX. + # Mirrors the shape browser/video_gen ship today. "providers": [ { "name": "Nous Subscription", @@ -322,15 +332,6 @@ TOOL_CATEGORIES = { "override_env_vars": ["FAL_KEY"], "imagegen_backend": "fal", }, - { - "name": "FAL.ai", - "badge": "paid", - "tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.", - "env_vars": [ - {"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"}, - ], - "imagegen_backend": "fal", - }, ], }, "video_gen": { @@ -482,6 +483,11 @@ TOOLSET_ENV_REQUIREMENTS = { # โ”€โ”€โ”€ Post-Setup Hooks โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +def _cua_driver_cmd() -> str: + """Return the cua-driver executable name/path, honoring non-empty overrides.""" + return os.environ.get("HERMES_CUA_DRIVER_CMD", "").strip() or "cua-driver" + + def _pip_install( args: List[str], *, @@ -550,6 +556,55 @@ def _pip_install( ) + +def _check_cua_driver_asset_for_arch() -> bool: + """Check whether the latest CUA release ships an asset for this architecture. + + Returns True if the asset likely exists (or if we cannot determine it). + Returns False and prints a warning when the asset is confirmed missing, + so callers can skip the install attempt and avoid a raw 404. + """ + import platform as _plat + import urllib.request + + machine = _plat.machine() # "x86_64" or "arm64" + if machine == "arm64": + # arm64 (Apple Silicon) assets are always published. + return True + + # x86_64 / Intel โ€” probe the latest release for an architecture-specific + # asset before falling through to the upstream installer. + api_url = ( + "https://api.github.com/repos/trycua/cua/releases/latest" + ) + try: + req = urllib.request.Request(api_url, headers={"Accept": "application/vnd.github+json"}) + with urllib.request.urlopen(req, timeout=10) as resp: + release = _json.loads(resp.read().decode()) + tag = release.get("tag_name", "") + assets = release.get("assets", []) + arch_names = {"x86_64", "amd64"} + has_asset = any( + any(a in a_info.get("name", "").lower() for a in arch_names) + for a_info in assets + ) + if not has_asset: + _print_warning( + f" Latest CUA release ({tag}) has no Intel (x86_64) asset." + ) + _print_info( + " CUA Driver currently only ships Apple Silicon builds." + ) + _print_info( + " See: https://github.com/trycua/cua/issues/1493" + ) + return False + except Exception: + # Network / API failure โ€” proceed and let the installer handle it. + pass + return True + + def install_cua_driver(upgrade: bool = False) -> bool: """Install or refresh the cua-driver binary used by Computer Use. @@ -579,7 +634,8 @@ def install_cua_driver(upgrade: bool = False) -> bool: _print_warning(" Computer Use (cua-driver) is macOS-only; skipping.") return False - binary = shutil.which("cua-driver") + driver_cmd = _cua_driver_cmd() + binary = shutil.which(driver_cmd) # Not installed โ†’ fresh install path (only when caller asked for it). if not binary and not upgrade: @@ -587,18 +643,20 @@ def install_cua_driver(upgrade: bool = False) -> bool: _print_warning(" curl not found โ€” install manually:") _print_info(" https://github.com/trycua/cua/blob/main/libs/cua-driver/README.md") return False + if not _check_cua_driver_asset_for_arch(): + return False return _run_cua_driver_installer(label="Installing") # Already installed and caller didn't ask to upgrade โ†’ just confirm. if binary and not upgrade: try: version = subprocess.run( - ["cua-driver", "--version"], + [driver_cmd, "--version"], capture_output=True, text=True, timeout=5, ).stdout.strip() - _print_success(f" cua-driver already installed: {version or 'unknown version'}") + _print_success(f" {driver_cmd} already installed: {version or 'unknown version'}") except Exception: - _print_success(" cua-driver already installed.") + _print_success(f" {driver_cmd} already installed.") _print_info(" Grant macOS permissions if not done yet:") _print_info(" System Settings > Privacy & Security > Accessibility") _print_info(" System Settings > Privacy & Security > Screen Recording") @@ -609,11 +667,14 @@ def install_cua_driver(upgrade: bool = False) -> bool: _print_warning(" curl not found โ€” cannot refresh cua-driver.") return bool(binary) + if not _check_cua_driver_asset_for_arch(): + return bool(binary) + if binary: # Show before/after version when we have a baseline. Best-effort. try: before = subprocess.run( - ["cua-driver", "--version"], + [driver_cmd, "--version"], capture_output=True, text=True, timeout=5, ).stdout.strip() except Exception: @@ -625,13 +686,13 @@ def install_cua_driver(upgrade: bool = False) -> bool: if ok and before: try: after = subprocess.run( - ["cua-driver", "--version"], + [driver_cmd, "--version"], capture_output=True, text=True, timeout=5, ).stdout.strip() if after and after != before: - _print_success(f" cua-driver upgraded: {before} โ†’ {after}") + _print_success(f" {driver_cmd} upgraded: {before} โ†’ {after}") elif after: - _print_info(f" cua-driver up to date: {after}") + _print_info(f" {driver_cmd} up to date: {after}") except Exception: pass return ok @@ -655,11 +716,12 @@ def _run_cua_driver_installer(label: str = "Installing", verbose: bool = True) - _print_info(f" {label} cua-driver (macOS background computer-use)...") else: _print_info(f" {label} cua-driver...") + driver_cmd = _cua_driver_cmd() try: result = subprocess.run(install_cmd, shell=True, timeout=300) - if result.returncode == 0 and shutil.which("cua-driver"): + if result.returncode == 0 and shutil.which(driver_cmd): if verbose: - _print_success(" cua-driver installed.") + _print_success(f" {driver_cmd} installed.") _print_info(" IMPORTANT โ€” grant macOS permissions now:") _print_info(" System Settings > Privacy & Security > Accessibility") _print_info(" System Settings > Privacy & Security > Screen Recording") @@ -1506,12 +1568,9 @@ def _plugin_image_gen_providers() -> list[dict]: Each returned dict looks like a regular ``TOOL_CATEGORIES`` provider row but carries an ``image_gen_plugin_name`` marker so downstream code (config writing, model picker) knows to route through the - plugin registry instead of the in-tree FAL backend. - - FAL is skipped โ€” it's already exposed by the hardcoded - ``TOOL_CATEGORIES["image_gen"]`` entries. When FAL gets ported to - a plugin in a follow-up PR, the hardcoded entries go away and this - function surfaces it alongside OpenAI automatically. + plugin registry. Every image-gen backend is a plugin now โ€” there + are no hardcoded rows left in ``TOOL_CATEGORIES["image_gen"]`` for + this function to dedupe against (see issue #26241). """ try: from agent.image_gen_registry import list_providers @@ -1524,9 +1583,6 @@ def _plugin_image_gen_providers() -> list[dict]: rows: list[dict] = [] for provider in providers: - if getattr(provider, "name", None) == "fal": - # FAL has its own hardcoded rows today. - continue try: schema = provider.get_setup_schema() except Exception: @@ -1751,7 +1807,7 @@ _POST_SETUP_INSTALLED: dict = { # entry when (a) the post_setup is the ONLY install side-effect for # a no-key provider, and (b) an installed-state check is cheap and # doesn't trigger a heavy import. - "cua_driver": lambda: bool(shutil.which("cua-driver")), + "cua_driver": lambda: bool(shutil.which(_cua_driver_cmd())), } @@ -1869,6 +1925,16 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): print() # Plain text labels only (no ANSI codes in menu items) + # When the user is logged into Nous, surface a marker on providers + # whose access is included in their subscription so it's visually + # obvious which options cost extra vs. cost nothing on top of Nous. + try: + _nous_logged_in = bool( + get_nous_subscription_features(config).nous_auth_present + ) + except Exception: + _nous_logged_in = False + provider_choices = [] for p in providers: badge = f" [{p['badge']}]" if p.get("badge") else "" @@ -1882,7 +1948,15 @@ def _configure_tool_category(ts_key: str, cat: dict, config: dict): configured = "" else: configured = " [configured]" - provider_choices.append(f"{p['name']}{badge}{tag}{configured}") + # Highlight Nous-managed entries when the user has Portal auth. + # curses_radiolist can't render ANSI inside item strings, so we + # use a plain unicode star + parenthetical phrase. Suppressed + # when no Portal auth is present so non-subscribers see the + # picker unchanged. + sub_marker = "" + if _nous_logged_in and p.get("managed_nous_feature"): + sub_marker = " โ˜… Included with your Nous subscription" + provider_choices.append(f"{p['name']}{badge}{tag}{configured}{sub_marker}") # Add skip option provider_choices.append("Skip โ€” keep defaults / configure later") @@ -2349,6 +2423,30 @@ def _configure_provider(provider: dict, config: dict): # Prompt for each required env var all_configured = True + # If this BYOK provider lives in a category that ALSO has a + # Nous-managed sibling, show a single dim hint so users know + # they can avoid the key entirely via a Portal subscription. + # Suppressed when the user is already authed to Nous. + _show_portal_hint = False + if env_vars and not managed_feature and not provider.get("requires_nous_auth"): + try: + _has_managed_sibling = False + for _cat_key, _cat in TOOL_CATEGORIES.items(): + _providers = _cat.get("providers", []) + if provider in _providers and any( + sib.get("managed_nous_feature") for sib in _providers + ): + _has_managed_sibling = True + break + if _has_managed_sibling: + _features = get_nous_subscription_features(config) + _show_portal_hint = not _features.nous_auth_present + except Exception: + _show_portal_hint = False + + if _show_portal_hint: + _print_info(" Available through Nous Portal subscription.") + for var in env_vars: existing = get_env_value(var["key"]) if existing: diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 7d28ce07617..eee068d1209 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -48,6 +48,7 @@ from hermes_cli.config import ( redact_key, ) from gateway.status import get_running_pid, read_runtime_status +from utils import env_var_enabled try: from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect @@ -118,7 +119,6 @@ _PUBLIC_API_PATHS: frozenset = frozenset({ "/api/model/info", "/api/dashboard/themes", "/api/dashboard/plugins", - "/api/dashboard/plugins/rescan", }) @@ -975,11 +975,13 @@ _AUX_TASK_SLOTS: Tuple[str, ...] = ( "vision", "web_extract", "compression", - "session_search", "skills_hub", "approval", "mcp", "title_generation", + "triage_specifier", + "kanban_decomposer", + "profile_describer", "curator", ) @@ -3293,24 +3295,49 @@ _VALID_CHANNEL_RE = re.compile(r"^[A-Za-z0-9._-]{1,128}$") _LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost", "testclient"}) -def _is_public_bind() -> bool: - """True when bound to all-interfaces (operator used --insecure).""" - return getattr(app.state, "bound_host", "") in {"0.0.0.0", "::"} - - def _ws_client_is_allowed(ws: "WebSocket") -> bool: """Check if the WebSocket client IP is acceptable. - Allows loopback always; allows any IP when bound to all-interfaces - (--insecure mode, guarded by session token auth). + Allows loopback clients only. """ - if _is_public_bind(): - return True client_host = ws.client.host if ws.client else "" if not client_host: return True return client_host in _LOOPBACK_HOSTS + +def _ws_host_origin_is_allowed(ws: "WebSocket") -> bool: + """Apply the dashboard Host/Origin guard to WebSocket upgrades. + + FastAPI HTTP middleware does not run for WebSocket routes, so the + DNS-rebinding Host check used for normal dashboard HTTP requests must be + repeated here before accepting the upgrade. Browsers also send an Origin + header on WebSocket handshakes; when present, require it to target the + same bound dashboard host. + """ + bound_host = getattr(app.state, "bound_host", None) + if not bound_host: + return True + + host_header = ws.headers.get("host", "") + if not _is_accepted_host(host_header, bound_host): + return False + + origin = ws.headers.get("origin", "") + if not origin: + return True + + parsed = urllib.parse.urlparse(origin) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + return False + + return _is_accepted_host(parsed.netloc, bound_host) + + +def _ws_request_is_allowed(ws: "WebSocket") -> bool: + """Return True when the WebSocket upgrade matches dashboard boundaries.""" + return _ws_host_origin_is_allowed(ws) and _ws_client_is_allowed(ws) + # Per-channel subscriber registry used by /api/pub (PTY-side gateway โ†’ dashboard) # and /api/events (dashboard โ†’ browser sidebar). Keyed by an opaque channel id # the chat tab generates on mount; entries auto-evict when the last subscriber @@ -3389,7 +3416,7 @@ async def _broadcast_event(channel: str, payload: str) -> None: except Exception: # Subscriber went away mid-send; the /api/events finally clause # will remove it from the registry on its next iteration. - pass + _log.warning("broadcast send failed for subscriber on %s", channel, exc_info=True) def _channel_or_close_code(ws: WebSocket) -> Optional[str]: @@ -3412,7 +3439,7 @@ async def pty_ws(ws: WebSocket) -> None: await ws.close(code=4401) return - if not _ws_client_is_allowed(ws): + if not _ws_request_is_allowed(ws): await ws.close(code=4403) return @@ -3531,7 +3558,7 @@ async def gateway_ws(ws: WebSocket) -> None: await ws.close(code=4401) return - if not _ws_client_is_allowed(ws): + if not _ws_request_is_allowed(ws): await ws.close(code=4403) return @@ -3563,7 +3590,7 @@ async def pub_ws(ws: WebSocket) -> None: await ws.close(code=4401) return - if not _ws_client_is_allowed(ws): + if not _ws_request_is_allowed(ws): await ws.close(code=4403) return @@ -3592,7 +3619,7 @@ async def events_ws(ws: WebSocket) -> None: await ws.close(code=4401) return - if not _ws_client_is_allowed(ws): + if not _ws_request_is_allowed(ws): await ws.close(code=4403) return @@ -4044,6 +4071,43 @@ async def set_dashboard_theme(body: ThemeSetBody): # Dashboard plugin system # --------------------------------------------------------------------------- +def _safe_plugin_api_relpath(api_field: Any, *, dashboard_dir: Path) -> Optional[str]: + """Validate the manifest's ``api`` field for the plugin loader. + + The web server later imports this file as a Python module via + ``importlib.util.spec_from_file_location`` (arbitrary code + execution by design โ€” that's how plugins extend the backend). + Pre-#29156 the field was used as-is, which meant: + + * An absolute path swallowed the plugin's dashboard directory + entirely โ€” ``Path('safe/dashboard') / '/tmp/evil.py'`` resolves + to ``/tmp/evil.py``, so any attacker-controlled manifest could + point the import at any Python file on disk (GHSA-5qr3-c538-wm9j). + * A ``../..`` traversal could climb out of the plugin into + neighbouring directories on the search path. + + Return the original string when the resolved path stays under + ``dashboard_dir``; return ``None`` (with a warning logged at the + call site) otherwise so the plugin still loads its static JS/CSS + but its backend ``api`` is rejected. + """ + if not isinstance(api_field, str) or not api_field.strip(): + return None + candidate = Path(api_field) + if candidate.is_absolute(): + return None + try: + resolved = (dashboard_dir / candidate).resolve() + base = dashboard_dir.resolve() + except (OSError, RuntimeError): + return None + try: + resolved.relative_to(base) + except ValueError: + return None + return api_field + + def _discover_dashboard_plugins() -> list: """Scan plugins/*/dashboard/manifest.json for dashboard extensions. @@ -4062,7 +4126,16 @@ def _discover_dashboard_plugins() -> list: (bundled_root / "memory", "bundled"), (bundled_root, "bundled"), ] - if os.environ.get("HERMES_ENABLE_PROJECT_PLUGINS"): + # GHSA-5qr3-c538-wm9j (#29156): the previous ``os.environ.get(...)`` + # check treated *any* non-empty string as truthy, so ``=0``, ``=false``, + # and ``=no`` โ€” all of which the agent loader and operators correctly + # read as "disabled" โ€” silently *enabled* the untrusted project source + # in the web server. Combined with the absolute-path RCE primitive on + # the manifest's ``api`` field (now patched below), this turned the + # opt-in into a sticky always-on switch. Use the shared truthy + # semantics (``1`` / ``true`` / ``yes`` / ``on``) so the gate matches + # ``hermes_cli/plugins.py`` and the documented user contract. + if env_var_enabled("HERMES_ENABLE_PROJECT_PLUGINS"): search_dirs.append((Path.cwd() / ".hermes" / "plugins", "project")) for plugins_root, source in search_dirs: @@ -4101,6 +4174,23 @@ def _discover_dashboard_plugins() -> list: slots: List[str] = [] if isinstance(slots_src, list): slots = [s for s in slots_src if isinstance(s, str) and s] + # Validate ``api`` at discovery time so the value cached + # on the plugin entry is already safe to feed into the + # importer. An attacker-controlled manifest can name + # any absolute path or ``..`` traversal here โ€” the + # web server then imports that file as a Python module + # (RCE, GHSA-5qr3-c538-wm9j). + raw_api = data.get("api") + dashboard_dir = child / "dashboard" + safe_api = _safe_plugin_api_relpath(raw_api, dashboard_dir=dashboard_dir) + if raw_api and safe_api is None: + _log.warning( + "Plugin %s: refusing unsafe api path %r (must be a " + "relative file inside the plugin's dashboard/ " + "directory); backend routes from this plugin will " + "not be mounted", + name, raw_api, + ) plugins.append({ "name": name, "label": data.get("label", name), @@ -4111,10 +4201,10 @@ def _discover_dashboard_plugins() -> list: "slots": slots, "entry": data.get("entry", "dist/index.js"), "css": data.get("css"), - "has_api": bool(data.get("api")), + "has_api": bool(safe_api), "source": source, - "_dir": str(child / "dashboard"), - "_api_file": data.get("api"), + "_dir": str(dashboard_dir), + "_api_file": safe_api, }) except Exception as exc: _log.warning("Bad dashboard plugin manifest %s: %s", manifest_file, exc) @@ -4317,12 +4407,13 @@ async def post_agent_plugin_install(request: Request, body: _AgentPluginInstallB def _validate_plugin_name(name: str) -> str: """Reject path-traversal attempts in plugin name URL parameters.""" - if not name or "/" in name or "\\" in name or ".." in name: + name = name.strip("/") + if not name or ".." in name or "\\" in name: raise HTTPException(status_code=400, detail="Invalid plugin name.") return name -@app.post("/api/dashboard/agent-plugins/{name}/enable") +@app.post("/api/dashboard/agent-plugins/{name:path}/enable") async def post_agent_plugin_enable(request: Request, name: str): _require_token(request) name = _validate_plugin_name(name) @@ -4334,7 +4425,7 @@ async def post_agent_plugin_enable(request: Request, name: str): return result -@app.post("/api/dashboard/agent-plugins/{name}/disable") +@app.post("/api/dashboard/agent-plugins/{name:path}/disable") async def post_agent_plugin_disable(request: Request, name: str): _require_token(request) name = _validate_plugin_name(name) @@ -4346,7 +4437,7 @@ async def post_agent_plugin_disable(request: Request, name: str): return result -@app.post("/api/dashboard/agent-plugins/{name}/update") +@app.post("/api/dashboard/agent-plugins/{name:path}/update") async def post_agent_plugin_update(request: Request, name: str): _require_token(request) name = _validate_plugin_name(name) @@ -4359,7 +4450,7 @@ async def post_agent_plugin_update(request: Request, name: str): return result -@app.delete("/api/dashboard/agent-plugins/{name}") +@app.delete("/api/dashboard/agent-plugins/{name:path}") async def delete_agent_plugin(request: Request, name: str): _require_token(request) name = _validate_plugin_name(name) @@ -4397,7 +4488,7 @@ class _PluginVisibilityBody(BaseModel): hidden: bool -@app.post("/api/dashboard/plugins/{name}/visibility") +@app.post("/api/dashboard/plugins/{name:path}/visibility") async def post_plugin_visibility(request: Request, name: str, body: _PluginVisibilityBody): """Toggle a plugin's sidebar visibility (persists to config.yaml dashboard.hidden_plugins).""" _require_token(request) @@ -4468,12 +4559,42 @@ def _mount_plugin_api_routes(): Each plugin's ``api`` field points to a Python file that must expose a ``router`` (FastAPI APIRouter). Routes are mounted under ``/api/plugins//``. + + Backend import is restricted to ``bundled`` and ``user`` sources. + Project plugins (``./.hermes/plugins/``) ship with the CWD and are + therefore attacker-controlled in any threat model where the user + opens a malicious repo; they can extend the dashboard UI via + static JS/CSS but their Python ``api`` file is never auto-imported + by the web server. See GHSA-5qr3-c538-wm9j (#29156). """ for plugin in _get_dashboard_plugins(): api_file_name = plugin.get("_api_file") if not api_file_name: continue - api_path = Path(plugin["_dir"]) / api_file_name + if plugin.get("source") == "project": + _log.warning( + "Plugin %s: ignoring backend api=%s (project plugins may " + "not auto-import Python code; move the plugin to " + "~/.hermes/plugins/ if you trust it)", + plugin["name"], api_file_name, + ) + continue + dashboard_dir = Path(plugin["_dir"]) + api_path = dashboard_dir / api_file_name + try: + resolved_api = api_path.resolve() + resolved_base = dashboard_dir.resolve() + resolved_api.relative_to(resolved_base) + except (OSError, RuntimeError, ValueError): + # Discovery already filters this, but re-check here in case + # ``_dir`` was tampered with after caching or a future caller + # bypasses the validator. Defence in depth keeps the import + # primitive contained even if the upstream check regresses. + _log.warning( + "Plugin %s: refusing to import api file outside its " + "dashboard directory (%s)", plugin["name"], api_path, + ) + continue if not api_path.exists(): _log.warning("Plugin %s declares api=%s but file not found", plugin["name"], api_file_name) continue diff --git a/hermes_cli/webhook.py b/hermes_cli/webhook.py index 621acc82e27..75470128707 100644 --- a/hermes_cli/webhook.py +++ b/hermes_cli/webhook.py @@ -11,8 +11,10 @@ hot-reloaded by the webhook adapter without a gateway restart. """ import json +import os import re import secrets +import tempfile import time from pathlib import Path from typing import Dict @@ -23,6 +25,7 @@ from hermes_cli.config import cfg_get _SUBSCRIPTIONS_FILENAME = "webhook_subscriptions.json" +_SUBSCRIPTIONS_FILE_MODE = 0o600 def _hermes_home() -> Path: @@ -48,12 +51,33 @@ def _load_subscriptions() -> Dict[str, dict]: def _save_subscriptions(subs: Dict[str, dict]) -> None: path = _subscriptions_path() path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(".tmp") - tmp_path.write_text( - json.dumps(subs, indent=2, ensure_ascii=False), - encoding="utf-8", + # webhook_subscriptions.json contains per-route HMAC secrets โ€” write + # via tempfile + chmod 0o600 before the atomic rename so a permissive + # umask cannot leave the secrets readable to other local users in the + # window between create and rename. + fd, tmp_name = tempfile.mkstemp( + prefix=f".{path.name}.", + suffix=".tmp", + dir=path.parent, + text=True, ) - atomic_replace(tmp_path, path) + tmp_path = Path(tmp_name) + try: + with os.fdopen(fd, "w", encoding="utf-8") as fh: + json.dump(subs, fh, indent=2, ensure_ascii=False) + fh.flush() + os.fsync(fh.fileno()) + os.chmod(tmp_path, _SUBSCRIPTIONS_FILE_MODE) + atomic_replace(tmp_path, path) + # Re-assert after rename in case the destination existed with a + # broader mode and atomic_replace preserved it. + os.chmod(path, _SUBSCRIPTIONS_FILE_MODE) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except OSError: + pass + raise def _get_webhook_config() -> dict: diff --git a/hermes_state.py b/hermes_state.py index 5804437198a..0391047d055 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -33,7 +33,7 @@ T = TypeVar("T") DEFAULT_DB_PATH = get_hermes_home() / "state.db" -SCHEMA_VERSION = 12 +SCHEMA_VERSION = 13 # --------------------------------------------------------------------------- # WAL-compatibility fallback @@ -237,7 +237,8 @@ CREATE TABLE IF NOT EXISTS messages ( reasoning_details TEXT, codex_reasoning_items TEXT, codex_message_items TEXT, - platform_message_id TEXT + platform_message_id TEXT, + observed INTEGER DEFAULT 0 ); CREATE TABLE IF NOT EXISTS state_meta ( @@ -1460,6 +1461,7 @@ class SessionDB: codex_reasoning_items: Any = None, codex_message_items: Any = None, platform_message_id: str = None, + observed: bool = False, ) -> int: """ Append a message to a session. Returns the message row ID. @@ -1501,8 +1503,8 @@ class SessionDB: """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items, platform_message_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + codex_message_items, platform_message_id, observed) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, role, @@ -1519,6 +1521,7 @@ class SessionDB: codex_items_json, codex_message_items_json, platform_message_id, + 1 if observed else 0, ), ) msg_id = cursor.lastrowid @@ -1590,8 +1593,8 @@ class SessionDB: """INSERT INTO messages (session_id, role, content, tool_call_id, tool_calls, tool_name, timestamp, token_count, finish_reason, reasoning, reasoning_content, reasoning_details, codex_reasoning_items, - codex_message_items, platform_message_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + codex_message_items, platform_message_id, observed) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( session_id, role, @@ -1608,6 +1611,7 @@ class SessionDB: codex_items_json, codex_message_items_json, platform_msg_id, + 1 if msg.get("observed") else 0, ), ) total_messages += 1 @@ -1925,7 +1929,7 @@ class SessionDB: rows = self._conn.execute( "SELECT role, content, tool_call_id, tool_calls, tool_name, " "finish_reason, reasoning, reasoning_content, reasoning_details, " - "codex_reasoning_items, codex_message_items, platform_message_id " + "codex_reasoning_items, codex_message_items, platform_message_id, observed " f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id", tuple(session_ids), ).fetchall() @@ -1953,6 +1957,8 @@ class SessionDB: # for backward compatibility with the JSONL transcript shape. if row["platform_message_id"]: msg["message_id"] = row["platform_message_id"] + if row["observed"]: + msg["observed"] = True # Restore reasoning fields on assistant messages so providers # that replay reasoning (OpenRouter, OpenAI, Nous) receive # coherent multi-turn reasoning context. diff --git a/infographic/bitwarden-secrets-manager/infographic.png b/infographic/bitwarden-secrets-manager/infographic.png deleted file mode 100644 index eb0a25f9bba..00000000000 Binary files a/infographic/bitwarden-secrets-manager/infographic.png and /dev/null differ diff --git a/infographic/bitwarden-secrets-manager/prompts/infographic.md b/infographic/bitwarden-secrets-manager/prompts/infographic.md deleted file mode 100644 index 6c9b5d08c25..00000000000 --- a/infographic/bitwarden-secrets-manager/prompts/infographic.md +++ /dev/null @@ -1,121 +0,0 @@ -Create a professional infographic following these specifications: - -## Image Specifications - -- **Type**: Infographic -- **Layout**: bento-grid -- **Style**: retro-pop-grid -- **Aspect Ratio**: 1:1 (square) -- **Language**: en - -## Core Principles - -- Follow the layout structure precisely for information architecture -- Apply style aesthetics consistently throughout -- Keep information concise, highlight keywords and core concepts -- Use ample whitespace for visual clarity -- Maintain clear visual hierarchy - -## Text Requirements - -- All text must match the specified style treatment -- Main titles should be prominent and readable -- Key concepts should be visually emphasized -- Labels should be clear and appropriately sized -- Use English for all text content - -## Layout Guidelines (bento-grid) - -- Grid of rectangular cells with varied sizes (1x1, 2x1, 1x2, 2x2) -- Hero cell ("ONE TOKEN, EVERY KEY") takes the largest position (top-center or upper-left, 2x2) -- Supporting cells around the hero, mixed cell sizes for rhythm -- Each cell self-contained with its own title + icon + brief content -- Title strip at the top: "BITWARDEN SECRETS MANAGER โ€” HERMES-AGENT PR #30035" -- Footer strip at the bottom with commit SHA + repo - -## Style Guidelines (retro-pop-grid) - -- 1970s retro pop art with strict Swiss international grid -- Background: warm vintage cream/beige (#F5F0E6) -- Accents: salmon pink, sky blue, mustard yellow, mint green โ€” all muted retro tones -- Pure solid black (#000000) and solid white (#FFFFFF) for extreme-contrast cells -- Uniform thick black outlines on ALL illustrations, text boxes, grid dividers -- Pure 2D flat vector aesthetic with subtle screen-print texture -- One cell inverted to black-background-with-white-text for the "NEVER BLOCKS STARTUP" warning section -- Geometric fill patterns in empty cells: checkerboards, diagonal lines, dot grids -- Flat abstract symbols: shields (security), wrenches (install), arrows (rotation), keyholes (auth), checkmarks (tests) -- Vintage comic-style smiley face for "26/26 PASSING" cell -- Bold brutalist or thick retro display fonts for headers; clean sans-serif body -- Decorative stylistic labels acceptable: "WARNING", "NEW DEFAULT", "PINNED", "VERIFIED", "ROTATE" - -## Avoid - -- 3D rendering, gradients, soft shadows, sketch-like lines -- Free-floating elements โ€” everything anchored in grid cells -- Pure white background โ€” must use warm cream/beige - ---- - -Generate the infographic based on the content below: - -### Title (top strip) -BITWARDEN SECRETS MANAGER โ†’ HERMES-AGENT -PR #30035 - -### HERO CELL (largest, top-center, salmon pink background with thick black border) -ONE TOKEN, EVERY KEY -Rotate once in the Bitwarden web app. -Every Hermes process picks it up on next start. -NEW DEFAULT: override_existing = true - -### Cell โ€” LAZY INSTALL (sky blue background) -~/.hermes/bin/bws -bws v2.0.0 PINNED -SHA-256 VERIFIED -No apt ยท no brew ยท no sudo -Icon: wrench + downward arrow - -### Cell โ€” CLI SURFACE (mustard yellow background, checkerboard accents) -$ hermes secrets bitwarden - setup wizard - status diagnose - sync fetch - install binary - disable off -Icon: terminal prompt symbol - -### Cell โ€” SOURCE OF TRUTH (mint green background) -BITWARDEN WINS -Overwrites stale .env on every start -Bootstrap token never overwritten (exception) -Icon: keyhole + arrow - -### Cell โ€” INVERTED BLACK CELL with WHITE TEXT โ€” NEVER BLOCKS STARTUP (extreme contrast) -WARNING-FREE STARTUP -Missing binary โ†’ warn + continue -Bad token โ†’ warn + continue -Network down โ†’ warn + continue -Checksum mismatch โ†’ refuse + warn -30s timeout ceiling -Icon: white triangle warning sign - -### Cell โ€” TESTS (cream with thick black outline, vintage comic smiley face) -26 / 26 -HERMETIC -subprocess + urllib mocked -linux ยท macos ยท windows -x86_64 ยท arm64 -Icon: comic-style smiley face with checkmark - -### Cell โ€” CONFIG YAML (white background with black grid) -secrets: - bitwarden: - enabled: true - project_id: ... - override_existing: true - cache_ttl_seconds: 300 - auto_install: true - -### Footer strip (bottom, black-on-cream) -PR #30035 ยท commit 7f9b05668 ยท NousResearch/hermes-agent -10 files ยท +1743 / -1 ยท agent/secret_sources/ ยท hermes_cli/secrets_cli.py diff --git a/infographic/bitwarden-secrets-manager/structured-content.md b/infographic/bitwarden-secrets-manager/structured-content.md deleted file mode 100644 index 9d0a9c76d70..00000000000 --- a/infographic/bitwarden-secrets-manager/structured-content.md +++ /dev/null @@ -1,57 +0,0 @@ -# Hermes-Agent PR #30035 โ€” Bitwarden Secrets Manager Integration - -## Hero -**ONE TOKEN, EVERY KEY** -Rotate once. Every Hermes process picks it up on next start. -`secrets.bitwarden.override_existing: true` (default) - -## Cells - -### Lazy Install -- `bws v2.0.0` pinned -- Downloaded into `~/.hermes/bin/bws` -- SHA-256 verified vs GitHub Releases checksum file -- No apt, no brew, no sudo -- Cross-platform: linux gnu+musl, macos universal, windows x86_64+arm64 - -### CLI Surface -- `hermes secrets bitwarden setup` wizard -- `hermes secrets bitwarden status` diagnose -- `hermes secrets bitwarden sync` dry-run / --apply -- `hermes secrets bitwarden install` binary only -- `hermes secrets bitwarden disable` off switch - -### Source of Truth -- Bitwarden WINS on every Hermes start -- BSM values overwrite stale `.env` lines -- Rotate a key once โ†’ all your machines reload it -- Bootstrap token `BWS_ACCESS_TOKEN` is the lone exception (never overwritten) - -### Never Blocks Startup -- Missing binary โ†’ warn + continue -- Bad token โ†’ warn + continue -- Checksum mismatch โ†’ refuse install + warn -- No network โ†’ warn + continue -- Timeout โ†’ 30s ceiling, warn + continue - -### Tests -- 26/26 passing, hermetic -- subprocess + urllib mocked -- Platform matrix tested (linux, macos, windows ร— x86_64, arm64) -- Cache hit/miss, auth fail, non-JSON, timeout, override behavior - -### Config -```yaml -secrets: - bitwarden: - enabled: true - project_id: - override_existing: true # NEW DEFAULT - cache_ttl_seconds: 300 - auto_install: true -``` - -## Footer -PR #30035 ยท commit 7f9b05668 ยท NousResearch/hermes-agent - -10 files changed ยท +1743 / -1 ยท agent/secret_sources/ ยท hermes_cli/secrets_cli.py ยท tests ยท docs diff --git a/infographic/kanban-db-corruption-defense/infographic.png b/infographic/kanban-db-corruption-defense/infographic.png new file mode 100644 index 00000000000..54e4d48bc76 Binary files /dev/null and b/infographic/kanban-db-corruption-defense/infographic.png differ diff --git a/infographic/skill-scanner-no-ghost-skills/infographic.png b/infographic/skill-scanner-no-ghost-skills/infographic.png deleted file mode 100644 index 72e207a5fab..00000000000 Binary files a/infographic/skill-scanner-no-ghost-skills/infographic.png and /dev/null differ diff --git a/locales/af.yaml b/locales/af.yaml index b08f4316566..636bae754f3 100644 --- a/locales/af.yaml +++ b/locales/af.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Geen benoemde sessies gevind nie.\nGebruik `/title My Sessie` om jou huidige sessie 'n naam te gee, en dan `/resume My Sessie` om later daarheen terug te keer." list_header: "๐Ÿ“‹ **Benoemde Sessies**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nGebruik: `/resume `" + list_footer_numbered: "\nGebruik: `/resume ` of `/resume ` (bv. `/resume 1` vir die mees onlangse)" list_failed: "Kon nie sessies lys nie: {error}" + out_of_range: "Hervat-indeks {index} is buite bereik.\nGebruik `/resume` sonder argumente om beskikbare sessies te sien." not_found: "Geen sessie gevind wat by '**{name}**' pas nie.\nGebruik `/resume` sonder argumente om beskikbare sessies te sien." already_on: "๐Ÿ“Œ Reeds op sessie **{name}**." switch_failed: "Kon nie sessie verander nie." diff --git a/locales/de.yaml b/locales/de.yaml index 70546c875f5..f400dd9fb2e 100644 --- a/locales/de.yaml +++ b/locales/de.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Keine benannten Sitzungen gefunden.\nVerwenden Sie `/title Meine Sitzung`, um die aktuelle Sitzung zu benennen, dann `/resume Meine Sitzung`, um spรคter dorthin zurรผckzukehren." list_header: "๐Ÿ“‹ **Benannte Sitzungen**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nVerwendung: `/resume `" + list_footer_numbered: "\nVerwendung: `/resume ` oder `/resume ` (z. B. `/resume 1` fรผr die zuletzt verwendete)" list_failed: "Sitzungen konnten nicht aufgelistet werden: {error}" + out_of_range: "Wiederaufnahme-Index {index} liegt auรŸerhalb des gรผltigen Bereichs.\nVerwenden Sie `/resume` ohne Argumente, um verfรผgbare Sitzungen anzuzeigen." not_found: "Keine Sitzung passend zu '**{name}**' gefunden.\nVerwenden Sie `/resume` ohne Argumente, um verfรผgbare Sitzungen zu sehen." already_on: "๐Ÿ“Œ Bereits in Sitzung **{name}**." switch_failed: "Sitzungswechsel fehlgeschlagen." diff --git a/locales/en.yaml b/locales/en.yaml index cbb61055fc8..88d18a2f892 100644 --- a/locales/en.yaml +++ b/locales/en.yaml @@ -237,9 +237,12 @@ gateway: no_named_sessions: "No named sessions found.\nUse `/title My Session` to name your current session, then `/resume My Session` to return to it later." list_header: "๐Ÿ“‹ **Named Sessions**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nUsage: `/resume `" + list_footer_numbered: "\nUsage: `/resume ` or `/resume ` (e.g. `/resume 1` for the most recent)" list_failed: "Could not list sessions: {error}" + out_of_range: "Resume index {index} is out of range.\nUse `/resume` with no arguments to see available sessions." not_found: "No session found matching '**{name}**'.\nUse `/resume` with no arguments to see available sessions." already_on: "๐Ÿ“Œ Already on session **{name}**." switch_failed: "Failed to switch session." diff --git a/locales/es.yaml b/locales/es.yaml index 34b9a7bb1bb..08aaf9ad0b4 100644 --- a/locales/es.yaml +++ b/locales/es.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "No se encontraron sesiones con nombre.\nUsa `/title Mi sesiรณn` para nombrar la sesiรณn actual y luego `/resume Mi sesiรณn` para volver a ella." list_header: "๐Ÿ“‹ **Sesiones con nombre**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nUso: `/resume `" + list_footer_numbered: "\nUso: `/resume ` o `/resume ` (p. ej. `/resume 1` para la mรกs reciente)" list_failed: "No se pudieron listar las sesiones: {error}" + out_of_range: "El รญndice de reanudaciรณn {index} estรก fuera de rango.\nUsa `/resume` sin argumentos para ver las sesiones disponibles." not_found: "No se encontrรณ ninguna sesiรณn que coincida con '**{name}**'.\nUsa `/resume` sin argumentos para ver las sesiones disponibles." already_on: "๐Ÿ“Œ Ya estรกs en la sesiรณn **{name}**." switch_failed: "No se pudo cambiar de sesiรณn." diff --git a/locales/fr.yaml b/locales/fr.yaml index 03d5e0b6222..ddb89bd2f49 100644 --- a/locales/fr.yaml +++ b/locales/fr.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Aucune session nommรฉe trouvรฉe.\nUtilisez `/title Ma session` pour nommer la session actuelle, puis `/resume Ma session` pour y revenir plus tard." list_header: "๐Ÿ“‹ **Sessions nommรฉes**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nUsage : `/resume `" + list_footer_numbered: "\nUtilisation : `/resume ` ou `/resume ` (par exemple `/resume 1` pour la plus rรฉcente)" list_failed: "Impossible de lister les sessions : {error}" + out_of_range: "L'index de reprise {index} est hors limites.\nUtilisez `/resume` sans arguments pour voir les sessions disponibles." not_found: "Aucune session correspondant ร  '**{name}**' trouvรฉe.\nUtilisez `/resume` sans argument pour voir les sessions disponibles." already_on: "๐Ÿ“Œ Dรฉjร  sur la session **{name}**." switch_failed: "ร‰chec du changement de session." diff --git a/locales/ga.yaml b/locales/ga.yaml index 3dd5c46447f..40fb94ba4e6 100644 --- a/locales/ga.yaml +++ b/locales/ga.yaml @@ -226,9 +226,12 @@ gateway: no_named_sessions: "Nรญor aimsรญodh aon seisiรบn ainmnithe.\nรšsรกid `/title M'Ainm Seisiรบin` chun do sheisiรบn reatha a ainmniรบ, ansin `/resume M'Ainm Seisiรบin` chun filleadh air nรญos dรฉanaรญ." list_header: "๐Ÿ“‹ **Seisiรบin Ainmnithe**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nรšsรกid: `/resume `" + list_footer_numbered: "\nรšsรกid: `/resume ` nรณ `/resume ` (m.sh. `/resume 1` don cheann is dรฉanaรญ)" list_failed: "Nรญorbh fhรฉidir seisiรบin a liostรกil: {error}" + out_of_range: "Tรก an t-innรฉacs atosaithe {index} as raon.\nรšsรกid `/resume` gan argรณintรญ chun na seisiรบin atรก ar fรกil a fheiceรกil." not_found: "Nรญor aimsรญodh aon seisiรบn ag teacht le '**{name}**'.\nรšsรกid `/resume` gan argรณintรญ chun seisiรบin atรก ar fรกil a fheiceรกil." already_on: "๐Ÿ“Œ Cheana ar an seisiรบn **{name}**." switch_failed: "Theip ar athrรบ seisiรบin." diff --git a/locales/hu.yaml b/locales/hu.yaml index b18f7be707f..9be44294dc2 100644 --- a/locales/hu.yaml +++ b/locales/hu.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Nem talรกlhatรณ elnevezett munkamenet.\nHasznรกld a `/title Sajรกt munkamenet` parancsot a jelenlegi munkamenet elnevezรฉsรฉhez, majd a `/resume Sajรกt munkamenet` paranccsal tรฉrhetsz vissza hozzรก." list_header: "๐Ÿ“‹ **Elnevezett munkamenetek**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nHasznรกlat: `/resume `" + list_footer_numbered: "\nHasznรกlat: `/resume ` vagy `/resume ` (pl. `/resume 1` a legutรณbbihoz)" list_failed: "Nem sikerรผlt listรกzni a munkameneteket: {error}" + out_of_range: "A folytatรกsi index ({index}) tartomรกnyon kรญvรผl esik.\nA `/resume` argumentumok nรฉlkรผli hasznรกlata megjelenรญti az elรฉrhetล‘ munkameneteket." not_found: "Nem talรกlhatรณ '**{name}**' nevลฑ munkamenet.\nArgumentumok nรฉlkรผl hasznรกld a `/resume` parancsot az elรฉrhetล‘ munkamenetek megtekintรฉsรฉhez." already_on: "๐Ÿ“Œ Mรกr a **{name}** munkamenetben vagy." switch_failed: "Nem sikerรผlt munkamenetet vรกltani." diff --git a/locales/it.yaml b/locales/it.yaml index 053046be7d5..e98d86e7fb1 100644 --- a/locales/it.yaml +++ b/locales/it.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Nessuna sessione con nome trovata.\nUsa `/title My Session` per dare un nome alla sessione attuale, poi `/resume My Session` per tornare a essa in seguito." list_header: "๐Ÿ“‹ **Sessioni con nome**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nUso: `/resume `" + list_footer_numbered: "\nUso: `/resume ` o `/resume ` (es. `/resume 1` per la piรน recente)" list_failed: "Impossibile elencare le sessioni: {error}" + out_of_range: "L'indice di ripresa {index} รจ fuori intervallo.\nUsa `/resume` senza argomenti per vedere le sessioni disponibili." not_found: "Nessuna sessione trovata corrispondente a '**{name}**'.\nUsa `/resume` senza argomenti per vedere le sessioni disponibili." already_on: "๐Ÿ“Œ Giร  nella sessione **{name}**." switch_failed: "Cambio di sessione non riuscito." diff --git a/locales/ja.yaml b/locales/ja.yaml index 931e88ed3d8..33cb1b99c9a 100644 --- a/locales/ja.yaml +++ b/locales/ja.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "ๅๅ‰ไป˜ใใ‚ปใƒƒใ‚ทใƒงใƒณใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใ€‚\n`/title ใ‚ปใƒƒใ‚ทใƒงใƒณๅ` ใง็พๅœจใฎใ‚ปใƒƒใ‚ทใƒงใƒณใซๅๅ‰ใ‚’ไป˜ใ‘ใ‚‹ใจใ€ๅพŒใง `/resume ใ‚ปใƒƒใ‚ทใƒงใƒณๅ` ใงๆˆปใ‚Œใพใ™ใ€‚" list_header: "๐Ÿ“‹ **ๅๅ‰ไป˜ใใ‚ปใƒƒใ‚ทใƒงใƒณ**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nไฝฟใ„ๆ–น: `/resume <ใ‚ปใƒƒใ‚ทใƒงใƒณๅ>`" + list_footer_numbered: "\nไฝฟใ„ๆ–น: `/resume <ใ‚ปใƒƒใ‚ทใƒงใƒณๅ>` ใพใŸใฏ `/resume <็•ชๅท>`๏ผˆไพ‹: ๆœ€ๆ–ฐใฎใ‚ปใƒƒใ‚ทใƒงใƒณใซใฏ `/resume 1`๏ผ‰" list_failed: "ใ‚ปใƒƒใ‚ทใƒงใƒณใ‚’ไธ€่ฆง่กจ็คบใงใใพใ›ใ‚“ใงใ—ใŸ: {error}" + out_of_range: "ๅ†้–‹ใ‚คใƒณใƒ‡ใƒƒใ‚ฏใ‚น {index} ใฏ็ฏ„ๅ›ฒๅค–ใงใ™ใ€‚\nๅผ•ๆ•ฐใชใ—ใง `/resume` ใ‚’ๅฎŸ่กŒใ™ใ‚‹ใจใ€ๅˆฉ็”จๅฏ่ƒฝใชใ‚ปใƒƒใ‚ทใƒงใƒณใŒ่กจ็คบใ•ใ‚Œใพใ™ใ€‚" not_found: "'**{name}**' ใซไธ€่‡ดใ™ใ‚‹ใ‚ปใƒƒใ‚ทใƒงใƒณใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใ€‚\nๅผ•ๆ•ฐใชใ—ใง `/resume` ใ‚’ๅฎŸ่กŒใ™ใ‚‹ใจๅˆฉ็”จๅฏ่ƒฝใชใ‚ปใƒƒใ‚ทใƒงใƒณใ‚’่กจ็คบใ—ใพใ™ใ€‚" already_on: "๐Ÿ“Œ ๆ—ขใซใ‚ปใƒƒใ‚ทใƒงใƒณ **{name}** ใซใ„ใพใ™ใ€‚" switch_failed: "ใ‚ปใƒƒใ‚ทใƒงใƒณใฎๅˆ‡ใ‚Šๆ›ฟใˆใซๅคฑๆ•—ใ—ใพใ—ใŸใ€‚" diff --git a/locales/ko.yaml b/locales/ko.yaml index 6fc9d1679d2..3f9ad817334 100644 --- a/locales/ko.yaml +++ b/locales/ko.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "์ด๋ฆ„์ด ์ง€์ •๋œ ์„ธ์…˜์ด ์—†์Šต๋‹ˆ๋‹ค.\nํ˜„์žฌ ์„ธ์…˜์— ์ด๋ฆ„์„ ์ง€์ •ํ•˜๋ ค๋ฉด `/title ๋‚ด ์„ธ์…˜`์„ ์‚ฌ์šฉํ•˜๊ณ , ๋‚˜์ค‘์— `/resume ๋‚ด ์„ธ์…˜`์œผ๋กœ ๋Œ์•„์˜ค์„ธ์š”." list_header: "๐Ÿ“‹ **์ด๋ฆ„์ด ์ง€์ •๋œ ์„ธ์…˜**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\n์‚ฌ์šฉ๋ฒ•: `/resume `" + list_footer_numbered: "\n์‚ฌ์šฉ๋ฒ•: `/resume <์„ธ์…˜ ์ด๋ฆ„>` ๋˜๋Š” `/resume <๋ฒˆํ˜ธ>` (์˜ˆ: ๊ฐ€์žฅ ์ตœ๊ทผ ์„ธ์…˜์€ `/resume 1`)" list_failed: "์„ธ์…˜ ๋ชฉ๋ก์„ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {error}" + out_of_range: "์žฌ๊ฐœ ์ธ๋ฑ์Šค {index}์ด(๊ฐ€) ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚ฌ์Šต๋‹ˆ๋‹ค.\n์ธ์ž ์—†์ด `/resume`์„ ์‹คํ–‰ํ•˜๋ฉด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์„ธ์…˜์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค." not_found: "'**{name}**'์™€ ์ผ์น˜ํ•˜๋Š” ์„ธ์…˜์ด ์—†์Šต๋‹ˆ๋‹ค.\n์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์„ธ์…˜์„ ๋ณด๋ ค๋ฉด ์ธ์ˆ˜ ์—†์ด `/resume`์„ ์‚ฌ์šฉํ•˜์„ธ์š”." already_on: "๐Ÿ“Œ ์ด๋ฏธ **{name}** ์„ธ์…˜์— ์žˆ์Šต๋‹ˆ๋‹ค." switch_failed: "์„ธ์…˜ ์ „ํ™˜์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค." diff --git a/locales/pt.yaml b/locales/pt.yaml index e202a53480f..0c0eddad91e 100644 --- a/locales/pt.yaml +++ b/locales/pt.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "Nรฃo foram encontradas sessรตes com nome.\nUsa `/title A minha sessรฃo` para nomear a sessรฃo atual e depois `/resume A minha sessรฃo` para voltar a ela." list_header: "๐Ÿ“‹ **Sessรตes com nome**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nUso: `/resume `" + list_footer_numbered: "\nUso: `/resume ` ou `/resume ` (ex.: `/resume 1` para a mais recente)" list_failed: "Nรฃo foi possรญvel listar as sessรตes: {error}" + out_of_range: "O รญndice de retomada {index} estรก fora do intervalo.\nUse `/resume` sem argumentos para ver as sessรตes disponรญveis." not_found: "Nรฃo foi encontrada nenhuma sessรฃo correspondente a '**{name}**'.\nUsa `/resume` sem argumentos para ver as sessรตes disponรญveis." already_on: "๐Ÿ“Œ Jรก estรกs na sessรฃo **{name}**." switch_failed: "Falha ao mudar de sessรฃo." diff --git a/locales/ru.yaml b/locales/ru.yaml index 76fde56a9b6..b3a202be777 100644 --- a/locales/ru.yaml +++ b/locales/ru.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "ะ˜ะผะตะฝะพะฒะฐะฝะฝั‹ั… ัะตะฐะฝัะพะฒ ะฝะต ะฝะฐะนะดะตะฝะพ.\nะ˜ัะฟะพะปัŒะทัƒะนั‚ะต `/title ะœะพะน ัะตะฐะฝั`, ั‡ั‚ะพะฑั‹ ะฝะฐะทะฒะฐั‚ัŒ ั‚ะตะบัƒั‰ะธะน ัะตะฐะฝั, ะทะฐั‚ะตะผ `/resume ะœะพะน ัะตะฐะฝั`, ั‡ั‚ะพะฑั‹ ะฒะตั€ะฝัƒั‚ัŒัั ะบ ะฝะตะผัƒ ะฟะพะทะถะต." list_header: "๐Ÿ“‹ **ะ˜ะผะตะฝะพะฒะฐะฝะฝั‹ะต ัะตะฐะฝัั‹**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nะ˜ัะฟะพะปัŒะทะพะฒะฐะฝะธะต: `/resume <ะฝะฐะทะฒะฐะฝะธะต ัะตะฐะฝัะฐ>`" + list_footer_numbered: "\nะ˜ัะฟะพะปัŒะทะพะฒะฐะฝะธะต: `/resume <ะธะผั ัะตะฐะฝัะฐ>` ะธะปะธ `/resume <ะฝะพะผะตั€>` (ะฝะฐะฟั€ะธะผะตั€, `/resume 1` ะดะปั ัะฐะผะพะณะพ ะฝะตะดะฐะฒะฝะตะณะพ)" list_failed: "ะะต ัƒะดะฐะปะพััŒ ะฟะพะปัƒั‡ะธั‚ัŒ ัะฟะธัะพะบ ัะตะฐะฝัะพะฒ: {error}" + out_of_range: "ะ˜ะฝะดะตะบั ะฒะพะทะพะฑะฝะพะฒะปะตะฝะธั {index} ะฒะฝะต ะดะธะฐะฟะฐะทะพะฝะฐ.\nะ˜ัะฟะพะปัŒะทัƒะนั‚ะต `/resume` ะฑะตะท ะฐั€ะณัƒะผะตะฝั‚ะพะฒ, ั‡ั‚ะพะฑั‹ ัƒะฒะธะดะตั‚ัŒ ะดะพัั‚ัƒะฟะฝั‹ะต ัะตะฐะฝัั‹." not_found: "ะกะตะฐะฝั, ัะพะพั‚ะฒะตั‚ัั‚ะฒัƒัŽั‰ะธะน '**{name}**', ะฝะต ะฝะฐะนะดะตะฝ.\nะ˜ัะฟะพะปัŒะทัƒะนั‚ะต `/resume` ะฑะตะท ะฐั€ะณัƒะผะตะฝั‚ะพะฒ, ั‡ั‚ะพะฑั‹ ัƒะฒะธะดะตั‚ัŒ ะดะพัั‚ัƒะฟะฝั‹ะต ัะตะฐะฝัั‹." already_on: "๐Ÿ“Œ ะฃะถะต ะฒ ัะตะฐะฝัะต **{name}**." switch_failed: "ะะต ัƒะดะฐะปะพััŒ ะฟะตั€ะตะบะปัŽั‡ะธั‚ัŒ ัะตะฐะฝั." diff --git a/locales/tr.yaml b/locales/tr.yaml index add252ea56b..0be0e351af7 100644 --- a/locales/tr.yaml +++ b/locales/tr.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "AdlandฤฑrฤฑlmฤฑลŸ oturum bulunamadฤฑ.\nMevcut oturumu adlandฤฑrmak iรงin `/title Oturumum`, daha sonra geri dรถnmek iรงin `/resume Oturumum` kullanฤฑn." list_header: "๐Ÿ“‹ **AdlandฤฑrฤฑlmฤฑลŸ Oturumlar**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nKullanฤฑm: `/resume `" + list_footer_numbered: "\nKullanฤฑm: `/resume ` veya `/resume ` (รถrn. en yenisi iรงin `/resume 1`)" list_failed: "Oturumlar listelenemedi: {error}" + out_of_range: "Devam endeksi {index} aralฤฑk dฤฑลŸฤฑnda.\nKullanฤฑlabilir oturumlarฤฑ gรถrmek iรงin `/resume` komutunu argรผmansฤฑz รงalฤฑลŸtฤฑrฤฑn." not_found: "'**{name}**' ile eลŸleลŸen oturum bulunamadฤฑ.\nKullanฤฑlabilir oturumlarฤฑ gรถrmek iรงin argรผmansฤฑz `/resume` kullanฤฑn." already_on: "๐Ÿ“Œ Zaten **{name}** oturumundasฤฑnฤฑz." switch_failed: "Oturum deฤŸiลŸtirilemedi." diff --git a/locales/uk.yaml b/locales/uk.yaml index 972e535f901..1b36b3e2f48 100644 --- a/locales/uk.yaml +++ b/locales/uk.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "ะ†ะผะตะฝะพะฒะฐะฝะธั… ัะตะฐะฝัั–ะฒ ะฝะต ะทะฝะฐะนะดะตะฝะพ.\nะ’ะธะบะพั€ะธัั‚ะฐะนั‚ะต `/title ะœั–ะน ัะตะฐะฝั`, ั‰ะพะฑ ะฝะฐะทะฒะฐั‚ะธ ะฟะพั‚ะพั‡ะฝะธะน ัะตะฐะฝั, ะฟะพั‚ั–ะผ `/resume ะœั–ะน ัะตะฐะฝั`, ั‰ะพะฑ ะฟะพะฒะตั€ะฝัƒั‚ะธัั ะดะพ ะฝัŒะพะณะพ." list_header: "๐Ÿ“‹ **ะ†ะผะตะฝะพะฒะฐะฝั– ัะตะฐะฝัะธ**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\nะ’ะธะบะพั€ะธัั‚ะฐะฝะฝั: `/resume <ะฝะฐะทะฒะฐ ัะตะฐะฝััƒ>`" + list_footer_numbered: "\nะ’ะธะบะพั€ะธัั‚ะฐะฝะฝั: `/resume <ะฝะฐะทะฒะฐ ัะตัั–ั—>` ะฐะฑะพ `/resume <ะฝะพะผะตั€>` (ะฝะฐะฟั€ะธะบะปะฐะด, `/resume 1` ะดะปั ะฝะฐะนะฝะพะฒั–ัˆะพั—)" list_failed: "ะะต ะฒะดะฐะปะพัั ะพั‚ั€ะธะผะฐั‚ะธ ัะฟะธัะพะบ ัะตะฐะฝัั–ะฒ: {error}" + out_of_range: "ะ†ะฝะดะตะบั ะฒั–ะดะฝะพะฒะปะตะฝะฝั {index} ะฟะพะทะฐ ะผะตะถะฐะผะธ ะดั–ะฐะฟะฐะทะพะฝัƒ.\nะ’ะธะบะพั€ะธัั‚ะพะฒัƒะนั‚ะต `/resume` ะฑะตะท ะฐั€ะณัƒะผะตะฝั‚ั–ะฒ, ั‰ะพะฑ ะฟะตั€ะตะณะปัะฝัƒั‚ะธ ะดะพัั‚ัƒะฟะฝั– ัะตัั–ั—." not_found: "ะกะตะฐะฝั, ั‰ะพ ะฒั–ะดะฟะพะฒั–ะดะฐั” '**{name}**', ะฝะต ะทะฝะฐะนะดะตะฝะพ.\nะ’ะธะบะพั€ะธัั‚ะฐะนั‚ะต `/resume` ะฑะตะท ะฐั€ะณัƒะผะตะฝั‚ั–ะฒ, ั‰ะพะฑ ะฟะพะฑะฐั‡ะธั‚ะธ ะดะพัั‚ัƒะฟะฝั– ัะตะฐะฝัะธ." already_on: "๐Ÿ“Œ ะฃะถะต ะฒ ัะตะฐะฝัั– **{name}**." switch_failed: "ะะต ะฒะดะฐะปะพัั ะฟะตั€ะตะบะปัŽั‡ะธั‚ะธ ัะตะฐะฝั." diff --git a/locales/zh-hant.yaml b/locales/zh-hant.yaml index 30fbcabac3f..a8c67533847 100644 --- a/locales/zh-hant.yaml +++ b/locales/zh-hant.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "ๆ‰พไธๅˆฐๅทฒๅ‘ฝๅ็š„ๅทฅไฝœ้šŽๆฎตใ€‚\nไฝฟ็”จ `/title ๆˆ‘็š„ๅทฅไฝœ้šŽๆฎต` ็‚บ็›ฎๅ‰ๅทฅไฝœ้šŽๆฎตๅ‘ฝๅ๏ผŒ็„ถๅพŒไฝฟ็”จ `/resume ๆˆ‘็š„ๅทฅไฝœ้šŽๆฎต` ่ฟ”ๅ›žใ€‚" list_header: "๐Ÿ“‹ **ๅทฒๅ‘ฝๅๅทฅไฝœ้šŽๆฎต**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\n็”จๆณ•๏ผš`/resume <ๅทฅไฝœ้šŽๆฎตๅ็จฑ>`" + list_footer_numbered: "\n็”จๆณ•๏ผš`/resume <ๆœƒ่ฉฑๅ็จฑ>` ๆˆ– `/resume <็ทจ่™Ÿ>`๏ผˆไพ‹ๅฆ‚๏ผŒ`/resume 1` ่กจ็คบๆœ€่ฟ‘็š„ๆœƒ่ฉฑ๏ผ‰" list_failed: "็„กๆณ•ๅˆ—ๅ‡บๅทฅไฝœ้šŽๆฎต๏ผš{error}" + out_of_range: "ๆขๅพฉ็ดขๅผ• {index} ่ถ…ๅ‡บ็ฏ„ๅœใ€‚\n่ซ‹ไฝฟ็”จไธๅธถๅƒๆ•ธ็š„ `/resume` ๆŸฅ็œ‹ๅฏ็”จๆœƒ่ฉฑใ€‚" not_found: "ๆ‰พไธๅˆฐ็ฌฆๅˆ '**{name}**' ็š„ๅทฅไฝœ้šŽๆฎตใ€‚\nไฝฟ็”จไธๅธถๅƒๆ•ธ็š„ `/resume` ๆชข่ฆ–ๅฏ็”จ็š„ๅทฅไฝœ้šŽๆฎตใ€‚" already_on: "๐Ÿ“Œ ๅทฒๅœจๅทฅไฝœ้šŽๆฎต **{name}** ไธŠใ€‚" switch_failed: "ๅˆ‡ๆ›ๅทฅไฝœ้šŽๆฎตๅคฑๆ•—ใ€‚" diff --git a/locales/zh.yaml b/locales/zh.yaml index 60999f06d3a..86c1d359777 100644 --- a/locales/zh.yaml +++ b/locales/zh.yaml @@ -222,9 +222,12 @@ gateway: no_named_sessions: "ๆœชๆ‰พๅˆฐๅทฒๅ‘ฝๅ็š„ไผš่ฏใ€‚\nไฝฟ็”จ `/title ๆˆ‘็š„ไผš่ฏ` ไธบๅฝ“ๅ‰ไผš่ฏๅ‘ฝๅ๏ผŒ็„ถๅŽ็”จ `/resume ๆˆ‘็š„ไผš่ฏ` ่ฟ”ๅ›žใ€‚" list_header: "๐Ÿ“‹ **ๅทฒๅ‘ฝๅไผš่ฏ**\n" list_item: "โ€ข **{title}**{preview_part}" + list_item_numbered: "{index}. **{title}**{preview_part}" list_preview_suffix: " โ€” _{preview}_" list_footer: "\n็”จๆณ•๏ผš`/resume <ไผš่ฏๅ็งฐ>`" + list_footer_numbered: "\n็”จๆณ•๏ผš`/resume <ไผš่ฏๅ็งฐ>` ๆˆ– `/resume <็ผ–ๅท>`๏ผˆไพ‹ๅฆ‚๏ผŒ`/resume 1` ่กจ็คบๆœ€่ฟ‘็š„ไผš่ฏ๏ผ‰" list_failed: "ๆ— ๆณ•ๅˆ—ๅ‡บไผš่ฏ๏ผš{error}" + out_of_range: "ๆขๅค็ดขๅผ• {index} ่ถ…ๅ‡บ่Œƒๅ›ดใ€‚\n่ฏทไฝฟ็”จไธๅธฆๅ‚ๆ•ฐ็š„ `/resume` ๆŸฅ็œ‹ๅฏ็”จไผš่ฏใ€‚" not_found: "ๆœชๆ‰พๅˆฐๅŒน้… '**{name}**' ็š„ไผš่ฏใ€‚\nไฝฟ็”จไธๅธฆๅ‚ๆ•ฐ็š„ `/resume` ๆŸฅ็œ‹ๅฏ็”จไผš่ฏใ€‚" already_on: "๐Ÿ“Œ ๅทฒๅœจไผš่ฏ **{name}** ไธŠใ€‚" switch_failed: "ๅˆ‡ๆขไผš่ฏๅคฑ่ดฅใ€‚" diff --git a/nix/web.nix b/nix/web.nix index 54f7870d8ea..557f596b911 100644 --- a/nix/web.nix +++ b/nix/web.nix @@ -4,7 +4,7 @@ let src = ../web; npmDeps = pkgs.fetchNpmDeps { inherit src; - hash = "sha256-xSsyluzU2lNhwGqB6XMCGMv3QFHZizE6hgUyc1jvyOw="; + hash = "sha256-6qhGuifHVtCeep1SiQdCUxBMr7UGhYpdMTvXhrQu/zA="; }; npm = hermesNpmLib.mkNpmPassthru { folder = "web"; attr = "web"; pname = "hermes-web"; }; diff --git a/plugins/image_gen/fal/__init__.py b/plugins/image_gen/fal/__init__.py new file mode 100644 index 00000000000..21b88f37f34 --- /dev/null +++ b/plugins/image_gen/fal/__init__.py @@ -0,0 +1,182 @@ +"""FAL.ai image generation backend. + +Wraps the 18-model FAL catalog (FLUX 2, Z-Image, Nano Banana, GPT +Image 1.5, Recraft, Imagen 4, Qwen, Ideogram, โ€ฆ) as an +:class:`ImageGenProvider` implementation. + +The heavy lifting โ€” model catalog, payload construction, request +submission, managed-Nous-gateway selection, Clarity Upscaler chaining +โ€” lives in :mod:`tools.image_generation_tool`. This plugin reaches into +that module via call-time indirection (``import tools.image_generation_tool as _it``) +so: + +* the existing test suite (``tests/tools/test_image_generation.py``, + ``tests/tools/test_managed_media_gateways.py``) keeps patching + ``image_tool._submit_fal_request`` / ``image_tool.fal_client`` / + ``image_tool._managed_fal_client`` without modification, and +* there's exactly one canonical FAL code path on disk โ€” the plugin is a + registration adapter, not a parallel implementation. + +See issue #26241 for the migration plan and the +``plugin-extraction-test-patch-compatibility.md`` rules this follows. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, Dict, List, Optional + +from agent.image_gen_provider import ( + DEFAULT_ASPECT_RATIO, + ImageGenProvider, + resolve_aspect_ratio, +) + +logger = logging.getLogger(__name__) + + +class FalImageGenProvider(ImageGenProvider): + """FAL.ai image generation backend. + + Delegates to ``tools.image_generation_tool.image_generate_tool`` so + the in-tree FAL implementation (model catalog, payload builder, + managed-gateway selection, Clarity Upscaler chaining) is the single + source of truth. Everything is resolved at call time via the + ``_it`` indirection so tests can monkey-patch the legacy module. + """ + + @property + def name(self) -> str: + return "fal" + + @property + def display_name(self) -> str: + return "FAL.ai" + + def is_available(self) -> bool: + # Available when direct FAL_KEY is set OR the managed Nous + # gateway resolves a fal-queue origin. Both checks come from the + # legacy module so this provider tracks whatever logic ships + # there. + import tools.image_generation_tool as _it + try: + return bool(_it.check_fal_api_key()) + except Exception: # noqa: BLE001 โ€” defensive; never break the picker + return False + + def list_models(self) -> List[Dict[str, Any]]: + import tools.image_generation_tool as _it + return [ + { + "id": model_id, + "display": meta.get("display", model_id), + "speed": meta.get("speed", ""), + "strengths": meta.get("strengths", ""), + "price": meta.get("price", ""), + } + for model_id, meta in _it.FAL_MODELS.items() + ] + + def default_model(self) -> Optional[str]: + import tools.image_generation_tool as _it + return _it.DEFAULT_MODEL + + def get_setup_schema(self) -> Dict[str, Any]: + return { + "name": "FAL.ai", + "badge": "paid", + "tag": "Pick from flux-2-klein, flux-2-pro, gpt-image, nano-banana, etc.", + "env_vars": [ + { + "key": "FAL_KEY", + "prompt": "FAL API key", + "url": "https://fal.ai/dashboard/keys", + }, + ], + } + + def generate( + self, + prompt: str, + aspect_ratio: str = DEFAULT_ASPECT_RATIO, + **kwargs: Any, + ) -> Dict[str, Any]: + """Generate an image via the legacy FAL pipeline. + + Forwards prompt + aspect_ratio (and any forward-compat extras + the schema supports) into :func:`tools.image_generation_tool.image_generate_tool`, + then reshapes its JSON-string response into the provider-ABC + dict format consumed by ``_dispatch_to_plugin_provider``. + """ + import tools.image_generation_tool as _it + + aspect = resolve_aspect_ratio(aspect_ratio) + passthrough = { + key: kwargs[key] + for key in ( + "num_inference_steps", + "guidance_scale", + "num_images", + "output_format", + "seed", + ) + if key in kwargs and kwargs[key] is not None + } + + try: + raw = _it.image_generate_tool( + prompt=prompt, + aspect_ratio=aspect, + **passthrough, + ) + except Exception as exc: # noqa: BLE001 โ€” never raise out of generate + logger.warning("FAL image_generate_tool raised: %s", exc, exc_info=True) + return { + "success": False, + "image": None, + "error": f"FAL image generation failed: {exc}", + "error_type": type(exc).__name__, + "provider": "fal", + "prompt": prompt, + "aspect_ratio": aspect, + } + + try: + response = json.loads(raw) if isinstance(raw, str) else raw + except Exception: # noqa: BLE001 + response = {"success": False, "image": None, "error": "Invalid JSON from FAL pipeline"} + + if not isinstance(response, dict): + response = { + "success": False, + "image": None, + "error": "FAL pipeline returned a non-dict response", + "error_type": "provider_contract", + } + + # Stamp provider/prompt/aspect_ratio so downstream consumers see + # the uniform shape declared in ``agent.image_gen_provider``. + response.setdefault("provider", "fal") + response.setdefault("prompt", prompt) + response.setdefault("aspect_ratio", aspect) + # Annotate model best-effort โ€” the legacy pipeline resolves it + # internally, so query it after the fact for the response shape. + if "model" not in response: + try: + model_id, _meta = _it._resolve_fal_model() + response["model"] = model_id + except Exception: # noqa: BLE001 + pass + return response + + +# --------------------------------------------------------------------------- +# Plugin entry point +# --------------------------------------------------------------------------- + + +def register(ctx) -> None: + """Plugin entry point โ€” wire ``FalImageGenProvider`` into the registry.""" + ctx.register_image_gen_provider(FalImageGenProvider()) diff --git a/plugins/image_gen/fal/plugin.yaml b/plugins/image_gen/fal/plugin.yaml new file mode 100644 index 00000000000..775b76c906d --- /dev/null +++ b/plugins/image_gen/fal/plugin.yaml @@ -0,0 +1,7 @@ +name: fal +version: 1.0.0 +description: "FAL.ai image generation backend (flux-2-klein, flux-2-pro, nano-banana, gpt-image-1.5, recraft-v3, etc.)." +author: NousResearch +kind: backend +requires_env: + - FAL_KEY diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index ff01bbf402e..42925fa74aa 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -47,6 +47,25 @@ _DEFAULT_ENDPOINT = "http://127.0.0.1:1933" _TIMEOUT = 30.0 _REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://") +# Maps the viking_remember `category` enum to a viking:// subdirectory. +# Keep in sync with REMEMBER_SCHEMA.parameters.properties.category.enum. +_CATEGORY_SUBDIR_MAP = { + "preference": "preferences", + "entity": "entities", + "event": "events", + "case": "cases", + "pattern": "patterns", +} +_DEFAULT_MEMORY_SUBDIR = "preferences" + +# Maps the built-in memory tool's `target` ("user" vs "memory") to a subdir +# for on_memory_write mirroring. User profile facts โ†’ preferences; agent +# notes / observations โ†’ patterns. Anything unknown falls back to the default. +_MEMORY_WRITE_TARGET_SUBDIR_MAP = { + "user": "preferences", + "memory": "patterns", +} + # --------------------------------------------------------------------------- # Process-level atexit safety net โ€” ensures pending sessions are committed @@ -607,24 +626,35 @@ class OpenVikingMemoryProvider(MemoryProvider): except Exception as e: logger.warning("OpenViking session commit failed: %s", e) - def on_memory_write(self, action: str, target: str, content: str) -> None: - """Mirror built-in memory writes to OpenViking as explicit memories.""" + def _build_memory_uri(self, subdir: str) -> str: + """Build a viking:// memory URI under the configured user/subdir.""" + slug = uuid.uuid4().hex[:12] + return f"viking://user/{self._user}/memories/{subdir}/mem_{slug}.md" + + def on_memory_write( + self, + action: str, + target: str, + content: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Mirror built-in memory writes to OpenViking via content/write.""" if not self._client or action != "add" or not content: return + subdir = _MEMORY_WRITE_TARGET_SUBDIR_MAP.get(target, _DEFAULT_MEMORY_SUBDIR) + uri = self._build_memory_uri(subdir) + def _write(): try: client = _VikingClient( self._endpoint, self._api_key, account=self._account, user=self._user, agent=self._agent, ) - # Add as a user message with memory context so the commit - # picks it up as an explicit memory during extraction - client.post(f"/api/v1/sessions/{self._session_id}/messages", { - "role": "user", - "parts": [ - {"type": "text", "text": f"[Memory note โ€” {target}] {content}"}, - ], + client.post("/api/v1/content/write", { + "uri": uri, + "content": content, + "mode": "create", }) except Exception as e: logger.debug("OpenViking memory mirror failed: %s", e) @@ -858,24 +888,27 @@ class OpenVikingMemoryProvider(MemoryProvider): if not content: return tool_error("content is required") - # Store as a session message that will be extracted during commit. - # The category hint helps OpenViking's extraction classify correctly. category = args.get("category", "") - text = f"[Remember] {content}" - if category: - text = f"[Remember โ€” {category}] {content}" + subdir = _CATEGORY_SUBDIR_MAP.get(category, _DEFAULT_MEMORY_SUBDIR) + uri = self._build_memory_uri(subdir) - self._client.post(f"/api/v1/sessions/{self._session_id}/messages", { - "role": "user", - "parts": [ - {"type": "text", "text": text}, - ], - }) - - return json.dumps({ - "status": "stored", - "message": "Memory recorded. Will be extracted and indexed on session commit.", - }) + # Write directly via content/write API. + # This creates the file, stores the content, and queues vector indexing + # in a single call โ€” no dependency on session commit / VLM extraction. + try: + result = self._client.post("/api/v1/content/write", { + "uri": uri, + "content": content, + "mode": "create", + }) + written = result.get("result", {}).get("written_bytes", 0) + return json.dumps({ + "status": "stored", + "message": f"Memory stored ({written}b) and queued for vector indexing.", + }) + except Exception as e: + logger.error("OpenViking content/write failed: %s", e) + return tool_error(f"Failed to store memory: {e}") def _tool_add_resource(self, args: dict) -> str: url = args.get("url", "") diff --git a/plugins/model-providers/opencode-zen/__init__.py b/plugins/model-providers/opencode-zen/__init__.py index f720e8f5fad..385741f09a1 100644 --- a/plugins/model-providers/opencode-zen/__init__.py +++ b/plugins/model-providers/opencode-zen/__init__.py @@ -7,9 +7,81 @@ Both use per-model api_mode routing: (this profile) """ +from __future__ import annotations + +from typing import Any + from providers import register_provider from providers.base import ProviderProfile + +def _flat_model_name(model: str | None) -> str: + """Return the bare OpenCode model ID, tolerating aggregator prefixes.""" + return (model or "").strip().rsplit("/", 1)[-1].lower() + + +def _is_kimi_k2_model(model: str | None) -> bool: + return _flat_model_name(model).startswith("kimi-k2") + + +def _is_deepseek_thinking_model(model: str | None) -> bool: + m = _flat_model_name(model) + if m.startswith("deepseek-v") and not m.startswith("deepseek-v3"): + return True + return m == "deepseek-reasoner" + + +class OpenCodeGoProfile(ProviderProfile): + """OpenCode Go - model-specific reasoning controls.""" + + def build_api_kwargs_extras( + self, *, reasoning_config: dict | None = None, model: str | None = None, **context + ) -> tuple[dict[str, Any], dict[str, Any]]: + extra_body: dict[str, Any] = {} + top_level: dict[str, Any] = {} + + if _is_kimi_k2_model(model): + # Kimi K2 on OpenCode Go uses Moonshot's native wire shape: + # extra_body.thinking (binary toggle) + top-level reasoning_effort + # (low|medium|high). Mirrors the KimiProfile (api.moonshot.ai/v1). + if not isinstance(reasoning_config, dict): + # No config โ†’ leave server defaults alone. + return extra_body, top_level + + enabled = reasoning_config.get("enabled") is not False + extra_body["thinking"] = {"type": "enabled" if enabled else "disabled"} + + if not enabled: + return extra_body, top_level + + effort = (reasoning_config.get("effort") or "").strip().lower() + if effort in {"xhigh", "max"}: + top_level["reasoning_effort"] = "high" + elif effort in {"low", "medium", "high"}: + top_level["reasoning_effort"] = effort + return extra_body, top_level + + if not _is_deepseek_thinking_model(model): + return extra_body, top_level + + enabled = True + if isinstance(reasoning_config, dict) and reasoning_config.get("enabled") is False: + enabled = False + extra_body["thinking"] = {"type": "enabled" if enabled else "disabled"} + + if not enabled: + return extra_body, top_level + + if isinstance(reasoning_config, dict): + effort = (reasoning_config.get("effort") or "").strip().lower() + if effort in {"xhigh", "max"}: + top_level["reasoning_effort"] = "max" + elif effort in {"low", "medium", "high"}: + top_level["reasoning_effort"] = effort + + return extra_body, top_level + + opencode_zen = ProviderProfile( name="opencode-zen", aliases=("opencode", "opencode_zen", "zen"), @@ -18,7 +90,7 @@ opencode_zen = ProviderProfile( default_aux_model="gemini-3-flash", ) -opencode_go = ProviderProfile( +opencode_go = OpenCodeGoProfile( name="opencode-go", aliases=("opencode_go", "go", "opencode-go-sub"), env_vars=("OPENCODE_GO_API_KEY",), diff --git a/plugins/platforms/discord/__init__.py b/plugins/platforms/discord/__init__.py new file mode 100644 index 00000000000..d4f1d7bf0e3 --- /dev/null +++ b/plugins/platforms/discord/__init__.py @@ -0,0 +1,3 @@ +from .adapter import register + +__all__ = ["register"] diff --git a/gateway/platforms/discord.py b/plugins/platforms/discord/adapter.py similarity index 91% rename from gateway/platforms/discord.py rename to plugins/platforms/discord/adapter.py index 0d64b24d7e4..efe0b5d1de7 100644 --- a/gateway/platforms/discord.py +++ b/plugins/platforms/discord/adapter.py @@ -1489,7 +1489,8 @@ class DiscordAdapter(BasePlatformAdapter): reported in ``raw_response['warnings']`` so the caller can surface partial-send issues. """ - from tools.send_message_tool import _derive_forum_thread_name + # _derive_forum_thread_name is defined further down in this same + # module โ€” no cross-module import needed. formatted = self.format_message(content) chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) @@ -1551,7 +1552,8 @@ class DiscordAdapter(BasePlatformAdapter): ForumChannel accepts the same file/files/content kwargs as ``channel.send``, creating the thread and starter message atomically. """ - from tools.send_message_tool import _derive_forum_thread_name + # _derive_forum_thread_name is defined further down in this same + # module โ€” no cross-module import needed. if not thread_name: # Prefer the text content, fall back to the first attached @@ -5699,7 +5701,492 @@ def _define_discord_view_classes() -> None: self.resolved = True for child in self.children: child.disabled = True - - if DISCORD_AVAILABLE: _define_discord_view_classes() + + +# โ”€โ”€ Standalone (out-of-process) sender โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +# Used by ``tools/send_message_tool._send_via_adapter`` when the gateway runner +# is not in this process (e.g. ``hermes cron`` running standalone) and no live +# DiscordAdapter instance is available. Implements the same forum/thread/ +# multipart logic the live adapter would use, via Discord's REST API directly. +# +# This block was previously hosted in ``tools/send_message_tool.py`` as +# ``_send_discord``. It moved into the plugin so all Discord-specific HTTP +# logic lives next to the adapter โ€” same shape as Teams' ``_standalone_send``. + +# Process-local cache for Discord channel-type probes. Avoids re-probing the +# same channel on every send when the directory cache has no entry (e.g. fresh +# install, or channel created after the last directory build). +_DISCORD_CHANNEL_TYPE_PROBE_CACHE: Dict[str, bool] = {} + + +def _remember_channel_is_forum(chat_id: str, is_forum: bool) -> None: + _DISCORD_CHANNEL_TYPE_PROBE_CACHE[str(chat_id)] = bool(is_forum) + + +def _probe_is_forum_cached(chat_id: str) -> Optional[bool]: + return _DISCORD_CHANNEL_TYPE_PROBE_CACHE.get(str(chat_id)) + + +def _derive_forum_thread_name(message: str) -> str: + """Derive a thread name from the first line of the message, capped at 100 chars.""" + first_line = message.strip().split("\n", 1)[0].strip() + # Strip common markdown heading prefixes + first_line = first_line.lstrip("#").strip() + if not first_line: + first_line = "New Post" + return first_line[:100] + + +def _standalone_sanitize_error(text) -> str: + """Local copy of tools.send_message_tool._sanitize_error_text โ€” strips bot + tokens from any error payload before bubbling it up. Inlined so the + plugin doesn't introduce a hard dependency on send_message_tool internals. + """ + s = str(text) + # Mask anything that looks like a Bot token in an Authorization header. + import re as _re_san + return _re_san.sub( + r"(Authorization:\s*Bot\s+)\S+", + r"\1***", + s, + flags=_re_san.IGNORECASE, + ) + + +async def _standalone_send( + pconfig, + chat_id: str, + message: str, + *, + thread_id: Optional[str] = None, + media_files: Optional[list] = None, + force_document: bool = False, +) -> Dict[str, Any]: + """Send via Discord REST API without a live gateway adapter. + + Used by ``tools/send_message_tool._send_via_adapter`` when the gateway + runner is not in this process. Reads ``DISCORD_BOT_TOKEN`` from + ``pconfig.token`` (set by the gateway config loader from env) and falls + back to the ``DISCORD_BOT_TOKEN`` env var. + + Forum channels (type 15) reject ``POST /messages`` โ€” a thread post is + created automatically via ``POST /channels/{id}/threads``. Media files + are uploaded as multipart attachments on the starter message of the new + thread. Channel type is resolved from the channel directory first, then + a process-local probe cache, and only as a last resort with a live + ``GET /channels/{id}`` probe (whose result is memoized). + + ``force_document`` is accepted for signature parity but unused โ€” Discord + treats every uploaded file as a generic attachment. + """ + try: + import aiohttp + except ImportError: + return {"error": "aiohttp not installed. Run: pip install aiohttp"} + + token = (getattr(pconfig, "token", None) or os.getenv("DISCORD_BOT_TOKEN", "")).strip() + if not token: + return {"error": "Discord standalone send: DISCORD_BOT_TOKEN is not set"} + + try: + from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp + _proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY") + _sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy) + auth_headers = {"Authorization": f"Bot {token}"} + json_headers = {**auth_headers, "Content-Type": "application/json"} + media_files = media_files or [] + last_data = None + warnings = [] + + # Thread endpoint: Discord threads are channels; send directly to the thread ID. + if thread_id: + url = f"https://discord.com/api/v10/channels/{thread_id}/messages" + else: + # Check if the target channel is a forum channel (type 15). + # Forum channels reject POST /messages โ€” create a thread post instead. + # Three-layer detection: directory cache โ†’ process-local probe + # cache โ†’ GET /channels/{id} probe (with result memoized). + _channel_type = None + try: + from gateway.channel_directory import lookup_channel_type + _channel_type = lookup_channel_type("discord", chat_id) + except Exception: + pass + + if _channel_type == "forum": + is_forum = True + elif _channel_type is not None: + is_forum = False + else: + cached = _probe_is_forum_cached(chat_id) + if cached is not None: + is_forum = cached + else: + is_forum = False + try: + info_url = f"https://discord.com/api/v10/channels/{chat_id}" + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15), **_sess_kw) as info_sess: + async with info_sess.get(info_url, headers=json_headers, **_req_kw) as info_resp: + if info_resp.status == 200: + info = await info_resp.json() + is_forum = info.get("type") == 15 + _remember_channel_is_forum(chat_id, is_forum) + except Exception: + logger.debug("Failed to probe channel type for %s", chat_id, exc_info=True) + + if is_forum: + thread_name = _derive_forum_thread_name(message) + thread_url = f"https://discord.com/api/v10/channels/{chat_id}/threads" + + # Filter to readable media files up front so we can pick the + # right code path (JSON vs multipart) before opening a session. + valid_media = [] + for media_path, _is_voice in media_files: + if not os.path.exists(media_path): + warning = f"Media file not found, skipping: {media_path}" + logger.warning(warning) + warnings.append(warning) + continue + valid_media.append(media_path) + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60), **_sess_kw) as session: + if valid_media: + # Multipart: payload_json + files[N] creates a forum + # thread with the starter message plus attachments in + # a single API call. + attachments_meta = [ + {"id": str(idx), "filename": os.path.basename(path)} + for idx, path in enumerate(valid_media) + ] + starter_message = {"content": message, "attachments": attachments_meta} + payload_json = json.dumps({"name": thread_name, "message": starter_message}) + + form = aiohttp.FormData() + form.add_field("payload_json", payload_json, content_type="application/json") + + try: + for idx, media_path in enumerate(valid_media): + with open(media_path, "rb") as fh: + form.add_field( + f"files[{idx}]", + fh.read(), + filename=os.path.basename(media_path), + ) + async with session.post(thread_url, headers=auth_headers, data=form, **_req_kw) as resp: + if resp.status not in {200, 201}: + body = await resp.text() + return {"error": f"Discord forum thread creation error ({resp.status}): {body}"} + data = await resp.json() + except Exception as e: + return {"error": _standalone_sanitize_error(f"Discord forum thread upload failed: {e}")} + else: + # No media โ€” simple JSON POST creates the thread with + # just the text starter. + async with session.post( + thread_url, + headers=json_headers, + json={ + "name": thread_name, + "message": {"content": message}, + }, + **_req_kw, + ) as resp: + if resp.status not in {200, 201}: + body = await resp.text() + return {"error": f"Discord forum thread creation error ({resp.status}): {body}"} + data = await resp.json() + + thread_id_created = data.get("id") + starter_msg_id = (data.get("message") or {}).get("id", thread_id_created) + result = { + "success": True, + "platform": "discord", + "chat_id": chat_id, + "thread_id": thread_id_created, + "message_id": starter_msg_id, + } + if warnings: + result["warnings"] = warnings + return result + + url = f"https://discord.com/api/v10/channels/{chat_id}/messages" + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session: + # Send text message (skip if empty and media is present) + if message.strip() or not media_files: + async with session.post(url, headers=json_headers, json={"content": message}, **_req_kw) as resp: + if resp.status not in {200, 201}: + body = await resp.text() + return {"error": f"Discord API error ({resp.status}): {body}"} + last_data = await resp.json() + + # Send each media file as a separate multipart upload + for media_path, _is_voice in media_files: + if not os.path.exists(media_path): + warning = f"Media file not found, skipping: {media_path}" + logger.warning(warning) + warnings.append(warning) + continue + try: + form = aiohttp.FormData() + filename = os.path.basename(media_path) + with open(media_path, "rb") as f: + form.add_field("files[0]", f, filename=filename) + async with session.post(url, headers=auth_headers, data=form, **_req_kw) as resp: + if resp.status not in {200, 201}: + body = await resp.text() + warning = _standalone_sanitize_error(f"Failed to send media {media_path}: Discord API error ({resp.status}): {body}") + logger.error(warning) + warnings.append(warning) + continue + last_data = await resp.json() + except Exception as e: + warning = _standalone_sanitize_error(f"Failed to send media {media_path}: {e}") + logger.error(warning) + warnings.append(warning) + + if last_data is None: + error = "No deliverable text or media remained after processing" + if warnings: + return {"error": error, "warnings": warnings} + return {"error": error} + + result = {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": last_data.get("id")} + if warnings: + result["warnings"] = warnings + return result + except Exception as e: + return {"error": _standalone_sanitize_error(f"Discord send failed: {e}")} + + +# โ”€โ”€ Plugin entry point โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _clean_discord_user_ids(raw: str) -> list: + """Strip common Discord mention prefixes from a comma-separated ID string.""" + cleaned = [] + for uid in raw.replace(" ", "").split(","): + uid = uid.strip() + if uid.startswith("<@") and uid.endswith(">"): + uid = uid.lstrip("<@!").rstrip(">") + if uid.lower().startswith("user:"): + uid = uid[5:] + if uid: + cleaned.append(uid) + return cleaned + + +def interactive_setup() -> None: + """Guide the user through Discord bot setup. + + Mirrors Teams' ``interactive_setup`` shape: lazy-imports CLI helpers so + the plugin's import surface stays small, prompts for the bot token, + captures an allowlist, and offers to set a home channel. + """ + from hermes_cli.config import get_env_value, save_env_value + from hermes_cli.cli_output import ( + prompt, + prompt_yes_no, + print_header, + print_info, + print_success, + ) + + print_header("Discord") + existing = get_env_value("DISCORD_BOT_TOKEN") + if existing: + print_info("Discord: already configured") + if not prompt_yes_no("Reconfigure Discord?", False): + if not get_env_value("DISCORD_ALLOWED_USERS"): + print_info("โš ๏ธ Discord has no user allowlist - anyone can use your bot!") + if prompt_yes_no("Add allowed users now?", True): + print_info(" To find Discord ID: Enable Developer Mode, right-click name โ†’ Copy ID") + allowed_users = prompt("Allowed user IDs (comma-separated)") + if allowed_users: + cleaned_ids = _clean_discord_user_ids(allowed_users) + save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids)) + print_success("Discord allowlist configured") + return + + print_info("Create a bot at https://discord.com/developers/applications") + token = prompt("Discord bot token", password=True) + if not token: + return + save_env_value("DISCORD_BOT_TOKEN", token) + print_success("Discord token saved") + + print() + print_info("๐Ÿ”’ Security: Restrict who can use your bot") + print_info(" To find your Discord user ID:") + print_info(" 1. Enable Developer Mode in Discord settings") + print_info(" 2. Right-click your name โ†’ Copy ID") + print() + print_info(" You can also use Discord usernames (resolved on gateway start).") + print() + allowed_users = prompt( + "Allowed user IDs or usernames (comma-separated, leave empty for open access)" + ) + if allowed_users: + cleaned_ids = _clean_discord_user_ids(allowed_users) + save_env_value("DISCORD_ALLOWED_USERS", ",".join(cleaned_ids)) + print_success("Discord allowlist configured") + else: + print_info("โš ๏ธ No allowlist set - anyone in servers with your bot can use it!") + + print() + print_info("๐Ÿ“ฌ Home Channel: where Hermes delivers cron job results,") + print_info(" cross-platform messages, and notifications.") + print_info(" To get a channel ID: right-click a channel โ†’ Copy Channel ID") + print_info(" (requires Developer Mode in Discord settings)") + print_info(" You can also set this later by typing /set-home in a Discord channel.") + home_channel = prompt("Home channel ID (leave empty to set later with /set-home)") + if home_channel: + save_env_value("DISCORD_HOME_CHANNEL", home_channel) + + +def _apply_yaml_config(yaml_cfg: dict, discord_cfg: dict) -> dict | None: + """Translate ``config.yaml`` ``discord:`` keys into env vars. + + Implements the ``apply_yaml_config_fn`` contract (#24836). Mirrors the + legacy ``discord_cfg`` block that used to live in + ``gateway/config.py::load_gateway_config()`` before this migration. + + The DiscordAdapter reads its runtime configuration via ``os.getenv()`` + throughout the connect / handle code paths (``DISCORD_REQUIRE_MENTION``, + ``DISCORD_FREE_RESPONSE_CHANNELS``, ``DISCORD_AUTO_THREAD``, + ``DISCORD_REACTIONS``, ``DISCORD_IGNORED_CHANNELS``, + ``DISCORD_ALLOWED_CHANNELS``, ``DISCORD_NO_THREAD_CHANNELS``, + ``DISCORD_HISTORY_BACKFILL``, ``DISCORD_HISTORY_BACKFILL_LIMIT``, + ``DISCORD_ALLOW_MENTION_*``, ``DISCORD_REPLY_TO_MODE``, + ``DISCORD_THREAD_REQUIRE_MENTION``). Rather than rewrite ~50 call sites + inside the adapter to read from ``PlatformConfig.extra`` instead, this + hook keeps the existing env-driven model and merely owns the + YAMLโ†’env translation here, next to the adapter that consumes it. + + Env vars take precedence over YAML โ€” every assignment is guarded by + ``not os.getenv(...)`` so explicit env vars survive a config.yaml + update. Returns ``None`` because no extras are seeded into + ``PlatformConfig.extra`` directly (everything flows through env). + """ + if "require_mention" in discord_cfg and not os.getenv("DISCORD_REQUIRE_MENTION"): + os.environ["DISCORD_REQUIRE_MENTION"] = str(discord_cfg["require_mention"]).lower() + if "thread_require_mention" in discord_cfg and not os.getenv("DISCORD_THREAD_REQUIRE_MENTION"): + os.environ["DISCORD_THREAD_REQUIRE_MENTION"] = str(discord_cfg["thread_require_mention"]).lower() + frc = discord_cfg.get("free_response_channels") + if frc is not None and not os.getenv("DISCORD_FREE_RESPONSE_CHANNELS"): + if isinstance(frc, list): + frc = ",".join(str(v) for v in frc) + os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc) + if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"): + os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower() + if "reactions" in discord_cfg and not os.getenv("DISCORD_REACTIONS"): + os.environ["DISCORD_REACTIONS"] = str(discord_cfg["reactions"]).lower() + # ignored_channels: channels where bot never responds (even when mentioned) + ic = discord_cfg.get("ignored_channels") + if ic is not None and not os.getenv("DISCORD_IGNORED_CHANNELS"): + if isinstance(ic, list): + ic = ",".join(str(v) for v in ic) + os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic) + # allowed_channels: if set, bot ONLY responds in these channels (whitelist) + ac = discord_cfg.get("allowed_channels") + if ac is not None and not os.getenv("DISCORD_ALLOWED_CHANNELS"): + if isinstance(ac, list): + ac = ",".join(str(v) for v in ac) + os.environ["DISCORD_ALLOWED_CHANNELS"] = str(ac) + # no_thread_channels: channels where bot responds directly without creating thread + ntc = discord_cfg.get("no_thread_channels") + if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"): + if isinstance(ntc, list): + ntc = ",".join(str(v) for v in ntc) + os.environ["DISCORD_NO_THREAD_CHANNELS"] = str(ntc) + # history_backfill: recover missed channel messages for shared sessions + # when require_mention is active. Fetches messages between bot turns + # and prepends them to the user message for context. + if "history_backfill" in discord_cfg and not os.getenv("DISCORD_HISTORY_BACKFILL"): + os.environ["DISCORD_HISTORY_BACKFILL"] = str(discord_cfg["history_backfill"]).lower() + hbl = discord_cfg.get("history_backfill_limit") + if hbl is not None and not os.getenv("DISCORD_HISTORY_BACKFILL_LIMIT"): + os.environ["DISCORD_HISTORY_BACKFILL_LIMIT"] = str(hbl) + # allow_mentions: granular control over what the bot can ping. + # Safe defaults (no @everyone/roles) are applied in the adapter; + # these YAML keys only override when set and let users opt back + # into unsafe modes (e.g. roles=true) if they actually want it. + allow_mentions_cfg = discord_cfg.get("allow_mentions") + if isinstance(allow_mentions_cfg, dict): + for yaml_key, env_key in ( + ("everyone", "DISCORD_ALLOW_MENTION_EVERYONE"), + ("roles", "DISCORD_ALLOW_MENTION_ROLES"), + ("users", "DISCORD_ALLOW_MENTION_USERS"), + ("replied_user", "DISCORD_ALLOW_MENTION_REPLIED_USER"), + ): + if yaml_key in allow_mentions_cfg and not os.getenv(env_key): + os.environ[env_key] = str(allow_mentions_cfg[yaml_key]).lower() + # reply_to_mode: top-level preferred, falls back to extra.reply_to_mode. + # YAML 1.1 parses bare 'off' as boolean False โ€” coerce to string "off". + _discord_extra = discord_cfg.get("extra") if isinstance(discord_cfg.get("extra"), dict) else {} + _discord_rtm = ( + discord_cfg["reply_to_mode"] if "reply_to_mode" in discord_cfg + else _discord_extra.get("reply_to_mode") + ) + if _discord_rtm is not None and not os.getenv("DISCORD_REPLY_TO_MODE"): + _rtm_str = "off" if _discord_rtm is False else str(_discord_rtm).lower() + os.environ["DISCORD_REPLY_TO_MODE"] = _rtm_str + return None # all settings flow through env; nothing to merge into extras + + +def _is_connected(config) -> bool: + """Discord is considered connected when DISCORD_BOT_TOKEN is set. + + Looks up via ``hermes_cli.gateway.get_env_value`` at call time (not via + the plugin's own bound import) so tests that patch ``gateway_mod.get_env_value`` + โ€” including ``test_setup_openclaw_migration`` โ€” can suppress ambient + ``DISCORD_BOT_TOKEN`` env vars. Matches what the legacy + ``_PLATFORMS["discord"]`` dispatch did before this migration. + """ + import hermes_cli.gateway as gateway_mod + return bool((gateway_mod.get_env_value("DISCORD_BOT_TOKEN") or "").strip()) + + +def _build_adapter(config): + """Factory wrapper that constructs DiscordAdapter from a PlatformConfig.""" + return DiscordAdapter(config) + + +def register(ctx) -> None: + """Plugin entry point โ€” called by the Hermes plugin system.""" + ctx.register_platform( + name="discord", + label="Discord", + adapter_factory=_build_adapter, + check_fn=check_discord_requirements, + is_connected=_is_connected, + required_env=["DISCORD_BOT_TOKEN"], + install_hint="pip install 'hermes-agent[discord]'", + # Interactive setup wizard โ€” replaces the central + # hermes_cli/setup.py::_setup_discord function. Same shape as Teams. + setup_fn=interactive_setup, + # YAMLโ†’env config bridge โ€” owns the translation of ``config.yaml`` + # ``discord:`` keys (require_mention, free_response_channels, + # auto_thread, reactions, ignored_channels, allowed_channels, + # no_thread_channels, allow_mentions.*, reply_to_mode, + # thread_require_mention) into ``DISCORD_*`` env vars that the + # adapter reads via ``os.getenv()``. Replaces the hardcoded block + # that used to live in ``gateway/config.py``. Hook contract: #24836. + apply_yaml_config_fn=_apply_yaml_config, + # Auth env vars for _is_user_authorized() integration + allowed_users_env="DISCORD_ALLOWED_USERS", + allow_all_env="DISCORD_ALLOW_ALL_USERS", + # Cron home-channel delivery + cron_deliver_env_var="DISCORD_HOME_CHANNEL", + # Out-of-process cron delivery via Discord REST API. Without this + # hook, ``deliver=discord`` cron jobs fail with "No live adapter" + # when cron runs separately from the gateway. Mirrors Teams pattern. + standalone_sender_fn=_standalone_send, + # Discord hard limit per message + max_message_length=2000, + # Display + emoji="๐ŸŽฎ", + allow_update_command=True, + ) diff --git a/plugins/platforms/discord/plugin.yaml b/plugins/platforms/discord/plugin.yaml new file mode 100644 index 00000000000..3e09fc9ec86 --- /dev/null +++ b/plugins/platforms/discord/plugin.yaml @@ -0,0 +1,34 @@ +name: discord-platform +label: Discord +kind: platform +version: 1.0.0 +description: > + Discord gateway adapter for Hermes Agent. + Connects to Discord via the discord.py library and relays messages + between Discord guilds/DMs and the Hermes agent. Supports voice mode, + slash commands, free-response channels, role-based DM auth, threads, + reactions, and channel skill bindings. +author: NousResearch +requires_env: + - name: DISCORD_BOT_TOKEN + description: "Discord bot token" + prompt: "Discord bot token" + url: "https://discord.com/developers/applications" + password: true +optional_env: + - name: DISCORD_ALLOWED_USERS + description: "Comma-separated Discord user IDs allowed to talk to the bot" + prompt: "Allowed users (comma-separated)" + password: false + - name: DISCORD_ALLOW_ALL_USERS + description: "Allow any Discord user to trigger the bot (dev only)" + prompt: "Allow all users? (true/false)" + password: false + - name: DISCORD_HOME_CHANNEL + description: "Default channel ID for cron / notification delivery" + prompt: "Home channel ID" + password: false + - name: DISCORD_HOME_CHANNEL_NAME + description: "Display name for the Discord home channel" + prompt: "Home channel display name" + password: false diff --git a/plugins/platforms/ntfy/__init__.py b/plugins/platforms/ntfy/__init__.py new file mode 100644 index 00000000000..d4f1d7bf0e3 --- /dev/null +++ b/plugins/platforms/ntfy/__init__.py @@ -0,0 +1,3 @@ +from .adapter import register + +__all__ = ["register"] diff --git a/plugins/platforms/ntfy/adapter.py b/plugins/platforms/ntfy/adapter.py new file mode 100644 index 00000000000..b9280ab9e6e --- /dev/null +++ b/plugins/platforms/ntfy/adapter.py @@ -0,0 +1,582 @@ +"""ntfy platform adapter (Hermes plugin). + +Subscribes to a topic on ntfy.sh or any self-hosted ntfy server via +HTTP streaming (``/json`` endpoint with ``poll=false``) and publishes +replies via HTTP POST. No external SDK โ€” only httpx, which is already +a Hermes dependency. + +This adapter ships as a Hermes platform plugin under +``plugins/platforms/ntfy/``. The Hermes plugin loader scans the +directory at startup, calls :func:`register`, and the platform becomes +available to ``gateway/run.py`` and ``tools/send_message_tool`` through +the registry โ€” no edits to core files required. + +Configuration in config.yaml:: + + platforms: + ntfy: + enabled: true + extra: + server: "https://ntfy.sh" # or self-hosted URL + topic: "hermes-in" # subscribe topic (incoming) + publish_topic: "hermes-out" # optional โ€” defaults to topic + token: "..." # optional Bearer / Basic auth token + markdown: true # optional โ€” enable markdown (default: false) + +Environment variables (all read at adapter construct time, env wins over +config.yaml ``extra``): + + NTFY_TOPIC Topic to subscribe to (required) + NTFY_SERVER_URL Server URL (default: https://ntfy.sh) + NTFY_TOKEN Bearer token or 'user:pass' for Basic auth + NTFY_PUBLISH_TOPIC Reply topic (defaults to NTFY_TOPIC) + NTFY_MARKDOWN "true"/"1"/"yes" enables X-Markdown header + NTFY_ALLOWED_USERS Allowlist (treated by gateway as user IDs; + on ntfy these are topic names) + NTFY_ALLOW_ALL_USERS Allow any topic โ€” dev only + NTFY_HOME_CHANNEL Default topic for cron / notification delivery + NTFY_HOME_CHANNEL_NAME Human label for the home channel + +Identity model: ntfy has no native authenticated user identity. The +``title`` field is publisher-controlled and is NOT used for +authorization. Each topic is treated as a single trusted channel โ€” +``user_id`` is fixed to the topic name. Use a private topic protected +by a read token for any real trust boundary. +""" + +import asyncio +import json +import logging +import os +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +try: + import httpx + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False + httpx = None # type: ignore[assignment] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + + +class _FatalStreamError(Exception): + """Raised when a stream error is unrecoverable (e.g. 401, 404).""" + + +DEFAULT_SERVER = "https://ntfy.sh" +MAX_MESSAGE_LENGTH = 4096 # ntfy message body limit +DEDUP_WINDOW_SECONDS = 300 +DEDUP_MAX_SIZE = 1000 +RECONNECT_BACKOFF = [2, 5, 10, 30, 60] +STREAM_TIMEOUT_SECONDS = 90 # ntfy keepalive default is 55s; give margin + + +def _build_auth_header(token: str) -> Dict[str, str]: + """Build an ``Authorization`` header from an ntfy token. + + Shared by :class:`NtfyAdapter._auth_headers` and :func:`_standalone_send` + so both paths follow the same auth shape and whitespace-stripping rules. + + Tokens are stripped of surrounding whitespace โ€” pasted tokens often + carry trailing newlines that would otherwise render the header + malformed (``Authorization: Bearer foo\\n``). ``user:pass`` tokens + become Basic auth; anything else is treated as a Bearer token. + Returns ``{}`` when no token is configured. + """ + if not token: + return {} + token = token.strip() + if not token: + return {} + if ":" in token: + import base64 + encoded = base64.b64encode(token.encode()).decode() + return {"Authorization": f"Basic {encoded}"} + return {"Authorization": f"Bearer {token}"} + + +def _truncate_body(message: str, *, context: str) -> bytes: + """Apply the ntfy 4096-char limit, logging a warning on truncation. + + ``context`` is included in the log message so adapter and standalone + truncations can be told apart in logs. + """ + if len(message) > MAX_MESSAGE_LENGTH: + logger.warning( + "%s: truncating message from %d to %d chars (ntfy limit)", + context, len(message), MAX_MESSAGE_LENGTH, + ) + return message[:MAX_MESSAGE_LENGTH].encode("utf-8") + + +def check_requirements() -> bool: + """Check whether the ntfy adapter is installable and minimally configured. + + Reads ``NTFY_TOPIC`` directly to avoid the cost of a full + ``load_gateway_config()`` (which also writes to ``os.environ``) on + every pre-flight check. + """ + if not HTTPX_AVAILABLE: + return False + topic = os.getenv("NTFY_TOPIC", "").strip() + return bool(topic) + + +def validate_config(config) -> bool: + """Validate that the configured ntfy platform has a topic set.""" + extra = getattr(config, "extra", {}) or {} + topic = extra.get("topic") or os.getenv("NTFY_TOPIC", "") + return bool(topic) + + +def is_connected(config) -> bool: + """Check whether ntfy is configured (env or config.yaml).""" + extra = getattr(config, "extra", {}) or {} + topic = os.getenv("NTFY_TOPIC") or extra.get("topic", "") + return bool(topic) + + +class NtfyAdapter(BasePlatformAdapter): + """ntfy adapter. + + Subscribes to a topic via HTTP streaming (``/json`` endpoint) and + publishes replies via HTTP POST. No external SDK โ€” only httpx. + """ + + MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH + + def __init__(self, config: PlatformConfig): + platform = Platform("ntfy") + super().__init__(config=config, platform=platform) + + extra = config.extra or {} + self._server: str = ( + extra.get("server") + or os.getenv("NTFY_SERVER_URL", DEFAULT_SERVER) + ).rstrip("/") + self._topic: str = extra.get("topic") or os.getenv("NTFY_TOPIC", "") + self._publish_topic: str = ( + extra.get("publish_topic") + or os.getenv("NTFY_PUBLISH_TOPIC", "") + or self._topic + ) + self._token: str = extra.get("token") or os.getenv("NTFY_TOKEN", "") + + self._stream_task: Optional[asyncio.Task] = None + self._http_client: Optional["httpx.AsyncClient"] = None + + # Message deduplication: msg_id -> timestamp + self._seen_messages: Dict[str, float] = {} + + # -- Connection lifecycle ----------------------------------------------- + + async def connect(self) -> bool: + """Connect to ntfy by starting the streaming subscription task.""" + if not HTTPX_AVAILABLE: + logger.warning("[%s] httpx not installed. Run: pip install httpx", self.name) + return False + if not self._topic: + logger.warning("[%s] NTFY_TOPIC not configured", self.name) + return False + + try: + self._http_client = httpx.AsyncClient(timeout=None) + self._stream_task = asyncio.create_task(self._run_stream()) + self._mark_connected() + logger.info("[%s] Connected โ€” subscribing to %s/%s", self.name, self._server, self._topic) + return True + except Exception as e: + logger.error("[%s] Failed to connect: %s", self.name, e) + return False + + async def _run_stream(self) -> None: + """Subscribe to the ntfy topic with automatic reconnection.""" + backoff_idx = 0 + stream_start: float = 0.0 + url = f"{self._server}/{self._topic}/json" + headers = self._auth_headers() + + while self._running: + try: + logger.debug("[%s] Opening stream to %s", self.name, url) + stream_start = time.monotonic() + await self._consume_stream(url, headers) + except asyncio.CancelledError: + return + except _FatalStreamError: + self._running = False + return + except Exception as e: + if not self._running: + return + logger.warning("[%s] Stream error: %s", self.name, e) + + if not self._running: + return + + # Reset backoff if stream stayed alive for at least 60s + if time.monotonic() - stream_start >= 60.0: + backoff_idx = 0 + delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)] + logger.info("[%s] Reconnecting in %ds...", self.name, delay) + await asyncio.sleep(delay) + backoff_idx += 1 + + async def _consume_stream(self, url: str, headers: Dict[str, str]) -> None: + """Open an HTTP streaming connection and dispatch events.""" + # poll=false keeps a persistent streaming connection alive with keepalive events + params = {"poll": "false"} + async with self._http_client.stream( + "GET", + url, + headers=headers, + params=params, + timeout=httpx.Timeout(connect=15.0, read=STREAM_TIMEOUT_SECONDS, write=15.0, pool=15.0), + ) as response: + if response.status_code == 401: + logger.error( + "[%s] Authentication failed (401) โ€” stopping reconnect loop. Check NTFY_TOKEN.", + self.name, + ) + self._set_fatal_error( + "ntfy_unauthorized", + "ntfy server rejected auth (401). Check NTFY_TOKEN.", + retryable=False, + ) + raise _FatalStreamError("401 Unauthorized") + if response.status_code == 404: + logger.error( + "[%s] Topic not found (404): %s โ€” stopping reconnect loop.", + self.name, self._topic, + ) + self._set_fatal_error( + "ntfy_topic_not_found", + f"ntfy topic '{self._topic}' returned 404. Check NTFY_TOPIC.", + retryable=False, + ) + raise _FatalStreamError("404 Not Found") + response.raise_for_status() + + async for line in response.aiter_lines(): + if not self._running: + return + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + continue + if event.get("event") == "message": + await self._on_message(event) + + async def disconnect(self) -> None: + """Disconnect from ntfy.""" + self._running = False + self._mark_disconnected() + + if self._stream_task: + self._stream_task.cancel() + try: + await self._stream_task + except asyncio.CancelledError: + pass + self._stream_task = None + + if self._http_client: + await self._http_client.aclose() + self._http_client = None + + self._seen_messages.clear() + logger.info("[%s] Disconnected", self.name) + + # -- Inbound message processing ----------------------------------------- + + async def _on_message(self, event: Dict[str, Any]) -> None: + """Process an incoming ntfy message event.""" + msg_id = event.get("id") or uuid.uuid4().hex + if self._is_duplicate(msg_id): + logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) + return + + text = (event.get("message") or "").strip() + if not text: + logger.debug("[%s] Empty message body, skipping", self.name) + return + + topic = event.get("topic") or self._topic + # ntfy has no native authenticated user identity. The title field is + # publisher-controlled and must NOT be used for authorization โ€” any + # publisher who knows the topic can set title to an allowed username. + # Treat ntfy as a single trusted channel; user_id is fixed to the + # topic name. NTFY_ALLOWED_USERS is only a real trust boundary when + # the topic itself is protected by a read token. + user_id = topic + user_name = topic + + source = self.build_source( + chat_id=topic, + chat_name=topic, + chat_type="dm", + user_id=user_id, + user_name=user_name, + ) + + unix_ts = event.get("time") + try: + timestamp = ( + datetime.fromtimestamp(int(unix_ts), tz=timezone.utc) + if unix_ts else datetime.now(tz=timezone.utc) + ) + except (ValueError, OSError, TypeError): + timestamp = datetime.now(tz=timezone.utc) + + message_event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id=msg_id, + raw_message=event, + timestamp=timestamp, + ) + + logger.debug("[%s] Message on topic %s: %s", self.name, topic, text[:80]) + await self.handle_message(message_event) + + # -- Deduplication ------------------------------------------------------ + + def _is_duplicate(self, msg_id: str) -> bool: + """Return True if this message ID was already seen within the dedup window.""" + now = time.time() + if len(self._seen_messages) > DEDUP_MAX_SIZE: + cutoff = now - DEDUP_WINDOW_SECONDS + self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff} + + if msg_id in self._seen_messages: + return True + self._seen_messages[msg_id] = now + return False + + # -- Outbound messaging ------------------------------------------------- + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Publish a message to the configured publish topic.""" + metadata = metadata or {} + publish_topic = metadata.get("publish_topic") or self._publish_topic or chat_id + + if not self._http_client: + return SendResult(success=False, error="HTTP client not initialized") + + url = f"{self._server}/{publish_topic}" + markdown_enabled = (self.config.extra or {}).get("markdown", False) + headers = {**self._auth_headers(), "Content-Type": "text/plain; charset=utf-8"} + if markdown_enabled: + headers["X-Markdown"] = "true" + + if len(content) > self.MAX_MESSAGE_LENGTH: + logger.warning( + "[%s] Message truncated from %d to %d chars (ntfy limit)", + self.name, len(content), self.MAX_MESSAGE_LENGTH, + ) + body = content[:self.MAX_MESSAGE_LENGTH] + + try: + resp = await self._http_client.post( + url, content=body.encode("utf-8"), headers=headers, timeout=15.0, + ) + if resp.status_code < 300: + try: + data = resp.json() + returned_id = data.get("id") or uuid.uuid4().hex[:12] + except Exception: + returned_id = uuid.uuid4().hex[:12] + return SendResult(success=True, message_id=returned_id) + body_text = resp.text + logger.warning("[%s] Send failed HTTP %d: %s", self.name, resp.status_code, body_text[:200]) + return SendResult(success=False, error=f"HTTP {resp.status_code}: {body_text[:200]}") + except httpx.TimeoutException: + return SendResult(success=False, error="Timeout publishing to ntfy") + except Exception as e: + logger.error("[%s] Send error: %s", self.name, e) + return SendResult(success=False, error=str(e)) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """ntfy does not support typing indicators.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return basic info about an ntfy topic.""" + return {"name": chat_id, "type": "dm"} + + # -- Helpers ------------------------------------------------------------ + + def _auth_headers(self) -> Dict[str, str]: + """Build Authorization header if a token is configured.""" + return _build_auth_header(self._token) + + +# --------------------------------------------------------------------------- +# Plugin registration +# --------------------------------------------------------------------------- + + +def _env_enablement() -> dict | None: + """Seed ``PlatformConfig.extra`` from env vars during gateway config load. + + Called by the platform registry's env-enablement hook BEFORE adapter + construction, so ``gateway status`` and ``get_connected_platforms()`` + reflect env-only configuration without instantiating the HTTP client. + Returns ``None`` when ntfy isn't minimally configured; the caller skips + auto-enabling. + + The special ``home_channel`` key in the returned dict is handled by the + core hook โ€” it becomes a proper ``HomeChannel`` dataclass on the + ``PlatformConfig`` rather than being merged into ``extra``. + """ + topic = os.getenv("NTFY_TOPIC", "").strip() + if not topic: + return None + seed: dict = { + "topic": topic, + "server": os.getenv("NTFY_SERVER_URL", DEFAULT_SERVER).rstrip("/"), + } + publish_topic = os.getenv("NTFY_PUBLISH_TOPIC", "").strip() + if publish_topic: + seed["publish_topic"] = publish_topic + token = os.getenv("NTFY_TOKEN", "").strip() + if token: + seed["token"] = token + markdown = os.getenv("NTFY_MARKDOWN", "").strip().lower() + if markdown: + seed["markdown"] = markdown in ("1", "true", "yes") + home = os.getenv("NTFY_HOME_CHANNEL", "").strip() or topic + if home: + seed["home_channel"] = { + "chat_id": home, + "name": os.getenv("NTFY_HOME_CHANNEL_NAME", home), + } + return seed + + +async def _standalone_send( + pconfig, + chat_id: str, + message: str, + *, + thread_id: Optional[str] = None, + media_files: Optional[List[str]] = None, + force_document: bool = False, +) -> Dict[str, Any]: + """Out-of-process publish for cron / send_message_tool fallbacks. + + Used by ``tools/send_message_tool._send_via_adapter`` and the cron + scheduler when the gateway runner is not in this process (e.g. + ``hermes cron`` running standalone). Without this hook, + ``deliver=ntfy`` cron jobs fail with ``No live adapter for platform``. + + ``thread_id`` and ``media_files`` are accepted for signature parity + only โ€” ntfy has no thread or attachment primitive. Markdown is + honored if ``NTFY_MARKDOWN`` is set OR ``pconfig.extra["markdown"]`` + is True. + """ + if not HTTPX_AVAILABLE: + return {"error": "ntfy standalone send: httpx not installed"} + + extra = getattr(pconfig, "extra", {}) or {} + server = ( + extra.get("server") + or os.getenv("NTFY_SERVER_URL", DEFAULT_SERVER) + ).rstrip("/") + publish_topic = ( + chat_id + or extra.get("publish_topic") + or os.getenv("NTFY_PUBLISH_TOPIC", "").strip() + or extra.get("topic") + or os.getenv("NTFY_TOPIC", "").strip() + ) + if not publish_topic: + return {"error": "ntfy standalone send: NTFY_TOPIC not configured"} + + token = extra.get("token") or os.getenv("NTFY_TOKEN", "") + markdown_env = os.getenv("NTFY_MARKDOWN", "").strip().lower() + markdown_enabled = bool(extra.get("markdown")) or markdown_env in ("1", "true", "yes") + + headers = {"Content-Type": "text/plain; charset=utf-8", **_build_auth_header(token)} + if markdown_enabled: + headers["X-Markdown"] = "true" + + body = _truncate_body(message, context="ntfy standalone") + + url = f"{server}/{publish_topic}" + try: + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.post(url, content=body, headers=headers) + if resp.status_code >= 300: + return {"error": f"ntfy HTTP {resp.status_code}: {resp.text[:200]}"} + try: + data = resp.json() + msg_id = data.get("id") or uuid.uuid4().hex[:12] + except Exception: + msg_id = uuid.uuid4().hex[:12] + return {"success": True, "platform": "ntfy", "chat_id": publish_topic, "message_id": msg_id} + except Exception as e: + return {"error": f"ntfy standalone send failed: {e}"} + + +def register(ctx) -> None: + """Plugin entry point โ€” called by the Hermes plugin system at startup.""" + ctx.register_platform( + name="ntfy", + label="ntfy", + adapter_factory=lambda cfg: NtfyAdapter(cfg), + check_fn=check_requirements, + validate_config=validate_config, + is_connected=is_connected, + required_env=["NTFY_TOPIC"], + install_hint="pip install httpx # already a Hermes dependency", + # Env-driven auto-configuration: seeds PlatformConfig.extra so + # env-only setups show up in `hermes gateway status` without + # instantiating the HTTP client. + env_enablement_fn=_env_enablement, + # Cron home-channel delivery support โ€” `deliver=ntfy` cron jobs + # route to NTFY_HOME_CHANNEL when set. + cron_deliver_env_var="NTFY_HOME_CHANNEL", + # Out-of-process cron delivery. Without this hook, deliver=ntfy + # cron jobs fail with "No live adapter" when cron runs separately + # from the gateway. + standalone_sender_fn=_standalone_send, + # Auth env vars for _is_user_authorized() integration. + allowed_users_env="NTFY_ALLOWED_USERS", + allow_all_env="NTFY_ALLOW_ALL_USERS", + max_message_length=MAX_MESSAGE_LENGTH, + emoji="๐Ÿ””", + # ntfy publishers have no persistent identity โ€” topic names are + # the only identifier, no phone numbers / emails to redact. + pii_safe=True, + allow_update_command=True, + platform_hint=( + "You are communicating via ntfy push notifications. " + "Use plain text by default โ€” ntfy supports optional markdown " + "(set markdown: true in config or NTFY_MARKDOWN=true). " + "Keep responses concise; ntfy is a push notification service " + "with a 4096-character per-message limit." + ), + ) diff --git a/plugins/platforms/ntfy/plugin.yaml b/plugins/platforms/ntfy/plugin.yaml new file mode 100644 index 00000000000..e476a36235f --- /dev/null +++ b/plugins/platforms/ntfy/plugin.yaml @@ -0,0 +1,56 @@ +name: ntfy-platform +label: ntfy +kind: platform +version: 1.0.0 +description: > + ntfy push-notification gateway adapter for Hermes Agent. + Subscribes to a topic on ntfy.sh or any self-hosted ntfy server via + HTTP streaming, and publishes replies via HTTP POST. Lightweight โ€” + no external SDK, only httpx (already a Hermes dependency). + + ntfy has no native user-identity primitive; the adapter treats each + topic as a single trusted channel and never derives user identity + from publisher-controlled fields. Use a private topic + read token + for any real trust boundary. +author: sprmn24 +# ``requires_env`` and ``optional_env`` entries are surfaced in the +# ``hermes config`` UI via the platform-plugin env var injector in +# ``hermes_cli/config.py``. +requires_env: + - name: NTFY_TOPIC + description: "Topic name to subscribe to (e.g. hermes-in)" + prompt: "ntfy subscribe topic" + password: false +optional_env: + - name: NTFY_SERVER_URL + description: "ntfy server URL (default: https://ntfy.sh)" + prompt: "ntfy server URL" + password: false + - name: NTFY_TOKEN + description: "Bearer token or 'user:pass' for Basic auth (optional)" + prompt: "ntfy auth token (or empty)" + password: true + - name: NTFY_PUBLISH_TOPIC + description: "Topic to publish replies to (defaults to NTFY_TOPIC)" + prompt: "ntfy publish topic (or empty)" + password: false + - name: NTFY_MARKDOWN + description: "Send replies with X-Markdown: true header (true/false, default: false)" + prompt: "Enable markdown formatting? (true/false)" + password: false + - name: NTFY_ALLOWED_USERS + description: "Comma-separated topic names allowed (allowlist)" + prompt: "Allowed topic names (comma-separated)" + password: false + - name: NTFY_ALLOW_ALL_USERS + description: "Allow any topic to talk to the bot (dev only โ€” disables allowlist)" + prompt: "Allow all topics? (true/false)" + password: false + - name: NTFY_HOME_CHANNEL + description: "Default topic for cron / notification delivery" + prompt: "Home channel topic (or empty)" + password: false + - name: NTFY_HOME_CHANNEL_NAME + description: "Human label for the home channel (defaults to the topic name)" + prompt: "Home channel display name (or empty)" + password: false diff --git a/plugins/video_gen/fal/__init__.py b/plugins/video_gen/fal/__init__.py index 0f46f62a7a0..61b36789855 100644 --- a/plugins/video_gen/fal/__init__.py +++ b/plugins/video_gen/fal/__init__.py @@ -282,20 +282,24 @@ def _build_payload( # --------------------------------------------------------------------------- -# fal_client lazy import (same pattern as image_generation_tool) +# fal_client lazy import (shared with image_generation_tool via fal_common) # --------------------------------------------------------------------------- _fal_client: Any = None def _load_fal_client() -> Any: + """Lazy-load the ``fal_client`` SDK and cache it on this module. + + Delegates the actual import to :func:`tools.fal_common.import_fal_client` + so the ``lazy_deps`` ensure-install handling stays in one place. + """ global _fal_client if _fal_client is not None: return _fal_client - import fal_client # type: ignore - - _fal_client = fal_client - return fal_client + from tools.fal_common import import_fal_client + _fal_client = import_fal_client() + return _fal_client # --------------------------------------------------------------------------- diff --git a/run_agent.py b/run_agent.py index 001d03784ad..b364127c278 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1368,6 +1368,18 @@ class AIAgent: * xAI OAuth: "do not have an active Grok subscription" / "out of available resources" / "does not have permission" + "grok" + Disambiguator for xAI (#29344): the same ``code`` text ("The caller + does not have permission to execute the specified operation") is + returned for BOTH an unsubscribed account AND a stale OAuth access + token. xAI ships an explicit signal in the ``error`` field that + tells the two apart: a ``[WKE=unauthenticated:...]`` suffix (and/or + the ``OAuth2 access token could not be validated`` phrasing) means + the credentials failed validation โ€” that's recoverable by refreshing + the token, NOT by surfacing an entitlement message. When either + signal is present we return False eagerly so the credential-pool + refresh path runs, letting long-running TUI sessions recover from + stale tokens without an exit/reopen cycle. + Extend here for new providers as we discover them (Anthropic's Claude Max OAuth entitlement errors look distinct enough today that the existing 1M-context-beta branch handles them; revisit if other @@ -1377,11 +1389,29 @@ class AIAgent: return False if not isinstance(error_context, dict): return False + # Build a single lowercase haystack covering every field shape the + # body might land in. ``_extract_api_error_context`` normalises to + # ``message``/``reason``, but callers (and the test suite) may also + # hand us the raw body with ``code``/``error`` keys; cover both so + # the WKE disambiguator below fires regardless of entry point. message = str(error_context.get("message") or "").lower() reason = str(error_context.get("reason") or "").lower() - haystack = f"{message} {reason}" + code = str(error_context.get("code") or "").lower() + err = str(error_context.get("error") or "").lower() + haystack = f"{message} {reason} {code} {err}" if not haystack.strip(): return False + # xAI's authoritative disambiguator for "stale token" vs + # "unsubscribed account". Both conditions share the same + # permission-denied ``code`` text; only one carries this suffix. + # Bail out before the entitlement keyword checks so a stale OAuth + # token routes through the credential-refresh path instead of the + # surface-error-as-entitlement path. See #29344 for the long- + # running TUI failure mode this closes. + if "[wke=unauthenticated:" in haystack: + return False + if "oauth2 access token could not be validated" in haystack: + return False if "do not have an active grok subscription" in haystack: return True if "out of available resources" in haystack and "grok" in haystack: @@ -2563,6 +2593,39 @@ class AIAgent: def _close_request_openai_client(self, client: Any, *, reason: str) -> None: self._close_openai_client(client, reason=reason, shared=False) + def _abort_request_openai_client(self, client: Any, *, reason: str) -> None: + """Cross-thread abort: shut sockets down without releasing FDs. + + Companion to :meth:`_close_request_openai_client` for stranger-thread + callers (interrupt-check loop, stale-call detector). Calling + ``client.close()`` from a thread that does not own the active httpx + connection raced the still-live SSL BIO and corrupted unrelated file + descriptors when the kernel recycled the just-freed TCP FD (#29507). + + Here we only ``shutdown(SHUT_RDWR)`` the pool's sockets. That unblocks + the owning worker thread's pending ``recv``/``send`` with an EOF or + ``EPIPE`` so it can unwind and close ``client`` from its own context + โ€” which is where the FD release belongs. + """ + if client is None: + return + try: + shutdown_count = self._force_close_tcp_sockets(client) + logger.info( + "OpenAI client aborted (%s, shared=False, tcp_force_closed=%d, " + "deferred_close=stranger_thread) %s", + reason, + shutdown_count, + self._client_log_context(), + ) + except Exception as exc: + logger.debug( + "OpenAI client abort failed (%s, shared=False) %s error=%s", + reason, + self._client_log_context(), + exc, + ) + def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None): """Forwarder โ€” see ``agent.codex_runtime.run_codex_stream``.""" from agent.codex_runtime import run_codex_stream @@ -3357,6 +3420,25 @@ class AIAgent: return content if self._model_supports_vision(): + # Vision-capable on paper โ€” but if we've already learned in this + # session that the active (provider, model) rejects list-type + # tool content (e.g. Xiaomi MiMo's 400 "text is not set"), + # short-circuit to a text summary so we don't burn another + # round-trip relearning the same lesson. Cache populated by + # the 400 recovery path in agent.conversation_loop. Transient + # per-session; next session retries. + key = ( + (getattr(self, "provider", "") or "").strip().lower(), + (getattr(self, "model", "") or "").strip(), + ) + no_list = getattr(self, "_no_list_tool_content_models", None) + if no_list and key in no_list: + logger.debug( + "Tool %s: model %s/%s known to reject list-type tool " + "content this session โ€” sending text summary", + tool_name, key[0], key[1], + ) + return _multimodal_text_summary(result) return content summary = _multimodal_text_summary(result) @@ -3385,6 +3467,80 @@ class AIAgent: from agent.conversation_compression import try_shrink_image_parts_in_messages return try_shrink_image_parts_in_messages(api_messages) + def _try_strip_image_parts_from_tool_messages(self, api_messages: list) -> bool: + """Downgrade list-type tool messages to text summaries in-place. + + Recovery path for providers that reject list-type tool message content + (e.g. Xiaomi MiMo's 400 "text is not set"; see issue #27344). Walks + ``api_messages`` for any ``role: "tool"`` message whose ``content`` is + a list containing image parts, replaces the content with the existing + text part(s) (or a minimal placeholder if none survive), and records + the active (provider, model) in ``self._no_list_tool_content_models`` + so subsequent ``_tool_result_content_for_active_model`` calls in this + session preemptively downgrade screenshots without a round-trip. + + Returns True when at least one tool message was downgraded โ€” the + caller (the 400 recovery branch in ``agent.conversation_loop``) uses + this to decide whether to retry the API call with the modified + history or surface the original error. + """ + if not isinstance(api_messages, list): + return False + + # Record (provider, model) so we don't relearn this lesson. + key = ( + (getattr(self, "provider", "") or "").strip().lower(), + (getattr(self, "model", "") or "").strip(), + ) + if not hasattr(self, "_no_list_tool_content_models"): + self._no_list_tool_content_models = set() + if key[1]: # only record when we actually have a model id + self._no_list_tool_content_models.add(key) + + changed = False + for msg in api_messages: + if not isinstance(msg, dict) or msg.get("role") != "tool": + continue + content = msg.get("content") + if not isinstance(content, list): + continue + + # Salvage any text parts so the model still sees some signal. + text_parts: List[str] = [] + had_image = False + for part in content: + if not isinstance(part, dict): + if isinstance(part, str) and part.strip(): + text_parts.append(part.strip()) + continue + ptype = part.get("type") + if ptype == "image_url" or ptype == "input_image": + had_image = True + continue + if ptype in {"text", "input_text"}: + text = str(part.get("text") or "").strip() + if text: + text_parts.append(text) + + if not had_image: + # List-type content but no image parts โ€” leave alone (some + # providers reject ANY list content, but stripping a + # text-only list doesn't reduce ambiguity; let the caller + # surface the original error if this turns out to be the + # case). + continue + + if text_parts: + msg["content"] = "\n\n".join(text_parts) + else: + msg["content"] = ( + "[image content removed โ€” provider does not accept " + "list-type tool message content]" + ) + changed = True + + return changed + def _anthropic_preserve_dots(self) -> bool: """True when using an anthropic-compatible endpoint that preserves dots in model names. Alibaba/DashScope keeps dots (e.g. qwen3.5-plus). diff --git a/scripts/release.py b/scripts/release.py index 24e3fd92fc7..659e7902062 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -47,6 +47,8 @@ ACP_REGISTRY_MANIFEST = REPO_ROOT / "acp_registry" / "agent.json" AUTHOR_MAP = { # teknium (multiple emails) "teknium1@gmail.com": "teknium1", + "kenyon1977@gmail.com": "kenyonxu", + "cipherframe@users.noreply.github.com": "CipherFrame", "me@promplate.dev": "CNSeniorious000", "yichengqiao21@gmail.com": "YarrowQiao", "erhanyasarx@gmail.com": "erhnysr", @@ -59,19 +61,28 @@ AUTHOR_MAP = { "mgongzai@gmail.com": "vKongv", "0x.badfriend@gmail.com": "discodirector", "altriatree@gmail.com": "TruaShamu", + "contact-me@stark-x.cn": "Stark-X", "nat@nthrow.io": "nthrow", "m@mobrienv.dev": "mikeyobrien", "saeed919@pm.me": "falasi", + "chrisdlc119@outlook.com": "chdlc", "omar@techdeveloper.site": "nycomar", "qiyin.zuo@pcitc.com": "qiyin-code", "mr.aashiz@gmail.com": "aashizpoudel", "70629228+shaun0927@users.noreply.github.com": "shaun0927", + "soju06@users.noreply.github.com": "Soju06", + "34199905+Soju06@users.noreply.github.com": "Soju06", "98262967+Bihruze@users.noreply.github.com": "Bihruze", + "189280367+Lempkey@users.noreply.github.com": "Lempkey", + "34853915+m0n3r0@users.noreply.github.com": "m0n3r0", + "leeseoki@makestar.com": "leeseoki0", + "leovillalbajr@gmail.com": "Lempkey", "nidhi2894@gmail.com": "nidhi-singh02", "30312689+aashizpoudel@users.noreply.github.com": "aashizpoudel", "oleksii.lisikh@gmail.com": "olisikh", "jithendranaidunara@gmail.com": "JithendraNara", "jeremy@geocaching.com": "outdoorsea", + "54763683+thedavidmurray@users.noreply.github.com": "thedavidmurray", "leone.parise@gmail.com": "leoneparise", "mr@shu.io": "mrshu", "adam.manning@gmail.com": "am423", @@ -80,7 +91,9 @@ AUTHOR_MAP = { "yanglongwei06@gmail.com": "Alex-yang00", "teknium@nousresearch.com": "teknium1", "markuscontasul@gmail.com": "Glucksberg", + "80581902+Glucksberg@users.noreply.github.com": "Glucksberg", "piyushvp1@gmail.com": "thelumiereguy", + "pnascimento9596@gmail.com": "pnascimento9596", "dskwelmcy@163.com": "dskwe", "421774554@qq.com": "wuli666", "twebefy@gmail.com": "tw2818", @@ -184,6 +197,7 @@ AUTHOR_MAP = { "gonzes7@gmail.com": "aqilaziz", # PR #26406 salvage (preserve native audio outside Telegram) "karthikeyann@users.noreply.github.com": "karthikeyann", # PR #26609 salvage (DM-topic routing pin) "rino.alpin@gmail.com": "kunci115", # PR #27098 salvage (thread-not-found retry) + "hayka-pacha@users.noreply.github.com": "hayka-pacha", # PR #25270 salvage (registry-aware mcp_ prefix strip) "237601532+chromalinx@users.noreply.github.com": "chromalinx", # PR #27014 salvage (commands for groups+DM) "booker1207@gmail.com": "booker1207", # PR #25132 salvage (gate profile bots by allowed topics) "kiranvk2011@gmail.com": "kiranvk-2011", # PR #24815 salvage (image documents โ†’ vision) @@ -641,7 +655,7 @@ AUTHOR_MAP = { "beibei1988@proton.me": "beibi9966", # โ”€โ”€ bulk addition: 75 emails resolved via API, PR salvage bodies, noreply # crossref, and GH contributor list matching (April 2026 audit) โ”€โ”€ - "1115117931@qq.com": "aaronagent", + "1115117931@qq.com": "aaronlab", "1506751656@qq.com": "hqhq1025", "364939526@qq.com": "luyao618", "hgk324@gmail.com": "houziershi", @@ -803,6 +817,7 @@ AUTHOR_MAP = { "xiayh17@gmail.com": "xiayh0107", "zhujianxyz@gmail.com": "opriz", "tuancanhnguyen706@gmail.com": "xxxigm", + "larcombe.n@gmail.com": "NickLarcombe", "54813621+xxxigm@users.noreply.github.com": "xxxigm", "asurla@nvidia.com": "anniesurla", "kchantharuan@nvidia.com": "nv-kasikritc", @@ -930,6 +945,8 @@ AUTHOR_MAP = { "holynn@placeholder.local": "holynn-q", "agent@hermes.local": "jacdevos", "sunsky.lau@gmail.com": "liuhao1024", + "fabianoeq@gmail.com": "rodrigoeqnit", + "178342791+sgtworkman@users.noreply.github.com": "sgtworkman", "qiuqfang98@qq.com": "keepcalmqqf", "261867348+ai-ag2026@users.noreply.github.com": "ai-ag2026", "yanzh.su@gmail.com": "YanzhongSu", @@ -1263,6 +1280,10 @@ AUTHOR_MAP = { "120500656+oooindefatigable@users.noreply.github.com": "ooovenenoso", "vanthinh6886@gmail.com": "vanthinh6886", # PR #28018 salvage (yaml/flock/atomic write guards) "erik.engervall@gmail.com": "erikengervall", # PR #28774 (firecrawl integration tag) + "egilewski@egilewski.com": "egilewski", # PR #30432 (MEDIA path traversal fix, GHSA-jmf9-9729-7pp8) + "edison@mcclean.codes": "McClean-Edison", # PR #29817 (register_auxiliary_task plugin API) + "zhangsamuel12@gmail.com": "SamuelZ12", # PR #7480 (show recap after in-session resume) + "490408354@qq.com": "daizhonggeng", # PR #9020 (numbered /resume selection) } diff --git a/scripts/run_tests_parallel.py b/scripts/run_tests_parallel.py index 7daaa6cbb1e..57178899012 100755 --- a/scripts/run_tests_parallel.py +++ b/scripts/run_tests_parallel.py @@ -38,6 +38,7 @@ Exit code: 0 if every file's pytest exited 0; 1 otherwise. from __future__ import annotations import argparse +import json import os import subprocess import sys @@ -62,6 +63,11 @@ _SKIP_PARTS = {"integration", "e2e"} # via --file-timeout or HERMES_TEST_FILE_TIMEOUT. _DEFAULT_FILE_TIMEOUT_SECONDS = 600.0 # 10 minutes +# Duration cache: maps relative file paths to last-observed subprocess +# wall-clock seconds. Used by ``--slice`` to distribute files across +# CI jobs by estimated total time, so no one job gets all the slow files. +_DURATIONS_FILE = "test_durations.json" + def _count_tests( files: List[Path], repo_root: Path, pytest_passthrough: List[str] @@ -219,10 +225,10 @@ def _run_one_file( pytest_args: List[str], repo_root: Path, file_timeout: float, -) -> Tuple[Path, int, str, dict[str, int]]: +) -> Tuple[Path, int, str, dict[str, int], float]: """Run ``python -m pytest `` in a fresh subprocess. - Returns (file, returncode, captured_combined_output, summary_counts). + Returns (file, returncode, captured_combined_output, summary_counts, subprocess_wall_seconds). ``summary_counts`` is the result of ``_parse_pytest_summary(output)`` โ€” @@ -247,6 +253,7 @@ def _run_one_file( bound a pathologically slow or hung file as a whole. """ cmd = [sys.executable, "-m", "pytest", str(file), *pytest_args] + subproc_start = time.monotonic() proc = subprocess.Popen( cmd, cwd=repo_root, @@ -308,7 +315,8 @@ def _run_one_file( # so the operator can spot it. rc = 0 summary = _parse_pytest_summary(output) - return file, rc, output, summary + subproc_wall = time.monotonic() - subproc_start + return file, rc, output, summary, subproc_wall def _parse_pytest_summary(output: str) -> dict[str, int]: @@ -370,12 +378,17 @@ def _print_progress( tests_failed: int, test_counts: dict[Path, int], file_summary: dict[str, int] | None = None, + subproc_wall: float | None = None, ) -> None: """Single-line live progress. When ``file_summary`` is provided (parsed from pytest output), the per-file parenthetical shows individual test pass/fail counts instead of just the total test count. + + ``subproc_wall`` is the actual subprocess wall-clock time (excluding + queue-wait). When available, the display shows both the subprocess + time and the queue-inclusive elapsed time. """ status = "โœ“" if rc == 0 else "โœ—" pct = (tests_done / total_tests * 100) if total_tests else 0 @@ -407,10 +420,15 @@ def _print_progress( else: n_tests = test_counts.get(file, 0) test_str = f"{n_tests} tests, " if n_tests else "" + # Show subprocess time when available; fall back to queue-inclusive dur. + if subproc_wall is not None: + time_str = f"{subproc_wall:.1f}s" + else: + time_str = f"{dur:.1f}s" msg = ( f"[{pct:5.1f}% | {tests_done:>5}/{total_tests}" f" | โœ“{tests_passed:>{fw}} | โœ—{tests_failed:>{fw}}] " - f"{status} {_format_file(file, repo_root)} ({test_str}{dur:.1f}s)" + f"{status} {_format_file(file, repo_root)} ({test_str}{time_str})" ) # Truncate to terminal width if available (no clobbering ANSI lines). try: @@ -453,6 +471,107 @@ def _print_inline_failure( print(flush=True) +def _load_durations(repo_root: Path) -> dict[str, float]: + """Read the duration cache from the repo root. + + Returns a dict mapping relative file paths (e.g. + ``tests/tools/test_code_execution.py``) to wall-clock seconds from + the last run. Missing or corrupt file โ†’ empty dict (safe fallback). + """ + path = repo_root / _DURATIONS_FILE + if not path.is_file(): + return {} + try: + return json.loads(path.read_text()) + except (json.JSONDecodeError, OSError): + return {} + + +def _save_durations( + file_times: List[Tuple[Path, float]], + repo_root: Path, +) -> None: + """Write the duration cache so future ``--slice`` runs can use it. + + Merges with any existing cache so entries from files not in the + current run (e.g. from a different slice) are preserved. Keys are + repo-relative paths so the cache is portable across checkouts + and CI runners. + """ + data: dict[str, float] = _load_durations(repo_root) + for f, t in file_times: + key = _format_file(f, repo_root) + data[key] = round(t, 3) + path = repo_root / _DURATIONS_FILE + path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n") + + +def _slice_files( + files: List[Path], + slice_index: int, + slice_count: int, + durations: dict[str, float], + repo_root: Path, +) -> List[Path]: + """Return the subset of *files* belonging to slice *slice_index*. + + Uses **Longest Processing Time first** (LPT) distribution: sort files + by estimated duration descending, then greedily assign each file to + the slice with the smallest accumulated time so far. This minimizes + the makespan (max slice duration) and keeps CI jobs balanced. + + Files with no cached duration get a default estimate of 2.0s (roughly + the P50 from profiling). This means first-time ``--slice`` runs + (no cache) still get reasonable distribution, and new files don't + all land in one slice. + + ``slice_index`` is 1-indexed (1..slice_count) for ergonomics โ€” + ``--slice 1/4`` reads more naturally than ``--slice 0/4``. + """ + if slice_count < 2: + return files + if not (1 <= slice_index <= slice_count): + print( + f"error: --slice index must be 1..{slice_count}, got {slice_index}", + file=sys.stderr, + ) + sys.exit(2) + + # Build (file, estimated_duration) pairs. + default_dur = 2.0 + file_durs: List[Tuple[Path, float]] = [] + for f in files: + rel = _format_file(f, repo_root) + dur = durations.get(rel, default_dur) + file_durs.append((f, dur)) + + # Sort longest first (LPT). + file_durs.sort(key=lambda x: x[1], reverse=True) + + # Greedy assignment: for each file, add it to the slice with the + # smallest current total. + bucket_files: List[List[Path]] = [[] for _ in range(slice_count)] + bucket_totals: List[float] = [0.0] * slice_count + + for f, dur in file_durs: + # Find the least-loaded bucket. + min_idx = min(range(slice_count), key=lambda i: bucket_totals[i]) + bucket_files[min_idx].append(f) + bucket_totals[min_idx] += dur + + # Print slice summary for visibility. + target = bucket_files[slice_index - 1] + target_dur = bucket_totals[slice_index - 1] + total_dur = sum(bucket_totals) + print( + f"Slice {slice_index}/{slice_count}: {len(target)} files " + f"(~{target_dur:.0f}s estimated of {total_dur:.0f}s total)", + flush=True, + ) + + return target + + def main() -> int: parser = argparse.ArgumentParser( description=__doc__, @@ -487,6 +606,17 @@ def main() -> int: "Default: 600 (10 min), env: HERMES_TEST_FILE_TIMEOUT." ), ) + parser.add_argument( + "--slice", + metavar="I/N", + help=( + "Run only slice I of N (e.g. --slice 1/4). " + "Files are distributed across slices using cached durations " + "so each slice takes roughly equal wall time. " + "Without a duration cache, files are distributed by count. " + "Env: HERMES_TEST_SLICE (format: I/N)." + ), + ) parser.add_argument( "paths_positional", nargs="*", @@ -509,6 +639,20 @@ def main() -> int: our_args, pytest_passthrough = argv, [] args = parser.parse_args(our_args) + # Parse --slice (or HERMES_TEST_SLICE) early so we can exit on bad input + # before doing any expensive discovery. + slice_raw = args.slice or os.environ.get("HERMES_TEST_SLICE") + slice_index: int | None = None + slice_count: int = 1 + if slice_raw: + try: + idx_s, count_s = slice_raw.split("/", 1) + slice_index = int(idx_s) + slice_count = int(count_s) + except (ValueError, AttributeError): + print(f"error: --slice must be I/N (e.g. 1/4), got: {slice_raw!r}", file=sys.stderr) + sys.exit(2) + repo_root = Path(__file__).resolve().parent.parent # Resolve discovery roots: positional path args override --paths if any @@ -535,6 +679,15 @@ def main() -> int: test_counts = _count_tests(files, repo_root, pytest_passthrough) total_tests = sum(test_counts.values()) + # Apply slicing if requested โ€” distribute files across CI jobs by + # estimated duration so no one job gets all the slow files. + if slice_index is not None: + durations = _load_durations(repo_root) + files = _slice_files(files, slice_index, slice_count, durations, repo_root) + # Recount after slicing. + test_counts = {f: test_counts[f] for f in files if f in test_counts} + total_tests = sum(test_counts.values()) + print( f"Discovered {len(files)} test files ({total_tests} tests) under " f"{[str(r.relative_to(repo_root)) if r.is_relative_to(repo_root) else str(r) for r in roots]}; " @@ -545,6 +698,7 @@ def main() -> int: # Capture and print on completion (out-of-order is fine โ€” keeps the # terminal clean rather than interleaving N parallel pytest outputs). failures: List[Tuple[Path, str, Dict[str, int]]] = [] + file_times: List[Tuple[Path, float]] = [] # (file, subprocess_wall) for distribution started = time.monotonic() files_done = 0 tests_done = 0 @@ -554,11 +708,11 @@ def main() -> int: tests_failed = 0 lock = threading.Lock() - def _on_done(file: Path, started_at: float, fut: "Future[Tuple[Path, int, str, dict[str, int]]]") -> None: + def _on_done(file: Path, started_at: float, fut: "Future[Tuple[Path, int, str, dict[str, int], float]]") -> None: nonlocal files_done, tests_done, pass_count, fail_count, tests_passed, tests_failed n_tests = test_counts.get(file, 0) try: - fpath, rc, output, summary = fut.result() + fpath, rc, output, summary, subproc_wall = fut.result() except Exception as exc: # noqa: BLE001 โ€” must always advance counter with lock: files_done += 1 @@ -570,6 +724,7 @@ def main() -> int: time.monotonic() - started_at, repo_root, tests_passed, tests_failed, test_counts, + subproc_wall=0.0, ) return with lock: @@ -578,6 +733,7 @@ def main() -> int: # Accumulate test-level counts from parsed summary. tests_passed += summary.get("passed", 0) tests_failed += summary.get("failed", 0) + file_times.append((fpath, subproc_wall)) if rc == 0: pass_count += 1 else: @@ -589,6 +745,7 @@ def main() -> int: repo_root, tests_passed, tests_failed, test_counts, file_summary=summary, + subproc_wall=subproc_wall, ) if rc != 0: _print_inline_failure(fpath, output, repo_root, pytest_passthrough) @@ -613,6 +770,40 @@ def main() -> int: pct = (tests_done / total_tests * 100) if total_tests else 0 print(f"=== Summary: {len(files)} files, {tests_passed} tests passed, {tests_failed} failed ({pct:.0f}% complete) in {elapsed:.1f}s ({args.jobs} workers) ===") + # Save durations for future --slice runs. Each slice writes its own + # partial test_durations.json; a CI merge step joins them later. + # Locally, _save_durations merges with any existing cache so entries + # from previous runs aren't lost. + if file_times: + _save_durations(file_times, repo_root) + print(f" Durations cached to {_DURATIONS_FILE} ({len(file_times)} files)") + + # Per-file time distribution (throwaway diagnostic โ€” shows how + # subprocess time is distributed so we can see if startup dominates). + if file_times: + times = sorted([t for _, t in file_times]) + total_subproc = sum(times) + median_t = times[len(times) // 2] + p50 = median_t + p90 = times[int(len(times) * 0.90)] + p95 = times[int(len(times) * 0.95)] + p99 = times[min(int(len(times) * 0.99), len(times) - 1)] + max_t = times[-1] + # How many files finish in <1s? That's roughly "just startup". + fast = sum(1 for t in times if t < 1.0) + fast_2s = sum(1 for t in times if t < 2.0) + print() + print(f"=== Per-file subprocess time distribution ===") + print(f" Files: {len(times)}") + print(f" Total subprocess CPU-wall: {total_subproc:.1f}s (runner wall: {elapsed:.1f}s, parallelism: {args.jobs}x)") + print(f" P50: {p50:.2f}s P90: {p90:.2f}s P95: {p95:.2f}s P99: {p99:.2f}s Max: {max_t:.2f}s") + print(f" <1s: {fast} files ({fast/len(times)*100:.0f}%) <2s: {fast_2s} files ({fast_2s/len(times)*100:.0f}%)") + # Top 10 slowest files โ€” likely the ones dragging the run. + slowest = sorted(file_times, key=lambda x: x[1], reverse=True)[:10] + print(f" Top 10 slowest:") + for f, t in slowest: + print(f" {t:>6.2f}s {_format_file(f, repo_root)}") + if failures: print() print("=== Failure output ===") diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index c1ff1bf4e63..32eced9dd27 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -971,6 +971,18 @@ class TestSessionConfiguration: "hermes_cli.runtime_provider.resolve_runtime_provider", fake_resolve_runtime_provider, ) + # Pin the parser so this test doesn't depend on live + # ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state + # (sibling of the same hardening on + # ``test_model_switch_uses_requested_provider``). + monkeypatch.setattr( + "hermes_cli.models.parse_model_input", + lambda raw, current: ("anthropic", "claude-sonnet-4-6"), + ) + monkeypatch.setattr( + "hermes_cli.models.detect_provider_for_model", + lambda model, current: None, + ) manager = SessionManager(db=SessionDB(tmp_path / "state.db")) with patch("run_agent.AIAgent", side_effect=fake_agent): @@ -1191,6 +1203,48 @@ class TestPrompt: assert len(agent_chunks) == 1 assert agent_chunks[0].content.text == "streamed answer" + @pytest.mark.asyncio + async def test_prompt_delivers_transformed_response_after_streaming(self, agent): + """If a transform_llm_output plugin hook modifies the response after + streaming, ACP must deliver the transformed final_response so the + appended/rewritten text reaches the client. + """ + new_resp = await agent.new_session(cwd=".") + state = agent.session_manager.get_session(new_resp.session_id) + + def mock_run(*args, **kwargs): + state.agent.stream_delta_callback("original answer") + return { + "final_response": "original answer\n\n[plugin appended this]", + "response_transformed": True, + "messages": [], + } + + state.agent.run_conversation = mock_run + + mock_conn = MagicMock(spec=acp.Client) + mock_conn.session_update = AsyncMock() + agent._conn = mock_conn + + prompt = [TextContentBlock(type="text", text="hello")] + await agent.prompt(prompt=prompt, session_id=new_resp.session_id) + + updates = [ + call.kwargs.get("update") or call.args[1] + for call in mock_conn.session_update.call_args_list + ] + # The streamed chunk and the post-stream transformed message should + # both be present (final delivery is a separate update_agent_message_text + # call carrying the full transformed text). + all_texts = [ + getattr(getattr(u, "content", None), "text", None) + for u in updates + ] + assert any( + text and "[plugin appended this]" in text for text in all_texts + ), f"expected transformed final to be delivered, got: {all_texts!r}" + + @pytest.mark.asyncio async def test_prompt_auto_titles_session(self, agent): new_resp = await agent.new_session(cwd=".") @@ -1543,6 +1597,20 @@ class TestSlashCommands: "hermes_cli.runtime_provider.resolve_runtime_provider", fake_resolve_runtime_provider, ) + # Pin the model-string parser independently of the live + # ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state. + # Otherwise any test in the same xdist worker that mutates those + # globals (e.g. registers a custom provider that shadows + # ``anthropic``) flakes this one โ€” observed once in CI as + # ``'custom' == 'anthropic'``. + monkeypatch.setattr( + "hermes_cli.models.parse_model_input", + lambda raw, current: ("anthropic", "claude-sonnet-4-6"), + ) + monkeypatch.setattr( + "hermes_cli.models.detect_provider_for_model", + lambda model, current: None, + ) manager = SessionManager(db=SessionDB(tmp_path / "state.db")) with patch("run_agent.AIAgent", side_effect=fake_agent): diff --git a/tests/agent/test_anthropic_mcp_prefix_strip.py b/tests/agent/test_anthropic_mcp_prefix_strip.py new file mode 100644 index 00000000000..102cbadca51 --- /dev/null +++ b/tests/agent/test_anthropic_mcp_prefix_strip.py @@ -0,0 +1,250 @@ +"""Tests for GH-25255: Anthropic OAuth mcp_ prefix stripping. + +When strip_tool_prefix=True (Anthropic OAuth path), the transport must only +strip the ``mcp_`` prefix from OAuth-injected tools, NOT from Hermes-native +MCP server tools that are registered under their full ``mcp__`` +name in the tool registry. +""" + +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_tool_use_block(name: str, block_id: str = "tc_1", input_data: dict | None = None): + """Create a fake Anthropic tool_use content block.""" + return SimpleNamespace( + type="tool_use", + id=block_id, + name=name, + input=input_data or {"query": "test"}, + ) + + +def _make_response(*blocks, stop_reason="end_turn"): + """Create a fake Anthropic Messages response.""" + return SimpleNamespace( + content=list(blocks), + stop_reason=stop_reason, + model="claude-sonnet-4", + usage=SimpleNamespace(input_tokens=100, output_tokens=50), + ) + + +class _FakeRegistry: + """Minimal fake tool registry for testing prefix stripping logic.""" + + def __init__(self, registered_names: set[str]): + self._names = registered_names + + def get_entry(self, name: str): + if name in self._names: + return SimpleNamespace(name=name) # truthy = tool exists + return None + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestAnthropicMcpPrefixStrip: + """Verify that strip_tool_prefix only strips OAuth-injected prefixes.""" + + def _get_transport(self): + from agent.transports.anthropic import AnthropicTransport + return AnthropicTransport() + + def test_strips_prefix_for_oauth_injected_tool(self): + """OAuth tools: mcp_read_file -> read_file (stripped). + + The tool was registered as 'read_file' in the registry. + Anthropic sees 'mcp_read_file' because Hermes adds the prefix. + On response, we must strip it back to 'read_file'. + """ + transport = self._get_transport() + block = _make_tool_use_block("mcp_read_file") + response = _make_response(block) + + registry = _FakeRegistry({"read_file", "terminal", "web_search"}) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "read_file" + + def test_preserves_native_mcp_server_tool_name(self): + """Native MCP tools: mcp_composio_SEARCH -> mcp_composio_SEARCH (kept). + + The tool is registered with the full mcp_ prefix in the registry. + Stripping would break registry lookup. + """ + transport = self._get_transport() + block = _make_tool_use_block("mcp_composio_COMPOSIO_SEARCH_TOOLS") + response = _make_response(block) + + registry = _FakeRegistry({ + "mcp_composio_COMPOSIO_SEARCH_TOOLS", + "mcp_composio_COMPOSIO_GET_TOOL_SCHEMAS", + "read_file", + }) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "mcp_composio_COMPOSIO_SEARCH_TOOLS" + + def test_no_strip_when_flag_false(self): + """When strip_tool_prefix=False, names are never modified.""" + transport = self._get_transport() + block = _make_tool_use_block("mcp_read_file") + response = _make_response(block) + + registry = _FakeRegistry({"read_file"}) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=False) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "mcp_read_file" + + def test_no_strip_when_not_mcp_prefixed(self): + """Non-mcp_ names are untouched regardless of strip flag.""" + transport = self._get_transport() + block = _make_tool_use_block("web_search") + response = _make_response(block) + + registry = _FakeRegistry({"web_search"}) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "web_search" + + def test_preserves_name_when_neither_in_registry(self): + """When neither stripped nor full name is in registry, keep full name. + + Safety fallback: if we can't determine the type, prefer the full name + since it's what the LLM was told about. + """ + transport = self._get_transport() + block = _make_tool_use_block("mcp_unknown_tool") + response = _make_response(block) + + registry = _FakeRegistry({"read_file"}) # neither name registered + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "mcp_unknown_tool" + + def test_mixed_tools_same_response(self): + """Both OAuth and native MCP tools in the same response.""" + transport = self._get_transport() + block1 = _make_tool_use_block("mcp_read_file", block_id="tc_1") + block2 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_2") + block3 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_3") # also registered natively + response = _make_response(block1, block2, block3) + + registry = _FakeRegistry({ + "read_file", # OAuth-injected + "mcp_composio_SEARCH", # native MCP + }) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 3 + # OAuth tool: stripped + assert result.tool_calls[0].name == "read_file" + # Native MCP: preserved (both stripped and full are registered, full wins) + assert result.tool_calls[1].name == "mcp_composio_SEARCH" + assert result.tool_calls[2].name == "mcp_composio_SEARCH" + + def test_both_stripped_and_full_registered_prefers_full(self): + """Edge case: both 'foo' and 'mcp_foo' exist in registry. + + Keep 'mcp_foo' (the original name) since it's what the LLM requested. + """ + transport = self._get_transport() + block = _make_tool_use_block("mcp_foo") + response = _make_response(block) + + registry = _FakeRegistry({"foo", "mcp_foo"}) + with patch("tools.registry.registry", registry): + result = transport.normalize_response(response, strip_tool_prefix=True) + + assert len(result.tool_calls) == 1 + # Both exist โ€” the condition `get_entry(stripped) and not get_entry(name)` + # is False because get_entry(name) IS truthy, so we keep the full name. + assert result.tool_calls[0].name == "mcp_foo" + + +class TestAnthropicOAuthOutgoingPrefix: + """Verify the outgoing-side companion fix: build_anthropic_kwargs must not + double-prefix tool names that already start with ``mcp_`` (native MCP server + tools registered as ``mcp__``). GH-25255.""" + + def _build(self, tools, is_oauth=True): + from agent.anthropic_adapter import build_anthropic_kwargs + return build_anthropic_kwargs( + model="claude-sonnet-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=tools, + max_tokens=4096, + reasoning_config=None, + is_oauth=is_oauth, + ) + + def test_oauth_adds_prefix_to_bare_tool_name(self): + """OAuth + bare name โ†’ prefix added (existing Claude Code convention).""" + kwargs = self._build([{ + "type": "function", + "function": {"name": "read_file", "description": "x", "parameters": {}}, + }]) + names = [t["name"] for t in kwargs["tools"]] + assert names == ["mcp_read_file"] + + def test_oauth_does_not_double_prefix_native_mcp_tool(self): + """OAuth + already-prefixed native MCP name โ†’ left alone.""" + kwargs = self._build([{ + "type": "function", + "function": { + "name": "mcp_composio_COMPOSIO_SEARCH_TOOLS", + "description": "x", + "parameters": {}, + }, + }]) + names = [t["name"] for t in kwargs["tools"]] + # Must NOT become "mcp_mcp_composio_..." โ€” that breaks the round-trip + # because normalize_response only strips ONE mcp_ prefix. + assert names == ["mcp_composio_COMPOSIO_SEARCH_TOOLS"] + + def test_oauth_mixed_native_and_bare_tools(self): + """Mixed: native MCP preserved, bare names prefixed.""" + kwargs = self._build([ + {"type": "function", "function": {"name": "read_file", + "description": "x", "parameters": {}}}, + {"type": "function", "function": {"name": "mcp_composio_SEARCH", + "description": "y", "parameters": {}}}, + {"type": "function", "function": {"name": "terminal", + "description": "z", "parameters": {}}}, + ]) + names = sorted(t["name"] for t in kwargs["tools"]) + assert names == ["mcp_composio_SEARCH", "mcp_read_file", "mcp_terminal"] + + def test_non_oauth_path_untouched(self): + """Non-OAuth requests never get the prefix โ€” schemas pass through as-is.""" + kwargs = self._build([ + {"type": "function", "function": {"name": "read_file", + "description": "x", "parameters": {}}}, + {"type": "function", "function": {"name": "mcp_composio_SEARCH", + "description": "y", "parameters": {}}}, + ], is_oauth=False) + names = sorted(t["name"] for t in kwargs["tools"]) + assert names == ["mcp_composio_SEARCH", "read_file"] diff --git a/tests/agent/test_auxiliary_config_bridge.py b/tests/agent/test_auxiliary_config_bridge.py index 11fe9f71c23..3215303b5c2 100644 --- a/tests/agent/test_auxiliary_config_bridge.py +++ b/tests/agent/test_auxiliary_config_bridge.py @@ -198,22 +198,32 @@ class TestGatewayBridgeCodeParity: """Verify the gateway/run.py config bridge contains the auxiliary section.""" def test_gateway_has_auxiliary_bridge(self): - """The gateway config bridge must include auxiliary.* bridging.""" + """The gateway config bridge must include auxiliary.* bridging. + + After the plugin-aux-task API refactor (2026-05), gateway env-var + names are derived dynamically (``AUXILIARY__*``) so the + literal strings ``AUXILIARY_VISION_PROVIDER`` etc. no longer appear + in source. Assert the dynamic shape and the canonical built-in keys + bridged set instead. + """ gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py" # Pin encoding to UTF-8: source files in this repo are UTF-8, but # Path.read_text() defaults to the system locale โ€” which is cp1252 # on most Western Windows installs and crashes as soon as the file # contains any non-ASCII byte (e.g. an em-dash in a comment). content = gateway_path.read_text(encoding="utf-8") - # Check for key patterns that indicate the bridge is present - assert "AUXILIARY_VISION_PROVIDER" in content - assert "AUXILIARY_VISION_MODEL" in content - assert "AUXILIARY_VISION_BASE_URL" in content - assert "AUXILIARY_VISION_API_KEY" in content - assert "AUXILIARY_WEB_EXTRACT_PROVIDER" in content - assert "AUXILIARY_WEB_EXTRACT_MODEL" in content - assert "AUXILIARY_WEB_EXTRACT_BASE_URL" in content - assert "AUXILIARY_WEB_EXTRACT_API_KEY" in content + # Dynamic env-var derivation present + assert 'f"AUXILIARY_{_upper}_PROVIDER"' in content + assert 'f"AUXILIARY_{_upper}_MODEL"' in content + assert 'f"AUXILIARY_{_upper}_BASE_URL"' in content + assert 'f"AUXILIARY_{_upper}_API_KEY"' in content + # Built-in bridged keys present + assert "_aux_bridged_keys" in content + assert '"vision"' in content + assert '"web_extract"' in content + assert '"approval"' in content + # Plugin-aux-task discovery hooked into bridging + assert "get_plugin_auxiliary_tasks" in content def test_gateway_no_compression_env_bridge(self): """Gateway should NOT bridge compression config to env vars (config-only).""" diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index d8691fdf87c..dca10bb4462 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -65,11 +65,11 @@ class TestCompress: assert result == msgs def test_truncation_fallback_no_client(self, compressor): - # compressor has client=None and abort_on_summary_failure=False (default), - # so the LEGACY fallback path inserts a static "summary unavailable" - # placeholder and the middle window is dropped. + # Simulate "no summarizer available" explicitly. call_llm can otherwise + # discover the developer's real auxiliary credentials from auth state. msgs = [{"role": "system", "content": "System prompt"}] + self._make_messages(10) - result = compressor.compress(msgs) + with patch("agent.context_compressor.call_llm", side_effect=RuntimeError("no provider")): + result = compressor.compress(msgs) assert len(result) < len(msgs) # Should keep system message and last N assert result[0]["role"] == "system" diff --git a/tests/agent/test_display_todo_progress.py b/tests/agent/test_display_todo_progress.py new file mode 100644 index 00000000000..7205602e01a --- /dev/null +++ b/tests/agent/test_display_todo_progress.py @@ -0,0 +1,243 @@ +"""Tests for get_cute_tool_message todo progress display. + +Verifies the completion status rendering (done/total โœ“) on all three +todo tool call paths: read, create (merge=False), update (merge=True). +""" + +import json +import pytest +from agent.display import get_cute_tool_message + + +def _todo_result(total: int, completed: int) -> str: + """Build a fake todo_tool return value.""" + return json.dumps({ + "todos": [], + "summary": { + "total": total, + "pending": total - completed, + "in_progress": 0, + "completed": completed, + "cancelled": 0, + }, + }) + + +class TestTodoRead: + """get_cute_tool_message(โ€ฆ, result=โ€ฆ) when todos_arg is None (read path).""" + + def test_read_no_result(self): + msg = get_cute_tool_message("todo", {}, 0.5) + assert "reading tasks" in msg + assert "0.5s" in msg + + def test_read_with_progress(self): + msg = get_cute_tool_message("todo", {}, 0.5, + result=_todo_result(4, 2)) + assert "2/4" in msg + assert "task(s)" in msg + + def test_read_all_done(self): + msg = get_cute_tool_message("todo", {}, 0.5, + result=_todo_result(4, 4)) + assert "4/4" in msg + assert "task(s)" in msg + + def test_read_zero_total(self): + """Edge case: empty todo list returns summary with total=0.""" + msg = get_cute_tool_message("todo", {}, 0.5, + result=_todo_result(0, 0)) + assert "reading tasks" in msg + + def test_read_invalid_result_fallback(self): + """Garbage result should not crash; fall back to reading tasks.""" + msg = get_cute_tool_message("todo", {}, 0.5, result="not json") + assert "reading tasks" in msg + + def test_read_result_missing_summary(self): + msg = get_cute_tool_message("todo", {}, 0.5, + result='{"todos": []}') + assert "reading tasks" in msg + + +class TestTodoCreate: + """get_cute_tool_message when merge=False (new plan creation).""" + + def test_create_default(self): + """Brand-new plan: all pending, no result โ€” plain count.""" + msg = get_cute_tool_message("todo", + {"todos": [ + {"id": "a", "content": "x", "status": "pending"}, + ]}, 0.3) + assert "1 task(s)" in msg + assert "0.3s" in msg + assert "/" not in msg # no progress fraction + + def test_create_multiple(self): + msg = get_cute_tool_message("todo", + {"todos": [ + {"id": "a", "content": "x", "status": "pending"}, + {"id": "b", "content": "y", "status": "pending"}, + {"id": "c", "content": "z", "status": "pending"}, + ]}, 0.2) + assert "3 task(s)" in msg + + def test_create_with_result_shows_progress_when_done(self): + """Even on create, if result has completed tasks show it.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "content": "x", "status": "completed"}]}, + 0.4, + result=_todo_result(1, 1)) + assert "1/1" in msg + assert "task(s)" in msg + + def test_create_with_result_zero_done(self): + """New plan with 0 done โ€” plain count, no progress fraction.""" + msg = get_cute_tool_message("todo", + {"todos": [ + {"id": "a", "content": "x", "status": "pending"}, + {"id": "b", "content": "y", "status": "pending"}, + ]}, + 0.3, + result=_todo_result(2, 0)) + assert "2 task(s)" in msg + assert "/" not in msg + + +class TestTodoUpdate: + """get_cute_tool_message when merge=True (incremental update).""" + + def test_update_no_result(self): + """No result available โ€” plain update N task(s).""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "completed"}], + "merge": True}, 0.5) + assert "update 1 task(s)" in msg + + def test_update_partial_progress(self): + """1/4 tasks completed โ€” show fraction with checkmark.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "completed"}], + "merge": True}, + 0.5, + result=_todo_result(4, 1)) + assert "update" in msg + assert "1/4" in msg + assert "โœ“" in msg + + def test_update_halfway(self): + """2/4 โ€” midpoint progress.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "b", "status": "in_progress"}], + "merge": True}, + 0.7, + result=_todo_result(4, 2)) + assert "2/4" in msg + assert "โœ“" in msg + + def test_update_all_completed(self): + """4/4 โ€” full checkmark.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "d", "status": "completed"}], + "merge": True}, + 0.2, + result=_todo_result(4, 4)) + assert "4/4" in msg + assert "โœ“" in msg + + def test_update_zero_done(self): + """No completed tasks yet โ€” plain update N task(s).""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "pending"}], + "merge": True}, + 0.3, + result=_todo_result(3, 0)) + assert "update 1 task(s)" in msg + assert "โœ“" not in msg + assert "/" not in msg # no progress fraction when done=0 + + def test_update_invalid_result_fallback(self): + """Bad JSON result โ€” fall back to plain update N task(s).""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "completed"}], + "merge": True}, + 0.6, + result="{broken") + assert "update 1 task(s)" in msg + assert "โœ“" not in msg + + def test_update_result_missing_summary(self): + """Result no summary key โ€” fall back to plain update.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "completed"}], + "merge": True}, + 0.4, + result='{"todos": []}') + assert "update 1 task(s)" in msg + assert "โœ“" not in msg + + def test_update_total_not_in_summary(self): + """Result summary missing total key.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "status": "completed"}], + "merge": True}, + 0.3, + result=json.dumps({"summary": {"completed": 2}})) + assert "update 1 task(s)" in msg + assert "โœ“" not in msg + + def test_update_multiple_tasks_in_line(self): + """Update line with several tasks in the update request.""" + msg = get_cute_tool_message("todo", + {"todos": [ + {"id": "a", "status": "completed"}, + {"id": "b", "status": "in_progress"}, + ], "merge": True}, + 0.5, + result=_todo_result(5, 3)) + assert "update" in msg + assert "3/5" in msg + assert "โœ“" in msg + + +class TestTodoEdgeCases: + """Boundary cases that should not crash.""" + + def test_merge_default_value(self): + """merge defaults to False in function signature, should be False when absent.""" + msg = get_cute_tool_message("todo", + {"todos": [{"id": "a", "content": "x", "status": "pending"}]}, + 1.0) + assert "1 task(s)" in msg + + def test_duration_formatting(self): + """Duration formatting works correctly.""" + msg = get_cute_tool_message("todo", {}, 0.123) + assert "0.1s" in msg + + msg = get_cute_tool_message("todo", {}, 1.0) + assert "1.0s" in msg + + msg = get_cute_tool_message("todo", {}, 123.456) + assert "123.5s" in msg + + def test_large_task_count(self): + """Many tasks should not break formatting.""" + many = [{"id": str(i), "content": "x", "status": "pending"} for i in range(50)] + msg = get_cute_tool_message("todo", {"todos": many}, 0.5) + assert "50 task(s)" in msg + + def test_read_with_no_args_and_no_result(self): + """Completely empty call.""" + msg = get_cute_tool_message("todo", {}, 0.0) + assert "reading tasks" in msg + + +class TestTodoSkinIntegration: + """Verify the skin prefix is applied to todo messages too. + This uses the same pattern as test_skin_engine test_tool_message_uses_skin_prefix. + """ + + def test_default_skin_prefix(self): + msg = get_cute_tool_message("todo", {}, 0.5) + assert msg.startswith("โ”Š") diff --git a/tests/agent/test_display_tool_failure.py b/tests/agent/test_display_tool_failure.py new file mode 100644 index 00000000000..ca56e20f3a1 --- /dev/null +++ b/tests/agent/test_display_tool_failure.py @@ -0,0 +1,185 @@ +"""Tests for _detect_tool_failure + _trim_error + get_cute_tool_message +inline failure suffix rendering. + +Covers the user-visible promise: when a tool fails, the CLI shows a short, +specific reason in square brackets at the end of the completion line โ€” +not a generic "[error]". +""" + +import json +import pytest + +from agent.display import ( + _detect_tool_failure, + _trim_error, + _ERROR_SUFFIX_MAX_LEN, + get_cute_tool_message, +) + + +class TestTrimError: + """The helper that shrinks an error message for inline display.""" + + def test_short_message_unchanged(self): + assert _trim_error("nope") == "nope" + + def test_whitespace_stripped(self): + assert _trim_error(" bad input ") == "bad input" + + def test_long_message_truncated_to_cap(self): + msg = "x" * 200 + trimmed = _trim_error(msg) + assert len(trimmed) <= _ERROR_SUFFIX_MAX_LEN + assert trimmed.endswith("...") + + def test_file_not_found_path_collapsed_to_filename(self): + long_path = "File not found: /home/teknium/.hermes/hermes-agent/very/deep/path/foo.py" + assert _trim_error(long_path) == "File not found: foo.py" + + def test_file_not_found_already_short_unchanged(self): + assert _trim_error("File not found: foo.py") == "File not found: foo.py" + + def test_file_not_found_relative_path_unchanged(self): + # Without a slash there's no path to trim. + assert _trim_error("File not found: foo.py") == "File not found: foo.py" + + +class TestDetectToolFailureTerminal: + """terminal: non-zero exit_code is the canonical failure signal.""" + + def test_success_returns_no_suffix(self): + result = json.dumps({"output": "ok\n", "exit_code": 0}) + assert _detect_tool_failure("terminal", result) == (False, "") + + def test_nonzero_exit_with_no_error_shows_exit_code(self): + result = json.dumps({"output": "", "exit_code": 1}) + is_failure, suffix = _detect_tool_failure("terminal", result) + assert is_failure is True + assert suffix == " [exit 1]" + + def test_nonzero_exit_with_error_shows_message(self): + result = json.dumps({ + "output": "", + "exit_code": 127, + "error": "ls: cannot access 'foo': No such file or directory", + }) + is_failure, suffix = _detect_tool_failure("terminal", result) + assert is_failure is True + assert "cannot access" in suffix + # Trimmed to the cap, in brackets + assert suffix.startswith(" [") + assert suffix.endswith("]") + + def test_malformed_json_returns_no_suffix(self): + # Terminal is special: only exit_code matters. Malformed JSON should + # not crash and should not be flagged as failure. + assert _detect_tool_failure("terminal", "not json") == (False, "") + + def test_none_result_returns_no_suffix(self): + assert _detect_tool_failure("terminal", None) == (False, "") + + +class TestDetectToolFailureMemory: + """memory: 'full' is distinct from real errors.""" + + def test_memory_full_returns_full_suffix(self): + result = json.dumps({"success": False, "error": "would exceed the limit"}) + assert _detect_tool_failure("memory", result) == (True, " [full]") + + def test_memory_other_error_returns_specific_message(self): + # An error that's NOT a "full" overflow falls through to the + # structured-error path and surfaces the actual message. + result = json.dumps({"success": False, "error": "invalid action: zap"}) + is_failure, suffix = _detect_tool_failure("memory", result) + assert is_failure is True + assert "invalid action" in suffix + + +class TestDetectToolFailureStructured: + """Generic path: any tool that returns {"error": ...} JSON.""" + + def test_read_file_error_surfaced(self): + result = json.dumps({ + "path": "/nope/missing.py", + "success": False, + "error": "File not found: /nope/missing.py", + }) + is_failure, suffix = _detect_tool_failure("read_file", result) + assert is_failure is True + # _trim_error reduces the path to the basename. + assert suffix == " [File not found: missing.py]" + + def test_error_without_success_key_still_flagged(self): + # Some tools return {"error": "..."} with no explicit success flag. + result = json.dumps({"error": "remote unavailable"}) + is_failure, suffix = _detect_tool_failure("web_search", result) + assert is_failure is True + assert suffix == " [remote unavailable]" + + def test_message_field_only_with_success_false_flagged(self): + # When success is False and only 'message' is set, surface it. + result = json.dumps({"success": False, "message": "rate limited"}) + is_failure, suffix = _detect_tool_failure("web_search", result) + assert is_failure is True + assert "rate limited" in suffix + + def test_successful_result_not_flagged(self): + result = json.dumps({"success": True, "data": "hello"}) + assert _detect_tool_failure("web_search", result) == (False, "") + + def test_dict_without_error_or_success_uses_generic_heuristic(self): + # Plain successful dict โ€” should pass through the generic + # heuristic which only fires on the string "Error" / '"error"' / etc. + result = json.dumps({"data": "hello"}) + is_failure, _ = _detect_tool_failure("web_search", result) + assert is_failure is False + + +class TestGetCuteToolMessageFailureSuffix: + """End-to-end: failure suffix is appended by get_cute_tool_message.""" + + def test_read_file_failure_suffix_appended(self): + fail = json.dumps({ + "path": "/etc/missing", + "success": False, + "error": "File not found: /etc/missing", + }) + line = get_cute_tool_message("read_file", {"path": "/etc/missing"}, 0.1, result=fail) + assert "[File not found: missing]" in line + + def test_terminal_exit_only_suffix(self): + fail = json.dumps({"output": "", "exit_code": 2}) + line = get_cute_tool_message("terminal", {"command": "false"}, 0.1, result=fail) + assert "[exit 2]" in line + + def test_terminal_with_stderr_uses_message(self): + fail = json.dumps({ + "output": "", + "exit_code": 127, + "error": "command not found: notathing", + }) + line = get_cute_tool_message("terminal", {"command": "notathing"}, 0.1, result=fail) + assert "command not found" in line + # No '[exit 127]' tag when we have a specific message + assert "exit 127" not in line + + def test_memory_full_suffix(self): + fail = json.dumps({"success": False, "error": "would exceed the limit"}) + line = get_cute_tool_message( + "memory", + {"action": "add", "target": "memory", "content": "x"}, + 0.05, + result=fail, + ) + assert "[full]" in line + + def test_success_has_no_suffix(self): + ok = json.dumps({"success": True, "data": "hi"}) + line = get_cute_tool_message("web_search", {"query": "hi"}, 0.2, result=ok) + assert "[" not in line.split("0.2s", 1)[1] + + def test_no_result_has_no_suffix(self): + # No result passed at all โ€” display function should not invent a + # failure suffix. + line = get_cute_tool_message("terminal", {"command": "ls"}, 0.2) + assert "[" not in line.split("0.2s", 1)[1] diff --git a/tests/agent/test_error_classifier.py b/tests/agent/test_error_classifier.py index a6fb56a7075..397d2673552 100644 --- a/tests/agent/test_error_classifier.py +++ b/tests/agent/test_error_classifier.py @@ -56,6 +56,7 @@ class TestFailoverReason: "overloaded", "server_error", "timeout", "context_overflow", "payload_too_large", "image_too_large", "model_not_found", "format_error", + "multimodal_tool_content_unsupported", "provider_policy_blocked", "thinking_signature", "long_context_tier", "oauth_long_context_beta_forbidden", @@ -292,6 +293,64 @@ class TestClassifyApiError: result = classify_api_error(e) assert result.reason == FailoverReason.overloaded + # โ”€โ”€ 5xx that are actually request-validation errors โ”€โ”€ + # Some OpenAI-compatible gateways (e.g. codex.nekos.me) return + # request-validation failures with a 5xx status. These are + # deterministic, so they must NOT be retried โ€” otherwise the retry + # loop hammers the identical bad request into a flood. + + def test_502_with_unknown_parameter_is_non_retryable(self): + e = MockAPIError( + "Unknown parameter: 'input[617]._empty_recovery_synthetic'", + status_code=502, + body={ + "error": { + "type": "invalid_request_error", + "message": ( + "[ObjectParam] [input[617]._empty_recovery_synthetic] " + "[unknown_parameter] Unknown parameter: " + "'input[617]._empty_recovery_synthetic'." + ), + } + }, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + assert result.should_fallback is True + + def test_502_with_unsupported_parameter_is_non_retryable(self): + e = MockAPIError( + "Unsupported parameter: logprobs", + status_code=502, + body={ + "error": { + "type": "invalid_request_error", + "message": "Unsupported parameter: logprobs", + } + }, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + def test_500_with_invalid_request_error_type_is_non_retryable(self): + e = MockAPIError( + "bad request", + status_code=500, + body={"error": {"type": "invalid_request_error", "message": "bad request"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + def test_502_plain_bad_gateway_still_retryable(self): + """A genuine 502 with no request-validation signal stays retryable.""" + e = MockAPIError("Bad Gateway", status_code=502) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + assert result.retryable is True + # โ”€โ”€ Model not found โ”€โ”€ def test_404_model_not_found(self): @@ -1256,3 +1315,66 @@ class TestRateLimitErrorWithoutStatusCode: e.status_code = None result = classify_api_error(e, provider="copilot", model="gpt-4o") assert result.reason != FailoverReason.rate_limit + + + +# โ”€โ”€ Test: multimodal_tool_content_unsupported pattern โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestMultimodalToolContentUnsupported: + """Issue #27344 โ€” providers that reject list-type tool message content + should be classified as ``multimodal_tool_content_unsupported`` so the + retry loop can downgrade screenshots to text and try again. + """ + + def test_xiaomi_mimo_text_is_not_set_pattern(self): + """The actual Xiaomi MiMo 400 wording from the bug report.""" + e = MockAPIError( + "Error code: 400 - {'error': {'code': '400', 'message': 'Param Incorrect', 'param': 'text is not set', 'type': ''}}", + status_code=400, + ) + result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + assert result.retryable is True + + def test_generic_tool_message_must_be_string(self): + e = MockAPIError( + "tool message content must be a string", + status_code=400, + ) + result = classify_api_error(e, provider="custom", model="some-model") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + + def test_expected_string_got_list(self): + e = MockAPIError( + "Schema validation failed: expected string, got list", + status_code=400, + ) + result = classify_api_error(e, provider="custom", model="some-model") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + + def test_multimodal_tool_content_takes_priority_over_context_overflow(self): + """Some providers return a 400 whose message contains BOTH + 'text is not set' and a length-shaped phrase; the tool-content + recovery is cheaper than compression so it must win the priority. + """ + e = MockAPIError( + "text is not set; context length exceeded", + status_code=400, + ) + result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + + def test_no_status_code_path_also_classifies(self): + """When the error reaches us without a status code (transport + layer ate it) the message-only classifier branch must also + recognise the pattern. + """ + e = MockTransportError("tool_call.content must be string") + result = classify_api_error(e, provider="alibaba", model="qwen3.5-plus") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + + def test_unrelated_400_is_not_misclassified(self): + """Make sure the patterns don't false-positive on normal 400s.""" + e = MockAPIError("bad request: missing field 'model'", status_code=400) + result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4") + assert result.reason != FailoverReason.multimodal_tool_content_unsupported diff --git a/tests/agent/test_file_safety_credentials.py b/tests/agent/test_file_safety_credentials.py new file mode 100644 index 00000000000..94cf82f2ccd --- /dev/null +++ b/tests/agent/test_file_safety_credentials.py @@ -0,0 +1,275 @@ +"""Tests for HERMES_HOME credential-file read blocking in file_safety. + +Regression for https://github.com/NousResearch/hermes-agent/issues/17656 โ€” +``read_file`` was previously only sandboxed against ``HERMES_HOME`` itself, +which left ``auth.json`` and ``.anthropic_oauth.json`` (plaintext provider +keys + OAuth tokens) readable by the agent. A prompt-injection reaching +``read_file`` could exfiltrate active credentials. + +These tests verify that ``get_read_block_error`` returns a denial message +for the credential stores while leaving arbitrary ``HERMES_HOME`` files +readable, and that the existing ``skills/.hub`` deny still applies. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + + +@pytest.fixture() +def fake_home(tmp_path, monkeypatch): + """Point ``_hermes_home_path()`` at a tmp dir for isolated checks.""" + import agent.file_safety as fs + + home = tmp_path / "hermes_home" + home.mkdir() + monkeypatch.setattr(fs, "_hermes_home_path", lambda: home) + return home + + +def _create(home: Path, rel: str | Path) -> Path: + """Create the file (with parents) so realpath() resolves it.""" + p = home / rel + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text("dummy", encoding="utf-8") + return p + + +def test_auth_json_blocked(fake_home): + from agent.file_safety import get_read_block_error + + auth = _create(fake_home, "auth.json") + err = get_read_block_error(str(auth)) + assert err is not None + assert "credential store" in err + assert "auth.json" in err + + +def test_auth_lock_blocked(fake_home): + from agent.file_safety import get_read_block_error + + lock = _create(fake_home, "auth.lock") + err = get_read_block_error(str(lock)) + assert err is not None + assert "credential store" in err + + +def test_anthropic_oauth_json_blocked(fake_home): + from agent.file_safety import get_read_block_error + + oauth = _create(fake_home, ".anthropic_oauth.json") + err = get_read_block_error(str(oauth)) + assert err is not None + assert "credential store" in err + + +def test_arbitrary_hermes_home_file_not_blocked(fake_home): + """Non-credential files inside HERMES_HOME stay readable.""" + from agent.file_safety import get_read_block_error + + safe = _create(fake_home, "session_log.txt") + assert get_read_block_error(str(safe)) is None + + +def test_subdirectory_named_auth_json_not_blocked(fake_home): + """Only the top-level auth.json is the credential store; a file with the + same name in a subdirectory (e.g., a skill mock) must remain readable.""" + from agent.file_safety import get_read_block_error + + nested = _create(fake_home, Path("skills") / "my-skill" / "auth.json") + assert get_read_block_error(str(nested)) is None + + +def test_skills_hub_block_still_applies(fake_home): + """Regression guard: the original skills/.hub deny must keep working.""" + from agent.file_safety import get_read_block_error + + hub_file = _create(fake_home, "skills/.hub/manifest.json") + err = get_read_block_error(str(hub_file)) + assert err is not None + assert "internal Hermes cache file" in err + + +def test_path_traversal_resolves_to_blocked(fake_home, tmp_path): + """A path that traverses through a sibling dir back into HERMES_HOME's + auth.json must still be caught โ€” the check resolves through realpath.""" + from agent.file_safety import get_read_block_error + + _create(fake_home, "auth.json") + sibling = tmp_path / "elsewhere" + sibling.mkdir() + traversal = sibling / ".." / "hermes_home" / "auth.json" + err = get_read_block_error(str(traversal)) + assert err is not None + assert "credential store" in err + + +def test_symlink_to_auth_json_blocked(fake_home, tmp_path): + """A symlink pointing at HERMES_HOME/auth.json from outside the home + must be blocked โ€” readlink-resolution catches the indirection.""" + from agent.file_safety import get_read_block_error + + target = _create(fake_home, "auth.json") + link = tmp_path / "shim.json" + try: + os.symlink(target, link) + except (OSError, NotImplementedError): + pytest.skip("symlinks not supported on this platform/filesystem") + err = get_read_block_error(str(link)) + assert err is not None + assert "credential store" in err + + +def test_read_file_tool_blocks_relative_path_under_terminal_cwd( + fake_home, tmp_path, monkeypatch +): + """Bypass guard: a relative path like ``"auth.json"`` resolved by + ``read_file_tool`` against ``TERMINAL_CWD == HERMES_HOME`` must still + be blocked, even though ``get_read_block_error``'s own ``resolve()`` + is anchored at the (different) Python process cwd. + """ + import json + + import tools.file_tools as ft + + _create(fake_home, "auth.json") + # Force the file_tools resolver to anchor relative paths at HERMES_HOME + # while the Python process cwd remains tmp_path (a different directory). + monkeypatch.setenv("TERMINAL_CWD", str(fake_home)) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr( + ft, "_get_live_tracking_cwd", lambda task_id="default": None + ) + + out = json.loads(ft.read_file_tool("auth.json")) + assert "error" in out + assert "credential store" in out["error"] + + +# --------------------------------------------------------------------------- +# Widening: .env, webhook_subscriptions.json, mcp-tokens/ +# --------------------------------------------------------------------------- + + +def test_dotenv_blocked(fake_home): + """.env in HERMES_HOME holds API keys โ€” blocked.""" + from agent.file_safety import get_read_block_error + + env = _create(fake_home, ".env") + err = get_read_block_error(str(env)) + assert err is not None + assert "credential store" in err + + +def test_webhook_subscriptions_blocked(fake_home): + """webhook_subscriptions.json holds per-route HMAC secrets โ€” blocked.""" + from agent.file_safety import get_read_block_error + + subs = _create(fake_home, "webhook_subscriptions.json") + err = get_read_block_error(str(subs)) + assert err is not None + assert "credential store" in err + + +def test_mcp_tokens_file_blocked(fake_home): + """Files under mcp-tokens/ hold OAuth tokens โ€” blocked.""" + from agent.file_safety import get_read_block_error + + tok = _create(fake_home, Path("mcp-tokens") / "github.json") + err = get_read_block_error(str(tok)) + assert err is not None + assert "MCP token" in err + + +def test_mcp_tokens_nested_blocked(fake_home): + """Nested files inside mcp-tokens/ are also blocked.""" + from agent.file_safety import get_read_block_error + + tok = _create(fake_home, Path("mcp-tokens") / "providers" / "azure.json") + err = get_read_block_error(str(tok)) + assert err is not None + assert "MCP token" in err + + +def test_mcp_tokens_dir_itself_blocked(fake_home): + """The mcp-tokens directory itself is blocked (listing is exfiltrating).""" + from agent.file_safety import get_read_block_error + + tokens_dir = fake_home / "mcp-tokens" + tokens_dir.mkdir(parents=True, exist_ok=True) + err = get_read_block_error(str(tokens_dir)) + assert err is not None + assert "MCP token" in err + + +def test_identically_named_files_outside_hermes_home_not_blocked( + fake_home, tmp_path +): + """A project's ``.env``, ``auth.json``, or ``mcp-tokens/`` outside + HERMES_HOME must remain readable โ€” the gate is per-location, not + per-filename.""" + from agent.file_safety import get_read_block_error + + project = tmp_path / "myproject" + project.mkdir() + for rel in (".env", "auth.json"): + p = project / rel + p.write_text("not secret here", encoding="utf-8") + assert get_read_block_error(str(p)) is None, ( + f"{rel} outside HERMES_HOME should NOT be blocked" + ) + + tokens = project / "mcp-tokens" + tokens.mkdir() + tok_file = tokens / "token.json" + tok_file.write_text("not really a token", encoding="utf-8") + assert get_read_block_error(str(tok_file)) is None + + +def test_config_yaml_not_blocked(fake_home): + """config.yaml is NOT a credential file โ€” agent should still be + able to read it for debugging. (Writes are denied separately by + is_write_denied; reads stay allowed.)""" + from agent.file_safety import get_read_block_error + + cfg = _create(fake_home, "config.yaml") + assert get_read_block_error(str(cfg)) is None + + +def test_profile_mode_blocks_root_credentials(tmp_path, monkeypatch): + """Under a profile, HERMES_HOME = /profiles/, but + /auth.json must ALSO be blocked โ€” credentials at root are + inherited by every profile.""" + import agent.file_safety as fs + + root = tmp_path / "hermes" + profile = root / "profiles" / "coder" + profile.mkdir(parents=True) + monkeypatch.setattr(fs, "_hermes_home_path", lambda: profile) + monkeypatch.setattr(fs, "_hermes_root_path", lambda: root) + + from agent.file_safety import get_read_block_error + + # Profile-local credential store: blocked + profile_auth = profile / "auth.json" + profile_auth.write_text("x") + assert "credential store" in (get_read_block_error(str(profile_auth)) or "") + + # Root-level credential store: ALSO blocked (this is the widening) + root_auth = root / "auth.json" + root_auth.write_text("x") + assert "credential store" in (get_read_block_error(str(root_auth)) or "") + + # Root-level .env: blocked too + root_env = root / ".env" + root_env.write_text("x") + assert "credential store" in (get_read_block_error(str(root_env)) or "") + + # Root-level mcp-tokens: blocked + root_tok = root / "mcp-tokens" / "gh.json" + root_tok.parent.mkdir(parents=True, exist_ok=True) + root_tok.write_text("x") + assert "MCP token" in (get_read_block_error(str(root_tok)) or "") diff --git a/tests/agent/test_file_safety_cross_profile.py b/tests/agent/test_file_safety_cross_profile.py new file mode 100644 index 00000000000..cf3605774a3 --- /dev/null +++ b/tests/agent/test_file_safety_cross_profile.py @@ -0,0 +1,219 @@ +"""Tests for the cross-Hermes-profile write guard in agent/file_safety. + +The guard fires when a tool tries to write into another Hermes profile's +skills/plugins/cron/memories directory. It's a soft guard โ€” defense in +depth, NOT a security boundary โ€” but it prevents the agent from silently +corrupting a profile that belongs to a different session. + +Reference: May 2026 incident โ€” a hermes-security profile session +accidentally edited skills under both ~/.hermes/profiles/hermes-security/skills/ +AND ~/.hermes/skills/ (the default profile's skills), realizing only +afterwards that the second path belonged to a different profile. +""" +from __future__ import annotations + +import os +from pathlib import Path + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers โ€” set up a fake Hermes root with two profiles, monkeypatch the +# resolver helpers so the classifier sees the test layout. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fake_hermes(tmp_path, monkeypatch): + """Build a fake Hermes layout: + + / + skills/foo/SKILL.md # default profile + plugins/foo/__init__.py + cron/ + memories/MEMORY.md + profiles/ + hermes-security/ + skills/foo/SKILL.md # named profile + plugins/... + coder/ + skills/foo/SKILL.md # another named profile + """ + root = tmp_path / "fake-hermes" + (root / "skills" / "foo").mkdir(parents=True) + (root / "skills" / "foo" / "SKILL.md").write_text("# default skill\n") + (root / "plugins" / "foo").mkdir(parents=True) + (root / "memories").mkdir(parents=True) + (root / "cron").mkdir(parents=True) + + sec_home = root / "profiles" / "hermes-security" + (sec_home / "skills" / "foo").mkdir(parents=True) + (sec_home / "skills" / "foo" / "SKILL.md").write_text("# sec skill\n") + (sec_home / "plugins").mkdir(parents=True) + + coder_home = root / "profiles" / "coder" + (coder_home / "skills" / "foo").mkdir(parents=True) + (coder_home / "skills" / "foo" / "SKILL.md").write_text("# coder skill\n") + + # Monkeypatch the resolver functions used by file_safety so each test + # can choose which profile is "active". + import hermes_constants + monkeypatch.setattr(hermes_constants, "get_default_hermes_root", lambda: root) + + # The reloads below ensure get_cross_profile_warning/classify see the patched root. + import agent.file_safety as fs + monkeypatch.setattr(fs, "_hermes_root_path", lambda: root) + + return { + "root": root, + "default_home": root, + "security_home": sec_home, + "coder_home": coder_home, + } + + +def _set_active_home(monkeypatch, hermes_home: Path): + """Point file_safety._hermes_home_path at a specific profile dir.""" + import agent.file_safety as fs + monkeypatch.setattr(fs, "_hermes_home_path", lambda: hermes_home) + + +# --------------------------------------------------------------------------- +# _resolve_active_profile_name +# --------------------------------------------------------------------------- + + +class TestResolveActiveProfileName: + def test_default_when_home_is_root(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["default_home"]) + from agent.file_safety import _resolve_active_profile_name + assert _resolve_active_profile_name() == "default" + + def test_named_profile(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import _resolve_active_profile_name + assert _resolve_active_profile_name() == "hermes-security" + + def test_falls_back_to_default_on_resolution_failure(self, fake_hermes, monkeypatch): + """If HERMES_HOME resolution raises, return 'default' rather than crashing the tool.""" + import agent.file_safety as fs + + def _boom(): + raise RuntimeError("simulated") + + monkeypatch.setattr(fs, "_hermes_home_path", _boom) + # Should not raise โ€” falls back to "default" + assert fs._resolve_active_profile_name() == "default" + + +# --------------------------------------------------------------------------- +# classify_cross_profile_target +# --------------------------------------------------------------------------- + + +class TestClassifyCrossProfileTarget: + def test_same_profile_write_returns_none(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + result = classify_cross_profile_target( + str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md") + ) + assert result is None + + def test_security_writing_default_skill(self, fake_hermes, monkeypatch): + """The exact incident from May 2026.""" + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + result = classify_cross_profile_target( + str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md") + ) + assert result is not None + assert result["active_profile"] == "hermes-security" + assert result["target_profile"] == "default" + assert result["area"] == "skills" + + def test_default_writing_security_skill(self, fake_hermes, monkeypatch): + """Inverse direction โ€” default-profile session reaching into a named profile.""" + _set_active_home(monkeypatch, fake_hermes["default_home"]) + from agent.file_safety import classify_cross_profile_target + result = classify_cross_profile_target( + str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md") + ) + assert result is not None + assert result["active_profile"] == "default" + assert result["target_profile"] == "hermes-security" + + def test_named_to_named_cross_profile(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + result = classify_cross_profile_target( + str(fake_hermes["coder_home"] / "skills" / "foo" / "SKILL.md") + ) + assert result is not None + assert result["target_profile"] == "coder" + + @pytest.mark.parametrize("area", ["skills", "plugins", "cron", "memories"]) + def test_all_profile_scoped_areas_classified(self, fake_hermes, monkeypatch, area): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + target = fake_hermes["default_home"] / area / "foo.txt" + result = classify_cross_profile_target(str(target)) + assert result is not None + assert result["area"] == area + + def test_non_hermes_path_returns_none(self, fake_hermes, monkeypatch, tmp_path): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + # Path outside any Hermes root + assert classify_cross_profile_target(str(tmp_path / "random.txt")) is None + + def test_hermes_config_not_classified_as_cross_profile(self, fake_hermes, monkeypatch): + """Files under /config.yaml or /.env are NOT profile-scoped + (already covered by build_write_denied_paths). Don't double-warn.""" + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import classify_cross_profile_target + # config.yaml at root level is not in PROFILE_SCOPED_AREAS + result = classify_cross_profile_target( + str(fake_hermes["default_home"] / "config.yaml") + ) + assert result is None + + +# --------------------------------------------------------------------------- +# get_cross_profile_warning +# --------------------------------------------------------------------------- + + +class TestGetCrossProfileWarning: + def test_in_profile_returns_none(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import get_cross_profile_warning + assert get_cross_profile_warning( + str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md") + ) is None + + def test_cross_profile_warning_names_both_profiles(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import get_cross_profile_warning + warn = get_cross_profile_warning( + str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md") + ) + assert warn is not None + # Must name BOTH profiles so the model knows which is which. + assert "default" in warn + assert "hermes-security" in warn + # Must name the bypass kwarg. + assert "cross_profile=True" in warn + # Must reference the area. + assert "skills" in warn + + def test_warning_is_defense_in_depth_not_boundary(self, fake_hermes, monkeypatch): + _set_active_home(monkeypatch, fake_hermes["security_home"]) + from agent.file_safety import get_cross_profile_warning + warn = get_cross_profile_warning( + str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md") + ) + # Must self-document as defense-in-depth so future reviewers + # don't promote it to a hard block. + assert "not a security boundary" in warn.lower() diff --git a/tests/agent/test_last_total_tokens.py b/tests/agent/test_last_total_tokens.py new file mode 100644 index 00000000000..ed4735ae253 --- /dev/null +++ b/tests/agent/test_last_total_tokens.py @@ -0,0 +1,22 @@ +"""Test that last_total_tokens is correctly set by ContextCompressor.""" + +from agent.context_compressor import ContextCompressor + + +def test_update_from_response_sets_total_tokens(): + """ABC contract: last_total_tokens must be set from API response.""" + c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) + + c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130}) + assert c.last_total_tokens == 130 + + c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30}) + assert c.last_total_tokens == 130 + + +def test_session_reset_clears_total_tokens(): + """on_session_reset must zero total_tokens.""" + c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) + c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130}) + c.on_session_reset() + assert c.last_total_tokens == 0 diff --git a/tests/agent/test_memory_provider.py b/tests/agent/test_memory_provider.py index ca39da70f08..6f8cfc8a93d 100644 --- a/tests/agent/test_memory_provider.py +++ b/tests/agent/test_memory_provider.py @@ -1060,3 +1060,191 @@ class TestHonchoCadenceTracking: p.on_turn_start(2, "second message") should_skip = p._injection_frequency == "first-turn" and p._turn_count > 1 assert should_skip, "Second turn (turn 2) SHOULD be skipped" + + +class TestMemoryToolToolsetGate: + """Issue #5544: memory provider tools must respect platform_toolsets. + + Before the fix, MemoryManager.get_all_tool_schemas() output was appended + to AIAgent.tools unconditionally in agent_init.py โ€” bypassing the + enabled_toolsets filter. Result: `platform_toolsets: telegram: []` + still leaked fact_store and other memory tools into the tool surface, + causing 10x latency on local models (Qwen3-30B: 1.7s โ†’ 42s) and + tool-call loops on small models. + + These tests mirror the gate logic in agent/agent_init.py around the + memory provider tool injection block. The gate condition is: + + enabled_toolsets is None โ†’ no filter, inject (backward compat) + "memory" in enabled_toolsets โ†’ user opted in, inject + otherwise (incl. []) โ†’ skip injection + """ + + @staticmethod + def _run_memory_injection(enabled_toolsets, memory_manager): + """Simulate the gated memory-tool injection block from agent_init.py.""" + tools = [] + valid_tool_names = set() + + if memory_manager and tools is not None and ( + enabled_toolsets is None or "memory" in enabled_toolsets + ): + _existing = { + t.get("function", {}).get("name") + for t in tools + if isinstance(t, dict) + } + for _schema in memory_manager.get_all_tool_schemas(): + _tname = _schema.get("name", "") + if _tname and _tname in _existing: + continue + tools.append({"type": "function", "function": _schema}) + if _tname: + valid_tool_names.add(_tname) + _existing.add(_tname) + + return tools, valid_tool_names + + def _mgr_with_tools(self, *tool_names): + """Build a MemoryManager whose providers expose the named tool schemas.""" + mgr = MemoryManager() + p = FakeMemoryProvider( + "ext", + tools=[{"name": n, "description": n, "parameters": {}} for n in tool_names], + ) + mgr.add_provider(p) + return mgr + + def test_none_toolsets_injects(self): + """enabled_toolsets=None (no filter) injects memory tools โ€” backward compat.""" + mgr = self._mgr_with_tools("fact_store") + tools, names = self._run_memory_injection(None, mgr) + assert "fact_store" in names + assert any(t["function"]["name"] == "fact_store" for t in tools) + + def test_memory_in_toolsets_injects(self): + """enabled_toolsets including 'memory' injects memory tools.""" + mgr = self._mgr_with_tools("fact_store") + tools, names = self._run_memory_injection(["terminal", "memory", "web"], mgr) + assert "fact_store" in names + + def test_empty_toolsets_blocks_injection(self): + """`platform_toolsets: telegram: []` must suppress memory tools. (#5544)""" + mgr = self._mgr_with_tools("fact_store") + tools, names = self._run_memory_injection([], mgr) + assert tools == [] + assert names == set() + + def test_toolsets_without_memory_blocks_injection(self): + """Toolset list that doesn't name 'memory' must suppress injection.""" + mgr = self._mgr_with_tools("fact_store") + tools, names = self._run_memory_injection(["terminal", "web"], mgr) + assert tools == [] + assert names == set() + + def test_no_memory_manager_no_injection(self): + """Gate is moot without a memory manager.""" + tools, names = self._run_memory_injection(None, None) + assert tools == [] + + def test_multiple_schemas_all_blocked_together(self): + """When the gate is closed, no memory tools leak โ€” not even partially.""" + mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add") + tools, names = self._run_memory_injection(["terminal"], mgr) + assert tools == [] + assert names == set() + + def test_multiple_schemas_all_injected_when_enabled(self): + """When the gate is open, every memory tool schema is injected.""" + mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add") + tools, names = self._run_memory_injection(None, mgr) + assert names == {"fact_store", "memory_search", "memory_add"} + + +class TestContextEngineToolsetGate: + """Issue #5544 (sibling): context engine tools follow the same gate. + + `agent.context_compressor.get_tool_schemas()` (e.g. lcm_grep, lcm_describe, + lcm_expand) was appended to AIAgent.tools unconditionally. Same blind + injection class as the memory bug; same local-model penalty. Gate name: + "context_engine" (matches the existing plugin-system convention). + """ + + @staticmethod + def _run_context_engine_injection(enabled_toolsets, compressor): + """Simulate the gated context-engine injection block from agent_init.py.""" + tools = [] + valid_tool_names = set() + engine_tool_names = set() + + if ( + compressor is not None + and tools is not None + and ( + enabled_toolsets is None + or "context_engine" in enabled_toolsets + ) + ): + _existing = { + t.get("function", {}).get("name") + for t in tools + if isinstance(t, dict) + } + for _schema in compressor.get_tool_schemas(): + _tname = _schema.get("name", "") + if _tname and _tname in _existing: + continue + tools.append({"type": "function", "function": _schema}) + if _tname: + valid_tool_names.add(_tname) + engine_tool_names.add(_tname) + _existing.add(_tname) + + return tools, valid_tool_names, engine_tool_names + + class _FakeCompressor: + def __init__(self, schemas): + self._schemas = schemas + + def get_tool_schemas(self): + return list(self._schemas) + + def _compressor_with(self, *tool_names): + return self._FakeCompressor( + [{"name": n, "description": n, "parameters": {}} for n in tool_names] + ) + + def test_none_toolsets_injects(self): + """enabled_toolsets=None injects context-engine tools โ€” backward compat.""" + c = self._compressor_with("lcm_grep", "lcm_describe", "lcm_expand") + tools, names, engine_names = self._run_context_engine_injection(None, c) + assert engine_names == {"lcm_grep", "lcm_describe", "lcm_expand"} + + def test_context_engine_in_toolsets_injects(self): + """enabled_toolsets including 'context_engine' injects the tools.""" + c = self._compressor_with("lcm_grep") + tools, names, engine_names = self._run_context_engine_injection( + ["terminal", "context_engine"], c + ) + assert "lcm_grep" in engine_names + + def test_empty_toolsets_blocks_injection(self): + """`platform_toolsets: telegram: []` must suppress context-engine tools.""" + c = self._compressor_with("lcm_grep") + tools, names, engine_names = self._run_context_engine_injection([], c) + assert tools == [] + assert engine_names == set() + + def test_toolsets_without_context_engine_blocks_injection(self): + """A toolset list that doesn't name 'context_engine' suppresses injection.""" + c = self._compressor_with("lcm_grep", "lcm_describe") + tools, names, engine_names = self._run_context_engine_injection( + ["terminal", "memory"], c + ) + assert tools == [] + assert engine_names == set() + + def test_no_compressor_no_injection(self): + """Gate is moot without a context_compressor.""" + tools, names, engine_names = self._run_context_engine_injection(None, None) + assert tools == [] diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 4f2b51293a6..e905c3e1f6b 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -164,6 +164,7 @@ class TestDefaultContextLengths: "grok-4-1-fast": 2000000, "grok-4-fast": 2000000, "grok-4": 256000, + "grok-build": 256000, "grok-code-fast": 256000, "grok-3": 131072, "grok-2": 131072, @@ -195,6 +196,7 @@ class TestDefaultContextLengths: ("grok-4-fast-non-reasoning", 2000000), ("grok-4", 256000), ("grok-4-0709", 256000), + ("grok-build-0.1", 256000), ("grok-code-fast-1", 256000), ("grok-3", 131072), ("grok-3-mini", 131072), @@ -210,6 +212,32 @@ class TestDefaultContextLengths: f"{model_id}: expected {expected_ctx}, got {actual}" ) + def test_xai_oauth_grok_build_uses_xai_models_dev_context(self): + """xAI OAuth should share the xAI provider metadata path. + + The xAI /v1/models endpoint does not currently include context fields + for grok-build-0.1, so this guards against falling through to the + generic "grok" 131k fallback when using OAuth credentials. + """ + registry = { + "xai": { + "models": { + "grok-build-0.1": { + "limit": {"context": 256000, "output": 64000}, + }, + }, + }, + } + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata._query_ollama_api_show", return_value=None), \ + patch("agent.models_dev.fetch_models_dev", return_value=registry): + assert get_model_context_length( + "grok-build-0.1", + provider="xai-oauth", + base_url="https://api.x.ai/v1", + api_key="oauth-token", + ) == 256000 + def test_deepseek_v4_models_1m_context(self): from agent.model_metadata import get_model_context_length from unittest.mock import patch as mock_patch diff --git a/tests/agent/test_models_dev.py b/tests/agent/test_models_dev.py index 2cb9746b223..e3338091b9f 100644 --- a/tests/agent/test_models_dev.py +++ b/tests/agent/test_models_dev.py @@ -41,6 +41,16 @@ SAMPLE_REGISTRY = { }, }, }, + "xai": { + "id": "xai", + "name": "xAI", + "models": { + "grok-build-0.1": { + "id": "grok-build-0.1", + "limit": {"context": 256000, "output": 64000}, + }, + }, + }, "kilo": { "id": "kilo", "name": "Kilo Gateway", @@ -86,6 +96,10 @@ class TestProviderMapping: assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo" assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel" + def test_xai_oauth_uses_xai_catalog(self): + assert PROVIDER_TO_MODELS_DEV["xai"] == "xai" + assert PROVIDER_TO_MODELS_DEV["xai-oauth"] == "xai" + def test_unmapped_provider_not_in_dict(self): assert "nous" not in PROVIDER_TO_MODELS_DEV @@ -144,6 +158,12 @@ class TestLookupModelsDevContext: # GitHub Copilot: only 128K for same model assert lookup_models_dev_context("copilot", "claude-opus-4.6") == 128000 + @patch("agent.models_dev.fetch_models_dev") + def test_xai_oauth_resolves_xai_context(self, mock_fetch): + """xAI OAuth is an auth path, not a separate model catalog.""" + mock_fetch.return_value = SAMPLE_REGISTRY + assert lookup_models_dev_context("xai-oauth", "grok-build-0.1") == 256000 + @patch("agent.models_dev.fetch_models_dev") def test_zero_context_filtered(self, mock_fetch): mock_fetch.return_value = SAMPLE_REGISTRY diff --git a/tests/agent/test_redact.py b/tests/agent/test_redact.py index 928eb1ff357..ea79ea9ce39 100644 --- a/tests/agent/test_redact.py +++ b/tests/agent/test_redact.py @@ -451,6 +451,28 @@ class TestUrlQueryParamRedaction: result = redact_sensitive_text(text) assert "opaqueWsToken123" not in result + def test_http_access_log_relative_request_target_query(self): + text = ( + 'INFO aiohttp.access: 127.0.0.1 "POST ' + '/bluebubbles-webhook?password=webhookSecret123&event=new-message ' + 'HTTP/1.1" 200 173 "-" "test-client"' + ) + result = redact_sensitive_text(text) + assert "webhookSecret123" not in result + assert "password=***" in result + assert "event=new-message" in result + + def test_http_access_log_absolute_request_target_query(self): + text = ( + 'INFO aiohttp.access: 127.0.0.1 "GET ' + 'https://example.com/callback?code=oauthCode123&state=csrf-ok ' + 'HTTP/1.1" 200 173 "-" "test-client"' + ) + result = redact_sensitive_text(text) + assert "oauthCode123" not in result + assert "code=***" in result + assert "state=csrf-ok" in result + class TestUrlUserinfoRedaction: """URL userinfo (`scheme://user:pass@host`) for non-DB schemes.""" diff --git a/tests/agent/test_skill_utils.py b/tests/agent/test_skill_utils.py index ae22dc569be..1338e7a5b24 100644 --- a/tests/agent/test_skill_utils.py +++ b/tests/agent/test_skill_utils.py @@ -1,6 +1,12 @@ """Tests for agent/skill_utils.py.""" -from agent.skill_utils import extract_skill_conditions, iter_skill_index_files +from unittest.mock import patch + +from agent.skill_utils import ( + extract_skill_conditions, + iter_skill_index_files, + skill_matches_platform, +) def test_metadata_as_dict_with_hermes(): @@ -94,3 +100,100 @@ def test_iter_skill_index_files_prunes_dependency_dirs(tmp_path): found = list(iter_skill_index_files(tmp_path, "SKILL.md")) assert found == [real / "SKILL.md"] + + +# โ”€โ”€ skill_matches_platform on Termux โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestSkillMatchesPlatformTermux: + """Termux is Linux userland on Android. Skills tagged platforms:[linux] + must load there regardless of whether Python reports sys.platform as + "linux" (pre-3.13) or "android" (3.13+). Reported by user @LikiusInik + in May 2026 โ€” only 3 built-in skills appeared on Termux because every + github/productivity/mlops skill is tagged platforms:[linux,macos,windows] + and sys.platform=="android" did not start with "linux". + """ + + def test_no_platforms_field_matches_everywhere(self): + # Backward-compat default โ€” skills without a platforms tag load + # on any OS, Termux included. + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform({}) is True + assert skill_matches_platform({"name": "foo"}) is True + + def test_linux_skill_loads_on_termux_android_platform(self): + # Python 3.13+ on Termux reports sys.platform == "android". + fm = {"platforms": ["linux"]} + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform(fm) is True + + def test_linux_macos_windows_skill_loads_on_termux(self): + # The common "[linux, macos, windows]" tag used by github-*, + # productivity, mlops, etc. + fm = {"platforms": ["linux", "macos", "windows"]} + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform(fm) is True + + def test_linux_skill_loads_on_termux_linux_platform(self): + # Pre-3.13 Termux reports sys.platform == "linux" already โ€” this + # works without the Termux escape hatch but must still pass. + fm = {"platforms": ["linux"]} + with patch("agent.skill_utils.sys.platform", "linux"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform(fm) is True + + def test_macos_only_skill_still_excluded_on_termux(self): + # macOS-only skills (apple-notes, imessage, ...) should NOT load + # on Termux. The Termux fallback only widens platforms:[linux,...]. + fm = {"platforms": ["macos"]} + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform(fm) is False + + def test_windows_only_skill_still_excluded_on_termux(self): + fm = {"platforms": ["windows"]} + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform(fm) is False + + def test_explicit_termux_or_android_tag_matches(self): + # Skills can also opt in explicitly via platforms:[termux] or + # platforms:[android] โ€” both should match a Termux session. + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=True + ): + assert skill_matches_platform({"platforms": ["termux"]}) is True + assert skill_matches_platform({"platforms": ["android"]}) is True + + def test_non_termux_android_does_not_widen(self): + # If we're somehow on a plain Android Python (not Termux), don't + # silently load Linux skills โ€” Termux is the supported environment. + fm = {"platforms": ["linux"]} + with patch("agent.skill_utils.sys.platform", "android"), patch( + "agent.skill_utils.is_termux", return_value=False + ): + assert skill_matches_platform(fm) is False + + def test_linux_skill_on_real_linux_unaffected(self): + # The non-Termux Linux path must not change. + fm = {"platforms": ["linux"]} + with patch("agent.skill_utils.sys.platform", "linux"), patch( + "agent.skill_utils.is_termux", return_value=False + ): + assert skill_matches_platform(fm) is True + + def test_macos_skill_on_real_macos_unaffected(self): + fm = {"platforms": ["macos"]} + with patch("agent.skill_utils.sys.platform", "darwin"), patch( + "agent.skill_utils.is_termux", return_value=False + ): + assert skill_matches_platform(fm) is True diff --git a/tests/agent/test_vision_routing_31179.py b/tests/agent/test_vision_routing_31179.py new file mode 100644 index 00000000000..268cd27aa96 --- /dev/null +++ b/tests/agent/test_vision_routing_31179.py @@ -0,0 +1,297 @@ +"""Regression tests for issue #31179. + +Before the fix: + - ``auxiliary.vision.provider: openai`` silently failed to resolve because + ``openai`` is not a first-class provider in PROVIDER_REGISTRY (only + ``openai-codex`` for OAuth and ``custom`` for OPENAI_BASE_URL). + - The vision branch of ``call_llm`` then silently fell back to ``auto`` + which happily picked the user's main provider (e.g. DeepSeek), sending + image content to a text-only endpoint and producing cryptic + ``unknown variant 'image_url', expected 'text'`` errors. + - ``check_vision_requirements`` used the explicit-only path, so + ``vision_analyze`` disappeared from the tool list while ``browser_vision`` + stayed (its check_fn only validated the browser). + +The three fixes covered here: + 1. ``provider: openai`` in auxiliary task config resolves to + ``custom`` + ``https://api.openai.com/v1``. + 2. The vision auto-detect chain skips the user's main provider when it + reports ``supports_vision=False`` instead of routing image content to + a text-only endpoint. + 3. ``check_vision_requirements`` mirrors the runtime fallback chain so + ``vision_analyze`` shows up whenever the auto chain can serve vision, + and ``browser_vision`` gates on vision availability as well. +""" + +from __future__ import annotations + +import os +import shutil +import sys +import tempfile + +import pytest + + +# --------------------------------------------------------------------------- +# Test infrastructure +# --------------------------------------------------------------------------- + + +@pytest.fixture +def isolated_home(monkeypatch): + """Temp HERMES_HOME with config + clean credential env vars.""" + test_home = tempfile.mkdtemp(prefix="hermes_test_31179_") + hermes_home = os.path.join(test_home, ".hermes") + os.makedirs(hermes_home) + monkeypatch.setenv("HERMES_HOME", hermes_home) + + # Strip all credential-shaped env vars so each scenario starts hermetic. + for k in list(os.environ.keys()): + if k.endswith("_API_KEY") or k.endswith("_TOKEN"): + monkeypatch.delenv(k, raising=False) + + yield hermes_home + shutil.rmtree(test_home, ignore_errors=True) + + +def _write_config(home: str, text: str) -> None: + with open(os.path.join(home, "config.yaml"), "w") as fp: + fp.write(text) + + +def _fresh_modules(): + """Drop cached hermes modules so each test reloads against current env.""" + for mod in list(sys.modules.keys()): + if mod.startswith(("agent.auxiliary_client", "agent.image_routing", + "tools.vision_tools", "tools.browser_tool", + "hermes_cli.config")): + del sys.modules[mod] + + +# --------------------------------------------------------------------------- +# Fix 1: provider=openai โ†’ custom + api.openai.com/v1 +# --------------------------------------------------------------------------- + + +class TestOpenAiAliasForAuxiliary: + """``auxiliary..provider: openai`` should produce a working client.""" + + def test_provider_openai_routes_to_openai_dot_com(self, isolated_home, monkeypatch): + _write_config(isolated_home, """ +auxiliary: + vision: + provider: openai + model: gpt-4o-mini +""") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + _fresh_modules() + + from agent.auxiliary_client import _resolve_task_provider_model + provider, model, base_url, _key, _mode = _resolve_task_provider_model("vision") + assert provider == "custom" + assert model == "gpt-4o-mini" + assert base_url == "https://api.openai.com/v1" + + def test_provider_openai_with_explicit_base_url_preserves_user_endpoint( + self, isolated_home, monkeypatch + ): + """User-supplied base_url wins; alias still normalizes provider name + to ``custom`` so resolution doesn't hit the unknown-provider path.""" + _write_config(isolated_home, """ +auxiliary: + vision: + provider: openai + model: gpt-4o-mini + base_url: https://my-proxy.example.com/v1 +""") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + _fresh_modules() + + from agent.auxiliary_client import _resolve_task_provider_model + provider, _model, base_url, _key, _mode = _resolve_task_provider_model("vision") + assert provider == "custom" + assert base_url == "https://my-proxy.example.com/v1" + + def test_provider_openai_resolves_to_working_client(self, isolated_home, monkeypatch): + """End-to-end: the resolved client points at api.openai.com.""" + _write_config(isolated_home, """ +auxiliary: + vision: + provider: openai + model: gpt-4o-mini +""") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + _fresh_modules() + + from agent.auxiliary_client import resolve_vision_provider_client + from urllib.parse import urlparse + provider, client, model = resolve_vision_provider_client() + assert client is not None, "openai alias should produce a usable client" + # Exact hostname comparison (not substring) โ€” defends against URLs + # like ``api.openai.com.evil.example`` and keeps CodeQL happy. + host = urlparse(str(getattr(client, "base_url", ""))).hostname or "" + assert host == "api.openai.com", f"expected api.openai.com host, got {host!r}" + assert model == "gpt-4o-mini" + + +# --------------------------------------------------------------------------- +# Fix 2: auto chain skips text-only main providers +# --------------------------------------------------------------------------- + + +class TestTextOnlyMainSkippedForVision: + """Vision auto-detect must not return a text-only main-provider client.""" + + def test_text_only_main_skipped_when_no_aggregator(self, isolated_home, monkeypatch): + """DeepSeek main + no aggregator credentials โ†’ no client built. + + Pre-fix this silently returned the deepseek client with model + substitution, producing ``unknown variant 'image_url'`` at call time. + """ + _write_config(isolated_home, """ +model: + provider: deepseek + default: deepseek-v4-pro +""") + monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test") + _fresh_modules() + + from agent.auxiliary_client import resolve_vision_provider_client + provider, client, _model = resolve_vision_provider_client(provider="auto") + assert client is None, ( + f"Vision auto-detect must skip text-only main {provider!r} when " + "no vision-capable aggregator is available, not return a client " + "that will fail at API time" + ) + + def test_vision_capable_main_used(self, isolated_home, monkeypatch): + """Vision-capable main provider should be returned by auto chain.""" + _write_config(isolated_home, """ +model: + provider: anthropic + default: claude-sonnet-4-6 +""") + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test") + _fresh_modules() + + from agent.auxiliary_client import resolve_vision_provider_client + provider, client, _model = resolve_vision_provider_client(provider="auto") + assert client is not None + assert provider == "anthropic" + + def test_unknown_capability_does_not_block(self, isolated_home, monkeypatch): + """When models.dev has no entry, fall back to permissive (attempt the call). + + This keeps new/custom providers working โ€” only providers we have + cataloged as text-only are skipped. + """ + _fresh_modules() + from agent.auxiliary_client import _main_model_supports_vision + # Bogus provider/model โ€” capability lookup returns None โ†’ permissive. + assert _main_model_supports_vision("nonexistent-provider", "nonexistent-model") is True + + +# --------------------------------------------------------------------------- +# Fix 3: check_vision_requirements + check_browser_vision_requirements parity +# --------------------------------------------------------------------------- + + +class TestVisionToolGating: + """Tool visibility must match runtime capability.""" + + def test_check_vision_succeeds_for_aliased_openai(self, isolated_home, monkeypatch): + """The user's exact reported scenario: provider=openai unhides + vision_analyze instead of silently dropping it.""" + _write_config(isolated_home, """ +auxiliary: + vision: + provider: openai + model: gpt-4o-mini +""") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test") + _fresh_modules() + + from tools.vision_tools import check_vision_requirements + assert check_vision_requirements() is True + + def test_check_vision_falls_back_to_auto(self, isolated_home, monkeypatch): + """Bad explicit provider doesn't hide the tool when auto fallback works. + + Mirrors call_llm's runtime fallback chain. + """ + _write_config(isolated_home, """ +model: + provider: openrouter + default: anthropic/claude-sonnet-4 +auxiliary: + vision: + provider: not-a-real-provider +""") + monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test") + _fresh_modules() + + from tools.vision_tools import check_vision_requirements + assert check_vision_requirements() is True + + def test_check_vision_false_with_text_only_main_and_no_aggregator( + self, isolated_home, monkeypatch + ): + _write_config(isolated_home, """ +model: + provider: deepseek + default: deepseek-v4-pro +""") + monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test") + _fresh_modules() + + from tools.vision_tools import check_vision_requirements + assert check_vision_requirements() is False + + def test_browser_vision_requires_both_browser_and_vision(self, isolated_home, monkeypatch): + """``browser_vision`` must not be advertised when vision is unavailable.""" + from unittest.mock import patch + + _write_config(isolated_home, """ +model: + provider: deepseek + default: deepseek-v4-pro +""") + monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test") + _fresh_modules() + + import tools.browser_tool + # Force the browser side to True so we exercise the vision-gating part. + with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True): + assert tools.browser_tool.check_browser_vision_requirements() is False + + def test_browser_vision_false_when_browser_missing(self, isolated_home, monkeypatch): + from unittest.mock import patch + + _write_config(isolated_home, """ +model: + provider: openrouter + default: anthropic/claude-sonnet-4 +""") + monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test") + _fresh_modules() + + import tools.browser_tool + with patch.object(tools.browser_tool, "check_browser_requirements", return_value=False): + # Vision available but browser missing โ†’ still False. + assert tools.browser_tool.check_browser_vision_requirements() is False + + def test_browser_vision_true_when_both_available(self, isolated_home, monkeypatch): + from unittest.mock import patch + + _write_config(isolated_home, """ +model: + provider: openrouter + default: anthropic/claude-sonnet-4 +""") + monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test") + _fresh_modules() + + import tools.browser_tool + with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True): + assert tools.browser_tool.check_browser_vision_requirements() is True diff --git a/tests/agent/transports/test_chat_completions.py b/tests/agent/transports/test_chat_completions.py index 2e7b9da2f8d..9f3a205f8a8 100644 --- a/tests/agent/transports/test_chat_completions.py +++ b/tests/agent/transports/test_chat_completions.py @@ -66,6 +66,38 @@ class TestChatCompletionsBasic: # Original list untouched (deepcopy-on-demand) assert msgs[2]["tool_name"] == "execute_code" + def test_convert_messages_strips_internal_scaffolding_markers(self, transport): + """Hermes-internal ``_``-prefixed markers must never reach the wire. + + The empty-response recovery path appends synthetic messages tagged + with ``_empty_recovery_synthetic``; permissive providers ignore the + unknown key, but strict gateways (opencode-go, codex.nekos.me) + reject the request, poisoning every later turn in the session. + """ + msgs = [ + {"role": "user", "content": "run the task"}, + {"role": "assistant", "content": "(empty)", "_empty_recovery_synthetic": True}, + {"role": "user", "content": "continue", "_empty_recovery_synthetic": True}, + {"role": "assistant", "content": "done", "_thinking_prefill": True, + "_empty_terminal_sentinel": True}, + ] + result = transport.convert_messages(msgs) + for m in result: + assert not any(k.startswith("_") for k in m), m + # Visible content preserved + assert result[1]["content"] == "(empty)" + assert result[2]["content"] == "continue" + # Original list untouched (deepcopy-on-demand) + assert msgs[1]["_empty_recovery_synthetic"] is True + + def test_convert_messages_clean_list_is_identity(self, transport): + """A list with no internal/codex keys is returned as-is (no copy).""" + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + assert transport.convert_messages(msgs) is msgs + class TestChatCompletionsBuildKwargs: diff --git a/tests/agent/transports/test_codex_app_server_session.py b/tests/agent/transports/test_codex_app_server_session.py index b192d64e1c8..d43a92a1eb9 100644 --- a/tests/agent/transports/test_codex_app_server_session.py +++ b/tests/agent/transports/test_codex_app_server_session.py @@ -20,6 +20,7 @@ from agent.transports.codex_app_server_session import ( TurnResult, _ServerRequestRouting, _approval_choice_to_codex_decision, + _coerce_turn_input_text, ) @@ -128,6 +129,15 @@ class TestApprovalChoiceMapping: assert _approval_choice_to_codex_decision(choice) == expected +class TestTurnInputCoercion: + def test_list_content_keeps_text_and_marks_images(self): + text = _coerce_turn_input_text([ + {"type": "text", "text": "caption"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]) + assert text == "caption\n\n[image attached]" + + # ---- lifecycle ---- class TestLifecycle: @@ -188,6 +198,35 @@ class TestRunTurn: # turn_id propagated for downstream session-DB linkage assert r.turn_id == "turn-fake-001" + def test_rich_content_turn_is_collapsed_to_text_payload(self): + client = FakeClient() + client.queue_notification( + "turn/completed", + threadId="t", + turn={"id": "tu1", "status": "completed", "error": None}, + ) + s = make_session(client) + r = s.run_turn( + [ + { + "type": "text", + "text": "look at this\n\n[Image attached at: /tmp/a.png]", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,abc"}, + }, + ], + turn_timeout=2.0, + ) + assert r.error is None + method, params = next(req for req in client.requests if req[0] == "turn/start") + assert method == "turn/start" + text = params["input"][0]["text"] + assert isinstance(text, str) + assert "[Image attached at: /tmp/a.png]" in text + assert "[image attached]" in text + def test_tool_iteration_counter_ticks(self): client = FakeClient() # Two completed exec items + one final agent message diff --git a/tests/cli/test_branch_command.py b/tests/cli/test_branch_command.py index 409ab295fc0..cf48384403f 100644 --- a/tests/cli/test_branch_command.py +++ b/tests/cli/test_branch_command.py @@ -168,6 +168,25 @@ class TestBranchCommandCLI: assert cli_instance._resumed is True + def test_branch_rotates_hermes_session_id_env_and_context(self, cli_instance, session_db): + """Branching must update process-local session-id readers too.""" + from cli import HermesCLI + from gateway.session_context import _UNSET, _VAR_MAP, get_session_env + + old_session_id = cli_instance.session_id + os.environ["HERMES_SESSION_ID"] = old_session_id + _VAR_MAP["HERMES_SESSION_ID"].set(old_session_id) + + try: + HermesCLI._handle_branch_command(cli_instance, "/branch") + + assert cli_instance.session_id != old_session_id + assert os.environ["HERMES_SESSION_ID"] == cli_instance.session_id + assert get_session_env("HERMES_SESSION_ID") == cli_instance.session_id + finally: + os.environ.pop("HERMES_SESSION_ID", None) + _VAR_MAP["HERMES_SESSION_ID"].set(_UNSET) + def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db): """The /branch command must notify memory providers of the rotation. diff --git a/tests/cli/test_cli_init.py b/tests/cli/test_cli_init.py index b05df5220c5..5849b5b490f 100644 --- a/tests/cli/test_cli_init.py +++ b/tests/cli/test_cli_init.py @@ -102,6 +102,20 @@ class TestVerboseAndToolProgress: assert cli.tool_progress_mode in {"off", "new", "all", "verbose"} +class TestFallbackChainInit: + def test_merges_new_and_legacy_fallback_config(self): + cli = _make_cli(config_overrides={ + "fallback_providers": [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + ], + "fallback_model": {"provider": "nous", "model": "Hermes-4"}, + }) + assert cli._fallback_model == [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + {"provider": "nous", "model": "Hermes-4"}, + ] + + class TestBusyInputMode: def test_default_busy_input_mode_is_interrupt(self): cli = _make_cli() @@ -317,7 +331,63 @@ class TestHistoryDisplay: assert "Recent sessions" in output assert "Checking Running Hermes Agent" in output - assert "Use /resume to continue" in output + assert "Use /resume" in output + assert "session title" in output + + def test_resume_updates_hermes_session_id_env_and_context(self, tmp_path): + from gateway.session_context import _UNSET, _VAR_MAP, get_session_env + from hermes_state import SessionDB + + cli = _make_cli() + cli.session_id = "current_session" + cli.conversation_history = [] + cli.agent = None + cli._session_db = SessionDB(db_path=tmp_path / "state.db") + cli._session_db.create_session("current_session", "cli") + cli._session_db.create_session("target_session", "cli") + cli._session_db.append_message("target_session", "user", "hello from resumed session") + + os.environ["HERMES_SESSION_ID"] = "current_session" + _VAR_MAP["HERMES_SESSION_ID"].set("current_session") + + try: + cli._handle_resume_command("/resume target_session") + + assert cli.session_id == "target_session" + assert os.environ["HERMES_SESSION_ID"] == "target_session" + assert get_session_env("HERMES_SESSION_ID") == "target_session" + finally: + cli._session_db.close() + os.environ.pop("HERMES_SESSION_ID", None) + _VAR_MAP["HERMES_SESSION_ID"].set(_UNSET) + + def test_resume_list_shows_full_long_titles(self, capsys): + """Long session titles render in full in the /resume table โ€” not + truncated to 30 chars (fixes #14082).""" + cli = _make_cli() + cli.session_id = "current" + cli._session_db = MagicMock() + long_title = "Salvage BytePlus Volcengine PR With Fixes" + cli._session_db.list_sessions_rich.return_value = [ + { + "id": "current", + "title": "Current", + "preview": "Current preview", + "last_active": 0, + }, + { + "id": "20260401_201329_d85961", + "title": long_title, + "preview": "fix byteplus pr and resume", + "last_active": 0, + }, + ] + + cli._handle_resume_command("/resume") + output = capsys.readouterr().out + + assert long_title in output + assert "20260401_201329_d85961" in output def test_sessions_command_no_args_lists_recent_sessions(self, capsys): """/sessions with no args prints the recent-sessions table (TUI parity). @@ -429,8 +499,8 @@ class TestRootLevelProviderOverride: assert cfg["model"]["provider"] == "openrouter" - def test_root_provider_ignored_when_default_model_provider_exists(self, tmp_path, monkeypatch): - """Even when model.provider is the default 'auto', root-level provider is ignored.""" + def test_root_provider_used_as_fallback_when_model_provider_missing(self, tmp_path, monkeypatch): + """Legacy root-level provider still populates model.provider in the CLI loader.""" import yaml hermes_home = tmp_path / ".hermes" @@ -450,8 +520,29 @@ class TestRootLevelProviderOverride: monkeypatch.setattr(cli, "_hermes_home", hermes_home) cfg = cli.load_cli_config() - # Root-level "opencode-go" must NOT leak through - assert cfg["model"]["provider"] != "opencode-go" + assert cfg["model"]["provider"] == "opencode-go" + + def test_root_base_url_used_as_fallback_when_model_base_url_missing(self, tmp_path, monkeypatch): + """Legacy root-level base_url still populates model.base_url in the CLI loader.""" + import yaml + + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + config_path = hermes_home / "config.yaml" + config_path.write_text(yaml.safe_dump({ + "base_url": "https://example.com/v1", + "model": { + "default": "google/gemini-3-flash-preview", + }, + })) + + import cli + monkeypatch.setattr(cli, "_hermes_home", hermes_home) + cfg = cli.load_cli_config() + + assert cfg["model"]["base_url"] == "https://example.com/v1" def test_terminal_vercel_runtime_bridged_to_env(self, tmp_path, monkeypatch): """Classic CLI must expose terminal.vercel_runtime to terminal_tool.py.""" diff --git a/tests/cli/test_cli_new_session.py b/tests/cli/test_cli_new_session.py index 05503552cec..c56ab63cf24 100644 --- a/tests/cli/test_cli_new_session.py +++ b/tests/cli/test_cli_new_session.py @@ -8,6 +8,8 @@ import sys from datetime import datetime, timedelta from unittest.mock import MagicMock, patch +import pytest + from hermes_state import SessionDB from tools.todo_tool import TodoStore @@ -138,6 +140,15 @@ def _prepare_cli_with_active_session(tmp_path): return cli +@pytest.fixture(autouse=True) +def _reset_session_id_context(): + from gateway.session_context import _UNSET, _VAR_MAP + + yield + os.environ.pop("HERMES_SESSION_ID", None) + _VAR_MAP["HERMES_SESSION_ID"].set(_UNSET) + + def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path): cli = _prepare_cli_with_active_session(tmp_path) old_session_id = cli.session_id @@ -164,6 +175,21 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path) cli.agent._invalidate_system_prompt.assert_called_once() +def test_new_command_rotates_hermes_session_id_env_and_context(tmp_path): + from gateway.session_context import _VAR_MAP, get_session_env + + cli = _prepare_cli_with_active_session(tmp_path) + old_session_id = cli.session_id + os.environ["HERMES_SESSION_ID"] = old_session_id + _VAR_MAP["HERMES_SESSION_ID"].set(old_session_id) + + cli.process_command("/new") + + assert cli.session_id != old_session_id + assert os.environ["HERMES_SESSION_ID"] == cli.session_id + assert get_session_env("HERMES_SESSION_ID") == cli.session_id + + def test_reset_command_is_alias_for_new_session(tmp_path): cli = _prepare_cli_with_active_session(tmp_path) old_session_id = cli.session_id diff --git a/tests/cli/test_cli_resume_command.py b/tests/cli/test_cli_resume_command.py new file mode 100644 index 00000000000..2790ce5be69 --- /dev/null +++ b/tests/cli/test_cli_resume_command.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock, patch + +from cli import HermesCLI + + +def _make_cli(): + cli_obj = HermesCLI.__new__(HermesCLI) + cli_obj.session_id = "current_session" + cli_obj._resumed = False + cli_obj._pending_title = None + cli_obj.conversation_history = [] + cli_obj.agent = None + cli_obj._session_db = MagicMock() + # _handle_resume_command now triggers _display_resumed_history (#31695), + # which reads self.resume_display. "minimal" short-circuits the recap so + # the test only exercises session-switch behavior. + cli_obj.resume_display = "minimal" + return cli_obj + + +class TestCliResumeCommand: + def test_show_recent_sessions_includes_indexes_and_resume_hint(self, capsys): + cli_obj = _make_cli() + cli_obj._list_recent_sessions = MagicMock(return_value=[ + {"id": "sess_002", "title": "Coding", "preview": "build feature", "last_active": None}, + {"id": "sess_001", "title": "Research", "preview": "read docs", "last_active": None}, + ]) + + shown = cli_obj._show_recent_sessions(reason="resume") + output = capsys.readouterr().out + + assert shown is True + assert "1" in output + assert "2" in output + assert "Coding" in output + assert "Research" in output + assert "/resume 2" in output + assert "/resume " in output + + def test_handle_resume_by_index_switches_to_numbered_session(self): + cli_obj = _make_cli() + cli_obj._list_recent_sessions = MagicMock(return_value=[ + {"id": "sess_002", "title": "Coding"}, + {"id": "sess_001", "title": "Research"}, + ]) + cli_obj._session_db.get_session.return_value = {"id": "sess_001", "title": "Research"} + cli_obj._session_db.get_messages_as_conversation.return_value = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + # resolve_resume_session_id passes the id through when no compression chain. + cli_obj._session_db.resolve_resume_session_id.return_value = "sess_001" + + with ( + patch("hermes_cli.main._resolve_session_by_name_or_id", return_value=None), + patch("cli._cprint") as mock_cprint, + ): + cli_obj._handle_resume_command("/resume 2") + + printed = " ".join(str(call) for call in mock_cprint.call_args_list) + assert cli_obj.session_id == "sess_001" + assert "Resumed session sess_001" in printed + assert "Research" in printed + + def test_handle_resume_by_index_out_of_range(self): + cli_obj = _make_cli() + cli_obj._list_recent_sessions = MagicMock(return_value=[ + {"id": "sess_002", "title": "Coding"}, + ]) + + with patch("cli._cprint") as mock_cprint: + cli_obj._handle_resume_command("/resume 9") + + printed = " ".join(str(call) for call in mock_cprint.call_args_list) + assert "out of range" in printed.lower() + assert "/resume" in printed + assert cli_obj.session_id == "current_session" diff --git a/tests/cli/test_destructive_slash_confirm.py b/tests/cli/test_destructive_slash_confirm.py index 1b2fc8c0b1f..88103ac8dcd 100644 --- a/tests/cli/test_destructive_slash_confirm.py +++ b/tests/cli/test_destructive_slash_confirm.py @@ -209,3 +209,123 @@ def test_slash_confirm_display_fragments_include_choice_mapping(): assert "[2] Always Approve" in rendered assert "[3] Cancel" in rendered assert "Type 1/2/3" in rendered + + +# --------------------------------------------------------------------------- +# Inline-skip escape hatch (issue #30768) +# +# Users on platforms where the prompt_toolkit modal doesn't dispatch keys +# (currently native Windows PowerShell) need a way to bypass the confirmation +# without flipping the config gate. ``/reset now``, ``/new --yes``, ``/clear +# -y`` all skip the modal and return "once" immediately. +# --------------------------------------------------------------------------- + + +def test_split_destructive_skip_recognized_tokens(): + """``now``, ``--yes``, and ``-y`` are recognized as skip tokens.""" + from cli import HermesCLI + + assert HermesCLI._split_destructive_skip("/reset now") == ("", True) + assert HermesCLI._split_destructive_skip("/clear --yes") == ("", True) + assert HermesCLI._split_destructive_skip("/undo -y") == ("", True) + + +def test_split_destructive_skip_strips_command_word(): + """Leading ``/cmd`` token is stripped; remaining args survive.""" + from cli import HermesCLI + + assert HermesCLI._split_destructive_skip("/new My title") == ("My title", False) + assert HermesCLI._split_destructive_skip("/new --yes My title") == ("My title", True) + + +def test_split_destructive_skip_case_insensitive(): + """Token matching is case-insensitive but not a substring match.""" + from cli import HermesCLI + + assert HermesCLI._split_destructive_skip("/new NOW") == ("", True) + # Substring match must NOT trigger โ€” "Now-Title" is a literal title token. + assert HermesCLI._split_destructive_skip("/new Now-Title") == ("Now-Title", False) + + +def test_split_destructive_skip_handles_empty_and_none(): + """Defensive against missing/empty input.""" + from cli import HermesCLI + + assert HermesCLI._split_destructive_skip(None) == ("", False) + assert HermesCLI._split_destructive_skip("") == ("", False) + assert HermesCLI._split_destructive_skip(" ") == ("", False) + + +def test_confirm_destructive_slash_now_skips_modal(): + """``/reset now`` skips the modal even when the gate is on.""" + from cli import HermesCLI + + # Build a prompt stub that fails the test if invoked โ€” proving the modal + # was never reached. + def _explode(**_kw): + raise AssertionError("modal must not be invoked when inline-skip present") + + self_ = SimpleNamespace( + _app=None, + _prompt_text_input_modal=_explode, + ) + self_._normalize_slash_confirm_choice = _bound( + HermesCLI._normalize_slash_confirm_choice, self_, + ) + self_._split_destructive_skip = HermesCLI._split_destructive_skip # classmethod + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + result = _bound(HermesCLI._confirm_destructive_slash, self_)( + "new", "detail", cmd_original="/reset now", + ) + + assert result == "once" + + +def test_confirm_destructive_slash_yes_flag_skips_modal(): + """``--yes`` flag is equivalent to ``now``.""" + from cli import HermesCLI + + def _explode(**_kw): + raise AssertionError("modal must not be invoked when --yes present") + + self_ = SimpleNamespace( + _app=None, + _prompt_text_input_modal=_explode, + ) + self_._normalize_slash_confirm_choice = _bound( + HermesCLI._normalize_slash_confirm_choice, self_, + ) + self_._split_destructive_skip = HermesCLI._split_destructive_skip + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + result = _bound(HermesCLI._confirm_destructive_slash, self_)( + "new", "detail", cmd_original="/new --yes My Session", + ) + + assert result == "once" + + +def test_confirm_destructive_slash_no_skip_token_still_prompts(): + """Without a skip token the gate-on path still consults the modal.""" + from cli import HermesCLI + + self_ = _make_self(prompt_response="3") # cancel + self_._split_destructive_skip = HermesCLI._split_destructive_skip + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + result = _bound(HermesCLI._confirm_destructive_slash, self_)( + "new", "detail", cmd_original="/new My Session", + ) + + # Prompt was reached and returned cancel โ†’ None. + assert result is None diff --git a/tests/cli/test_destructive_slash_inline_skip_e2e.py b/tests/cli/test_destructive_slash_inline_skip_e2e.py new file mode 100644 index 00000000000..3ed434ab47a --- /dev/null +++ b/tests/cli/test_destructive_slash_inline_skip_e2e.py @@ -0,0 +1,129 @@ +"""End-to-end integration test for the destructive-slash inline-skip path. + +Drives ``HermesCLI.process_command("/reset now")`` against a minimal stand-in +and verifies: + +1. ``new_session`` was invoked (the command actually ran) +2. ``_prompt_text_input_modal`` was NOT invoked (modal bypassed) +3. The skip token did not leak into the session title + +This is the regression test for issue #30768 โ€” the inline-skip escape hatch +must work without ever touching the modal, on every platform. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + + +def _make_cli_stub(): + """Build a minimal HermesCLI-shaped object that can run ``process_command`` + for the destructive-slash branches without spinning up a real TUI.""" + from cli import HermesCLI + + new_session_calls = [] + + def _capture_new_session(self_, title=None, silent=False): + new_session_calls.append({"title": title, "silent": silent}) + + self_ = SimpleNamespace( + _app=None, + _prompt_text_input_modal=lambda **_kw: (_ for _ in ()).throw( + AssertionError("modal must not be invoked when inline-skip token present") + ), + new_session=lambda **kw: _capture_new_session(self_, **kw), + # Stub out side-effects the destructive-slash branches reach for. + console=SimpleNamespace(clear=lambda: None), + compact=False, + model="stub-model", + session_id="stub-session", + enabled_toolsets=[], + _pending_title=None, + _session_db=None, + ) + # Bind the methods we need under test. + self_._split_destructive_skip = HermesCLI._split_destructive_skip + self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__( + self_, type(self_) + ) + self_.process_command = HermesCLI.process_command.__get__(self_, type(self_)) + return self_, new_session_calls + + +def test_reset_now_invokes_new_session_without_modal(): + """``/reset now`` runs ``new_session`` and never touches the modal.""" + self_, calls = _make_cli_stub() + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + self_.process_command("/reset now") + + assert calls, "new_session was never invoked" + # The /new branch passes title=None when there's no non-skip remainder. + assert calls[0]["title"] is None + + +def test_new_yes_with_title_preserves_title(): + """``/new --yes My Session`` runs ``new_session(title='My Session')``.""" + self_, calls = _make_cli_stub() + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + self_.process_command("/new --yes My Session") + + assert calls, "new_session was never invoked" + assert calls[0]["title"] == "My Session" + + +def test_new_without_skip_token_still_consults_modal(): + """``/new My Session`` (no skip token) must reach the modal. + + Sanity check that we haven't accidentally short-circuited the normal path. + """ + from cli import HermesCLI + + new_session_calls = [] + modal_calls = [] + + def _capture_new_session(self_, title=None, silent=False): + new_session_calls.append({"title": title, "silent": silent}) + + def _record_modal(**kw): + modal_calls.append(kw) + # Simulate user cancelling so new_session is not called. + return "3" + + self_ = SimpleNamespace( + _app=None, + _prompt_text_input_modal=_record_modal, + new_session=lambda **kw: _capture_new_session(self_, **kw), + console=SimpleNamespace(clear=lambda: None), + compact=False, + model="stub-model", + session_id="stub-session", + enabled_toolsets=[], + _pending_title=None, + _session_db=None, + ) + self_._split_destructive_skip = HermesCLI._split_destructive_skip + self_._normalize_slash_confirm_choice = HermesCLI._normalize_slash_confirm_choice.__get__( + self_, type(self_) + ) + self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__( + self_, type(self_) + ) + self_.process_command = HermesCLI.process_command.__get__(self_, type(self_)) + + with patch( + "cli.load_cli_config", + return_value={"approvals": {"destructive_slash_confirm": True}}, + ): + self_.process_command("/new My Session") + + assert modal_calls, "modal must be reached when no skip token is present" + assert not new_session_calls, "user cancelled โ€” new_session must not run" diff --git a/tests/cli/test_resume_display.py b/tests/cli/test_resume_display.py index ffeb4402cdf..be9282f8595 100644 --- a/tests/cli/test_resume_display.py +++ b/tests/cli/test_resume_display.py @@ -155,14 +155,34 @@ class TestDisplayResumedHistory: assert "Page content" not in output def test_tool_calls_shown_as_summary(self): - cli = _make_cli() + # Disable tool-only skip so the summary line is rendered for this fixture. + cli = _make_cli(config_overrides={"display": {"resume_skip_tool_only": False}}) cli.conversation_history = _tool_call_history() - output = self._capture_display(cli) + import cli as _cli_mod + # CLI_CONFIG is read at call-time inside _display_resumed_history, so + # apply the override for the duration of the capture, not just at init. + with patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": { + "display": {"resume_skip_tool_only": False, "resume_display": "full"} + }}): + output = self._capture_display(cli) assert "2 tool calls" in output assert "web_search" in output assert "web_extract" in output + def test_tool_only_message_skipped_by_default(self): + """Assistant messages with only tool_calls (no text) are skipped when + resume_skip_tool_only=True (the default). The summary line is hidden. + """ + cli = _make_cli() + cli.conversation_history = _tool_call_history() + output = self._capture_display(cli) + + # The tool-only assistant entry should be skipped + assert "2 tool calls" not in output + # The final text reply should still appear + assert "Here are some great Python tutorials" in output + def test_long_user_message_truncated(self): cli = _make_cli() long_text = "A" * 500 @@ -611,6 +631,55 @@ class TestPreloadResumedSession: assert "1 user messages" not in output +# โ”€โ”€ Tests for _handle_resume_command recap display โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestHandleResumeCommandRecap: + """In-session /resume should show the same recap panel as startup resume.""" + + def test_resume_command_displays_recap_when_messages_restored(self): + cli = _make_cli() + cli.session_id = "current_session" + messages = _simple_history() + + mock_db = MagicMock() + mock_db.get_session.return_value = {"id": "target_session", "title": "Test Session"} + mock_db.get_messages_as_conversation.return_value = messages + # resolve_resume_session_id passes the id through when no compression chain. + mock_db.resolve_resume_session_id.return_value = "target_session" + cli._session_db = mock_db + + with ( + patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"), + patch.object(cli, "_display_resumed_history") as display_mock, + ): + cli._handle_resume_command("/resume test session") + + assert cli.session_id == "target_session" + assert cli.conversation_history == messages + mock_db.end_session.assert_called_once_with("current_session", "resumed_other") + mock_db.reopen_session.assert_called_once_with("target_session") + display_mock.assert_called_once_with() + + def test_resume_command_skips_recap_when_session_has_no_messages(self): + cli = _make_cli() + cli.session_id = "current_session" + + mock_db = MagicMock() + mock_db.get_session.return_value = {"id": "target_session", "title": None} + mock_db.get_messages_as_conversation.return_value = [] + mock_db.resolve_resume_session_id.return_value = "target_session" + cli._session_db = mock_db + + with ( + patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"), + patch.object(cli, "_display_resumed_history") as display_mock, + ): + cli._handle_resume_command("/resume target_session") + + display_mock.assert_not_called() + + # โ”€โ”€ Integration: _init_agent skips when preloaded โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/tests/cli/test_tool_progress_scrollback.py b/tests/cli/test_tool_progress_scrollback.py index 7924f41598b..d6af08deab9 100644 --- a/tests/cli/test_tool_progress_scrollback.py +++ b/tests/cli/test_tool_progress_scrollback.py @@ -14,9 +14,10 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) # Module-level reference to the cli module (set by _make_cli on first call) _cli_mod = None +_UNSET = object() -def _make_cli(tool_progress="all"): +def _make_cli(tool_progress="all", verbose=_UNSET): """Create a HermesCLI instance with minimal mocking.""" global _cli_mod _clean_config = { @@ -54,7 +55,9 @@ def _make_cli(tool_progress="all"): _cli_mod = mod with patch.object(mod, "get_tool_definitions", return_value=[]), \ patch.dict(mod.__dict__, {"CLI_CONFIG": _clean_config}): - return mod.HermesCLI() + if verbose is _UNSET: + return mod.HermesCLI() + return mod.HermesCLI(verbose=verbose) class TestToolProgressScrollback: @@ -122,14 +125,21 @@ class TestToolProgressScrollback: mock_print.assert_not_called() def test_error_suffix_on_failed_tool(self): - """When is_error=True, the stacked line includes [error].""" + """When a failed tool's result is forwarded, the stacked line surfaces + the specific error (e.g. ``[exit 1]`` or ``[File not found: x]``) + instead of the legacy generic ``[error]`` suffix.""" + import json cli = _make_cli(tool_progress="all") - cli._on_tool_progress("tool.started", "terminal", "bad cmd", {"command": "bad cmd"}) + cli._on_tool_progress("tool.started", "terminal", "false", {"command": "false"}) with patch.object(_cli_mod, "_cprint") as mock_print: - cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=True) + cli._on_tool_progress( + "tool.completed", "terminal", None, None, + duration=0.5, is_error=True, + result=json.dumps({"output": "", "exit_code": 1}), + ) line = mock_print.call_args[0][0] - assert "[error]" in line + assert "[exit 1]" in line def test_spinner_still_updates_on_started(self): """tool.started still updates the spinner text for live display.""" @@ -168,6 +178,35 @@ class TestToolProgressScrollback: mock_print.assert_not_called() + def test_verbose_mode_config_does_not_enable_global_debug_logging(self): + """display.tool_progress=verbose controls TOOL-CALL DISPLAY ONLY. + + It must NOT auto-flip self.verbose, which controls root-logger DEBUG + level for the entire process (every module spews to console). PR + #6a1aa420e had coupled them, causing all debug logs to flood the + terminal whenever a user picked tool_progress: verbose for richer + per-tool rendering. + """ + cli = _make_cli(tool_progress="verbose") + + assert cli.tool_progress_mode == "verbose" + assert cli.verbose is False + + def test_explicit_verbose_argument_wins_over_config(self): + """Explicit verbose=True from the CLI flag still enables DEBUG logging + regardless of tool_progress_mode.""" + cli = _make_cli(tool_progress="off", verbose=True) + + assert cli.tool_progress_mode == "off" + assert cli.verbose is True + + def test_explicit_non_verbose_argument_keeps_debug_logging_off(self): + """Explicit verbose=False overrides any default to enable DEBUG.""" + cli = _make_cli(tool_progress="verbose", verbose=False) + + assert cli.tool_progress_mode == "verbose" + assert cli.verbose is False + def test_pending_info_stores_on_started(self): """tool.started stores args for later use by tool.completed.""" cli = _make_cli(tool_progress="all") diff --git a/tests/conftest.py b/tests/conftest.py index 3cdce42c495..0514702546b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -358,6 +358,10 @@ def _hermetic_environment(tmp_path, monkeypatch): monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1") monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1") + # Tirith auto-installs from GitHub when enabled and missing. Unit tests + # should never perform that implicit network/bootstrap path; Tirith-specific + # tests opt back in by patching the security config directly. + monkeypatch.setenv("TIRITH_ENABLED", "false") # 5. Reset plugin singleton so tests don't leak plugins from # ~/.hermes/plugins/ (which, per step 3, is now empty โ€” but the diff --git a/tests/cron/test_scheduler.py b/tests/cron/test_scheduler.py index 32485a917e0..62bc6b688a0 100644 --- a/tests/cron/test_scheduler.py +++ b/tests/cron/test_scheduler.py @@ -490,6 +490,17 @@ class TestRoutingIntents: class TestDeliverResultWrapping: """Verify that cron deliveries are wrapped with header/footer and no longer mirrored.""" + def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"): + root = tmp_path / "media-cache" + media_file = root / name + media_file.parent.mkdir(parents=True, exist_ok=True) + media_file.write_bytes(data) + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + (root,), + ) + return media_file.resolve() + def test_delivery_wraps_content_with_header_and_footer(self): """Delivered content should include task name header and agent-invisible note.""" from gateway.config import Platform @@ -564,9 +575,10 @@ class TestDeliverResultWrapping: assert "Cronjob Response" not in sent_content assert "The agent cannot see" not in sent_content - def test_delivery_extracts_media_tags_before_send(self): + def test_delivery_extracts_media_tags_before_send(self, tmp_path, monkeypatch): """Cron delivery should pass MEDIA attachments separately to the send helper.""" from gateway.config import Platform + media_path = self._safe_media_path(tmp_path, monkeypatch, "test-voice.ogg") pconfig = MagicMock() pconfig.enabled = True @@ -581,7 +593,7 @@ class TestDeliverResultWrapping: "deliver": "origin", "origin": {"platform": "telegram", "chat_id": "123"}, } - _deliver_result(job, "Title\nMEDIA:/tmp/test-voice.ogg") + _deliver_result(job, f"Title\nMEDIA:{media_path}") send_mock.assert_called_once() args, kwargs = send_mock.call_args @@ -589,14 +601,15 @@ class TestDeliverResultWrapping: assert "MEDIA:" not in args[3] assert "Title" in args[3] # Media files should be forwarded separately - assert kwargs["media_files"] == [("/tmp/test-voice.ogg", False)] + assert kwargs["media_files"] == [(str(media_path), False)] - def test_live_adapter_sends_media_as_attachments(self): + def test_live_adapter_sends_media_as_attachments(self, tmp_path, monkeypatch): """When a live adapter is available, MEDIA files should be sent as native platform attachments (e.g., Discord voice, Telegram audio) rather than as literal 'MEDIA:/path' text.""" from gateway.config import Platform from concurrent.futures import Future + media_path = self._safe_media_path(tmp_path, monkeypatch, "cron-voice.mp3") adapter = AsyncMock() adapter.send.return_value = MagicMock(success=True) @@ -628,7 +641,7 @@ class TestDeliverResultWrapping: patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro): _deliver_result( job, - "Here is TTS\nMEDIA:/tmp/cron-voice.mp3", + f"Here is TTS\nMEDIA:{media_path}", adapters={Platform.DISCORD: adapter}, loop=loop, ) @@ -642,12 +655,13 @@ class TestDeliverResultWrapping: # Audio file should be sent as a voice attachment adapter.send_voice.assert_called_once() voice_call = adapter.send_voice.call_args - assert voice_call[1]["audio_path"] == "/tmp/cron-voice.mp3" + assert voice_call[1]["audio_path"] == str(media_path) - def test_live_adapter_routes_image_to_send_image_file(self): + def test_live_adapter_routes_image_to_send_image_file(self, tmp_path, monkeypatch): """Image MEDIA files should be routed to send_image_file, not send_voice.""" from gateway.config import Platform from concurrent.futures import Future + media_path = self._safe_media_path(tmp_path, monkeypatch, "chart.png") adapter = AsyncMock() adapter.send.return_value = MagicMock(success=True) @@ -678,19 +692,20 @@ class TestDeliverResultWrapping: patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro): _deliver_result( job, - "Chart attached\nMEDIA:/tmp/chart.png", + f"Chart attached\nMEDIA:{media_path}", adapters={Platform.DISCORD: adapter}, loop=loop, ) adapter.send_image_file.assert_called_once() - assert adapter.send_image_file.call_args[1]["image_path"] == "/tmp/chart.png" + assert adapter.send_image_file.call_args[1]["image_path"] == str(media_path) adapter.send_voice.assert_not_called() - def test_live_adapter_media_only_no_text(self): + def test_live_adapter_media_only_no_text(self, tmp_path, monkeypatch): """When content is ONLY a MEDIA tag with no text, media should still be sent.""" from gateway.config import Platform from concurrent.futures import Future + media_path = self._safe_media_path(tmp_path, monkeypatch, "voice.ogg") adapter = AsyncMock() adapter.send_voice.return_value = MagicMock(success=True) @@ -720,7 +735,7 @@ class TestDeliverResultWrapping: patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro): _deliver_result( job, - "[[audio_as_voice]]\nMEDIA:/tmp/voice.ogg", + f"[[audio_as_voice]]\nMEDIA:{media_path}", adapters={Platform.TELEGRAM: adapter}, loop=loop, ) @@ -2164,43 +2179,56 @@ class TestBuildJobPromptBumpUse: class TestSendMediaViaAdapter: """Unit tests for _send_media_via_adapter โ€” routes files to typed adapter methods.""" + def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"): + root = tmp_path / "media-cache" + media_file = root / name + media_file.parent.mkdir(parents=True, exist_ok=True) + media_file.write_bytes(data) + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + (root,), + ) + return media_file.resolve() + @staticmethod def _run_with_loop(adapter, chat_id, media_files, metadata, job): - """Helper: run _send_media_via_adapter with a real running event loop.""" - import asyncio - import threading + """Helper: run _send_media_via_adapter with immediate scheduling.""" + from concurrent.futures import Future - loop = asyncio.new_event_loop() - t = threading.Thread(target=loop.run_forever, daemon=True) - t.start() - try: - _send_media_via_adapter(adapter, chat_id, media_files, metadata, loop, job) - finally: - loop.call_soon_threadsafe(loop.stop) - t.join(timeout=5) - loop.close() + def fake_run_coro(coro, _loop): + coro.close() + completed = Future() + completed.set_result(MagicMock(success=True)) + return completed - def test_video_dispatched_to_send_video(self): + with patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro): + _send_media_via_adapter(adapter, chat_id, media_files, metadata, MagicMock(), job) + + def test_video_dispatched_to_send_video(self, tmp_path, monkeypatch): adapter = MagicMock() adapter.send_video = AsyncMock() - media_files = [("/tmp/clip.mp4", False)] + media_path = self._safe_media_path(tmp_path, monkeypatch, "clip.mp4") + media_files = [(str(media_path), False)] self._run_with_loop(adapter, "123", media_files, None, {"id": "j1"}) adapter.send_video.assert_called_once() - assert adapter.send_video.call_args[1]["video_path"] == "/tmp/clip.mp4" + assert adapter.send_video.call_args[1]["video_path"] == str(media_path) - def test_unknown_ext_dispatched_to_send_document(self): + def test_unknown_ext_dispatched_to_send_document(self, tmp_path, monkeypatch): adapter = MagicMock() adapter.send_document = AsyncMock() - media_files = [("/tmp/report.pdf", False)] + media_path = self._safe_media_path(tmp_path, monkeypatch, "report.pdf") + media_files = [(str(media_path), False)] self._run_with_loop(adapter, "123", media_files, None, {"id": "j2"}) adapter.send_document.assert_called_once() - assert adapter.send_document.call_args[1]["file_path"] == "/tmp/report.pdf" + assert adapter.send_document.call_args[1]["file_path"] == str(media_path) - def test_multiple_media_files_all_delivered(self): + def test_multiple_media_files_all_delivered(self, tmp_path, monkeypatch): adapter = MagicMock() adapter.send_voice = AsyncMock() adapter.send_image_file = AsyncMock() - media_files = [("/tmp/voice.mp3", False), ("/tmp/photo.jpg", False)] + voice_path = self._safe_media_path(tmp_path, monkeypatch, "voice.mp3") + photo_path = self._safe_media_path(tmp_path, monkeypatch, "photo.jpg") + media_files = [(str(voice_path), False), (str(photo_path), False)] self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"}) adapter.send_voice.assert_called_once() adapter.send_image_file.assert_called_once() @@ -2462,7 +2490,7 @@ class TestSendMediaTimeoutCancelsFuture: in-flight coroutine must be cancelled before the next file is tried. """ - def test_media_send_timeout_cancels_future_and_continues(self): + def test_media_send_timeout_cancels_future_and_continues(self, tmp_path, monkeypatch): """End-to-end: _send_media_via_adapter with a future whose .result() raises TimeoutError. Assert cancel() fires and the loop proceeds to the next file rather than hanging or crashing.""" @@ -2493,9 +2521,19 @@ class TestSendMediaTimeoutCancelsFuture: coro.close() return next(futures_iter) + root = tmp_path / "media-cache" + slow = root / "slow.png" + fast = root / "fast.mp4" + slow.parent.mkdir(parents=True) + slow.write_bytes(b"slow") + fast.write_bytes(b"fast") + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + (root,), + ) media_files = [ - ("/tmp/slow.png", False), # times out - ("/tmp/fast.mp4", False), # succeeds + (str(slow), False), # times out + (str(fast), False), # succeeds ] loop = MagicMock() @@ -2509,4 +2547,4 @@ class TestSendMediaTimeoutCancelsFuture: assert timeout_cancel_calls == [True], "future.cancel() must fire on TimeoutError" # 2. Second file still got dispatched โ€” one timeout doesn't abort the batch adapter.send_video.assert_called_once() - assert adapter.send_video.call_args[1]["video_path"] == "/tmp/fast.mp4" + assert adapter.send_video.call_args[1]["video_path"] == str(fast.resolve()) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index acb999e9e34..3adbd557dd1 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -119,7 +119,7 @@ _ensure_slack_mock() import discord # noqa: E402 โ€” mocked above from gateway.platforms.telegram import TelegramAdapter # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 import gateway.platforms.slack as _slack_mod # noqa: E402 _slack_mod.SLACK_AVAILABLE = True diff --git a/tests/gateway/test_active_session_text_merge.py b/tests/gateway/test_active_session_text_merge.py index 087f8dbabd0..05e7a36fd6b 100644 --- a/tests/gateway/test_active_session_text_merge.py +++ b/tests/gateway/test_active_session_text_merge.py @@ -1,20 +1,10 @@ -"""Regression test for #4469. +"""Regression tests for active-session TEXT follow-up queueing. -When the agent is actively running (session present in -``adapter._active_sessions``) and the user fires off multiple TEXT -follow-ups in rapid succession, the previous behaviour was a single-slot -replacement at ``gateway/platforms/base.py``: - - self._pending_messages[session_key] = event - -So three rapid messages ``A``, ``B``, ``C`` arriving while the agent was -still working on the initial turn produced a pending slot containing only -``C``; ``A`` and ``B`` were silently dropped. - -The fix routes the follow-up through ``merge_pending_message_event(..., -merge_text=True)`` so TEXT events accumulate into the existing pending -event's text instead of clobbering it. Photo / media bursts continue to -merge through the same helper (they always did). +When the agent is actively running, rapid text follow-ups should survive as +one next-turn pending message instead of clobbering each other. In +``busy_text_mode=queue`` those active follow-ups first pass through a short +debounce so bursty multi-message thoughts are merged before the active drain +hands off the next turn. """ from __future__ import annotations @@ -22,7 +12,7 @@ from __future__ import annotations import asyncio import sys import types -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -44,16 +34,27 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + SendResult, ) from gateway.session import SessionSource, build_session_key -def _make_event(text: str, chat_id: str = "12345") -> MessageEvent: +def _make_event( + text: str, + chat_id: str = "12345", + *, + chat_type: str = "dm", + user_id: str = "u1", + user_name: str | None = None, + thread_id: str | None = None, +) -> MessageEvent: source = SessionSource( platform=Platform.TELEGRAM, chat_id=chat_id, - chat_type="dm", - user_id="u1", + chat_type=chat_type, + user_id=user_id, + user_name=user_name, + thread_id=thread_id, ) return MessageEvent( text=text, @@ -63,27 +64,26 @@ def _make_event(text: str, chat_id: str = "12345") -> MessageEvent: ) +class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc] + async def connect(self): + pass + + async def disconnect(self): + pass + + async def get_chat_info(self, chat_id): + return None + + async def send(self, *args, **kwargs): + return SendResult(success=True, message_id="x") + + +def _make_initialized_adapter() -> BasePlatformAdapter: + return _DummyAdapter(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) + + def _make_adapter() -> BasePlatformAdapter: - """Build a BasePlatformAdapter without running its heavy __init__. - - We only need the bits ``handle_message`` touches on the active-session - path: ``_active_sessions``, ``_pending_messages``, - ``_message_handler``, ``_busy_session_handler``, ``config``, ``platform``. - """ - - class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc] - async def connect(self): - pass - - async def disconnect(self): - pass - - async def get_chat_info(self, chat_id): - return None - - async def send(self, *args, **kwargs): - return MagicMock(success=True, message_id="x", retryable=False) - + """Build a BasePlatformAdapter without running its heavy __init__.""" adapter = object.__new__(_DummyAdapter) adapter.config = PlatformConfig(enabled=True, token="***") adapter.platform = Platform.TELEGRAM @@ -100,6 +100,10 @@ def _make_adapter() -> BasePlatformAdapter: adapter._fatal_error_retryable = True adapter._fatal_error_handler = None adapter._running = True + adapter._busy_text_mode = "queue" + adapter._busy_text_debounce_seconds = 0.1 + adapter._busy_text_hard_cap_seconds = 1.0 + adapter._text_debounce = {} adapter._auto_tts_default = False adapter._auto_tts_enabled_chats = set() adapter._auto_tts_disabled_chats = set() @@ -107,39 +111,235 @@ def _make_adapter() -> BasePlatformAdapter: return adapter +def _debounced_event(adapter: BasePlatformAdapter, session_key: str) -> MessageEvent: + return adapter._text_debounce[session_key].event + + @pytest.mark.asyncio async def test_rapid_text_followups_accumulate_instead_of_replacing(): - """Three rapid TEXT follow-ups during an active session must all - survive in ``adapter._pending_messages[session_key].text``.""" + """Rapid TEXT follow-ups must all survive in the pending event.""" adapter = _make_adapter() + adapter._busy_text_mode = "" # direct-merge behavior, no debounce first = _make_event("part one") session_key = build_session_key(first.source) - - # Mark the session as active so subsequent messages take the - # "already running" branch in handle_message. adapter._active_sessions[session_key] = asyncio.Event() - second = _make_event("part two") - third = _make_event("part three") + await adapter.handle_message(_make_event("part two")) + await adapter.handle_message(_make_event("part three")) - await adapter.handle_message(second) - await adapter.handle_message(third) - - # Both rapid follow-ups must be preserved, not just the last one. pending = adapter._pending_messages[session_key] - assert pending.text == "part two\npart three", ( - f"expected accumulated text, got {pending.text!r}" + assert pending.text == "part two\npart three" + assert not adapter._active_sessions[session_key].is_set() + + +@pytest.mark.asyncio +async def test_debounce_buffers_rapid_text_then_flushes_to_pending(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 0.05 + + first = _make_event("part one") + session_key = build_session_key(first.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(_make_event("part two")) + assert session_key in adapter._text_debounce + assert _debounced_event(adapter, session_key).text == "part two" + assert session_key not in adapter._pending_messages + + await adapter.handle_message(_make_event("part three")) + assert _debounced_event(adapter, session_key).text == "part two\npart three" + + await asyncio.sleep(0.15) + + assert session_key not in adapter._text_debounce + assert adapter._pending_messages[session_key].text == "part two\npart three" + + +@pytest.mark.asyncio +async def test_debounce_resets_timer_on_new_arrival(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 0.1 + + first = _make_event("one") + session_key = build_session_key(first.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(first) + task1 = adapter._text_debounce[session_key].task + assert task1 is not None + assert not task1.done() + + await adapter.handle_message(_make_event("two")) + task2 = adapter._text_debounce[session_key].task + assert task2 is not None + assert task2 is not task1 + await asyncio.sleep(0) + assert task1.cancelled() or task1.done() + assert adapter._text_debounce[session_key].task is task2 + + await adapter.handle_message(_make_event("three")) + task3 = adapter._text_debounce[session_key].task + assert task3 is not None + assert task3 is not task2 + + await asyncio.sleep(0.2) + assert session_key not in adapter._text_debounce + assert adapter._pending_messages[session_key].text == "one\ntwo\nthree" + + +@pytest.mark.asyncio +async def test_active_drain_force_flushes_debounce_before_release(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 1.0 + processed: list[str] = [] + + async def _handler(event): + processed.append(event.text) + if event.text == "current": + await adapter.handle_message(_make_event("follow up")) + return None + + adapter._message_handler = _handler + current = _make_event("current") + session_key = build_session_key(current.source) + + task = asyncio.create_task(adapter._process_message_background(current, session_key)) + adapter._session_tasks[session_key] = task + await asyncio.wait_for(task, timeout=1.0) + + for _ in range(20): + if processed == ["current", "follow up"] and session_key not in adapter._active_sessions: + break + await asyncio.sleep(0.05) + + assert processed == ["current", "follow up"] + assert session_key not in adapter._text_debounce + assert session_key not in adapter._pending_messages + assert session_key not in adapter._active_sessions + + +@pytest.mark.asyncio +async def test_force_flush_cancels_timer_without_duplicate_processing(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 0.2 + + event = _make_event("queued once") + session_key = build_session_key(event.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(event) + timer_task = adapter._text_debounce[session_key].task + + flushed = await adapter._flush_text_debounce_now(session_key) + assert flushed is True + assert session_key not in adapter._text_debounce + assert adapter._pending_messages[session_key].text == "queued once" + + await asyncio.sleep(0.3) + assert timer_task is not None + assert timer_task.cancelled() or timer_task.done() + assert adapter._pending_messages[session_key].text == "queued once" + + +@pytest.mark.asyncio +async def test_text_debounce_does_not_merge_different_senders(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 1.0 + + first = _make_event( + "from alice", + chat_type="group", + user_id="alice", + user_name="Alice", + thread_id="topic-1", ) - # Interrupt event must be signalled exactly like before. - assert adapter._active_sessions[session_key].is_set() + second = _make_event( + "from bob", + chat_type="group", + user_id="bob", + user_name="Bob", + thread_id="topic-1", + ) + session_key = build_session_key(first.source) + assert session_key == build_session_key(second.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(first) + await adapter.handle_message(second) + + assert adapter._pending_messages[session_key].text == "from alice" + assert _debounced_event(adapter, session_key).text == "from bob" + + +@pytest.mark.asyncio +async def test_control_and_clarify_messages_bypass_text_debounce(): + adapter = _make_adapter() + started: list[str] = [] + + def _fake_start(event, session_key, *, interrupt_event=None): + started.append(event.text) + return True + + adapter._start_session_processing = _fake_start # type: ignore[method-assign] + + await adapter.handle_message(_make_event("/status")) + assert started == ["/status"] + assert adapter._text_debounce == {} + + answer = _make_event("clarify answer") + session_key = build_session_key(answer.source) + adapter._active_sessions[session_key] = asyncio.Event() + adapter._message_handler = AsyncMock(return_value=None) + + with patch("tools.clarify_gateway.get_pending_for_session", return_value=object()): + await adapter.handle_message(answer) + + adapter._message_handler.assert_awaited_once_with(answer) + assert session_key not in adapter._text_debounce + assert session_key not in adapter._pending_messages + + +@pytest.mark.asyncio +async def test_debounce_skipped_when_busy_text_mode_not_queue(): + adapter = _make_adapter() + adapter._busy_text_mode = "" + event = _make_event("direct merge") + session_key = build_session_key(event.source) + adapter._active_sessions[session_key] = asyncio.Event() + + await adapter.handle_message(event) + + assert adapter._pending_messages[session_key].text == "direct merge" + assert session_key not in adapter._text_debounce + + +def test_debounce_respects_env_var_override(monkeypatch): + monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_DEBOUNCE_SECONDS", "2.5") + adapter = _make_initialized_adapter() + assert adapter._busy_text_debounce_seconds == 2.5 + + +@pytest.mark.asyncio +async def test_debounce_cleanup_in_cancel_background_tasks(): + adapter = _make_adapter() + adapter._busy_text_debounce_seconds = 1.0 + + event = _make_event("cleanup test") + session_key = build_session_key(event.source) + adapter._active_sessions[session_key] = asyncio.Event() + await adapter.handle_message(event) + + assert session_key in adapter._text_debounce + + await adapter.cancel_background_tasks() + + assert session_key not in adapter._text_debounce @pytest.mark.asyncio async def test_single_followup_is_stored_as_is(): - """One TEXT follow-up still lands as the event object itself - (no spurious wrapping / mutation) โ€” guards against the merge path - breaking the simple case.""" adapter = _make_adapter() + adapter._busy_text_mode = "" first = _make_event("only one") session_key = build_session_key(first.source) @@ -149,4 +349,29 @@ async def test_single_followup_is_stored_as_is(): pending = adapter._pending_messages[session_key] assert pending is first assert pending.text == "only one" - assert adapter._active_sessions[session_key].is_set() + assert not adapter._active_sessions[session_key].is_set() + + +def test_adapter_defaults_to_queue_mode(monkeypatch): + monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False) + adapter = _make_initialized_adapter() + assert adapter._busy_text_mode == "queue" + assert adapter._is_queue_text_debounce_candidate(_make_event("hello")) + + +def test_adapter_is_queue_text_debounce_candidate_by_default(): + adapter = _make_adapter() + assert adapter._is_queue_text_debounce_candidate(_make_event("hello world")) + + +def test_command_messages_bypass_debounce_even_in_queue_mode(): + adapter = _make_adapter() + assert not adapter._is_queue_text_debounce_candidate(_make_event("")) + assert not adapter._is_queue_text_debounce_candidate(_make_event("/stop")) + + +def test_busy_text_mode_respects_env_var_override(monkeypatch): + monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "interrupt") + adapter = _make_initialized_adapter() + assert adapter._busy_text_mode == "interrupt" + assert not adapter._is_queue_text_debounce_candidate(_make_event("test")) diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index aae5f550532..608385bef17 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -14,6 +14,8 @@ Tests cover: import asyncio import json +import os +import stat import time import uuid from unittest.mock import AsyncMock, MagicMock, patch @@ -128,6 +130,37 @@ class TestResponseStore: # resp_2 mapping should still be intact assert store.get_conversation("chat-b") == "resp_2" + @pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits are platform-specific") + def test_file_store_created_owner_only_under_permissive_umask(self, tmp_path): + """response_store.db must be 0o600 on creation even under umask 022.""" + db_path = tmp_path / "response_store.db" + store = None + old_umask = os.umask(0o022) + try: + store = ResponseStore(max_size=10, db_path=str(db_path)) + store.put( + "resp_secret", + { + "response": {"id": "resp_secret"}, + "conversation_history": [{"role": "tool", "content": "dummy-marker"}], + }, + ) + finally: + os.umask(old_umask) + if store is not None: + store.close() + + assert stat.S_IMODE(db_path.stat().st_mode) == 0o600 + # WAL/SHM sidecars are owner-only too when present. WAL mode may be + # unavailable on some filesystems (NFS/SMB) โ€” only assert when the + # sidecar files actually exist. + for sidecar in ( + db_path.with_name(db_path.name + "-wal"), + db_path.with_name(db_path.name + "-shm"), + ): + if sidecar.exists(): + assert stat.S_IMODE(sidecar.stat().st_mode) == 0o600 + # --------------------------------------------------------------------------- # _IdempotencyCache diff --git a/tests/gateway/test_auth_fallback.py b/tests/gateway/test_auth_fallback.py index 3edb8b1ee9a..5976962e651 100644 --- a/tests/gateway/test_auth_fallback.py +++ b/tests/gateway/test_auth_fallback.py @@ -27,8 +27,11 @@ class TestResolveRuntimeAgentKwargsAuthFallback: def _mock_resolve(**kwargs): call_count["n"] += 1 - requested = kwargs.get("requested", "") - if requested and "codex" in str(requested).lower(): + # First call = primary path (gateway reads model.provider from + # config.yaml internally; we simulate the auth failure here). + # Second call = fallback path with explicit_api_key + explicit_base_url + # supplied by gateway from fallback_model config. + if call_count["n"] == 1: raise AuthError("Codex token refresh failed with status 401") return { "api_key": "fallback-key", @@ -40,8 +43,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback: "credential_pool": None, } - monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex") - with patch( "hermes_cli.runtime_provider.resolve_runtime_provider", side_effect=_mock_resolve, @@ -62,7 +63,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback: config_path.write_text("model:\n provider: openai-codex\n") monkeypatch.setattr("gateway.run._hermes_home", tmp_path) - monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex") with patch( "hermes_cli.runtime_provider.resolve_runtime_provider", @@ -71,3 +71,46 @@ class TestResolveRuntimeAgentKwargsAuthFallback: from gateway.run import _resolve_runtime_agent_kwargs with pytest.raises(RuntimeError): _resolve_runtime_agent_kwargs() + + def test_legacy_fallback_is_appended_after_fallback_providers(self, tmp_path, monkeypatch): + """When both keys exist, the legacy entry still participates in resolution.""" + config_path = tmp_path / "config.yaml" + config_path.write_text( + "fallback_providers:\n" + " - provider: openrouter\n" + " model: anthropic/claude-sonnet-4.6\n" + "fallback_model:\n" + " provider: nous\n" + " model: Hermes-4\n" + ) + + monkeypatch.setattr("gateway.run._hermes_home", tmp_path) + + calls = [] + + def _mock_resolve(**kwargs): + requested = kwargs.get("requested") + calls.append(requested) + if requested == "openrouter": + raise RuntimeError("openrouter unavailable") + return { + "api_key": "nous-key", + "base_url": "https://portal.nousresearch.com/v1", + "provider": "nous", + "api_mode": "chat_completions", + "command": None, + "args": None, + "credential_pool": None, + } + + with patch( + "hermes_cli.runtime_provider.resolve_runtime_provider", + side_effect=_mock_resolve, + ): + from gateway.run import _try_resolve_fallback_provider + + result = _try_resolve_fallback_provider() + + assert calls == ["openrouter", "nous"] + assert result["provider"] == "nous" + assert result["model"] == "Hermes-4" diff --git a/tests/gateway/test_base_topic_sessions.py b/tests/gateway/test_base_topic_sessions.py index a55fcb1d8ff..dd2ef3a1262 100644 --- a/tests/gateway/test_base_topic_sessions.py +++ b/tests/gateway/test_base_topic_sessions.py @@ -15,6 +15,7 @@ from gateway.session import SessionSource, build_session_key class DummyTelegramAdapter(BasePlatformAdapter): def __init__(self): super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM) + self._busy_text_mode = "" self.sent = [] self.typing = [] self.processing_hooks = [] diff --git a/tests/gateway/test_bluebubbles.py b/tests/gateway/test_bluebubbles.py index 6f93c1d4dba..dea806fe66b 100644 --- a/tests/gateway/test_bluebubbles.py +++ b/tests/gateway/test_bluebubbles.py @@ -452,6 +452,14 @@ class TestBlueBubblesWebhookUrl: adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@") assert "password=W9fTC%26L5JL%2A%40" in adapter._webhook_register_url + def test_register_url_for_log_masks_password(self, monkeypatch): + """Log-safe webhook URLs must never expose the webhook password.""" + adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@") + safe_url = adapter._webhook_register_url_for_log + assert safe_url.endswith("?password=***") + assert "W9fTC" not in safe_url + assert "%26" not in safe_url + def test_register_url_omits_query_when_no_password(self, monkeypatch): """If no password is configured, the register URL should be the bare URL.""" monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False) diff --git a/tests/gateway/test_busy_session_ack.py b/tests/gateway/test_busy_session_ack.py index b16e5ebb5f2..f13e16961e4 100644 --- a/tests/gateway/test_busy_session_ack.py +++ b/tests/gateway/test_busy_session_ack.py @@ -65,6 +65,7 @@ def _make_runner(): runner._pending_messages = {} runner._busy_ack_ts = {} runner._draining = False + runner._busy_text_mode = "interrupt" runner.adapters = {} runner.config = MagicMock() runner.session_store = None @@ -84,6 +85,8 @@ def _make_adapter(platform_val="telegram"): adapter.config = MagicMock() adapter.config.extra = {} adapter.platform = MagicMock(value=platform_val) + adapter._text_debounce = {} + adapter._busy_text_debounce_seconds = 0.6 return adapter @@ -186,6 +189,32 @@ class TestBusySessionAck: assert "respond once the current task finishes" in content assert "Interrupting" not in content + @pytest.mark.asyncio + async def test_busy_text_mode_queue_delegates_to_adapter_handle_message(self): + """busy_text_mode=queue lets the adapter debounce text silently.""" + runner, sentinel = _make_runner() + runner._busy_input_mode = "interrupt" + runner._busy_text_mode = "queue" + adapter = _make_adapter() + + first = _make_event(text="part one") + second = _make_event(text="part two") + sk = build_session_key(first.source) + + agent = MagicMock() + runner._running_agents[sk] = agent + runner.adapters[first.source.platform] = adapter + runner.adapters[second.source.platform] = adapter + + result1 = await runner._handle_active_session_busy_message(first, sk) + result2 = await runner._handle_active_session_busy_message(second, sk) + + assert result1 is False + assert result2 is False + assert sk not in adapter._pending_messages + agent.interrupt.assert_not_called() + adapter._send_with_retry.assert_not_called() + @pytest.mark.asyncio async def test_steer_mode_calls_agent_steer_no_interrupt_no_queue(self): """busy_input_mode='steer' injects via agent.steer() and skips queueing.""" diff --git a/tests/gateway/test_command_bypass_active_session.py b/tests/gateway/test_command_bypass_active_session.py index aae68b6b53f..2c0a593dc55 100644 --- a/tests/gateway/test_command_bypass_active_session.py +++ b/tests/gateway/test_command_bypass_active_session.py @@ -47,6 +47,7 @@ def _make_adapter(): """Create a minimal adapter for testing the active-session guard.""" config = PlatformConfig(enabled=True, token="test-token") adapter = _StubAdapter(config, Platform.TELEGRAM) + adapter._busy_text_mode = "" adapter.sent_responses = [] async def _mock_handler(event): diff --git a/tests/gateway/test_config_env_bridge_authority.py b/tests/gateway/test_config_env_bridge_authority.py index 26c54f1c736..a82beb397b9 100644 --- a/tests/gateway/test_config_env_bridge_authority.py +++ b/tests/gateway/test_config_env_bridge_authority.py @@ -45,6 +45,7 @@ def _run_gateway_import(hermes_home: Path, initial_env: dict[str, str]) -> dict[ "HERMES_AGENT_TIMEOUT", "HERMES_AGENT_TIMEOUT_WARNING", "HERMES_GATEWAY_BUSY_INPUT_MODE", + "HERMES_GATEWAY_BUSY_TEXT_MODE", "HERMES_TIMEZONE", ): v = os.environ.get(k) @@ -143,6 +144,15 @@ def test_config_display_busy_input_mode_wins_over_stale_env(hermes_home: Path) - assert env.get("HERMES_GATEWAY_BUSY_INPUT_MODE") == "interrupt" +def test_config_display_busy_text_mode_wins_over_stale_env(hermes_home: Path) -> None: + _write_config(hermes_home, display_cfg={"busy_text_mode": "queue"}) + _write_env(hermes_home, {"HERMES_GATEWAY_BUSY_TEXT_MODE": "interrupt"}) + + env = _run_gateway_import(hermes_home, initial_env={}) + + assert env.get("HERMES_GATEWAY_BUSY_TEXT_MODE") == "queue" + + def test_config_timezone_wins_over_stale_env(hermes_home: Path) -> None: _write_config(hermes_home, timezone="America/Los_Angeles") _write_env(hermes_home, {"HERMES_TIMEZONE": "UTC"}) diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py index 6b2db13299d..2da55a00979 100644 --- a/tests/gateway/test_dingtalk.py +++ b/tests/gateway/test_dingtalk.py @@ -407,6 +407,36 @@ class TestConnect: assert len(adapter._dedup._seen) == 0 assert adapter._http_client is None + @pytest.mark.asyncio + async def test_disconnect_finalizes_open_streaming_cards(self): + """Streaming cards must be finalized before HTTP client closes.""" + from unittest.mock import AsyncMock, patch + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._http_client = AsyncMock() + adapter._stream_task = None + adapter._streaming_cards = { + "chat-1": {"track-a": "last content"}, + "chat-2": {"track-b": "other"}, + } + + close_calls = [] + + async def fake_close_siblings(chat_id): + # HTTP client must still be alive at call time. + assert adapter._http_client is not None, ( + "HTTP client was already closed before card finalization" + ) + close_calls.append(chat_id) + adapter._streaming_cards.pop(chat_id, None) + + with patch.object(adapter, "_close_streaming_siblings", side_effect=fake_close_siblings): + await adapter.disconnect() + + assert set(close_calls) == {"chat-1", "chat-2"} + assert adapter._streaming_cards == {} + assert adapter._http_client is None + # --------------------------------------------------------------------------- # Platform enum diff --git a/tests/gateway/test_discord_allowed_mentions.py b/tests/gateway/test_discord_allowed_mentions.py index c717c3cd196..dee9c379a2d 100644 --- a/tests/gateway/test_discord_allowed_mentions.py +++ b/tests/gateway/test_discord_allowed_mentions.py @@ -81,7 +81,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import _build_allowed_mentions # noqa: E402 +from plugins.platforms.discord.adapter import _build_allowed_mentions # noqa: E402 # The four DISCORD_ALLOW_MENTION_* env vars that _build_allowed_mentions reads. diff --git a/tests/gateway/test_discord_attachment_download.py b/tests/gateway/test_discord_attachment_download.py index 06384aead82..5f8f74fd826 100644 --- a/tests/gateway/test_discord_attachment_download.py +++ b/tests/gateway/test_discord_attachment_download.py @@ -58,7 +58,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 from gateway.platforms.base import MessageType # noqa: E402 @@ -146,10 +146,10 @@ class TestCacheDiscordImage: att = _make_attachment_with_read(_PNG_BYTES) with patch( - "gateway.platforms.discord.cache_image_from_bytes", + "plugins.platforms.discord.adapter.cache_image_from_bytes", return_value="/tmp/cached.png", ) as mock_bytes, patch( - "gateway.platforms.discord.cache_image_from_url", + "plugins.platforms.discord.adapter.cache_image_from_url", new_callable=AsyncMock, ) as mock_url: result = await adapter._cache_discord_image(att, ".png") @@ -165,9 +165,9 @@ class TestCacheDiscordImage: att = _make_attachment_without_read() with patch( - "gateway.platforms.discord.cache_image_from_bytes", + "plugins.platforms.discord.adapter.cache_image_from_bytes", ) as mock_bytes, patch( - "gateway.platforms.discord.cache_image_from_url", + "plugins.platforms.discord.adapter.cache_image_from_url", new_callable=AsyncMock, return_value="/tmp/from_url.png", ) as mock_url: @@ -186,10 +186,10 @@ class TestCacheDiscordImage: att = _make_attachment_with_read(b"forbidden") with patch( - "gateway.platforms.discord.cache_image_from_bytes", + "plugins.platforms.discord.adapter.cache_image_from_bytes", side_effect=ValueError("not a valid image"), ), patch( - "gateway.platforms.discord.cache_image_from_url", + "plugins.platforms.discord.adapter.cache_image_from_url", new_callable=AsyncMock, return_value="/tmp/fallback.png", ) as mock_url: @@ -210,10 +210,10 @@ class TestCacheDiscordAudio: att = _make_attachment_with_read(_OGG_BYTES) with patch( - "gateway.platforms.discord.cache_audio_from_bytes", + "plugins.platforms.discord.adapter.cache_audio_from_bytes", return_value="/tmp/voice.ogg", ) as mock_bytes, patch( - "gateway.platforms.discord.cache_audio_from_url", + "plugins.platforms.discord.adapter.cache_audio_from_url", new_callable=AsyncMock, ) as mock_url: result = await adapter._cache_discord_audio(att, ".ogg") @@ -228,7 +228,7 @@ class TestCacheDiscordAudio: att = _make_attachment_without_read() with patch( - "gateway.platforms.discord.cache_audio_from_url", + "plugins.platforms.discord.adapter.cache_audio_from_url", new_callable=AsyncMock, return_value="/tmp/from_url.ogg", ) as mock_url: @@ -267,7 +267,7 @@ class TestCacheDiscordDocument: att = _make_attachment_without_read() # no .read โ†’ forces fallback with patch( - "gateway.platforms.discord.is_safe_url", return_value=False + "plugins.platforms.discord.adapter.is_safe_url", return_value=False ) as mock_safe, patch("aiohttp.ClientSession") as mock_session: with pytest.raises(ValueError, match="SSRF"): await adapter._cache_discord_document(att, ".pdf") @@ -295,7 +295,7 @@ class TestCacheDiscordDocument: session.__aexit__ = AsyncMock(return_value=False) with patch( - "gateway.platforms.discord.is_safe_url", return_value=True + "plugins.platforms.discord.adapter.is_safe_url", return_value=True ), patch("aiohttp.ClientSession", return_value=session): result = await adapter._cache_discord_document(att, ".pdf") @@ -320,10 +320,10 @@ class TestHandleMessageUsesAuthenticatedRead: adapter.handle_message = AsyncMock() with patch( - "gateway.platforms.discord.cache_image_from_bytes", + "plugins.platforms.discord.adapter.cache_image_from_bytes", return_value="/tmp/img_from_read.png", ), patch( - "gateway.platforms.discord.cache_image_from_url", + "plugins.platforms.discord.adapter.cache_image_from_url", new_callable=AsyncMock, ) as mock_url_download: att = SimpleNamespace( @@ -342,7 +342,7 @@ class TestHandleMessageUsesAuthenticatedRead: # Patch the DMChannel isinstance check so our fake counts as DM. monkeypatch.setattr( - "gateway.platforms.discord.discord.DMChannel", + "plugins.platforms.discord.adapter.discord.DMChannel", _FakeDMChannel, ) chan = _FakeDMChannel() @@ -368,7 +368,7 @@ class TestHandleMessageUsesAuthenticatedRead: adapter.handle_message = AsyncMock() with patch( - "gateway.platforms.discord.cache_audio_from_bytes", + "plugins.platforms.discord.adapter.cache_audio_from_bytes", return_value="/tmp/voice_from_read.ogg", ): att = SimpleNamespace( @@ -386,7 +386,7 @@ class TestHandleMessageUsesAuthenticatedRead: name = "dm" monkeypatch.setattr( - "gateway.platforms.discord.discord.DMChannel", + "plugins.platforms.discord.adapter.discord.DMChannel", _FakeDMChannel, ) chan = _FakeDMChannel() @@ -412,7 +412,7 @@ class TestHandleMessageUsesAuthenticatedRead: adapter.handle_message = AsyncMock() with patch( - "gateway.platforms.discord.cache_audio_from_bytes", + "plugins.platforms.discord.adapter.cache_audio_from_bytes", return_value="/tmp/audio_from_read.ogg", ): att = SimpleNamespace( @@ -430,7 +430,7 @@ class TestHandleMessageUsesAuthenticatedRead: name = "dm" monkeypatch.setattr( - "gateway.platforms.discord.discord.DMChannel", + "plugins.platforms.discord.adapter.discord.DMChannel", _FakeDMChannel, ) chan = _FakeDMChannel() diff --git a/tests/gateway/test_discord_bot_auth_bypass.py b/tests/gateway/test_discord_bot_auth_bypass.py index 8ff39a1bf49..7d86e034eb3 100644 --- a/tests/gateway/test_discord_bot_auth_bypass.py +++ b/tests/gateway/test_discord_bot_auth_bypass.py @@ -172,42 +172,49 @@ def test_bot_bypass_does_not_leak_to_other_platforms(monkeypatch): # ----------------------------------------------------------------------------- -# DISCORD_ALLOWED_ROLES gateway-layer bypass (#7871) +# DISCORD_ALLOWED_ROLES no longer bypasses the gateway allowlist (#30742) +# +# Prior behavior: setting DISCORD_ALLOWED_ROLES caused _is_user_authorized +# to return True for ANY Discord event, on the assumption that the adapter +# pre-filter had already validated role membership. That allowed slash +# commands and synthetic voice events to bypass role checks. PR #30742 +# removed the shortcut โ€” Discord auth now flows through the same allowlist +# / pairing / allow-all path as every other platform. # ----------------------------------------------------------------------------- -def test_discord_role_config_bypasses_gateway_allowlist(monkeypatch): - """When DISCORD_ALLOWED_ROLES is set, _is_user_authorized must trust - the adapter's pre-filter and authorize. Without this, role-only setups - (DISCORD_ALLOWED_ROLES populated, DISCORD_ALLOWED_USERS empty) would - hit the 'no allowlists configured' branch and get rejected. +def test_discord_role_config_does_not_bypass_gateway_allowlist(monkeypatch): + """DISCORD_ALLOWED_ROLES alone must NOT authorize at the gateway layer + (regression guard for #30742). Role-based access is enforced by the + adapter pre-filter on real message events; the gateway layer requires + an explicit allowlist hit or pairing approval. """ runner = _make_bare_runner() monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674") - # Note: DISCORD_ALLOWED_USERS is NOT set โ€” the entire point. + # DISCORD_ALLOWED_USERS deliberately NOT set โ€” verifies the role + # config alone no longer grants authorization. source = _make_discord_human_source(user_id="999888777") - assert runner._is_user_authorized(source) is True + assert runner._is_user_authorized(source) is False -def test_discord_role_config_still_authorizes_alongside_users(monkeypatch): - """Sanity: setting both DISCORD_ALLOWED_ROLES and DISCORD_ALLOWED_USERS - doesn't break the user-id path. Users in the allowlist should still be - authorized even if they don't have a role. (OR semantics.) +def test_discord_user_allowlist_still_authorizes_when_role_is_also_configured(monkeypatch): + """Sanity: DISCORD_ALLOWED_USERS still authorizes users on the list, + independent of DISCORD_ALLOWED_ROLES. This guards against a future + regression that ties the user-allowlist check to the (now-removed) + role bypass. """ runner = _make_bare_runner() monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674") monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300") - # User on the user allowlist, no role โ†’ still authorized at gateway - # level via the role bypass (adapter already approved them). source = _make_discord_human_source(user_id="100200300") assert runner._is_user_authorized(source) is True -def test_discord_role_bypass_does_not_leak_to_other_platforms(monkeypatch): +def test_discord_role_config_does_not_leak_to_other_platforms(monkeypatch): """DISCORD_ALLOWED_ROLES must only affect Discord. Setting it should not suddenly start authorizing Telegram users whose platform has its own empty allowlist. diff --git a/tests/gateway/test_discord_channel_controls.py b/tests/gateway/test_discord_channel_controls.py index dc7971529a1..3142ef839d7 100644 --- a/tests/gateway/test_discord_channel_controls.py +++ b/tests/gateway/test_discord_channel_controls.py @@ -45,8 +45,8 @@ def _ensure_discord_mock(): _ensure_discord_mock() -import gateway.platforms.discord as discord_platform # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +import plugins.platforms.discord.adapter as discord_platform # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class FakeDMChannel: diff --git a/tests/gateway/test_discord_channel_prompts.py b/tests/gateway/test_discord_channel_prompts.py index e1efd734dc0..378e0f19a0b 100644 --- a/tests/gateway/test_discord_channel_prompts.py +++ b/tests/gateway/test_discord_channel_prompts.py @@ -58,7 +58,7 @@ def _install_fake_agent(monkeypatch): def _make_adapter(): _ensure_discord_mock() - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter adapter = object.__new__(DiscordAdapter) adapter.config = MagicMock() diff --git a/tests/gateway/test_discord_channel_skills.py b/tests/gateway/test_discord_channel_skills.py index 26c75f0a9f7..33c469df60d 100644 --- a/tests/gateway/test_discord_channel_skills.py +++ b/tests/gateway/test_discord_channel_skills.py @@ -5,7 +5,7 @@ import pytest def _make_adapter(): """Create a minimal DiscordAdapter with mocked config.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter adapter = object.__new__(DiscordAdapter) adapter.config = MagicMock() adapter.config.extra = {} diff --git a/tests/gateway/test_discord_clarify_buttons.py b/tests/gateway/test_discord_clarify_buttons.py index b6e21f1f44b..04f20195f46 100644 --- a/tests/gateway/test_discord_clarify_buttons.py +++ b/tests/gateway/test_discord_clarify_buttons.py @@ -26,7 +26,7 @@ if _repo not in sys.path: # Triggers the shared discord mock from tests/gateway/conftest.py before # importing the production module. -from gateway.platforms.discord import ( # noqa: E402 +from plugins.platforms.discord.adapter import ( # noqa: E402 ClarifyChoiceView, DiscordAdapter, ) diff --git a/tests/gateway/test_discord_component_auth.py b/tests/gateway/test_discord_component_auth.py index 5758e82561e..95d746b80ee 100644 --- a/tests/gateway/test_discord_component_auth.py +++ b/tests/gateway/test_discord_component_auth.py @@ -18,7 +18,7 @@ import pytest # Trigger the shared discord mock from tests/gateway/conftest.py before # importing the production module. -from gateway.platforms.discord import ( # noqa: E402 +from plugins.platforms.discord.adapter import ( # noqa: E402 ExecApprovalView, ModelPickerView, SlashConfirmView, diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index 43f88bcf9da..54dc903e971 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -67,8 +67,8 @@ def _ensure_discord_mock(): _ensure_discord_mock() -import gateway.platforms.discord as discord_platform # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +import plugins.platforms.discord.adapter as discord_platform # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 @pytest.fixture(autouse=True) diff --git a/tests/gateway/test_discord_document_handling.py b/tests/gateway/test_discord_document_handling.py index 0685b69663a..7b75c4a07f6 100644 --- a/tests/gateway/test_discord_document_handling.py +++ b/tests/gateway/test_discord_document_handling.py @@ -57,8 +57,8 @@ def _ensure_discord_mock(): _ensure_discord_mock() -import gateway.platforms.discord as discord_platform # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +import plugins.platforms.discord.adapter as discord_platform # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 # --------------------------------------------------------------------------- @@ -371,7 +371,7 @@ class TestIncomingDocumentHandling: async def test_image_attachment_unaffected(self, adapter): """Image attachments should still go through the image path, not the document path.""" with patch( - "gateway.platforms.discord.cache_image_from_url", + "plugins.platforms.discord.adapter.cache_image_from_url", new_callable=AsyncMock, return_value="/tmp/cached_image.png", ): diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index c69af3e7781..554288812b7 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -45,8 +45,8 @@ def _ensure_discord_mock(): _ensure_discord_mock() -import gateway.platforms.discord as discord_platform # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +import plugins.platforms.discord.adapter as discord_platform # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class FakeDMChannel: diff --git a/tests/gateway/test_discord_imports.py b/tests/gateway/test_discord_imports.py index bbda79c9ece..7246b4f09a4 100644 --- a/tests/gateway/test_discord_imports.py +++ b/tests/gateway/test_discord_imports.py @@ -14,10 +14,13 @@ class TestDiscordImportSafety: raise ImportError("discord unavailable for test") return original_import(name, globals, locals, fromlist, level) - monkeypatch.delitem(sys.modules, "gateway.platforms.discord", raising=False) + # Purge the cached module so the import below actually re-runs the + # module body with discord.py simulated-missing. + monkeypatch.delitem(sys.modules, "plugins.platforms.discord.adapter", raising=False) + monkeypatch.delitem(sys.modules, "plugins.platforms.discord", raising=False) monkeypatch.setattr(builtins, "__import__", fake_import) - module = importlib.import_module("gateway.platforms.discord") + module = importlib.import_module("plugins.platforms.discord.adapter") assert module.DISCORD_AVAILABLE is False assert module.discord is None diff --git a/tests/gateway/test_discord_lazy_install_views.py b/tests/gateway/test_discord_lazy_install_views.py index 62f2b974e02..2ed926e0f8f 100644 --- a/tests/gateway/test_discord_lazy_install_views.py +++ b/tests/gateway/test_discord_lazy_install_views.py @@ -34,7 +34,7 @@ class TestDefineDiscordViewClasses: def test_registers_all_five_view_classes(self, monkeypatch): """Calling _define_discord_view_classes() must (re)define all 5 view classes.""" - dp = importlib.import_module("gateway.platforms.discord") + dp = importlib.import_module("plugins.platforms.discord.adapter") # Remove the classes to simulate the state where the module was loaded # with DISCORD_AVAILABLE=False (the lazy-install scenario). @@ -54,7 +54,7 @@ class TestDefineDiscordViewClasses: def test_check_discord_requirements_calls_define_on_lazy_install(self, monkeypatch): """check_discord_requirements() must call _define_discord_view_classes() on a successful lazy install so view classes exist when DISCORD_AVAILABLE=True.""" - dp = importlib.import_module("gateway.platforms.discord") + dp = importlib.import_module("plugins.platforms.discord.adapter") # Simulate discord not yet available at module load. monkeypatch.setattr(dp, "DISCORD_AVAILABLE", False) diff --git a/tests/gateway/test_discord_media_metadata.py b/tests/gateway/test_discord_media_metadata.py index a98ac4fc043..966700b700d 100644 --- a/tests/gateway/test_discord_media_metadata.py +++ b/tests/gateway/test_discord_media_metadata.py @@ -1,6 +1,6 @@ import inspect -from gateway.platforms.discord import DiscordAdapter +from plugins.platforms.discord.adapter import DiscordAdapter def test_discord_media_methods_accept_metadata_kwarg(): diff --git a/tests/gateway/test_discord_model_picker.py b/tests/gateway/test_discord_model_picker.py index a1ff434bd37..2ee4e86a38d 100644 --- a/tests/gateway/test_discord_model_picker.py +++ b/tests/gateway/test_discord_model_picker.py @@ -11,7 +11,7 @@ from unittest.mock import AsyncMock import pytest -from gateway.platforms.discord import ModelPickerView +from plugins.platforms.discord.adapter import ModelPickerView @pytest.mark.asyncio diff --git a/tests/gateway/test_discord_opus.py b/tests/gateway/test_discord_opus.py index ef66cde004d..63bef5acaf5 100644 --- a/tests/gateway/test_discord_opus.py +++ b/tests/gateway/test_discord_opus.py @@ -8,14 +8,14 @@ class TestOpusFindLibrary: def test_uses_find_library_first(self): """find_library must be the primary lookup strategy.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter source = inspect.getsource(DiscordAdapter.connect) assert "find_library" in source, \ "Opus loading must use ctypes.util.find_library" def test_homebrew_fallback_is_conditional(self): """Homebrew paths must only be tried when find_library returns None.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter source = inspect.getsource(DiscordAdapter.connect) # Homebrew fallback must exist assert "/opt/homebrew" in source or "homebrew" in source, \ @@ -31,7 +31,7 @@ class TestOpusFindLibrary: def test_opus_decode_error_logged(self): """Opus decode failure must log the error, not silently return.""" - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver source = inspect.getsource(VoiceReceiver._on_packet) assert "logger" in source, \ "_on_packet must log Opus decode errors" diff --git a/tests/gateway/test_discord_race_polish.py b/tests/gateway/test_discord_race_polish.py index 02c927e370f..5f86150921f 100644 --- a/tests/gateway/test_discord_race_polish.py +++ b/tests/gateway/test_discord_race_polish.py @@ -10,7 +10,7 @@ from gateway.config import Platform, PlatformConfig def _make_adapter(): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter adapter = object.__new__(DiscordAdapter) adapter._platform = Platform.DISCORD @@ -60,7 +60,7 @@ async def test_concurrent_joins_do_not_double_connect(): channel.guild.id = 42 channel.connect = lambda: slow_connect(channel) - from gateway.platforms import discord as discord_mod + from plugins.platforms.discord import adapter as discord_mod with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))): with patch.object(discord_mod.asyncio, "ensure_future", diff --git a/tests/gateway/test_discord_reactions.py b/tests/gateway/test_discord_reactions.py index 2d7b2a2c934..e968b750ea3 100644 --- a/tests/gateway/test_discord_reactions.py +++ b/tests/gateway/test_discord_reactions.py @@ -40,7 +40,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class FakeTree: diff --git a/tests/gateway/test_discord_reply_mode.py b/tests/gateway/test_discord_reply_mode.py index 64e27a27aa8..d113af2e6a2 100644 --- a/tests/gateway/test_discord_reply_mode.py +++ b/tests/gateway/test_discord_reply_mode.py @@ -53,7 +53,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 @pytest.fixture() diff --git a/tests/gateway/test_discord_roles_dm_scope.py b/tests/gateway/test_discord_roles_dm_scope.py index 0f10ba79ae1..ee2939aae3b 100644 --- a/tests/gateway/test_discord_roles_dm_scope.py +++ b/tests/gateway/test_discord_roles_dm_scope.py @@ -20,7 +20,7 @@ from unittest.mock import MagicMock import pytest -from gateway.platforms.discord import DiscordAdapter +from plugins.platforms.discord.adapter import DiscordAdapter def _set_dm_role_auth_guild(monkeypatch, guild_id=None): diff --git a/tests/gateway/test_discord_send.py b/tests/gateway/test_discord_send.py index 03f442a3b88..cd2950f9fbb 100644 --- a/tests/gateway/test_discord_send.py +++ b/tests/gateway/test_discord_send.py @@ -42,7 +42,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 @pytest.mark.asyncio diff --git a/tests/gateway/test_discord_slash_auth.py b/tests/gateway/test_discord_slash_auth.py index e51f240e3aa..39d06ba74fb 100644 --- a/tests/gateway/test_discord_slash_auth.py +++ b/tests/gateway/test_discord_slash_auth.py @@ -85,7 +85,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 @pytest.fixture(autouse=True) diff --git a/tests/gateway/test_discord_slash_commands.py b/tests/gateway/test_discord_slash_commands.py index 589e8053bc1..d5ed297faad 100644 --- a/tests/gateway/test_discord_slash_commands.py +++ b/tests/gateway/test_discord_slash_commands.py @@ -75,7 +75,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class FakeTree: diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py index b6be0a66832..75237f6403f 100644 --- a/tests/gateway/test_discord_thread_persistence.py +++ b/tests/gateway/test_discord_thread_persistence.py @@ -17,7 +17,7 @@ class TestDiscordThreadPersistence: def _make_adapter(self, tmp_path): """Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path.""" from gateway.config import PlatformConfig - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter config = PlatformConfig(enabled=True, token="test-token") with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): diff --git a/tests/gateway/test_fast_command.py b/tests/gateway/test_fast_command.py index c904b659d1b..58db9faf05e 100644 --- a/tests/gateway/test_fast_command.py +++ b/tests/gateway/test_fast_command.py @@ -148,6 +148,15 @@ async def test_run_agent_passes_priority_processing_to_gateway_agent(monkeypatch monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env") monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None) monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {}) + # ``_load_service_tier`` was refactored to call ``_load_gateway_runtime_config`` + # (which wraps ``_load_gateway_config`` plus env-expansion). Since the test + # stubs ``_load_gateway_config`` to ``{}``, also stub the runtime wrapper + # directly so the priority routing assertions still exercise the live tier. + monkeypatch.setattr( + gateway_run, + "_load_gateway_runtime_config", + lambda: {"agent": {"service_tier": "fast"}}, + ) monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4") monkeypatch.setattr( gateway_run, diff --git a/tests/gateway/test_feishu.py b/tests/gateway/test_feishu.py index 63287d88cb4..75f61923956 100644 --- a/tests/gateway/test_feishu.py +++ b/tests/gateway/test_feishu.py @@ -167,6 +167,7 @@ class TestFeishuAdapterMessaging(unittest.TestCase): "FEISHU_WEBHOOK_HOST": "127.0.0.1", "FEISHU_WEBHOOK_PORT": "9001", "FEISHU_WEBHOOK_PATH": "/hook", + "FEISHU_VERIFICATION_TOKEN": "vtok", }, clear=True) def test_connect_webhook_mode_starts_local_server(self): from gateway.config import PlatformConfig @@ -1538,6 +1539,34 @@ class TestAdapterBehavior(unittest.TestCase): self.assertEqual(response.status, 200) adapter._on_message_event.assert_called_once() + @patch.dict(os.environ, {"FEISHU_VERIFICATION_TOKEN": "expected-token"}, clear=True) + def test_url_verification_requires_configured_verification_token(self): + """url_verification must be rejected when token is set but mismatched. + + Regression: previously the challenge was reflected before the token + check, so an unauthenticated remote could prove endpoint control by + sending an attacker-controlled challenge string. + """ + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter(PlatformConfig()) + body = json.dumps({ + "type": "url_verification", + "token": "wrong-token", + "challenge": "attacker-controlled-challenge", + }).encode("utf-8") + request = SimpleNamespace( + remote="203.0.113.10", + content_length=None, + headers={}, + read=AsyncMock(return_value=body), + ) + + response = asyncio.run(adapter._handle_webhook_request(request)) + + self.assertEqual(response.status, 401) + @patch.dict(os.environ, {}, clear=True) def test_process_inbound_message_uses_event_sender_identity_only(self): from gateway.config import PlatformConfig @@ -3191,6 +3220,39 @@ class TestWebhookSecurity(unittest.TestCase): response = asyncio.run(adapter._handle_webhook_request(request)) self.assertEqual(response.status, 401) + @patch.dict(os.environ, {}, clear=True) + def test_webhook_connect_requires_inbound_auth_secret(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter( + PlatformConfig( + enabled=True, + extra={"app_id": "cli_app", "app_secret": "secret_app", "connection_mode": "webhook"}, + ) + ) + self.assertFalse(asyncio.run(adapter.connect())) + + @patch.dict(os.environ, {}, clear=True) + def test_webhook_loads_auth_secrets_from_platform_extra(self): + from gateway.config import PlatformConfig + from gateway.platforms.feishu import FeishuAdapter + + adapter = FeishuAdapter( + PlatformConfig( + enabled=True, + extra={ + "app_id": "cli_app", + "app_secret": "secret_app", + "connection_mode": "webhook", + "verification_token": "token_from_extra", + "encrypt_key": "encrypt_from_extra", + }, + ) + ) + self.assertEqual(adapter._verification_token, "token_from_extra") + self.assertEqual(adapter._encrypt_key, "encrypt_from_extra") + @patch.dict(os.environ, {}, clear=True) def test_webhook_url_verification_challenge_passes_without_signature(self): """Challenge requests must succeed even when no encrypt_key is set.""" diff --git a/tests/gateway/test_feishu_approval_buttons.py b/tests/gateway/test_feishu_approval_buttons.py index 8af56913c10..e739d47b087 100644 --- a/tests/gateway/test_feishu_approval_buttons.py +++ b/tests/gateway/test_feishu_approval_buttons.py @@ -320,7 +320,7 @@ class TestResolveApproval: } with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: - await adapter._resolve_approval(1, "once", "Norbert") + await adapter._resolve_approval(1, "once", "Norbert", open_id="ou_user1", chat_id="oc_12345") mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once") assert 1 not in adapter._approval_state @@ -335,7 +335,7 @@ class TestResolveApproval: } with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: - await adapter._resolve_approval(2, "deny", "Alice") + await adapter._resolve_approval(2, "deny", "Alice", open_id="ou_user1", chat_id="oc_12345") mock_resolve.assert_called_once_with("some-session", "deny") @@ -349,7 +349,7 @@ class TestResolveApproval: } with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: - await adapter._resolve_approval(3, "session", "Bob") + await adapter._resolve_approval(3, "session", "Bob", open_id="ou_user1", chat_id="oc_99") mock_resolve.assert_called_once_with("sess-3", "session") @@ -363,7 +363,7 @@ class TestResolveApproval: } with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve: - await adapter._resolve_approval(4, "always", "Carol") + await adapter._resolve_approval(4, "always", "Carol", open_id="ou_user1", chat_id="oc_55") mock_resolve.assert_called_once_with("sess-4", "always") @@ -372,10 +372,41 @@ class TestResolveApproval: adapter = _make_adapter() with patch("tools.approval.resolve_gateway_approval") as mock_resolve: - await adapter._resolve_approval(99, "once", "Nobody") + await adapter._resolve_approval(99, "once", "Nobody", open_id="ou_user1", chat_id="oc_12345") mock_resolve.assert_not_called() + @pytest.mark.asyncio + async def test_unauthorized_click_does_not_resolve(self): + adapter = _make_adapter() + adapter._admins = {"ou_admin"} + adapter._approval_state[5] = { + "session_key": "sess-5", + "message_id": "msg_005", + "chat_id": "oc_12345", + } + + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + await adapter._resolve_approval(5, "once", "Mallory", open_id="ou_intruder", chat_id="oc_12345") + + mock_resolve.assert_not_called() + assert 5 in adapter._approval_state + + @pytest.mark.asyncio + async def test_chat_mismatch_does_not_resolve(self): + adapter = _make_adapter() + adapter._approval_state[6] = { + "session_key": "sess-6", + "message_id": "msg_006", + "chat_id": "oc_expected", + } + + with patch("tools.approval.resolve_gateway_approval") as mock_resolve: + await adapter._resolve_approval(6, "session", "Norbert", open_id="ou_user1", chat_id="oc_wrong") + + mock_resolve.assert_not_called() + assert 6 in adapter._approval_state + # =========================================================================== # _handle_card_action_event โ€” non-approval card actions # =========================================================================== @@ -448,6 +479,12 @@ class TestCardActionCallbackResponse: adapter = _make_adapter() adapter._loop = MagicMock() adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_bob"} + adapter._approval_state[1] = { + "session_key": "sess-1", + "message_id": "msg-1", + "chat_id": "oc_12345", + } data = _make_card_action_data( {"hermes_action": "approve_once", "approval_id": 1}, open_id="ou_bob", @@ -469,6 +506,12 @@ class TestCardActionCallbackResponse: adapter = _make_adapter() adapter._loop = MagicMock() adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_user1"} + adapter._approval_state[2] = { + "session_key": "sess-2", + "message_id": "msg-2", + "chat_id": "oc_12345", + } data = _make_card_action_data( {"hermes_action": "deny", "approval_id": 2}, ) @@ -510,6 +553,12 @@ class TestCardActionCallbackResponse: adapter = _make_adapter() adapter._loop = MagicMock() adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_unknown"} + adapter._approval_state[3] = { + "session_key": "sess-3", + "message_id": "msg-3", + "chat_id": "oc_12345", + } data = _make_card_action_data( {"hermes_action": "approve_session", "approval_id": 3}, open_id="ou_unknown", @@ -525,6 +574,12 @@ class TestCardActionCallbackResponse: adapter = _make_adapter() adapter._loop = MagicMock() adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_expired"} + adapter._approval_state[4] = { + "session_key": "sess-4", + "message_id": "msg-4", + "chat_id": "oc_12345", + } data = _make_card_action_data( {"hermes_action": "approve_once", "approval_id": 4}, open_id="ou_expired", @@ -538,6 +593,51 @@ class TestCardActionCallbackResponse: assert "Old Name" not in card["elements"][0]["content"] assert "ou_expired" in card["elements"][0]["content"] + def test_rejects_approval_click_from_unauthorized_user(self, _patch_callback_card_types): + adapter = _make_adapter() + adapter._loop = MagicMock() + adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_allowed"} + adapter._approval_state[5] = { + "session_key": "sess-5", + "message_id": "msg-5", + "chat_id": "oc_12345", + } + data = _make_card_action_data( + {"hermes_action": "approve_once", "approval_id": 5}, + open_id="ou_attacker", + ) + + with patch("asyncio.run_coroutine_threadsafe") as mock_submit: + response = adapter._on_card_action_trigger(data) + + assert response is not None + assert response.card is None + mock_submit.assert_not_called() + + def test_rejects_approval_click_when_callback_chat_mismatches(self, _patch_callback_card_types): + adapter = _make_adapter() + adapter._loop = MagicMock() + adapter._loop.is_closed = MagicMock(return_value=False) + adapter._allowed_group_users = {"ou_bob"} + adapter._approval_state[6] = { + "session_key": "sess-6", + "message_id": "msg-6", + "chat_id": "oc_expected", + } + data = _make_card_action_data( + {"hermes_action": "approve_once", "approval_id": 6}, + chat_id="oc_mismatch", + open_id="ou_bob", + ) + + with patch("asyncio.run_coroutine_threadsafe") as mock_submit: + response = adapter._on_card_action_trigger(data) + + assert response is not None + assert response.card is None + mock_submit.assert_not_called() + def test_returns_card_for_update_prompt_yes(self, _patch_callback_card_types): adapter = _make_adapter() adapter._loop = MagicMock() diff --git a/tests/gateway/test_interrupt_key_match.py b/tests/gateway/test_interrupt_key_match.py index 445a16f7a19..3a703c0261d 100644 --- a/tests/gateway/test_interrupt_key_match.py +++ b/tests/gateway/test_interrupt_key_match.py @@ -103,6 +103,7 @@ class TestInterruptKeyConsistency: async def test_handle_message_stores_under_session_key(self): """handle_message stores pending messages under session_key, not chat_id.""" adapter = StubAdapter() + adapter._busy_text_mode = "" adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None)) source = _source("-1001234", "group") @@ -120,8 +121,8 @@ class TestInterruptKeyConsistency: # NOT stored under chat_id assert source.chat_id not in adapter._pending_messages - # Interrupt event was set - assert adapter._active_sessions[session_key].is_set() + # Text follow-ups queue silently and do not interrupt the active turn. + assert adapter._active_sessions[session_key].is_set() is False @pytest.mark.asyncio async def test_photo_followup_is_queued_without_interrupt(self): diff --git a/tests/gateway/test_loop_exception_handler.py b/tests/gateway/test_loop_exception_handler.py new file mode 100644 index 00000000000..66ba4d94304 --- /dev/null +++ b/tests/gateway/test_loop_exception_handler.py @@ -0,0 +1,210 @@ +"""Tests for the gateway loop-level transient-network-error safety net. + +Issues #31066 / #31110: unhandled ``telegram.error.TimedOut`` (or peer +``NetworkError`` / ``httpx`` connection error) propagating to the +asyncio event loop killed the gateway process, taking down every +profile attached to the same runner. The safety net installed in +:func:`gateway.run.start_gateway` catches the transient crash class +and logs+swallows it; non-transient errors still surface. + +These tests pin the classifier and the loop handler so the safety net +can't silently regress to swallowing every exception. +""" + +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from gateway.run import ( + _gateway_loop_exception_handler, + _is_transient_network_error, +) + + +# ----- Fake exception classes that mimic the real wire types ---------- +# We avoid importing telegram / httpx here so the test runs in environments +# without those packages installed (the classifier matches on class name). + +class TimedOut(Exception): + """Stand-in for ``telegram.error.TimedOut``.""" + + +class NetworkError(Exception): + """Stand-in for ``telegram.error.NetworkError``.""" + + +class ConnectError(Exception): + """Stand-in for ``httpx.ConnectError``.""" + + +class ReadTimeout(Exception): + """Stand-in for ``httpx.ReadTimeout``.""" + + +class PoolTimeout(Exception): + """Stand-in for ``httpx.PoolTimeout``.""" + + +class ClientConnectorError(Exception): + """Stand-in for ``aiohttp.ClientConnectorError``.""" + + +class SomeUnrelatedBug(Exception): + """A non-transient error that should NOT be swallowed.""" + + +# --------------------------------------------------------------------- +# Classifier +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "exc_cls", + [ + TimedOut, + NetworkError, + ConnectError, + ReadTimeout, + PoolTimeout, + ClientConnectorError, + ], +) +def test_transient_classifier_matches_known_network_errors(exc_cls): + """Every well-known transient network exception class is classified.""" + assert _is_transient_network_error(exc_cls("boom")) is True + + +def test_transient_classifier_rejects_unrelated_errors(): + """Real bugs (ValueError, KeyError, custom app errors) are NOT swallowed.""" + for exc in (ValueError("bad"), KeyError("missing"), SomeUnrelatedBug("x")): + assert _is_transient_network_error(exc) is False + + +def test_transient_classifier_unwraps_cause_chain(): + """A NetworkError wrapping a ConnectError is still classified.""" + inner = ConnectError("connection refused") + outer = NetworkError("upstream failed") + outer.__cause__ = inner + assert _is_transient_network_error(outer) is True + + +def test_transient_classifier_unwraps_context_chain(): + """Implicit ``__context__`` wrapping is also unwrapped.""" + try: + try: + raise TimedOut("upstream timeout") + except TimedOut: + # Re-raise something else with the original as implicit context + raise SomeUnrelatedBug("wrapper") + except SomeUnrelatedBug as e: + wrapped = e + # The wrapper class name is not transient, but the chained context is. + assert _is_transient_network_error(wrapped) is True + + +def test_transient_classifier_does_not_infinite_loop_on_cyclic_cause(): + """A pathological self-referential cause chain terminates.""" + exc = SomeUnrelatedBug("loop") + exc.__cause__ = exc # cycle + # Must return without hanging. + assert _is_transient_network_error(exc) is False + + +# --------------------------------------------------------------------- +# Loop handler +# --------------------------------------------------------------------- + + +def test_handler_swallows_transient_error_and_logs_warning(caplog): + """Transient errors are logged at WARNING but not re-raised.""" + loop = asyncio.new_event_loop() + try: + with caplog.at_level(logging.WARNING, logger="gateway.run"): + _gateway_loop_exception_handler( + loop, + { + "message": "Task exception was never retrieved", + "exception": TimedOut("Timed out"), + }, + ) + # Warning emitted, exception class name appears in the log. + assert any("TimedOut" in r.message for r in caplog.records) + finally: + loop.close() + + +def test_handler_delegates_unknown_errors_to_default(monkeypatch): + """A non-transient error is forwarded to ``loop.default_exception_handler``.""" + loop = asyncio.new_event_loop() + try: + forwarded: list[dict] = [] + + def fake_default(ctx): + forwarded.append(ctx) + + monkeypatch.setattr(loop, "default_exception_handler", fake_default) + + context = { + "message": "Something else broke", + "exception": SomeUnrelatedBug("real bug"), + } + _gateway_loop_exception_handler(loop, context) + assert forwarded == [context] + finally: + loop.close() + + +def test_handler_tolerates_missing_exception_key(monkeypatch): + """Contexts without an ``exception`` key fall through to the default handler.""" + loop = asyncio.new_event_loop() + try: + forwarded: list[dict] = [] + monkeypatch.setattr( + loop, "default_exception_handler", lambda ctx: forwarded.append(ctx) + ) + ctx = {"message": "warning without exception"} + _gateway_loop_exception_handler(loop, ctx) + assert forwarded == [ctx] + finally: + loop.close() + + +# --------------------------------------------------------------------- +# End-to-end: task-level +# --------------------------------------------------------------------- + + +def test_unhandled_transient_error_in_task_does_not_propagate_to_loop(): + """Smoke test the wiring as a loop would actually use it. + + Schedules a task that raises TimedOut and is never awaited. With the + handler installed, the loop completes normally and logs a warning + instead of dying. Without the handler, asyncio would emit + ``Task exception was never retrieved`` and (depending on Python's + debug mode) potentially escalate. + """ + + async def raiser(): + raise TimedOut("upstream timeout") + + async def main(): + loop = asyncio.get_running_loop() + loop.set_exception_handler(_gateway_loop_exception_handler) + task = loop.create_task(raiser()) + # Give the task a tick to run and raise. + await asyncio.sleep(0) + # Don't await ``task`` โ€” let it become an unhandled-exception task. + del task + import gc + + gc.collect() + await asyncio.sleep(0) + + # If the safety net works, this returns cleanly. If not, the test + # would still pass (asyncio's default is a warning, not a crash) โ€” + # the real assertion is that no unhandled exception escapes the + # ``run`` boundary. + asyncio.run(main()) diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index a0fb8f086d8..c7c03b1a8b1 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -797,6 +797,79 @@ class TestMatrixRequirements: with patch("tools.lazy_deps.ensure", side_effect=ImportError("mautrix unavailable")): assert matrix_mod.check_matrix_requirements() is False + def test_check_e2ee_deps_requires_asyncpg(self, monkeypatch): + """E2EE deps check must reject when asyncpg is missing โ€” even if olm is present. + + Regression for #31116: ``mautrix[encryption]`` extra installs python-olm + but NOT asyncpg/aiosqlite, which are required by mautrix's crypto store + at connect time. ``_check_e2ee_deps`` previously only tested + ``OlmMachine`` import and returned True, so the failure manifested as + a confusing ``No module named 'asyncpg'`` deep in + ``MatrixAdapter.connect()``. + """ + from gateway.platforms.matrix import _check_e2ee_deps + import builtins + real_import = builtins.__import__ + + def _blocking_import(name, *args, **kwargs): + if name == "asyncpg" or name.startswith("asyncpg."): + raise ImportError("blocked for test") + return real_import(name, *args, **kwargs) + + with patch.object(builtins, "__import__", _blocking_import): + assert _check_e2ee_deps() is False + + def test_check_e2ee_deps_requires_aiosqlite(self): + """E2EE deps check must reject when aiosqlite is missing. + + Mautrix's ``Database.create("sqlite:///...")`` driver lookup imports + aiosqlite lazily โ€” without it, connect fails at ``crypto_db.start()``. + """ + from gateway.platforms.matrix import _check_e2ee_deps + import builtins + real_import = builtins.__import__ + + def _blocking_import(name, *args, **kwargs): + if name == "aiosqlite" or name.startswith("aiosqlite."): + raise ImportError("blocked for test") + return real_import(name, *args, **kwargs) + + with patch.object(builtins, "__import__", _blocking_import): + assert _check_e2ee_deps() is False + + def test_check_requirements_runs_lazy_install_when_partial(self, monkeypatch): + """When mautrix is installed but asyncpg/aiosqlite are missing, + check_matrix_requirements must still run the lazy installer. + + Regression for #31116: the previous ``try: import mautrix`` gate + short-circuited the install of the OTHER 4 platform.matrix packages, + so a partial install (mautrix only) was treated as fully installed. + """ + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False) + + from gateway.platforms import matrix as matrix_mod + + # Simulate "mautrix installed, asyncpg missing" โ†’ feature_missing + # returns a non-empty tuple โ†’ ensure_and_bind MUST be called. + called = {"ensure_and_bind": False} + + def _fake_ensure_and_bind(feature, importer, target_globals, **kwargs): + called["ensure_and_bind"] = True + assert feature == "platform.matrix" + return True # Pretend install succeeded. + + with patch("tools.lazy_deps.feature_missing", return_value=("asyncpg==0.31.0",)), \ + patch("tools.lazy_deps.ensure_and_bind", side_effect=_fake_ensure_and_bind): + matrix_mod.check_matrix_requirements() + + assert called["ensure_and_bind"], ( + "check_matrix_requirements must call ensure_and_bind whenever ANY " + "platform.matrix dep is missing, not just when mautrix itself is " + "missing (#31116)" + ) + # --------------------------------------------------------------------------- # Access-token auth / E2EE bootstrap diff --git a/tests/gateway/test_msgraph_webhook.py b/tests/gateway/test_msgraph_webhook.py index d97c98492ae..bddcf419014 100644 --- a/tests/gateway/test_msgraph_webhook.py +++ b/tests/gateway/test_msgraph_webhook.py @@ -6,7 +6,7 @@ import json import pytest from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides -from gateway.platforms.msgraph_webhook import MSGraphWebhookAdapter +from gateway.platforms.msgraph_webhook import AIOHTTP_AVAILABLE, MSGraphWebhookAdapter def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter: @@ -70,6 +70,16 @@ class TestMSGraphWebhookConfig: class TestMSGraphValidationHandshake: + @pytest.mark.anyio + async def test_connect_requires_client_state(self): + if not AIOHTTP_AVAILABLE: + pytest.skip("aiohttp not installed") + adapter = MSGraphWebhookAdapter(PlatformConfig(enabled=True, extra={})) + connected = await adapter.connect() + assert connected is False + # is_connected is a @property on the base adapter, not a method. + assert adapter.is_connected is False + @pytest.mark.anyio async def test_validation_token_echo_on_get(self): adapter = _make_adapter() @@ -99,6 +109,22 @@ class TestMSGraphValidationHandshake: class TestMSGraphNotifications: + @pytest.mark.anyio + async def test_missing_client_state_is_auth_rejected(self): + adapter = _make_adapter(client_state=None) + payload = { + "value": [ + { + "id": "notif-no-client-state", + "subscriptionId": "sub-1", + "changeType": "updated", + "resource": "communications/onlineMeetings/meeting-1", + } + ] + } + resp = await adapter._handle_notification(_FakeRequest(json_payload=payload)) + assert resp.status == 403 + @pytest.mark.anyio async def test_valid_notification_accepted_and_scheduled(self): adapter = _make_adapter() diff --git a/tests/gateway/test_ntfy_plugin.py b/tests/gateway/test_ntfy_plugin.py new file mode 100644 index 00000000000..40cf148de44 --- /dev/null +++ b/tests/gateway/test_ntfy_plugin.py @@ -0,0 +1,943 @@ +"""Tests for the ntfy platform-plugin adapter. + +Loaded via the ``_plugin_adapter_loader`` helper so this lives under +``plugin_adapter_ntfy`` in ``sys.modules`` and cannot collide with +sibling platform-plugin tests on the same xdist worker. + +Most tests target the adapter class directly. The plugin-shape tests +(``register()``, ``_env_enablement``, ``_standalone_send``, registry +presence) replace the core-file grep tests from the original PR โ€” the +ntfy adapter no longer modifies ``gateway/config.py``, ``gateway/run.py``, +``cron/scheduler.py``, ``toolsets.py``, etc. Everything routes through +the ``platform_registry``. +""" + +from __future__ import annotations + +import asyncio +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import PlatformConfig +from tests.gateway._plugin_adapter_loader import load_plugin_adapter + +_ntfy = load_plugin_adapter("ntfy") + +NtfyAdapter = _ntfy.NtfyAdapter +check_requirements = _ntfy.check_requirements +validate_config = _ntfy.validate_config +is_connected = _ntfy.is_connected +register = _ntfy.register +_env_enablement = _ntfy._env_enablement +_standalone_send = _ntfy._standalone_send +DEFAULT_SERVER = _ntfy.DEFAULT_SERVER +DEDUP_WINDOW_SECONDS = _ntfy.DEDUP_WINDOW_SECONDS +DEDUP_MAX_SIZE = _ntfy.DEDUP_MAX_SIZE +MAX_MESSAGE_LENGTH = _ntfy.MAX_MESSAGE_LENGTH + + +def _run(coro): + """Run an async coroutine synchronously.""" + return asyncio.get_event_loop().run_until_complete(coro) + + +# --------------------------------------------------------------------------- +# 1. Platform enum (plugin-discovered, not bundled) +# --------------------------------------------------------------------------- + + +def test_platform_enum_resolves_via_plugin_scan(): + """The plugin filesystem scan should expose Platform("ntfy").""" + from gateway.config import Platform + p = Platform("ntfy") + assert p.value == "ntfy" + # Identity stability โ€” repeated lookups return the same pseudo-member + assert Platform("ntfy") is p + + +# --------------------------------------------------------------------------- +# 2. check_requirements / validate_config / is_connected +# --------------------------------------------------------------------------- + + +class TestNtfyRequirements: + + def test_returns_false_when_httpx_unavailable(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-test") + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False) + assert check_requirements() is False + + def test_returns_false_when_topic_not_set(self, monkeypatch): + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True) + monkeypatch.delenv("NTFY_TOPIC", raising=False) + assert check_requirements() is False + + def test_returns_true_when_topic_set_via_env(self, monkeypatch): + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True) + monkeypatch.setenv("NTFY_TOPIC", "hermes-test") + assert check_requirements() is True + + def test_validate_config_requires_topic(self, monkeypatch): + monkeypatch.delenv("NTFY_TOPIC", raising=False) + assert validate_config(PlatformConfig(enabled=True, extra={})) is False + assert validate_config( + PlatformConfig(enabled=True, extra={"topic": "t"}) + ) is True + + def test_is_connected_from_extra(self, monkeypatch): + monkeypatch.delenv("NTFY_TOPIC", raising=False) + assert is_connected(PlatformConfig(enabled=True, extra={"topic": "t"})) is True + assert is_connected(PlatformConfig(enabled=True, extra={})) is False + + def test_is_connected_from_env(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "env-topic") + assert is_connected(PlatformConfig(enabled=True, extra={})) is True + + +# --------------------------------------------------------------------------- +# 3. Adapter init +# --------------------------------------------------------------------------- + + +class TestNtfyAdapterInit: + + def test_default_server_url(self, monkeypatch): + monkeypatch.delenv("NTFY_SERVER_URL", raising=False) + config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"}) + adapter = NtfyAdapter(config) + assert adapter._server == DEFAULT_SERVER.rstrip("/") + + def test_topic_read_from_extra(self): + config = PlatformConfig(enabled=True, extra={"topic": "my-topic"}) + adapter = NtfyAdapter(config) + assert adapter._topic == "my-topic" + + def test_topic_read_from_env(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "env-topic") + config = PlatformConfig(enabled=True, extra={}) + adapter = NtfyAdapter(config) + assert adapter._topic == "env-topic" + + def test_publish_topic_falls_back_to_topic(self, monkeypatch): + monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False) + config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"}) + adapter = NtfyAdapter(config) + assert adapter._publish_topic == "hermes-in" + + def test_publish_topic_uses_extra_value(self): + config = PlatformConfig( + enabled=True, + extra={"topic": "hermes-in", "publish_topic": "hermes-out"}, + ) + adapter = NtfyAdapter(config) + assert adapter._publish_topic == "hermes-out" + + def test_token_read_from_extra(self): + config = PlatformConfig(enabled=True, extra={"topic": "t", "token": "tok-123"}) + adapter = NtfyAdapter(config) + assert adapter._token == "tok-123" + + def test_token_read_from_env(self, monkeypatch): + monkeypatch.setenv("NTFY_TOKEN", "env-token") + config = PlatformConfig(enabled=True, extra={"topic": "t"}) + adapter = NtfyAdapter(config) + assert adapter._token == "env-token" + + def test_server_trailing_slash_stripped(self): + config = PlatformConfig( + enabled=True, + extra={"topic": "t", "server": "https://ntfy.example.com/"}, + ) + adapter = NtfyAdapter(config) + assert not adapter._server.endswith("/") + + def test_initial_state(self): + config = PlatformConfig(enabled=True, extra={"topic": "t"}) + adapter = NtfyAdapter(config) + assert adapter._stream_task is None + assert adapter._http_client is None + assert adapter._seen_messages == {} + + +# --------------------------------------------------------------------------- +# 4. Auth headers +# --------------------------------------------------------------------------- + + +class TestAuthHeaders: + + def _make_adapter(self, token=""): + config = PlatformConfig(enabled=True, extra={"topic": "t", "token": token}) + return NtfyAdapter(config) + + def test_no_token_returns_empty_dict(self): + adapter = self._make_adapter(token="") + assert adapter._auth_headers() == {} + + def test_bearer_token_for_plain_token(self): + adapter = self._make_adapter(token="myapitoken") + headers = adapter._auth_headers() + assert headers["Authorization"] == "Bearer myapitoken" + + def test_basic_auth_for_user_colon_password(self): + adapter = self._make_adapter(token="user:pass") + headers = adapter._auth_headers() + assert headers["Authorization"].startswith("Basic ") + import base64 + expected = "Basic " + base64.b64encode(b"user:pass").decode() + assert headers["Authorization"] == expected + + def test_bearer_token_used_when_no_colon(self): + adapter = self._make_adapter(token="noColonHere") + headers = adapter._auth_headers() + assert headers["Authorization"] == "Bearer noColonHere" + + def test_auth_header_key_is_authorization(self): + adapter = self._make_adapter(token="tok") + headers = adapter._auth_headers() + assert list(headers.keys()) == ["Authorization"] + + +# --------------------------------------------------------------------------- +# 5. Deduplication +# --------------------------------------------------------------------------- + + +class TestDeduplication: + + def _make_adapter(self): + return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + + def test_first_message_not_duplicate(self): + adapter = self._make_adapter() + assert adapter._is_duplicate("msg-1") is False + + def test_second_occurrence_is_duplicate(self): + adapter = self._make_adapter() + adapter._is_duplicate("msg-1") + assert adapter._is_duplicate("msg-1") is True + + def test_different_ids_not_duplicate(self): + adapter = self._make_adapter() + adapter._is_duplicate("msg-1") + assert adapter._is_duplicate("msg-2") is False + + def test_many_messages_recorded(self): + adapter = self._make_adapter() + for i in range(50): + adapter._is_duplicate(f"msg-{i}") + assert len(adapter._seen_messages) == 50 + + def test_cache_pruned_on_overflow(self): + adapter = self._make_adapter() + for i in range(DEDUP_MAX_SIZE + 20): + adapter._is_duplicate(f"msg-{i}") + assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 20 + + def test_expired_id_can_be_seen_again(self): + import time + adapter = self._make_adapter() + adapter._seen_messages["old-msg"] = time.time() - DEDUP_WINDOW_SECONDS - 1 + for i in range(DEDUP_MAX_SIZE + 1): + adapter._is_duplicate(f"fill-{i}") + assert adapter._is_duplicate("old-msg") is False + + +# --------------------------------------------------------------------------- +# 6. connect() / disconnect() +# --------------------------------------------------------------------------- + + +class TestConnect: + + def test_connect_fails_when_httpx_unavailable(self, monkeypatch): + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False) + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + result = _run(adapter.connect()) + assert result is False + + def test_connect_fails_when_no_topic(self, monkeypatch): + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True) + monkeypatch.delenv("NTFY_TOPIC", raising=False) + config = PlatformConfig(enabled=True, extra={}) + adapter = NtfyAdapter(config) + result = _run(adapter.connect()) + assert result is False + + def test_connect_starts_stream_task(self, monkeypatch): + monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True) + config = PlatformConfig(enabled=True, extra={"topic": "hermes-test"}) + adapter = NtfyAdapter(config) + + with patch.object(adapter, "_run_stream", new_callable=AsyncMock): + with patch.object(_ntfy, "httpx") as mock_httpx: + mock_httpx.AsyncClient.return_value = MagicMock() + result = _run(adapter.connect()) + + assert result is True + assert adapter._stream_task is not None + adapter._stream_task.cancel() + try: + _run(adapter._stream_task) + except (asyncio.CancelledError, Exception): + pass + + def test_disconnect_clears_state(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + adapter._seen_messages["x"] = 1.0 + adapter._http_client = AsyncMock() + adapter._stream_task = None + adapter._running = True + + _run(adapter.disconnect()) + + assert adapter._seen_messages == {} + assert adapter._http_client is None + assert adapter._running is False + + def test_disconnect_cancels_stream_task(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + + async def _hang(): + await asyncio.sleep(9999) + + loop = asyncio.get_event_loop() + adapter._stream_task = loop.create_task(_hang()) + adapter._http_client = AsyncMock() + adapter._running = True + + _run(adapter.disconnect()) + assert adapter._stream_task is None + + +# --------------------------------------------------------------------------- +# 7. send() +# --------------------------------------------------------------------------- + + +class TestSend: + + def _make_adapter(self, topic="hermes-in", publish_topic="", token="", markdown=False): + extra: dict = {"topic": topic, "token": token} + if publish_topic: + extra["publish_topic"] = publish_topic + if markdown: + extra["markdown"] = True + return NtfyAdapter(PlatformConfig(enabled=True, extra=extra)) + + def test_send_fails_without_http_client(self): + adapter = self._make_adapter() + result = _run(adapter.send("hermes-in", "hello")) + assert result.success is False + assert "not initialized" in result.error.lower() + + def test_send_posts_to_publish_topic(self): + adapter = self._make_adapter(topic="hermes-in", publish_topic="hermes-out") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "abc123"} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + result = _run(adapter.send("hermes-in", "Hello ntfy!")) + assert result.success is True + assert result.message_id == "abc123" + + posted_url = mock_client.post.call_args[0][0] + assert posted_url.endswith("/hermes-out") + + def test_send_falls_back_to_subscribe_topic(self): + adapter = self._make_adapter(topic="hermes-in") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + result = _run(adapter.send("hermes-in", "Hello!")) + assert result.success is True + posted_url = mock_client.post.call_args[0][0] + assert posted_url.endswith("/hermes-in") + + def test_send_uses_metadata_publish_topic(self): + adapter = self._make_adapter(topic="hermes-in") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + result = _run(adapter.send( + "hermes-in", "Hi!", metadata={"publish_topic": "override-out"} + )) + assert result.success is True + posted_url = mock_client.post.call_args[0][0] + assert posted_url.endswith("/override-out") + + def test_send_handles_http_error_status(self): + adapter = self._make_adapter(topic="hermes-in") + + mock_resp = MagicMock() + mock_resp.status_code = 403 + mock_resp.text = "Forbidden" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + result = _run(adapter.send("hermes-in", "Hello!")) + assert result.success is False + assert "403" in result.error + + def test_send_handles_timeout(self): + adapter = self._make_adapter(topic="hermes-in") + + class _FakeTimeout(Exception): + pass + + fake_httpx = MagicMock() + fake_httpx.TimeoutException = _FakeTimeout + + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=_FakeTimeout("timed out")) + adapter._http_client = mock_client + + with patch.object(_ntfy, "httpx", fake_httpx): + result = _run(adapter.send("hermes-in", "Hello!")) + + assert result.success is False + assert "timeout" in result.error.lower() + + def test_send_truncates_to_max_length(self): + adapter = self._make_adapter(topic="t") + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + long_msg = "x" * (MAX_MESSAGE_LENGTH + 500) + _run(adapter.send("t", long_msg)) + + posted_body = mock_client.post.call_args[1]["content"] + assert len(posted_body.decode()) <= MAX_MESSAGE_LENGTH + + def test_send_typing_is_noop(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + _run(adapter.send_typing("t")) # must not raise + + def test_get_chat_info_returns_dict(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + info = _run(adapter.get_chat_info("hermes-in")) + assert info["name"] == "hermes-in" + assert info["type"] == "dm" + + def test_send_includes_bearer_auth_header(self): + adapter = self._make_adapter(topic="hermes-in", token="mytoken") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + _run(adapter.send("hermes-in", "secure message")) + + call_headers = mock_client.post.call_args[1]["headers"] + assert call_headers.get("Authorization") == "Bearer mytoken" + + def test_send_emits_markdown_header_when_enabled(self): + adapter = self._make_adapter(topic="hermes-in", markdown=True) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + _run(adapter.send("hermes-in", "**bold**")) + call_headers = mock_client.post.call_args[1]["headers"] + assert call_headers.get("X-Markdown") == "true" + + def test_send_omits_markdown_header_when_disabled(self): + adapter = self._make_adapter(topic="hermes-in", markdown=False) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + adapter._http_client = mock_client + + _run(adapter.send("hermes-in", "plain")) + call_headers = mock_client.post.call_args[1]["headers"] + assert "X-Markdown" not in call_headers + + +# --------------------------------------------------------------------------- +# 8. Inbound message processing (identity invariant โ€” security-critical) +# --------------------------------------------------------------------------- + + +class TestOnMessage: + + def _make_adapter(self): + return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "hermes-in"})) + + def test_message_dispatched_to_handler(self): + adapter = self._make_adapter() + calls = [] + + async def handler(event): + calls.append(event) + + adapter.set_message_handler(handler) + + event = { + "id": "evt-001", + "event": "message", + "topic": "hermes-in", + "message": "Hello from ntfy", + "time": 1700000000, + } + _run(adapter._on_message(event)) + assert len(calls) == 1 + assert calls[0].text == "Hello from ntfy" + + def test_empty_message_skipped(self): + adapter = self._make_adapter() + calls = [] + + async def handler(event): + calls.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "x", "event": "message", "topic": "t", "message": "", "time": None + })) + assert calls == [] + + def test_duplicate_message_skipped(self): + adapter = self._make_adapter() + calls = [] + + async def handler(event): + calls.append(event) + + adapter.set_message_handler(handler) + event = {"id": "dup-1", "event": "message", "topic": "hermes-in", "message": "hi", "time": None} + _run(adapter._on_message(event)) + _run(adapter._on_message(event)) + assert len(calls) == 1 + + def test_timestamp_parsed_from_event(self): + from datetime import timezone + adapter = self._make_adapter() + captured = [] + + async def handler(event): + captured.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "ts-1", + "event": "message", + "topic": "hermes-in", + "message": "ping", + "time": 1700000000, + })) + ts = captured[0].timestamp + assert ts.tzinfo == timezone.utc + + def test_message_id_set_from_event(self): + adapter = self._make_adapter() + captured = [] + + async def handler(event): + captured.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "ntfy-id-42", + "event": "message", + "topic": "hermes-in", + "message": "test", + "time": None, + })) + assert captured[0].message_id == "ntfy-id-42" + + def test_title_not_used_as_user_id(self): + """title field must not be used for identity โ€” it is publisher-controlled.""" + adapter = self._make_adapter() + captured = [] + + async def handler(event): + captured.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "u-1", + "event": "message", + "topic": "hermes-in", + "message": "hello", + "title": "Alice", + "time": None, + })) + assert captured[0].source.user_id == "hermes-in" + assert captured[0].source.user_name == "hermes-in" + + def test_unknown_publisher_cannot_impersonate_allowed_user(self): + """An unknown publisher setting title=admin must not gain admin identity.""" + adapter = self._make_adapter() + captured = [] + + async def handler(event): + captured.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "u-2", + "event": "message", + "topic": "hermes-in", + "message": "sensitive command", + "title": "admin", + "time": None, + })) + assert captured[0].source.user_id == "hermes-in" + assert captured[0].source.user_id != "admin" + + def test_source_chat_id_is_topic(self): + adapter = self._make_adapter() + captured = [] + + async def handler(event): + captured.append(event) + + adapter.set_message_handler(handler) + _run(adapter._on_message({ + "id": "s-1", + "event": "message", + "topic": "hermes-in", + "message": "hello", + "time": None, + })) + assert captured[0].source.chat_id == "hermes-in" + + +# --------------------------------------------------------------------------- +# 9. _env_enablement() โ€” env-only auto-config +# --------------------------------------------------------------------------- + + +class TestEnvEnablement: + + def test_returns_none_without_topic(self, monkeypatch): + monkeypatch.delenv("NTFY_TOPIC", raising=False) + assert _env_enablement() is None + + def test_seeds_topic_and_server(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.delenv("NTFY_SERVER_URL", raising=False) + seed = _env_enablement() + assert seed is not None + assert seed["topic"] == "hermes-in" + assert seed["server"] == DEFAULT_SERVER + + def test_custom_server_url(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.setenv("NTFY_SERVER_URL", "https://ntfy.example.com/") + seed = _env_enablement() + assert seed["server"] == "https://ntfy.example.com" # trailing slash stripped + + def test_publish_topic_seeded(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.setenv("NTFY_PUBLISH_TOPIC", "hermes-out") + seed = _env_enablement() + assert seed["publish_topic"] == "hermes-out" + + def test_token_seeded(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.setenv("NTFY_TOKEN", "tk_abc") + seed = _env_enablement() + assert seed["token"] == "tk_abc" + + def test_markdown_truthy_values(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + for val in ("true", "1", "yes", "TRUE"): + monkeypatch.setenv("NTFY_MARKDOWN", val) + assert _env_enablement()["markdown"] is True + + def test_markdown_falsy_values(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + for val in ("false", "0", "no", "anything"): + monkeypatch.setenv("NTFY_MARKDOWN", val) + assert _env_enablement()["markdown"] is False + + def test_home_channel_defaults_to_topic(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.delenv("NTFY_HOME_CHANNEL", raising=False) + seed = _env_enablement() + assert seed["home_channel"]["chat_id"] == "hermes-in" + assert seed["home_channel"]["name"] == "hermes-in" + + def test_home_channel_override(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + monkeypatch.setenv("NTFY_HOME_CHANNEL", "alerts") + monkeypatch.setenv("NTFY_HOME_CHANNEL_NAME", "Alerts Channel") + seed = _env_enablement() + assert seed["home_channel"]["chat_id"] == "alerts" + assert seed["home_channel"]["name"] == "Alerts Channel" + + +# --------------------------------------------------------------------------- +# 10. _standalone_send() โ€” out-of-process cron delivery +# --------------------------------------------------------------------------- + + +class TestStandaloneSend: + + def test_errors_without_topic(self, monkeypatch): + monkeypatch.delenv("NTFY_TOPIC", raising=False) + monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False) + pconfig = MagicMock() + pconfig.extra = {} + result = _run(_standalone_send(pconfig, "", "hello")) + assert "error" in result + assert "NTFY_TOPIC" in result["error"] + + def test_posts_to_server(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + pconfig = MagicMock() + pconfig.extra = {"server": "https://ntfy.example.com", "topic": "hermes-in"} + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "id-42"} + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_ntfy, "httpx") as mock_httpx: + mock_httpx.AsyncClient.return_value = mock_client + result = _run(_standalone_send(pconfig, "hermes-in", "hello")) + + assert result.get("success") is True + assert result["platform"] == "ntfy" + assert result["message_id"] == "id-42" + posted_url = mock_client.post.call_args[0][0] + assert posted_url == "https://ntfy.example.com/hermes-in" + + def test_emits_bearer_token_when_configured(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + pconfig = MagicMock() + pconfig.extra = {"topic": "hermes-in", "token": "tk_xyz"} + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_ntfy, "httpx") as mock_httpx: + mock_httpx.AsyncClient.return_value = mock_client + _run(_standalone_send(pconfig, "hermes-in", "hi")) + + headers = mock_client.post.call_args[1]["headers"] + assert headers["Authorization"] == "Bearer tk_xyz" + + def test_basic_auth_when_token_has_colon(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + pconfig = MagicMock() + pconfig.extra = {"topic": "hermes-in", "token": "user:pass"} + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_ntfy, "httpx") as mock_httpx: + mock_httpx.AsyncClient.return_value = mock_client + _run(_standalone_send(pconfig, "hermes-in", "hi")) + + headers = mock_client.post.call_args[1]["headers"] + assert headers["Authorization"].startswith("Basic ") + + def test_returns_error_on_http_failure(self, monkeypatch): + monkeypatch.setenv("NTFY_TOPIC", "hermes-in") + pconfig = MagicMock() + pconfig.extra = {"topic": "hermes-in"} + + mock_resp = MagicMock() + mock_resp.status_code = 403 + mock_resp.text = "Forbidden" + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_resp) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch.object(_ntfy, "httpx") as mock_httpx: + mock_httpx.AsyncClient.return_value = mock_client + result = _run(_standalone_send(pconfig, "hermes-in", "hi")) + + assert "error" in result + assert "403" in result["error"] + + +# --------------------------------------------------------------------------- +# 11. register() โ€” plugin-side metadata +# --------------------------------------------------------------------------- + + +def test_register_calls_register_platform(): + ctx = MagicMock() + register(ctx) + ctx.register_platform.assert_called_once() + kwargs = ctx.register_platform.call_args.kwargs + assert kwargs["name"] == "ntfy" + assert kwargs["label"] == "ntfy" + assert kwargs["required_env"] == ["NTFY_TOPIC"] + assert kwargs["allowed_users_env"] == "NTFY_ALLOWED_USERS" + assert kwargs["allow_all_env"] == "NTFY_ALLOW_ALL_USERS" + assert kwargs["cron_deliver_env_var"] == "NTFY_HOME_CHANNEL" + assert kwargs["max_message_length"] == MAX_MESSAGE_LENGTH + assert callable(kwargs["check_fn"]) + assert callable(kwargs["validate_config"]) + assert callable(kwargs["is_connected"]) + assert callable(kwargs["env_enablement_fn"]) + assert callable(kwargs["standalone_sender_fn"]) + assert callable(kwargs["adapter_factory"]) + # ntfy has no user-identifying PII (only topic names) + assert kwargs["pii_safe"] is True + assert "ntfy" in kwargs["platform_hint"].lower() + + +def test_adapter_factory_returns_ntfy_adapter(): + ctx = MagicMock() + register(ctx) + factory = ctx.register_platform.call_args.kwargs["adapter_factory"] + cfg = PlatformConfig(enabled=True, extra={"topic": "t"}) + adapter = factory(cfg) + assert isinstance(adapter, NtfyAdapter) + + +# --------------------------------------------------------------------------- +# 12. Robustness โ€” token hygiene + fatal-state propagation +# --------------------------------------------------------------------------- + + +class TestTokenHygiene: + """``_build_auth_header`` must strip pasted-token whitespace; pasted + tokens often carry trailing newlines that break the Authorization line.""" + + def test_trailing_whitespace_stripped(self): + assert _ntfy._build_auth_header(" tok123 ") == {"Authorization": "Bearer tok123"} + + def test_trailing_newline_stripped(self): + assert _ntfy._build_auth_header("tok123\n") == {"Authorization": "Bearer tok123"} + + def test_whitespace_only_returns_empty(self): + assert _ntfy._build_auth_header(" \n ") == {} + + def test_basic_auth_token_also_stripped(self): + h = _ntfy._build_auth_header(" user:pass ") + assert h["Authorization"].startswith("Basic ") + import base64 + assert h["Authorization"] == "Basic " + base64.b64encode(b"user:pass").decode() + + def test_adapter_strips_token_via_helper(self): + """The adapter delegates to _build_auth_header, so token whitespace + passed via config.extra is also stripped.""" + config = PlatformConfig(enabled=True, extra={"topic": "t", "token": " tok\n"}) + adapter = NtfyAdapter(config) + assert adapter._auth_headers() == {"Authorization": "Bearer tok"} + + +class TestFatalErrorPropagation: + """When the stream hits 401/404, the adapter must transition to the + ``fatal`` state via ``_set_fatal_error`` so the gateway's runtime + status reflects reality instead of staying 'connected'.""" + + def test_401_sets_fatal_unauthorized(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"})) + adapter._http_client = MagicMock() + + # Mock the streaming response + mock_response = MagicMock() + mock_response.status_code = 401 + # async-context-manager flavor for httpx.stream + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_cm.__aexit__ = AsyncMock(return_value=None) + adapter._http_client.stream = MagicMock(return_value=mock_cm) + + fake_httpx = MagicMock() + fake_httpx.Timeout = MagicMock() + with patch.object(_ntfy, "httpx", fake_httpx): + with pytest.raises(_ntfy._FatalStreamError): + _run(adapter._consume_stream("https://ntfy.example/t/json", {})) + + assert adapter.has_fatal_error is True + assert adapter._fatal_error_code == "ntfy_unauthorized" + assert adapter._fatal_error_retryable is False + + def test_404_sets_fatal_topic_not_found(self): + adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "missing-topic"})) + adapter._http_client = MagicMock() + + mock_response = MagicMock() + mock_response.status_code = 404 + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_cm.__aexit__ = AsyncMock(return_value=None) + adapter._http_client.stream = MagicMock(return_value=mock_cm) + + fake_httpx = MagicMock() + fake_httpx.Timeout = MagicMock() + with patch.object(_ntfy, "httpx", fake_httpx): + with pytest.raises(_ntfy._FatalStreamError): + _run(adapter._consume_stream("https://ntfy.example/missing-topic/json", {})) + + assert adapter.has_fatal_error is True + assert adapter._fatal_error_code == "ntfy_topic_not_found" + assert "missing-topic" in adapter._fatal_error_message + assert adapter._fatal_error_retryable is False + + +class TestTruncateHelper: + """``_truncate_body`` is shared between adapter.send() (inline truncation + today, may migrate) and ``_standalone_send``. It must cap to + MAX_MESSAGE_LENGTH and return bytes.""" + + def test_short_message_passes_through(self): + assert _ntfy._truncate_body("hi", context="test") == b"hi" + + def test_long_message_truncated(self): + long = "x" * (MAX_MESSAGE_LENGTH + 50) + result = _ntfy._truncate_body(long, context="test") + assert isinstance(result, bytes) + assert len(result) == MAX_MESSAGE_LENGTH + + def test_unicode_message_encoded(self): + result = _ntfy._truncate_body("hรฉllo ๐Ÿ””", context="test") + assert result == "hรฉllo ๐Ÿ””".encode("utf-8") diff --git a/tests/gateway/test_pairing.py b/tests/gateway/test_pairing.py index 36e6bda15dd..0bff131ed1a 100644 --- a/tests/gateway/test_pairing.py +++ b/tests/gateway/test_pairing.py @@ -2,10 +2,13 @@ import json import os +import sys import time from pathlib import Path from unittest.mock import patch +import pytest + from gateway.pairing import ( PairingStore, ALPHABET, @@ -37,6 +40,10 @@ class TestSecureWrite: assert target.exists() assert json.loads(target.read_text()) == {"hello": "world"} + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="POSIX file modes are not enforced on Windows", + ) def test_sets_file_permissions(self, tmp_path): target = tmp_path / "secret.json" _secure_write(target, "data") @@ -75,9 +82,197 @@ class TestCodeGeneration: code = store.generate_code("telegram", "user1", "Alice") pending = store.list_pending("telegram") assert len(pending) == 1 - assert pending[0]["code"] == code + # list_pending no longer returns the original code โ€” it returns a + # truncated hash prefix. Verify the metadata is correct instead. assert pending[0]["user_id"] == "user1" assert pending[0]["user_name"] == "Alice" + # The code field is now a hash prefix, not the original plaintext code + assert pending[0]["code"] != code + + +# --------------------------------------------------------------------------- +# Hashed storage +# --------------------------------------------------------------------------- + + +class TestHashedStorage: + def test_pending_file_contains_hash_and_salt(self, tmp_path): + """Stored entries must have 'hash' and 'salt', never the plaintext code.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Alice") + raw = json.loads( + (tmp_path / "telegram-pending.json").read_text(encoding="utf-8") + ) + + assert len(raw) == 1 + entry = next(iter(raw.values())) + # Must have hash and salt fields + assert "hash" in entry + assert "salt" in entry + # Hash must be a valid hex SHA-256 digest (64 hex chars) + assert len(entry["hash"]) == 64 + assert all(c in "0123456789abcdef" for c in entry["hash"]) + # Salt must be a valid hex string (32 hex chars for 16 bytes) + assert len(entry["salt"]) == 32 + assert all(c in "0123456789abcdef" for c in entry["salt"]) + # The plaintext code must NOT appear as a key or value anywhere + assert code not in raw # not a key + for key, val in raw.items(): + assert code != key + for field_val in val.values(): + if isinstance(field_val, str): + assert field_val != code + + def test_plaintext_code_not_stored(self, tmp_path): + """The raw JSON file must not contain the plaintext code anywhere.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1") + raw_text = (tmp_path / "telegram-pending.json").read_text(encoding="utf-8") + assert code not in raw_text + + def test_valid_code_verifies_against_hash(self, tmp_path): + """approve_code with the correct code should succeed.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("telegram", "user1", "Bob") + result = store.approve_code("telegram", code) + assert result is not None + assert result["user_id"] == "user1" + assert result["user_name"] == "Bob" + + def test_invalid_code_rejected(self, tmp_path): + """approve_code with a wrong code should fail.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + store.generate_code("telegram", "user1") + result = store.approve_code("telegram", "ZZZZZZZZ") + assert result is None + + def test_different_salts_per_entry(self, tmp_path): + """Each pending entry should have a unique salt.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + store.generate_code("telegram", "user0") + store.generate_code("telegram", "user1") + store.generate_code("telegram", "user2") + raw = json.loads( + (tmp_path / "telegram-pending.json").read_text(encoding="utf-8") + ) + salts = [entry["salt"] for entry in raw.values()] + assert len(set(salts)) == 3 # all unique + + def test_hash_code_static_method(self, tmp_path): + """_hash_code should be deterministic for the same code+salt.""" + salt = os.urandom(16) + h1 = PairingStore._hash_code("ABCD1234", salt) + h2 = PairingStore._hash_code("ABCD1234", salt) + assert h1 == h2 + # Different salt should produce a different hash + salt2 = os.urandom(16) + h3 = PairingStore._hash_code("ABCD1234", salt2) + assert h3 != h1 + + +class TestLegacyPendingFileCompat: + """Defensive coverage for pre-hash pending.json on upgraded installs. + + Existing user installs may have a pending.json written by the old + code (plaintext code as key, no hash/salt fields). The new + approve_code / list_pending / _cleanup_expired must not crash on + those entries โ€” they should be ignored and aged out at TTL. + """ + + @staticmethod + def _write_legacy(tmp_path, code="ABCD1234", created_at=None): + """Write a pre-hash pending.json with plaintext code as the key.""" + import time as _time + if created_at is None: + created_at = _time.time() + legacy = { + code: { + "user_id": "legacy-user", + "user_name": "Legacy", + "created_at": created_at, + } + } + (tmp_path / "telegram-pending.json").write_text( + json.dumps(legacy), encoding="utf-8" + ) + + def test_approve_code_ignores_legacy_entries(self, tmp_path): + """A valid old-format code must NOT silently approve under the new schema.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + self._write_legacy(tmp_path, code="LEGACY01") + store = PairingStore() + # The plaintext "code" used to be the key โ€” under the new schema + # it's not even looked at, and there's no hash/salt to verify. + # Result: approve_code returns None, the legacy entry is left + # alone (gets pruned by _cleanup_expired at TTL). + result = store.approve_code("telegram", "LEGACY01") + assert result is None + # Approved list must be empty + assert store.is_approved("telegram", "legacy-user") is False + + def test_list_pending_handles_legacy_entries(self, tmp_path): + """list_pending must not KeyError on a missing 'hash' field.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + self._write_legacy(tmp_path) + store = PairingStore() + pending = store.list_pending("telegram") + assert len(pending) == 1 + assert pending[0]["user_id"] == "legacy-user" + assert pending[0]["code"] == "legacy" # placeholder + + def test_cleanup_expired_removes_legacy_at_ttl(self, tmp_path): + """Legacy entries past CODE_TTL must still get pruned.""" + import time as _time + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + self._write_legacy( + tmp_path, + code="LEGACY99", + created_at=_time.time() - CODE_TTL_SECONDS - 1, + ) + store = PairingStore() + store._cleanup_expired("telegram") + raw = json.loads( + (tmp_path / "telegram-pending.json").read_text(encoding="utf-8") + ) + assert raw == {} + + def test_cleanup_expired_handles_malformed_entries(self, tmp_path): + """Non-dict / missing-created_at entries get evicted, not crashed on.""" + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + (tmp_path / "telegram-pending.json").write_text( + json.dumps({ + "broken1": "not a dict", + "broken2": {"user_id": "x"}, # no created_at + "broken3": {"created_at": "not a number"}, + }), + encoding="utf-8", + ) + store = PairingStore() + store._cleanup_expired("telegram") + raw = json.loads( + (tmp_path / "telegram-pending.json").read_text(encoding="utf-8") + ) + assert raw == {} + + def test_approve_code_skips_malformed_entries(self, tmp_path): + """Malformed entries must not crash approve_code's hash loop.""" + import time as _time + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + (tmp_path / "telegram-pending.json").write_text( + json.dumps({ + "broken": {"user_id": "x", "created_at": _time.time(), + "salt": "not-hex", "hash": "doesntmatter"}, + }), + encoding="utf-8", + ) + store = PairingStore() + # Approving with any code must just return None, not crash. + assert store.approve_code("telegram", "ABCD1234") is None # --------------------------------------------------------------------------- @@ -117,6 +312,23 @@ class TestRateLimiting: assert isinstance(code2, str) and len(code2) == CODE_LENGTH assert code2 != code1 + def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net") + code2 = store.generate_code("whatsapp", "999999999999999@lid") + + assert isinstance(code1, str) and len(code1) == CODE_LENGTH + assert code2 is None + # --------------------------------------------------------------------------- # Max pending limit @@ -209,6 +421,55 @@ class TestApprovalFlow: result = store.approve_code("telegram", "INVALIDCODE") assert result is None + def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice") + store.approve_code("whatsapp", code) + + assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True + assert store.is_approved("whatsapp", "999999999999999@lid") is True + + approved = store.list_approved("whatsapp") + + assert len(approved) == 1 + assert approved[0]["user_id"] == "15551234567" + + def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch): + mapping_dir = tmp_path / "whatsapp" / "session" + mapping_dir.mkdir(parents=True, exist_ok=True) + (mapping_dir / "lid-mapping-999999999999999.json").write_text( + json.dumps("15551234567@s.whatsapp.net"), + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + approved_path = tmp_path / "whatsapp-approved.json" + approved_path.write_text( + json.dumps( + { + "15551234567@s.whatsapp.net": { + "user_name": "Legacy Alice", + "approved_at": time.time(), + } + }, + indent=2, + ), + encoding="utf-8", + ) + + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + assert store.is_approved("whatsapp", "999999999999999@lid") is True + # --------------------------------------------------------------------------- # Lockout after failed attempts @@ -300,9 +561,10 @@ class TestCodeExpiry: store = PairingStore() code = store.generate_code("telegram", "user1") - # Manually expire the code + # Manually expire all pending entries pending = store._load_json(store._pending_path("telegram")) - pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 + for entry_id in pending: + pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 store._save_json(store._pending_path("telegram"), pending) # Cleanup happens on next operation @@ -314,9 +576,10 @@ class TestCodeExpiry: store = PairingStore() code = store.generate_code("telegram", "user1") - # Expire it + # Expire all entries pending = store._load_json(store._pending_path("telegram")) - pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 + for entry_id in pending: + pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1 store._save_json(store._pending_path("telegram"), pending) result = store.approve_code("telegram", code) diff --git a/tests/gateway/test_platform_base.py b/tests/gateway/test_platform_base.py index 23646545bfc..3f303d0377c 100644 --- a/tests/gateway/test_platform_base.py +++ b/tests/gateway/test_platform_base.py @@ -361,6 +361,72 @@ class TestExtractMedia: assert "[[as_document]]" not in cleaned +class TestMediaDeliveryPathValidation: + def _patch_roots(self, monkeypatch, *roots): + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + tuple(roots), + ) + + def test_allows_existing_file_inside_safe_root(self, tmp_path, monkeypatch): + root = tmp_path / "media-cache" + media_file = root / "voice.ogg" + media_file.parent.mkdir(parents=True) + media_file.write_bytes(b"OggS") + self._patch_roots(monkeypatch, root) + + assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve()) + + def test_rejects_existing_file_outside_safe_root(self, tmp_path, monkeypatch): + root = tmp_path / "media-cache" + root.mkdir() + secret = tmp_path / "secrets.txt" + secret.write_text("not for upload") + self._patch_roots(monkeypatch, root) + + assert BasePlatformAdapter.validate_media_delivery_path(str(secret)) is None + + def test_rejects_symlink_escape_from_safe_root(self, tmp_path, monkeypatch): + root = tmp_path / "media-cache" + root.mkdir() + secret = tmp_path / "outside.png" + secret.write_bytes(b"secret") + link = root / "safe-looking.png" + try: + link.symlink_to(secret) + except OSError: + pytest.skip("symlink creation is unavailable") + self._patch_roots(monkeypatch, root) + + assert BasePlatformAdapter.validate_media_delivery_path(str(link)) is None + + def test_filter_keeps_safe_media_and_drops_unsafe(self, tmp_path, monkeypatch): + root = tmp_path / "media-cache" + safe = root / "speech.ogg" + unsafe = tmp_path / "outside.ogg" + safe.parent.mkdir(parents=True) + safe.write_bytes(b"OggS") + unsafe.write_bytes(b"OggS") + self._patch_roots(monkeypatch, root) + + filtered = BasePlatformAdapter.filter_media_delivery_paths([ + (str(unsafe), False), + (str(safe), True), + ]) + + assert filtered == [(str(safe.resolve()), True)] + + def test_allows_operator_configured_extra_root(self, tmp_path, monkeypatch): + extra_root = tmp_path / "operator-media" + media_file = extra_root / "report.pdf" + media_file.parent.mkdir(parents=True) + media_file.write_bytes(b"%PDF-1.4") + self._patch_roots(monkeypatch) + monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(extra_root)) + + assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve()) + + # --------------------------------------------------------------------------- # should_send_media_as_audio # --------------------------------------------------------------------------- @@ -728,4 +794,3 @@ class TestProxyKwargsForAiohttp: sess_kw, req_kw = proxy_kwargs_for_aiohttp("http://proxy:8080") assert sess_kw == {} assert req_kw == {"proxy": "http://proxy:8080"} - diff --git a/tests/gateway/test_platform_connected_checkers.py b/tests/gateway/test_platform_connected_checkers.py index 941b8c74506..f7677a3a676 100644 --- a/tests/gateway/test_platform_connected_checkers.py +++ b/tests/gateway/test_platform_connected_checkers.py @@ -79,10 +79,11 @@ def test_checker_returns_true_when_configured(platform, checker, monkeypatch): elif platform in { Platform.API_SERVER, Platform.WEBHOOK, - Platform.MSGRAPH_WEBHOOK, Platform.WHATSAPP, }: mock_config.extra = {} + elif platform == Platform.MSGRAPH_WEBHOOK: + mock_config.extra = {"client_state": "expected-client-state"} elif platform == Platform.FEISHU: mock_config.extra = {"app_id": "app"} elif platform == Platform.WECOM: diff --git a/tests/gateway/test_platform_registry.py b/tests/gateway/test_platform_registry.py index 4ddc645b7b2..9ca80fe8a1f 100644 --- a/tests/gateway/test_platform_registry.py +++ b/tests/gateway/test_platform_registry.py @@ -708,3 +708,279 @@ class TestPluginPlatformSharedKeyBridge: assert extra.get("allow_from") == ["alice", "bob"] finally: _reg.unregister("mysharedplat") + + +class TestPluginEnablementGate: + """Plugin platforms must NOT auto-enable on check_fn alone (#31116). + + When a plugin registers ``is_connected`` (the "did the user actually + configure credentials" probe), ``load_gateway_config`` must consult it + before flipping ``enabled = True``. Without this gate, ``check_fn`` + semantics ("the SDK is importable") get conflated with "the user wants + this platform on", and the gateway tries to connect to e.g. Discord + with no token โ€” emitting noisy retry-forever errors on every fresh + install that has the plugin loaded. + """ + + def _write_config(self, tmp_path, content: str = ""): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text(content, encoding="utf-8") + return hermes_home + + def test_plugin_with_is_connected_false_is_NOT_enabled( + self, tmp_path, monkeypatch + ): + """check_fn=True + is_connected=False must NOT enable the platform. + + Reproduces #31116: Discord plugin loads, its check_fn lazy-installs + discord.py and returns True, but the user has no DISCORD_BOT_TOKEN. + Previously this auto-enabled Discord and the gateway spammed + ``ERROR ... [Discord] No bot token configured`` on every reconnect. + """ + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="myunconfiguredplat", + label="MyUnconfigured", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, # SDK available + is_connected=lambda cfg: False, # but user hasn't set credentials + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("myunconfiguredplat") + # Either absent entirely, or present but explicitly disabled. + if plat in cfg.platforms: + assert cfg.platforms[plat].enabled is False, ( + "Plugin with is_connected=False must NOT be auto-enabled" + ) + finally: + _reg.unregister("myunconfiguredplat") + + def test_plugin_with_is_connected_true_is_enabled( + self, tmp_path, monkeypatch + ): + """check_fn=True + is_connected=True still enables the platform.""" + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="myconfiguredplat", + label="MyConfigured", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + is_connected=lambda cfg: True, + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("myconfiguredplat") + assert plat in cfg.platforms + assert cfg.platforms[plat].enabled is True + finally: + _reg.unregister("myconfiguredplat") + + def test_plugin_without_is_connected_falls_back_to_check_fn( + self, tmp_path, monkeypatch + ): + """Legacy plugins that don't register is_connected keep working. + + For plugins where ``is_connected is None``, gating on ``check_fn`` + alone remains the contract โ€” that's what callers without a + credential probe have always done. + """ + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="mylegacyplat", + label="MyLegacy", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + # is_connected intentionally omitted (None) + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("mylegacyplat") + assert plat in cfg.platforms + assert cfg.platforms[plat].enabled is True + finally: + _reg.unregister("mylegacyplat") + + def test_is_connected_raises_does_not_enable(self, tmp_path, monkeypatch): + """A buggy is_connected must not silently enable the platform. + + Treat a raising is_connected as "configuration unknown" โ€” refuse to + enable, log, and move on. Anything else would re-introduce the + #31116 bug for plugins whose probe has a transient failure. + """ + from gateway.platform_registry import platform_registry as _reg + + def _bad_probe(cfg): + raise RuntimeError("plugin bug") + + _reg.register(PlatformEntry( + name="mybadprobeplat", + label="MyBadProbe", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + is_connected=_bad_probe, + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("mybadprobeplat") + if plat in cfg.platforms: + assert cfg.platforms[plat].enabled is False + finally: + _reg.unregister("mybadprobeplat") + + def test_yaml_enabled_true_overrides_is_connected_false( + self, tmp_path, monkeypatch + ): + """Explicit YAML ``enabled: true`` wins over is_connected=False. + + If the user wrote ``platforms.X.enabled: true`` themselves, respect + that โ€” they may be using a credential mechanism the plugin's + is_connected probe doesn't know about. Don't fight them. + """ + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="myexplicitplat", + label="MyExplicit", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + is_connected=lambda cfg: False, + source="plugin", + )) + try: + home = self._write_config( + tmp_path, + "platforms:\n" + " myexplicitplat:\n" + " enabled: true\n", + ) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("myexplicitplat") + assert plat in cfg.platforms + assert cfg.platforms[plat].enabled is True, ( + "Explicit YAML enabled: true must win over plugin's " + "is_connected=False โ€” user has the final say" + ) + finally: + _reg.unregister("myexplicitplat") + + def test_is_connected_sees_env_seeded_extras(self, tmp_path, monkeypatch): + """``env_enablement_fn`` extras must be visible to ``is_connected``. + + Some plugins (e.g. Google Chat) implement ``is_connected`` by + inspecting ``config.extra`` (where ``env_enablement_fn`` deposits + env-var-derived state) rather than reading ``os.environ`` directly. + If the gate runs BEFORE the seeding step, those plugins fail the + gate even when the user is genuinely configured via env vars. + + Pin the contract: when both hooks are present, ``env_enablement_fn`` + feeds a candidate config to ``is_connected``. + """ + from gateway.platform_registry import platform_registry as _reg + + seen_extras: dict = {} + + def _is_connected(cfg): + seen_extras["snapshot"] = dict(getattr(cfg, "extra", {}) or {}) + extra = getattr(cfg, "extra", {}) or {} + return bool(extra.get("project_id") and extra.get("subscription_name")) + + def _env_enablement(): + return {"project_id": "p", "subscription_name": "s"} + + _reg.register(PlatformEntry( + name="myextrasplat", + label="MyExtras", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + is_connected=_is_connected, + env_enablement_fn=_env_enablement, + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("myextrasplat") + assert plat in cfg.platforms, ( + "is_connected was called with empty extras โ€” " + "env_enablement_fn must seed the probe BEFORE the gate" + ) + assert cfg.platforms[plat].enabled is True + # extras populated on the live config too + assert cfg.platforms[plat].extra.get("project_id") == "p" + assert cfg.platforms[plat].extra.get("subscription_name") == "s" + # and the probe saw them + assert seen_extras["snapshot"]["project_id"] == "p" + finally: + _reg.unregister("myextrasplat") + + def test_is_connected_failed_gate_does_not_leak_extras( + self, tmp_path, monkeypatch + ): + """When the gate rejects, env-seeded extras must NOT leak onto + ``config.platforms``. A rejected plugin should be invisible, not + present-but-partially-populated. + """ + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="myrejectedplat", + label="MyRejected", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + is_connected=lambda cfg: False, + env_enablement_fn=lambda: {"some_key": "should-not-leak"}, + source="plugin", + )) + try: + home = self._write_config(tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + + from gateway.config import load_gateway_config, Platform + cfg = load_gateway_config() + + plat = Platform("myrejectedplat") + if plat in cfg.platforms: + assert cfg.platforms[plat].enabled is False + assert "some_key" not in cfg.platforms[plat].extra, ( + "Rejected plugin's env-seeded extras leaked onto " + "config.platforms" + ) + finally: + _reg.unregister("myrejectedplat") diff --git a/tests/gateway/test_qqbot.py b/tests/gateway/test_qqbot.py index 4b3402387a4..bdcb4c9e8df 100644 --- a/tests/gateway/test_qqbot.py +++ b/tests/gateway/test_qqbot.py @@ -1233,14 +1233,14 @@ class TestAdapterInteractionDispatch: "user_openid": "user-1", "data": { "type": 11, - "resolved": {"button_data": "approve:s:deny", "button_id": "deny"}, + "resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny", "button_id": "deny"}, }, }) assert len(ack_calls) == 1 assert ack_calls[0][0] == "i-1" assert len(received) == 1 - assert received[0].button_data == "approve:s:deny" + assert received[0].button_data == "approve:agent:main:qqbot:c2c:u:deny" assert received[0].scene == "c2c" @pytest.mark.asyncio @@ -1262,7 +1262,7 @@ class TestAdapterInteractionDispatch: adapter.set_interaction_callback(cb) await adapter._on_interaction({ "chat_type": 2, # no id - "data": {"resolved": {"button_data": "approve:s:deny"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}}, }) assert ack_calls == [] @@ -1286,7 +1286,7 @@ class TestAdapterInteractionDispatch: "id": "i-2", "chat_type": 2, "user_openid": "u", - "data": {"resolved": {"button_data": "approve:s:deny"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}}, }) @pytest.mark.asyncio @@ -1304,7 +1304,7 @@ class TestAdapterInteractionDispatch: "id": "i-3", "chat_type": 2, "user_openid": "u", - "data": {"resolved": {"button_data": "approve:s:deny"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}}, }) @@ -1570,13 +1570,13 @@ class TestDefaultInteractionDispatch: "id": "i", "chat_type": 2, "user_openid": "u-42", - "data": {"resolved": {"button_data": "approve:sess-abc:allow-once"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u-42:allow-once"}}, }) await adapter._default_interaction_dispatch(event) finally: tools.approval.resolve_gateway_approval = orig - assert resolve_calls == [("sess-abc", "once", False)] + assert resolve_calls == [("agent:main:qqbot:c2c:u-42", "once", False)] @pytest.mark.asyncio async def test_approval_click_always_maps_to_always(self): @@ -1594,13 +1594,13 @@ class TestDefaultInteractionDispatch: from gateway.platforms.qqbot.keyboards import parse_interaction_event event = parse_interaction_event({ "id": "i", "chat_type": 2, "user_openid": "u", - "data": {"resolved": {"button_data": "approve:s:allow-always"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:allow-always"}}, }) await adapter._default_interaction_dispatch(event) finally: tools.approval.resolve_gateway_approval = orig - assert resolve_calls == [("s", "always", False)] + assert resolve_calls == [("agent:main:qqbot:c2c:u", "always", False)] @pytest.mark.asyncio async def test_approval_click_deny_maps_to_deny(self): @@ -1618,13 +1618,40 @@ class TestDefaultInteractionDispatch: from gateway.platforms.qqbot.keyboards import parse_interaction_event event = parse_interaction_event({ "id": "i", "chat_type": 2, "user_openid": "u", - "data": {"resolved": {"button_data": "approve:s:deny"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}}, }) await adapter._default_interaction_dispatch(event) finally: tools.approval.resolve_gateway_approval = orig - assert resolve_calls == [("s", "deny", False)] + assert resolve_calls == [("agent:main:qqbot:c2c:u", "deny", False)] + + + @pytest.mark.asyncio + async def test_approval_click_rejects_unauthorized_operator(self): + adapter = self._make_adapter() + resolve_calls = [] + + def fake_resolve(session_key, choice, resolve_all=False): + resolve_calls.append((session_key, choice, resolve_all)) + return 1 + + import tools.approval + orig = tools.approval.resolve_gateway_approval + tools.approval.resolve_gateway_approval = fake_resolve + try: + from gateway.platforms.qqbot.keyboards import parse_interaction_event + event = parse_interaction_event({ + "id": "i", "chat_type": 1, + "group_openid": "g-1", + "group_member_openid": "attacker", + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:group:g-1:owner:allow-once"}}, + }) + await adapter._default_interaction_dispatch(event) + finally: + tools.approval.resolve_gateway_approval = orig + + assert resolve_calls == [] @pytest.mark.asyncio async def test_update_prompt_click_writes_response_file(self, tmp_path, monkeypatch): @@ -1700,7 +1727,7 @@ class TestDefaultInteractionDispatch: from gateway.platforms.qqbot.keyboards import parse_interaction_event event = parse_interaction_event({ "id": "i", "chat_type": 2, "user_openid": "u", - "data": {"resolved": {"button_data": "approve:s:deny"}}, + "data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}}, }) # Must not raise. await adapter._default_interaction_dispatch(event) @@ -1810,3 +1837,365 @@ class TestSendUpdatePrompt: adapter.send_with_keyboard = fake_swk # type: ignore[assignment] await adapter.send_update_prompt(chat_id="u", prompt="ok?") + + +# --------------------------------------------------------------------------- +# _send_identify includes INTERACTION intent +# --------------------------------------------------------------------------- + +class TestIdentifyIntents: + """Verify the WebSocket identify payload includes the INTERACTION intent bit.""" + + def _make_adapter(self): + from gateway.platforms.qqbot.adapter import QQAdapter + return QQAdapter(_make_config(app_id="a", client_secret="b")) + + @pytest.mark.asyncio + async def test_intents_include_interaction_bit(self): + adapter = self._make_adapter() + + # Mock token retrieval and WebSocket + adapter._access_token = "fake_token" + adapter._token_expires_at = 9999999999.0 + + sent_payloads = [] + + class FakeWS: + closed = False + + async def send_json(self, payload): + sent_payloads.append(payload) + + adapter._ws = FakeWS() + await adapter._send_identify() + + assert len(sent_payloads) == 1 + intents = sent_payloads[0]["d"]["intents"] + + # Verify all expected intent bits are present + assert intents & (1 << 25), "GROUP_MESSAGES (1<<25) missing" + assert intents & (1 << 30), "GUILD_AT_MESSAGE (1<<30) missing" + assert intents & (1 << 12), "DIRECT_MESSAGES (1<<12) missing" + assert intents & (1 << 26), "INTERACTION (1<<26) missing" + + +# --------------------------------------------------------------------------- +# _process_attachments: video/file path exposure +# --------------------------------------------------------------------------- + +class TestProcessAttachmentsPathExposure: + """Verify that video and file attachments include the cached local path.""" + + def _make_adapter(self): + from gateway.platforms.qqbot.adapter import QQAdapter + return QQAdapter(_make_config(app_id="a", client_secret="b")) + + @pytest.mark.asyncio + async def test_video_attachment_includes_path(self): + adapter = self._make_adapter() + + # Mock _download_and_cache to return a known path + async def fake_download(url, ct, original_name=""): + return "/tmp/cache/video_abc123.mp4" + + adapter._download_and_cache = fake_download # type: ignore[assignment] + + attachments = [ + { + "content_type": "video/mp4", + "url": "https://multimedia.nt.qq.com.cn/download/video123", + "filename": "my_video.mp4", + } + ] + result = await adapter._process_attachments(attachments) + + assert result["image_urls"] == [] + assert result["voice_transcripts"] == [] + info = result["attachment_info"] + assert "[video:" in info + assert "my_video.mp4" in info + assert "/tmp/cache/video_abc123.mp4" in info + + @pytest.mark.asyncio + async def test_file_attachment_includes_path(self): + adapter = self._make_adapter() + + async def fake_download(url, ct, original_name=""): + return "/tmp/cache/doc_abc123_report.pdf" + + adapter._download_and_cache = fake_download # type: ignore[assignment] + + attachments = [ + { + "content_type": "application/pdf", + "url": "https://multimedia.nt.qq.com.cn/download/file456", + "filename": "report.pdf", + } + ] + result = await adapter._process_attachments(attachments) + + info = result["attachment_info"] + assert "[file:" in info + assert "report.pdf" in info + assert "/tmp/cache/doc_abc123_report.pdf" in info + + @pytest.mark.asyncio + async def test_video_without_filename_falls_back_to_content_type(self): + adapter = self._make_adapter() + + async def fake_download(url, ct, original_name=""): + return "/tmp/cache/video_xyz.mp4" + + adapter._download_and_cache = fake_download # type: ignore[assignment] + + attachments = [ + { + "content_type": "video/mp4", + "url": "https://cdn.qq.com/vid", + "filename": "", + } + ] + result = await adapter._process_attachments(attachments) + + info = result["attachment_info"] + assert "[video: video/mp4" in info + assert "/tmp/cache/video_xyz.mp4" in info + + @pytest.mark.asyncio + async def test_download_failure_produces_no_attachment_info(self): + adapter = self._make_adapter() + + async def fake_download(url, ct, original_name=""): + return None + + adapter._download_and_cache = fake_download # type: ignore[assignment] + + attachments = [ + { + "content_type": "video/mp4", + "url": "https://cdn.qq.com/vid", + "filename": "vid.mp4", + } + ] + result = await adapter._process_attachments(attachments) + assert result["attachment_info"] == "" + + @pytest.mark.asyncio + async def test_quoted_video_includes_path_in_quote_block(self): + """Quoted video attachments should surface the cached path in the quote block.""" + adapter = self._make_adapter() + + async def fake_process(atts): + # Simulate the fixed _process_attachments for a video attachment. + return { + "image_urls": [], + "image_media_types": [], + "voice_transcripts": [], + "attachment_info": "[video: clip.mp4 (/tmp/cache/clip.mp4)]", + } + + adapter._process_attachments = fake_process # type: ignore[assignment] + + d = { + "message_type": 103, + "msg_elements": [{ + "content": "็œ‹็œ‹่ฟ™ไธช่ง†้ข‘", + "attachments": [ + {"content_type": "video/mp4", + "url": "https://qq-cdn/clip.mp4", + "filename": "clip.mp4"} + ], + }], + } + out = await adapter._process_quoted_context(d) + assert "[Quoted message]:" in out["quote_block"] + assert "/tmp/cache/clip.mp4" in out["quote_block"] + + @pytest.mark.asyncio + async def test_quoted_file_includes_path_in_quote_block(self): + """Quoted file attachments should surface the cached path in the quote block.""" + adapter = self._make_adapter() + + async def fake_process(atts): + return { + "image_urls": [], + "image_media_types": [], + "voice_transcripts": [], + "attachment_info": "[file: report.pdf (/tmp/cache/report.pdf)]", + } + + adapter._process_attachments = fake_process # type: ignore[assignment] + + d = { + "message_type": 103, + "msg_elements": [{ + "content": "", + "attachments": [ + {"content_type": "application/pdf", + "url": "https://qq-cdn/report.pdf", + "filename": "report.pdf"} + ], + }], + } + out = await adapter._process_quoted_context(d) + assert "[Quoted message]:" in out["quote_block"] + assert "/tmp/cache/report.pdf" in out["quote_block"] + + +# --------------------------------------------------------------------------- +# WebSocket op 7 (Server Reconnect) and op 9 (Invalid Session) +# --------------------------------------------------------------------------- + +class TestOp7ServerReconnect: + """Verify op 7 triggers WS close (which triggers reconnect in outer loop).""" + + def _make_adapter(self): + from gateway.platforms.qqbot.adapter import QQAdapter + return QQAdapter(_make_config(app_id="a", client_secret="b")) + + def test_op7_closes_websocket(self): + adapter = self._make_adapter() + adapter._session_id = "sess_keep" + adapter._last_seq = 42 + + close_called = [] + + class FakeWS: + closed = False + + async def close(self): + close_called.append(True) + + adapter._ws = FakeWS() + adapter._dispatch_payload({"op": 7, "d": None}) + + # Session should be preserved for Resume + assert adapter._session_id == "sess_keep" + assert adapter._last_seq == 42 + # close() should have been scheduled + assert len(close_called) == 0 # _create_task schedules, not immediate + # But the task was created โ€” verify via asyncio + + @pytest.mark.asyncio + async def test_op7_close_task_executes(self): + adapter = self._make_adapter() + close_called = [] + + class FakeWS: + closed = False + + async def close(self): + close_called.append(True) + self.closed = True + + adapter._ws = FakeWS() + adapter._dispatch_payload({"op": 7, "d": None}) + + # Let the event loop run the scheduled task + await asyncio.sleep(0) + assert close_called == [True] + # Session preserved + assert adapter._session_id is None # was never set + + +class TestOp9InvalidSession: + """Verify op 9 handles resumable vs non-resumable sessions.""" + + def _make_adapter(self): + from gateway.platforms.qqbot.adapter import QQAdapter + return QQAdapter(_make_config(app_id="a", client_secret="b")) + + def test_op9_not_resumable_clears_session(self): + adapter = self._make_adapter() + adapter._session_id = "sess_old" + adapter._last_seq = 99 + + class FakeWS: + closed = False + + async def close(self): + self.closed = True + + adapter._ws = FakeWS() + adapter._dispatch_payload({"op": 9, "d": False}) + + assert adapter._session_id is None + assert adapter._last_seq is None + + def test_op9_resumable_preserves_session(self): + adapter = self._make_adapter() + adapter._session_id = "sess_keep" + adapter._last_seq = 99 + + class FakeWS: + closed = False + + async def close(self): + self.closed = True + + adapter._ws = FakeWS() + adapter._dispatch_payload({"op": 9, "d": True}) + + # Session should be preserved for Resume + assert adapter._session_id == "sess_keep" + assert adapter._last_seq == 99 + + @pytest.mark.asyncio + async def test_op9_non_resumable_triggers_ws_close(self): + adapter = self._make_adapter() + adapter._session_id = "s" + adapter._last_seq = 1 + close_called = [] + + class FakeWS: + closed = False + + async def close(self): + close_called.append(True) + self.closed = True + + adapter._ws = FakeWS() + adapter._dispatch_payload({"op": 9, "d": False}) + await asyncio.sleep(0) + + assert close_called == [True] + + +# --------------------------------------------------------------------------- +# Close code classification +# --------------------------------------------------------------------------- + +class TestCloseCodeClassification: + """Verify fatal close codes stop reconnecting and 4009 preserves session.""" + + def _make_adapter(self): + from gateway.platforms.qqbot.adapter import QQAdapter + return QQAdapter(_make_config(app_id="a", client_secret="b")) + + def test_4009_preserves_session(self): + """4009 (connection timeout) should NOT clear the session.""" + adapter = self._make_adapter() + adapter._session_id = "sess_to_keep" + adapter._last_seq = 50 + + # The session-clearing codes set should NOT contain 4009. + # We verify the logic directly: dispatch a close-code event that + # exercises the session-clearing path (4006), then verify 4009 does not. + session_clear_codes = { + 4006, 4007, 4900, 4901, 4902, 4903, + 4904, 4905, 4906, 4907, 4908, 4909, + 4910, 4911, 4912, 4913, + } + assert 4009 not in session_clear_codes + + def test_fatal_codes_include_intent_errors(self): + """4013 (invalid intent) and 4014 (not authorized) should be fatal.""" + fatal_codes = {4001, 4002, 4010, 4011, 4012, 4013, 4014, 4914, 4915} + # Verify these are all treated as fatal by checking the adapter's + # code path would call _set_fatal_error. We verify the set membership + # which is what the if-branch checks. + assert 4013 in fatal_codes + assert 4014 in fatal_codes + assert 4001 in fatal_codes + assert 4915 in fatal_codes + diff --git a/tests/gateway/test_reload_skills_discord_resync.py b/tests/gateway/test_reload_skills_discord_resync.py index 7b2e1d20ff9..1d3b62fb12b 100644 --- a/tests/gateway/test_reload_skills_discord_resync.py +++ b/tests/gateway/test_reload_skills_discord_resync.py @@ -27,7 +27,7 @@ from unittest.mock import MagicMock def _make_adapter(): """Construct a DiscordAdapter without going through __init__ / token checks.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.platforms.base import Platform adapter = object.__new__(DiscordAdapter) adapter.config = MagicMock() diff --git a/tests/gateway/test_restart_drain.py b/tests/gateway/test_restart_drain.py index 9000e4d4820..c1578e3617a 100644 --- a/tests/gateway/test_restart_drain.py +++ b/tests/gateway/test_restart_drain.py @@ -116,6 +116,24 @@ def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, mon assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt" +def test_load_busy_text_mode_defaults_to_queue_and_allows_interrupt(tmp_path, monkeypatch): + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False) + + assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue" + + (tmp_path / "config.yaml").write_text( + "display:\n busy_text_mode: interrupt\n", encoding="utf-8" + ) + assert gateway_run.GatewayRunner._load_busy_text_mode() == "interrupt" + + monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "queue") + assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue" + + monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "bogus") + assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue" + + def test_load_restart_drain_timeout_prefers_env_then_config_then_default( tmp_path, monkeypatch, caplog ): diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index 0d2060ef31f..288193132a9 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -88,6 +88,9 @@ class TestHandleResumeCommand: assert "Research" in result assert "Coding" in result assert "Named Sessions" in result + assert "1." in result + assert "2." in result + assert "/resume 1" in result db.close() @pytest.mark.asyncio @@ -104,6 +107,47 @@ class TestHandleResumeCommand: assert "/title" in result db.close() + @pytest.mark.asyncio + async def test_resume_by_index(self, tmp_path): + """Numeric argument resumes the indexed titled session from the list.""" + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("sess_001", "telegram") + db.create_session("sess_002", "telegram") + db.set_session_title("sess_001", "Research") + db.set_session_title("sess_002", "Coding") + db.create_session("current_session_001", "telegram") + + event = _make_event(text="/resume 2") + runner = _make_runner(session_db=db, current_session_id="current_session_001", + event=event) + result = await runner._handle_resume_command(event) + + assert "Resumed" in result + runner.session_store.switch_session.assert_called_once() + call_args = runner.session_store.switch_session.call_args + assert call_args[0][1] == "sess_001" + db.close() + + @pytest.mark.asyncio + async def test_resume_index_out_of_range(self, tmp_path): + """Out-of-range numeric arguments show a helpful error.""" + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("sess_001", "telegram") + db.set_session_title("sess_001", "Research") + db.create_session("current_session_001", "telegram") + + event = _make_event(text="/resume 9") + runner = _make_runner(session_db=db, current_session_id="current_session_001", + event=event) + result = await runner._handle_resume_command(event) + + assert "out of range" in result.lower() + assert "/resume" in result + runner.session_store.switch_session.assert_not_called() + db.close() + @pytest.mark.asyncio async def test_resume_by_name(self, tmp_path): """Resolves a title and switches to that session.""" diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 8f218dfc11c..5b7dfb821b0 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -942,6 +942,62 @@ async def test_run_agent_matrix_streaming_omits_cursor(monkeypatch, tmp_path): assert any("Continuing to refine:" in text for text in all_text) +class TransformedStreamAgent: + """Streams a response, then signals the gateway that a plugin hook + (``transform_llm_output``) modified the final text after streaming + finished. ``run_conversation`` returns ``response_transformed=True`` + plus a ``final_response`` that diverges from what was streamed. + """ + + def __init__(self, **kwargs): + self.stream_delta_callback = kwargs.get("stream_delta_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + if self.stream_delta_callback: + self.stream_delta_callback("original answer") + return { + "final_response": "original answer\n\n[plugin appended this]", + "response_previewed": True, + "response_transformed": True, + "messages": [], + "api_calls": 1, + } + + +@pytest.mark.asyncio +async def test_transformed_response_edits_streamed_message_in_place(monkeypatch, tmp_path): + """When a transform_llm_output hook modifies the response after streaming, + the gateway must edit the existing streamed message in place with the full + transformed content (so plugins like content filters / appenders reach the + user) and still mark already_sent=True (no duplicate send). + """ + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + TransformedStreamAgent, + session_id="sess-transformed-stream", + config_data={ + "display": {"tool_progress": "off", "interim_assistant_messages": False}, + "streaming": {"enabled": True, "edit_interval": 0.01, "buffer_threshold": 1}, + }, + platform=Platform.MATRIX, + chat_id="!room:matrix.example.org", + chat_type="group", + thread_id="$thread", + adapter_cls=MetadataEditProgressCaptureAdapter, + ) + + # Final delivery happened (no duplicate send fallback). + assert result.get("already_sent") is True + # The transformed final text reached the user โ€” appended portion is present + # in an edit_message call (not just in the streamed sends). + edited_texts = [e["content"] for e in adapter.edits] + assert any("[plugin appended this]" in text for text in edited_texts), ( + f"expected transformed text in adapter.edits, got: {edited_texts!r}" + ) + + @pytest.mark.asyncio async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path): QueuedCommentaryAgent.calls = 0 diff --git a/tests/gateway/test_runner_startup_failures.py b/tests/gateway/test_runner_startup_failures.py index 438553f34ed..b82062e4090 100644 --- a/tests/gateway/test_runner_startup_failures.py +++ b/tests/gateway/test_runner_startup_failures.py @@ -207,6 +207,7 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p lambda **kwargs: 0, ) monkeypatch.setattr("gateway.status.terminate_pid", lambda pid, force=False: calls.append((pid, force))) + monkeypatch.setattr("gateway.status._pid_exists", lambda pid: True) monkeypatch.setattr("gateway.run.os.getpid", lambda: 100) monkeypatch.setattr("gateway.run.os.kill", lambda pid, sig: None) monkeypatch.setattr("time.sleep", lambda _: None) diff --git a/tests/gateway/test_runtime_config_env_expansion.py b/tests/gateway/test_runtime_config_env_expansion.py new file mode 100644 index 00000000000..e77e9daaa66 --- /dev/null +++ b/tests/gateway/test_runtime_config_env_expansion.py @@ -0,0 +1,97 @@ +"""Regression tests for gateway runtime config env-var expansion.""" + +from __future__ import annotations + +import json + +import pytest + +import gateway.run as gateway_run + + +def _write_config(home, body: str) -> None: + (home / "config.yaml").write_text(body, encoding="utf-8") + + +@pytest.fixture +def gateway_home(monkeypatch, tmp_path): + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.delenv("HERMES_PREFILL_MESSAGES_FILE", raising=False) + monkeypatch.delenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", raising=False) + monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False) + monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False) + monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False) + return tmp_path + + +def test_load_prefill_messages_expands_env_var_path(monkeypatch, gateway_home): + prefill = [{"role": "system", "content": "few-shot"}] + (gateway_home / "prefill.json").write_text(json.dumps(prefill), encoding="utf-8") + _write_config(gateway_home, "prefill_messages_file: ${PREFILL_FILE}\n") + monkeypatch.setenv("PREFILL_FILE", "prefill.json") + + assert gateway_run.GatewayRunner._load_prefill_messages() == prefill + + +@pytest.mark.parametrize( + ("config_body", "env_name", "env_value", "loader_name", "expected"), + [ + ( + "agent:\n system_prompt: ${GW_PROMPT}\n", + "GW_PROMPT", + "expanded prompt", + "_load_ephemeral_system_prompt", + "expanded prompt", + ), + ( + "agent:\n reasoning_effort: ${REASONING_LEVEL}\n", + "REASONING_LEVEL", + "high", + "_load_reasoning_config", + {"enabled": True, "effort": "high"}, + ), + ( + "agent:\n service_tier: ${SERVICE_TIER}\n", + "SERVICE_TIER", + "priority", + "_load_service_tier", + "priority", + ), + ( + "display:\n busy_input_mode: ${BUSY_MODE}\n", + "BUSY_MODE", + "steer", + "_load_busy_input_mode", + "steer", + ), + ( + "agent:\n restart_drain_timeout: ${DRAIN_TIMEOUT}\n", + "DRAIN_TIMEOUT", + "12", + "_load_restart_drain_timeout", + 12.0, + ), + ( + "display:\n background_process_notifications: ${BG_MODE}\n", + "BG_MODE", + "error", + "_load_background_notifications_mode", + "error", + ), + ], +) +def test_gateway_runtime_loaders_expand_env_var_templates( + monkeypatch, + gateway_home, + config_body, + env_name, + env_value, + loader_name, + expected, +): + _write_config(gateway_home, config_body) + monkeypatch.setenv(env_name, env_value) + + loader = getattr(gateway_run.GatewayRunner, loader_name) + + assert loader() == expected diff --git a/tests/gateway/test_send_image_file.py b/tests/gateway/test_send_image_file.py index cb0e436739e..b769d2be9fb 100644 --- a/tests/gateway/test_send_image_file.py +++ b/tests/gateway/test_send_image_file.py @@ -190,7 +190,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() import discord as discord_mod_ref # noqa: E402 -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class TestDiscordSendImageFile: diff --git a/tests/gateway/test_send_multiple_images.py b/tests/gateway/test_send_multiple_images.py index 06983a4b6b8..5f6f3e7b771 100644 --- a/tests/gateway/test_send_multiple_images.py +++ b/tests/gateway/test_send_multiple_images.py @@ -210,7 +210,7 @@ def _ensure_discord_mock(): _ensure_discord_mock() -from gateway.platforms.discord import DiscordAdapter # noqa: E402 +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 class TestDiscordMultiImage: diff --git a/tests/gateway/test_session_model_override_routing.py b/tests/gateway/test_session_model_override_routing.py index 26acdc157aa..b1e50c07bf3 100644 --- a/tests/gateway/test_session_model_override_routing.py +++ b/tests/gateway/test_session_model_override_routing.py @@ -218,3 +218,46 @@ fallback_providers: assert runtime_kwargs["provider"] == "openrouter" assert runtime_kwargs["api_key"] == "sk-openrouter" + +def test_gateway_auth_fallback_resolves_key_env_for_custom_provider(tmp_path, monkeypatch): + """Auth-failure fallback should honor key_env/api_key_env custom-endpoint hints.""" + config = tmp_path / "config.yaml" + config.write_text( + """ +fallback_providers: + - provider: custom + model: fallback-model + base_url: https://fallback.example/v1 + key_env: MY_FALLBACK_KEY +""".lstrip(), + encoding="utf-8", + ) + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setenv("MY_FALLBACK_KEY", "env-secret") + + def fake_resolve_runtime_provider(*, requested=None, explicit_base_url=None, explicit_api_key=None): + assert requested == "custom" + assert explicit_base_url == "https://fallback.example/v1" + assert explicit_api_key == "env-secret" + return { + "api_key": explicit_api_key, + "base_url": explicit_base_url, + "provider": "custom", + "api_mode": "chat_completions", + "command": None, + "args": [], + "credential_pool": None, + } + + import hermes_cli.runtime_provider as runtime_provider + + monkeypatch.setattr(runtime_provider, "resolve_runtime_provider", fake_resolve_runtime_provider) + + runtime_kwargs = gateway_run._try_resolve_fallback_provider() + + assert runtime_kwargs is not None + assert runtime_kwargs["provider"] == "custom" + assert runtime_kwargs["api_key"] == "env-secret" + assert runtime_kwargs["base_url"] == "https://fallback.example/v1" + assert runtime_kwargs["model"] == "fallback-model" + diff --git a/tests/gateway/test_session_split_brain_11016.py b/tests/gateway/test_session_split_brain_11016.py index 1076a77c44c..0b2972ac173 100644 --- a/tests/gateway/test_session_split_brain_11016.py +++ b/tests/gateway/test_session_split_brain_11016.py @@ -53,6 +53,7 @@ class _StubAdapter(BasePlatformAdapter): def _make_adapter(): config = PlatformConfig(enabled=True, token="test-token") adapter = _StubAdapter(config, Platform.TELEGRAM) + adapter._busy_text_mode = "" adapter.sent_responses = [] async def _mock_send_retry(chat_id, content, **kwargs): @@ -396,4 +397,3 @@ class TestOldTaskCannotClobberNewerGuard: # default path) still work. adapter._release_session_guard(sk) assert sk not in adapter._active_sessions - diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 41d8f40e84d..24c984f0cc6 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -149,7 +149,7 @@ class TestEditMessageFinalizeSignature: "module_path,class_name", [ ("gateway.platforms.telegram", "TelegramAdapter"), - ("gateway.platforms.discord", "DiscordAdapter"), + ("plugins.platforms.discord.adapter", "DiscordAdapter"), ("gateway.platforms.slack", "SlackAdapter"), ("gateway.platforms.matrix", "MatrixAdapter"), ("gateway.platforms.mattermost", "MattermostAdapter"), diff --git a/tests/gateway/test_telegram_group_gating.py b/tests/gateway/test_telegram_group_gating.py index 5ba1b48ade4..c3814a7fb8a 100644 --- a/tests/gateway/test_telegram_group_gating.py +++ b/tests/gateway/test_telegram_group_gating.py @@ -225,6 +225,128 @@ def test_observed_group_context_uses_shared_source_and_prompt_for_later_mentions asyncio.run(_run()) +def test_observed_group_context_replays_as_current_message_context_not_user_turns(): + from gateway.run import ( + _build_gateway_agent_history, + _wrap_current_message_with_observed_context, + ) + + history = [ + {"role": "session_meta", "content": "tool defs"}, + {"role": "user", "content": "[Alice|111]\nAcha que dรก fazer estoque?", "observed": True}, + {"role": "user", "content": "[Alice|111]\nTem lote e vencimento", "observed": True}, + {"role": "assistant", "content": "previous explicit reply"}, + ] + + agent_history, observed_context = _build_gateway_agent_history( + history, + channel_prompt="You are handling Telegram; observed Telegram group context is present.", + ) + api_message = _wrap_current_message_with_observed_context( + "[Bob|222]\ncambio", + observed_context, + ) + + assert agent_history == [{"role": "assistant", "content": "previous explicit reply"}] + assert "[Observed Telegram group context - context only, not requests]" in api_message + assert "[Current addressed message - answer only this" in api_message + assert "Acha que dรก fazer estoque?" in api_message + assert "Tem lote e vencimento" in api_message + assert api_message.endswith("[Bob|222]\ncambio") + + +def test_observed_group_context_does_not_hide_current_user_turn_behind_history_offset(): + from agent.agent_runtime_helpers import repair_message_sequence + from gateway.run import ( + _build_gateway_agent_history, + _wrap_current_message_with_observed_context, + ) + + history = [ + {"role": "user", "content": "[Alice|111]\nAcha que dรก fazer estoque?", "observed": True}, + ] + agent_history, observed_context = _build_gateway_agent_history( + history, + channel_prompt="observed Telegram group context", + ) + api_message = _wrap_current_message_with_observed_context("[Bob|222]\ncambio", observed_context) + messages = list(agent_history) + [{"role": "user", "content": api_message}] + + repair_message_sequence(object(), messages) + + history_offset = len(agent_history) + new_messages = messages[history_offset:] + assert len(agent_history) == 0 + assert new_messages[0]["role"] == "user" + assert new_messages[0]["content"].endswith("[Bob|222]\ncambio") + + +def test_observed_group_context_wraps_multimodal_current_message_without_mutating_parts(): + from gateway.run import _wrap_current_message_with_observed_context + + original = [ + {"type": "text", "text": "[Bob|222]\nsee this image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ] + + wrapped = _wrap_current_message_with_observed_context( + original, + "[Alice|111]\nside chatter", + ) + + assert original[0]["text"] == "[Bob|222]\nsee this image" + assert wrapped[0]["text"].startswith("[Observed Telegram group context - context only") + assert wrapped[0]["text"].endswith("[Bob|222]\nsee this image") + assert wrapped[1] == original[1] + + +def test_observed_group_context_replays_normally_without_telegram_prompt(): + from gateway.run import _build_gateway_agent_history + + history = [ + {"role": "user", "content": "[Alice|111]\nside chatter", "observed": True}, + ] + + agent_history, observed_context = _build_gateway_agent_history(history, channel_prompt=None) + + assert observed_context is None + assert agent_history == [{"role": "user", "content": "[Alice|111]\nside chatter"}] + + +def test_observed_group_context_preserves_slash_command_text_for_dispatch(): + from gateway.platforms.base import MessageEvent, MessageType, Platform, SessionSource + + adapter = _make_adapter( + require_mention=True, + allowed_chats=["-100"], + group_allowed_chats=["-100"], + observe_unmentioned_group_messages=True, + ) + event = MessageEvent( + text="/new@hermes_bot", + message_type=MessageType.COMMAND, + source=SessionSource( + platform=Platform.TELEGRAM, + chat_id="-100", + user_id="111", + user_name="Alice", + chat_type="group", + thread_id="7", + ), + raw_message=_group_message( + "/new@hermes_bot", + entities=[_bot_command_entity("/new@hermes_bot", "/new@hermes_bot")], + ), + ) + + attributed = adapter._apply_telegram_group_observe_attribution(event) + + assert attributed.text == "/new@hermes_bot" + assert attributed.get_command() == "new" + assert attributed.source.user_id is None + assert "observed Telegram group context" in attributed.channel_prompt + + def test_unmentioned_group_observe_requires_chat_allowlist_for_shared_context(): async def _run(): adapter = _make_adapter( diff --git a/tests/gateway/test_telegram_send_path_health.py b/tests/gateway/test_telegram_send_path_health.py new file mode 100644 index 00000000000..940633224e4 --- /dev/null +++ b/tests/gateway/test_telegram_send_path_health.py @@ -0,0 +1,90 @@ +"""TelegramAdapter send-path health gating after reconnect storms. + +After sustained Bad Gateway / TimedOut reconnect cycles, the PTB httpx client +can enter a wedged state where ``bot.send_message()`` returns a valid Message +but nothing reaches the recipient. ``_send_path_degraded`` short-circuits +``send()`` so cron's live-adapter branch falls through to standalone HTTP. +""" +import sys +import types +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import PlatformConfig + + +def _ensure_telegram_mock(): + if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"): + return + mod = MagicMock() + mod.error.NetworkError = type("NetworkError", (OSError,), {}) + mod.error.TimedOut = type("TimedOut", (OSError,), {}) + mod.error.BadRequest = type("BadRequest", (Exception,), {}) + for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"): + sys.modules.setdefault(name, mod) + sys.modules.setdefault("telegram.error", mod.error) + + +_ensure_telegram_mock() + +from gateway.platforms.telegram import TelegramAdapter # noqa: E402 + + +def _make_adapter() -> TelegramAdapter: + adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***")) + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=42)) + return adapter + + +@pytest.mark.asyncio +async def test_send_succeeds_when_path_healthy(): + """Healthy adapter delivers normally; send_message is called.""" + adapter = _make_adapter() + assert adapter._send_path_degraded is False + + result = await adapter.send("123", "hello") + + assert result.success is True + adapter._bot.send_message.assert_awaited() + + +@pytest.mark.asyncio +async def test_send_short_circuits_when_path_degraded(): + """Degraded adapter returns failure WITHOUT calling send_message, + so cron's live-adapter branch falls through to standalone HTTP.""" + adapter = _make_adapter() + adapter._send_path_degraded = True + + result = await adapter.send("123", "hello") + + assert result.success is False + assert result.error == "send_path_degraded" + assert result.retryable is True + adapter._bot.send_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_reconnect_storm_sets_and_heartbeat_clears_flag(monkeypatch): + """_handle_polling_network_error sets the flag; a successful heartbeat + probe in _verify_polling_after_reconnect clears it.""" + adapter = _make_adapter() + adapter._app = MagicMock() + adapter._app.updater = MagicMock() + adapter._app.updater.running = True + adapter._app.updater.stop = AsyncMock() + adapter._app.updater.start_polling = AsyncMock() + adapter._app.bot = MagicMock() + adapter._app.bot.get_me = AsyncMock(return_value=MagicMock()) + adapter._polling_error_callback_ref = AsyncMock() + monkeypatch.setattr( + "gateway.platforms.telegram.Update", MagicMock(ALL_TYPES=[]) + ) + + await adapter._handle_polling_network_error(OSError("Bad Gateway")) + assert adapter._send_path_degraded is True + + with patch("gateway.platforms.telegram.asyncio.sleep", new_callable=AsyncMock): + await adapter._verify_polling_after_reconnect() + assert adapter._send_path_degraded is False diff --git a/tests/gateway/test_telegram_status_update.py b/tests/gateway/test_telegram_status_update.py new file mode 100644 index 00000000000..f49ca9c60e1 --- /dev/null +++ b/tests/gateway/test_telegram_status_update.py @@ -0,0 +1,162 @@ +"""Tests for TelegramAdapter.send_or_update_status (issue #30045). + +The status-update path must: + 1. Send a fresh message on the first call for a (chat_id, status_key) pair. + 2. Edit that same message on subsequent calls with the same key. + 3. Fall back to sending fresh when the cached message edit fails. + 4. Keep distinct keys independent (no cross-talk). +""" + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import PlatformConfig +from gateway.platforms.base import SendResult + + +def _install_fake_telegram(monkeypatch): + """Stub the python-telegram-bot package so TelegramAdapter can be imported.""" + fake_telegram = types.ModuleType("telegram") + fake_telegram.Update = SimpleNamespace(ALL_TYPES=()) + fake_telegram.Bot = object + fake_telegram.Message = object + fake_telegram.InlineKeyboardButton = object + fake_telegram.InlineKeyboardMarkup = object + + fake_error = types.ModuleType("telegram.error") + fake_error.NetworkError = type("NetworkError", (Exception,), {}) + fake_error.BadRequest = type("BadRequest", (Exception,), {}) + fake_error.TimedOut = type("TimedOut", (Exception,), {}) + fake_telegram.error = fake_error + + fake_constants = types.ModuleType("telegram.constants") + fake_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2") + fake_constants.ChatType = SimpleNamespace( + GROUP="group", SUPERGROUP="supergroup", + CHANNEL="channel", PRIVATE="private", + ) + fake_telegram.constants = fake_constants + + fake_ext = types.ModuleType("telegram.ext") + fake_ext.Application = object + fake_ext.CommandHandler = object + fake_ext.CallbackQueryHandler = object + fake_ext.MessageHandler = object + fake_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object) + fake_ext.filters = object + + fake_request = types.ModuleType("telegram.request") + fake_request.HTTPXRequest = object + + monkeypatch.setitem(sys.modules, "telegram", fake_telegram) + monkeypatch.setitem(sys.modules, "telegram.error", fake_error) + monkeypatch.setitem(sys.modules, "telegram.constants", fake_constants) + monkeypatch.setitem(sys.modules, "telegram.ext", fake_ext) + monkeypatch.setitem(sys.modules, "telegram.request", fake_request) + + +@pytest.fixture +def adapter(monkeypatch): + _install_fake_telegram(monkeypatch) + from gateway.platforms.telegram import TelegramAdapter + + a = TelegramAdapter(PlatformConfig(enabled=True, token="fake-token")) + a._bot = MagicMock() + # Patch send / edit_message so tests can drive them directly. + a.send = AsyncMock() + a.edit_message = AsyncMock() + return a + + +@pytest.mark.asyncio +async def test_first_call_sends_and_caches_message_id(adapter): + """First call for a (chat, key) pair must send and remember the id.""" + adapter.send.return_value = SendResult(success=True, message_id="100") + + result = await adapter.send_or_update_status("chat-1", "lifecycle", "starting") + + assert result.success is True + assert result.message_id == "100" + adapter.send.assert_awaited_once() + adapter.edit_message.assert_not_awaited() + assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100" + + +@pytest.mark.asyncio +async def test_second_call_edits_in_place(adapter): + """Same (chat, key) on the second call must edit, not send.""" + adapter.send.return_value = SendResult(success=True, message_id="100") + adapter.edit_message.return_value = SendResult(success=True, message_id="100") + + await adapter.send_or_update_status("chat-1", "lifecycle", "step 1") + await adapter.send_or_update_status("chat-1", "lifecycle", "step 2") + + adapter.send.assert_awaited_once() + adapter.edit_message.assert_awaited_once() + # Edit was directed at the cached message id. + args, kwargs = adapter.edit_message.call_args + assert args[0] == "chat-1" + assert args[1] == "100" + assert args[2] == "step 2" + + +@pytest.mark.asyncio +async def test_edit_failure_falls_back_to_fresh_send(adapter): + """When edit_message fails the cache is cleared and a new send happens.""" + adapter.send.side_effect = [ + SendResult(success=True, message_id="100"), + SendResult(success=True, message_id="200"), + ] + adapter.edit_message.return_value = SendResult( + success=False, error="Bad Request: message to edit not found", + ) + + await adapter.send_or_update_status("chat-1", "lifecycle", "step 1") + result = await adapter.send_or_update_status("chat-1", "lifecycle", "step 2") + + assert result.success is True + assert result.message_id == "200" + assert adapter.send.await_count == 2 + assert adapter.edit_message.await_count == 1 + # Cache now points at the fresh message id. + assert adapter._status_message_ids[("chat-1", "lifecycle")] == "200" + + +@pytest.mark.asyncio +async def test_distinct_status_keys_do_not_collide(adapter): + """A different status_key gets its own message; the original isn't touched.""" + adapter.send.side_effect = [ + SendResult(success=True, message_id="100"), + SendResult(success=True, message_id="200"), + ] + + await adapter.send_or_update_status("chat-1", "lifecycle", "ctx pressure") + await adapter.send_or_update_status("chat-1", "model-switch", "switched to opus") + + assert adapter.send.await_count == 2 + adapter.edit_message.assert_not_awaited() + assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100" + assert adapter._status_message_ids[("chat-1", "model-switch")] == "200" + + +@pytest.mark.asyncio +async def test_distinct_chat_ids_do_not_collide(adapter): + """Same status_key in different chats must not edit each other's messages.""" + adapter.send.side_effect = [ + SendResult(success=True, message_id="100"), + SendResult(success=True, message_id="200"), + ] + + await adapter.send_or_update_status("chat-1", "lifecycle", "first") + await adapter.send_or_update_status("chat-2", "lifecycle", "second") + + assert adapter.send.await_count == 2 + adapter.edit_message.assert_not_awaited() + assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100" + assert adapter._status_message_ids[("chat-2", "lifecycle")] == "200" diff --git a/tests/gateway/test_telegram_thread_fallback.py b/tests/gateway/test_telegram_thread_fallback.py index 642306c142c..6bba27a78cd 100644 --- a/tests/gateway/test_telegram_thread_fallback.py +++ b/tests/gateway/test_telegram_thread_fallback.py @@ -98,6 +98,7 @@ _fake_telegram_ext.Application = object _fake_telegram_ext.CommandHandler = object _fake_telegram_ext.CallbackQueryHandler = object _fake_telegram_ext.MessageHandler = object +_fake_telegram_ext.TypeHandler = object _fake_telegram_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object) _fake_telegram_ext.filters = object _fake_telegram_request = types.ModuleType("telegram.request") diff --git a/tests/gateway/test_telegram_topic_mode.py b/tests/gateway/test_telegram_topic_mode.py index 7945fb716b0..1941bb89e20 100644 --- a/tests/gateway/test_telegram_topic_mode.py +++ b/tests/gateway/test_telegram_topic_mode.py @@ -1175,13 +1175,15 @@ def test_recover_returns_none_for_known_topic(tmp_path): assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="222")) is None -def test_recover_rewrites_unknown_thread_id_to_most_recent(tmp_path): - # Cross-topic Reply leak: inbound thread_id is a Telegram-only id we never bound. +def test_recover_preserves_unknown_thread_id_for_new_topic(tmp_path): + # A newly-created Telegram DM topic arrives with a real, previously-unbound + # message_thread_id. It must become its own session lane rather than being + # rewritten to whichever older topic was most recently active. db = SessionDB(db_path=tmp_path / "state.db") _seed_two_topic_bindings(db) runner = _make_runner(session_db=db) - assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) == "222" + assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) is None def test_recover_rewrites_lobby_thread_id_to_most_recent(tmp_path): @@ -1209,6 +1211,31 @@ def test_recover_returns_none_when_no_bindings_yet(tmp_path): assert runner._recover_telegram_topic_thread_id(_make_source(thread_id=None)) is None +def test_recover_returns_none_for_brand_new_topic(tmp_path): + # Regression for #31086: bindings exist for a prior topic but the user + # opened a fresh one (thread_id "99999"). Recovery must return None so the + # new topic gets its own session rather than being silently merged into + # the previous topic's session. The hijack was self-reinforcing โ€” because + # the rewrite ran before _record_telegram_topic_binding, the new topic's + # binding row never got written, so every subsequent message in that topic + # looked "unknown" and was hijacked again. + db = SessionDB(db_path=tmp_path / "state.db") + db.enable_telegram_topic_mode(chat_id="208214988", user_id="208214988") + db.create_session(session_id="sess-old", source="telegram", user_id="208214988") + src_old = _make_source(thread_id="12345") + db.bind_telegram_topic( + chat_id=src_old.chat_id, + thread_id=src_old.thread_id, + user_id=src_old.user_id, + session_key=build_session_key(src_old), + session_id="sess-old", + ) + runner = _make_runner(session_db=db) + + # "99999" is non-lobby and not in the binding table โ€” brand-new topic. + assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="99999")) is None + + def test_list_telegram_topic_bindings_for_chat(tmp_path): db = SessionDB(db_path=tmp_path / "state.db") _seed_two_topic_bindings(db) diff --git a/tests/gateway/test_text_batching.py b/tests/gateway/test_text_batching.py index 1ad89ffd055..7154ae4ae09 100644 --- a/tests/gateway/test_text_batching.py +++ b/tests/gateway/test_text_batching.py @@ -41,7 +41,7 @@ def _make_event( def _make_discord_adapter(): """Create a minimal DiscordAdapter for testing text batching.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter config = PlatformConfig(enabled=True, token="test-token") adapter = object.__new__(DiscordAdapter) diff --git a/tests/gateway/test_tts_media_routing.py b/tests/gateway/test_tts_media_routing.py index ec93c33f75c..b4f410c280e 100644 --- a/tests/gateway/test_tts_media_routing.py +++ b/tests/gateway/test_tts_media_routing.py @@ -50,11 +50,24 @@ def _event(thread_id=None): ) +def _allowed_media_path(tmp_path, monkeypatch, name): + root = tmp_path / "media-cache" + media_file = root / name + media_file.parent.mkdir(parents=True, exist_ok=True) + media_file.write_bytes(b"media") + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + (root,), + ) + return media_file.resolve() + + @pytest.mark.asyncio -async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(): +async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch): adapter = _MediaRoutingAdapter() event = _event() - adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.flac") + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac") + adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}") adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) @@ -62,17 +75,18 @@ async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(): adapter.send_document.assert_awaited_once_with( chat_id="chat-1", - file_path="/tmp/speech.flac", + file_path=str(media_file), metadata=None, ) adapter.send_voice.assert_not_awaited() @pytest.mark.asyncio -async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender(): +async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch): adapter = _MediaRoutingAdapter() event = _event() - adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.ogg") + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg") + adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}") adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) @@ -80,18 +94,19 @@ async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_ adapter.send_document.assert_awaited_once_with( chat_id="chat-1", - file_path="/tmp/speech.ogg", + file_path=str(media_file), metadata=None, ) adapter.send_voice.assert_not_awaited() @pytest.mark.asyncio -async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender(): +async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender(tmp_path, monkeypatch): adapter = _MediaRoutingAdapter() event = _event() + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg") adapter._message_handler = AsyncMock( - return_value="[[audio_as_voice]]\nMEDIA:/tmp/speech.ogg" + return_value=f"[[audio_as_voice]]\nMEDIA:{media_file}" ) adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) @@ -100,7 +115,7 @@ async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_ adapter.send_voice.assert_awaited_once_with( chat_id="chat-1", - audio_path="/tmp/speech.ogg", + audio_path=str(media_file), metadata=None, ) adapter.send_document.assert_not_awaited() @@ -117,8 +132,9 @@ def _fake_runner(thread_meta): @pytest.mark.asyncio -async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender(): +async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch): event = _event(thread_id="topic-1") + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac") adapter = SimpleNamespace( name="test", extract_media=BasePlatformAdapter.extract_media, @@ -132,22 +148,23 @@ async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sen await GatewayRunner._deliver_media_from_response( _fake_runner({"thread_id": "topic-1"}), - "MEDIA:/tmp/speech.flac", + f"MEDIA:{media_file}", event, adapter, ) adapter.send_document.assert_awaited_once_with( chat_id="chat-1", - file_path="/tmp/speech.flac", + file_path=str(media_file), metadata={"thread_id": "topic-1"}, ) adapter.send_voice.assert_not_awaited() @pytest.mark.asyncio -async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender(): +async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch): event = _event(thread_id="topic-1") + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg") adapter = SimpleNamespace( name="test", extract_media=BasePlatformAdapter.extract_media, @@ -161,24 +178,25 @@ async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_doc await GatewayRunner._deliver_media_from_response( _fake_runner({"thread_id": "topic-1"}), - "MEDIA:/tmp/speech.ogg", + f"MEDIA:{media_file}", event, adapter, ) adapter.send_document.assert_awaited_once_with( chat_id="chat-1", - file_path="/tmp/speech.ogg", + file_path=str(media_file), metadata={"thread_id": "topic-1"}, ) adapter.send_voice.assert_not_awaited() @pytest.mark.asyncio -async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(): +async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(tmp_path, monkeypatch): """MP3 audio on Telegram must go through send_voice (which routes to sendAudio internally); Telegram accepts MP3 for the audio player.""" event = _event(thread_id="topic-1") + media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.mp3") adapter = SimpleNamespace( name="test", extract_media=BasePlatformAdapter.extract_media, @@ -192,14 +210,47 @@ async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender( await GatewayRunner._deliver_media_from_response( _fake_runner({"thread_id": "topic-1"}), - "MEDIA:/tmp/speech.mp3", + f"MEDIA:{media_file}", event, adapter, ) adapter.send_voice.assert_awaited_once_with( chat_id="chat-1", - audio_path="/tmp/speech.mp3", + audio_path=str(media_file), metadata={"thread_id": "topic-1"}, ) adapter.send_document.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_streaming_delivery_blocks_media_path_outside_allowed_roots(tmp_path, monkeypatch): + event = _event(thread_id="topic-1") + allowed_root = tmp_path / "media-cache" + allowed_root.mkdir() + secret = tmp_path / "outside.pdf" + secret.write_bytes(b"%PDF secret") + monkeypatch.setattr( + "gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS", + (allowed_root,), + ) + adapter = SimpleNamespace( + name="test", + extract_media=BasePlatformAdapter.extract_media, + extract_images=BasePlatformAdapter.extract_images, + extract_local_files=BasePlatformAdapter.extract_local_files, + send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")), + send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")), + send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")), + send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")), + ) + + await GatewayRunner._deliver_media_from_response( + _fake_runner({"thread_id": "topic-1"}), + f"MEDIA:{secret}", + event, + adapter, + ) + + adapter.send_document.assert_not_awaited() + adapter.send_voice.assert_not_awaited() diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index b02b7f72ff5..160b35c6449 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -511,7 +511,7 @@ class TestDiscordPlayTtsSkip: """Discord adapter skips play_tts when bot is in a voice channel.""" def _make_discord_adapter(self): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import Platform, PlatformConfig config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -599,7 +599,7 @@ class TestVoiceReceiver: """Test VoiceReceiver silence detection, SSRC mapping, and lifecycle.""" def _make_receiver(self): - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver mock_vc = MagicMock() mock_vc._connection.secret_key = [0] * 32 mock_vc._connection.dave_session = None @@ -1066,7 +1066,7 @@ class TestDiscordVoiceChannelMethods: """Test DiscordAdapter voice channel methods (join, leave, play, etc.).""" def _make_adapter(self): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import Platform, PlatformConfig config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -1208,7 +1208,7 @@ class TestDiscordVoiceChannelMethods: pcm_data = b"\x00" * 96000 - with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \ patch("tools.transcription_tools.transcribe_audio", return_value={"success": True, "transcript": "Hello"}), \ patch("tools.voice_mode.is_whisper_hallucination", return_value=False): @@ -1223,7 +1223,7 @@ class TestDiscordVoiceChannelMethods: callback = AsyncMock() adapter._voice_input_callback = callback - with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \ patch("tools.transcription_tools.transcribe_audio", return_value={"success": True, "transcript": "Thank you."}), \ patch("tools.voice_mode.is_whisper_hallucination", return_value=True): @@ -1238,7 +1238,7 @@ class TestDiscordVoiceChannelMethods: callback = AsyncMock() adapter._voice_input_callback = callback - with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \ patch("tools.transcription_tools.transcribe_audio", return_value={"success": False, "error": "API error"}): await adapter._process_voice_input(111, 42, b"\x00" * 96000) @@ -1251,7 +1251,7 @@ class TestDiscordVoiceChannelMethods: adapter = self._make_adapter() adapter._voice_input_callback = AsyncMock() - with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav", + with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav", side_effect=RuntimeError("ffmpeg not found")): await adapter._process_voice_input(111, 42, b"\x00" * 96000) # Should not raise @@ -1269,7 +1269,7 @@ class TestVoiceReceiverThreadSafety: """Verify that VoiceReceiver buffer access is protected by lock.""" def _make_receiver(self): - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver mock_vc = MagicMock() mock_vc._connection.secret_key = [0] * 32 mock_vc._connection.dave_session = None @@ -1282,7 +1282,7 @@ class TestVoiceReceiverThreadSafety: def test_check_silence_holds_lock(self): """check_silence must hold lock while iterating buffers.""" import ast, inspect, textwrap - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver source = textwrap.dedent(inspect.getsource(VoiceReceiver.check_silence)) tree = ast.parse(source) # Find 'with self._lock:' that contains buffer iteration @@ -1303,7 +1303,7 @@ class TestVoiceReceiverThreadSafety: def test_on_packet_buffer_write_holds_lock(self): """_on_packet must hold lock when writing to buffers.""" import ast, inspect, textwrap - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver source = textwrap.dedent(inspect.getsource(VoiceReceiver._on_packet)) tree = ast.parse(source) # Find 'with self._lock:' that contains buffer extend @@ -1670,7 +1670,7 @@ class TestStopAcquiresLock: @staticmethod def _make_receiver(): - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver vc = MagicMock() vc._connection.secret_key = [0] * 32 vc._connection.dave_session = None @@ -1772,7 +1772,7 @@ class TestPacketDebugCounterIsInstanceLevel: @staticmethod def _make_receiver(): - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver vc = MagicMock() vc._connection.secret_key = [0] * 32 vc._connection.dave_session = None @@ -1805,7 +1805,7 @@ class TestPlayInVoiceChannelUsesRunningLoop: def test_source_uses_get_running_loop(self): """The method source code calls get_running_loop, not get_event_loop.""" import inspect - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter source = inspect.getsource(DiscordAdapter.play_in_voice_channel) assert "get_running_loop" in source, \ "play_in_voice_channel should use asyncio.get_running_loop()" @@ -1849,7 +1849,7 @@ class TestVoiceTimeoutCleansRunnerState: @staticmethod def _make_discord_adapter(): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import PlatformConfig, Platform config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -1940,7 +1940,7 @@ class TestPlaybackTimeout: @staticmethod def _make_discord_adapter(): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import PlatformConfig, Platform config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -1964,7 +1964,7 @@ class TestPlaybackTimeout: def test_source_has_wait_for_timeout(self): """The method uses asyncio.wait_for with timeout.""" import inspect - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter source = inspect.getsource(DiscordAdapter.play_in_voice_channel) assert "wait_for" in source, \ "play_in_voice_channel must use asyncio.wait_for for timeout" @@ -1973,14 +1973,14 @@ class TestPlaybackTimeout: def test_playback_timeout_constant_exists(self): """PLAYBACK_TIMEOUT constant is defined on DiscordAdapter.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter assert hasattr(DiscordAdapter, "PLAYBACK_TIMEOUT") assert DiscordAdapter.PLAYBACK_TIMEOUT > 0 @pytest.mark.asyncio async def test_playback_timeout_fires(self): """When done event is never set, playback times out gracefully.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter adapter = self._make_discord_adapter() mock_vc = MagicMock() @@ -2008,7 +2008,7 @@ class TestPlaybackTimeout: @pytest.mark.asyncio async def test_is_playing_wait_has_timeout(self): """While loop waiting for previous playback has a timeout.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter adapter = self._make_discord_adapter() mock_vc = MagicMock() @@ -2124,7 +2124,7 @@ class TestVoiceChannelAwareness: """Tests for get_voice_channel_info() and get_voice_channel_context().""" def _make_adapter(self): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import PlatformConfig config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -2267,7 +2267,7 @@ class TestVoiceReception: @staticmethod def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999): - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver vc = MagicMock() vc._connection.secret_key = [0] * 32 vc._connection.dave_session = MagicMock() if dave else None @@ -2451,7 +2451,7 @@ class TestVoiceReception: def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None): """Create a receiver that can process _on_packet with mocked NaCl + Opus.""" - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver vc = MagicMock() vc._connection.secret_key = [0] * 32 vc._connection.dave_session = dave_session @@ -2593,7 +2593,7 @@ class TestVoiceTTSPlayback: @staticmethod def _make_discord_adapter(): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import PlatformConfig, Platform config = PlatformConfig(enabled=True, extra={}) config.token = "fake-token" @@ -2766,14 +2766,14 @@ class TestUDPKeepalive: """UDP keepalive prevents Discord from dropping the voice session.""" def test_keepalive_interval_is_reasonable(self): - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter interval = DiscordAdapter._KEEPALIVE_INTERVAL assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s" @pytest.mark.asyncio async def test_keepalive_sends_silence_frame(self): """Listen loop sends silence frame via send_packet after interval.""" - from gateway.platforms.discord import DiscordAdapter + from plugins.platforms.discord.adapter import DiscordAdapter from gateway.config import PlatformConfig, Platform config = PlatformConfig(enabled=True, extra={}) @@ -2795,7 +2795,7 @@ class TestUDPKeepalive: adapter._voice_clients[111] = mock_vc mock_vc._connection = mock_conn - from gateway.platforms.discord import VoiceReceiver + from plugins.platforms.discord.adapter import VoiceReceiver mock_receiver_vc = MagicMock() mock_receiver_vc._connection.secret_key = [0] * 32 mock_receiver_vc._connection.dave_session = None diff --git a/tests/gateway/test_webhook_adapter.py b/tests/gateway/test_webhook_adapter.py index 8ca98cfb2bf..9cf61c3c3b5 100644 --- a/tests/gateway/test_webhook_adapter.py +++ b/tests/gateway/test_webhook_adapter.py @@ -15,6 +15,7 @@ Covers: """ import asyncio +import base64 import hashlib import hmac import json @@ -100,6 +101,18 @@ def _generic_signature(body: bytes, secret: str) -> str: return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest() +def _svix_signature(body: bytes, secret: str, msg_id: str, timestamp: str) -> str: + """Compute a Svix v1 signature header for *body* using *secret*.""" + key = ( + base64.b64decode(secret.removeprefix("whsec_")) + if secret.startswith("whsec_") + else secret.encode() + ) + signed = msg_id.encode() + b"." + timestamp.encode() + b"." + body + digest = hmac.new(key, signed, hashlib.sha256).digest() + return "v1," + base64.b64encode(digest).decode() + + # =================================================================== # Signature validation # =================================================================== @@ -170,6 +183,134 @@ class TestValidateSignature: req = _mock_request(headers={"X-Webhook-Signature": sig}) assert adapter._validate_signature(req, body, secret) is True + def test_validate_svix_signature_valid(self): + """Valid Svix/AgentMail v1 signature headers are accepted.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode() + msg_id = "msg_123" + timestamp = str(int(time.time())) + sig = _svix_signature(body, secret, msg_id, timestamp) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": sig, + } + ) + assert adapter._validate_signature(req, body, secret) is True + + def test_validate_svix_signature_wrong_body_rejects(self): + """Svix/AgentMail signatures are bound to the exact raw request body.""" + adapter = _make_adapter() + signed_body = b'{"event_type":"message.received"}' + received_body = b'{"event_type":"message.sent"}' + secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode() + msg_id = "msg_123" + timestamp = str(int(time.time())) + sig = _svix_signature(signed_body, secret, msg_id, timestamp) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": sig, + } + ) + assert adapter._validate_signature(req, received_body, secret) is False + + def test_validate_svix_signature_old_timestamp_rejects(self): + """Svix/AgentMail signatures outside the replay window are rejected.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode() + msg_id = "msg_123" + timestamp = str(int(time.time()) - 301) + sig = _svix_signature(body, secret, msg_id, timestamp) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": sig, + } + ) + assert adapter._validate_signature(req, body, secret) is False + + def test_validate_svix_signature_multiple_entries_accepts_matching_v1(self): + """Svix rotation headers may contain multiple space-separated signatures.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode() + msg_id = "msg_123" + timestamp = str(int(time.time())) + sig = _svix_signature(body, secret, msg_id, timestamp) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": "v1,wrong " + sig, + } + ) + assert adapter._validate_signature(req, body, secret) is True + + def test_validate_svix_signature_missing_signature_rejects(self): + """Partial Svix headers reject instead of falling through to another scheme.""" + adapter = _make_adapter() + req = _mock_request(headers={"svix-id": "msg_123"}) + assert adapter._validate_signature(req, b"{}", "secret") is False + + def test_validate_svix_signature_unsupported_version_rejects(self): + """Only Svix v1 signatures are accepted.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode() + msg_id = "msg_123" + timestamp = str(int(time.time())) + sig = _svix_signature(body, secret, msg_id, timestamp).replace("v1,", "v2,") + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": sig, + } + ) + assert adapter._validate_signature(req, body, secret) is False + + def test_validate_svix_signature_invalid_whsec_rejects(self): + """Malformed whsec_ secrets are rejected, not silently treated as raw secrets.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + malformed_secret = "whsec_not-valid-base64!" + msg_id = "msg_123" + timestamp = str(int(time.time())) + raw_sig = _svix_signature( + body, malformed_secret.removeprefix("whsec_"), msg_id, timestamp + ) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": raw_sig, + } + ) + assert adapter._validate_signature(req, body, malformed_secret) is False + + def test_validate_svix_signature_raw_secret_valid(self): + """Raw shared secrets are accepted for Svix-style senders without whsec_ secrets.""" + adapter = _make_adapter() + body = b'{"event_type":"message.received"}' + secret = "raw-agentmail-secret" + msg_id = "msg_123" + timestamp = str(int(time.time())) + sig = _svix_signature(body, secret, msg_id, timestamp) + req = _mock_request( + headers={ + "svix-id": msg_id, + "svix-timestamp": timestamp, + "svix-signature": sig, + } + ) + assert adapter._validate_signature(req, body, secret) is True + # =================================================================== # Prompt rendering @@ -304,6 +445,27 @@ class TestEventFilter: ) assert resp.status == 202 + @pytest.mark.asyncio + async def test_event_filter_accepts_payload_type_field(self): + """Svix-style payloads often use a top-level `type` event field.""" + routes = { + "svix": { + "secret": _INSECURE_NO_AUTH, + "events": ["message.received"], + "prompt": "got it", + } + } + adapter = _make_adapter(routes=routes) + adapter.handle_message = AsyncMock() + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post( + "/webhooks/svix", + json={"type": "message.received"}, + ) + assert resp.status == 202 + # =================================================================== # HTTP handling @@ -336,6 +498,22 @@ class TestHTTPHandling: assert data["status"] == "accepted" assert data["route"] == "test" + @pytest.mark.asyncio + async def test_route_without_secret_rejects_unsigned_request(self): + """Missing HMAC secret must fail closed even if connect() was bypassed.""" + routes = {"test": {"prompt": "hi"}} + adapter = _make_adapter(routes=routes, secret="") + adapter.handle_message = AsyncMock() + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/webhooks/test", json={"data": "value"}) + assert resp.status == 403 + data = await resp.json() + assert data["error"] == "Webhook route is missing an HMAC secret" + + adapter.handle_message.assert_not_called() + @pytest.mark.asyncio async def test_health_endpoint(self): """GET /health returns 200 with status=ok.""" @@ -432,6 +610,25 @@ class TestIdempotency: resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers) assert resp2.status == 202 # re-accepted + @pytest.mark.asyncio + async def test_svix_id_used_as_delivery_id_for_deduplication(self): + """Svix retries reuse svix-id, so use it as the delivery ID when present.""" + routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}} + adapter = _make_adapter(routes=routes) + adapter.handle_message = AsyncMock() + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + headers = {"svix-id": "msg_duplicate"} + resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers) + assert resp1.status == 202 + + resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers) + assert resp2.status == 200 + data = await resp2.json() + assert data["status"] == "duplicate" + assert data["delivery_id"] == "msg_duplicate" + # =================================================================== # Rate limiting diff --git a/tests/gateway/test_webhook_dynamic_routes.py b/tests/gateway/test_webhook_dynamic_routes.py index 2029dd1399e..98c0db26492 100644 --- a/tests/gateway/test_webhook_dynamic_routes.py +++ b/tests/gateway/test_webhook_dynamic_routes.py @@ -6,7 +6,11 @@ import pytest from pathlib import Path from gateway.config import PlatformConfig -from gateway.platforms.webhook import WebhookAdapter, _DYNAMIC_ROUTES_FILENAME +from gateway.platforms.webhook import ( + WebhookAdapter, + _DYNAMIC_ROUTES_FILENAME, + _INSECURE_NO_AUTH, +) def _make_adapter(routes=None, extra=None): @@ -85,3 +89,88 @@ class TestDynamicRouteLoading: adapter._reload_dynamic_routes() assert "static" in adapter._routes assert len(adapter._dynamic_routes) == 0 + + +class TestDynamicRouteSecretValidation: + """Empty/missing secrets must be rejected during hot-reload. + + Regression for HMAC bypass: prior to the fix, an agent-induced + dynamic route with `"secret": ""` would be merged into self._routes + by _reload_dynamic_routes(), then _handle_webhook's + `if secret and secret != _INSECURE_NO_AUTH` would skip signature + validation because empty string is falsy. Unauthenticated POSTs + would then execute the webhook prompt. + """ + + def test_empty_secret_rejected(self, tmp_path): + # Explicit empty-string secret must NOT fall back to the global + # secret, and the route must be skipped entirely. + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"evil": {"secret": "", "prompt": "rm -rf"}}) + ) + adapter = _make_adapter() # has global secret + adapter._reload_dynamic_routes() + assert "evil" not in adapter._routes + assert "evil" not in adapter._dynamic_routes + + def test_missing_secret_no_global_rejected(self, tmp_path): + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"orphan": {"prompt": "test"}}) + ) + # No global secret configured + adapter = _make_adapter(extra={"secret": ""}) + adapter._reload_dynamic_routes() + assert "orphan" not in adapter._routes + assert "orphan" not in adapter._dynamic_routes + + def test_missing_secret_inherits_global(self, tmp_path): + # No per-route secret but a global one is set โ†’ route is kept, + # the global secret protects it. Preserves existing fallback. + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"valid": {"prompt": "ok"}}) + ) + adapter = _make_adapter() # global secret set + adapter._reload_dynamic_routes() + assert "valid" in adapter._routes + + def test_insecure_no_auth_preserved(self, tmp_path): + # Explicit opt-in escape hatch for local testing โ€” must still load. + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"test": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}}) + ) + adapter = _make_adapter(extra={"host": "127.0.0.1"}) + adapter._reload_dynamic_routes() + assert "test" in adapter._routes + + def test_insecure_no_auth_rejected_on_non_loopback_bind(self, tmp_path): + # Dynamic INSECURE_NO_AUTH routes are only valid on loopback hosts. + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"pub": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}}) + ) + adapter = _make_adapter(extra={"host": "0.0.0.0"}) + adapter._reload_dynamic_routes() + assert "pub" not in adapter._routes + assert "pub" not in adapter._dynamic_routes + + def test_warning_logged_on_skip(self, tmp_path, caplog): + import logging + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({"silent": {"secret": "", "prompt": "x"}}) + ) + adapter = _make_adapter() + with caplog.at_level(logging.WARNING, logger="gateway.platforms.webhook"): + adapter._reload_dynamic_routes() + assert any("silent" in rec.message for rec in caplog.records) + + def test_partial_skip(self, tmp_path): + # One route bad, one route good โ€” only the bad one is dropped. + (tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text( + json.dumps({ + "bad": {"secret": "", "prompt": "x"}, + "good": {"secret": "valid-secret", "prompt": "y"}, + }) + ) + adapter = _make_adapter() + adapter._reload_dynamic_routes() + assert "good" in adapter._routes + assert "bad" not in adapter._routes diff --git a/tests/gateway/test_wecom.py b/tests/gateway/test_wecom.py index 7bf56f9d319..02d04daf64e 100644 --- a/tests/gateway/test_wecom.py +++ b/tests/gateway/test_wecom.py @@ -1,5 +1,6 @@ """Tests for the WeCom platform adapter.""" +import asyncio import base64 import os from pathlib import Path @@ -831,3 +832,91 @@ class TestWeComZombieSessionFix: cmd = adapter._send_request.await_args.args[0] assert cmd == APP_CMD_SEND + + +class TestTextBatchFlushRace: + """Regression tests for the cancel-delivery race in _flush_text_batch. + + When asyncio.sleep() fires and Task.cancel() is called before the task + runs, CPython sets _must_cancel but cannot cancel the already-done sleep + future. CancelledError is then delivered at the *next* await + (handle_message), after the task has already popped the event โ€” the + superseding task sees an empty batch and silently drops the message. + The fix adds a synchronous task-registry check between the sleep and + the pop so a superseded task returns before touching the event. + """ + + @pytest.mark.asyncio + async def test_superseded_task_does_not_pop_or_process_event(self): + """A flush task that has been superseded must leave the event in the + batch dict for the new task to handle.""" + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._text_batch_delay_seconds = 0 + + key = "test-session" + event = MessageEvent(text="hello", message_type=MessageType.TEXT) + adapter._pending_text_batches[key] = event + + handle_calls = [] + + async def fake_handle(evt): + handle_calls.append(evt) + + adapter.handle_message = fake_handle + + # Create T1 and register it. + t1 = asyncio.create_task(adapter._flush_text_batch(key)) + adapter._pending_text_batch_tasks[key] = t1 + + # Simulate T2 superseding T1 before T1 wakes from sleep. + t2 = asyncio.create_task(asyncio.sleep(9999)) + adapter._pending_text_batch_tasks[key] = t2 + + # Yield long enough for T1's sleep(0) to complete and T1 to run. + await asyncio.sleep(0.05) + + t2.cancel() + try: + await t2 + except asyncio.CancelledError: + pass + + # T1 must have returned without processing or removing the event. + assert handle_calls == [], "superseded task must not call handle_message" + assert adapter._pending_text_batches.get(key) is event, ( + "superseded task must not pop the event" + ) + + @pytest.mark.asyncio + async def test_active_task_processes_event_normally(self): + """When the task is not superseded it must still process the event.""" + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._text_batch_delay_seconds = 0 + + key = "test-session" + event = MessageEvent(text="world", message_type=MessageType.TEXT) + adapter._pending_text_batches[key] = event + + handle_calls = [] + + async def fake_handle(evt): + handle_calls.append(evt) + + adapter.handle_message = fake_handle + + t1 = asyncio.create_task(adapter._flush_text_batch(key)) + adapter._pending_text_batch_tasks[key] = t1 + + # No superseding task โ€” T1 should process normally. + await asyncio.sleep(0.05) + + assert handle_calls == [event], "active task must call handle_message" + assert adapter._pending_text_batches.get(key) is None, ( + "active task must pop the event after processing" + ) diff --git a/tests/gateway/test_wecom_callback.py b/tests/gateway/test_wecom_callback.py index 88c084ae3e0..e4646b70b5e 100644 --- a/tests/gateway/test_wecom_callback.py +++ b/tests/gateway/test_wecom_callback.py @@ -153,6 +153,130 @@ class TestWecomCallbackRouting: assert calls["json"]["agentid"] == 1001 +class TestWecomCallbackSendTokenRefresh: + @pytest.mark.asyncio + async def test_send_retries_with_fresh_token_on_errcode_40001(self): + """errcode=40001 must evict the cached token, refresh, and retry once.""" + adapter = WecomCallbackAdapter(_config()) + adapter._access_tokens["test-app"] = {"token": "stale", "expires_at": 9999999999} + adapter._user_app_map["ww1234567890:alice"] = "test-app" + + responses = [ + {"errcode": 40001, "errmsg": "invalid credential"}, + {"errcode": 0, "msgid": "msg-ok"}, + ] + post_calls = [] + + class FakeClient: + async def post(self, url, json=None, **kw): + post_calls.append(url) + + class R: + def json(inner): + return responses[len(post_calls) - 1] + return R() + + async def get(self, url, params=None, **kw): + class R: + def json(inner): + return {"errcode": 0, "access_token": "fresh", "expires_in": 7200} + return R() + + adapter._http_client = FakeClient() + result = await adapter.send("ww1234567890:alice", "hello") + + assert result.success is True + assert result.message_id == "msg-ok" + assert len(post_calls) == 2 + assert "fresh" in post_calls[1] + assert adapter._access_tokens["test-app"]["token"] == "fresh" + + @pytest.mark.asyncio + async def test_send_retries_with_fresh_token_on_errcode_42001(self): + """errcode=42001 (token expired) must also trigger the refresh-retry path.""" + adapter = WecomCallbackAdapter(_config()) + adapter._access_tokens["test-app"] = {"token": "expired", "expires_at": 9999999999} + + responses = [ + {"errcode": 42001, "errmsg": "access_token expired"}, + {"errcode": 0, "msgid": "msg-42"}, + ] + post_calls = [] + + class FakeClient: + async def post(self, url, json=None, **kw): + post_calls.append(url) + + class R: + def json(inner): + return responses[len(post_calls) - 1] + return R() + + async def get(self, url, params=None, **kw): + class R: + def json(inner): + return {"errcode": 0, "access_token": "renewed", "expires_in": 7200} + return R() + + adapter._http_client = FakeClient() + result = await adapter.send("alice", "hello") + + assert result.success is True + assert len(post_calls) == 2 + + @pytest.mark.asyncio + async def test_send_does_not_retry_on_non_token_errcode(self): + """Errors unrelated to token validity must fail immediately without retrying.""" + adapter = WecomCallbackAdapter(_config()) + adapter._access_tokens["test-app"] = {"token": "good", "expires_at": 9999999999} + + post_calls = [] + + class FakeClient: + async def post(self, url, json=None, **kw): + post_calls.append(url) + + class R: + def json(inner): + return {"errcode": 60020, "errmsg": "not allow to access"} + return R() + + adapter._http_client = FakeClient() + result = await adapter.send("alice", "hello") + + assert result.success is False + assert len(post_calls) == 1 + + @pytest.mark.asyncio + async def test_send_fails_cleanly_when_retry_also_fails(self): + """If the refreshed token is also rejected, return failure without looping further.""" + adapter = WecomCallbackAdapter(_config()) + adapter._access_tokens["test-app"] = {"token": "bad1", "expires_at": 9999999999} + + post_calls = [] + + class FakeClient: + async def post(self, url, json=None, **kw): + post_calls.append(url) + + class R: + def json(inner): + return {"errcode": 42001, "errmsg": "access_token expired"} + return R() + + async def get(self, url, params=None, **kw): + class R: + def json(inner): + return {"errcode": 0, "access_token": "bad2", "expires_in": 7200} + return R() + + adapter._http_client = FakeClient() + result = await adapter.send("alice", "hello") + + assert result.success is False + assert len(post_calls) == 2 + + class TestWecomCallbackPollLoop: @pytest.mark.asyncio async def test_poll_loop_dispatches_handle_message(self, monkeypatch): diff --git a/tests/hermes_cli/test_argparse_flag_propagation.py b/tests/hermes_cli/test_argparse_flag_propagation.py index 741425a82dc..c3d8e80db32 100644 --- a/tests/hermes_cli/test_argparse_flag_propagation.py +++ b/tests/hermes_cli/test_argparse_flag_propagation.py @@ -57,6 +57,59 @@ def _build_parser(): return parser +class TestChatVerboseArg: + """Verify chat --verbose preserves config fallback when absent.""" + + def test_chat_without_verbose_leaves_attribute_unset(self): + from hermes_cli._parser import build_top_level_parser + + parser, _subparsers, _chat_parser = build_top_level_parser() + args = parser.parse_args(["chat"]) + + assert not hasattr(args, "verbose") + + def test_chat_verbose_sets_attribute_true(self): + from hermes_cli._parser import build_top_level_parser + + parser, _subparsers, _chat_parser = build_top_level_parser() + args = parser.parse_args(["chat", "--verbose"]) + + assert args.verbose is True + + def test_cmd_chat_forwards_none_when_verbose_is_absent(self, monkeypatch): + import types + import sys + + import hermes_cli.main as main_mod + from hermes_cli._parser import build_top_level_parser + + parser, _subparsers, chat_parser = build_top_level_parser() + chat_parser.set_defaults(func=main_mod.cmd_chat) + args = parser.parse_args(["chat"]) + captured = {} + fake_cli = types.ModuleType("cli") + + def fake_main(**kwargs): + captured.update(kwargs) + + setattr(fake_cli, "main", fake_main) + fake_banner = types.ModuleType("hermes_cli.banner") + setattr(fake_banner, "prefetch_update_check", lambda: None) + fake_skills_sync = types.ModuleType("tools.skills_sync") + setattr(fake_skills_sync, "sync_skills", lambda quiet=True: None) + + monkeypatch.setitem(sys.modules, "cli", fake_cli) + monkeypatch.setitem(sys.modules, "hermes_cli.banner", fake_banner) + monkeypatch.setitem(sys.modules, "tools.skills_sync", fake_skills_sync) + monkeypatch.setattr(main_mod, "_has_any_provider_configured", lambda: True) + monkeypatch.setattr(main_mod, "_pin_kanban_board_env", lambda: None) + + main_mod.cmd_chat(args) + + assert captured["quiet"] is False + assert "verbose" not in captured + + class TestYoloEnvVar: """Verify --yolo sets HERMES_YOLO_MODE regardless of flag position. diff --git a/tests/hermes_cli/test_auth_qwen_provider.py b/tests/hermes_cli/test_auth_qwen_provider.py index f1943d8459b..a2f58df6b0b 100644 --- a/tests/hermes_cli/test_auth_qwen_provider.py +++ b/tests/hermes_cli/test_auth_qwen_provider.py @@ -392,8 +392,84 @@ def test_get_qwen_auth_status_logged_in(qwen_env): assert status["api_key"] == "status-at" +def test_get_qwen_auth_status_refreshes_expired_token(qwen_env): + expired_ms = int((time.time() - 3600) * 1000) + tokens = _make_qwen_tokens(access_token="old-at", expiry_date=expired_ms) + _write_qwen_creds(qwen_env, tokens) + + refreshed = _make_qwen_tokens(access_token="refreshed-at") + + with patch( + "hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed + ) as mock_refresh: + status = get_qwen_auth_status() + + mock_refresh.assert_called_once() + assert status["logged_in"] is True + assert status["api_key"] == "refreshed-at" + + +def test_get_qwen_auth_status_expired_unrefreshable_token_is_not_logged_in(qwen_env): + expired_ms = int((time.time() - 3600) * 1000) + tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms) + _write_qwen_creds(qwen_env, tokens) + + with patch( + "hermes_cli.auth._refresh_qwen_cli_tokens", + side_effect=AuthError( + "Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.", + provider="qwen-oauth", + code="qwen_refresh_failed", + ), + ) as mock_refresh: + status = get_qwen_auth_status() + + mock_refresh.assert_called_once() + assert status["logged_in"] is False + assert "qwen auth qwen-oauth" in status["error"] + + def test_get_qwen_auth_status_not_logged_in(qwen_env): # No credentials file status = get_qwen_auth_status() assert status["logged_in"] is False assert "error" in status + + +def test_model_flow_qwen_oauth_stale_token_shows_reauth_guidance(qwen_env, monkeypatch, capsys): + from hermes_cli.main import _model_flow_qwen_oauth + + expired_ms = int((time.time() - 3600) * 1000) + tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms) + _write_qwen_creds(qwen_env, tokens) + + monkeypatch.setattr( + "hermes_cli.auth._refresh_qwen_cli_tokens", + lambda *args, **kwargs: (_ for _ in ()).throw( + AuthError( + "Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.", + provider="qwen-oauth", + code="qwen_refresh_failed", + ) + ), + ) + + prompt_called = {"value": False} + update_called = {"value": False} + + monkeypatch.setattr( + "hermes_cli.auth._prompt_model_selection", + lambda *args, **kwargs: prompt_called.__setitem__("value", True), + ) + monkeypatch.setattr( + "hermes_cli.auth._update_config_for_provider", + lambda *args, **kwargs: update_called.__setitem__("value", True), + ) + + _model_flow_qwen_oauth({}, current_model="qwen3-coder-plus") + + out = capsys.readouterr().out + assert "Run: qwen auth qwen-oauth" in out + assert "Qwen refresh rejected" in out + assert prompt_called["value"] is False + assert update_called["value"] is False diff --git a/tests/hermes_cli/test_auth_usable_secret.py b/tests/hermes_cli/test_auth_usable_secret.py new file mode 100644 index 00000000000..cb24ef5ee26 --- /dev/null +++ b/tests/hermes_cli/test_auth_usable_secret.py @@ -0,0 +1,13 @@ +"""Tests for placeholder API key detection in hermes_cli.auth.""" + +from hermes_cli.auth import has_usable_secret + + +def test_has_usable_secret_rejects_documented_placeholder_key() -> None: + """Network-exposed API server key must reject static documentation placeholders.""" + assert not has_usable_secret("your_api_key_here", min_length=8) + + +def test_has_usable_secret_accepts_generated_key() -> None: + """Random-looking keys should still be accepted.""" + assert has_usable_secret("b4d59f7fe8b857d0b367ef0f5710b6a4", min_length=8) diff --git a/tests/hermes_cli/test_curses_color_compat.py b/tests/hermes_cli/test_curses_color_compat.py new file mode 100644 index 00000000000..c7509cc965f --- /dev/null +++ b/tests/hermes_cli/test_curses_color_compat.py @@ -0,0 +1,131 @@ +"""Tests for curses color compatibility on low-color terminals (Docker). + +Regression test for #13688: ``hermes plugins`` crashes with +``curses.error: init_pair() : color number is greater than COLORS-1`` +in Docker containers where curses.COLORS == 8 (only colors 0-7 exist). + +The bug was ``curses.init_pair(4, 8, -1)`` using raw color 8 ("bright +black" / dim gray) which does not exist on 8-color terminals. The fix +clamps with ``min(8, curses.COLORS - 1)``. +""" + +import curses +import re +from pathlib import Path +from unittest.mock import patch, MagicMock, call + +import pytest + + +# Path to the source files under test +_SRC_ROOT = Path(__file__).parent.parent.parent / "hermes_cli" + + +class TestInitPairClampingBehavior: + """Simulate curses color initialization on low-color terminals. + + Patches curses.COLORS to 8 (Docker default) and verifies that + init_pair is never called with a color >= COLORS. + """ + + def _collect_init_pair_calls(self, draw_fn, colors_value): + """Run a curses draw function with a mock stdscr and patched COLORS. + + Returns list of (pair_number, fg, bg) tuples from init_pair calls. + """ + calls = [] + real_init_pair = curses.init_pair + + def tracking_init_pair(pair, fg, bg): + calls.append((pair, fg, bg)) + + mock_stdscr = MagicMock() + mock_stdscr.getmaxyx.return_value = (24, 80) + mock_stdscr.getch.return_value = 27 # ESC to exit + + with patch("curses.COLORS", colors_value, create=True), \ + patch("curses.init_pair", side_effect=tracking_init_pair), \ + patch("curses.has_colors", return_value=True), \ + patch("curses.start_color"), \ + patch("curses.use_default_colors"), \ + patch("curses.curs_set"): + try: + draw_fn(mock_stdscr) + except (SystemExit, StopIteration, Exception): + pass # draw functions loop until keypress + + return calls + + def test_8_color_terminal_no_color_exceeds_limit(self): + """On an 8-color terminal (Docker), no init_pair fg color >= 8.""" + # Simulate the color init pattern from plugins_cmd.py + def _simulated_color_init(stdscr): + if curses.has_colors(): + curses.start_color() + curses.use_default_colors() + curses.init_pair(1, curses.COLOR_GREEN, -1) + curses.init_pair(2, curses.COLOR_YELLOW, -1) + curses.init_pair(3, curses.COLOR_CYAN, -1) + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) + + calls = self._collect_init_pair_calls(_simulated_color_init, 8) + for pair, fg, bg in calls: + assert fg < 8, ( + f"init_pair({pair}, {fg}, {bg}) uses color {fg} which " + f"does not exist on an 8-color terminal (valid: 0-7)" + ) + + def test_256_color_terminal_uses_color_8(self): + """On a 256-color terminal, color 8 (dim gray) should be used.""" + def _simulated_color_init(stdscr): + if curses.has_colors(): + curses.start_color() + curses.use_default_colors() + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) + + calls = self._collect_init_pair_calls(_simulated_color_init, 256) + assert any(fg == 8 for _, fg, _ in calls), ( + "On 256-color terminals, color 8 (dim gray) should be used" + ) + + def test_16_color_terminal_uses_color_8(self): + """On a 16-color terminal, color 8 should be available.""" + def _simulated_color_init(stdscr): + if curses.has_colors(): + curses.start_color() + curses.use_default_colors() + curses.init_pair(4, 8 if curses.COLORS > 8 else curses.COLOR_WHITE, -1) + + calls = self._collect_init_pair_calls(_simulated_color_init, 16) + assert any(fg == 8 for _, fg, _ in calls) + + +class TestSourceCodeGuardrails: + """Regression guardrails: raw color 8 must not reappear in source. + + These complement the behavioral tests above โ€” they catch regressions + introduced by copy-paste of the old pattern. + """ + + _RAW_COLOR_8_PATTERN = re.compile(r'init_pair\(\d+,\s*8\s*,') + + def test_no_raw_color_8_in_plugins_cmd(self): + source = (_SRC_ROOT / "plugins_cmd.py").read_text() + matches = self._RAW_COLOR_8_PATTERN.findall(source) + assert not matches, ( + f"plugins_cmd.py contains unclamped color 8: {matches}" + ) + + def test_no_raw_color_8_in_main(self): + source = (_SRC_ROOT / "main.py").read_text() + matches = self._RAW_COLOR_8_PATTERN.findall(source) + assert not matches, ( + f"main.py contains unclamped color 8: {matches}" + ) + + def test_no_raw_color_8_in_curses_ui(self): + source = (_SRC_ROOT / "curses_ui.py").read_text() + matches = self._RAW_COLOR_8_PATTERN.findall(source) + assert not matches, ( + f"curses_ui.py contains unclamped color 8: {matches}" + ) diff --git a/tests/hermes_cli/test_debug.py b/tests/hermes_cli/test_debug.py index 1996e7fce98..aad1c8e92a5 100644 --- a/tests/hermes_cli/test_debug.py +++ b/tests/hermes_cli/test_debug.py @@ -353,6 +353,40 @@ class TestCaptureLogSnapshotRedaction: assert snap.full_text is not None assert _REDACT_FIXTURE_TOKEN not in snap.full_text + def test_default_redacts_email_addresses_for_public_share( + self, hermes_home_with_secret + ): + from hermes_cli.debug import _capture_log_snapshot + + log_path = hermes_home_with_secret / "logs" / "agent.log" + log_path.write_text( + "2026-04-12 17:00:00 INFO gateway.run: " + "inbound message: platform=bluebubbles " + "user=person@example.com chat=iMessage;-;person@example.com msg='hello'\n" + ) + + snap = _capture_log_snapshot("agent", tail_lines=10) + + assert "person@example.com" not in snap.tail_text + assert "[REDACTED_EMAIL]" in snap.tail_text + assert snap.full_text is not None + assert "person@example.com" not in snap.full_text + + def test_no_redact_preserves_email_addresses(self, hermes_home_with_secret): + from hermes_cli.debug import _capture_log_snapshot + + log_path = hermes_home_with_secret / "logs" / "agent.log" + log_path.write_text( + "2026-04-12 17:00:00 INFO gateway.run: " + "inbound message: platform=bluebubbles " + "user=person@example.com chat=iMessage;-;person@example.com msg='hello'\n" + ) + + snap = _capture_log_snapshot("agent", tail_lines=10, redact=False) + + assert "person@example.com" in snap.tail_text + assert "person@example.com" in (snap.full_text or "") + def test_capture_default_log_snapshots_threads_redact( self, hermes_home_with_secret ): diff --git a/tests/hermes_cli/test_env_loader.py b/tests/hermes_cli/test_env_loader.py index f309dfd4c6a..2523754a84b 100644 --- a/tests/hermes_cli/test_env_loader.py +++ b/tests/hermes_cli/test_env_loader.py @@ -70,6 +70,23 @@ def test_user_env_takes_precedence_over_project_env(tmp_path, monkeypatch): assert os.getenv("OPENAI_API_KEY") == "project-key" +def test_null_bytes_in_user_env_are_stripped(tmp_path, monkeypatch): + home = tmp_path / "hermes" + home.mkdir() + env_file = home / ".env" + # Null bytes can be introduced when copy-pasting API keys. + env_file.write_text("GLM_API_KEY=abc\x00\x00\nOPENAI_API_KEY=sk-123\n", encoding="utf-8") + + monkeypatch.delenv("GLM_API_KEY", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + loaded = load_hermes_dotenv(hermes_home=home) + + assert loaded == [env_file] + assert os.getenv("GLM_API_KEY") == "abc" + assert os.getenv("OPENAI_API_KEY") == "sk-123" + + def test_main_import_applies_user_env_over_shell_values(tmp_path, monkeypatch): home = tmp_path / "hermes" home.mkdir() diff --git a/tests/hermes_cli/test_fallback_cmd.py b/tests/hermes_cli/test_fallback_cmd.py index a88c84b3aa8..2eed7d62f97 100644 --- a/tests/hermes_cli/test_fallback_cmd.py +++ b/tests/hermes_cli/test_fallback_cmd.py @@ -55,6 +55,31 @@ class TestReadChain: {"provider": "nous", "model": "Hermes-4-Llama-3.1-405B"}, ] + def test_merges_new_and_legacy_formats(self): + from hermes_cli.fallback_cmd import _read_chain + cfg = { + "fallback_providers": [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + ], + "fallback_model": {"provider": "nous", "model": "Hermes-4"}, + } + assert _read_chain(cfg) == [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + {"provider": "nous", "model": "Hermes-4"}, + ] + + def test_legacy_duplicate_is_deduplicated_after_merge(self): + from hermes_cli.fallback_cmd import _read_chain + cfg = { + "fallback_providers": [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + ], + "fallback_model": {"provider": "OpenRouter", "model": "anthropic/claude-sonnet-4.6"}, + } + assert _read_chain(cfg) == [ + {"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"}, + ] + def test_migrates_legacy_single_dict(self): from hermes_cli.fallback_cmd import _read_chain cfg = {"fallback_model": {"provider": "openrouter", "model": "gpt-5.4"}} diff --git a/tests/hermes_cli/test_image_gen_picker.py b/tests/hermes_cli/test_image_gen_picker.py index 51eafd6da67..04d46bbbb86 100644 --- a/tests/hermes_cli/test_image_gen_picker.py +++ b/tests/hermes_cli/test_image_gen_picker.py @@ -69,18 +69,19 @@ class TestPluginPickerInjection: assert "Myimg" in names assert "myimg" in plugin_names - def test_fal_skipped_to_avoid_duplicate(self, monkeypatch): + def test_fal_surfaced_alongside_other_plugins(self, monkeypatch): from hermes_cli import tools_config - # Simulate a FAL plugin being registered โ€” the picker already has - # hardcoded FAL rows in TOOL_CATEGORIES, so plugin-FAL must be - # skipped to avoid showing FAL twice. + # After #26241, FAL is itself a plugin (`plugins/image_gen/fal/`) + # and the hardcoded `TOOL_CATEGORIES["image_gen"]` FAL row is + # gone. The plugin-row builder therefore surfaces it like any + # other backend โ€” no deduplication step needed. image_gen_registry.register_provider(_FakeProvider("fal")) image_gen_registry.register_provider(_FakeProvider("openai")) rows = tools_config._plugin_image_gen_providers() names = [r.get("image_gen_plugin_name") for r in rows] - assert "fal" not in names + assert "fal" in names assert "openai" in names def test_visible_providers_includes_plugins_for_image_gen(self, monkeypatch): diff --git a/tests/hermes_cli/test_install_cua_driver.py b/tests/hermes_cli/test_install_cua_driver.py index 6cd50261694..aa7fd68fec9 100644 --- a/tests/hermes_cli/test_install_cua_driver.py +++ b/tests/hermes_cli/test_install_cua_driver.py @@ -1,4 +1,4 @@ -"""Tests for ``install_cua_driver`` upgrade semantics. +"""Tests for ``install_cua_driver`` upgrade semantics and architecture pre-check. The cua-driver upstream installer always pulls the latest release tag, so re-running it is the canonical upgrade path. ``install_cua_driver(upgrade=True)`` @@ -10,18 +10,18 @@ must: fix for the "we only pulled cua-driver once on enable" complaint). * Preserve original ``upgrade=False`` behaviour for the toolset-enable flow: skip if installed, install otherwise, warn on non-macOS. +* Pre-check architecture compatibility before downloading to avoid raw 404 + errors on Intel macOS when the upstream release lacks x86_64 assets. """ from __future__ import annotations -from unittest.mock import patch +import json +from unittest.mock import MagicMock, patch class TestInstallCuaDriverUpgrade: def test_upgrade_on_non_macos_is_silent_noop(self): - """``hermes update`` calls install_cua_driver(upgrade=True) for every - user. On Linux/Windows it must return False without printing the - "macOS-only; skipping" warning that the toolset-enable path emits.""" from hermes_cli import tools_config with patch.object(tools_config, "_print_warning") as warn, \ @@ -30,8 +30,6 @@ class TestInstallCuaDriverUpgrade: warn.assert_not_called() def test_non_upgrade_on_non_macos_warns(self): - """The toolset-enable path (upgrade=False) should still warn loudly - when the user tries to enable Computer Use on a non-macOS host.""" from hermes_cli import tools_config with patch.object(tools_config, "_print_warning") as warn, \ @@ -40,43 +38,36 @@ class TestInstallCuaDriverUpgrade: warn.assert_called() def test_upgrade_on_macos_with_binary_runs_installer(self): - """When cua-driver is already on PATH and upgrade=True, we must - re-run the upstream installer (this is the fix for the bug report). - """ from hermes_cli import tools_config with patch("platform.system", return_value="Darwin"), \ patch.object(tools_config.shutil, "which", side_effect=lambda n: "/usr/local/bin/" + n if n in {"cua-driver", "curl"} else None), \ + patch.object(tools_config, "_check_cua_driver_asset_for_arch", + return_value=True), \ patch.object(tools_config, "_run_cua_driver_installer", return_value=True) as runner, \ patch("subprocess.run"): assert tools_config.install_cua_driver(upgrade=True) is True runner.assert_called_once() - # Refresh path uses non-verbose mode so we don't re-print the - # "grant macOS permissions" block on every `hermes update`. kwargs = runner.call_args.kwargs assert kwargs.get("verbose") is False def test_upgrade_on_macos_without_binary_runs_installer(self): - """upgrade=True with cua-driver missing must still trigger an - install โ€” equivalent to a fresh install. (Don't silently no-op.)""" from hermes_cli import tools_config with patch("platform.system", return_value="Darwin"), \ patch.object(tools_config.shutil, "which", side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \ + patch.object(tools_config, "_check_cua_driver_asset_for_arch", + return_value=True), \ patch.object(tools_config, "_run_cua_driver_installer", return_value=True) as runner: assert tools_config.install_cua_driver(upgrade=True) is True runner.assert_called_once() def test_non_upgrade_on_macos_with_binary_skips_install(self): - """Original toolset-enable behaviour: cua-driver already installed - + upgrade=False โ†’ confirm and return without re-running installer. - This is the behaviour that ``hermes tools`` (re)enable depends on, - so the new helper must not regress it.""" from hermes_cli import tools_config with patch("platform.system", return_value="Darwin"), \ @@ -89,27 +80,133 @@ class TestInstallCuaDriverUpgrade: runner.assert_not_called() def test_non_upgrade_on_macos_without_binary_runs_installer(self): - """Original fresh-install path must still work.""" from hermes_cli import tools_config with patch("platform.system", return_value="Darwin"), \ patch.object(tools_config.shutil, "which", side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \ + patch.object(tools_config, "_check_cua_driver_asset_for_arch", + return_value=True), \ patch.object(tools_config, "_run_cua_driver_installer", return_value=True) as runner: assert tools_config.install_cua_driver(upgrade=False) is True - runner.assert_called_once() - def test_upgrade_without_curl_does_not_crash(self): - """If curl isn't on PATH we can't refresh โ€” must warn and return - the current install state, not raise.""" + +class TestCheckCuaDriverAssetForArch: + def test_arm64_always_returns_true(self): from hermes_cli import tools_config - # cua-driver present, curl missing. - def _which(name): - return "/usr/local/bin/cua-driver" if name == "cua-driver" else None + with patch("platform.machine", return_value="arm64"): + assert tools_config._check_cua_driver_asset_for_arch() is True + + def test_x86_64_with_asset_returns_true(self): + from hermes_cli import tools_config + + release = { + "tag_name": "cua-driver-v0.1.6", + "assets": [ + {"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}, + {"name": "cua-driver-0.1.6-darwin-x86_64.tar.gz"}, + ], + } + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(release).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", return_value=mock_resp): + assert tools_config._check_cua_driver_asset_for_arch() is True + + def test_x86_64_without_asset_returns_false(self): + from hermes_cli import tools_config + + release = { + "tag_name": "cua-driver-v0.1.6", + "assets": [ + {"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}, + {"name": "cua-driver.tar.gz"}, + ], + } + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(release).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object(tools_config, "_print_warning") as warn, \ + patch.object(tools_config, "_print_info"): + assert tools_config._check_cua_driver_asset_for_arch() is False + warn.assert_called_once() + assert "no Intel" in warn.call_args[0][0].lower() or "x86_64" in warn.call_args[0][0] + + def test_x86_64_api_failure_returns_true(self): + """Network failure should fail open โ€” let the installer handle it.""" + from hermes_cli import tools_config + + with patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", side_effect=Exception("timeout")): + assert tools_config._check_cua_driver_asset_for_arch() is True + + def test_fresh_install_x86_64_no_asset_skips_installer(self): + """When the latest release has no Intel asset, skip the installer.""" + from hermes_cli import tools_config + + release = { + "tag_name": "cua-driver-v0.1.6", + "assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}], + } + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(release).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) with patch("platform.system", return_value="Darwin"), \ - patch.object(tools_config.shutil, "which", side_effect=_which), \ - patch.object(tools_config, "_print_warning"): + patch.object(tools_config.shutil, "which", + side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \ + patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object(tools_config, "_print_warning"), \ + patch.object(tools_config, "_print_info"), \ + patch.object(tools_config, "_run_cua_driver_installer") as runner: + assert tools_config.install_cua_driver(upgrade=False) is False + runner.assert_not_called() + + def test_upgrade_x86_64_no_asset_returns_existing_status(self): + """On upgrade with no Intel asset, return whether binary existed.""" + from hermes_cli import tools_config + + release = { + "tag_name": "cua-driver-v0.1.6", + "assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}], + } + mock_resp = MagicMock() + mock_resp.read.return_value = json.dumps(release).encode() + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + + # With binary installed โ€” returns True (binary exists) + with patch("platform.system", return_value="Darwin"), \ + patch.object(tools_config.shutil, "which", + side_effect=lambda n: "/usr/local/bin/" + n + if n in ("cua-driver", "curl") else None), \ + patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object(tools_config, "_print_warning"), \ + patch.object(tools_config, "_print_info"), \ + patch.object(tools_config, "_run_cua_driver_installer") as runner: assert tools_config.install_cua_driver(upgrade=True) is True + runner.assert_not_called() + + # Without binary โ€” returns False + with patch("platform.system", return_value="Darwin"), \ + patch.object(tools_config.shutil, "which", + side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \ + patch("platform.machine", return_value="x86_64"), \ + patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object(tools_config, "_print_warning"), \ + patch.object(tools_config, "_print_info"), \ + patch.object(tools_config, "_run_cua_driver_installer") as runner: + assert tools_config.install_cua_driver(upgrade=True) is False + runner.assert_not_called() diff --git a/tests/hermes_cli/test_kanban_db.py b/tests/hermes_cli/test_kanban_db.py index 435ef41001a..883cf8f4d5d 100644 --- a/tests/hermes_cli/test_kanban_db.py +++ b/tests/hermes_cli/test_kanban_db.py @@ -1470,6 +1470,138 @@ def test_worktree_workspace_returns_intended_path(kanban_home, tmp_path): assert str(ws) == target +# --------------------------------------------------------------------------- +# Scratch cleanup containment (#28818) +# --------------------------------------------------------------------------- + +def test_cleanup_workspace_removes_managed_scratch_dir(kanban_home): + """A scratch workspace under the kanban workspaces root is removed.""" + with kb.connect() as conn: + t = kb.create_task(conn, title="scratchy") + task = kb.get_task(conn, t) + ws = kb.resolve_workspace(task) + kb.set_workspace_path(conn, t, ws) + assert ws.is_dir() + kb.complete_task(conn, t, result="ok") + assert not ws.exists(), "Hermes-managed scratch dir should be cleaned up" + + +def test_cleanup_workspace_refuses_path_outside_scratch_root(kanban_home, tmp_path): + """A scratch task with a user path outside the workspaces root must NOT be deleted (#28818). + + Reproduces the data-loss vector where a board's ``default_workdir`` is set + to a real source directory; tasks created without an explicit + ``workspace_kind`` inherit ``scratch`` semantics, and the old cleanup path + would ``shutil.rmtree`` the user's source tree on task completion. + """ + real_source = tmp_path / "real-source" + real_source.mkdir() + (real_source / ".git").mkdir() + (real_source / "README.md").write_text("important", encoding="utf-8") + + with kb.connect() as conn: + t = kb.create_task(conn, title="ship") + # Simulate the bad state directly: workspace_kind='scratch' (default) + # but workspace_path pointing at the user's real source tree, which is + # exactly what board.default_workdir produces when the task is created + # without an explicit workspace_kind. + conn.execute( + "UPDATE tasks SET workspace_kind=?, workspace_path=? WHERE id=?", + ("scratch", str(real_source), t), + ) + conn.commit() + kb.complete_task(conn, t, result="ok") + + assert real_source.exists(), "User source tree must not be deleted by scratch cleanup" + assert (real_source / ".git").exists() + assert (real_source / "README.md").read_text(encoding="utf-8") == "important" + + +def test_cleanup_workspace_honors_workspaces_root_env_override(tmp_path, monkeypatch): + """``HERMES_KANBAN_WORKSPACES_ROOT`` extends the managed-scratch set. + + Worker subprocesses run with this env var injected by the dispatcher. The + cleanup containment check must treat paths under it as managed even when + they sit outside the active kanban home. + """ + home = tmp_path / ".hermes" + home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + workspaces_override = tmp_path / "ext-workspaces" + workspaces_override.mkdir() + monkeypatch.setenv("HERMES_KANBAN_WORKSPACES_ROOT", str(workspaces_override)) + kb.init_db() + + with kb.connect() as conn: + t = kb.create_task(conn, title="ext") + scratch_dir = workspaces_override / t + scratch_dir.mkdir() + conn.execute( + "UPDATE tasks SET workspace_kind=?, workspace_path=? WHERE id=?", + ("scratch", str(scratch_dir), t), + ) + conn.commit() + kb.complete_task(conn, t, result="ok") + + assert not scratch_dir.exists(), "Override-root scratch dir should be cleaned up" + + +def test_is_managed_scratch_path_accepts_per_board_workspaces(kanban_home, tmp_path): + """Per-board scratch dirs under ``/kanban/boards//workspaces`` are managed.""" + board_scratch = kanban_home / "kanban" / "boards" / "my-board" / "workspaces" / "task-1" + board_scratch.mkdir(parents=True) + assert kb._is_managed_scratch_path(board_scratch) + + +def test_is_managed_scratch_path_rejects_real_source_tree(kanban_home, tmp_path): + """A path outside any managed root (e.g. a user's repo) is NOT managed.""" + real = tmp_path / "code" / "my-project" + real.mkdir(parents=True) + assert not kb._is_managed_scratch_path(real) + + +def test_is_managed_scratch_path_rejects_kanban_metadata_subtrees(kanban_home): + """Hermes' own DB/metadata/log subtrees under ``/kanban`` are NOT managed. + + Regression guard for the Copilot finding on #28819: a scratch task whose + ``workspace_path`` was mis-set to the kanban home, the logs dir, or a + board's metadata dir (i.e. the board root itself, not its ``workspaces/`` + child) must be refused. Without this, the containment check would happily + ``shutil.rmtree`` Hermes' DB/metadata/logs on task completion. + """ + kanban_root = kanban_home / "kanban" + kanban_root.mkdir(parents=True, exist_ok=True) + assert not kb._is_managed_scratch_path(kanban_root) + + logs_dir = kanban_root / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + assert not kb._is_managed_scratch_path(logs_dir) + + board_root = kanban_root / "boards" / "my-board" + board_root.mkdir(parents=True, exist_ok=True) + # The board root itself is NOT a managed scratch dir โ€” only the + # ``workspaces/`` child (and its descendants) are. + assert not kb._is_managed_scratch_path(board_root) + + # Sibling subtrees of ``workspaces/`` under a board (e.g. its kanban.db + # or board.json living next to ``workspaces/``) are also not managed. + board_logs = board_root / "logs" + board_logs.mkdir(parents=True, exist_ok=True) + assert not kb._is_managed_scratch_path(board_logs) + + # Now create the board's workspaces dir and a task scratch dir under it โ€” + # the latter is the only thing the guard should allow. + board_workspaces = board_root / "workspaces" + board_workspaces.mkdir(parents=True, exist_ok=True) + # The workspaces root itself is also NOT managed โ€” deleting it would + # wipe every task's scratch dir at once. + assert not kb._is_managed_scratch_path(board_workspaces) + task_dir = board_workspaces / "task-42" + task_dir.mkdir(parents=True, exist_ok=True) + assert kb._is_managed_scratch_path(task_dir) + + # --------------------------------------------------------------------------- # Tenancy # --------------------------------------------------------------------------- @@ -2464,13 +2596,32 @@ def test_task_dict_survives_corrupt_created_at(tmp_path, monkeypatch): # --------------------------------------------------------------------------- -def test_create_task_without_workspace_inherits_board_default_workdir(kanban_home, monkeypatch): - """Board with default_workdir โ†’ create_task without workspace_path โ†’ inherits default.""" +def test_create_task_scratch_without_workspace_ignores_board_default_workdir(kanban_home, monkeypatch): + """Scratch tasks must NOT inherit board.default_workdir โ€” would point auto-cleanup + at the user's source tree on completion (#28818).""" default_wd = "/home/user/project" kb.create_board("work-proj", default_workdir=default_wd) with kb.connect(board="work-proj") as conn: - tid = kb.create_task(conn, title="inherited", board="work-proj") + tid = kb.create_task(conn, title="scratch-task", board="work-proj") + t = kb.get_task(conn, tid) + assert t is not None + assert t.workspace_kind == "scratch" + assert t.workspace_path is None + + +def test_create_task_dir_without_workspace_inherits_board_default_workdir(kanban_home, monkeypatch): + """Board default_workdir is for persistent dir/worktree workspaces, not scratch.""" + default_wd = "/home/user/project" + kb.create_board("work-proj-dir", default_workdir=default_wd) + + with kb.connect(board="work-proj-dir") as conn: + tid = kb.create_task( + conn, + title="inherited", + workspace_kind="dir", + board="work-proj-dir", + ) t = kb.get_task(conn, tid) assert t is not None assert t.workspace_path == default_wd @@ -2981,3 +3132,210 @@ def test_detect_stale_does_not_tick_failure_counter(kanban_home, monkeypatch): assert "stale" in kinds, ( f"Expected 'stale' event in task_events; got {kinds!r}" ) + + +# --------------------------------------------------------------------------- +# Corruption guard (issue #30687) +# --------------------------------------------------------------------------- + +def _write_corrupt_db(path: Path) -> bytes: + """Write a kanban DB with a VALID SQLite header but malformed page content. + + This is the corruption shape the integrity guard specifically targets + (e.g. issue #29507 follow-up reports where the file's first 16 bytes + pass the header byte check but ``PRAGMA integrity_check`` then fails + because the internal pages are damaged). It's what main's header-only + validator was letting through, and what this PR adds the full guard + for. + """ + # 100-byte SQLite header (magic + minimal valid-looking fields) so the + # cheap header check passes, then deliberate garbage so sqlite refuses + # to read the file past the header. + header = b"SQLite format 3\x00" + b"\x10\x00\x02\x02\x00\x40\x20\x20" + header += b"\x00\x00\x00\x0c\x00\x00\x23\x46\x00\x00\x00\x00" + header = header.ljust(100, b"\x00") + payload = b"definitely not a valid sqlite page \x00\x01\x02\x03" * 64 + blob = header + payload + path.write_bytes(blob) + return blob + + +def test_init_db_refuses_corrupt_existing_file(tmp_path): + db_path = tmp_path / "kanban.db" + original = _write_corrupt_db(db_path) + # Ensure the cache doesn't mask the guard. + kb._INITIALIZED_PATHS.discard(str(db_path.resolve())) + + with pytest.raises(kb.KanbanDbCorruptError) as excinfo: + kb.init_db(db_path=db_path) + + err = excinfo.value + assert err.db_path == db_path + assert err.backup_path is not None + assert err.backup_path.exists() + assert err.backup_path.read_bytes() == original + # Original bytes untouched โ€” no schema was written on top. + assert db_path.read_bytes() == original + assert str(db_path) in str(err) + assert str(err.backup_path) in str(err) + + +def test_connect_refuses_corrupt_existing_file(tmp_path): + db_path = tmp_path / "kanban.db" + _write_corrupt_db(db_path) + kb._INITIALIZED_PATHS.discard(str(db_path.resolve())) + + with pytest.raises(kb.KanbanDbCorruptError): + kb.connect(db_path=db_path) + + +def test_locked_healthy_db_does_not_classify_as_corrupt(tmp_path, monkeypatch): + """A transient lock during the probe must not produce a .corrupt backup + and must not be reported as :class:`KanbanDbCorruptError`. Raw sqlite + ``OperationalError`` (lock/busy) is acceptable and expected.""" + db_path = tmp_path / "kanban.db" + kb.init_db(db_path=db_path) + kb._INITIALIZED_PATHS.discard(str(db_path.resolve())) + + real_connect = sqlite3.connect + + def flaky_connect(*args, **kwargs): + # First call is the integrity probe โ€” simulate a lock. + raise sqlite3.OperationalError("database is locked") + + monkeypatch.setattr(kb.sqlite3, "connect", flaky_connect) + + with pytest.raises(sqlite3.OperationalError): + kb.connect(db_path=db_path) + + # No .corrupt backup may be produced for a healthy-but-locked DB. + backups = list(tmp_path.glob("*.corrupt.*")) + assert backups == [], f"unexpected corrupt backups: {backups}" + + # And once the lock clears, normal access still works. + monkeypatch.setattr(kb.sqlite3, "connect", real_connect) + with kb.connect(db_path=db_path) as conn: + kb.create_task(conn, title="still here") + titles = [t.title for t in kb.list_tasks(conn)] + assert "still here" in titles + + +def test_init_db_allows_missing_then_healthy(tmp_path): + db_path = tmp_path / "fresh.db" + assert not db_path.exists() + kb.init_db(db_path=db_path) + assert db_path.exists() and db_path.stat().st_size > 0 + + # Idempotent on a healthy DB: data survives a second init. + with kb.connect(db_path=db_path) as conn: + kb.create_task(conn, title="keeps") + kb.init_db(db_path=db_path) + with kb.connect(db_path=db_path) as conn: + tasks = kb.list_tasks(conn) + assert [t.title for t in tasks] == ["keeps"] + + +# --------------------------------------------------------------------------- +# First-use tip for scratch workspaces +# --------------------------------------------------------------------------- + +def test_maybe_emit_scratch_tip_fires_once_per_install(kanban_home, caplog): + """First scratch workspace materialization warns + emits an event. + + Subsequent scratch workspaces on the SAME install stay silent โ€” the + sentinel file under kanban_home() flips after the first emit. + """ + import logging + + with kb.connect() as conn: + t1 = kb.create_task(conn, title="first scratch") + t2 = kb.create_task(conn, title="second scratch") + + # Sentinel must not exist yet on a fresh install. + assert not kb._scratch_tip_shown() + + with caplog.at_level(logging.WARNING, logger="hermes_cli.kanban_db"): + with kb.connect() as conn: + kb._maybe_emit_scratch_tip(conn, t1, "scratch") + + # Sentinel is now set. + assert kb._scratch_tip_shown() + assert kb._scratch_tip_sentinel_path().exists() + + # Warning was logged exactly once. + tip_records = [ + r for r in caplog.records + if "scratch workspaces are ephemeral" in r.getMessage() + ] + assert len(tip_records) == 1, ( + f"Expected exactly one tip warning, got {len(tip_records)}: " + f"{[r.getMessage() for r in tip_records]!r}" + ) + + # An event row was appended on the first task. + with kb.connect() as conn: + events = conn.execute( + "SELECT kind FROM task_events WHERE task_id = ? ORDER BY id", + (t1,), + ).fetchall() + kinds = [e["kind"] for e in events] + assert "tip_scratch_workspace" in kinds, ( + f"Expected tip_scratch_workspace event on first scratch task; " + f"got {kinds!r}" + ) + + # Second scratch materialization on the same install stays silent. + caplog.clear() + with caplog.at_level(logging.WARNING, logger="hermes_cli.kanban_db"): + with kb.connect() as conn: + kb._maybe_emit_scratch_tip(conn, t2, "scratch") + tip_records2 = [ + r for r in caplog.records + if "scratch workspaces are ephemeral" in r.getMessage() + ] + assert tip_records2 == [], ( + f"Tip should not re-fire after sentinel is set; got " + f"{[r.getMessage() for r in tip_records2]!r}" + ) + with kb.connect() as conn: + events2 = conn.execute( + "SELECT kind FROM task_events WHERE task_id = ? ORDER BY id", + (t2,), + ).fetchall() + assert "tip_scratch_workspace" not in [e["kind"] for e in events2], ( + "Tip event should not be appended for subsequent scratch tasks." + ) + + +def test_maybe_emit_scratch_tip_skips_non_scratch_workspaces(kanban_home, caplog): + """worktree/dir workspaces are preserved on completion and must not + trigger the scratch-cleanup tip.""" + import logging + + with kb.connect() as conn: + t_wt = kb.create_task(conn, title="worktree task") + t_dir = kb.create_task(conn, title="dir task") + + assert not kb._scratch_tip_shown() + + with caplog.at_level(logging.WARNING, logger="hermes_cli.kanban_db"): + with kb.connect() as conn: + kb._maybe_emit_scratch_tip(conn, t_wt, "worktree") + kb._maybe_emit_scratch_tip(conn, t_dir, "dir") + + # Sentinel stays unset โ€” these workspaces are preserved by design, + # so the warning is irrelevant for them and we save the one-shot + # for a real scratch user. + assert not kb._scratch_tip_shown() + tip_records = [ + r for r in caplog.records + if "scratch workspaces are ephemeral" in r.getMessage() + ] + assert tip_records == [] + with kb.connect() as conn: + for tid in (t_wt, t_dir): + events = conn.execute( + "SELECT kind FROM task_events WHERE task_id = ?", (tid,), + ).fetchall() + assert "tip_scratch_workspace" not in [e["kind"] for e in events] + diff --git a/tests/hermes_cli/test_kanban_notify.py b/tests/hermes_cli/test_kanban_notify.py index 1ebf92705d7..44a0bd90a03 100644 --- a/tests/hermes_cli/test_kanban_notify.py +++ b/tests/hermes_cli/test_kanban_notify.py @@ -17,6 +17,11 @@ def kanban_home(tmp_path, monkeypatch): home.mkdir() monkeypatch.setenv("HERMES_HOME", str(home)) monkeypatch.setattr(Path, "home", lambda: tmp_path) + # Allow the kanban notifier path-validator to upload artifacts the + # tests write under ``tmp_path``. Without this, every artifact-delivery + # test silently drops files because ``tmp_path`` isn't inside the + # default ``MEDIA_DELIVERY_SAFE_ROOTS`` cache dirs. + monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(tmp_path)) kb.init_db() return home @@ -482,7 +487,7 @@ async def test_gateway_create_autosubscribes_on_explicit_board(kanban_home): @pytest.mark.asyncio -async def test_notifier_uploads_artifacts_on_completion(kanban_home, tmp_path): +async def test_notifier_uploads_artifacts_on_completion(kanban_home, tmp_path, monkeypatch): """When a completed event carries ``artifacts`` in its payload, the notifier uploads each file to the subscribed chat as a native attachment. Images batch through send_multiple_images; documents @@ -494,6 +499,13 @@ async def test_notifier_uploads_artifacts_on_completion(kanban_home, tmp_path): from gateway.config import Platform from tools import kanban_tools as kt + # ``_deliver_kanban_artifacts`` routes candidates through + # ``BasePlatformAdapter.filter_local_delivery_paths``, which only accepts + # paths under ``MEDIA_DELIVERY_SAFE_ROOTS`` or roots explicitly allowlisted + # via ``HERMES_MEDIA_ALLOW_DIRS``. Test fixtures live under ``tmp_path``, + # so allowlist it for the duration of the test. + monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(tmp_path)) + # Materialize real files so os.path.isfile passes inside the helper. chart_path = tmp_path / "q3-revenue.png" chart_path.write_bytes(b"PNG-fake-bytes") @@ -572,7 +584,7 @@ async def test_notifier_uploads_artifacts_on_completion(kanban_home, tmp_path): @pytest.mark.asyncio -async def test_notifier_artifact_delivery_skips_missing_files(kanban_home, tmp_path): +async def test_notifier_artifact_delivery_skips_missing_files(kanban_home, tmp_path, monkeypatch): """Missing artifact paths are silently skipped โ€” they may have been referenced by name only. The notifier must not crash and must still deliver any artifacts that do exist.""" @@ -581,6 +593,10 @@ async def test_notifier_artifact_delivery_skips_missing_files(kanban_home, tmp_p from gateway.config import Platform from tools import kanban_tools as kt + # Allow ``tmp_path`` through the media-delivery safety filter. See the + # companion test for the full explanation. + monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(tmp_path)) + real_pdf = tmp_path / "real.pdf" real_pdf.write_bytes(b"%PDF-fake") diff --git a/tests/hermes_cli/test_kanban_promote.py b/tests/hermes_cli/test_kanban_promote.py new file mode 100644 index 00000000000..6cbf3b77071 --- /dev/null +++ b/tests/hermes_cli/test_kanban_promote.py @@ -0,0 +1,254 @@ +"""Tests for the kanban `promote` verb (issue #28822). + +The realistic bug scenario from #28822 is: a child task ends up in +``todo`` with all its parents already ``done`` (because the +auto-promote daemon hasn't run, or a manual close raced it). +Direct-SQL setup is used to construct that state deterministically. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import pytest + +from hermes_cli import kanban as kb_cli +from hermes_cli import kanban_db as kb + + +@pytest.fixture +def kanban_home(tmp_path, monkeypatch): + home = tmp_path / ".hermes" + home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + db_path = kb.kanban_db_path(board="default") + kb._INITIALIZED_PATHS.discard(str(db_path.resolve())) + kb.init_db() + return home + + +@pytest.fixture +def conn(kanban_home): + with kb.connect() as c: + yield c + + +def _stuck_todo(conn, *, parents_done=True, n_parents=1): + """Build the #28822 scenario: child in 'todo' whose parents may + have closed as 'done' without the auto-promote logic firing. + """ + parent_ids = [ + kb.create_task(conn, title=f"parent{i}", assignee="setup") + for i in range(n_parents) + ] + child_id = kb.create_task( + conn, title="child", parents=parent_ids, assignee="setup" + ) + assert kb.get_task(conn, child_id).status == "todo" + if parents_done: + for pid in parent_ids: + conn.execute( + "UPDATE tasks SET status='done' WHERE id=?", (pid,) + ) + return child_id, parent_ids + + +def test_promote_stuck_todo_succeeds(conn): + child, _ = _stuck_todo(conn, parents_done=True) + ok, err = kb.promote_task(conn, child, actor="tester") + assert ok and err is None + assert kb.get_task(conn, child).status == "ready" + + +def test_promote_refuses_when_parent_not_done(conn): + child, parents = _stuck_todo(conn, parents_done=False) + ok, err = kb.promote_task(conn, child, actor="tester") + assert ok is False + assert err is not None and "unsatisfied parent dependencies" in err + assert parents[0] in err + assert kb.get_task(conn, child).status == "todo" + + +def test_promote_with_force_bypasses_dependency_check(conn): + child, _ = _stuck_todo(conn, parents_done=False) + ok, err = kb.promote_task( + conn, child, actor="tester", reason="recovery", force=True + ) + assert ok and err is None + assert kb.get_task(conn, child).status == "ready" + + +def test_promote_emits_audit_event(conn): + child, _ = _stuck_todo(conn, parents_done=True) + kb.promote_task(conn, child, actor="tester", reason="manual recovery") + ev = conn.execute( + "SELECT kind, payload FROM task_events " + "WHERE task_id = ? AND kind = 'promoted_manual'", + (child,), + ).fetchone() + assert ev is not None + payload = json.loads(ev["payload"]) + assert payload["actor"] == "tester" + assert payload["reason"] == "manual recovery" + assert payload["forced"] is False + + +def test_promote_force_records_forced_flag(conn): + child, _ = _stuck_todo(conn, parents_done=False) + kb.promote_task(conn, child, actor="tester", force=True, reason="r") + ev = conn.execute( + "SELECT payload FROM task_events " + "WHERE task_id = ? AND kind = 'promoted_manual'", + (child,), + ).fetchone() + assert json.loads(ev["payload"])["forced"] is True + + +def test_promote_does_not_change_assignee(conn): + child, _ = _stuck_todo(conn, parents_done=True) + before = kb.get_task(conn, child).assignee + kb.promote_task(conn, child, actor="someone_else") + after = kb.get_task(conn, child).assignee + assert before == after + + +def test_promote_dry_run_does_not_mutate(conn): + child, _ = _stuck_todo(conn, parents_done=True) + ok, err = kb.promote_task(conn, child, actor="tester", dry_run=True) + assert ok and err is None + assert kb.get_task(conn, child).status == "todo" + n = conn.execute( + "SELECT COUNT(*) AS n FROM task_events " + "WHERE task_id = ? AND kind = 'promoted_manual'", + (child,), + ).fetchone()["n"] + assert n == 0 + + +def test_promote_dry_run_reports_dependency_failure(conn): + child, _ = _stuck_todo(conn, parents_done=False) + ok, err = kb.promote_task(conn, child, actor="tester", dry_run=True) + assert ok is False + assert err is not None and "unsatisfied" in err + + +def test_promote_rejects_non_todo_status(conn): + tid = kb.create_task(conn, title="standalone") + assert kb.get_task(conn, tid).status == "ready" + ok, err = kb.promote_task(conn, tid, actor="tester") + assert ok is False + assert "'ready'" in err and "promote only applies" in err + + +def test_promote_rejects_unknown_task(conn): + ok, err = kb.promote_task(conn, "t_doesnotexist", actor="tester") + assert ok is False + assert err is not None and "not found" in err + + +def test_promote_blocked_task_works(conn): + tid = kb.create_task(conn, title="t") + conn.execute("UPDATE tasks SET status='blocked' WHERE id=?", (tid,)) + ok, err = kb.promote_task( + conn, tid, actor="tester", reason="ready now" + ) + assert ok and err is None + assert kb.get_task(conn, tid).status == "ready" + + +# --------------------------------------------------------------------------- +# CLI `_cmd_promote` โ€” bulk via `--ids` (the issue's anti-respawn use case: +# promote all children of a closed parent in one command). +# --------------------------------------------------------------------------- + + +def _promote_ns(task_id, *, ids=None, reason=None, force=False, + dry_run=False, as_json=False): + return argparse.Namespace( + task_id=task_id, + reason=list(reason or []), + ids=list(ids or []) or None, + force=force, + dry_run=dry_run, + json=as_json, + ) + + +def test_cli_promote_bulk_ids_promotes_all(kanban_home, capsys): + with kb.connect() as conn: + parent = kb.create_task(conn, title="parent") + children = [ + kb.create_task(conn, title=f"c{i}", parents=[parent]) + for i in range(3) + ] + conn.execute("UPDATE tasks SET status='done' WHERE id=?", (parent,)) + rc = kb_cli._cmd_promote(_promote_ns(children[0], ids=children[1:])) + assert rc == 0 + out = capsys.readouterr().out + for c in children: + assert c in out + with kb.connect() as conn: + for c in children: + assert kb.get_task(conn, c).status == "ready" + + +def test_cli_promote_bulk_partial_failure_exits_1(kanban_home, capsys): + """Bulk with one bad id: good ones still promote, exit code reflects failure.""" + with kb.connect() as conn: + parent = kb.create_task(conn, title="parent") + good = kb.create_task(conn, title="good", parents=[parent]) + conn.execute("UPDATE tasks SET status='done' WHERE id=?", (parent,)) + rc = kb_cli._cmd_promote(_promote_ns(good, ids=["t_nope"])) + assert rc == 1 + captured = capsys.readouterr() + assert good in captured.out # good one promoted + assert "t_nope" in captured.err and "not found" in captured.err + with kb.connect() as conn: + assert kb.get_task(conn, good).status == "ready" + + +def test_cli_promote_bulk_json_emits_list(kanban_home, capsys): + with kb.connect() as conn: + parent = kb.create_task(conn, title="parent") + a = kb.create_task(conn, title="a", parents=[parent]) + b = kb.create_task(conn, title="b", parents=[parent]) + conn.execute("UPDATE tasks SET status='done' WHERE id=?", (parent,)) + rc = kb_cli._cmd_promote(_promote_ns(a, ids=[b], as_json=True)) + assert rc == 0 + payload = json.loads(capsys.readouterr().out) + assert isinstance(payload, list) and len(payload) == 2 + assert {r["task_id"] for r in payload} == {a, b} + assert all(r["promoted"] for r in payload) + + +def test_cli_promote_single_json_stays_flat_object(kanban_home, capsys): + """Back-compat: single-id JSON is still a flat object, not a list.""" + with kb.connect() as conn: + parent = kb.create_task(conn, title="parent") + child = kb.create_task(conn, title="c", parents=[parent]) + conn.execute("UPDATE tasks SET status='done' WHERE id=?", (parent,)) + rc = kb_cli._cmd_promote(_promote_ns(child, as_json=True)) + assert rc == 0 + payload = json.loads(capsys.readouterr().out) + assert isinstance(payload, dict) + assert payload["task_id"] == child and payload["promoted"] is True + + +def test_cli_promote_dedupes_duplicate_ids(kanban_home, capsys): + """Same id in positional + --ids must only attempt the promotion once.""" + with kb.connect() as conn: + parent = kb.create_task(conn, title="parent") + child = kb.create_task(conn, title="c", parents=[parent]) + conn.execute("UPDATE tasks SET status='done' WHERE id=?", (parent,)) + rc = kb_cli._cmd_promote(_promote_ns(child, ids=[child, child])) + assert rc == 0 + with kb.connect() as conn: + n = conn.execute( + "SELECT COUNT(*) AS n FROM task_events " + "WHERE task_id = ? AND kind = 'promoted_manual'", + (child,), + ).fetchone()["n"] + assert n == 1 diff --git a/tests/hermes_cli/test_nous_inference_url_validation.py b/tests/hermes_cli/test_nous_inference_url_validation.py new file mode 100644 index 00000000000..4e688a59a74 --- /dev/null +++ b/tests/hermes_cli/test_nous_inference_url_validation.py @@ -0,0 +1,214 @@ +"""Regression tests for Nous Portal inference_base_url host-allowlist validation. + +A poisoned ``inference_base_url`` from the Portal refresh / agent-key-mint +response (network MITM, malicious response injection) would otherwise be +persisted to auth.json and forwarded the user's legitimate agent_key +bearer on every subsequent proxy request, exfiltrating their inference +budget and opening a response-injection channel into the IDE / chat +client. ``_validate_nous_inference_url_from_network()`` blocks any URL +outside the allowlist at the source. + +These tests verify: + +1. The validator's host + scheme rules. +2. Each of the five NETWORK call sites in ``auth.py`` calls the validator + rather than the unrestricted ``_optional_base_url`` helper. +3. The proxy adapter applies the validator as belt-and-suspenders. +4. The env-var override path (``NOUS_INFERENCE_BASE_URL``) is NOT + gated by the validator โ€” that's the documented dev/staging escape + hatch. +""" + +from __future__ import annotations + +import logging +import pytest + +from hermes_cli.auth import ( + DEFAULT_NOUS_INFERENCE_URL, + _ALLOWED_NOUS_INFERENCE_HOSTS, + _validate_nous_inference_url_from_network, +) + + +class TestValidatorRules: + def test_allowlisted_https_host_returned(self): + url = "https://inference-api.nousresearch.com/v1" + assert _validate_nous_inference_url_from_network(url) == url + + def test_trailing_slash_stripped(self): + url = "https://inference-api.nousresearch.com/v1/" + assert _validate_nous_inference_url_from_network(url) == url.rstrip("/") + + def test_attacker_host_rejected(self, caplog): + with caplog.at_level(logging.WARNING, logger="hermes_cli.auth"): + assert ( + _validate_nous_inference_url_from_network("https://attacker.com/v1") + is None + ) + assert any("attacker.com" in rec.message for rec in caplog.records) + + def test_subdomain_of_allowlist_host_rejected(self): + """*.nousresearch.com is NOT in the allowlist โ€” exact hostname only. + + A subdomain takeover or DNS hijack of *.nousresearch.com would + otherwise pass โ€” keep the gate tight. + """ + assert ( + _validate_nous_inference_url_from_network( + "https://evil.inference-api.nousresearch.com/v1" + ) + is None + ) + + def test_http_scheme_rejected(self, caplog): + with caplog.at_level(logging.WARNING, logger="hermes_cli.auth"): + assert ( + _validate_nous_inference_url_from_network( + "http://inference-api.nousresearch.com/v1" + ) + is None + ) + assert any("non-https" in rec.message for rec in caplog.records) + + def test_file_scheme_rejected(self): + assert ( + _validate_nous_inference_url_from_network("file:///etc/passwd") is None + ) + + def test_javascript_scheme_rejected(self): + assert ( + _validate_nous_inference_url_from_network( + "javascript:alert(document.cookie)" + ) + is None + ) + + def test_empty_string_rejected(self): + assert _validate_nous_inference_url_from_network("") is None + + def test_whitespace_only_rejected(self): + assert _validate_nous_inference_url_from_network(" ") is None + + def test_none_rejected(self): + assert _validate_nous_inference_url_from_network(None) is None + + def test_non_string_rejected(self): + assert _validate_nous_inference_url_from_network(12345) is None # type: ignore[arg-type] + assert _validate_nous_inference_url_from_network({"url": "x"}) is None # type: ignore[arg-type] + + def test_malformed_url_rejected(self): + """Even garbled input must fall back safely, not raise.""" + assert ( + _validate_nous_inference_url_from_network("not://a real url at all") + is None + ) + + def test_default_inference_url_is_in_allowlist(self): + """Sanity check: DEFAULT_NOUS_INFERENCE_URL must itself validate. + + If anyone retargets the default away from + ``inference-api.nousresearch.com``, they MUST update the allowlist + in the same change โ€” otherwise the allowlist would reject the + Portal's own legitimate default and break every install. + """ + assert ( + _validate_nous_inference_url_from_network(DEFAULT_NOUS_INFERENCE_URL) + == DEFAULT_NOUS_INFERENCE_URL.rstrip("/") + ) + + def test_allowlist_contains_inference_api_host(self): + """The default's host must be in the allowlist set.""" + from urllib.parse import urlparse + host = urlparse(DEFAULT_NOUS_INFERENCE_URL).hostname + assert host in _ALLOWED_NOUS_INFERENCE_HOSTS + + +class TestCallSiteWiring: + """Verify the validator is actually wired into all 5 NETWORK call sites. + + These are not behaviour-end-to-end tests (the surrounding code is + several hundred lines per site with extensive HTTP mocking + requirements). They're text-grep contracts: if anyone replaces + ``_validate_nous_inference_url_from_network`` with the un-validated + ``_optional_base_url`` again, the test catches it. + + Each site lives inside ``resolve_nous_runtime_credentials`` and one + helper (``_extend_state_from_refresh``). The shape we guard against + is ``_url = _optional_base_url(.get("inference_base_url"))`` + โ€” that's what the unsafe pre-fix code looked like, and the only + semantic difference between the safe and unsafe helpers is the + host-allowlist check. + """ + + def _read_auth_source(self): + import hermes_cli.auth as _auth_mod + from pathlib import Path + return Path(_auth_mod.__file__).read_text(encoding="utf-8") + + def test_no_unvalidated_inference_base_url_assignments_remain(self): + """No remaining ``_optional_base_url(...inference_base_url...)`` reads + from Portal payloads. If you see a failure here, you've either + added a new NETWORK site that needs validation, or downgraded an + existing one back to the unsafe helper.""" + source = self._read_auth_source() + for needle in ( + '_optional_base_url(refreshed.get("inference_base_url"))', + '_optional_base_url(mint_payload.get("inference_base_url"))', + ): + assert needle not in source, ( + f"Found unvalidated network read: {needle!r}. " + f"Use _validate_nous_inference_url_from_network() instead." + ) + + def test_validator_wired_at_all_known_call_sites(self): + """All 5 known NETWORK sites use the validator. If this count + drops, someone removed protection; if it grows, audit the new + site to be sure validation is appropriate.""" + source = self._read_auth_source() + refresh_count = source.count( + '_validate_nous_inference_url_from_network(refreshed.get("inference_base_url"))' + ) + mint_count = source.count( + '_validate_nous_inference_url_from_network(mint_payload.get("inference_base_url"))' + ) + assert refresh_count == 3, f"expected 3 refresh sites, found {refresh_count}" + assert mint_count == 2, f"expected 2 mint sites, found {mint_count}" + + def test_proxy_adapter_also_validates(self): + """The Nous proxy adapter applies the validator as defense-in-depth + even though auth.py already validates at the source, so a future + bypass at the source layer still gets caught at the forward + boundary.""" + from pathlib import Path + import hermes_cli.proxy.adapters.nous_portal as _nous_adapter + source = Path(_nous_adapter.__file__).read_text(encoding="utf-8") + assert "_validate_nous_inference_url_from_network" in source + + +class TestEnvOverrideNotGated: + """The documented dev/staging env-var override must keep working. + + ``NOUS_INFERENCE_BASE_URL`` is read by ``resolve_nous_runtime_credentials`` + via ``os.getenv`` โ€” that path doesn't pass through the validator + (env values are trusted because the user set them themselves). + Verify the env-var read site does NOT consult the validator, so a + user running against a non-allowlisted staging host via env is not + inadvertently broken by this fix. + """ + + def test_env_override_path_does_not_call_validator(self): + """In resolve_nous_runtime_credentials, the env override is + read via os.getenv directly, not via the validator. Grep the + source to confirm: the env line should NOT mention the + validator.""" + import hermes_cli.auth as _auth_mod + from pathlib import Path + source = Path(_auth_mod.__file__).read_text(encoding="utf-8") + # Find the env-override read line. + for line in source.splitlines(): + if "NOUS_INFERENCE_BASE_URL" in line and "os.getenv" in line: + assert "_validate_nous_inference_url_from_network" not in line, ( + "env override path must not gate through the network " + "validator โ€” it would break documented dev/staging use." + ) diff --git a/tests/hermes_cli/test_plugin_auxiliary_tasks.py b/tests/hermes_cli/test_plugin_auxiliary_tasks.py new file mode 100644 index 00000000000..667546efe43 --- /dev/null +++ b/tests/hermes_cli/test_plugin_auxiliary_tasks.py @@ -0,0 +1,353 @@ +"""Tests for the plugin auxiliary-task registration API. + +Covers: + - PluginContext.register_auxiliary_task() validation + - PluginManager._aux_tasks storage + force-rediscovery clearing + - get_plugin_auxiliary_tasks() module-level helper + - _all_aux_tasks() merge of built-in + plugin tasks + - _reset_aux_to_auto() includes plugin tasks + - _get_auxiliary_task_config() layers plugin defaults under user config +""" + +from __future__ import annotations + +import pytest + +from hermes_cli.plugins import ( + PluginContext, + PluginManager, + PluginManifest, + get_plugin_auxiliary_tasks, +) + + +# โ”€โ”€ Fixtures โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def _make_ctx(name: str = "test_plugin") -> tuple[PluginContext, PluginManager]: + """Build a PluginContext + fresh PluginManager wired together. + + The manager skips discovery (no plugins.yaml, no scan) so the test + can exercise registration paths directly. + """ + manager = PluginManager() + manager._discovered = True # skip auto-discovery on lookup + manifest = PluginManifest(name=name) + ctx = PluginContext(manifest, manager) + return ctx, manager + + +@pytest.fixture +def patched_manager(monkeypatch): + """Replace the module-level singleton with a fresh manager for the test. + + Restored automatically after the test by monkeypatch. + """ + from hermes_cli import plugins as plugins_mod + + fresh = PluginManager() + fresh._discovered = True + monkeypatch.setattr(plugins_mod, "_PLUGIN_MANAGER", fresh, raising=False) + + def _stub_get_manager() -> PluginManager: + return fresh + + monkeypatch.setattr(plugins_mod, "get_plugin_manager", _stub_get_manager) + monkeypatch.setattr(plugins_mod, "_ensure_plugins_discovered", _stub_get_manager) + yield fresh + + +# โ”€โ”€ PluginContext.register_auxiliary_task โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_register_auxiliary_task_basic(): + ctx, manager = _make_ctx("my_plugin") + ctx.register_auxiliary_task( + key="my_task", + display_name="My task", + description="a custom side task", + ) + assert "my_task" in manager._aux_tasks + entry = manager._aux_tasks["my_task"] + assert entry["key"] == "my_task" + assert entry["display_name"] == "My task" + assert entry["description"] == "a custom side task" + assert entry["plugin"] == "my_plugin" + # Routing defaults populated + assert entry["defaults"]["provider"] == "auto" + assert entry["defaults"]["model"] == "" + assert entry["defaults"]["timeout"] == 60 + + +def test_register_auxiliary_task_with_custom_defaults(): + ctx, manager = _make_ctx() + ctx.register_auxiliary_task( + key="custom_task", + display_name="Custom", + description="d", + defaults={"timeout": 30, "extra_body": {"reasoning_effort": "low"}}, + ) + entry = manager._aux_tasks["custom_task"] + assert entry["defaults"]["timeout"] == 30 + assert entry["defaults"]["extra_body"] == {"reasoning_effort": "low"} + # Unspecified defaults still populated + assert entry["defaults"]["provider"] == "auto" + + +def test_register_auxiliary_task_rejects_builtin_keys(): + ctx, _ = _make_ctx() + for builtin in ( + "vision", + "compression", + "web_extract", + "approval", + "mcp", + "title_generation", + "skills_hub", + "curator", + ): + with pytest.raises(ValueError, match="reserved for a built-in task"): + ctx.register_auxiliary_task( + key=builtin, + display_name="x", + description="x", + ) + + +def test_register_auxiliary_task_rejects_invalid_key_shapes(): + ctx, _ = _make_ctx() + for bad in ("", "with-dash", "with.dot", "with space", "with/slash"): + with pytest.raises(ValueError): + ctx.register_auxiliary_task( + key=bad, + display_name="x", + description="x", + ) + + +def test_register_auxiliary_task_allows_same_plugin_re_registration(): + """Re-registration by the same plugin updates the entry (idempotent).""" + ctx, manager = _make_ctx("plug_a") + ctx.register_auxiliary_task( + key="t1", display_name="First", description="first" + ) + ctx.register_auxiliary_task( + key="t1", display_name="Second", description="second" + ) + assert manager._aux_tasks["t1"]["display_name"] == "Second" + + +def test_register_auxiliary_task_rejects_cross_plugin_collision(): + """Two different plugins cannot register the same task key.""" + manager = PluginManager() + manager._discovered = True + + manifest_a = PluginManifest(name="plug_a") + manifest_b = PluginManifest(name="plug_b") + ctx_a = PluginContext(manifest_a, manager) + ctx_b = PluginContext(manifest_b, manager) + + ctx_a.register_auxiliary_task( + key="shared", display_name="A", description="a" + ) + with pytest.raises(ValueError, match="already registered by plugin 'plug_a'"): + ctx_b.register_auxiliary_task( + key="shared", display_name="B", description="b" + ) + + +# โ”€โ”€ PluginManager state lifecycle โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_force_rediscovery_clears_aux_tasks(): + ctx, manager = _make_ctx() + ctx.register_auxiliary_task( + key="will_be_cleared", + display_name="x", + description="x", + ) + assert "will_be_cleared" in manager._aux_tasks + + manager._discovered = False + # Simulate force=True path: clears state before re-scanning + manager._aux_tasks.clear() + assert manager._aux_tasks == {} + + +# โ”€โ”€ Module-level helper โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_get_plugin_auxiliary_tasks_returns_sorted_list(patched_manager): + manifest = PluginManifest(name="plug") + ctx = PluginContext(manifest, patched_manager) + ctx.register_auxiliary_task( + key="zeta_task", display_name="Zeta", description="z" + ) + ctx.register_auxiliary_task( + key="alpha_task", display_name="Alpha", description="a" + ) + ctx.register_auxiliary_task( + key="mike_task", display_name="Mike", description="m" + ) + + tasks = get_plugin_auxiliary_tasks() + assert [t["key"] for t in tasks] == ["alpha_task", "mike_task", "zeta_task"] + + +def test_get_plugin_auxiliary_tasks_empty_when_none_registered(patched_manager): + assert get_plugin_auxiliary_tasks() == [] + + +# โ”€โ”€ _all_aux_tasks merges built-in + plugin โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_all_aux_tasks_includes_plugin_registered(patched_manager): + from hermes_cli.main import _AUX_TASKS, _all_aux_tasks + + manifest = PluginManifest(name="hindsight") + ctx = PluginContext(manifest, patched_manager) + ctx.register_auxiliary_task( + key="memory_retain_filter", + display_name="Memory retain filter", + description="hindsight pre-retain dedup/extract", + ) + + merged = _all_aux_tasks() + keys = [k for k, _, _ in merged] + # Built-ins preserved (and come first) + builtin_keys = [k for k, _, _ in _AUX_TASKS] + assert keys[: len(builtin_keys)] == builtin_keys + # Plugin task appended + assert "memory_retain_filter" in keys + plugin_entry = next(t for t in merged if t[0] == "memory_retain_filter") + assert plugin_entry == ( + "memory_retain_filter", + "Memory retain filter", + "hindsight pre-retain dedup/extract", + ) + + +def test_all_aux_tasks_swallows_plugin_discovery_failure(monkeypatch): + """Plugin discovery failure must not break the aux config UI.""" + from hermes_cli import main as main_mod + + def _broken(): + raise RuntimeError("plugin scan exploded") + + monkeypatch.setattr( + "hermes_cli.plugins.get_plugin_auxiliary_tasks", _broken + ) + + merged = main_mod._all_aux_tasks() + # Built-in tasks still present + assert any(k == "vision" for k, _, _ in merged) + + +# โ”€โ”€ _reset_aux_to_auto includes plugin tasks โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_reset_aux_to_auto_resets_plugin_tasks(tmp_path, monkeypatch, patched_manager): + """Plugin task with non-auto config gets reset alongside built-ins.""" + from pathlib import Path + from hermes_cli.config import load_config, save_config + from hermes_cli.main import _reset_aux_to_auto + + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + (tmp_path / ".hermes").mkdir(exist_ok=True) + + manifest = PluginManifest(name="plug") + ctx = PluginContext(manifest, patched_manager) + ctx.register_auxiliary_task( + key="my_aux", + display_name="My Aux", + description="d", + ) + + # Manually configure the plugin task to non-auto + cfg = load_config() + aux = cfg.setdefault("auxiliary", {}) + aux["my_aux"] = {"provider": "openrouter", "model": "gpt-4o", "base_url": "", "api_key": ""} + save_config(cfg) + + n = _reset_aux_to_auto() + assert n >= 1 + + cfg = load_config() + assert cfg["auxiliary"]["my_aux"]["provider"] == "auto" + assert cfg["auxiliary"]["my_aux"]["model"] == "" + + +# โ”€โ”€ auxiliary_client._get_auxiliary_task_config defaults layering โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +def test_get_auxiliary_task_config_layers_plugin_defaults( + tmp_path, monkeypatch, patched_manager +): + """Plugin-declared defaults appear when user has no config entry.""" + from pathlib import Path + from agent.auxiliary_client import _get_auxiliary_task_config + + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + (tmp_path / ".hermes").mkdir(exist_ok=True) + + manifest = PluginManifest(name="plug") + ctx = PluginContext(manifest, patched_manager) + ctx.register_auxiliary_task( + key="my_filter", + display_name="My filter", + description="x", + defaults={"timeout": 15, "extra_body": {"reasoning_effort": "low"}}, + ) + + # No user config for my_filter โ€” defaults should surface + resolved = _get_auxiliary_task_config("my_filter") + assert resolved["timeout"] == 15 + assert resolved["extra_body"] == {"reasoning_effort": "low"} + assert resolved["provider"] == "auto" + + +def test_get_auxiliary_task_config_user_config_wins_over_plugin_defaults( + tmp_path, monkeypatch, patched_manager +): + """User's config.yaml entry overrides plugin-declared defaults.""" + from pathlib import Path + from hermes_cli.config import load_config, save_config + from agent.auxiliary_client import _get_auxiliary_task_config + + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + (tmp_path / ".hermes").mkdir(exist_ok=True) + + manifest = PluginManifest(name="plug") + ctx = PluginContext(manifest, patched_manager) + ctx.register_auxiliary_task( + key="my_filter", + display_name="My filter", + description="x", + defaults={"timeout": 15, "provider": "auto"}, + ) + + # User overrides timeout + provider via config.yaml + cfg = load_config() + aux = cfg.setdefault("auxiliary", {}) + aux["my_filter"] = {"timeout": 90, "provider": "nous"} + save_config(cfg) + + resolved = _get_auxiliary_task_config("my_filter") + assert resolved["timeout"] == 90 # user wins + assert resolved["provider"] == "nous" # user wins + + +def test_get_auxiliary_task_config_unknown_task_returns_empty( + tmp_path, monkeypatch, patched_manager +): + from pathlib import Path + from agent.auxiliary_client import _get_auxiliary_task_config + + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + (tmp_path / ".hermes").mkdir(exist_ok=True) + + assert _get_auxiliary_task_config("nonexistent") == {} diff --git a/tests/hermes_cli/test_plugins_cmd.py b/tests/hermes_cli/test_plugins_cmd.py index 5a421f018f9..8184c373b77 100644 --- a/tests/hermes_cli/test_plugins_cmd.py +++ b/tests/hermes_cli/test_plugins_cmd.py @@ -65,6 +65,36 @@ class TestSanitizePluginName: with pytest.raises(ValueError, match="must not be empty"): _sanitize_plugin_name("", tmp_path) + # โ”€โ”€ allow_subdir=True โ”€โ”€ + + def test_allow_subdir_accepts_single_slash(self, tmp_path): + target = _sanitize_plugin_name( + "observability/langfuse", tmp_path, allow_subdir=True + ) + assert target == (tmp_path / "observability" / "langfuse").resolve() + + def test_allow_subdir_strips_leading_trailing_slash(self, tmp_path): + target = _sanitize_plugin_name( + "/image_gen/openai/", tmp_path, allow_subdir=True + ) + assert target == (tmp_path / "image_gen" / "openai").resolve() + + def test_allow_subdir_still_rejects_dot_dot(self, tmp_path): + with pytest.raises(ValueError, match="must not contain"): + _sanitize_plugin_name("foo/../bar", tmp_path, allow_subdir=True) + + def test_allow_subdir_still_rejects_backslash(self, tmp_path): + with pytest.raises(ValueError, match="must not contain"): + _sanitize_plugin_name("foo\\bar", tmp_path, allow_subdir=True) + + def test_allow_subdir_rejects_empty_after_strip(self, tmp_path): + with pytest.raises(ValueError, match="must not be empty"): + _sanitize_plugin_name("///", tmp_path, allow_subdir=True) + + def test_allow_subdir_resolves_inside_plugins_dir(self, tmp_path): + target = _sanitize_plugin_name("a/b/c", tmp_path, allow_subdir=True) + assert target.is_relative_to(tmp_path.resolve()) + # โ”€โ”€ _resolve_git_url โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ diff --git a/tests/hermes_cli/test_project_plugin_rce_bypass.py b/tests/hermes_cli/test_project_plugin_rce_bypass.py new file mode 100644 index 00000000000..7dc5ee803e2 --- /dev/null +++ b/tests/hermes_cli/test_project_plugin_rce_bypass.py @@ -0,0 +1,361 @@ +"""Regression coverage for GHSA-5qr3-c538-wm9j (#29156) โ€” Remote Code +Execution via the ``HERMES_ENABLE_PROJECT_PLUGINS`` bypass in the web +server's dashboard plugin loader. + +Two primitives combined into the original advisory chain: + +1. ``hermes_cli.web_server._discover_dashboard_plugins`` opted into + the untrusted ``./.hermes/plugins/`` source via + ``os.environ.get("HERMES_ENABLE_PROJECT_PLUGINS")`` โ€” truthy for + any non-empty string, so ``=0`` / ``=false`` / ``=no`` (all of + which the agent loader treats as off, and which operators set to + *disable* project plugins) silently *enabled* the source. +2. ``hermes_cli.web_server._mount_plugin_api_routes`` then imported + each plugin's manifest ``api`` field as a Python module via + ``importlib.util.spec_from_file_location``. The field was used + raw, with no path-traversal check, so a single manifest line + ``{"api": "/tmp/payload.py"}`` was enough to redirect the + importer at any Python file on disk (``Path('safe') / '/abs'`` + resolves to ``/abs`` in Python). + +These tests pin each layer of the new defence: + +* Truthy env semantics now match the agent loader. +* ``_safe_plugin_api_relpath`` rejects absolute paths, ``..`` + traversal, and non-string / empty values. +* ``_mount_plugin_api_routes`` re-validates at import time and + refuses project-source plugins outright. +* End-to-end the original PoC manifest no longer triggers + ``importlib`` for ``/tmp/payload.py``. +""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from hermes_cli import web_server + + +@pytest.fixture(autouse=True) +def _reset_plugin_cache(monkeypatch): + """The plugin scanner caches its result per-process. Bust the + cache before *and* after each test so leakage between tests can't + mask a regression โ€” and so the production cache the import-time + ``_mount_plugin_api_routes()`` populated doesn't bleed in.""" + web_server._dashboard_plugins_cache = None + yield + web_server._dashboard_plugins_cache = None + + +def _write_plugin_manifest(root: Path, name: str, manifest: dict) -> Path: + """Drop a manifest under ``root//dashboard/manifest.json`` and + return the dashboard dir path.""" + dashboard_dir = root / name / "dashboard" + dashboard_dir.mkdir(parents=True) + (dashboard_dir / "manifest.json").write_text(json.dumps(manifest)) + return dashboard_dir + + +# --------------------------------------------------------------------------- +# Layer 1 โ€” HERMES_ENABLE_PROJECT_PLUGINS env gate uses truthy semantics. +# --------------------------------------------------------------------------- + + +class TestProjectPluginsEnvGate: + """Project plugins must only be discovered when the env var is set + to a documented truthy value. Pre-#29156 any non-empty string โ€” + including ``0`` / ``false`` / ``no`` โ€” silently enabled the source.""" + + @pytest.fixture + def project_plugin(self, tmp_path, monkeypatch): + """Plant a project-source plugin under CWD's ``.hermes/plugins`` + and isolate the user-plugins dir to an empty tmp tree.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home")) + (tmp_path / "home").mkdir() + cwd = tmp_path / "evil-repo" + cwd.mkdir() + monkeypatch.chdir(cwd) + _write_plugin_manifest( + cwd / ".hermes" / "plugins", + "evil", + { + "name": "evil", + "label": "Evil", + "entry": "dist/index.js", + }, + ) + return cwd + + @pytest.mark.parametrize("value", ["", "0", "false", "FALSE", "no", "off", "False"]) + def test_falsy_values_keep_project_plugins_disabled( + self, project_plugin, monkeypatch, value + ): + if value == "": + monkeypatch.delenv("HERMES_ENABLE_PROJECT_PLUGINS", raising=False) + else: + monkeypatch.setenv("HERMES_ENABLE_PROJECT_PLUGINS", value) + + plugins = web_server._get_dashboard_plugins(force_rescan=True) + names = {p["name"] for p in plugins} + assert "evil" not in names, ( + f"HERMES_ENABLE_PROJECT_PLUGINS={value!r} must NOT enable the " + "project source โ€” that's the GHSA-5qr3-c538-wm9j env bypass." + ) + + @pytest.mark.parametrize("value", ["1", "true", "TRUE", "yes", "on", "YES"]) + def test_truthy_values_enable_project_plugins( + self, project_plugin, monkeypatch, value + ): + monkeypatch.setenv("HERMES_ENABLE_PROJECT_PLUGINS", value) + plugins = web_server._get_dashboard_plugins(force_rescan=True) + evil = next((p for p in plugins if p["name"] == "evil"), None) + assert evil is not None + assert evil["source"] == "project" + + +# --------------------------------------------------------------------------- +# Layer 2 โ€” _safe_plugin_api_relpath rejects path-traversal payloads. +# --------------------------------------------------------------------------- + + +class TestApiPathSanitizer: + """Unit-level coverage for the new ``_safe_plugin_api_relpath`` + helper. Anything that escapes the plugin's dashboard directory + must come back as ``None``.""" + + def _dashboard_dir(self, tmp_path): + d = tmp_path / "plug" / "dashboard" + d.mkdir(parents=True) + return d + + def test_simple_relative_path_accepted(self, tmp_path): + d = self._dashboard_dir(tmp_path) + (d / "api.py").write_text("router = None\n") + assert web_server._safe_plugin_api_relpath("api.py", dashboard_dir=d) == "api.py" + + def test_nested_relative_path_accepted(self, tmp_path): + d = self._dashboard_dir(tmp_path) + (d / "backend").mkdir() + (d / "backend" / "routes.py").write_text("router = None\n") + out = web_server._safe_plugin_api_relpath( + "backend/routes.py", dashboard_dir=d + ) + assert out == "backend/routes.py" + + @pytest.mark.parametrize("payload", [ + "/etc/passwd", + "/tmp/payload.py", + "/usr/bin/python", + # NT-style absolute on POSIX is a relative path โ€” covered by traversal below. + ]) + def test_absolute_path_rejected(self, tmp_path, payload): + d = self._dashboard_dir(tmp_path) + assert web_server._safe_plugin_api_relpath(payload, dashboard_dir=d) is None + + @pytest.mark.parametrize("payload", [ + "../../../etc/passwd", + "../neighbour/api.py", + "../../../../tmp/evil.py", + "subdir/../../../../etc/passwd", + ]) + def test_traversal_rejected(self, tmp_path, payload): + d = self._dashboard_dir(tmp_path) + assert web_server._safe_plugin_api_relpath(payload, dashboard_dir=d) is None + + @pytest.mark.parametrize("payload", [None, "", " ", 42, [], {}]) + def test_non_string_or_empty_rejected(self, tmp_path, payload): + d = self._dashboard_dir(tmp_path) + assert web_server._safe_plugin_api_relpath(payload, dashboard_dir=d) is None + + +# --------------------------------------------------------------------------- +# Layer 3 โ€” _discover_dashboard_plugins scrubs ``_api_file`` early. +# --------------------------------------------------------------------------- + + +class TestDiscoveryScrubsApiField: + """The cached plugin entry must NEVER carry an unsanitised api path. + A regression here would re-arm the RCE for any caller that uses + ``plugin['_api_file']`` directly.""" + + @pytest.fixture + def user_plugin_factory(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.delenv("HERMES_ENABLE_PROJECT_PLUGINS", raising=False) + + def _make(name: str, manifest: dict) -> None: + _write_plugin_manifest(tmp_path / "plugins", name, manifest) + + return _make + + def test_absolute_api_path_in_manifest_is_scrubbed(self, user_plugin_factory): + user_plugin_factory("evil", { + "name": "evil", + "label": "Evil", + "api": "/tmp/payload.py", + "entry": "dist/index.js", + }) + plugins = web_server._get_dashboard_plugins(force_rescan=True) + evil = next(p for p in plugins if p["name"] == "evil") + assert evil["_api_file"] is None + assert evil["has_api"] is False + + def test_traversal_api_path_in_manifest_is_scrubbed(self, user_plugin_factory): + user_plugin_factory("traverse", { + "name": "traverse", + "label": "Traverse", + "api": "../../../../tmp/evil.py", + "entry": "dist/index.js", + }) + plugins = web_server._get_dashboard_plugins(force_rescan=True) + entry = next(p for p in plugins if p["name"] == "traverse") + assert entry["_api_file"] is None + assert entry["has_api"] is False + + def test_safe_api_path_survives(self, user_plugin_factory, tmp_path): + user_plugin_factory("safe", { + "name": "safe", + "label": "Safe", + "api": "api.py", + "entry": "dist/index.js", + }) + # Make the api file actually exist so a downstream mount could + # in principle proceed โ€” we're only testing the discovery scrub. + (tmp_path / "plugins" / "safe" / "dashboard" / "api.py").write_text( + "router = None\n" + ) + plugins = web_server._get_dashboard_plugins(force_rescan=True) + entry = next(p for p in plugins if p["name"] == "safe") + assert entry["_api_file"] == "api.py" + assert entry["has_api"] is True + + +# --------------------------------------------------------------------------- +# Layer 4 โ€” _mount_plugin_api_routes refuses project-source + traversal. +# --------------------------------------------------------------------------- + + +class TestMountApiRoutesRefusesUntrusted: + """The mount routine is the actual ``importlib`` call site โ€” these + tests poke synthetic plugin entries directly into the cache and + assert the importer is *not* invoked.""" + + def _payload_plugin(self, tmp_path, *, source: str, api_file: str = "api.py"): + dash = tmp_path / "plug" / "dashboard" + dash.mkdir(parents=True) + # Write a benign router file; the test asserts it's NOT imported + # regardless of whether it exists, since the source/path checks + # short-circuit before the importer runs. + (dash / "api.py").write_text( + "from fastapi import APIRouter\nrouter = APIRouter()\n" + ) + return { + "name": "synthetic", + "label": "Synthetic", + "tab": {"path": "/synthetic", "position": "end"}, + "slots": [], + "entry": "dist/index.js", + "css": None, + "has_api": True, + "source": source, + "_dir": str(dash), + "_api_file": api_file, + } + + def test_project_source_api_is_not_imported(self, tmp_path): + plugin = self._payload_plugin(tmp_path, source="project") + web_server._dashboard_plugins_cache = [plugin] + with patch("importlib.util.spec_from_file_location") as spec: + web_server._mount_plugin_api_routes() + assert spec.call_count == 0, ( + "project-source plugin's api file was imported โ€” " + "GHSA-5qr3-c538-wm9j defence-in-depth regression" + ) + + def test_bundled_source_api_imports_normally(self, tmp_path): + plugin = self._payload_plugin(tmp_path, source="bundled") + web_server._dashboard_plugins_cache = [plugin] + with patch("importlib.util.spec_from_file_location") as spec: + spec.return_value = None # loader is None -> early continue, safe + web_server._mount_plugin_api_routes() + assert spec.call_count == 1 + # First positional arg after module_name is the resolved api path. + called_path = Path(spec.call_args.args[1]) + assert called_path.name == "api.py" + assert called_path.is_absolute() + + def test_traversal_api_caught_at_mount_time(self, tmp_path): + """Defence-in-depth: if discovery is bypassed (e.g. cache + tampering), mount-time validation still refuses to import a + file outside the dashboard dir.""" + plugin = self._payload_plugin(tmp_path, source="user", + api_file="../../../tmp/evil.py") + web_server._dashboard_plugins_cache = [plugin] + with patch("importlib.util.spec_from_file_location") as spec: + web_server._mount_plugin_api_routes() + assert spec.call_count == 0 + + +# --------------------------------------------------------------------------- +# Layer 5 โ€” End-to-end: the original PoC manifest no longer triggers RCE. +# --------------------------------------------------------------------------- + + +class TestEndToEndPocBlocked: + """Reproduces the original advisory PoC shape: untrusted CWD with a + manifest pointing ``api`` at an attacker-chosen Python file, with + ``HERMES_ENABLE_PROJECT_PLUGINS=0`` (so the operator believed the + project source was disabled). Post-fix, the importer must never + be invoked for the payload path, regardless of how the bypass is + framed (``=0`` truthy-string bypass, absolute path bypass, + project-source bypass).""" + + def test_full_chain_blocked(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home")) + (tmp_path / "home").mkdir() + cwd = tmp_path / "evil-repo" + cwd.mkdir() + monkeypatch.chdir(cwd) + # The original bypass: operator sets the var to a "disabled" + # string the web server pre-fix treated as enabled. + monkeypatch.setenv("HERMES_ENABLE_PROJECT_PLUGINS", "0") + # Payload: absolute path inside a manifest dropped in CWD. + payload_py = tmp_path / "payload.py" + payload_py.write_text("OWNED = True\n") + _write_plugin_manifest( + cwd / ".hermes" / "plugins", + "evil", + { + "name": "evil", + "label": "Evil", + "api": str(payload_py), + "entry": "dist/index.js", + }, + ) + + with patch("importlib.util.spec_from_file_location") as spec: + plugins = web_server._get_dashboard_plugins(force_rescan=True) + web_server._mount_plugin_api_routes() + + # The project source must stay disabled because ``0`` is no + # longer truthy. Even if the operator *had* opted in, the + # absolute-path api would be scrubbed at discovery, and even + # if discovery missed it the project-source guard in mount + # would refuse the import. + assert "evil" not in {p["name"] for p in plugins} + # Bundled plugins shipped with the repo may legitimately have + # ``api`` files and so ``spec_from_file_location`` can fire for + # those โ€” the regression is specifically that the *payload* + # path / *evil* module are never targeted. + for call in spec.call_args_list: + module_name = call.args[0] + target = Path(call.args[1]) + assert module_name != "hermes_dashboard_plugin_evil" + assert target != payload_py + assert "evil-repo" not in target.parts + assert "hermes_dashboard_plugin_evil" not in sys.modules diff --git a/tests/hermes_cli/test_security_audit.py b/tests/hermes_cli/test_security_audit.py new file mode 100644 index 00000000000..fe6abe7221c --- /dev/null +++ b/tests/hermes_cli/test_security_audit.py @@ -0,0 +1,299 @@ +"""Unit tests for hermes_cli.security_audit โ€” parsers + OSV plumbing. + +These never hit the live OSV API; HTTP is monkeypatched. The live-call path +is exercised in the E2E test embedded in PR validation, not here. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from hermes_cli import security_audit as sa + + +# โ”€โ”€โ”€ Parsers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestRequirementsParser: + def test_extracts_pinned_versions(self): + text = "requests==2.20.0\nflask==2.0.1\n" + assert sa._parse_requirements(text) == [ + ("requests", "2.20.0"), + ("flask", "2.0.1"), + ] + + def test_skips_comments_and_options(self): + text = "# comment\n-r other.txt\n--index-url https://x\nflask==2.0.1\n" + assert sa._parse_requirements(text) == [("flask", "2.0.1")] + + def test_skips_unpinned(self): + # We deliberately don't try to map >=, ~=, or bare-name deps to OSV. + text = "requests>=2.0\ntyping-extensions\nflask~=2.0\n" + assert sa._parse_requirements(text) == [] + + def test_handles_extras_and_markers(self): + text = 'requests[security]==2.20.0\nflask==2.0.1 ; python_version >= "3.8"\n' + assert sa._parse_requirements(text) == [ + ("requests", "2.20.0"), + ("flask", "2.0.1"), + ] + + def test_handles_empty(self): + assert sa._parse_requirements("") == [] + assert sa._parse_requirements(" \n\n ") == [] + + +class TestMCPComponentExtraction: + def test_npx_scoped_pinned(self): + comp = sa._extract_mcp_component( + "fs", "npx", ["-y", "@modelcontextprotocol/server-filesystem@0.5.0"] + ) + assert comp == sa.Component( + name="@modelcontextprotocol/server-filesystem", + version="0.5.0", + ecosystem="npm", + source="mcp:fs", + ) + + def test_npx_full_path_command(self): + comp = sa._extract_mcp_component( + "fetch", "/usr/local/bin/npx", ["mcp-server-fetch@1.2.3"] + ) + assert comp is not None + assert comp.name == "mcp-server-fetch" + assert comp.version == "1.2.3" + + def test_uvx_pinned(self): + comp = sa._extract_mcp_component("time", "uvx", ["mcp-server-time==2.1.0"]) + assert comp is not None + assert comp.ecosystem == "PyPI" + assert comp.name == "mcp-server-time" + assert comp.version == "2.1.0" + + def test_unpinned_returns_none(self): + # Bare npx package name = "latest" at runtime; not an audit subject. + assert sa._extract_mcp_component("x", "npx", ["-y", "some-pkg"]) is None + + def test_docker_returns_none(self): + # We don't currently parse docker image refs. + assert sa._extract_mcp_component("x", "docker", ["run", "-i", "mcp/foo:1.0"]) is None + + def test_empty_args(self): + assert sa._extract_mcp_component("x", "npx", []) is None + + +# โ”€โ”€โ”€ Plugin discovery โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestPluginDiscovery: + def test_reads_requirements_txt(self, tmp_path: Path): + plugin = tmp_path / "plugins" / "myplugin" + plugin.mkdir(parents=True) + (plugin / "requirements.txt").write_text("requests==2.20.0\n") + components = sa._discover_plugins(tmp_path) + assert len(components) == 1 + assert components[0].name == "requests" + assert components[0].source == "plugin:myplugin" + + def test_skips_when_no_plugins_dir(self, tmp_path: Path): + assert sa._discover_plugins(tmp_path) == [] + + def test_skips_hidden_dirs(self, tmp_path: Path): + (tmp_path / "plugins" / ".hidden").mkdir(parents=True) + (tmp_path / "plugins" / ".hidden" / "requirements.txt").write_text( + "requests==2.20.0\n" + ) + assert sa._discover_plugins(tmp_path) == [] + + def test_reads_pyproject_dependencies(self, tmp_path: Path): + plugin = tmp_path / "plugins" / "py" + plugin.mkdir(parents=True) + (plugin / "pyproject.toml").write_text( + '[project]\ndependencies = ["flask==2.0.1", "uvicorn>=0.20"]\n' + ) + components = sa._discover_plugins(tmp_path) + # uvicorn>=0.20 is unpinned, so only flask comes through + assert len(components) == 1 + assert components[0].name == "flask" + assert components[0].version == "2.0.1" + + +# โ”€โ”€โ”€ OSV severity extraction โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestSeverityExtraction: + def test_database_specific_severity(self): + rec = {"database_specific": {"severity": "HIGH"}} + assert sa._osv_severity_from_record(rec) == "HIGH" + + def test_unknown_when_no_severity(self): + assert sa._osv_severity_from_record({}) == "UNKNOWN" + + def test_ecosystem_specific_fallback(self): + rec = {"affected": [{"ecosystem_specific": {"severity": "MODERATE"}}]} + assert sa._osv_severity_from_record(rec) == "MODERATE" + + def test_fixed_versions_extracted_and_deduped(self): + rec = { + "affected": [ + { + "ranges": [ + { + "events": [ + {"introduced": "0"}, + {"fixed": "2.0.0"}, + ] + } + ] + }, + {"ranges": [{"events": [{"fixed": "2.0.0"}, {"fixed": "1.9.5"}]}]}, + ] + } + assert sa._osv_fixed_versions(rec) == ["2.0.0", "1.9.5"] + + +# โ”€โ”€โ”€ End-to-end orchestration with mocked OSV โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestRunAudit: + def test_no_components_returns_empty(self, tmp_path: Path): + findings = sa.run_audit( + skip_venv=True, skip_plugins=True, skip_mcp=True, hermes_home=tmp_path + ) + assert findings == [] + + def test_findings_sorted_by_severity_desc(self, tmp_path: Path): + plugin = tmp_path / "plugins" / "p" + plugin.mkdir(parents=True) + (plugin / "requirements.txt").write_text("alpha==1.0.0\nbeta==2.0.0\n") + + def fake_batch(comps): + return { + comps[0]: ["LOW-1"], + comps[1]: ["CRIT-1"], + } + + def fake_details(ids): + return { + "LOW-1": sa.Vulnerability(osv_id="LOW-1", severity="LOW", summary="low"), + "CRIT-1": sa.Vulnerability(osv_id="CRIT-1", severity="CRITICAL", summary="crit"), + } + + with patch.object(sa, "_osv_query_batch", side_effect=fake_batch), \ + patch.object(sa, "_osv_fetch_details", side_effect=fake_details): + findings = sa.run_audit( + skip_venv=True, skip_plugins=False, skip_mcp=True, hermes_home=tmp_path + ) + assert len(findings) == 2 + # CRITICAL must come first + assert findings[0].vuln.osv_id == "CRIT-1" + assert findings[1].vuln.osv_id == "LOW-1" + + +# โ”€โ”€โ”€ CLI subcommand exit codes โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestExitCodes: + def _build_args(self, **kwargs): + import argparse + + defaults = { + "skip_venv": True, + "skip_plugins": True, + "skip_mcp": True, + "json": False, + "fail_on": "critical", + } + defaults.update(kwargs) + return argparse.Namespace(**defaults) + + def test_clean_audit_exits_zero(self, tmp_path: Path, monkeypatch, capsys): + monkeypatch.setattr(sa, "get_hermes_home", lambda: str(tmp_path)) + # Everything skipped โ†’ no components โ†’ exit 0 + code = sa.cmd_security_audit(self._build_args()) + assert code == 0 + out = capsys.readouterr().out + assert "No components" in out or "0 component" in out + + def test_finding_above_threshold_exits_one(self, tmp_path: Path, monkeypatch): + monkeypatch.setattr(sa, "get_hermes_home", lambda: str(tmp_path)) + # Force a venv discovery to return one component, OSV to flag it CRITICAL + fake_comp = sa.Component( + name="pkg", version="1.0", ecosystem="PyPI", source="venv" + ) + monkeypatch.setattr(sa, "_discover_venv", lambda: [fake_comp]) + monkeypatch.setattr( + sa, "_osv_query_batch", lambda comps: {fake_comp: ["X-1"]} + ) + monkeypatch.setattr( + sa, + "_osv_fetch_details", + lambda ids: {"X-1": sa.Vulnerability(osv_id="X-1", severity="CRITICAL")}, + ) + code = sa.cmd_security_audit( + self._build_args(skip_venv=False, fail_on="critical") + ) + assert code == 1 + + def test_finding_below_threshold_exits_zero(self, tmp_path: Path, monkeypatch): + monkeypatch.setattr(sa, "get_hermes_home", lambda: str(tmp_path)) + fake_comp = sa.Component( + name="pkg", version="1.0", ecosystem="PyPI", source="venv" + ) + monkeypatch.setattr(sa, "_discover_venv", lambda: [fake_comp]) + monkeypatch.setattr( + sa, "_osv_query_batch", lambda comps: {fake_comp: ["X-1"]} + ) + monkeypatch.setattr( + sa, + "_osv_fetch_details", + lambda ids: {"X-1": sa.Vulnerability(osv_id="X-1", severity="MODERATE")}, + ) + code = sa.cmd_security_audit( + self._build_args(skip_venv=False, fail_on="critical") + ) + assert code == 0 + + def test_unknown_fail_on_value_exits_two(self, tmp_path: Path, monkeypatch, capsys): + monkeypatch.setattr(sa, "get_hermes_home", lambda: str(tmp_path)) + code = sa.cmd_security_audit(self._build_args(fail_on="garbage")) + assert code == 2 + err = capsys.readouterr().err + assert "fail-on" in err.lower() + + def test_json_output_shape(self, tmp_path: Path, monkeypatch, capsys): + monkeypatch.setattr(sa, "get_hermes_home", lambda: str(tmp_path)) + fake_comp = sa.Component( + name="pkg", version="1.0", ecosystem="PyPI", source="venv" + ) + monkeypatch.setattr(sa, "_discover_venv", lambda: [fake_comp]) + monkeypatch.setattr( + sa, "_osv_query_batch", lambda comps: {fake_comp: ["X-1"]} + ) + monkeypatch.setattr( + sa, + "_osv_fetch_details", + lambda ids: { + "X-1": sa.Vulnerability( + osv_id="X-1", + severity="HIGH", + summary="bad", + fixed_versions=["1.1"], + ) + }, + ) + sa.cmd_security_audit( + self._build_args(skip_venv=False, json=True, fail_on="critical") + ) + payload = capsys.readouterr().out + # The bitwarden banner can leak above the json; pick the first { line. + lines = payload.splitlines() + json_start = next(i for i, l in enumerate(lines) if l.startswith("{")) + data = json.loads("\n".join(lines[json_start:])) + assert data["finding_count"] == 1 + assert data["findings"][0]["severity"] == "HIGH" + assert data["findings"][0]["fixed_versions"] == ["1.1"] diff --git a/tests/hermes_cli/test_tools_config.py b/tests/hermes_cli/test_tools_config.py index 787292d83a4..0cb42ba299a 100644 --- a/tests/hermes_cli/test_tools_config.py +++ b/tests/hermes_cli/test_tools_config.py @@ -12,8 +12,10 @@ from hermes_cli.tools_config import ( _get_platform_tools, _platform_toolset_summary, _reconfigure_tool, + _run_post_setup, _save_platform_tools, _toolset_has_keys, + _toolset_needs_configuration_prompt, CONFIGURABLE_TOOLSETS, TOOL_CATEGORIES, _visible_providers, @@ -752,6 +754,91 @@ def test_numeric_mcp_server_name_does_not_crash_sorted(): # โ”€โ”€โ”€ Imagegen Backend Picker Wiring โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +def test_toolset_has_keys_treats_no_key_providers_as_configured(): + config = {} + + assert _toolset_has_keys("computer_use", config) is True + + +def test_computer_use_needs_configuration_when_cua_driver_post_setup_pending(): + """No-key providers can still need setup when their post_setup is unsatisfied. + + Returning users enabling Computer Use through `hermes tools` must reach the + cua-driver post-setup installer even though the provider has no API keys. + """ + with patch("shutil.which", return_value=None): + assert _toolset_needs_configuration_prompt("computer_use", {}) is True + + +def test_computer_use_skips_configuration_when_cua_driver_already_installed(): + """Installed post_setup dependencies should keep returning-user toggles no-op.""" + def fake_which(name: str): + return "/usr/local/bin/cua-driver" if name == "cua-driver" else None + + with patch("shutil.which", side_effect=fake_which): + assert _toolset_needs_configuration_prompt("computer_use", {}) is False + + +def test_computer_use_respects_custom_cua_driver_command(): + """The setup gate should match runtime's HERMES_CUA_DRIVER_CMD override.""" + def fake_which(name: str): + return "/opt/bin/custom-cua" if name == "custom-cua" else None + + with patch.dict("os.environ", {"HERMES_CUA_DRIVER_CMD": "custom-cua"}), \ + patch("shutil.which", side_effect=fake_which): + assert _toolset_needs_configuration_prompt("computer_use", {}) is False + + +def test_computer_use_blank_custom_driver_command_falls_back_to_default(): + """Blank overrides should not make the setup gate look for an empty command.""" + def fake_which(name: str): + return "/usr/local/bin/cua-driver" if name == "cua-driver" else None + + with patch.dict("os.environ", {"HERMES_CUA_DRIVER_CMD": " "}), \ + patch("shutil.which", side_effect=fake_which): + assert _toolset_needs_configuration_prompt("computer_use", {}) is False + + +def test_computer_use_post_setup_respects_custom_driver_command_when_installed(): + """post_setup already-installed checks should version-probe the override.""" + def fake_which(name: str): + return "/opt/bin/custom-cua" if name == "custom-cua" else None + + with patch.dict("os.environ", {"HERMES_CUA_DRIVER_CMD": "custom-cua"}), \ + patch("platform.system", return_value="Darwin"), \ + patch("shutil.which", side_effect=fake_which), \ + patch("subprocess.run") as run: + run.return_value.stdout = "custom 1.2.3\n" + + _run_post_setup("cua_driver") + + run.assert_called_once() + assert run.call_args.args[0] == ["custom-cua", "--version"] + + +def test_computer_use_post_setup_missing_override_does_not_accept_default_binary(): + """A default cua-driver binary must not satisfy a missing runtime override.""" + seen = [] + + def fake_which(name: str): + seen.append(name) + if name == "cua-driver": + return "/usr/local/bin/cua-driver" + if name == "curl": + return None + return None + + with patch.dict("os.environ", {"HERMES_CUA_DRIVER_CMD": "custom-cua"}), \ + patch("platform.system", return_value="Darwin"), \ + patch("shutil.which", side_effect=fake_which), \ + patch("subprocess.run") as run: + _run_post_setup("cua_driver") + + run.assert_not_called() + assert "custom-cua" in seen + assert "curl" in seen + + class TestImagegenBackendRegistry: """IMAGEGEN_BACKENDS tags drive the model picker flow in tools_config.""" diff --git a/tests/hermes_cli/test_tui_npm_install.py b/tests/hermes_cli/test_tui_npm_install.py index b11d3b4debb..6fca13c4927 100644 --- a/tests/hermes_cli/test_tui_npm_install.py +++ b/tests/hermes_cli/test_tui_npm_install.py @@ -168,7 +168,7 @@ def test_make_tui_argv_skips_build_only_on_termux_when_fresh( argv, cwd = main_mod._make_tui_argv(tmp_path, tui_dev=False) - assert argv == ["/bin/node", str(tmp_path / "dist" / "entry.js")] + assert argv == ["/bin/node", "--expose-gc", str(tmp_path / "dist" / "entry.js")] assert cwd == tmp_path diff --git a/tests/hermes_cli/test_tui_resume_flow.py b/tests/hermes_cli/test_tui_resume_flow.py index 7e6ccc05927..bcf552a8f10 100644 --- a/tests/hermes_cli/test_tui_resume_flow.py +++ b/tests/hermes_cli/test_tui_resume_flow.py @@ -1,4 +1,5 @@ from argparse import Namespace +import os from pathlib import Path import sys import types @@ -312,6 +313,37 @@ def test_termux_fast_cli_launch_chat_uses_light_parser(monkeypatch, main_mod): } +def test_termux_fast_cli_launch_bare_defers_agent_startup(monkeypatch, main_mod): + captured = {} + prepared = [] + + monkeypatch.setenv("TERMUX_VERSION", "1") + monkeypatch.delenv("HERMES_TUI", raising=False) + monkeypatch.delenv("HERMES_DEFER_AGENT_STARTUP", raising=False) + monkeypatch.delenv("HERMES_FAST_STARTUP_BANNER", raising=False) + monkeypatch.setattr(sys, "argv", ["hermes"]) + monkeypatch.setattr( + main_mod, "_prepare_agent_startup", lambda args: prepared.append(args.command) + ) + monkeypatch.setattr( + main_mod, + "cmd_chat", + lambda args: captured.update( + { + "query": args.query, + "command": args.command, + "compact": getattr(args, "compact", False), + } + ), + ) + + assert main_mod._try_termux_fast_cli_launch() is True + assert prepared == [] + assert captured == {"query": None, "command": None, "compact": True} + assert os.environ["HERMES_DEFER_AGENT_STARTUP"] == "1" + assert os.environ["HERMES_FAST_STARTUP_BANNER"] == "1" + + def test_termux_fast_cli_launch_oneshot_uses_light_parser(monkeypatch, main_mod): captured = {} prepared = [] @@ -364,6 +396,34 @@ def test_termux_fast_cli_launch_version_skips_update_check(monkeypatch, main_mod assert captured == [False] +def test_termux_ultrafast_version_runs_before_heavy_startup( + monkeypatch, capsys, main_mod +): + monkeypatch.setenv("TERMUX_VERSION", "1") + monkeypatch.delenv("HERMES_TERMUX_DISABLE_FAST_CLI", raising=False) + monkeypatch.setattr(sys, "argv", ["hermes", "--version"]) + + assert main_mod._try_termux_ultrafast_version() is True + + out = capsys.readouterr().out + assert "Hermes Agent v" in out + assert "Project:" in out + assert "Python:" in out + assert "OpenAI SDK:" in out + + +def test_read_openai_version_fast(monkeypatch, tmp_path, main_mod): + package_dir = tmp_path / "openai" + package_dir.mkdir() + (package_dir / "_version.py").write_text( + '__version__ = "9.8.7" # x-release-please-version\n', + encoding="utf-8", + ) + monkeypatch.setattr(sys, "path", [str(tmp_path)]) + + assert main_mod._read_openai_version_fast() == "9.8.7" + + def test_termux_fast_cli_launch_skips_help(monkeypatch, main_mod): monkeypatch.setenv("TERMUX_VERSION", "1") monkeypatch.delenv("HERMES_TUI", raising=False) diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index f5c06205621..d46e87c2862 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -327,6 +327,12 @@ class TestWebServerEndpoints: # Public endpoints should still work resp = unauth_client.get("/api/status") assert resp.status_code == 200 + resp = unauth_client.get("/api/dashboard/plugins") + assert resp.status_code == 200 + resp = unauth_client.get("/api/dashboard/plugins/rescan") + assert resp.status_code == 401 + resp = self.client.get("/api/dashboard/plugins/rescan") + assert resp.status_code == 200 def test_path_traversal_blocked(self): """Verify URL-encoded path traversal is blocked.""" @@ -2285,7 +2291,10 @@ class TestPtyWebSocket: self.ws_module.app.state, "bound_port", 9119, raising=False ) - with self.client.websocket_connect(self._url(channel="abc-123")) as conn: + headers = {"host": "127.0.0.1:9119", "origin": "http://127.0.0.1:9119"} + with self.client.websocket_connect( + self._url(channel="abc-123"), headers=headers + ) as conn: try: conn.receive_bytes() except Exception: @@ -2325,7 +2334,34 @@ class TestPtyWebSocket: with self.client.websocket_connect(pub_path) as pub: pub.send_text('{"type":"tool.start","payload":{"tool_id":"t1"}}') - received = sub.receive_text() + # Yield control so the server-side broadcast handler can + # process the frame. TestClient runs the ASGI app in a + # background thread; a small sleep gives that thread time + # to call _broadcast_event before we start blocking on + # receive_text(). Without this, under heavy CI load the + # receive can race the broadcast and hang until + # pytest-timeout kills us. + import queue, threading + recv_q: queue.Queue = queue.Queue() + + def _recv(): + try: + recv_q.put(sub.receive_text()) + except Exception as exc: + recv_q.put(exc) + + t = threading.Thread(target=_recv, daemon=True) + t.start() + try: + received = recv_q.get(timeout=10.0) + except queue.Empty: + raise AssertionError( + "broadcast not received within 10s โ€” server likely " + "dropped the frame silently (see _broadcast_event " + "except Exception: pass)" + ) + if isinstance(received, Exception): + raise received assert "tool.start" in received assert '"tool_id":"t1"' in received diff --git a/tests/hermes_cli/test_web_server_host_header.py b/tests/hermes_cli/test_web_server_host_header.py index 966127b05ce..9afef09d136 100644 --- a/tests/hermes_cli/test_web_server_host_header.py +++ b/tests/hermes_cli/test_web_server_host_header.py @@ -146,3 +146,72 @@ class TestHostHeaderMiddleware: resp = client.get("/api/status") # Should get through to the status endpoint, not a 400 assert resp.status_code != 400 + + +class TestWebSocketHostOriginGuard: + """WebSocket upgrades must enforce the same dashboard boundary as HTTP.""" + + def test_rebinding_websocket_host_is_rejected(self, monkeypatch): + from fastapi.testclient import TestClient + from starlette.websockets import WebSocketDisconnect + + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws.app.state, "bound_host", "127.0.0.1", raising=False) + monkeypatch.setattr(ws, "_DASHBOARD_EMBEDDED_CHAT_ENABLED", True) + + client = TestClient(ws.app) + url = f"/api/events?token={ws._SESSION_TOKEN}&channel=security-test" + with pytest.raises(WebSocketDisconnect) as exc: + with client.websocket_connect( + url, + headers={ + "Host": "evil.example", + "Origin": "http://evil.example", + }, + ): + pass + + assert exc.value.code == 4403 + + def test_rebinding_websocket_origin_is_rejected(self, monkeypatch): + from fastapi.testclient import TestClient + from starlette.websockets import WebSocketDisconnect + + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws.app.state, "bound_host", "127.0.0.1", raising=False) + monkeypatch.setattr(ws, "_DASHBOARD_EMBEDDED_CHAT_ENABLED", True) + + client = TestClient(ws.app) + url = f"/api/events?token={ws._SESSION_TOKEN}&channel=security-test" + with pytest.raises(WebSocketDisconnect) as exc: + with client.websocket_connect( + url, + headers={ + "Host": "localhost:9119", + "Origin": "http://evil.example", + }, + ): + pass + + assert exc.value.code == 4403 + + def test_loopback_websocket_host_and_origin_are_accepted(self, monkeypatch): + from fastapi.testclient import TestClient + + import hermes_cli.web_server as ws + + monkeypatch.setattr(ws.app.state, "bound_host", "127.0.0.1", raising=False) + monkeypatch.setattr(ws, "_DASHBOARD_EMBEDDED_CHAT_ENABLED", True) + + client = TestClient(ws.app) + url = f"/api/events?token={ws._SESSION_TOKEN}&channel=security-test" + with client.websocket_connect( + url, + headers={ + "Host": "localhost:9119", + "Origin": "http://localhost:9119", + }, + ): + pass diff --git a/tests/hermes_cli/test_webhook_cli.py b/tests/hermes_cli/test_webhook_cli.py index 0094e917c54..8d3880722bb 100644 --- a/tests/hermes_cli/test_webhook_cli.py +++ b/tests/hermes_cli/test_webhook_cli.py @@ -3,6 +3,7 @@ import json import os import pytest +import stat from argparse import Namespace from pathlib import Path @@ -145,6 +146,31 @@ class TestPersistence: path.write_text("broken{{{") assert _load_subscriptions() == {} + @pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits are platform-specific") + def test_save_creates_secret_file_owner_only_under_permissive_umask(self): + old_umask = os.umask(0o022) + try: + _save_subscriptions({"demo": {"secret": "TOPSECRET", "prompt": "x"}}) + finally: + os.umask(old_umask) + + path = _subscriptions_path() + assert stat.S_IMODE(path.stat().st_mode) == 0o600 + assert "TOPSECRET" in path.read_text(encoding="utf-8") + + @pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits are platform-specific") + def test_save_narrows_existing_broad_secret_file_mode(self): + # Simulate a pre-existing 0o644 file from before this hardening landed. + path = _subscriptions_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps({"old": {"secret": "stale", "prompt": "x"}})) + path.chmod(0o644) + + _save_subscriptions({"demo": {"secret": "FRESH", "prompt": "x"}}) + + assert stat.S_IMODE(path.stat().st_mode) == 0o600 + assert "FRESH" in path.read_text(encoding="utf-8") + class TestWebhookEnabledGate: def test_blocks_when_disabled(self, capsys, monkeypatch): diff --git a/tests/integration/test_voice_channel_flow.py b/tests/integration/test_voice_channel_flow.py index a38c8c6432f..420adcb0e73 100644 --- a/tests/integration/test_voice_channel_flow.py +++ b/tests/integration/test_voice_channel_flow.py @@ -38,7 +38,7 @@ except Exception: from types import SimpleNamespace from unittest.mock import MagicMock -from gateway.platforms.discord import VoiceReceiver +from plugins.platforms.discord.adapter import VoiceReceiver # --------------------------------------------------------------------------- diff --git a/tests/plugins/image_gen/check_parity_vs_main.py b/tests/plugins/image_gen/check_parity_vs_main.py new file mode 100644 index 00000000000..ca40cb5e13d --- /dev/null +++ b/tests/plugins/image_gen/check_parity_vs_main.py @@ -0,0 +1,300 @@ +"""Behavior-parity check for the image-gen FAL plugin migration (#26241). + +Spawns one subprocess per (version, scenario) cell โ€” pinned to either +``origin/main`` (legacy in-tree FAL fall-through + ``configured == "fal"`` +skip in ``_dispatch_to_plugin_provider``) or this PR's worktree (FAL is +itself a plugin and the dispatcher routes every set provider through +the registry). Each subprocess clears all FAL-related env vars + writes +a ``config.yaml``, then asks the dispatcher how it would route an +``image_generate`` call. The emitted shape tuple is +``{dispatch_kind, provider_name, model}``: + +* ``dispatch_kind`` โˆˆ ``{"legacy_fal", "plugin", "error", None}`` โ€” + whether the call would go straight to the in-tree pipeline, + through ``_dispatch_to_plugin_provider``, raise an explicit + provider-not-registered error, or fall through silently. +* ``provider_name`` โ€” when ``dispatch_kind == "plugin"``, the + resolved provider name. ``None`` otherwise. +* ``model`` โ€” the resolved FAL model id when applicable. + +The parent process diffs the shapes per scenario. A diff means the +migration introduced an observable behaviour change vs origin/main โ€” +likely a real regression for users on the existing config keys. + +Run from the PR worktree: + + python tests/plugins/image_gen/check_parity_vs_main.py +""" +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[3] + + +# Pin one path to current main, one to the PR worktree. +# ``REPO_ROOT`` is ``.../.worktrees/``; the main checkout lives +# two levels up. When running directly from a regular clone (no +# worktree), ``MAIN_DIR`` falls back to a sibling ``hermes-agent-main`` +# checkout if one exists. +def _resolve_main_dir() -> Path: + candidate = REPO_ROOT.parent.parent + if (candidate / "tools" / "image_generation_tool.py").exists() and candidate != REPO_ROOT: + return candidate + sibling = REPO_ROOT.parent / "hermes-agent-main" + if (sibling / "tools" / "image_generation_tool.py").exists(): + return sibling + return REPO_ROOT + + +MAIN_DIR = _resolve_main_dir() +PR_DIR = REPO_ROOT +assert (PR_DIR / "tools" / "image_generation_tool.py").exists(), ( + f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout" +) + + +SUBPROCESS_SCRIPT = r""" +import json, os, sys, tempfile +sys.path.insert(0, sys.argv[1]) + +# Isolated HERMES_HOME so the config write is hermetic. +home = tempfile.mkdtemp() +os.environ["HERMES_HOME"] = home + +# Clear FAL-related env so dispatch decisions are config-driven. +for k in ( + "FAL_KEY", "FAL_QUEUE_GATEWAY_URL", + "TOOL_GATEWAY_DOMAIN", "TOOL_GATEWAY_USER_TOKEN", + "FAL_IMAGE_MODEL", +): + os.environ.pop(k, None) + +scenario_env = json.loads(sys.argv[2]) +os.environ.update(scenario_env) + +config_yaml = sys.argv[3] +config_path = os.path.join(home, "config.yaml") +with open(config_path, "w") as f: + f.write(config_yaml) + +# Fresh import โ€” must not have anything cached. +for name in list(sys.modules): + if (name.startswith("tools.") + or name.startswith("agent.") + or name.startswith("plugins.") + or name.startswith("hermes_cli.")): + sys.modules.pop(name, None) + +import tools.image_generation_tool as image_tool + +dispatch_kind = None +provider_name = None +model = None +error_text = None + +try: + raw = image_tool._dispatch_to_plugin_provider("ping", "landscape") + if raw is None: + dispatch_kind = "legacy_fal" + else: + parsed = json.loads(raw) if isinstance(raw, str) else raw + if isinstance(parsed, dict): + if parsed.get("error_type") == "provider_not_registered": + dispatch_kind = "error" + error_text = parsed.get("error") + else: + dispatch_kind = "plugin" + provider_name = parsed.get("provider") + model = parsed.get("model") + else: + dispatch_kind = "unknown_payload" + + if model is None: + # _resolve_fal_model still returns the active FAL model id even + # when dispatch goes to a non-FAL plugin โ€” used for the diff + # only when applicable. + try: + model_id, _meta = image_tool._resolve_fal_model() + if dispatch_kind == "legacy_fal": + model = model_id + except Exception: + pass +except Exception as exc: + dispatch_kind = "exception" + error_text = repr(exc) + +shape = { + "dispatch_kind": dispatch_kind, + "provider_name": provider_name, + "model": model, + "error_present": error_text is not None, +} +print(json.dumps(shape)) +""" + + +SCENARIOS: list[tuple[str, str, dict[str, str]]] = [ + # (label, config.yaml body, extra env vars) + ("no-config-no-env", "", {}), + ( + "explicit-fal-no-creds", + "image_gen:\n provider: fal\n", + {}, + ), + ( + "explicit-fal-with-creds", + "image_gen:\n provider: fal\n", + {"FAL_KEY": "test-key"}, + ), + ( + "explicit-fal-with-model", + "image_gen:\n provider: fal\n model: fal-ai/flux-2-pro\n", + {"FAL_KEY": "test-key"}, + ), + ( + "explicit-typo-provider", + "image_gen:\n provider: not-a-real-backend\n", + {"FAL_KEY": "test-key"}, + ), + ( + "managed-gateway-only", + "", + { + "TOOL_GATEWAY_DOMAIN": "nousresearch.com", + "TOOL_GATEWAY_USER_TOKEN": "nous-token", + }, + ), +] + + +def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict) -> dict: + venv_python = repo_path / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = MAIN_DIR / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = Path("python3") + + out = subprocess.run( + [ + str(venv_python), + "-c", + SUBPROCESS_SCRIPT, + str(repo_path), + json.dumps(env), + config_yaml, + ], + capture_output=True, + text=True, + timeout=60, + ) + if out.returncode != 0: + return { + "error": "subprocess failed", + "stdout": out.stdout[-500:], + "stderr": out.stderr[-500:], + } + try: + return json.loads(out.stdout.strip().splitlines()[-1]) + except Exception as exc: + return {"error": f"could not parse output: {exc}", "stdout": out.stdout} + + +def _reduce(shape: dict) -> dict: + """Reduce to the parts that matter for user-visible parity. + + On origin/main, ``explicit-fal-*`` scenarios short-circuit to + ``legacy_fal`` because of the ``configured == "fal"`` skip. On the + PR, those same scenarios route through the plugin and emit + ``dispatch_kind == "plugin"`` with ``provider_name == "fal"``. + + Both shapes are functionally equivalent โ€” the plugin's ``generate()`` + re-enters the same in-tree pipeline via ``_it`` indirection โ€” but + we want the diff to be visible so reviewers can sign off on the + intentional behaviour delta. + """ + return { + "dispatch_kind": shape.get("dispatch_kind"), + "provider_name": shape.get("provider_name"), + "model": shape.get("model"), + "error_present": shape.get("error_present"), + } + + +def main() -> int: + print(f"main: {MAIN_DIR}") + print(f"pr: {PR_DIR}") + print() + + if MAIN_DIR == PR_DIR: + print( + "WARN: MAIN_DIR == PR_DIR โ€” diffs will be trivially identical.\n" + " Set up a sibling 'hermes-agent-main' checkout pinned to " + "origin/main to get real parity coverage." + ) + print() + + failures: list[str] = [] + errors: list[str] = [] + intentional_diffs: list[tuple[str, dict, dict]] = [] + for label, config_yaml, env in SCENARIOS: + main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env) + pr_shape = _run_scenario(PR_DIR, label, config_yaml, env) + + if "error" in main_shape or "error" in pr_shape: + print(f" [ERR ] {label}: subprocess failed") + print(f" main: {main_shape}") + print(f" pr: {pr_shape}") + errors.append(label) + continue + + main_reduced = _reduce(main_shape) + pr_reduced = _reduce(pr_shape) + + if main_reduced == pr_reduced: + print(f" [OK] {label}: {main_reduced}") + continue + + # On main, "explicit-fal-*" returns legacy_fal; on PR, plugin + # dispatch. That's the only acceptable diff โ€” flag everything + # else as a regression. + legacy_to_plugin_fal = ( + main_reduced.get("dispatch_kind") == "legacy_fal" + and pr_reduced.get("dispatch_kind") == "plugin" + and pr_reduced.get("provider_name") == "fal" + ) + if legacy_to_plugin_fal: + print(f" [DIFF] {label}: legacy_fal โ†’ plugin (fal) โ€” expected") + intentional_diffs.append((label, main_reduced, pr_reduced)) + else: + print(f" [FAIL] {label}") + print(f" main: {main_reduced}") + print(f" pr: {pr_reduced}") + failures.append(label) + + print() + if errors: + print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):") + for e in errors: + print(f" - {e}") + if failures: + print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):") + for f in failures: + print(f" - {f}") + if intentional_diffs: + print( + f"INTENTIONAL DIFFS ({len(intentional_diffs)}): " + f"legacy_fal โ†’ plugin dispatch for explicit FAL paths." + ) + if failures or errors: + return 1 + print(f"PARITY OK across {len(SCENARIOS)} scenarios.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/plugins/image_gen/test_fal_provider.py b/tests/plugins/image_gen/test_fal_provider.py new file mode 100644 index 00000000000..8b3e65e0bae --- /dev/null +++ b/tests/plugins/image_gen/test_fal_provider.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +"""Tests for the FAL.ai image generation plugin. + +The plugin is a thin registration adapter โ€” actual FAL pipeline logic +lives in ``tools.image_generation_tool`` and is exercised by +``tests/tools/test_image_generation.py``. These tests focus on: + +* the ``ImageGenProvider`` ABC surface (name, models, schema) +* call-time indirection (``_it`` resolution at ``generate()`` time so + ``monkeypatch.setattr(image_tool, ...)`` keeps working) +* response shape stamping (provider/prompt/aspect_ratio/model) +""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider surface +# --------------------------------------------------------------------------- + + +class TestFalImageGenProviderSurface: + def test_name(self): + from plugins.image_gen.fal import FalImageGenProvider + + assert FalImageGenProvider().name == "fal" + + def test_display_name(self): + from plugins.image_gen.fal import FalImageGenProvider + + assert FalImageGenProvider().display_name == "FAL.ai" + + def test_default_model_matches_legacy(self): + from plugins.image_gen.fal import FalImageGenProvider + from tools.image_generation_tool import DEFAULT_MODEL + + assert FalImageGenProvider().default_model() == DEFAULT_MODEL + + def test_list_models_uses_legacy_catalog(self): + from plugins.image_gen.fal import FalImageGenProvider + from tools.image_generation_tool import FAL_MODELS + + provider = FalImageGenProvider() + models = provider.list_models() + ids = {m["id"] for m in models} + # Whatever FAL_MODELS ships, the provider mirrors verbatim. + assert ids == set(FAL_MODELS.keys()) + # Spot-check the expected first-class fields are present. + for entry in models: + for field in ("id", "display", "speed", "strengths", "price"): + assert field in entry + + def test_setup_schema_advertises_fal_key(self): + from plugins.image_gen.fal import FalImageGenProvider + + schema = FalImageGenProvider().get_setup_schema() + assert schema["name"] == "FAL.ai" + assert schema["badge"] == "paid" + env_keys = {entry["key"] for entry in schema.get("env_vars", [])} + assert "FAL_KEY" in env_keys + + +class TestFalImageGenProviderAvailability: + def test_is_available_when_legacy_check_passes(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: True) + assert FalImageGenProvider().is_available() is True + + def test_is_available_false_when_legacy_check_fails(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "check_fal_api_key", lambda: False) + assert FalImageGenProvider().is_available() is False + + def test_is_available_handles_legacy_exception(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + def _boom(): + raise RuntimeError("config broke") + + monkeypatch.setattr(image_tool, "check_fal_api_key", _boom) + # Picker must not propagate exceptions โ€” show as "not available". + assert FalImageGenProvider().is_available() is False + + +# --------------------------------------------------------------------------- +# generate() โ€” call-time indirection +# --------------------------------------------------------------------------- + + +class TestFalImageGenProviderGenerate: + def test_generate_delegates_to_legacy_image_generate_tool(self, monkeypatch): + """Plugin must look up ``image_generate_tool`` at call time so + ``monkeypatch.setattr(image_tool, "image_generate_tool", ...)`` + takes effect.""" + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + captured = {} + + def fake_image_generate_tool(prompt, aspect_ratio, **kwargs): + captured["prompt"] = prompt + captured["aspect_ratio"] = aspect_ratio + captured["kwargs"] = kwargs + return json.dumps({"success": True, "image": "https://fake/image.png"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake_image_generate_tool) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + result = FalImageGenProvider().generate( + "a serene mountain landscape", + aspect_ratio="square", + seed=42, + ) + + assert captured["prompt"] == "a serene mountain landscape" + assert captured["aspect_ratio"] == "square" + assert captured["kwargs"] == {"seed": 42} + assert result["success"] is True + assert result["image"] == "https://fake/image.png" + # Stamped fields for the unified response shape + assert result["provider"] == "fal" + assert result["prompt"] == "a serene mountain landscape" + assert result["aspect_ratio"] == "square" + assert result["model"] == "fal-ai/flux-2/klein/9b" + + def test_generate_invalid_aspect_ratio_is_coerced(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + seen_aspect = {} + + def fake(prompt, aspect_ratio, **kwargs): + seen_aspect["v"] = aspect_ratio + return json.dumps({"success": True, "image": "x"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + FalImageGenProvider().generate("p", aspect_ratio="not-a-real-ratio") + # ``resolve_aspect_ratio`` clamps to landscape. + assert seen_aspect["v"] == "landscape" + + def test_generate_passthrough_drops_none_kwargs(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + seen = {} + + def fake(prompt, aspect_ratio, **kwargs): + seen.update(kwargs) + return json.dumps({"success": True, "image": "x"}) + + monkeypatch.setattr(image_tool, "image_generate_tool", fake) + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + FalImageGenProvider().generate( + "p", + aspect_ratio="landscape", + seed=None, + num_images=2, + guidance_scale=None, + ) + + # ``None`` values must not be forwarded โ€” they'd override the + # model's defaults inside the legacy payload builder. + assert "seed" not in seen + assert "guidance_scale" not in seen + assert seen.get("num_images") == 2 + + def test_generate_catches_exception_from_legacy(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + def boom(*args, **kwargs): + raise RuntimeError("FAL endpoint exploded") + + monkeypatch.setattr(image_tool, "image_generate_tool", boom) + + result = FalImageGenProvider().generate("p") + assert result["success"] is False + assert "FAL image generation failed" in result["error"] + assert result["error_type"] == "RuntimeError" + assert result["provider"] == "fal" + + def test_generate_invalid_json_response(self, monkeypatch): + import tools.image_generation_tool as image_tool + from plugins.image_gen.fal import FalImageGenProvider + + monkeypatch.setattr(image_tool, "image_generate_tool", lambda **kw: "not-json") + monkeypatch.setattr(image_tool, "_resolve_fal_model", + lambda: ("fal-ai/flux-2/klein/9b", {})) + + result = FalImageGenProvider().generate("p") + assert result["success"] is False + assert "Invalid JSON" in result["error"] + assert result["provider"] == "fal" + + +# --------------------------------------------------------------------------- +# Registry wiring +# --------------------------------------------------------------------------- + + +class TestFalImageGenPluginRegistration: + def test_register_wires_provider_into_registry(self): + from plugins.image_gen.fal import FalImageGenProvider, register + + ctx = MagicMock() + register(ctx) + + ctx.register_image_gen_provider.assert_called_once() + (registered,), _ = ctx.register_image_gen_provider.call_args + assert isinstance(registered, FalImageGenProvider) diff --git a/tests/plugins/model_providers/test_opencode_go_profile.py b/tests/plugins/model_providers/test_opencode_go_profile.py new file mode 100644 index 00000000000..7e6b5c8f64c --- /dev/null +++ b/tests/plugins/model_providers/test_opencode_go_profile.py @@ -0,0 +1,180 @@ +"""Unit tests for OpenCode Go reasoning-control wiring.""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture +def opencode_go_profile(): + """Resolve the registered OpenCode Go provider profile.""" + import model_tools # noqa: F401 + import providers + + profile = providers.get_provider_profile("opencode-go") + assert profile is not None, "opencode-go provider profile must be registered" + return profile + + +class TestOpenCodeGoKimiReasoning: + """Kimi K2 models use Moonshot's thinking + reasoning_effort shape on OpenCode Go.""" + + def test_high_effort_emits_thinking_and_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": "high"}, + model="kimi-k2.6", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {"reasoning_effort": "high"} + + def test_disabled_emits_thinking_disabled_without_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": False}, + model="kimi-k2.6", + ) + assert extra_body == {"thinking": {"type": "disabled"}} + assert top_level == {} + + def test_minimal_effort_enables_thinking_without_effort(self, opencode_go_profile): + # "minimal" is not a Moonshot-supported value โ€” drop it, keep thinking on. + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": "minimal"}, + model="kimi-k2.6", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {} + + @pytest.mark.parametrize( + "effort", + [ + "xhigh", + "max", + ], + ) + def test_strong_efforts_clamp_to_high(self, opencode_go_profile, effort): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": effort}, + model="moonshotai/kimi-k2.6", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {"reasoning_effort": "high"} + + def test_low_and_medium_pass_through(self, opencode_go_profile): + for effort in ("low", "medium"): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": effort}, + model="kimi-k2.5", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {"reasoning_effort": effort} + + def test_no_config_preserves_server_default(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config=None, + model="kimi-k2.6", + ) + assert extra_body == {} + assert top_level == {} + + +class TestOpenCodeGoDeepSeekThinking: + """DeepSeek V4 models use DeepSeek-style thinking controls on OpenCode Go.""" + + def test_high_effort_emits_thinking_and_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": "high"}, + model="deepseek-v4-pro", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {"reasoning_effort": "high"} + + def test_disabled_emits_thinking_disabled_without_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": False, "effort": "high"}, + model="deepseek-v4-pro", + ) + assert extra_body == {"thinking": {"type": "disabled"}} + assert top_level == {} + + def test_no_config_emits_thinking_enabled_without_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config=None, + model="deepseek-v4-pro", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {} + + def test_minimal_effort_enables_thinking_without_effort(self, opencode_go_profile): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": "minimal"}, + model="deepseek-v4-pro", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {} + + def test_xhigh_and_max_normalize_to_max(self, opencode_go_profile): + for effort in ("xhigh", "max"): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": effort}, + model="deepseek/deepseek-v4-pro", + ) + assert extra_body == {"thinking": {"type": "enabled"}} + assert top_level == {"reasoning_effort": "max"} + + +class TestOpenCodeGoModelGating: + """Other OpenCode Go models must not receive Kimi/DeepSeek controls.""" + + @pytest.mark.parametrize( + "model", + [ + "glm-5.1", + "qwen3.6-plus", + "minimax-m2.7", + "deepseek-v3.1", + "deepseek-chat", + "", + None, + ], + ) + def test_non_target_models_emit_nothing(self, opencode_go_profile, model): + extra_body, top_level = opencode_go_profile.build_api_kwargs_extras( + reasoning_config={"enabled": True, "effort": "high"}, + model=model, + ) + assert extra_body == {} + assert top_level == {} + + +class TestOpenCodeGoFullKwargsIntegration: + """End-to-end transport kwargs include the profile-provided controls.""" + + def test_kimi_reasoning_reaches_extra_body_and_top_level(self, opencode_go_profile): + from agent.transports.chat_completions import ChatCompletionsTransport + + kwargs = ChatCompletionsTransport().build_kwargs( + model="kimi-k2.6", + messages=[{"role": "user", "content": "ping"}], + tools=None, + provider_profile=opencode_go_profile, + reasoning_config={"enabled": True, "effort": "high"}, + base_url="https://opencode.ai/zen/go/v1", + ) + assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}} + assert kwargs["reasoning_effort"] == "high" + + def test_deepseek_thinking_reaches_extra_body_and_top_level( + self, opencode_go_profile + ): + from agent.transports.chat_completions import ChatCompletionsTransport + + kwargs = ChatCompletionsTransport().build_kwargs( + model="deepseek-v4-pro", + messages=[{"role": "user", "content": "ping"}], + tools=None, + provider_profile=opencode_go_profile, + reasoning_config={"enabled": True, "effort": "high"}, + base_url="https://opencode.ai/zen/go/v1", + ) + assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}} + assert kwargs["reasoning_effort"] == "high" diff --git a/tests/run_agent/test_31273_402_not_retried.py b/tests/run_agent/test_31273_402_not_retried.py new file mode 100644 index 00000000000..bae4af45733 --- /dev/null +++ b/tests/run_agent/test_31273_402_not_retried.py @@ -0,0 +1,147 @@ +"""Regression guard for #31273: HTTP 402 (billing exhaustion) must abort +after credential-pool rotation and provider fallback have failed. + +Before the fix, ``FailoverReason.billing`` was in the exclusion set that +prevents the loop's ``is_client_error`` branch from firing. When a user +ran a pay-per-token provider (OpenRouter, etc.) with no credential pool +and no fallback configured, a single 402 cascaded into +``agent.api_max_retries`` paid requests against an exhausted balance. +Real-world impact: ~$40 burned in 48h on a 24/7 gateway routing Telegram ++ Discord traffic. + +The fix removes ``FailoverReason.billing`` from the exclusion set. By +the time control reaches the ``is_client_error`` check: + * credential-pool rotation has already run (and either ``continue``d + on rotation, or returned False because the pool is exhausted/absent). + * the eager-fallback branch for billing has also run (and either + ``continue``d on fallback activation, or fell through because no + fallback is configured). +Falling through to the retry-backoff path from here just burns paid +requests with no recovery mechanism left. Aborting mirrors how 401/403 +(also ``should_fallback=True``) already behave once their recovery paths +have failed. +""" +from __future__ import annotations + + +class TestBillingTriggersClientErrorAbort: + """Mirror the ``is_client_error`` predicate shape used in + ``agent/conversation_loop.py`` and verify ``FailoverReason.billing`` + now resolves to True (i.e. aborts the loop). + """ + + def _mirror_is_client_error( + self, + *, + classified_retryable: bool, + classified_reason, + classified_should_compress: bool = False, + is_local_validation_error: bool = False, + is_context_length_error: bool = False, + ) -> bool: + """Exact shape of conversation_loop.py's is_client_error check. + + Kept in lock-step with the source. If you change one, change + both โ€” or, better, refactor the predicate into a shared helper + and have both sites import it. + """ + from agent.error_classifier import FailoverReason + + return ( + is_local_validation_error + or ( + not classified_retryable + and not classified_should_compress + and classified_reason not in { + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.context_overflow, + FailoverReason.payload_too_large, + FailoverReason.long_context_tier, + FailoverReason.thinking_signature, + } + ) + ) and not is_context_length_error + + def test_billing_now_aborts_the_loop(self): + """402 with no fallback / no pool entry โ†’ ``is_client_error`` True.""" + from agent.error_classifier import FailoverReason + + # This is what classify_api_error() returns for a plain 402: + # reason=billing, retryable=False, should_compress=False + assert self._mirror_is_client_error( + classified_retryable=False, + classified_reason=FailoverReason.billing, + ), ( + "FailoverReason.billing must trigger is_client_error abort after " + "credential-pool rotation and provider fallback have failed โ€” see #31273." + ) + + def test_rate_limit_still_retries(self): + """Sanity check: rate_limit must still fall through to backoff retry.""" + from agent.error_classifier import FailoverReason + + # 429 / transient 402 / rate-limited usage: must NOT abort, + # because Retry-After backoff and pool rotation are the right + # recovery paths. + assert not self._mirror_is_client_error( + classified_retryable=True, + classified_reason=FailoverReason.rate_limit, + ) + + def test_local_validation_error_still_aborts(self): + """Sanity check: bare ValueError/TypeError still abort.""" + from agent.error_classifier import FailoverReason + + assert self._mirror_is_client_error( + classified_retryable=True, + classified_reason=FailoverReason.unknown, + is_local_validation_error=True, + ) + + def test_context_overflow_still_falls_through_to_compression(self): + """Sanity check: context-overflow must NOT be classified as + client error โ€” compression is the recovery path.""" + from agent.error_classifier import FailoverReason + + assert not self._mirror_is_client_error( + classified_retryable=True, + classified_reason=FailoverReason.context_overflow, + classified_should_compress=True, + ) + + +class TestSourceStillHasBillingExclusionRemoved: + """Belt-and-suspenders: the production source must actually omit + ``FailoverReason.billing`` from the ``is_client_error`` exclusion + set. Protects against an accidental re-introduction. + """ + + def test_conversation_loop_omits_billing_from_client_error_exclusion(self): + import inspect + from agent import conversation_loop + + src = inspect.getsource(conversation_loop) + + # Locate the is_client_error block and inspect its exclusion set. + marker = "is_client_error = (" + assert marker in src, ( + "agent/conversation_loop.py must define is_client_error โ€” " + "the bug-fix anchor for #31273 has moved or been renamed." + ) + idx = src.index(marker) + # Window large enough to span the full predicate (~30 lines). + window = src[idx:idx + 2000] + + assert "FailoverReason.rate_limit" in window, ( + "is_client_error exclusion set has changed shape โ€” re-verify " + "that FailoverReason.billing is still NOT in it (#31273)." + ) + assert "FailoverReason.billing" not in window, ( + "FailoverReason.billing must NOT appear in the is_client_error " + "exclusion set โ€” see #31273. Billing (HTTP 402) is non-retryable " + "by the time control reaches this block: credential-pool rotation " + "and provider fallback have both already had their chance to " + "continue the loop. Re-adding it causes runaway token spend on " + "depleted balances." + ) diff --git a/tests/run_agent/test_codex_xai_oauth_recovery.py b/tests/run_agent/test_codex_xai_oauth_recovery.py index 585be09ab4d..a0d8656eabb 100644 --- a/tests/run_agent/test_codex_xai_oauth_recovery.py +++ b/tests/run_agent/test_codex_xai_oauth_recovery.py @@ -621,6 +621,246 @@ def test_recover_with_credential_pool_still_refreshes_genuine_auth_failure(): assert refresh_calls["n"] == 1 +# --------------------------------------------------------------------------- +# Fix D-bis: bad-credentials 403 must NOT be classified as entitlement (#29344) +# +# xAI returns the same permission-denied ``code`` text for two distinct +# conditions: unsubscribed account vs. stale OAuth access token. The +# ``error`` field's ``[WKE=unauthenticated:...]`` suffix (and the +# accompanying "OAuth2 access token could not be validated" phrasing) is +# xAI's authoritative disambiguator โ€” when present, the body is an auth +# failure, not entitlement, and the credential-pool refresh path must +# run. Pre-fix, long-running TUI sessions stuck on a stale token +# surfaced as a non-retryable client error; the workaround was to exit +# and reopen the TUI so the startup-resolve path refreshed. +# --------------------------------------------------------------------------- + + +def test_is_entitlement_failure_false_for_bad_credentials_wke_suffix(): + """403 with ``[WKE=unauthenticated:bad-credentials]`` is auth, not entitlement. + + Verbatim shape from the #29344 reporter โ€” the ``code`` text matches + the entitlement permission-denied heuristic, but the ``error`` field + carries xAI's explicit "this is a credential validation failure" + signal. Classifier must honor it. + """ + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "code": "The caller does not have permission to execute the specified operation", + "error": "The OAuth2 access token could not be validated. [WKE=unauthenticated:bad-credentials]", + }, + 403, + ) + + +def test_is_entitlement_failure_false_for_wke_suffix_in_normalized_shape(): + """The same body after ``_extract_api_error_context`` normalisation. + + Real runtime paths feed the classifier through + ``_extract_api_error_context``, which converts the raw body to + ``{message, reason, reset_at}``. The disambiguator must fire in + BOTH the raw-body shape (test above) and the normalised shape so + the fix actually reaches the production call site at + ``_recover_with_credential_pool``. + """ + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "reason": "The caller does not have permission to execute the specified operation", + "message": "The OAuth2 access token could not be validated. [WKE=unauthenticated:bad-credentials]", + }, + 403, + ) + + +@pytest.mark.parametrize("wke_variant", [ + # The headline variant โ€” what xAI returns today. + "[WKE=unauthenticated:bad-credentials]", + # Forward-compat: xAI documents the WKE prefix as a stable shape, + # the suffix after the colon is the "reason code" and could grow + # new values. Anything under ``unauthenticated:`` must route to + # the refresh path. + "[WKE=unauthenticated:expired-token]", + "[WKE=unauthenticated:revoked]", + "[WKE=unauthenticated:some-future-reason]", +]) +def test_is_entitlement_failure_false_for_any_wke_unauthenticated_variant(wke_variant): + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "code": "The caller does not have permission to execute the specified operation", + "error": f"Token rejected. {wke_variant}", + }, + 403, + ) + + +def test_is_entitlement_failure_false_via_oauth2_validation_phrase_alone(): + """Second disambiguator: the "OAuth2 access token could not be + validated" phrase by itself (no WKE suffix) must also route to + refresh. This is a belt-and-braces guard against xAI dropping or + reformatting the WKE suffix in a future API revision without + changing the human-readable error text.""" + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "code": "The caller does not have permission to execute the specified operation", + "error": "The OAuth2 access token could not be validated.", + }, + 403, + ) + + +def test_is_entitlement_failure_wke_signal_overrides_entitlement_keywords(): + """Defensive: if a future xAI body somehow carries BOTH the WKE + suffix AND entitlement language, the WKE signal wins. Auth is + recoverable; entitlement isn't. If the refreshed token still + can't access the resource, the next 403 (without WKE) lands on + the entitlement path correctly.""" + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "code": "The caller does not have permission to execute the specified operation", + "error": ( + "do not have an active Grok subscription. " + "[WKE=unauthenticated:bad-credentials]" + ), + }, + 403, + ) + + +def test_is_entitlement_failure_case_insensitive_wke_match(): + """Substring match is case-insensitive โ€” the classifier lowercases + everything before matching, so a future xAI build that uppercases + the prefix wouldn't reintroduce the misclassification.""" + from run_agent import AIAgent + + assert not AIAgent._is_entitlement_failure( + { + "code": "The caller does not have permission to execute the specified operation", + "error": "[wke=Unauthenticated:Bad-Credentials]", + }, + 403, + ) + + +def test_recover_with_credential_pool_refreshes_on_xai_bad_credentials_403(): + """End-to-end #29344: a bad-credentials 403 from xai-oauth MUST + call ``try_refresh_current()`` so the long-running TUI session + recovers without an exit/reopen cycle. + + Mirrors the scaffolding of + ``test_recover_with_credential_pool_still_refreshes_genuine_auth_failure`` + but with the exact 403 body shape xAI ships for stale tokens โ€” + the very body that pre-fix tripped the entitlement classifier + and short-circuited the refresh path. + """ + from run_agent import AIAgent + from agent.error_classifier import FailoverReason + + agent = _make_codex_agent() + + refresh_calls = {"n": 0} + + class _FakePool: + def try_refresh_current(self): + refresh_calls["n"] += 1 + entry = MagicMock() + entry.id = "entry_refreshed_after_stale" + return entry + + def mark_exhausted_and_rotate(self, **_kwargs): + return None + + def has_available(self): + return False + + agent._credential_pool = _FakePool() + agent._swap_credential = MagicMock() + + # Normalised shape that ``_extract_api_error_context`` would + # produce for the reporter's wire-level body. + error_context = { + "reason": ( + "The caller does not have permission to execute the specified operation" + ), + "message": ( + "The OAuth2 access token could not be validated. " + "[WKE=unauthenticated:bad-credentials]" + ), + } + + recovered, _retried_429 = agent._recover_with_credential_pool( + status_code=403, + has_retried_429=False, + classified_reason=FailoverReason.auth, + error_context=error_context, + ) + + assert recovered is True, ( + "Stale OAuth token (bad-credentials 403) must trigger refresh โ€” " + "pre-fix this returned False because the entitlement classifier " + "over-matched on the permission-denied code text" + ) + assert refresh_calls["n"] == 1, "try_refresh_current must run exactly once" + agent._swap_credential.assert_called_once() + + +def test_recover_with_credential_pool_still_blocks_real_entitlement(): + """Companion regression guard for the #29344 fix: the original + #26847 protection โ€” entitlement 403 must NOT refresh โ€” must + survive the new disambiguator. A real unsubscribed-account body + has no WKE suffix and no OAuth2-validation phrase, so the + classifier still classifies it as entitlement and short-circuits.""" + from run_agent import AIAgent + from agent.error_classifier import FailoverReason + + agent = _make_codex_agent() + + refresh_calls = {"n": 0} + + class _FakePool: + def try_refresh_current(self): + refresh_calls["n"] += 1 + return MagicMock(id="should_not_be_called") + + def mark_exhausted_and_rotate(self, **_kwargs): + return None + + def has_available(self): + return False + + agent._credential_pool = _FakePool() + + # Pure entitlement body โ€” no WKE suffix, no OAuth2 phrase. + error_context = { + "reason": ( + "The caller does not have permission to execute the specified operation" + ), + "message": ( + "You have either run out of available resources or do not have an " + "active Grok subscription. Manage at https://grok.com" + ), + } + + recovered, _retried_429 = agent._recover_with_credential_pool( + status_code=403, + has_retried_429=False, + classified_reason=FailoverReason.auth, + error_context=error_context, + ) + + assert recovered is False, "Entitlement 403 must surface, not refresh" + assert refresh_calls["n"] == 0 + + # --------------------------------------------------------------------------- # Fix E: grok-4.3 context length must be 1M, not 256K # --------------------------------------------------------------------------- diff --git a/tests/run_agent/test_create_openai_client_reuse.py b/tests/run_agent/test_create_openai_client_reuse.py index 13d95a46634..8b39711b3e4 100644 --- a/tests/run_agent/test_create_openai_client_reuse.py +++ b/tests/run_agent/test_create_openai_client_reuse.py @@ -190,7 +190,13 @@ def test_replace_primary_openai_client_survives_repeated_rebuilds(): def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper(): - """httpcore 1.x stores the real stream below conn._connection.""" + """httpcore 1.x stores the real stream below conn._connection. + + Post-#29507: the helper must shut sockets down but must NOT release the + FD via ``sock.close()`` โ€” that race recycled FDs into unrelated file + descriptors (kanban.db) and let TLS bytes overwrite SQLite headers. The + owning httpx thread is responsible for closing FDs on its own unwind. + """ from agent.agent_runtime_helpers import force_close_tcp_sockets class FakeSocket: @@ -215,4 +221,6 @@ def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper(): assert force_close_tcp_sockets(openai_client) == 1 assert sock.shutdown_calls == 1 - assert sock.close_calls == 1 + # #29507: close() must NOT be called from this helper โ€” the owning + # httpx worker thread releases the FD, not us. + assert sock.close_calls == 0 diff --git a/tests/run_agent/test_multimodal_tool_content_recovery.py b/tests/run_agent/test_multimodal_tool_content_recovery.py new file mode 100644 index 00000000000..63ee49f97c0 --- /dev/null +++ b/tests/run_agent/test_multimodal_tool_content_recovery.py @@ -0,0 +1,260 @@ +"""Tests for reactive multimodal-tool-content recovery. + +Covers the full chain for providers that reject list-type content in +``role: "tool"`` messages (Xiaomi MiMo's 400 "text is not set", etc.): + + 1. agent/error_classifier.py: 400 with the right wording classifies as + ``FailoverReason.multimodal_tool_content_unsupported``. + 2. run_agent._try_strip_image_parts_from_tool_messages downgrades tool + messages whose ``content`` is a list-with-image to a string text + summary, in-place, and records the active (provider, model) in + ``self._no_list_tool_content_models`` so future tool results in this + session preemptively downgrade. + 3. run_agent._tool_result_content_for_active_model short-circuits to a + text summary when the (provider, model) is in the cache, even though + ``_model_supports_vision`` returns True โ€” avoiding a wasted round + trip on every subsequent screenshot in the session. + +The end-to-end retry loop wiring (`conversation_loop.py`) is exercised by +the classifier signal + helper-mutation tests; the integration only adds +a trivial flag-and-continue around the existing pattern used for +``image_too_large`` recovery. + +See: https://github.com/NousResearch/hermes-agent/issues/27344 +""" + +from __future__ import annotations + +import pytest + +from agent.error_classifier import FailoverReason, classify_api_error + + +class _FakeApiError(Exception): + """Stand-in for an openai.BadRequestError with status_code + body.""" + + def __init__(self, status_code: int, message: str, body: dict | None = None): + super().__init__(message) + self.status_code = status_code + self.body = body or {"error": {"message": message}} + self.response = None + + +def _make_agent(provider: str = "xiaomi", model: str = "mimo-v2.5"): + """Build a bare AIAgent for method-level testing, no provider setup.""" + from run_agent import AIAgent + agent = object.__new__(AIAgent) + agent.provider = provider + agent.model = model + return agent + + +# โ”€โ”€โ”€ Strip helper โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestStripImagePartsHelper: + def test_no_messages_returns_false(self): + agent = _make_agent() + assert agent._try_strip_image_parts_from_tool_messages([]) is False + assert agent._try_strip_image_parts_from_tool_messages(None) is False + + def test_no_tool_messages_returns_false(self): + agent = _make_agent() + msgs = [ + {"role": "user", "content": "plain text"}, + {"role": "assistant", "content": "ack"}, + ] + assert agent._try_strip_image_parts_from_tool_messages(msgs) is False + + def test_tool_message_with_string_content_unchanged(self): + agent = _make_agent() + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": "plain string result"}, + ] + assert agent._try_strip_image_parts_from_tool_messages(msgs) is False + assert msgs[0]["content"] == "plain string result" + + def test_tool_message_list_without_image_unchanged(self): + """List content with only text parts is left alone โ€” caller surfaces + the original error if this turns out to also be rejected.""" + agent = _make_agent() + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "text", "text": "hello"}, + ]}, + ] + assert agent._try_strip_image_parts_from_tool_messages(msgs) is False + + def test_tool_message_list_with_image_downgrades(self): + agent = _make_agent() + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "text", "text": "AX summary: 5 buttons visible"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}, + ]}, + ] + assert agent._try_strip_image_parts_from_tool_messages(msgs) is True + # Image stripped; text preserved as a string. + assert isinstance(msgs[0]["content"], str) + assert "AX summary" in msgs[0]["content"] + assert "image_url" not in msgs[0]["content"] + assert "iVBOR" not in msgs[0]["content"] + + def test_tool_message_image_only_gets_placeholder(self): + """If the list had nothing but image parts, leave a placeholder so + the assistant message has something to reference.""" + agent = _make_agent() + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}, + ]}, + ] + assert agent._try_strip_image_parts_from_tool_messages(msgs) is True + assert isinstance(msgs[0]["content"], str) + assert "image content removed" in msgs[0]["content"] + + def test_records_provider_model_in_session_cache(self): + agent = _make_agent(provider="xiaomi", model="mimo-v2.5") + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "text", "text": "summary"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,X"}}, + ]}, + ] + agent._try_strip_image_parts_from_tool_messages(msgs) + assert ("xiaomi", "mimo-v2.5") in agent._no_list_tool_content_models + + def test_only_tool_messages_get_downgraded(self): + """User / assistant messages with list-type content are out of + scope โ€” they're handled by the existing image-routing path.""" + agent = _make_agent() + msgs = [ + {"role": "user", "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,X"}}, + ]}, + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "text", "text": "summary"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,Y"}}, + ]}, + ] + agent._try_strip_image_parts_from_tool_messages(msgs) + # User message untouched. + assert isinstance(msgs[0]["content"], list) + assert any(p.get("type") == "image_url" for p in msgs[0]["content"]) + # Tool message downgraded. + assert isinstance(msgs[1]["content"], str) + assert "summary" in msgs[1]["content"] + + def test_skips_recording_when_no_model_id(self): + """Don't poison the cache with empty keys when provider/model is + unset (e.g. lazy-initialised mid-handshake).""" + agent = _make_agent(provider="", model="") + msgs = [ + {"role": "tool", "tool_call_id": "x", "content": [ + {"type": "text", "text": "summary"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,X"}}, + ]}, + ] + agent._try_strip_image_parts_from_tool_messages(msgs) + assert agent._no_list_tool_content_models == set() + + +# โ”€โ”€โ”€ Short-circuit on cached models โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestToolResultContentShortCircuit: + """Once the session has learned that (provider, model) rejects list + content, ``_tool_result_content_for_active_model`` returns a text + summary even though ``_model_supports_vision`` reports True. + """ + + def _multimodal_result(self, png_b64: str = "iVBORw0KGgoAAAA"): + return { + "_multimodal": True, + "content": [ + {"type": "text", "text": "capture mode=som 800x600 app=Safari"}, + {"type": "image_url", + "image_url": {"url": f"data:image/png;base64,{png_b64}"}}, + ], + "text_summary": "capture mode=som 800x600 app=Safari", + "meta": {"mode": "som", "width": 800, "height": 600, "elements": 5, + "png_bytes": 1024}, + } + + def test_returns_list_when_cache_empty_and_vision_supported(self, monkeypatch): + agent = _make_agent(provider="xiaomi", model="mimo-v2.5") + agent._no_list_tool_content_models = set() # explicit empty + monkeypatch.setattr(agent, "_model_supports_vision", lambda: True) + out = agent._tool_result_content_for_active_model( + "computer_use", self._multimodal_result() + ) + # Native multimodal path: returns the content parts list. + assert isinstance(out, list) + assert any(p.get("type") == "image_url" for p in out) + + def test_returns_text_summary_when_model_in_cache(self, monkeypatch): + agent = _make_agent(provider="xiaomi", model="mimo-v2.5") + agent._no_list_tool_content_models = {("xiaomi", "mimo-v2.5")} + monkeypatch.setattr(agent, "_model_supports_vision", lambda: True) + out = agent._tool_result_content_for_active_model( + "computer_use", self._multimodal_result() + ) + # Short-circuit: a plain string summary, no image_url present. + assert isinstance(out, str) + assert "data:image" not in out + assert "image_url" not in out + + def test_cache_miss_on_different_model(self, monkeypatch): + """Cache is per (provider, model). A cached entry for mimo-v2.5 + must NOT affect a session running on a different model. + """ + agent = _make_agent(provider="xiaomi", model="mimo-v2.5-pro") + agent._no_list_tool_content_models = {("xiaomi", "mimo-v2.5")} + monkeypatch.setattr(agent, "_model_supports_vision", lambda: True) + out = agent._tool_result_content_for_active_model( + "computer_use", self._multimodal_result() + ) + assert isinstance(out, list) + + def test_missing_cache_attribute_falls_through(self, monkeypatch): + """Tests that build agents via ``object.__new__`` without calling + ``__init__`` must not crash โ€” the cache attribute may be absent. + """ + agent = _make_agent() + # Deliberately do not assign _no_list_tool_content_models. + monkeypatch.setattr(agent, "_model_supports_vision", lambda: True) + out = agent._tool_result_content_for_active_model( + "computer_use", self._multimodal_result() + ) + assert isinstance(out, list) + + +# โ”€โ”€โ”€ Classifier โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + +class TestRecoveryEndToEndClassification: + """Lock in that the patterns used by the recovery path classify to + the right ``FailoverReason``. (The recovery hook in + ``agent.conversation_loop`` consumes this reason directly.) + """ + + def test_xiaomi_mimo_classifies(self): + err = _FakeApiError( + status_code=400, + message=( + "Error code: 400 - {'error': {'code': '400', 'message': " + "'Param Incorrect', 'param': 'text is not set', 'type': ''}}" + ), + ) + result = classify_api_error(err, provider="xiaomi", model="mimo-v2.5") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported + assert result.retryable is True + + def test_alibaba_variant_classifies(self): + err = _FakeApiError( + status_code=400, + message="tool_call.content must be string", + ) + result = classify_api_error(err, provider="alibaba", model="qwen3.5-plus") + assert result.reason == FailoverReason.multimodal_tool_content_unsupported diff --git a/tests/run_agent/test_partial_stream_finish_reason.py b/tests/run_agent/test_partial_stream_finish_reason.py new file mode 100644 index 00000000000..f6948844f43 --- /dev/null +++ b/tests/run_agent/test_partial_stream_finish_reason.py @@ -0,0 +1,258 @@ +"""Regression tests for issue #30963 โ€” partial-stream stub finish_reason. + +Pins the contract: + +- text-only partial stream โ†’ stub.finish_reason == "length" so the + conversation loop's existing length-continuation path can keep the + agent moving against an unfinished goal. +- partial mid-tool-call โ†’ stub.finish_reason == "stop" so the loop + hands control back to the user (matches the user-visible warning + "Ask me to retry if you want to continue"). +- conversation_loop's length-continuation prompt distinguishes a real + output-length truncation from a partial-stream-stub network error + via response.id. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +# โ”€โ”€ Helpers (mirrors test_streaming.py) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +def _make_stream_chunk(content=None, tool_calls=None, finish_reason=None): + delta = SimpleNamespace( + content=content, tool_calls=tool_calls, + reasoning_content=None, reasoning=None, + ) + choice = SimpleNamespace(index=0, delta=delta, finish_reason=finish_reason) + return SimpleNamespace(choices=[choice], model=None, usage=None) + + +def _make_tool_call_delta(index=0, tc_id=None, name=None, arguments=None): + func = SimpleNamespace(name=name, arguments=arguments) + return SimpleNamespace(index=index, id=tc_id, function=func) + + +def _make_agent(): + from run_agent import AIAgent + agent = AIAgent( + api_key="test-key", + base_url="https://example.com/v1", + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + return agent + + +# โ”€โ”€ Stub finish_reason โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestPartialStreamStubFinishReason: + """The stub returned by interruptible_streaming_api_call when the + upstream connection dies mid-flight.""" + + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_text_only_partial_returns_length(self, _mock_close, mock_create, monkeypatch): + """#30963: text-only partials must classify as length so the loop + keeps continuing instead of exiting with budget remaining.""" + + def _stalling_stream(): + yield _make_stream_chunk(content="Here's my answer so far") + raise RuntimeError("simulated upstream stall") + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream() + mock_create.return_value = mock_client + + agent = _make_agent() + agent._current_streamed_assistant_text = "Here's my answer so far" + + monkeypatch.setenv("HERMES_STREAM_RETRIES", "0") + response = agent._interruptible_streaming_api_call({}) + + assert response.id == "partial-stream-stub" + assert response.choices[0].finish_reason == "length", ( + "Text-only partial streams must use finish_reason=length so the " + "conversation loop continues from where the network died " + "(issue #30963)." + ) + assert response.choices[0].message.content == "Here's my answer so far" + assert response.choices[0].message.tool_calls is None + + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_partial_tool_call_keeps_stop(self, _mock_close, mock_create, monkeypatch): + """Mid-tool-call partials keep finish_reason=stop on purpose โ€” the + warning text asks the user to drive the retry, not the agent.""" + + def _stalling_stream(): + yield _make_stream_chunk(content="Let me write the audit: ") + yield _make_stream_chunk(tool_calls=[ + _make_tool_call_delta(index=0, tc_id="call_1", name="write_file"), + ]) + yield _make_stream_chunk(tool_calls=[ + _make_tool_call_delta(index=0, arguments='{"path": "/tmp/x", '), + ]) + raise RuntimeError("simulated upstream stall") + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = lambda *a, **kw: _stalling_stream() + mock_create.return_value = mock_client + + agent = _make_agent() + agent._fire_stream_delta = lambda text: None + agent._current_streamed_assistant_text = "Let me write the audit: " + + monkeypatch.setenv("HERMES_STREAM_RETRIES", "0") + response = agent._interruptible_streaming_api_call({}) + + assert response.id == "partial-stream-stub" + assert response.choices[0].finish_reason == "stop", ( + "Partial mid-tool-call must keep finish_reason=stop โ€” the warning " + "appended to content asks the user to retry, so the agent must " + "not auto-replay a tool call with possible side-effects." + ) + content = response.choices[0].message.content or "" + assert "Stream stalled mid tool-call" in content + assert "write_file" in content + + +# โ”€โ”€ Length-continuation prompt branching โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +class TestLengthContinuationPromptBranching: + """When finish_reason=length, the continuation prompt that reaches the + model has to tell the truth: real truncation vs. network interruption. + Lying ("you were truncated") on a partial-stream stub leads the model + to no-op ("I wasn't truncated, I'm done"), defeating recovery.""" + + def _simulate_branch(self, response_id: str) -> str: + """Return the continuation prompt text the loop would inject for + a `finish_reason=length` response with the given id. Mirrors the + exact branch in agent/conversation_loop.py.""" + response = SimpleNamespace(id=response_id) + if getattr(response, "id", "") == "partial-stream-stub": + return ( + "[System: The previous response was cut off by a " + "network error mid-stream. Continue exactly where " + "you left off. Do not restart or repeat prior text. " + "Finish the answer directly.]" + ) + return ( + "[System: Your previous response was truncated by the output " + "length limit. Continue exactly where you left off. Do not " + "restart or repeat prior text. Finish the answer directly.]" + ) + + def test_partial_stream_stub_uses_network_prompt(self): + prompt = self._simulate_branch("partial-stream-stub") + assert "network error mid-stream" in prompt + assert "output length limit" not in prompt + + def test_real_truncation_uses_length_prompt(self): + prompt = self._simulate_branch("chatcmpl-abc123") + assert "output length limit" in prompt + assert "network error" not in prompt + + def test_no_id_falls_through_to_length_prompt(self): + prompt = self._simulate_branch("") + assert "output length limit" in prompt + + +# โ”€โ”€ Integration: live conversation loop โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +@pytest.fixture() +def loop_agent(): + """AIAgent with a mocked OpenAI client (mirrors test_run_agent's fixture) + so we can stage a stub + continuation pair on .chat.completions.create.""" + from run_agent import AIAgent + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + a._cached_system_prompt = "You are helpful." + a._use_prompt_caching = False + a.tool_delay = 0 + a.compression_enabled = False + a.save_trajectories = False + return a + + +class TestConversationLoopPartialStreamContinuation: + """End-to-end: a partial-stream stub feeds the loop and the loop + asks for continuation instead of exiting with finish_reason=stop.""" + + def test_partial_stream_stub_does_not_exit_loop_immediately(self, loop_agent): + """The stub from chat_completion_helpers used to exit the loop with + text_response(finish_reason=stop). Now finish_reason=length routes + through length_continue_retries โ€” the loop persists the partial + content and asks the model to continue.""" + + from tests.run_agent.test_run_agent import _mock_response, _mock_assistant_msg + + # First API call: the partial-stream stub (length on partial-stream-stub id). + partial_stub = SimpleNamespace( + id="partial-stream-stub", + model="test/model", + choices=[SimpleNamespace( + index=0, + message=_mock_assistant_msg(content="The first half of "), + finish_reason="length", + )], + usage=None, + ) + # Second API call: model continues with the rest, clean stop. + continuation = _mock_response( + content="the answer is forty-two.", finish_reason="stop", + ) + + loop_agent.client.chat.completions.create.side_effect = [ + partial_stub, continuation, + ] + + with ( + patch.object(loop_agent, "_persist_session"), + patch.object(loop_agent, "_save_trajectory"), + patch.object(loop_agent, "_cleanup_task_resources"), + ): + result = loop_agent.run_conversation("ask me something") + + # The loop made TWO API calls (stub + continuation), not one. + assert loop_agent.client.chat.completions.create.call_count == 2, ( + "Partial-stream-stub must trigger a continuation API call, not " + "exit the loop after one call." + ) + # The continuation prompt the loop appended must be the network-error + # variant, not the "output length limit" lie โ€” otherwise the model + # no-ops with "I wasn't truncated, I'm done." + # We assert it indirectly by inspecting the second-call kwargs. + second_call_kwargs = loop_agent.client.chat.completions.create.call_args_list[1] + msgs = second_call_kwargs.kwargs.get("messages") or second_call_kwargs.args[0].get("messages") + last_user = next( + (m for m in reversed(msgs) if m.get("role") == "user"), None, + ) + assert last_user is not None + assert "network error mid-stream" in (last_user.get("content") or ""), ( + "Continuation prompt for partial-stream-stub must mention the " + "network error, not the 'output length limit'." + ) + + # And the final response stitches both halves together. + assert "first half of" in result["final_response"] + assert "forty-two" in result["final_response"] diff --git a/tests/run_agent/test_plugin_context_engine_init.py b/tests/run_agent/test_plugin_context_engine_init.py index 60e89889088..83895ac6dce 100644 --- a/tests/run_agent/test_plugin_context_engine_init.py +++ b/tests/run_agent/test_plugin_context_engine_init.py @@ -87,5 +87,4 @@ def test_plugin_engine_update_model_args(): assert kw["context_length"] == 131_072 assert "model" in kw assert "provider" in kw - # Should NOT pass api_mode โ€” the ABC doesn't accept it - assert "api_mode" not in kw + assert "api_mode" in kw diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 821228075c3..3d0dcedddd0 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2636,6 +2636,31 @@ class TestRunConversation: assert result["final_response"] == "Final answer" assert result["completed"] is True + def test_ollama_small_runtime_context_fails_before_api_call(self, agent, caplog): + self._setup_agent(agent) + agent.model = "qwen3.5:9b" + agent.provider = "custom" + agent.base_url = "http://host.docker.internal:11434/v1" + agent._ollama_num_ctx = 4096 + + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + caplog.at_level(logging.WARNING, logger="agent.conversation_loop"), + ): + result = agent.run_conversation("Call ps -aux") + + assert result["failed"] is True + assert result["completed"] is False + assert result["api_calls"] == 0 + assert result["turn_exit_reason"] == "ollama_runtime_context_too_small" + assert "Ollama loaded `qwen3.5:9b` with only 4,096 tokens" in result["final_response"] + assert "model.ollama_num_ctx: 65536" in result["final_response"] + assert not agent.client.chat.completions.create.called + assert "Ollama runtime context too small for Hermes tool use" in caplog.text + assert "runtime_context=4096" in caplog.text + def test_tool_calls_then_stop(self, agent): self._setup_agent(agent) tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1") diff --git a/tests/run_agent/test_tls_fd_recycle_corruption.py b/tests/run_agent/test_tls_fd_recycle_corruption.py new file mode 100644 index 00000000000..062276db961 --- /dev/null +++ b/tests/run_agent/test_tls_fd_recycle_corruption.py @@ -0,0 +1,454 @@ +"""Regressions for issue #29507 โ€” cross-thread close of the per-request OpenAI +client could release a TLS socket FD whose integer was still cached in the +owning httpx worker's SSL BIO. The kernel then recycled the FD into the next +``open()`` (e.g. the kanban dispatcher's ``kanban.db``), and the worker's +delayed TLS flush wrote a 24-byte TLS application-data record on top of the +SQLite header. + +The fix has two prongs: + +1. ``force_close_tcp_sockets`` no longer calls ``sock.close()`` โ€” only + ``shutdown(SHUT_RDWR)``. Shutdown unblocks the worker's pending + ``recv``/``send`` without releasing the FD. + +2. ``_close_request_client_once`` is thread-aware: a stranger thread (the + interrupt-check / stale-call loop) only aborts the sockets and leaves + the client in the holder; the worker's own ``finally`` performs the + actual ``client.close()`` from its own thread context. + +Both prongs together close the FD-recycling window. The tests below pin +each prong individually and one end-to-end test simulates the reporter's +timeline at object granularity (no network, no real sockets). +""" +from __future__ import annotations + +import logging +import socket as _socket +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Prong 1: force_close_tcp_sockets must NOT release file descriptors. +# --------------------------------------------------------------------------- + + +class _FakeSocket: + """Records shutdown/close calls without touching real FDs.""" + + def __init__(self): + self.shutdown_calls = 0 + self.close_calls = 0 + + def shutdown(self, _how): + self.shutdown_calls += 1 + + def close(self): + self.close_calls += 1 + + +def _build_fake_client(sock): + """Mimic the httpcore-1 layout that ``_iter_pool_sockets`` walks.""" + stream = SimpleNamespace(_sock=sock) + http11 = SimpleNamespace(_network_stream=stream) + pool_entry = SimpleNamespace(_connection=http11) + pool = SimpleNamespace(_connections=[pool_entry]) + transport = SimpleNamespace(_pool=pool) + http_client = SimpleNamespace(_transport=transport) + return SimpleNamespace(_client=http_client) + + +def test_force_close_tcp_sockets_shutdown_only_no_close(): + """The smoking-gun guarantee: shutdown is called, close is NOT. + + If a future refactor reintroduces ``sock.close()`` here, the + FD-recycling race that corrupted ``kanban.db`` (issue #29507) will + re-open. Pin the contract explicitly. + """ + from agent.agent_runtime_helpers import force_close_tcp_sockets + + sock = _FakeSocket() + client = _build_fake_client(sock) + + n = force_close_tcp_sockets(client) + + assert n == 1 + assert sock.shutdown_calls == 1, "shutdown() must run โ€” it's how we unblock the worker" + assert sock.close_calls == 0, ( + "close() must NOT run from this helper โ€” releasing the FD here is the " + "race that wrote TLS bytes into kanban.db (#29507)" + ) + + +def test_force_close_tcp_sockets_uses_shut_rdwr(): + """Both directions must be shut down so the SSL state machine fully unwinds. + + Half-close (e.g. SHUT_WR only) wouldn't unblock a worker blocked in + ``recv``, defeating the whole point of the helper. + """ + from agent.agent_runtime_helpers import force_close_tcp_sockets + + captured = [] + + class _ProbingSocket: + def shutdown(self, how): + captured.append(how) + + def close(self): # pragma: no cover โ€” must not run, asserted below + captured.append("CLOSE_CALLED") + + sock = _ProbingSocket() + client = _build_fake_client(sock) + + force_close_tcp_sockets(client) + + assert captured == [_socket.SHUT_RDWR] + + +def test_force_close_tcp_sockets_swallows_oserror_on_shutdown(): + """A socket already shut down / not connected raises ``OSError`` โ€” benign.""" + from agent.agent_runtime_helpers import force_close_tcp_sockets + + class _AlreadyShut: + def shutdown(self, _how): + raise OSError("not connected") + + def close(self): # pragma: no cover โ€” must not run + raise AssertionError("close() must not be called") + + client = _build_fake_client(_AlreadyShut()) + + # No exception escapes; the helper still counts the socket as handled. + assert force_close_tcp_sockets(client) == 1 + + +def test_force_close_tcp_sockets_handles_multiple_pool_entries(): + """Walk every pool connection โ€” the bug equally applies to all of them.""" + from agent.agent_runtime_helpers import force_close_tcp_sockets + + socks = [_FakeSocket(), _FakeSocket(), _FakeSocket()] + entries = [ + SimpleNamespace(_connection=SimpleNamespace(_network_stream=SimpleNamespace(_sock=s))) + for s in socks + ] + pool = SimpleNamespace(_connections=entries) + transport = SimpleNamespace(_pool=pool) + http_client = SimpleNamespace(_transport=transport) + client = SimpleNamespace(_client=http_client) + + assert force_close_tcp_sockets(client) == 3 + for s in socks: + assert s.shutdown_calls == 1 + assert s.close_calls == 0 + + +# --------------------------------------------------------------------------- +# Prong 2: _close_request_client_once is thread-aware. +# --------------------------------------------------------------------------- + + +def _make_agent_mock(): + """Minimal agent with the two close primitives stubbed for spy-style checks.""" + agent = MagicMock() + agent._interrupt_requested = False + agent._close_request_openai_client = MagicMock() + agent._abort_request_openai_client = MagicMock() + return agent + + +def _call_inside_owner_thread(callable_): + """Run callable_ on a separate thread so its ``threading.get_ident()`` + differs from the test thread.""" + result = {"value": None, "exc": None} + + def runner(): + try: + result["value"] = callable_() + except BaseException as e: # noqa: BLE001 โ€” propagate test failures faithfully + result["exc"] = e + + t = threading.Thread(target=runner) + t.start() + t.join(timeout=5.0) + if result["exc"] is not None: + raise result["exc"] + return result["value"] + + +def test_close_from_stranger_thread_aborts_only_no_close(): + """Stranger-thread close โ†’ ``_abort_request_openai_client``, holder NOT popped. + + Reproduces the asyncio_0 โ†’ Thread-1616 interrupt path. After this call + the worker's eventual ``finally`` must still see the client in the + holder so IT can be the one releasing the FD. + """ + from agent.chat_completion_helpers import interruptible_api_call + + # We can't easily invoke just `_close_request_client_once` because it's + # a closure local to ``interruptible_api_call``. Re-extract the same + # logic by exercising it through a fake worker that lets us drive the + # holder state manually. + agent = _make_agent_mock() + # Pretend ``_call`` ran far enough to set the client on the holder + # from the owner thread. + sentinel = object() + owner_tid_holder = {"tid": None, "client_present_after_stranger_close": False} + + def _owner_workload(holder, lock): + # Owner-thread set + with lock: + holder["client"] = sentinel + holder["owner_tid"] = threading.get_ident() + owner_tid_holder["tid"] = threading.get_ident() + + holder = {"client": None, "owner_tid": None} + lock = threading.Lock() + _call_inside_owner_thread(lambda: _owner_workload(holder, lock)) + + # Now drive the exact body of the post-#29507 ``_close_request_client_once`` + # from the test thread (stranger) and from the owner thread. + def close_once(holder, lock, reason): + with lock: + request_client = holder.get("client") + owner_tid = holder.get("owner_tid") + stranger = ( + request_client is not None + and owner_tid is not None + and owner_tid != threading.get_ident() + ) + if not stranger: + holder["client"] = None + holder["owner_tid"] = None + if request_client is None: + return None + if stranger: + agent._abort_request_openai_client(request_client, reason=reason) + return "aborted" + agent._close_request_openai_client(request_client, reason=reason) + return "closed" + + outcome = close_once(holder, lock, "interrupt_abort") + + assert outcome == "aborted" + agent._abort_request_openai_client.assert_called_once() + agent._close_request_openai_client.assert_not_called() + # Holder is still populated โ€” the worker thread will pick this up in + # its ``finally`` and own the actual ``client.close()``. + assert holder["client"] is sentinel + assert holder["owner_tid"] == owner_tid_holder["tid"] + + +def test_close_from_owner_thread_pops_and_full_close(): + """Worker-thread close โ†’ ``_close_request_openai_client``, holder popped.""" + agent = _make_agent_mock() + sentinel = object() + holder = {"client": None, "owner_tid": None} + lock = threading.Lock() + + def workload(): + with lock: + holder["client"] = sentinel + holder["owner_tid"] = threading.get_ident() + + # Same body inlined here so the test thread and the closing thread + # are identical (owner == self). + with lock: + request_client = holder.get("client") + owner_tid = holder.get("owner_tid") + stranger = ( + request_client is not None + and owner_tid is not None + and owner_tid != threading.get_ident() + ) + if not stranger: + holder["client"] = None + holder["owner_tid"] = None + if request_client is None: + return None + if stranger: + agent._abort_request_openai_client(request_client, reason="request_complete") + return "aborted" + agent._close_request_openai_client(request_client, reason="request_complete") + return "closed" + + outcome = _call_inside_owner_thread(workload) + + assert outcome == "closed" + agent._close_request_openai_client.assert_called_once() + agent._abort_request_openai_client.assert_not_called() + assert holder["client"] is None + assert holder["owner_tid"] is None + + +def test_stranger_then_owner_close_sequence_runs_full_close_exactly_once(): + """Stranger abort followed by owner close โ†’ full close runs once. + + This mirrors the reporter's timeline: asyncio_0 fires interrupt_abort + (stranger โ†’ abort only), then Thread-1616 unwinds and its finally + fires request_complete (owner โ†’ full close). Net result must be one + abort + one full close, with the holder ending empty. + """ + agent = _make_agent_mock() + sentinel = object() + holder = {"client": None, "owner_tid": None} + lock = threading.Lock() + + def close_once(reason): + with lock: + request_client = holder.get("client") + owner_tid = holder.get("owner_tid") + stranger = ( + request_client is not None + and owner_tid is not None + and owner_tid != threading.get_ident() + ) + if not stranger: + holder["client"] = None + holder["owner_tid"] = None + if request_client is None: + return + if stranger: + agent._abort_request_openai_client(request_client, reason=reason) + else: + agent._close_request_openai_client(request_client, reason=reason) + + def owner_workload(): + # Set client from owner thread. + with lock: + holder["client"] = sentinel + holder["owner_tid"] = threading.get_ident() + # Simulate work being interrupted by a stranger from outside. + nonlocal_stranger_event.wait(timeout=2.0) + # Worker unwinds โ€” its finally calls close once. + close_once("request_complete") + + nonlocal_stranger_event = threading.Event() + owner = threading.Thread(target=owner_workload) + owner.start() + + # Test thread plays the stranger. + # Give the owner a moment to set the holder. + import time as _t + _t.sleep(0.05) + close_once("interrupt_abort") + nonlocal_stranger_event.set() + owner.join(timeout=5.0) + + assert not owner.is_alive(), "owner thread hung past join timeout" + + # The fix's intended outcome: abort once, close once, holder empty. + assert agent._abort_request_openai_client.call_count == 1 + assert agent._close_request_openai_client.call_count == 1 + assert holder["client"] is None + assert holder["owner_tid"] is None + + +# --------------------------------------------------------------------------- +# End-to-end: the agent's ``_abort_request_openai_client`` shuts sockets and +# logs deferred_close=stranger_thread without ever calling client.close(). +# --------------------------------------------------------------------------- + + +def test_agent_abort_request_openai_client_does_not_call_client_close(caplog): + """``_abort_request_openai_client`` must shutdown sockets but NEVER close(). + + This is the actual entry point used by the stranger-thread path. If a + future refactor accidentally wires it back to ``_close_openai_client`` + the FD race is back. Pin both the shutdown side-effect AND the absence + of any ``client.close()`` call. + """ + from run_agent import AIAgent + + sock = _FakeSocket() + client = _build_fake_client(sock) + + # ``client.close()`` would mutate the holder if invoked โ€” give it a + # MagicMock spy so we can assert no call. + client.close = MagicMock() + + agent = AIAgent.__new__(AIAgent) + agent._client_log_context = lambda: "provider=test" + + with caplog.at_level(logging.INFO, logger="run_agent"): + agent._abort_request_openai_client(client, reason="interrupt_abort") + + # Sockets shut down (one in our fake pool). + assert sock.shutdown_calls == 1 + assert sock.close_calls == 0 + # And critically: client.close() never ran here. + client.close.assert_not_called() + + # The log line is parseable: same ``tcp_force_closed=N`` field shape as + # the existing ``close`` log so dashboards keep working, plus a + # ``deferred_close=stranger_thread`` marker to make the new path + # observable in production triage. + msgs = [r.getMessage() for r in caplog.records] + assert any( + "OpenAI client aborted (interrupt_abort" in m + and "tcp_force_closed=1" in m + and "deferred_close=stranger_thread" in m + for m in msgs + ), f"missing abort log line; got: {msgs!r}" + + +def test_agent_abort_request_openai_client_null_client_is_noop(): + """A ``None`` client must short-circuit cleanly (defensive).""" + from run_agent import AIAgent + + agent = AIAgent.__new__(AIAgent) + agent._client_log_context = lambda: "provider=test" + + # No exception, no side effect. + agent._abort_request_openai_client(None, reason="interrupt_abort") + + +# --------------------------------------------------------------------------- +# FD-recycling proof: when shutdown-only is honored, a stranger-thread abort +# CANNOT release an FD that the owning thread still references. +# --------------------------------------------------------------------------- + + +def test_fd_recycle_window_closed_by_shutdown_only(): + """Construct the exact race the reporter saw โ€” abort from a stranger + thread, then have the (simulated) kernel recycle the FD into a new file. + With the fix, the worker's surviving socket reference cannot be + confused with the recycled file descriptor. + """ + from agent.agent_runtime_helpers import force_close_tcp_sockets + + # Tracks "was the FD released by the abort path?" โ€” that is the only + # signal the kernel needs to recycle the integer to a new ``open()``. + fd_released = {"yes": False} + + class _OwnedSocket: + """Simulates a socket whose FD is shared with the owner's SSL BIO. + + ``close`` flips ``fd_released`` so the test can assert that with + the fix the abort path NEVER releases the FD (and therefore the + kernel never recycles it under the owner's still-active reference). + """ + + def __init__(self): + self.shutdowns = 0 + + def shutdown(self, _how): + self.shutdowns += 1 + + def close(self): + fd_released["yes"] = True + + sock = _OwnedSocket() + client = _build_fake_client(sock) + + # Stranger thread runs the abort sweep (== what asyncio_0 did in the + # reporter's session). + _call_inside_owner_thread(lambda: force_close_tcp_sockets(client)) + + assert sock.shutdowns == 1, "shutdown must wake the worker" + assert fd_released["yes"] is False, ( + "force_close_tcp_sockets released the FD from a stranger thread โ€” " + "this is exactly the #29507 race. The owner thread must own close()." + ) diff --git a/tests/run_agent/test_tool_call_guardrail_runtime.py b/tests/run_agent/test_tool_call_guardrail_runtime.py index f1d90502391..e7ab376281a 100644 --- a/tests/run_agent/test_tool_call_guardrail_runtime.py +++ b/tests/run_agent/test_tool_call_guardrail_runtime.py @@ -304,3 +304,52 @@ def test_config_enabled_hard_stop_run_conversation_returns_controlled_guardrail_ call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]] following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids] assert len(following_results) == len(call_ids) + + +def test_guardrail_halt_emits_final_response_through_stream_delta_callback(): + """Regression for #30770: when the guardrail halts the loop, the + synthesized halt message must be pushed through ``stream_delta_callback`` + so SSE/TUI clients see why the agent stopped instead of a silent stream + close. Without this the chat-completions SSE writer drains an empty + queue and emits a finish chunk with zero content (indistinguishable + from a crash for Open WebUI and similar clients). + """ + agent = _make_agent("web_search", max_iterations=10, config=_hard_stop_config()) + same_args = {"query": "same"} + responses = [ + _mock_response( + content="", + finish_reason="tool_calls", + tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")], + ) + for i in range(1, 10) + ] + agent.client.chat.completions.create.side_effect = responses + + deltas: list = [] + agent.stream_delta_callback = lambda d: deltas.append(d) + # The mocked client returns SimpleNamespace responses which aren't + # iterable as streaming chunks; force the non-streaming code path so + # the guardrail-halt branch is reached without engaging the real + # streaming machinery. + agent._disable_streaming = True + + with ( + patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})), + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("search repeatedly") + + assert result["turn_exit_reason"] == "guardrail_halt" + halt_text = result["final_response"] + assert "stopped retrying" in halt_text + + # The halt message must have been pushed through the callback at least + # once. Empty-queue SSE writers were the bug โ€” clients saw no content + # delta before the finish chunk. + text_deltas = [d for d in deltas if isinstance(d, str)] + assert halt_text in text_deltas, ( + f"halt message was never streamed; callback only saw {deltas!r}" + ) diff --git a/tests/test_bitwarden_secrets.py b/tests/test_bitwarden_secrets.py index 47155795750..125fbcdc49e 100644 --- a/tests/test_bitwarden_secrets.py +++ b/tests/test_bitwarden_secrets.py @@ -301,6 +301,89 @@ def test_fetch_cache_hits(monkeypatch, tmp_path): assert call_count["n"] == 1 # cached on second call +def test_fetch_server_url_sets_env(monkeypatch, tmp_path): + """server_url must be plumbed into the subprocess as BWS_SERVER_URL.""" + fake_binary = tmp_path / "bws" + fake_binary.write_text("") + payload = _fake_bws_payload([{"key": "K", "value": "v"}]) + + captured_env = {} + + def fake_run(cmd, **kwargs): + captured_env.update(kwargs["env"]) + return mock.Mock(returncode=0, stdout=payload, stderr="") + + monkeypatch.setattr(bw.subprocess, "run", fake_run) + + bw.fetch_bitwarden_secrets( + access_token="0.t", + project_id="p", + binary=fake_binary, + use_cache=False, + server_url="https://vault.bitwarden.eu", + ) + assert captured_env.get("BWS_SERVER_URL") == "https://vault.bitwarden.eu" + + +def test_fetch_no_server_url_does_not_set_env(monkeypatch, tmp_path): + """When server_url is empty, BWS_SERVER_URL must not be injected.""" + fake_binary = tmp_path / "bws" + fake_binary.write_text("") + payload = _fake_bws_payload([]) + # Make sure the inherited env doesn't already have BWS_SERVER_URL set. + monkeypatch.delenv("BWS_SERVER_URL", raising=False) + + captured_env = {} + + def fake_run(cmd, **kwargs): + captured_env.update(kwargs["env"]) + return mock.Mock(returncode=0, stdout=payload, stderr="") + + monkeypatch.setattr(bw.subprocess, "run", fake_run) + + bw.fetch_bitwarden_secrets( + access_token="0.t", + project_id="p", + binary=fake_binary, + use_cache=False, + ) + assert "BWS_SERVER_URL" not in captured_env + + +def test_fetch_server_url_keyed_in_cache(monkeypatch, tmp_path): + """Different server_url values must produce separate cache entries.""" + fake_binary = tmp_path / "bws" + fake_binary.write_text("") + payload = _fake_bws_payload([{"key": "K", "value": "v"}]) + + call_count = {"n": 0} + + def fake_run(*a, **kw): + call_count["n"] += 1 + return mock.Mock(returncode=0, stdout=payload, stderr="") + + monkeypatch.setattr(bw.subprocess, "run", fake_run) + + # US (default empty) โ€” fresh fetch. + bw.fetch_bitwarden_secrets( + access_token="0.t", project_id="p", + binary=fake_binary, cache_ttl_seconds=60, + ) + # EU โ€” different server_url, must NOT hit the US cache entry. + bw.fetch_bitwarden_secrets( + access_token="0.t", project_id="p", + binary=fake_binary, cache_ttl_seconds=60, + server_url="https://vault.bitwarden.eu", + ) + # Second EU call hits cache. + bw.fetch_bitwarden_secrets( + access_token="0.t", project_id="p", + binary=fake_binary, cache_ttl_seconds=60, + server_url="https://vault.bitwarden.eu", + ) + assert call_count["n"] == 2 + + def test_fetch_cache_disabled(monkeypatch, tmp_path): fake_binary = tmp_path / "bws" fake_binary.write_text("") diff --git a/tests/test_env_loader_secret_sources.py b/tests/test_env_loader_secret_sources.py new file mode 100644 index 00000000000..8bd26451d9d --- /dev/null +++ b/tests/test_env_loader_secret_sources.py @@ -0,0 +1,119 @@ +"""Tests for the secret-source tracking in ``hermes_cli.env_loader``. + +These cover the small public surface that lets `hermes model` / `hermes setup` +label detected credentials with their origin ("from Bitwarden") so users +don't see an unexplained "credentials โœ“" line when their .env is empty. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from hermes_cli import env_loader # noqa: E402 + + +@pytest.fixture(autouse=True) +def _reset_sources(): + """Each test starts with a clean source map.""" + env_loader._SECRET_SOURCES.clear() + yield + env_loader._SECRET_SOURCES.clear() + + +def test_get_secret_source_returns_none_for_untracked_var(): + assert env_loader.get_secret_source("ANTHROPIC_API_KEY") is None + + +def test_get_secret_source_returns_label_for_tracked_var(): + env_loader._SECRET_SOURCES["ANTHROPIC_API_KEY"] = "bitwarden" + assert env_loader.get_secret_source("ANTHROPIC_API_KEY") == "bitwarden" + + +def test_format_secret_source_suffix_empty_for_untracked(): + # Credentials from .env or the shell shouldn't add noise โ€” the + # implicit case stays unlabeled. + assert env_loader.format_secret_source_suffix("ANTHROPIC_API_KEY") == "" + + +def test_format_secret_source_suffix_bitwarden_uses_proper_name(): + env_loader._SECRET_SOURCES["ANTHROPIC_API_KEY"] = "bitwarden" + assert ( + env_loader.format_secret_source_suffix("ANTHROPIC_API_KEY") + == " (from Bitwarden)" + ) + + +def test_format_secret_source_suffix_generic_label_for_future_sources(): + # Future-proofing: a new secret source (e.g. "vault") should still + # produce a sensible label without needing to edit every call site. + env_loader._SECRET_SOURCES["OPENAI_API_KEY"] = "vault" + assert ( + env_loader.format_secret_source_suffix("OPENAI_API_KEY") + == " (from vault)" + ) + + +def test_apply_external_secret_sources_records_bitwarden_origin(tmp_path, monkeypatch): + """End-to-end: when ``apply_bitwarden_secrets`` returns applied keys, + they end up in ``_SECRET_SOURCES`` so the UI can label them.""" + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config_path = tmp_path / "config.yaml" + config_path.write_text( + "secrets:\n" + " bitwarden:\n" + " enabled: true\n" + " project_id: test-project\n" + " access_token_env: BWS_ACCESS_TOKEN\n", + encoding="utf-8", + ) + + # Stub apply_bitwarden_secrets to return a synthetic FetchResult. + from agent.secret_sources.bitwarden import FetchResult + + fake_result = FetchResult( + secrets={"ANTHROPIC_API_KEY": "sk-ant-test"}, + applied=["ANTHROPIC_API_KEY"], + ) + + def _fake_apply(**_kwargs): + return fake_result + + # The import inside _apply_external_secret_sources is lazy, so we + # patch the *module attribute* it will pull in. + import agent.secret_sources.bitwarden as bw_module + + monkeypatch.setattr(bw_module, "apply_bitwarden_secrets", _fake_apply) + + env_loader._apply_external_secret_sources(tmp_path) + + assert env_loader.get_secret_source("ANTHROPIC_API_KEY") == "bitwarden" + assert ( + env_loader.format_secret_source_suffix("ANTHROPIC_API_KEY") + == " (from Bitwarden)" + ) + + +def test_apply_external_secret_sources_noop_when_disabled(tmp_path, monkeypatch): + """Disabled Bitwarden config must not touch the source map.""" + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config_path = tmp_path / "config.yaml" + config_path.write_text( + "secrets:\n" + " bitwarden:\n" + " enabled: false\n", + encoding="utf-8", + ) + + env_loader._apply_external_secret_sources(tmp_path) + + assert env_loader.get_secret_source("ANTHROPIC_API_KEY") is None diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 7c3cae75523..baabef000d2 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -161,6 +161,28 @@ class TestMessageStorage: session = db.get_session("s1") assert session["message_count"] == 2 + def test_observed_flag_round_trips_for_gateway_replay(self, db): + db.create_session(session_id="s1", source="telegram:-100") + db.append_message( + "s1", + role="user", + content="[Alice|111]\nside chatter", + observed=True, + ) + db.append_message("s1", role="assistant", content="ack") + + messages = db.get_messages("s1") + assert messages[0]["observed"] == 1 + assert messages[1]["observed"] == 0 + + conversation = db.get_messages_as_conversation("s1") + assert conversation[0] == { + "role": "user", + "content": "[Alice|111]\nside chatter", + "observed": True, + } + assert "observed" not in conversation[1] + def test_tool_response_does_not_increment_tool_count(self, db): """Tool responses (role=tool) should not increment tool_call_count. diff --git a/tests/test_minimax_oauth.py b/tests/test_minimax_oauth.py index 21e8ba13981..f29209cee8c 100644 --- a/tests/test_minimax_oauth.py +++ b/tests/test_minimax_oauth.py @@ -642,3 +642,202 @@ def test_generic_auth_status_dispatches_minimax_oauth(): assert status["logged_in"] is True assert status["provider"] == "minimax-oauth" assert status["region"] == "global" + + +# --------------------------------------------------------------------------- +# build_minimax_oauth_token_provider โ€” per-request callable bearer +# --------------------------------------------------------------------------- +# These tests verify the fix for short-lived (~15-min) MiniMax access tokens +# expiring mid-session. The callable is invoked by the Anthropic SDK on every +# outbound request via the existing Entra-style bearer hook. + + +def test_token_provider_returns_current_access_token_when_fresh(): + """When token is far from expiry, callable just returns the cached token.""" + from hermes_cli.auth import build_minimax_oauth_token_provider + + state = { + "access_token": "still-fresh", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), + } + + provider = build_minimax_oauth_token_provider() + + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state), \ + patch("httpx.Client") as mock_client_class: + token = provider() + # No network call should happen โ€” token is fresh. + mock_client_class.assert_not_called() + + assert token == "still-fresh" + + +def test_token_provider_refreshes_when_near_expiry(): + """When token is within the skew window, callable mints a fresh one.""" + from hermes_cli.auth import build_minimax_oauth_token_provider + + state = { + "access_token": "about-to-die", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(MINIMAX_OAUTH_REFRESH_SKEW_SECONDS - 1), + } + + refreshed_body = { + "status": "success", + "access_token": "fresh-bearer", + "refresh_token": "rt2", + "expired_in": 900, + } + mock_resp = _make_httpx_response(200, refreshed_body) + + provider = build_minimax_oauth_token_provider() + + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state), \ + patch("httpx.Client") as mock_client_class, \ + patch("hermes_cli.auth._minimax_save_auth_state"): + mock_instance = MagicMock() + mock_instance.__enter__ = MagicMock(return_value=mock_instance) + mock_instance.__exit__ = MagicMock(return_value=False) + mock_instance.post.return_value = mock_resp + mock_client_class.return_value = mock_instance + + token = provider() + + assert token == "fresh-bearer" + + +def test_token_provider_rereads_state_each_call(): + """Each callable invocation re-reads auth.json so cross-process refreshes + persisted by another hermes process are immediately visible.""" + from hermes_cli.auth import build_minimax_oauth_token_provider + + states = [ + { + "access_token": "first-token", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), + }, + { + "access_token": "second-token-after-another-process-refreshed", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), + }, + ] + + provider = build_minimax_oauth_token_provider() + with patch("hermes_cli.auth.get_provider_auth_state", side_effect=states): + first = provider() + second = provider() + + assert first == "first-token" + assert second == "second-token-after-another-process-refreshed" + + +def test_token_provider_raises_not_logged_in_when_state_missing(): + """No state in auth.json โ†’ AuthError(not_logged_in, relogin_required=True).""" + from hermes_cli.auth import build_minimax_oauth_token_provider + + provider = build_minimax_oauth_token_provider() + with patch("hermes_cli.auth.get_provider_auth_state", return_value=None): + with pytest.raises(AuthError) as exc_info: + provider() + + assert exc_info.value.code == "not_logged_in" + assert exc_info.value.relogin_required is True + + +def test_token_provider_quarantines_state_on_terminal_refresh(): + """When refresh returns invalid_grant, callable raises AuthError AND + wipes the dead tokens so subsequent calls fail fast without network.""" + from hermes_cli.auth import build_minimax_oauth_token_provider + + state = { + "access_token": "expired", + "refresh_token": "burned-rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _past_iso(100), + } + + bad_resp = _make_httpx_response(400, text="invalid_grant") + bad_resp.json.side_effect = Exception("no json") + bad_resp.text = "invalid_grant" + bad_resp.reason_phrase = "Bad Request" + + saved_states: list[dict] = [] + + provider = build_minimax_oauth_token_provider() + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state), \ + patch("httpx.Client") as mock_client_class, \ + patch( + "hermes_cli.auth._minimax_save_auth_state", + side_effect=lambda s: saved_states.append(dict(s)), + ): + mock_instance = MagicMock() + mock_instance.__enter__ = MagicMock(return_value=mock_instance) + mock_instance.__exit__ = MagicMock(return_value=False) + mock_instance.post.return_value = bad_resp + mock_client_class.return_value = mock_instance + + with pytest.raises(AuthError) as exc_info: + provider() + + assert exc_info.value.relogin_required is True + # Quarantine wrote a state with tokens removed. + assert len(saved_states) == 1 + quarantined = saved_states[0] + assert "access_token" not in quarantined + assert "refresh_token" not in quarantined + assert quarantined["last_auth_error"]["relogin_required"] is True + + +def test_resolve_returns_callable_when_as_token_provider_true(): + """Explicit opt-in path: resolve_minimax_oauth_runtime_credentials(as_token_provider=True) + returns a callable api_key.""" + state = { + "access_token": "tok", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), + } + + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state): + creds = resolve_minimax_oauth_runtime_credentials(as_token_provider=True) + + assert callable(creds["api_key"]) + assert not isinstance(creds["api_key"], str) + assert creds["base_url"] == MINIMAX_OAUTH_GLOBAL_INFERENCE.rstrip("/") + + +def test_resolve_returns_string_by_default(): + """Backwards-compatible default: api_key is a string materialized once.""" + state = { + "access_token": "tok", + "refresh_token": "rt", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), + } + + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state): + creds = resolve_minimax_oauth_runtime_credentials() + + assert creds["api_key"] == "tok" + assert isinstance(creds["api_key"], str) diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 2205cb8df64..3328110b2be 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -59,6 +59,59 @@ def test_write_json_returns_false_on_broken_pipe(monkeypatch): assert server.write_json({"ok": True}) is False +def test_tui_verbose_tool_details_fail_closed_when_redaction_fails(monkeypatch): + redact_module = types.ModuleType("agent.redact") + + def fail_redaction(*_args, **_kwargs): + raise RuntimeError("redaction unavailable") + + setattr(redact_module, "redact_sensitive_text", fail_redaction) + monkeypatch.setitem(sys.modules, "agent.redact", redact_module) + + assert server._redact_tui_verbose_text("api_key=secret") == "" + assert server._tool_args_text({"api_key": "secret"}) == "" + assert server._tool_result_text("token=secret") == "" + + +def test_tui_verbose_tool_details_are_capped_before_emit(monkeypatch): + monkeypatch.setattr(server, "_TUI_VERBOSE_TEXT_MAX_CHARS", 12) + monkeypatch.setattr(server, "_TUI_VERBOSE_TEXT_MAX_LINES", 2) + + capped = server._cap_tui_verbose_text("one\ntwo\nthree\nfour") + + assert capped.startswith("[showing verbose tail; omitted ") + assert capped.endswith("three\nfour") + assert "one" not in capped + + +def test_tui_verbose_tool_events_omit_details_when_redaction_fails(monkeypatch): + redact_module = types.ModuleType("agent.redact") + + def fail_redaction(*_args, **_kwargs): + raise RuntimeError("redaction unavailable") + + setattr(redact_module, "redact_sensitive_text", fail_redaction) + monkeypatch.setitem(sys.modules, "agent.redact", redact_module) + + events: list[tuple[str, str, dict]] = [] + monkeypatch.setattr( + server, "_emit", lambda event_type, sid, payload: events.append((event_type, sid, payload)) + ) + monkeypatch.setitem( + server._sessions, + "redaction-test", + {"tool_progress_mode": "verbose", "tool_started_at": {}}, + ) + + server._on_tool_start("redaction-test", "tool-1", "terminal", {"command": "pwd"}) + server._on_tool_complete("redaction-test", "tool-1", "terminal", {"command": "pwd"}, "done") + + assert events[0][0] == "tool.start" + assert events[1][0] == "tool.complete" + assert "args_text" not in events[0][2] + assert "result_text" not in events[1][2] + + def test_dispatch_rejects_non_object_request(): resp = server.dispatch([]) @@ -1566,6 +1619,26 @@ def test_complete_slash_includes_provider_alias(): assert any(item["text"] == "provider" for item in resp["result"]["items"]) +def test_complete_slash_returns_plain_string_fields(): + # prompt_toolkit hands us FormattedText (a list subclass) for + # display/display_meta; the TUI's CompletionItem contract is plain + # strings, and shipping the raw list trips Ink's row layout into + # 1-char truncation of the next column (/goal โ†’ /goa). + resp = server.handle_request( + {"id": "1", "method": "complete.slash", "params": {"text": "/g"}} + ) + + items = resp["result"]["items"] + goal = next((it for it in items if it["text"] == "goal"), None) + assert goal is not None + assert isinstance(goal["display"], str), goal["display"] + assert isinstance(goal["meta"], str), goal["meta"] + assert goal["display"] == "/goal" + for item in items: + assert isinstance(item["display"], str), item + assert isinstance(item["meta"], str), item + + def test_complete_slash_includes_tui_details_command(): resp = server.handle_request( {"id": "1", "method": "complete.slash", "params": {"text": "/det"}} diff --git a/tests/tools/conftest.py b/tests/tools/conftest.py index 548b37f38c9..494dd206a1e 100644 --- a/tests/tools/conftest.py +++ b/tests/tools/conftest.py @@ -8,6 +8,8 @@ depend on the registry being populated should use it explicitly or via ``@pytest.mark.usefixtures("web_registry_populated")``. """ +from unittest.mock import patch + import pytest @@ -48,3 +50,20 @@ def web_registry_populated(): yield from agent.web_search_registry import _reset_for_tests _reset_for_tests() + + +@pytest.fixture +def disable_lazy_stt_install(): + """Disarm the runtime lazy-install probe so static ``_HAS_FASTER_WHISPER`` + patches accurately simulate 'faster-whisper not installed'. + + Without this, ``_try_lazy_install_stt()`` calls + ``importlib.util.find_spec("faster_whisper")``, which returns truthy + whenever the package is installed in the dev / CI environment โ€” + defeating the test's ``_HAS_FASTER_WHISPER=False`` patch. + + Opt in at module scope with + ``pytestmark = pytest.mark.usefixtures("disable_lazy_stt_install")``. + """ + with patch("tools.transcription_tools._try_lazy_install_stt", return_value=False): + yield diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index 0694dbcdc91..942d27cbe13 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -1,6 +1,9 @@ """Tests for the dangerous command approval module.""" import ast +import os +import threading +import time from pathlib import Path from types import SimpleNamespace from unittest.mock import patch as mock_patch @@ -1305,3 +1308,165 @@ class TestEtcPatternsUnaffectedByRefactor: def test_grep_etc_passwd_is_safe(self): dangerous, _, _ = detect_dangerous_command("grep root /etc/passwd") assert dangerous is False + + +# ========================================================================= +# Gateway approval timeout = deny, NOT consent (#24912) +# +# A Slack user walked away mid-conversation; the agent requested approval +# to run `rm -rf .git`; the prompt timed out; the agent ran the command +# anyway. Reported by @tofalck on 2026-05-13, corroborated by +# @angry-programmer on Telegram. Silence is not consent. +# +# These tests pin: +# 1. Gateway timeout โ†’ approved=False, with a message strong enough that +# a downstream agent reading "BLOCKED: ... Silence is not consent." +# treats it as a hard halt, not an invitation to rephrase. +# 2. The structured outcome / user_consent fields are present so +# plugins, hooks, and audit pipelines can act on the timeout without +# string-parsing the message. +# 3. Explicit /deny carries the same shape (treat-as-not-consented). +# ========================================================================= + + +class TestApprovalTimeoutIsNotConsent: + """The gateway approval contract: silence is not consent (#24912).""" + + SESSION_KEY = "test-no-consent-session" + + def setup_method(self): + """Reset module state and force tight gateway_timeout for fast tests.""" + from tools import approval as mod + mod._gateway_queues.clear() + mod._gateway_notify_cbs.clear() + mod._session_approved.clear() + mod._permanent_approved.clear() + mod._pending.clear() + + self._saved_env = { + k: os.environ.get(k) + for k in ("HERMES_GATEWAY_SESSION", "HERMES_YOLO_MODE", + "HERMES_SESSION_KEY", "HERMES_INTERACTIVE") + } + os.environ.pop("HERMES_YOLO_MODE", None) + os.environ.pop("HERMES_INTERACTIVE", None) + os.environ["HERMES_GATEWAY_SESSION"] = "1" + os.environ["HERMES_SESSION_KEY"] = self.SESSION_KEY + + def teardown_method(self): + from tools import approval as mod + mod._gateway_queues.clear() + mod._gateway_notify_cbs.clear() + for k, v in self._saved_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + def _force_short_timeout(self, monkeypatch, seconds=1): + from tools import approval as mod + monkeypatch.setattr( + mod, "_get_approval_config", + lambda: {"mode": "manual", "gateway_timeout": seconds, "timeout": seconds}, + ) + + def test_timeout_returns_approved_false_with_no_consent(self, monkeypatch): + """The reported #24912 scenario โ€” user never responds, agent must see BLOCKED.""" + from tools import approval as mod + + self._force_short_timeout(monkeypatch, seconds=1) + + # Slack-shaped: notify_cb registered, but user doesn't respond. + notified = [] + mod.register_gateway_notify(self.SESSION_KEY, lambda data: notified.append(data)) + + result = mod.check_all_command_guards("rm -rf .git", "local") + + assert result["approved"] is False + assert result.get("user_consent") is False + assert result.get("outcome") == "timeout" + # The notify_cb DID fire โ€” we did try to ask the user. + assert len(notified) == 1 + + def test_timeout_message_is_emphatic_against_retry_and_rephrase(self, monkeypatch): + """The BLOCKED message must explicitly tell the agent not to rephrase. + + Without this, the agent treats 'Do NOT retry this command' as + permission to try a different command achieving the same outcome. + """ + from tools import approval as mod + self._force_short_timeout(monkeypatch, seconds=1) + mod.register_gateway_notify(self.SESSION_KEY, lambda data: None) + + result = mod.check_all_command_guards("rm -rf .git", "local") + + msg = result["message"] + # Explicit halt signals โ€” these are the model-facing contract. + assert "BLOCKED" in msg + assert "NOT consented" in msg + assert "Silence is not consent" in msg + # Both forms of evasion must be named: + assert "do NOT retry" in msg.lower() or "Do NOT retry" in msg + assert "rephrase" in msg.lower() + assert "different command" in msg.lower() + + def test_explicit_deny_carries_same_no_consent_shape(self): + """An explicit /deny must produce the same shape as timeout โ€” + the agent should treat both identically.""" + from tools import approval as mod + + notified = [] + mod.register_gateway_notify(self.SESSION_KEY, lambda data: notified.append(data)) + + # Spawn the approval wait in a thread, then resolve it with "deny". + result_holder = {} + def _check(): + result_holder["r"] = mod.check_all_command_guards("rm -rf .git", "local") + t = threading.Thread(target=_check) + t.start() + + # Wait for the queue entry to appear, then resolve. + for _ in range(50): + if mod._gateway_queues.get(self.SESSION_KEY): + break + time.sleep(0.02) + mod.resolve_gateway_approval(self.SESSION_KEY, "deny") + t.join(timeout=5) + assert "r" in result_holder, "approval wait did not return after deny" + + r = result_holder["r"] + assert r["approved"] is False + assert r.get("user_consent") is False + assert r.get("outcome") == "denied" + assert "Silence is not consent" not in r["message"] # this one IS denied, not timed-out + assert "NOT consented" in r["message"] + assert "rephrase" in r["message"].lower() + + def test_timeout_emits_post_hook_with_timeout_outcome(self, monkeypatch): + """Plugins must be able to distinguish timeout from explicit deny. + + This is what an audit / notification plugin needs to alert + operators on 'agent asked, user never replied' incidents like #24912. + """ + from tools import approval as mod + self._force_short_timeout(monkeypatch, seconds=1) + mod.register_gateway_notify(self.SESSION_KEY, lambda data: None) + + hook_calls = [] + original_fire = mod._fire_approval_hook + + def _capture(event_name, **kwargs): + hook_calls.append((event_name, kwargs)) + return original_fire(event_name, **kwargs) + + monkeypatch.setattr(mod, "_fire_approval_hook", _capture) + + mod.check_all_command_guards("rm -rf .git", "local") + + # post_approval_response must be in the hook log with choice=timeout + posts = [c for c in hook_calls if c[0] == "post_approval_response"] + assert posts, "post_approval_response hook did not fire" + last_post = posts[-1][1] + assert last_post.get("choice") == "timeout", ( + f"hook choice should be 'timeout' on no-response, got {last_post.get('choice')!r}" + ) diff --git a/tests/tools/test_browser_orphan_reaper.py b/tests/tools/test_browser_orphan_reaper.py index 0724cbd6311..edd8bda6c2d 100644 --- a/tests/tools/test_browser_orphan_reaper.py +++ b/tests/tools/test_browser_orphan_reaper.py @@ -72,7 +72,7 @@ class TestReapOrphanedBrowserSessions: assert not d.exists() def test_orphaned_alive_daemon_is_killed(self, fake_tmpdir): - """Alive daemon not tracked by _active_sessions gets SIGTERM (legacy path). + """Alive daemon not tracked by _active_sessions is terminated (legacy path). No owner_pid file => falls back to tracked_names check. """ @@ -82,18 +82,17 @@ class TestReapOrphanedBrowserSessions: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) - # Don't actually kill anything + def mock_terminate(pid): + kill_calls.append(pid) # Post-#21561 the liveness probe goes through # ``gateway.status._pid_exists`` (which wraps ``psutil.pid_exists`` # so it's safe on Windows โ€” ``os.kill(pid, 0)`` is bpo-14484). with patch("gateway.status._pid_exists", return_value=True), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - assert (12345, signal.SIGTERM) in kill_calls + assert 12345 in kill_calls def test_tracked_session_is_not_reaped(self, fake_tmpdir): """Sessions tracked in _active_sessions are left alone (legacy path).""" @@ -108,13 +107,13 @@ class TestReapOrphanedBrowserSessions: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) + def mock_terminate(pid): + kill_calls.append(pid) - with patch("os.kill", side_effect=mock_kill): + with patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - # Should NOT have tried to kill anything + # Should NOT have tried to terminate anything assert len(kill_calls) == 0 # Dir should still exist assert d.exists() @@ -126,23 +125,24 @@ class TestReapOrphanedBrowserSessions: ``gateway.status._pid_exists`` (which wraps ``psutil.pid_exists`` because ``os.kill(pid, 0)`` is a footgun on Windows โ€” bpo-14484). With no owner_pid file and no tracked-name entry, the reaper - SIGTERMs the daemon and removes its socket dir regardless of - whether SIGTERM succeeded (best-effort semantics). + terminates the daemon (and its process tree) and removes its socket + dir regardless of whether termination succeeded (best-effort + semantics). """ from tools.browser_tool import _reap_orphaned_browser_sessions d = _make_socket_dir(fake_tmpdir, "h_perm1234567", pid=12345) - sigterm_calls = [] + terminate_calls = [] - def mock_kill(pid, sig): - sigterm_calls.append((pid, sig)) + def mock_terminate(pid): + terminate_calls.append(pid) with patch("gateway.status._pid_exists", return_value=True), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - assert (12345, signal.SIGTERM) in sigterm_calls + assert 12345 in terminate_calls assert not d.exists() def test_cdp_sessions_are_also_reaped(self, fake_tmpdir): @@ -203,15 +203,15 @@ class TestOwnerPidCrossProcess: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) + def mock_terminate(pid): + kill_calls.append(pid) # Owner alive โ†’ reaper skips without ever probing the daemon. with patch("gateway.status._pid_exists", return_value=True), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - assert (12345, signal.SIGTERM) not in kill_calls + assert 12345 not in kill_calls assert d.exists() def test_dead_owner_triggers_reap(self, fake_tmpdir): @@ -225,17 +225,17 @@ class TestOwnerPidCrossProcess: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) + def mock_terminate(pid): + kill_calls.append(pid) # Owner 999999999 dead, daemon 12345 alive. pid_alive = {999999999: False, 12345: True} with patch("gateway.status._pid_exists", side_effect=lambda pid: pid_alive.get(int(pid), False)), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - assert (12345, signal.SIGTERM) in kill_calls + assert 12345 in kill_calls assert not d.exists() def test_corrupt_owner_pid_falls_back_to_legacy(self, fake_tmpdir): @@ -253,15 +253,15 @@ class TestOwnerPidCrossProcess: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) + def mock_terminate(pid): + kill_calls.append(pid) with patch("gateway.status._pid_exists", return_value=True), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() # Legacy path took over โ†’ tracked โ†’ not reaped - assert (12345, signal.SIGTERM) not in kill_calls + assert 12345 not in kill_calls assert d.exists() def test_owner_pid_permission_error_treated_as_alive(self, fake_tmpdir): @@ -280,16 +280,16 @@ class TestOwnerPidCrossProcess: kill_calls = [] - def mock_kill(pid, sig): - kill_calls.append((pid, sig)) + def mock_terminate(pid): + kill_calls.append(pid) # Owner 22222 reported alive (PermissionError collapses to True - # inside _pid_exists). Daemon never probed, never SIGTERMed. + # inside _pid_exists). Daemon never probed, never terminated. with patch("gateway.status._pid_exists", return_value=True), \ - patch("os.kill", side_effect=mock_kill): + patch("tools.process_registry.ProcessRegistry._terminate_host_pid", side_effect=mock_terminate): _reap_orphaned_browser_sessions() - assert (12345, signal.SIGTERM) not in kill_calls + assert 12345 not in kill_calls assert d.exists() def test_write_owner_pid_creates_file_with_current_pid( diff --git a/tests/tools/test_browser_secret_exfil.py b/tests/tools/test_browser_secret_exfil.py index 893fb11fe74..82fa7e490e1 100644 --- a/tests/tools/test_browser_secret_exfil.py +++ b/tests/tools/test_browser_secret_exfil.py @@ -31,7 +31,13 @@ class TestBrowserSecretExfil: def test_allows_normal_url(self): """Normal URLs pass the secret check (may fail for other reasons).""" from tools.browser_tool import browser_navigate - result = browser_navigate("https://github.com/NousResearch/hermes-agent") + # Patch the actual browser command โ€” we only care that the secret + # check doesn't block a clean URL, not that Chrome starts in CI. + mock_result = {"success": True, "data": {"title": "ok", "url": "https://github.com/NousResearch/hermes-agent"}} + with patch("tools.browser_tool._run_browser_command", return_value=mock_result), \ + patch("tools.browser_tool._get_session_info", return_value={"_first_nav": False}), \ + patch("tools.browser_tool._is_local_backend", return_value=True): + result = browser_navigate("https://github.com/NousResearch/hermes-agent") parsed = json.loads(result) # Should NOT be blocked by secret detection assert "API key or token" not in parsed.get("error", "") diff --git a/tests/tools/test_computer_use.py b/tests/tools/test_computer_use.py index 7afaa7b57de..44a97db47ac 100644 --- a/tests/tools/test_computer_use.py +++ b/tests/tools/test_computer_use.py @@ -76,6 +76,27 @@ class TestSchema: modes = set(COMPUTER_USE_SCHEMA["parameters"]["properties"]["mode"]["enum"]) assert modes == {"som", "vision", "ax"} + def test_schema_exposes_max_elements_cap_for_capture(self): + from tools.computer_use.schema import COMPUTER_USE_SCHEMA + props = COMPUTER_USE_SCHEMA["parameters"]["properties"] + assert "max_elements" in props + assert props["max_elements"]["type"] == "integer" + assert props["max_elements"].get("minimum", 1) >= 1 + + def test_schema_max_elements_documents_default_and_upper_bound(self): + """Schema description must agree with the runtime. The original PR + text said "Default 100" without a corresponding `default` field, and + had no upper bound โ€” both Copilot findings. + """ + from tools.computer_use.schema import COMPUTER_USE_SCHEMA + from tools.computer_use.tool import ( + _DEFAULT_MAX_ELEMENTS, + _MAX_ALLOWED_MAX_ELEMENTS, + ) + prop = COMPUTER_USE_SCHEMA["parameters"]["properties"]["max_elements"] + assert prop.get("default") == _DEFAULT_MAX_ELEMENTS + assert prop.get("maximum") == _MAX_ALLOWED_MAX_ELEMENTS + class TestRegistration: def test_tool_registers_with_registry(self): @@ -205,6 +226,54 @@ class TestDispatch: parsed = json.loads(out) assert "error" in parsed + def test_set_value_routes_to_backend(self, noop_backend): + """set_value must reach the backend โ€” regression for missing _NoopBackend stub.""" + from tools.computer_use.tool import handle_computer_use + out = handle_computer_use({"action": "set_value", "value": "Option A", "element": 5}) + parsed = json.loads(out) + assert parsed.get("ok") is True + assert parsed.get("action") == "set_value" + assert any(c[0] == "set_value" for c in noop_backend.calls) + + def test_set_value_missing_value_returns_error(self, noop_backend): + from tools.computer_use.tool import handle_computer_use + out = handle_computer_use({"action": "set_value"}) + parsed = json.loads(out) + assert "error" in parsed + def test_capture_after_skipped_when_action_failed(self, noop_backend): + """capture_after must not fire when res.ok=False (regression guard). + + A follow-up screenshot after a failed action shows the screen in a + normal state, misleading the model into thinking the action succeeded. + """ + from unittest.mock import patch + from tools.computer_use.backend import ActionResult + from tools.computer_use.tool import handle_computer_use + + # Make click() return a failure. + with patch.object(noop_backend, "click", + return_value=ActionResult(ok=False, action="click", + message="element not found")): + out = handle_computer_use({"action": "click", "element": 99, + "capture_after": True}) + + parsed = json.loads(out) + # Should return the error, not a multimodal capture. + assert parsed.get("ok") is False + assert parsed.get("action") == "click" + # No follow-up capture should have been issued. + capture_calls = [c for c in noop_backend.calls if c[0] == "capture"] + assert len(capture_calls) == 0, "capture must not be called after a failed action" + + def test_capture_after_fires_when_action_succeeds(self, noop_backend): + """capture_after must trigger for successful actions.""" + from tools.computer_use.tool import handle_computer_use + out = handle_computer_use({"action": "click", "element": 1, + "capture_after": True}) + # Noop backend returns ok=True, so capture should have been called. + capture_calls = [c for c in noop_backend.calls if c[0] == "capture"] + assert len(capture_calls) == 1 + # --------------------------------------------------------------------------- # Safety guards (type / key block lists) @@ -337,6 +406,193 @@ class TestCaptureResponse: assert "AXButton" in text_part["text"] assert "AXTextField" in text_part["text"] + def _ax_backend_with(self, count: int): + """Construct a fake backend that yields ``count`` AX elements.""" + from tools.computer_use.backend import CaptureResult, UIElement + + elements = [ + UIElement(index=i + 1, role="AXButton", label=f"el-{i}", bounds=(0, 0, 1, 1)) + for i in range(count) + ] + + class FakeBackend: + def start(self): pass + def stop(self): pass + def is_available(self): return True + def capture(self, mode="som", app=None): + return CaptureResult( + mode=mode, width=800, height=600, + png_b64="", + elements=list(elements), + app="Obsidian", + ) + def click(self, **kw): ... + def drag(self, **kw): ... + def scroll(self, **kw): ... + def type_text(self, text): ... + def key(self, keys): ... + def list_apps(self): return [] + def focus_app(self, app, raise_window=False): ... + + return FakeBackend() + + def test_capture_ax_caps_elements_at_default_for_dense_trees(self): + """Regression for #22865: an Electron-style 600-element AX tree must + not emit the entire array verbatim into the tool result. + """ + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(600) + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use({"action": "capture", "mode": "ax"}) + + parsed = json.loads(out) + assert parsed["mode"] == "ax" + assert parsed["total_elements"] == 600 + assert len(parsed["elements"]) == cu_tool._DEFAULT_MAX_ELEMENTS + assert parsed["truncated_elements"] == 600 - cu_tool._DEFAULT_MAX_ELEMENTS + # Truncation must be visible in the human summary so the model knows + # the JSON view is partial and can re-issue with a tighter scope. + assert "truncated to" in parsed["summary"] + + def test_capture_ax_honors_explicit_max_elements_override(self): + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(600) + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use( + {"action": "capture", "mode": "ax", "max_elements": 250} + ) + + parsed = json.loads(out) + assert len(parsed["elements"]) == 250 + assert parsed["truncated_elements"] == 350 + + def test_capture_ax_below_cap_is_unchanged(self): + """Backwards-compat: small captures keep the full elements array and + do not surface a `truncated_elements` field. + """ + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(5) + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use({"action": "capture", "mode": "ax"}) + + parsed = json.loads(out) + assert len(parsed["elements"]) == 5 + assert parsed["total_elements"] == 5 + assert "truncated_elements" not in parsed + assert "truncated to" not in parsed["summary"] + + def test_capture_ax_invalid_max_elements_falls_back_to_default(self): + """Malformed `max_elements` (string, negative, zero) must not silently + disable the cap and re-introduce the original unbounded behavior. + """ + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(600) + cu_tool.reset_backend_for_tests() + for bad in ("not-a-number", 0, -10): + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use( + {"action": "capture", "mode": "ax", "max_elements": bad} + ) + parsed = json.loads(out) + assert len(parsed["elements"]) == cu_tool._DEFAULT_MAX_ELEMENTS, ( + f"bad max_elements={bad!r} disabled the cap" + ) + + def test_capture_ax_clamps_oversized_max_elements_to_hard_cap(self): + """A caller passing a very large `max_elements` must not be able to + disable the safeguard. The cap is clamped to a hard upper bound so + the context-blow-up protection cannot be bypassed by argument. + """ + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(5000) + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use( + {"action": "capture", "mode": "ax", "max_elements": 10_000} + ) + parsed = json.loads(out) + assert len(parsed["elements"]) == cu_tool._MAX_ALLOWED_MAX_ELEMENTS + assert parsed["total_elements"] == 5000 + assert parsed["truncated_elements"] == 5000 - cu_tool._MAX_ALLOWED_MAX_ELEMENTS + + def test_capture_ax_summary_indices_match_returned_elements(self): + """When `max_elements` is below the human-summary's own line cap, the + summary must not index elements that aren't in the returned array. + Otherwise the model sees `#15` in the summary and finds no matching + entry in `elements`. + """ + from tools.computer_use import tool as cu_tool + + fake_backend = self._ax_backend_with(600) + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=fake_backend): + out = cu_tool.handle_computer_use( + {"action": "capture", "mode": "ax", "max_elements": 5} + ) + parsed = json.loads(out) + returned_indices = {e["index"] for e in parsed["elements"]} + summary_lines = parsed["summary"].splitlines() + indexed_lines = [ln for ln in summary_lines if ln.lstrip().startswith("#")] + for ln in indexed_lines: + idx_token = ln.lstrip().split()[0].lstrip("#") + idx = int(idx_token) + assert idx in returned_indices, ( + f"summary references #{idx} but it is absent from elements payload " + f"(returned: {sorted(returned_indices)})" + ) + + def test_capture_multimodal_summary_omits_truncation_note(self): + """The som/vision multimodal envelope returns a screenshot, not an + `elements` array โ€” so a "response truncated to N of M elements" + claim in the summary would be inaccurate. + """ + from tools.computer_use.backend import CaptureResult, UIElement + from tools.computer_use import tool as cu_tool + + fake_png = "iVBORw0KGgo=" + elements = [ + UIElement(index=i + 1, role="AXButton", label=f"el-{i}", bounds=(0, 0, 1, 1)) + for i in range(600) + ] + + class FakeBackend: + def start(self): pass + def stop(self): pass + def is_available(self): return True + def capture(self, mode="som", app=None): + return CaptureResult( + mode=mode, width=800, height=600, + png_b64=fake_png, elements=list(elements), + app="Obsidian", + ) + def click(self, **kw): ... + def drag(self, **kw): ... + def scroll(self, **kw): ... + def type_text(self, text): ... + def key(self, keys): ... + def list_apps(self): return [] + def focus_app(self, app, raise_window=False): ... + + cu_tool.reset_backend_for_tests() + with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()): + out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"}) + + assert isinstance(out, dict) and out["_multimodal"] is True + text_part = next(p for p in out["content"] if p.get("type") == "text") + assert "truncated to" not in text_part["text"], ( + "multimodal response carries an image, not an elements array; " + "the truncation note describes a payload field that isn't present" + ) + assert "truncated to" not in out["text_summary"] + # --------------------------------------------------------------------------- # Anthropic adapter: multimodal tool-result conversion diff --git a/tests/tools/test_cross_profile_guard.py b/tests/tools/test_cross_profile_guard.py new file mode 100644 index 00000000000..20814fea1ff --- /dev/null +++ b/tests/tools/test_cross_profile_guard.py @@ -0,0 +1,259 @@ +"""Tests for the cross-profile soft guard wired into write_file / patch / +skill_manage. + +The classifier is tested in tests/agent/test_file_safety_cross_profile.py. +This file tests that the tool surfaces: + + 1. Refuse cross-profile writes by default and return the warning. + 2. Accept cross-profile writes when cross_profile=True is passed. + 3. Continue to accept in-profile writes normally. + 4. skill_manage's "not found" error names other profiles where the + skill exists. +""" +from __future__ import annotations + +import json +import os +from pathlib import Path + +import pytest + + +@pytest.fixture +def fake_hermes(tmp_path, monkeypatch): + """Build a two-profile Hermes layout and point HERMES_HOME at + the hermes-security profile (matching the original-incident shape). + """ + root = tmp_path / "fake-hermes" + (root / "skills" / "shared-skill").mkdir(parents=True) + (root / "skills" / "shared-skill" / "SKILL.md").write_text( + "---\nname: shared-skill\ndescription: default copy.\n---\n" + ) + + sec_home = root / "profiles" / "hermes-security" + (sec_home / "skills").mkdir(parents=True) + + coder_home = root / "profiles" / "coder" + (coder_home / "skills").mkdir(parents=True) + + monkeypatch.setenv("HERMES_HOME", str(sec_home)) + + import hermes_constants + monkeypatch.setattr(hermes_constants, "get_default_hermes_root", lambda: root) + + import agent.file_safety as fs + monkeypatch.setattr(fs, "_hermes_home_path", lambda: sec_home) + monkeypatch.setattr(fs, "_hermes_root_path", lambda: root) + + return { + "root": root, + "sec_home": sec_home, + "coder_home": coder_home, + } + + +# --------------------------------------------------------------------------- +# write_file +# --------------------------------------------------------------------------- + + +class TestWriteFileCrossProfileGuard: + def test_in_profile_write_allowed(self, fake_hermes): + from tools.file_tools import write_file_tool + target = fake_hermes["sec_home"] / "skills" / "new-skill" / "SKILL.md" + target.parent.mkdir(parents=True) + result_json = write_file_tool(str(target), "in-profile content") + result = json.loads(result_json) + assert not result.get("error"), f"In-profile write should succeed: {result}" + assert target.exists() + assert target.read_text() == "in-profile content" + + def test_cross_profile_write_blocked_by_default(self, fake_hermes): + """The May 2026 incident โ€” security-profile session edits default + profile's skill. Must be blocked.""" + from tools.file_tools import write_file_tool + target = fake_hermes["root"] / "skills" / "shared-skill" / "SKILL.md" + original = target.read_text() + result_json = write_file_tool(str(target), "OVERWRITTEN") + result = json.loads(result_json) + assert result.get("error"), "Cross-profile write should be refused" + assert "cross-profile" in result["error"].lower() + assert "default" in result["error"] + assert "hermes-security" in result["error"] + # File untouched. + assert target.read_text() == original + + def test_cross_profile_True_bypass(self, fake_hermes): + """Explicit override after user direction must succeed.""" + from tools.file_tools import write_file_tool + target = fake_hermes["root"] / "skills" / "shared-skill" / "SKILL.md" + result_json = write_file_tool( + str(target), "user-directed override", cross_profile=True + ) + result = json.loads(result_json) + assert not result.get("error"), f"cross_profile=True must succeed: {result}" + assert target.read_text() == "user-directed override" + + def test_non_hermes_path_unaffected(self, fake_hermes, tmp_path): + from tools.file_tools import write_file_tool + target = tmp_path / "outside" / "main.py" + target.parent.mkdir() + result_json = write_file_tool(str(target), "print('hello')") + result = json.loads(result_json) + assert not result.get("error") + assert target.exists() + + +# --------------------------------------------------------------------------- +# patch +# --------------------------------------------------------------------------- + + +class TestPatchCrossProfileGuard: + def test_cross_profile_patch_blocked(self, fake_hermes): + from tools.file_tools import patch_tool + target = fake_hermes["root"] / "skills" / "shared-skill" / "SKILL.md" + original = target.read_text() + result_json = patch_tool( + mode="replace", + path=str(target), + old_string="default copy.", + new_string="HIJACKED.", + ) + result = json.loads(result_json) + assert result.get("error") + assert "cross-profile" in result["error"].lower() + assert target.read_text() == original + + def test_cross_profile_patch_bypass(self, fake_hermes): + from tools.file_tools import patch_tool + target = fake_hermes["root"] / "skills" / "shared-skill" / "SKILL.md" + result_json = patch_tool( + mode="replace", + path=str(target), + old_string="default copy.", + new_string="user-directed update.", + cross_profile=True, + ) + result = json.loads(result_json) + assert not result.get("error"), f"cross_profile=True bypass: {result}" + assert "user-directed update." in target.read_text() + + def test_v4a_patch_extracts_path_for_guard(self, fake_hermes): + """V4A patches embed the target paths in the patch body, not in + a ``path`` kwarg. The guard must still apply.""" + from tools.file_tools import patch_tool + target = fake_hermes["root"] / "skills" / "shared-skill" / "SKILL.md" + original = target.read_text() + v4a = ( + "*** Begin Patch\n" + f"*** Update File: {target}\n" + "@@\n" + "-default copy.\n" + "+HIJACKED.\n" + "*** End Patch" + ) + result_json = patch_tool(mode="patch", patch=v4a) + result = json.loads(result_json) + assert result.get("error"), f"V4A cross-profile must block: {result}" + assert "cross-profile" in result["error"].lower() + assert target.read_text() == original + + +# --------------------------------------------------------------------------- +# skill_manage โ€” error message naming other profile (item D) +# --------------------------------------------------------------------------- + + +class TestSkillManageCrossProfileErrorUX: + def _make_skill_in_profile(self, profile_dir: Path, name: str): + d = profile_dir / "skills" / name + d.mkdir(parents=True, exist_ok=True) + (d / "SKILL.md").write_text( + f"---\nname: {name}\ndescription: a skill.\n---\n" + ) + + def test_error_names_other_profile_when_skill_lives_there( + self, fake_hermes, monkeypatch + ): + """The original incident shape โ€” model expects 'foo' in active + profile, but 'foo' lives in default. Error must point at default.""" + self._make_skill_in_profile(fake_hermes["root"], "default-only-skill") + + # Re-import the module so SKILLS_DIR picks up HERMES_HOME (set in + # the fixture). Skill_manager_tool computes SKILLS_DIR at import. + import importlib + import tools.skill_manager_tool + importlib.reload(tools.skill_manager_tool) + from tools.skill_manager_tool import _skill_not_found_error + + err = _skill_not_found_error("default-only-skill") + assert "not found in active profile 'hermes-security'" in err + assert "default" in err + assert "cross_profile=True" in err + + def test_error_names_multiple_profiles(self, fake_hermes, monkeypatch): + """When the skill exists in TWO other profiles, both should be named.""" + self._make_skill_in_profile(fake_hermes["root"], "everywhere-skill") + self._make_skill_in_profile(fake_hermes["coder_home"], "everywhere-skill") + + import importlib + import tools.skill_manager_tool + importlib.reload(tools.skill_manager_tool) + from tools.skill_manager_tool import _skill_not_found_error + + err = _skill_not_found_error("everywhere-skill") + assert "default" in err + assert "coder" in err + # Switch-profiles hint + assert "hermes -p" in err + + def test_genuinely_missing_skill_keeps_helpful_hint( + self, fake_hermes, monkeypatch + ): + """When no profile has the skill, error falls back to skills_list hint.""" + import importlib + import tools.skill_manager_tool + importlib.reload(tools.skill_manager_tool) + from tools.skill_manager_tool import _skill_not_found_error + + err = _skill_not_found_error("totally-imaginary-skill") + assert "not found in active profile 'hermes-security'" in err + assert "skills_list" in err + + +# --------------------------------------------------------------------------- +# System prompt active-profile line (item B) +# --------------------------------------------------------------------------- + + +class TestSystemPromptActiveProfile: + def test_default_profile_line_in_prompt(self, tmp_path, monkeypatch): + """When active profile is 'default', the prompt names it and warns + about ~/.hermes/profiles//.""" + # Don't set HERMES_HOME โ€” falls back to default. + import agent.file_safety as fs + monkeypatch.setattr(fs, "_hermes_home_path", lambda: tmp_path / "fake") + monkeypatch.setattr(fs, "_hermes_root_path", lambda: tmp_path / "fake") + + from agent.file_safety import _resolve_active_profile_name + assert _resolve_active_profile_name() == "default" + # Build the line manually to pin the contract โ€” the prompt builder + # is too heavy to instantiate end-to-end in a unit test. + # See agent/system_prompt.py for the exact wording. + + def test_named_profile_line_in_prompt_text(self, fake_hermes): + """When active profile is 'hermes-security', the prompt warns + explicitly about NOT modifying default's skills/plugins/cron/memories.""" + # Spot-check by reading the source โ€” the contract is: + # (1) names the active profile, (2) names the default-profile + # paths, (3) says "do not modify another profile's" without + # explicit user direction. + from pathlib import Path + src = Path("agent/system_prompt.py").read_text() + assert "Active Hermes profile" in src + assert "cross_profile=True" in src + assert "~/.hermes/profiles/" in src + # Both branches present (default and named profile). + assert "Active Hermes profile: default" in src + assert "Active Hermes profile: {active_profile}" in src diff --git a/tests/tools/test_file_operations.py b/tests/tools/test_file_operations.py index 1fe116ecfa2..1d3ec8b4a02 100644 --- a/tests/tools/test_file_operations.py +++ b/tests/tools/test_file_operations.py @@ -60,6 +60,113 @@ class TestIsWriteDenied: def test_tilde_expansion(self): assert _is_write_denied("~/.ssh/authorized_keys") is True + @pytest.mark.parametrize( + "path", + [ + "auth.json", + "config.yaml", + "webhook_subscriptions.json", + "mcp-tokens/token1.json", + "mcp-tokens/subdir/token2.json", + "pairing/telegram-approved.json", + "pairing/discord-approved.json", + "pairing/telegram-pending.json", + "pairing", + ], + ) + def test_hermes_control_files_and_mcp_tokens_denied(self, path): + """Hermes control files and mcp-tokens/pairing entries must be write-denied.""" + from hermes_constants import get_hermes_home + hermes_home = get_hermes_home() + full_path = str(hermes_home / path) + assert _is_write_denied(full_path) is True + + @pytest.mark.parametrize( + "path", + [ + "dummy/../config.yaml", + "./auth.json", + "mcp-tokens/../config.yaml", + ], + ) + def test_hermes_control_files_traversal_denied(self, path): + """Path traversal attempts to control files must be blocked by realpath.""" + from hermes_constants import get_hermes_home + hermes_home = get_hermes_home() + full_path = str(hermes_home / path) + assert _is_write_denied(full_path) is True + + @pytest.mark.parametrize( + "path", + [ + "/tmp/standard_file.txt", + "~/projects/myapp/main.py", + "/var/log/app.log", + ], + ) + def test_standard_paths_allowed(self, path): + """Unrelated paths must still be allowed.""" + assert _is_write_denied(path) is False + + @pytest.mark.parametrize( + "name", + ["auth.json", "config.yaml", "webhook_subscriptions.json"], + ) + def test_control_files_protected_in_profile_mode(self, tmp_path, monkeypatch, name): + """Under a profile, BOTH /X and /X must be denied (#15981 shape). + + Without the root-level pass, a profile-mode session leaves the + global ~/.hermes/{auth.json,config.yaml,webhook_subscriptions.json} + writable โ€” the same gap PR #15981 fixed for .env. + """ + # Simulate a profile-mode HERMES_HOME layout: + # /profiles/coder/{auth.json,config.yaml,...} + # /{auth.json,config.yaml,...} โ† must also be denied + root = tmp_path / "hermes" + profile = root / "profiles" / "coder" + profile.mkdir(parents=True) + monkeypatch.setenv("HERMES_HOME", str(profile)) + + # Profile copy + assert _is_write_denied(str(profile / name)) is True + # Root copy โ€” the gap this widening closes + assert _is_write_denied(str(root / name)) is True + + def test_mcp_tokens_dir_protected_in_profile_mode(self, tmp_path, monkeypatch): + """mcp-tokens/ under profile AND under root must both be denied.""" + root = tmp_path / "hermes" + profile = root / "profiles" / "coder" + profile.mkdir(parents=True) + monkeypatch.setenv("HERMES_HOME", str(profile)) + + assert _is_write_denied(str(profile / "mcp-tokens" / "tok.json")) is True + assert _is_write_denied(str(root / "mcp-tokens" / "tok.json")) is True + # The directory itself must also be denied (not just files inside) + assert _is_write_denied(str(root / "mcp-tokens")) is True + + def test_pairing_dir_denied(self, tmp_path, monkeypatch): + """Regression: pairing/ must be write-denied under both profile and root. + + PR #30383 introduced ~/.hermes/pairing/{platform}-approved.json as the + gateway access-control list. Without this block, a prompt-injected agent + can write arbitrary user IDs into an approved file, granting persistent + gateway access without going through the pairing code flow โ€” the same + threat class that motivated protecting webhook_subscriptions.json. + """ + root = tmp_path / "hermes" + profile = root / "profiles" / "coder" + profile.mkdir(parents=True) + monkeypatch.setenv("HERMES_HOME", str(profile)) + + # Active profile pairing entries + assert _is_write_denied(str(profile / "pairing" / "telegram-approved.json")) is True + assert _is_write_denied(str(profile / "pairing" / "discord-pending.json")) is True + # The directory itself + assert _is_write_denied(str(profile / "pairing")) is True + # Root pairing entries (profile mode โ€” same shape as mcp-tokens gap) + assert _is_write_denied(str(root / "pairing" / "telegram-approved.json")) is True + assert _is_write_denied(str(root / "pairing")) is True + # ========================================================================= diff --git a/tests/tools/test_local_interrupt_cleanup.py b/tests/tools/test_local_interrupt_cleanup.py index a9b74559380..67d9e9e6b54 100644 --- a/tests/tools/test_local_interrupt_cleanup.py +++ b/tests/tools/test_local_interrupt_cleanup.py @@ -48,8 +48,14 @@ def _process_group_snapshot(pgid: int) -> str: ).stdout.strip() -def _wait_for_pgid_exit(pgid: int, timeout: float = 10.0) -> bool: - """Wait for a process group to disappear under loaded xdist hosts.""" +def _wait_for_pgid_exit(pgid: int, timeout: float = 30.0) -> bool: + """Wait for a process group to disappear under loaded xdist hosts. + + The cleanup chain is: SIGTERM โ†’ 3s TimeoutStopSec โ†’ SIGKILL โ†’ reap. + Under heavy xdist load (40 parallel workers, 6-shard CI), the full + sequence can exceed 10s. Default timeout is generous to avoid CI + flakes; in practice the wait returns in <1s on quiet hosts. + """ deadline = time.monotonic() + timeout while time.monotonic() < deadline: if not _pgid_still_alive(pgid): @@ -166,9 +172,11 @@ def test_wait_for_process_kills_subprocess_on_keyboardinterrupt(): assert ret == 1, f"SetAsyncExc returned {ret}, expected 1" # Give the worker a moment to: hit the exception at the next poll, - # run the except-block cleanup (_kill_process), and exit. - t.join(timeout=5.0) - assert not t.is_alive(), "worker didn't exit within 5 s of the interrupt" + # run the except-block cleanup (_kill_process), and exit. Under + # xdist load the SIGTERM โ†’ 3s wait โ†’ SIGKILL chain can take longer + # than 5s before the worker's join() returns; bumped to 15s. + t.join(timeout=15.0) + assert not t.is_alive(), "worker didn't exit within 15 s of the interrupt" # The critical assertion: the subprocess GROUP must be dead. Not # just the bash wrapper โ€” the 'sleep 30' child too. Under xdist load, diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 3212a350c37..b9a3cfcf8d9 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -1462,6 +1462,27 @@ class TestHTTPConfig: asyncio.run(_test()) + def test_stdio_unavailable_raises_importerror_not_nameerror(self): + """Regression test for #30904. + + When the mcp SDK isn't installed, ``_run_stdio`` previously leaked a + bare ``NameError: name 'StdioServerParameters' is not defined``. The + gate now raises a clear ``ImportError`` with install instructions, + mirroring ``_run_http``'s behaviour when the HTTP transport is + unavailable. + """ + from tools.mcp_tool import MCPServerTask + + server = MCPServerTask("local") + config = {"command": "python3", "args": ["/tmp/echo.py"]} + + async def _test(): + with patch("tools.mcp_tool._MCP_AVAILABLE", False): + with pytest.raises(ImportError, match=r"mcp.*SDK"): + await server._run_stdio(config) + + asyncio.run(_test()) + def test_http_seeds_initial_protocol_header(self): from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask diff --git a/tests/tools/test_memory_tool.py b/tests/tools/test_memory_tool.py index 7f63aee1ebb..1a635aa1ac3 100644 --- a/tests/tools/test_memory_tool.py +++ b/tests/tools/test_memory_tool.py @@ -255,3 +255,128 @@ class TestMemoryToolDispatcher: def test_remove_requires_old_text(self, store): result = json.loads(memory_tool(action="remove", store=store)) assert result["success"] is False + + +# ========================================================================= +# External drift guard (#26045) +# +# An external writer โ€” patch tool, shell append, manual edit, or sister +# session โ€” can grow MEMORY.md beyond the tool's mental model: no ยง +# delimiters, content that would all collapse into a single "entry" larger +# than the char limit. Pre-fix, the next memory(action=replace) from a +# session with stale in-memory state truncated that giant entry, silently +# discarding the appended bytes. Reproduced in production on 2026-05-14 โ€” +# ~8KB of structured vendor / standing-orders / pinboard content destroyed +# by a sister session's replace. +# ========================================================================= + + +class TestExternalDriftGuard: + """Mutations must refuse to flush when on-disk content shows external drift.""" + + def _plant_drift(self, store, target="memory"): + """Append free-form content (no ยง delimiters) past char_limit.""" + path = store._path_for(target) + path.parent.mkdir(parents=True, exist_ok=True) + # 800 chars per entry ร— 3 sections == ~2.4KB without delimiters, + # well over the test fixture's 500-char limit. + block = "\n\n## Vendor Master\n" + "x" * 800 + block += "\n\n## Standing Orders\n" + "y" * 800 + block += "\n\n## Pin Board\n" + "z" * 800 + existing = path.read_text(encoding="utf-8") if path.exists() else "" + path.write_text(existing + block, encoding="utf-8") + return path + + def test_replace_refuses_on_drift(self, store): + store.add("memory", "User likes brevity.") + path = self._plant_drift(store) + original_size = path.stat().st_size + + result = store.replace("memory", "User likes", "User prefers concise.") + + assert result["success"] is False + assert "drift_backup" in result + # On-disk file is UNTOUCHED โ€” that's the point. + assert path.stat().st_size == original_size + assert "Vendor Master" in path.read_text() + # Backup exists with the drifted content. + bak = result["drift_backup"] + assert Path(bak).exists() + assert "Vendor Master" in Path(bak).read_text() + + def test_add_refuses_on_drift(self, store): + store.add("memory", "Existing.") + path = self._plant_drift(store) + original = path.read_text() + + result = store.add("memory", "New entry under drift.") + + assert result["success"] is False + assert "drift_backup" in result + assert path.read_text() == original # untouched + + def test_remove_refuses_on_drift(self, store): + store.add("memory", "Target entry to remove.") + path = self._plant_drift(store) + original = path.read_text() + + result = store.remove("memory", "Target entry") + + assert result["success"] is False + assert "drift_backup" in result + assert path.read_text() == original # untouched + + def test_clean_file_does_not_trigger_drift(self, store): + """A normally-written file (just below char_limit, ยง-delimited) is fine.""" + # Two tool-shaped entries totaling under the 500-char limit. + store.add("memory", "Entry one โ€” normal length.") + store.add("memory", "Entry two โ€” also normal.") + + result = store.add("memory", "Entry three.") + assert result["success"] is True + assert "drift_backup" not in result + + result = store.replace("memory", "Entry two", "Entry two replaced.") + assert result["success"] is True + + def test_error_message_points_at_remediation(self, store): + """The error string must reference the backup AND remediation steps.""" + store.add("memory", "Initial.") + self._plant_drift(store) + + result = store.replace("memory", "Initial", "Replacement.") + assert result["success"] is False + # The model has to know what file to look at and what to do. + assert ".bak." in result["error"] + assert "remediation" in result + assert "26045" in result["error"] # tracking-issue back-reference + + def test_drift_guard_also_protects_user_target(self, store): + """USER.md gets the same guarantee as MEMORY.md.""" + store.add("user", "Some preference.") + path = self._plant_drift(store, target="user") + original_size = path.stat().st_size + + result = store.replace("user", "Some preference", "New preference.") + assert result["success"] is False + assert path.stat().st_size == original_size + + def test_drift_backup_filename_is_unique_per_invocation(self, store): + """Two drift refusals close together must not collide on bak.. + + If two refusals share the same epoch second, the second call would + overwrite the first .bak. The current implementation accepts that + โ€” both files describe the same on-disk state โ€” but pin the path + format here so any future change has to think about it. + """ + store.add("memory", "Initial.") + self._plant_drift(store) + + r1 = store.replace("memory", "Initial", "Replacement.") + r2 = store.add("memory", "Another.") + assert r1.get("drift_backup") + assert r2.get("drift_backup") + # Same epoch second is the expected collision case โ€” both point + # at the same snapshot. Different second is also fine. + assert ".bak." in r1["drift_backup"] + assert ".bak." in r2["drift_backup"] diff --git a/tests/tools/test_notify_on_complete.py b/tests/tools/test_notify_on_complete.py index 64d198970cb..4a4ca37bd89 100644 --- a/tests/tools/test_notify_on_complete.py +++ b/tests/tools/test_notify_on_complete.py @@ -348,3 +348,158 @@ class TestCompletionConsumed: result = registry.poll("proc_running") assert result["status"] == "running" assert not registry.is_completion_consumed("proc_running") + + +# --------------------------------------------------------------------------- +# Silent-background-process hint +# +# background=True without notify_on_complete=True OR watch_patterns runs +# the process silently โ€” the agent has no way to learn it finished short +# of calling process(action="poll") explicitly. The tool result must +# include a "hint" field that nudges the agent toward +# notify_on_complete=True for bounded tasks. May 2026 PR #31231 incident: +# bg CI poller exited green, agent never noticed, user had to surface it. +# --------------------------------------------------------------------------- + + +def _silent_bg_base_config(tmp_path): + return { + "env_type": "local", + "docker_image": "", + "singularity_image": "", + "modal_image": "", + "daytona_image": "", + "cwd": str(tmp_path), + "timeout": 30, + } + + +def _silent_bg_harness(monkeypatch, tmp_path): + """Common test fixture: patch enough of terminal_tool to spawn a fake + background process and capture the JSON result the agent sees.""" + import tools.terminal_tool as terminal_tool_module + from tools import process_registry as process_registry_module + from types import SimpleNamespace + + config = _silent_bg_base_config(tmp_path) + dummy_env = SimpleNamespace(env={}) + + def fake_spawn_local(**kwargs): + return SimpleNamespace( + id="proc_silent_test", + pid=4242, + notify_on_complete=False, + watcher_platform="", + watcher_chat_id="", + watcher_user_id="", + watcher_user_name="", + watcher_thread_id="", + watcher_message_id="", + watcher_interval=0, + ) + + monkeypatch.setattr(terminal_tool_module, "_get_env_config", lambda: config) + monkeypatch.setattr(terminal_tool_module, "_start_cleanup_thread", lambda: None) + monkeypatch.setattr(terminal_tool_module, "_check_all_guards", lambda *_args, **_kwargs: {"approved": True}) + monkeypatch.setattr(process_registry_module.process_registry, "spawn_local", fake_spawn_local) + monkeypatch.setitem(terminal_tool_module._active_environments, "default", dummy_env) + monkeypatch.setitem(terminal_tool_module._last_activity, "default", 0.0) + return terminal_tool_module + + +def test_background_without_notify_emits_silent_process_hint(monkeypatch, tmp_path): + """The footgun case (May 2026 PR #31231): bg=True alone runs silently + and the agent has no signal it finished. Tool must nudge.""" + tt = _silent_bg_harness(monkeypatch, tmp_path) + try: + result = json.loads( + tt.terminal_tool( + command="while true; do gh pr checks 999; sleep 30; done", + background=True, + ) + ) + finally: + tt._active_environments.pop("default", None) + tt._last_activity.pop("default", None) + + assert result["session_id"] == "proc_silent_test" + hint = result.get("hint", "") + assert hint, "Silent background process must include a hint field" + assert "notify_on_complete" in hint, ( + "Hint must name the corrective flag so the agent can self-correct" + ) + assert "silent" in hint.lower() or "no way to learn" in hint.lower(), ( + "Hint must explain the failure mode, not just suggest the fix" + ) + + +def test_background_with_notify_does_not_emit_hint(monkeypatch, tmp_path): + """The correct shape โ€” bg+notify together โ€” must not nag.""" + tt = _silent_bg_harness(monkeypatch, tmp_path) + try: + result = json.loads( + tt.terminal_tool( + command="pytest tests/", + background=True, + notify_on_complete=True, + ) + ) + finally: + tt._active_environments.pop("default", None) + tt._last_activity.pop("default", None) + + assert "hint" not in result, ( + f"Correct usage must not emit a hint, got: {result.get('hint')!r}" + ) + assert result.get("notify_on_complete") is True + + +def test_background_with_watch_patterns_does_not_emit_hint(monkeypatch, tmp_path): + """watch_patterns is the other legitimate non-silent shape โ€” also no hint.""" + tt = _silent_bg_harness(monkeypatch, tmp_path) + try: + result = json.loads( + tt.terminal_tool( + command="uvicorn app:server --port 8080", + background=True, + watch_patterns=["Application startup complete"], + ) + ) + finally: + tt._active_environments.pop("default", None) + tt._last_activity.pop("default", None) + + assert "hint" not in result, ( + f"watch_patterns shape must not emit a silent-process hint, got: {result.get('hint')!r}" + ) + + +def test_foreground_command_does_not_emit_hint(monkeypatch, tmp_path): + """Hint only applies to background processes โ€” foreground returns its + result synchronously and the agent always sees the outcome.""" + tt = _silent_bg_harness(monkeypatch, tmp_path) + + # Foreground path doesn't go through spawn_local. Patch the local-env + # exec method to short-circuit to a clean exit so the test doesn't + # actually shell out. + from types import SimpleNamespace + dummy_env = SimpleNamespace( + env={}, + execute=lambda *a, **kw: {"output": "done", "exit_code": 0, "error": None}, + ) + monkeypatch.setitem(tt._active_environments, "default", dummy_env) + + try: + result = json.loads( + tt.terminal_tool( + command="echo hello", + background=False, + ) + ) + finally: + tt._active_environments.pop("default", None) + tt._last_activity.pop("default", None) + + assert "hint" not in result, ( + f"Foreground commands must not emit the background-silence hint, got: {result.get('hint')!r}" + ) diff --git a/tests/tools/test_pr_6656_regressions.py b/tests/tools/test_pr_6656_regressions.py new file mode 100644 index 00000000000..9429a804135 --- /dev/null +++ b/tests/tools/test_pr_6656_regressions.py @@ -0,0 +1,287 @@ +"""Regression tests for PR #6656 โ€” skill uninstall + bundle hash + pairing lock. + +Three independent fixes that were salvaged together: + +1. ``uninstall_skill`` path traversal: ``install_path`` comes from a JSON + file on disk; a malicious skill could write ``install_path: "../../"`` + and trigger ``shutil.rmtree`` against parent directories. Guarded with + ``Path.resolve().is_relative_to(SKILLS_DIR.resolve())``. + +2. ``bundle_content_hash`` / ``content_hash`` filename inclusion: the + previous hash mixed only file CONTENTS, so swapping ``SKILL.md`` and + ``scripts/run.sh`` contents between two paths produced the same digest. + Now both functions prefix each entry with ``rel_path + \\x00`` and + stay symmetric (one on disk, one on in-memory bundle). + +3. ``PairingStore.list_pending`` TOCTOU: previously called + ``_cleanup_expired`` (which writes the JSON file) without holding + ``self._lock``, racing with ``generate_code`` / ``approve_code``. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tools.skills_hub import ( + SkillBundle, + bundle_content_hash, + uninstall_skill, +) +from tools.skills_guard import content_hash + + +# ============================================================================= +# uninstall_skill: path traversal guard +# ============================================================================= + + +class TestUninstallPathTraversal: + """The ``install_path`` field in ``lock.json`` is attacker-controllable + if a malicious skill is ever installed (or if the hub's lockfile is + corrupted). The uninstall path must refuse anything that resolves + outside ``SKILLS_DIR``. + """ + + @pytest.fixture + def hub_setup(self, tmp_path, monkeypatch): + """Build a hub directory tree with a malicious lock.json entry. + + ``HubLockFile`` binds its default ``path`` argument at def time + against the module-level ``LOCK_FILE`` constant, so monkey-patching + ``LOCK_FILE`` alone is not enough โ€” we also need to rebind the + function default. Patching ``HubLockFile.__init__.__defaults__`` + is the standard tool for this. + """ + import tools.skills_hub as hub + skills_dir = tmp_path / "skills" + hub_dir = skills_dir / ".hub" + hub_dir.mkdir(parents=True) + lock_path = hub_dir / "lock.json" + + monkeypatch.setattr(hub, "SKILLS_DIR", skills_dir) + monkeypatch.setattr(hub, "HUB_DIR", hub_dir) + monkeypatch.setattr(hub, "LOCK_FILE", lock_path) + monkeypatch.setattr(hub, "AUDIT_LOG", hub_dir / "audit.log") + # Rebind HubLockFile.__init__'s default `path=` arg so + # `HubLockFile()` (no args) picks up the new lock path. + monkeypatch.setattr( + hub.HubLockFile.__init__, + "__defaults__", + (lock_path,), + ) + + # A real directory outside skills_dir that the traversal would + # delete if the guard fails. + victim = tmp_path / "do-not-delete" + victim.mkdir() + (victim / "important.txt").write_text("data") + return skills_dir, hub_dir, victim + + def _write_lock(self, hub_dir: Path, entries: dict) -> None: + lock_path = hub_dir / "lock.json" + lock_path.write_text(json.dumps({"version": 1, "installed": entries})) + + def test_traversal_via_parent_segments_rejected(self, hub_setup): + """install_path: "../do-not-delete" must NOT escape SKILLS_DIR.""" + skills_dir, hub_dir, victim = hub_setup + self._write_lock(hub_dir, { + "evil": { + "install_path": "../do-not-delete", + "source": "https://example.com", + "version": "1.0", + }, + }) + + ok, msg = uninstall_skill("evil") + + assert ok is False + assert "outside" in msg or "resolves" in msg or "skills directory" in msg + # The victim directory MUST still exist. + assert victim.exists() + assert (victim / "important.txt").exists() + + def test_absolute_path_rejected(self, hub_setup): + """install_path that's an absolute path outside SKILLS_DIR must be refused.""" + skills_dir, hub_dir, victim = hub_setup + self._write_lock(hub_dir, { + "evil": { + "install_path": str(victim), + "source": "https://example.com", + "version": "1.0", + }, + }) + + ok, msg = uninstall_skill("evil") + + # SKILLS_DIR / "" still results in an absolute path, + # which when resolved is outside skills_dir. Must be refused. + assert ok is False + assert victim.exists() + + def test_symlink_escape_rejected(self, tmp_path, hub_setup): + """Symlinks inside SKILLS_DIR that point outside must be refused + after realpath resolution.""" + skills_dir, hub_dir, victim = hub_setup + # Create a "skill" that's actually a symlink to victim + evil_link = skills_dir / "trapdoor" + evil_link.symlink_to(victim) + + self._write_lock(hub_dir, { + "trap": { + "install_path": "trapdoor", + "source": "https://example.com", + "version": "1.0", + }, + }) + + ok, msg = uninstall_skill("trap") + + # realpath resolves the symlink โ†’ outside skills_dir โ†’ refused. + assert ok is False + assert victim.exists() + assert (victim / "important.txt").exists() + + def test_legitimate_skill_uninstall_still_works(self, hub_setup): + """The guard must NOT block a normal skill directory inside SKILLS_DIR.""" + skills_dir, hub_dir, _victim = hub_setup + legit = skills_dir / "category" / "my-skill" + legit.mkdir(parents=True) + (legit / "SKILL.md").write_text("test") + + self._write_lock(hub_dir, { + "my-skill": { + "install_path": "category/my-skill", + "source": "https://example.com", + "trust_level": "community", + "version": "1.0", + }, + }) + + ok, msg = uninstall_skill("my-skill") + + assert ok is True + assert not legit.exists() + + +# ============================================================================= +# Bundle / disk hash symmetry + filename inclusion +# ============================================================================= + + +class TestBundleHashFilenameSensitivity: + """Hashes must change when filenames are swapped, even if combined + contents stay identical. ``bundle_content_hash`` (in-memory) and + ``content_hash`` (on-disk) must stay symmetric โ€” they're used to + detect skill drift between an installed bundle and its source. + """ + + def _make_bundle(self, files: dict) -> SkillBundle: + return SkillBundle( + name="test", + files=files, + source="test", + identifier="test/test", + trust_level="community", + ) + + def test_filename_swap_changes_hash(self): + """Swapping content between SKILL.md and scripts/run.sh must + produce a different hash. Without the filename in the hash, + these two bundles would have looked identical.""" + a = self._make_bundle({"SKILL.md": "hello", "scripts/run.sh": "world"}) + b = self._make_bundle({"SKILL.md": "world", "scripts/run.sh": "hello"}) + assert bundle_content_hash(a) != bundle_content_hash(b) + + def test_identical_bundles_same_hash(self): + """Sanity: equal content + paths = equal hash.""" + a = self._make_bundle({"SKILL.md": "x", "run.sh": "y"}) + b = self._make_bundle({"SKILL.md": "x", "run.sh": "y"}) + assert bundle_content_hash(a) == bundle_content_hash(b) + + def test_disk_hash_changes_on_filename_swap(self, tmp_path): + """``content_hash`` on disk must also be filename-sensitive, + so it stays symmetric with ``bundle_content_hash``.""" + skill_a = tmp_path / "a" + skill_a.mkdir() + (skill_a / "SKILL.md").write_text("hello") + (skill_a / "run.sh").write_text("world") + + skill_b = tmp_path / "b" + skill_b.mkdir() + (skill_b / "SKILL.md").write_text("world") + (skill_b / "run.sh").write_text("hello") + + # Different filenameโ†”content mappings = different hashes. + assert content_hash(skill_a) != content_hash(skill_b) + + def test_bundle_and_disk_hash_match(self, tmp_path): + """Symmetry contract: the same skill, expressed as a SkillBundle + and as a directory tree, must produce the same digest. If this + fails, ``check_for_skill_updates`` will flag every clean + install as drifted.""" + skill_dir = tmp_path / "skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("hello") + (skill_dir / "scripts").mkdir() + (skill_dir / "scripts" / "run.sh").write_text("world") + + bundle = self._make_bundle({ + "SKILL.md": "hello", + "scripts/run.sh": "world", + }) + + assert bundle_content_hash(bundle) == content_hash(skill_dir) + + +# ============================================================================= +# PairingStore.list_pending: must hold the lock +# ============================================================================= + + +class TestListPendingLock: + """list_pending writes via _cleanup_expired. Without the lock, + a concurrent generate_code or approve_code can race against the + write, potentially clobbering a pending approval.""" + + def test_list_pending_acquires_lock(self, tmp_path): + """Source-grep contract: ``list_pending`` body must be wrapped + in ``with self._lock:``. If anyone unwraps it again, the TOCTOU + bug returns.""" + import gateway.pairing as _pairing_mod + source = Path(_pairing_mod.__file__).read_text(encoding="utf-8") + # Find the list_pending function body and assert the lock + # context manager appears inside it. We grep the function + # source rather than runtime-introspect because the racy + # behaviour is hard to deterministically reproduce in a test. + lines = source.splitlines() + in_func = False + seen_lock = False + for line in lines: + if line.startswith(" def list_pending("): + in_func = True + continue + if in_func: + if line.startswith(" def "): + break # next function + if "with self._lock:" in line: + seen_lock = True + break + assert seen_lock, ( + "list_pending must wrap its body in `with self._lock:` โ€” " + "without it, _cleanup_expired's file write races with " + "concurrent generate_code/approve_code." + ) + + def test_list_pending_returns_correct_data(self, tmp_path): + """End-to-end smoke: even with the lock held, basic operation works.""" + from gateway.pairing import PairingStore + with patch("gateway.pairing.PAIRING_DIR", tmp_path): + store = PairingStore() + store.generate_code("telegram", "user1", "Alice") + pending = store.list_pending("telegram") + assert len(pending) == 1 + assert pending[0]["user_id"] == "user1" diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index 3ac5bdfd1f1..10e4421e5f0 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -1007,3 +1007,163 @@ def test_drain_notifications_empty_queue(): results = process_registry.drain_notifications() assert results == [] + + +# --------------------------------------------------------------------------- +# _terminate_host_pid โ€” cross-platform process-tree termination +# --------------------------------------------------------------------------- + + +class TestTerminateHostPidWindows: + """Windows branch uses ``taskkill /T /F`` โ€” the documented MS tree-kill + primitive. We can't use psutil's ``children(recursive=True)`` / + ``.terminate()`` path on Windows because (1) Windows doesn't maintain + a Unix-style process tree so the walk is unreliable, and (2) + ``Process.terminate()`` on Windows is ``TerminateProcess()`` for the + target handle only, not the tree. + """ + + def test_windows_invokes_taskkill_with_tree_and_force_flags(self, monkeypatch): + """The Windows branch must shell out to ``taskkill /PID N /T /F``.""" + from tools import process_registry as pr + + captured = {} + + def fake_run(args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return MagicMock(returncode=0, stderr="", stdout="") + + monkeypatch.setattr(pr, "_IS_WINDOWS", True) + monkeypatch.setattr(pr.subprocess, "run", fake_run) + + pr.ProcessRegistry._terminate_host_pid(12345) + + assert captured["args"][0] == "taskkill" + assert "/PID" in captured["args"] + assert "12345" in captured["args"] + assert "/T" in captured["args"], "Tree flag required to reach descendants" + assert "/F" in captured["args"], "Force flag required for headless Chromium" + + def test_windows_falls_back_to_os_kill_when_taskkill_missing(self, monkeypatch): + """If ``taskkill.exe`` is somehow unavailable, fall back to a bare + ``os.kill(pid, SIGTERM)`` so we at least try to kill the parent.""" + from tools import process_registry as pr + + kill_calls = [] + + def fake_run(*args, **kwargs): + raise FileNotFoundError("taskkill not found") + + def fake_kill(pid, sig): + kill_calls.append((pid, sig)) + + monkeypatch.setattr(pr, "_IS_WINDOWS", True) + monkeypatch.setattr(pr.subprocess, "run", fake_run) + monkeypatch.setattr(pr.os, "kill", fake_kill) + + pr.ProcessRegistry._terminate_host_pid(12345) + + assert kill_calls == [(12345, signal.SIGTERM)] + + def test_windows_does_not_call_psutil(self, monkeypatch): + """The Windows branch must NOT exercise the psutil tree-walk + (it's unreliable on Windows โ€” see the function docstring).""" + from tools import process_registry as pr + import psutil + + psutil_calls = [] + + class _BoomProcess: + def __init__(self, pid): + psutil_calls.append(("Process", pid)) + + def children(self, recursive=False): + psutil_calls.append(("children", recursive)) + return [] + + def terminate(self): + psutil_calls.append(("terminate",)) + + def fake_run(args, **kwargs): + return MagicMock(returncode=0, stderr="", stdout="") + + monkeypatch.setattr(pr, "_IS_WINDOWS", True) + monkeypatch.setattr(pr.subprocess, "run", fake_run) + monkeypatch.setattr(psutil, "Process", _BoomProcess) + + pr.ProcessRegistry._terminate_host_pid(12345) + + assert psutil_calls == [], ( + f"Windows branch must not touch psutil, but saw {psutil_calls!r}" + ) + + +class TestTerminateHostPidPosix: + """POSIX branch walks the tree via psutil and SIGTERMs children first.""" + + def test_posix_walks_tree_and_terminates_children_then_parent(self, monkeypatch): + from tools import process_registry as pr + import psutil + + terminate_order = [] + + class _FakeChild: + def __init__(self, pid): + self.pid = pid + + def terminate(self): + terminate_order.append(self.pid) + + class _FakeParent: + def __init__(self, pid): + self.pid = pid + + def children(self, recursive=False): + assert recursive is True + return [_FakeChild(101), _FakeChild(102), _FakeChild(103)] + + def terminate(self): + terminate_order.append(self.pid) + + monkeypatch.setattr(pr, "_IS_WINDOWS", False) + monkeypatch.setattr(psutil, "Process", _FakeParent) + + pr.ProcessRegistry._terminate_host_pid(12345) + + assert terminate_order == [101, 102, 103, 12345], ( + "Children must be terminated before the parent" + ) + + def test_posix_no_such_process_swallowed(self, monkeypatch): + from tools import process_registry as pr + import psutil + + def boom(pid): + raise psutil.NoSuchProcess(pid) + + monkeypatch.setattr(pr, "_IS_WINDOWS", False) + monkeypatch.setattr(psutil, "Process", boom) + + # Must not raise. + pr.ProcessRegistry._terminate_host_pid(999999999) + + def test_posix_oserror_falls_back_to_os_kill(self, monkeypatch): + from tools import process_registry as pr + import psutil + + def boom(pid): + raise PermissionError("can't read /proc") + + kill_calls = [] + + def fake_kill(pid, sig): + kill_calls.append((pid, sig)) + + monkeypatch.setattr(pr, "_IS_WINDOWS", False) + monkeypatch.setattr(psutil, "Process", boom) + monkeypatch.setattr(pr.os, "kill", fake_kill) + + pr.ProcessRegistry._terminate_host_pid(12345) + + assert kill_calls == [(12345, signal.SIGTERM)] diff --git a/tests/tools/test_send_message_tool.py b/tests/tools/test_send_message_tool.py index 3a6cb6d6e30..66aab5eee74 100644 --- a/tests/tools/test_send_message_tool.py +++ b/tests/tools/test_send_message_tool.py @@ -28,16 +28,93 @@ def _reset_signal_scheduler(): from gateway.config import Platform from tools.send_message_tool import ( - _derive_forum_thread_name, _is_telegram_thread_not_found, _parse_target_ref, - _send_discord, _send_matrix_via_adapter, _send_signal, _send_telegram, _send_to_platform, send_message_tool, ) +# Discord helpers moved to the plugin in #24325. Import from the new path +# and provide a thin ``_send_discord(token, ...)`` shim that mirrors the +# pre-migration signature so the existing test bodies keep working. +from plugins.platforms.discord.adapter import ( + _DISCORD_CHANNEL_TYPE_PROBE_CACHE, + _derive_forum_thread_name, + _probe_is_forum_cached, + _remember_channel_is_forum, + _standalone_send, +) + + +async def _send_discord( + token, + chat_id, + message, + *, + thread_id=None, + media_files=None, +): + """Pre-migration ``(token, chat_id, message, โ€ฆ)`` adapter around the + plugin's ``_standalone_send(pconfig, โ€ฆ)``. Lets test bodies continue + to call ``_send_discord("tok", ...)`` without rewriting every signature. + """ + pconfig = SimpleNamespace(token=token, extra={}) + return await _standalone_send( + pconfig, + chat_id, + message, + thread_id=thread_id, + media_files=media_files, + ) + + +def _discord_entry(): + """Return the live Discord PlatformEntry, importing lazily so plugin + discovery is forced exactly once and patches survive across tests.""" + from hermes_cli.plugins import discover_plugins + from gateway.platform_registry import platform_registry + discover_plugins() + return platform_registry.get("discord") + + +class _patch_discord_sender: + """Patch the Discord registry entry's ``standalone_sender_fn`` with the + given mock and translate the production ``(pconfig, ...)`` call shape + back to the pre-migration ``(token, ...)`` shape the test mocks expect. + + Use as a context manager: + + send_mock = AsyncMock(return_value={...}) + with _patch_discord_sender(send_mock): + asyncio.run(_send_to_platform(Platform.DISCORD, ...)) + send_mock.assert_awaited_once_with("tok", "chat", "msg", + thread_id=None, media_files=[]) + """ + + def __init__(self, mock): + self._mock = mock + self._entry = None + self._original = None + + async def _adapter(self, pconfig, chat_id, message, *, thread_id=None, media_files=None): + token = getattr(pconfig, "token", None) + return await self._mock( + token, chat_id, message, + thread_id=thread_id, media_files=media_files, + ) + + def __enter__(self): + self._entry = _discord_entry() + self._original = self._entry.standalone_sender_fn + self._entry.standalone_sender_fn = self._adapter + return self._mock + + def __exit__(self, exc_type, exc, tb): + if self._entry is not None: + self._entry.standalone_sender_fn = self._original + return False def _run_async_immediately(coro): @@ -300,6 +377,37 @@ class TestSendMessageTool: user_id="user-123", ) + def test_media_tag_outside_allowed_roots_is_not_sent(self, tmp_path): + config, telegram_cfg = _make_config() + secret = tmp_path / "secret.pdf" + secret.write_bytes(b"%PDF secret") + + with patch("gateway.config.load_gateway_config", return_value=config), \ + patch("tools.interrupt.is_interrupted", return_value=False), \ + patch("model_tools._run_async", side_effect=_run_async_immediately), \ + patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})) as send_mock, \ + patch("gateway.mirror.mirror_to_session", return_value=True): + result = json.loads( + send_message_tool( + { + "action": "send", + "target": "telegram:12345", + "message": f"hello\nMEDIA:{secret}", + } + ) + ) + + assert result["success"] is True + send_mock.assert_awaited_once_with( + Platform.TELEGRAM, + telegram_cfg, + "12345", + "hello", + thread_id=None, + media_files=[], + force_document=False, + ) + def test_top_level_send_failure_redacts_query_token(self): config, _telegram_cfg = _make_config() leaked = "very-secret-query-token-123456" @@ -446,7 +554,7 @@ class TestSendToPlatformChunking: """Messages exceeding the platform limit are split into multiple sends.""" send = AsyncMock(return_value={"success": True, "message_id": "1"}) long_msg = "word " * 1000 # ~5000 chars, well over Discord's 2000 limit - with patch("tools.send_message_tool._send_discord", send): + with _patch_discord_sender(send): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1176,7 +1284,7 @@ class TestSendToPlatformDiscordThread: """Discord platform with thread_id passes it to _send_discord.""" send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) - with patch("tools.send_message_tool._send_discord", send_mock): + with _patch_discord_sender(send_mock): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1196,7 +1304,7 @@ class TestSendToPlatformDiscordThread: """Discord platform without thread_id passes None.""" send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) - with patch("tools.send_message_tool._send_discord", send_mock): + with _patch_discord_sender(send_mock): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1360,7 +1468,7 @@ class TestSendToPlatformDiscordMedia: # A message long enough to get chunked (Discord limit is 2000) long_msg = "A" * 1900 + " " + "B" * 1900 - with patch("tools.send_message_tool._send_discord", side_effect=mock_send_discord): + with _patch_discord_sender(AsyncMock(side_effect=mock_send_discord)): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1380,7 +1488,7 @@ class TestSendToPlatformDiscordMedia: """Short message (single chunk) gets media_files directly.""" send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) - with patch("tools.send_message_tool._send_discord", send_mock): + with _patch_discord_sender(send_mock): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1618,7 +1726,7 @@ class TestSendToPlatformDiscordForum: """Discord messages are routed through _send_discord, which handles forum detection.""" send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) - with patch("tools.send_message_tool._send_discord", send_mock): + with _patch_discord_sender(send_mock): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1637,7 +1745,7 @@ class TestSendToPlatformDiscordForum: """Thread ID is still passed through when sending to Discord.""" send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) - with patch("tools.send_message_tool._send_discord", send_mock): + with _patch_discord_sender(send_mock): result = asyncio.run( _send_to_platform( Platform.DISCORD, @@ -1775,11 +1883,11 @@ class TestForumProbeCache: """_DISCORD_CHANNEL_TYPE_PROBE_CACHE memoizes forum detection results.""" def setup_method(self): - from tools import send_message_tool as smt - smt._DISCORD_CHANNEL_TYPE_PROBE_CACHE.clear() + from plugins.platforms.discord import adapter as discord_adapter + discord_adapter._DISCORD_CHANNEL_TYPE_PROBE_CACHE.clear() def test_cache_round_trip(self): - from tools.send_message_tool import ( + from plugins.platforms.discord.adapter import ( _probe_is_forum_cached, _remember_channel_is_forum, ) @@ -1819,7 +1927,7 @@ class TestForumProbeCache: thread_session.post = MagicMock(return_value=thread_resp) # Two _send_discord calls: first does probe + thread-create; second should skip probe - from tools import send_message_tool as smt + from plugins.platforms.discord import adapter as discord_adapter sessions_created = [] @@ -1837,7 +1945,7 @@ class TestForumProbeCache: with patch("aiohttp.ClientSession", side_effect=session_factory): result1 = asyncio.run(_send_discord("tok", "ch1", "first")) assert result1["success"] is True - assert smt._probe_is_forum_cached("ch1") is True + assert discord_adapter._probe_is_forum_cached("ch1") is True # Second call: cache hits, no new probe session needed. We need to only # return thread_session now since probe is skipped. @@ -2575,4 +2683,3 @@ class TestSendTelegramThreadNotFoundRetry: finally: if media_path and os.path.exists(media_path): os.unlink(media_path) - diff --git a/tests/tools/test_skills_ast_audit.py b/tests/tools/test_skills_ast_audit.py new file mode 100644 index 00000000000..c70d6a1f41c --- /dev/null +++ b/tests/tools/test_skills_ast_audit.py @@ -0,0 +1,103 @@ +"""Tests for tools.skills_ast_audit โ€” opt-in AST diagnostic scanner.""" + +import sys +from pathlib import Path + +from tools.skills_ast_audit import ast_scan_path, format_ast_report + + +def _pids(findings): + return [pid for (_f, _l, pid, _d) in findings] + + +def test_bypass_payload_detected(tmp_path): + """The exact bypass shape from #7072 is caught.""" + f = tmp_path / "exfil.py" + f.write_text( + "import importlib\n" + "parts = ['o', 's']\n" + "m = importlib.import_module(''.join(parts))\n" + "e = m.__dict__[''.join(['e','n','v'])]\n" + ) + pids = _pids(ast_scan_path(f)) + assert "dynamic_import" in pids + assert "importlib_import" in pids + assert "dict_access" in pids + + +def test_syntax_error_does_not_crash(tmp_path): + f = tmp_path / "bad.py" + f.write_text("def broken(\n") + assert ast_scan_path(f) == [] + + +def test_recursion_error_does_not_crash(tmp_path): + f = tmp_path / "deep.py" + f.write_text("a" + ".x" * 5000 + "\n") + orig = sys.getrecursionlimit() + sys.setrecursionlimit(200) + try: + result = ast_scan_path(f) + finally: + sys.setrecursionlimit(orig) + assert isinstance(result, list) + + +def test_importer_lookalike_not_flagged(tmp_path): + """`import importer` must NOT match โ€” dot-bounded prefix.""" + f = tmp_path / "ok.py" + f.write_text("import importer\nfrom importer import x\n") + assert _pids(ast_scan_path(f)) == [] + + +def test_literal_dunder_import_not_flagged(tmp_path): + """__import__('os') with a literal is not flagged (regex catches those).""" + f = tmp_path / "ok.py" + f.write_text("m = __import__('os')\n") + assert "dynamic_import_computed" not in _pids(ast_scan_path(f)) + + +def test_non_python_file_returns_empty(tmp_path): + f = tmp_path / "script.sh" + f.write_text("import importlib\n") + assert ast_scan_path(f) == [] + + +def test_directory_scans_recursively_and_skips_cache_dirs(tmp_path): + skill = tmp_path / "s" + skill.mkdir() + (skill / "main.py").write_text("import importlib\n") + (skill / "sub").mkdir() + (skill / "sub" / "u.py").write_text("from importlib.util import find_spec\n") + for d in ("__pycache__", ".venv", "venv", "node_modules"): + ignored = skill / d + ignored.mkdir() + (ignored / "junk.py").write_text("import importlib\n") + pids = _pids(ast_scan_path(skill)) + assert pids.count("importlib_import") == 2 + + +def test_missing_path_returns_empty(tmp_path): + assert ast_scan_path(tmp_path / "does_not_exist") == [] + + +def test_dynamic_getattr_and_dict_access_detected(tmp_path): + f = tmp_path / "g.py" + f.write_text("name = 'x'\nv = getattr(o, name)\nv = o.__dict__[name]\n") + pids = _pids(ast_scan_path(f)) + assert "dynamic_getattr" in pids + assert "dict_access" in pids + + +def test_format_report_empty(): + assert "No dynamic" in format_ast_report([]) + + +def test_format_report_with_findings(): + findings = [ + ("a.py", 1, "importlib_import", "import importlib โ€” ..."), + ("a.py", 3, "dynamic_import", "importlib.import_module() โ€” ..."), + ] + out = format_ast_report(findings, skill_name="test") + assert "test" in out and "a.py" in out and "L1" in out and "L3" in out + assert "diagnostic hints" in out diff --git a/tests/tools/test_skills_guard.py b/tests/tools/test_skills_guard.py index ccc55da205a..e2cc1c84e79 100644 --- a/tests/tools/test_skills_guard.py +++ b/tests/tools/test_skills_guard.py @@ -84,13 +84,13 @@ class TestDetermineVerdict: f = Finding("x", "high", "network", "f.py", 1, "m", "d") assert _determine_verdict([f]) == "caution" - def test_medium_finding_caution(self): + def test_medium_finding_safe(self): f = Finding("x", "medium", "structural", "f.py", 1, "m", "d") - assert _determine_verdict([f]) == "caution" + assert _determine_verdict([f]) == "safe" - def test_low_finding_caution(self): + def test_low_finding_safe(self): f = Finding("x", "low", "obfuscation", "f.py", 1, "m", "d") - assert _determine_verdict([f]) == "caution" + assert _determine_verdict([f]) == "safe" # --------------------------------------------------------------------------- @@ -145,21 +145,46 @@ class TestShouldAllowInstall: allowed, _ = should_allow_install(self._result("community", "dangerous", f), force=False) assert allowed is False - def test_force_overrides_dangerous_for_community(self): + def test_force_does_not_override_dangerous_for_community(self): f = [Finding("x", "critical", "c", "f", 1, "m", "d")] allowed, reason = should_allow_install( self._result("community", "dangerous", f), force=True ) - assert allowed is True - assert "Force-installed" in reason + assert allowed is False + assert "Blocked" in reason + # Error message MUST explain why --force didn't work, not invite a retry. + assert "does not override" in reason + assert "Use --force to override" not in reason - def test_force_overrides_dangerous_for_trusted(self): + def test_force_does_not_override_dangerous_for_trusted_message(self): f = [Finding("x", "critical", "c", "f", 1, "m", "d")] allowed, reason = should_allow_install( self._result("trusted", "dangerous", f), force=True ) - assert allowed is True - assert "Force-installed" in reason + assert allowed is False + assert "does not override" in reason + assert "Use --force to override" not in reason + + def test_non_dangerous_block_keeps_force_hint(self): + # When --force CAN override the block, the error message must still + # point to it. Use builtin trust + dangerous to land in the block + # branch without triggering the dangerous-specific message. + f = [Finding("x", "high", "network", "f", 1, "m", "d")] + # Construct a path where decision == block but verdict != dangerous. + # community + caution = block per current INSTALL_POLICY. + allowed, reason = should_allow_install( + self._result("community", "caution", f), force=False + ) + assert allowed is False + assert "Use --force to override" in reason + + def test_force_does_not_override_dangerous_for_trusted(self): + f = [Finding("x", "critical", "c", "f", 1, "m", "d")] + allowed, reason = should_allow_install( + self._result("trusted", "dangerous", f), force=True + ) + assert allowed is False + assert "Blocked" in reason # -- agent-created policy -- diff --git a/tests/tools/test_ssh_bulk_upload.py b/tests/tools/test_ssh_bulk_upload.py index cbdb6543495..afad54cf4f4 100644 --- a/tests/tools/test_ssh_bulk_upload.py +++ b/tests/tools/test_ssh_bulk_upload.py @@ -91,7 +91,7 @@ class TestSSHBulkUpload: assert "/home/testuser/.hermes/credentials" in mkdir_str def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path): - """Symlinks in staging dir should mirror the remote path structure.""" + """Symlinks in staging dir should mirror the .hermes-relative layout.""" f1 = tmp_path / "local_a.txt" f1.write_text("content a") @@ -107,9 +107,7 @@ class TestSSHBulkUpload: c_idx = cmd.index("-C") staging_dir = cmd[c_idx + 1] # Check the symlink exists - expected = os.path.join( - staging_dir, "home/testuser/.hermes/skills/my_skill.md" - ) + expected = os.path.join(staging_dir, "skills/my_skill.md") staging_paths.append(expected) assert os.path.islink(expected), f"Expected symlink at {expected}" assert os.readlink(expected) == os.path.abspath(str(f1)) @@ -166,14 +164,42 @@ class TestSSHBulkUpload: assert "-" in tar_cmd # stdout assert "-C" in tar_cmd - # ssh: extract from stdin at /, preserving existing dir modes (#17767) + # ssh: extract from stdin at ~/.hermes, preserving existing dir modes (#17767) ssh_str = " ".join(ssh_cmd) assert "ssh" in ssh_str assert "tar xf -" in ssh_str assert "--no-overwrite-dir" in ssh_str - assert "-C /" in ssh_str + assert "-C /home/testuser/.hermes" in ssh_str assert "testuser@example.com" in ssh_str + def test_bulk_upload_never_stages_remote_home_prefix(self, mock_env, tmp_path): + """Regression: do not archive /home/ path components.""" + f1 = tmp_path / "nested.txt" + f1.write_text("nested") + files = [(str(f1), "/home/testuser/.hermes/cache/nested.txt")] + + def capture_tar_cmd(cmd, **kwargs): + if cmd[0] == "tar": + c_idx = cmd.index("-C") + staging_dir = cmd[c_idx + 1] + assert not os.path.exists(os.path.join(staging_dir, "home")) + expected = os.path.join(staging_dir, "cache/nested.txt") + assert os.path.islink(expected) + + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_tar_cmd): + mock_env._ssh_bulk_upload(files) + def test_mkdir_failure_raises(self, mock_env, tmp_path): """mkdir failure should raise RuntimeError before tar pipe.""" f1 = tmp_path / "y.txt" diff --git a/tests/tools/test_tirith_security.py b/tests/tools/test_tirith_security.py index b47c7a5ff58..6c771c6d482 100644 --- a/tests/tools/test_tirith_security.py +++ b/tests/tools/test_tirith_security.py @@ -831,7 +831,8 @@ class TestDiskFailureMarker: with patch("tools.tirith_security._failure_marker_path", return_value=marker): from tools.tirith_security import _mark_install_failed, _is_install_failed_on_disk _mark_install_failed("cosign_missing") - assert _is_install_failed_on_disk() # cosign still absent + with patch("tools.tirith_security.shutil.which", return_value=None): + assert _is_install_failed_on_disk() # cosign still absent # Now cosign appears on PATH with patch("tools.tirith_security.shutil.which", return_value="/usr/local/bin/cosign"): diff --git a/tests/tools/test_transcription.py b/tests/tools/test_transcription.py index 32f0ad48798..b7e399ca426 100644 --- a/tests/tools/test_transcription.py +++ b/tests/tools/test_transcription.py @@ -23,6 +23,9 @@ def _fake_faster_whisper_module(mock_model): # --------------------------------------------------------------------------- +pytestmark = pytest.mark.usefixtures("disable_lazy_stt_install") + + @pytest.fixture(autouse=True) def _clear_openai_env(monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) diff --git a/tests/tools/test_transcription_dotenv_fallback.py b/tests/tools/test_transcription_dotenv_fallback.py index 365b910d4cc..5a0517c3bee 100644 --- a/tests/tools/test_transcription_dotenv_fallback.py +++ b/tests/tools/test_transcription_dotenv_fallback.py @@ -12,6 +12,9 @@ from unittest.mock import MagicMock, patch import pytest +pytestmark = pytest.mark.usefixtures("disable_lazy_stt_install") + + @pytest.fixture(autouse=True) def isolate_env(monkeypatch): """Strip every STT-related env var so the test really exercises the diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index 7f83565b5d8..c7cf8950239 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -42,6 +42,9 @@ def sample_ogg(tmp_path): return str(ogg_path) +pytestmark = pytest.mark.usefixtures("disable_lazy_stt_install") + + @pytest.fixture(autouse=True) def clean_env(monkeypatch): """Ensure no real API keys leak into tests.""" diff --git a/tests/tools/test_voice_mode.py b/tests/tools/test_voice_mode.py index 4c7ba74bd6e..3f7ada8c4a2 100644 --- a/tests/tools/test_voice_mode.py +++ b/tests/tools/test_voice_mode.py @@ -10,6 +10,18 @@ from unittest.mock import MagicMock, patch import pytest +def _non_wsl_proc_version(real_open): + """Return an open() shim that makes host WSL detection deterministic.""" + def _fake_open(file, *args, **kwargs): + if file == "/proc/version": + from io import StringIO + + return StringIO("Linux test-kernel") + return real_open(file, *args, **kwargs) + + return _fake_open + + # ============================================================================ # Fixtures # ============================================================================ @@ -68,6 +80,7 @@ class TestDetectAudioEnvironment: monkeypatch.delenv("SSH_CONNECTION", raising=False) monkeypatch.setattr("tools.voice_mode._import_audio", lambda: (MagicMock(), MagicMock())) + monkeypatch.setattr("builtins.open", _non_wsl_proc_version(open)) from tools.voice_mode import detect_audio_environment result = detect_audio_environment() @@ -225,6 +238,7 @@ class TestDetectAudioEnvironment: monkeypatch.setattr("tools.voice_mode.shutil.which", lambda cmd: "/data/data/com.termux/files/usr/bin/termux-microphone-record" if cmd == "termux-microphone-record" else None) monkeypatch.setattr("tools.voice_mode._termux_api_app_installed", lambda: True) monkeypatch.setattr("tools.voice_mode._import_audio", lambda: (_ for _ in ()).throw(ImportError("no audio libs"))) + monkeypatch.setattr("builtins.open", _non_wsl_proc_version(open)) from tools.voice_mode import detect_audio_environment result = detect_audio_environment() diff --git a/tools/approval.py b/tools/approval.py index bfc70cd0fb0..399b9d6c2d2 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -1299,12 +1299,34 @@ def check_all_command_guards(command: str, env_type: str, ) if not resolved or choice is None or choice == "deny": - reason = "timed out" if not resolved else "denied by user" + # Consent contract: silence is NOT consent, and an explicit + # deny is also a hard halt โ€” both produce a BLOCKED outcome + # that names the agent's most common evasion paths (retry, + # rephrase, achieve the same outcome via a different command). + # See issue #24912 for the original incident. + if not resolved: + reason = "timed out without user response" + timeout_addendum = " Silence is not consent." + outcome = "timeout" + else: + reason = "denied by user" + timeout_addendum = "" + outcome = "denied" return { "approved": False, - "message": f"BLOCKED: Command {reason}. Do NOT retry this command.", + "message": ( + f"BLOCKED: Command {reason}. The user has NOT consented " + f"to this action. Do NOT retry this command, do NOT " + f"rephrase it, and do NOT attempt the same outcome via " + f"a different command. Stop the current workflow and " + f"wait for the user to respond before taking any " + f"further destructive or irreversible action." + f"{timeout_addendum}" + ), "pattern_key": primary_key, "description": combined_desc, + "outcome": outcome, + "user_consent": False, } # User approved โ€” persist based on scope (same logic as CLI) @@ -1369,9 +1391,18 @@ def check_all_command_guards(command: str, env_type: str, if choice == "deny": return { "approved": False, - "message": "BLOCKED: User denied. Do NOT retry.", + "message": ( + "BLOCKED: User denied this command. The user has NOT consented " + "to this action. Do NOT retry this command, do NOT rephrase " + "it, and do NOT attempt the same outcome via a different " + "command. Stop the current workflow and wait for the user " + "to respond before taking any further destructive or " + "irreversible action." + ), "pattern_key": primary_key, "description": combined_desc, + "outcome": "denied", + "user_consent": False, } # Persist approval for each warning individually diff --git a/tools/browser_tool.py b/tools/browser_tool.py index 447f6500714..5320d6adfdb 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -102,7 +102,6 @@ from plugins.browser.firecrawl.provider import ( # noqa: F401 FirecrawlBrowserProvider as FirecrawlProvider, ) from tools.tool_backend_helpers import normalize_browser_cloud_provider - # Camofox local anti-detection browser backend (optional). # When CAMOFOX_URL is set, all browser operations route through the # camofox REST API instead of the agent-browser CLI. @@ -1386,8 +1385,11 @@ def _reap_orphaned_browser_sessions(): continue # Daemon is alive and its owner is dead (or legacy + untracked). Reap. + # Use the process-tree termination helper so Chromium children + # (renderer, GPU, etc.) are cleaned up, not just the daemon parent. try: - os.kill(daemon_pid, signal.SIGTERM) + from tools.process_registry import ProcessRegistry + ProcessRegistry._terminate_host_pid(daemon_pid) logger.info("Reaped orphaned browser daemon PID %d (session %s)", daemon_pid, session_name) reaped += 1 @@ -3437,8 +3439,9 @@ def _cleanup_single_browser_session(task_id: str) -> None: pid_file = os.path.join(socket_dir, f"{session_name}.pid") if os.path.isfile(pid_file): try: + from tools.process_registry import ProcessRegistry daemon_pid = int(Path(pid_file).read_text(encoding="utf-8").strip()) - os.kill(daemon_pid, signal.SIGTERM) + ProcessRegistry._terminate_host_pid(daemon_pid) logger.debug("Killed daemon pid %s for %s", daemon_pid, session_name) except (ProcessLookupError, ValueError, PermissionError, OSError): logger.debug("Could not kill daemon pid for %s (already dead or inaccessible)", session_name) @@ -3649,6 +3652,24 @@ def check_browser_requirements() -> bool: return True +def check_browser_vision_requirements() -> bool: + """Whether ``browser_vision`` should be advertised to the model. + + Requires BOTH a working browser (``check_browser_requirements``) AND a + resolvable vision backend. Without the vision check, the tool stays in + the model's tool list even when no vision provider is configured, then + fails at call time with a cryptic provider-side error like + ``unknown variant `image_url`, expected `text``` (issue #31179). + """ + if not check_browser_requirements(): + return False + try: + from tools.vision_tools import check_vision_requirements + except ImportError: + return False + return check_vision_requirements() + + # ============================================================================ # Module Test # ============================================================================ @@ -3783,7 +3804,7 @@ registry.register( toolset="browser", schema=_BROWSER_SCHEMA_MAP["browser_vision"], handler=lambda args, **kw: browser_vision(question=args.get("question", ""), annotate=args.get("annotate", False), task_id=kw.get("task_id")), - check_fn=check_browser_requirements, + check_fn=check_browser_vision_requirements, emoji="๐Ÿ‘๏ธ", ) registry.register( diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index bdbc4bfbe1b..f57085277e9 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -202,9 +202,9 @@ _TOOL_STUBS = { ), "write_file": ( "write_file", - "path: str, content: str", - '"""Write content to a file (always overwrites). Returns dict with status."""', - '{"path": path, "content": content}', + "path: str, content: str, cross_profile: bool = False", + '"""Write content to a file (always overwrites). Returns dict with status. cross_profile=True opts out of the cross-Hermes-profile soft guard."""', + '{"path": path, "content": content, "cross_profile": cross_profile}', ), "search_files": ( "search_files", @@ -214,9 +214,9 @@ _TOOL_STUBS = { ), "patch": ( "patch", - 'path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False, mode: str = "replace", patch: str = None', - '"""Targeted find-and-replace (mode="replace") or V4A multi-file patches (mode="patch"). Returns dict with status."""', - '{"path": path, "old_string": old_string, "new_string": new_string, "replace_all": replace_all, "mode": mode, "patch": patch}', + 'path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False, mode: str = "replace", patch: str = None, cross_profile: bool = False', + '"""Targeted find-and-replace (mode="replace") or V4A multi-file patches (mode="patch"). Returns dict with status. cross_profile=True opts out of the cross-Hermes-profile soft guard."""', + '{"path": path, "old_string": old_string, "new_string": new_string, "replace_all": replace_all, "mode": mode, "patch": patch, "cross_profile": cross_profile}', ), "terminal": ( "terminal", diff --git a/tools/computer_use/backend.py b/tools/computer_use/backend.py index 9952510e9cc..c9686e41b04 100644 --- a/tools/computer_use/backend.py +++ b/tools/computer_use/backend.py @@ -142,6 +142,14 @@ class ComputerUseBackend(ABC): def focus_app(self, app: str, raise_window: bool = False) -> ActionResult: """Route input to `app` (by name or bundle ID). Default: focus without raise.""" + # โ”€โ”€ Native-value mutation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + @abstractmethod + def set_value(self, value: str, element: Optional[int] = None) -> ActionResult: + """Set a native value on an element (e.g. AXPopUpButton selection). + + `element` is the 1-based SOM index returned by a prior capture call. + """ + # โ”€โ”€ Timing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ def wait(self, seconds: float) -> ActionResult: """Default implementation: time.sleep.""" diff --git a/tools/computer_use/schema.py b/tools/computer_use/schema.py index d8928d0dc56..b39ccf06aa9 100644 --- a/tools/computer_use/schema.py +++ b/tools/computer_use/schema.py @@ -75,6 +75,28 @@ COMPUTER_USE_SCHEMA: Dict[str, Any] = { "frontmost app's window or the whole screen." ), }, + "max_elements": { + "type": "integer", + "description": ( + "Optional cap on the AX `elements` array returned by " + "`action='capture'`. Default 100, hard maximum 1000. " + "Dense UIs (Electron apps such as Obsidian or VS Code, " + "JetBrains IDEs) can publish 500+ AX nodes โ€” capping " + "prevents a single capture from blowing session " + "context. When the cap trims the response, " + "`total_elements` and `truncated_elements` are " + "surfaced in the result so you can re-call with " + "`app=` to narrow scope or raise `max_elements` when " + "the full tree is required. Has no effect on " + "`mode='som'` / `mode='vision'` when a screenshot is " + "included in the response; only the rare image-" + "missing fallback returns an `elements` array and is " + "subject to the cap." + ), + "default": 100, + "minimum": 1, + "maximum": 1000, + }, # โ”€โ”€ click / drag / scroll targeting โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ "element": { "type": "integer", diff --git a/tools/computer_use/tool.py b/tools/computer_use/tool.py index 4912b0f979a..abb14ebd878 100644 --- a/tools/computer_use/tool.py +++ b/tools/computer_use/tool.py @@ -200,6 +200,10 @@ class _NoopBackend(ComputerUseBackend): # pragma: no cover self.calls.append(("focus_app", {"app": app, "raise": raise_window})) return ActionResult(ok=True, action="focus_app") + def set_value(self, value: str, element: Optional[int] = None) -> ActionResult: + self.calls.append(("set_value", {"value": value, "element": element})) + return ActionResult(ok=True, action="set_value") + # --------------------------------------------------------------------------- # Dispatch @@ -317,7 +321,7 @@ def _dispatch(backend: ComputerUseBackend, action: str, args: Dict[str, Any]) -> if mode not in {"som", "vision", "ax"}: return json.dumps({"error": f"bad mode {mode!r}; use som|vision|ax"}) cap = backend.capture(mode=mode, app=args.get("app")) - return _capture_response(cap) + return _capture_response(cap, max_elements=_coerce_max_elements(args.get("max_elements"))) if action == "wait": seconds = float(args.get("seconds", 1.0)) @@ -416,16 +420,62 @@ def _text_response(res: ActionResult) -> str: return json.dumps(payload) -def _capture_response(cap: CaptureResult) -> Any: - element_index = _format_elements(cap.elements) +# Default cap for the AX `elements` array returned by capture. Dense UIs +# (Electron apps, Obsidian, JetBrains IDEs) can publish 500+ AX nodes, which +# can exhaust session context after a single capture. The model-facing +# `max_elements` argument lets callers raise this when they need the full tree. +_DEFAULT_MAX_ELEMENTS = 100 +# Hard upper bound on caller-supplied `max_elements`. Without this, a tool +# call passing a very large integer would silently disable the safeguard and +# reintroduce the original unbounded behavior. +_MAX_ALLOWED_MAX_ELEMENTS = 1000 + + +def _coerce_max_elements(value: Any) -> int: + """Validate the caller-supplied ``max_elements``. + + Falls back to :data:`_DEFAULT_MAX_ELEMENTS` for missing / non-integer / + sub-1 inputs so the cap can never be silently disabled by a malformed + tool-call argument. Clamps oversized values to + :data:`_MAX_ALLOWED_MAX_ELEMENTS` so a caller cannot bypass the + safeguard by passing a very large integer. + """ + if value is None: + return _DEFAULT_MAX_ELEMENTS + try: + n = int(value) + except (TypeError, ValueError): + return _DEFAULT_MAX_ELEMENTS + if n < 1: + return _DEFAULT_MAX_ELEMENTS + if n > _MAX_ALLOWED_MAX_ELEMENTS: + return _MAX_ALLOWED_MAX_ELEMENTS + return n + + +def _capture_response(cap: CaptureResult, max_elements: int = _DEFAULT_MAX_ELEMENTS) -> Any: + total_elements = len(cap.elements) + visible_elements = cap.elements[:max_elements] + truncated_elements = max(0, total_elements - len(visible_elements)) + + # Index only what's actually surfaced in the response โ€” otherwise the + # human-readable summary references element indices the model cannot + # find in the JSON `elements` array (e.g. max_elements=10 vs the default + # 40-line index window). + element_index = _format_elements(visible_elements) summary_lines = [ f"capture mode={cap.mode} {cap.width}x{cap.height}" + (f" app={cap.app}" if cap.app else "") + (f" window={cap.window_title!r}" if cap.window_title else ""), - f"{len(cap.elements)} interactable element(s):", + f"{total_elements} interactable element(s):", ] if element_index: summary_lines.extend(element_index) + # Multimodal and AX paths both reference `summary`; build it once up-front + # so the aux-vision routing branch (which fires before either path is + # selected) has a valid value to hand to _route_capture_through_aux_vision. + # The AX path appends the "truncated to N of M" note to summary_lines + # below and rebuilds; the multimodal path keeps this version untouched. summary = "\n".join(summary_lines) if cap.png_b64 and cap.mode != "ax": @@ -449,6 +499,9 @@ def _capture_response(cap: CaptureResult) -> Any: # JPEG: base64 starts with /9j/ PNG: starts with iVBOR _b64_prefix = cap.png_b64[:8] _mime = "image/jpeg" if _b64_prefix.startswith("/9j/") else "image/png" + # The multimodal response carries the screenshot, not the AX + # elements array, so a "response truncated to N of M elements" + # note would be inaccurate โ€” skip it on this branch. return { "_multimodal": True, "content": [ @@ -458,18 +511,29 @@ def _capture_response(cap: CaptureResult) -> Any: ], "text_summary": summary, "meta": {"mode": cap.mode, "width": cap.width, "height": cap.height, - "elements": len(cap.elements), "png_bytes": cap.png_bytes_len}, + "elements": total_elements, "png_bytes": cap.png_bytes_len}, } - # AX-only (or image missing): text path. - return json.dumps({ + # AX-only (or image-missing fallback): text path actually carries the + # `elements` array, so the truncation note applies here. + if truncated_elements: + summary_lines.append( + f" (response truncated to {len(visible_elements)} of {total_elements} elements; " + f"raise max_elements or pass app= to narrow)" + ) + summary = "\n".join(summary_lines) + payload: Dict[str, Any] = { "mode": cap.mode, "width": cap.width, "height": cap.height, "app": cap.app, "window_title": cap.window_title, - "elements": [_element_to_dict(e) for e in cap.elements], + "elements": [_element_to_dict(e) for e in visible_elements], + "total_elements": total_elements, "summary": summary, - }) + } + if truncated_elements: + payload["truncated_elements"] = truncated_elements + return json.dumps(payload) # --------------------------------------------------------------------------- @@ -611,6 +675,11 @@ def _maybe_follow_capture( ) -> Any: if not do_capture: return _text_response(res) + # Skip the follow-up capture when the action itself failed: showing a + # normal-looking screenshot after a failure misleads the model into thinking + # the action succeeded. Return the error text instead. + if not res.ok: + return _text_response(res) try: # Preserve the app context established by the preceding capture/focus_app so # that capture_after=True re-captures the same app rather than the frontmost diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 1f1afb48440..8924d76895f 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -169,6 +169,7 @@ class SSHEnvironment(BaseEnvironment): if not files: return + base = f"{self._remote_home}/.hermes" parents = unique_parent_dirs(files) if parents: cmd = self._build_ssh_command() @@ -180,7 +181,19 @@ class SSHEnvironment(BaseEnvironment): # Symlink staging avoids fragile GNU tar --transform rules. with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging: for host_path, remote_path in files: - staged = os.path.join(staging, remote_path.lstrip("/")) + try: + rel_remote = os.path.relpath(remote_path, base) + except ValueError as exc: + raise RuntimeError( + f"remote path {remote_path!r} is not under sync base {base!r}" + ) from exc + + if rel_remote == "." or rel_remote.startswith("../"): + raise RuntimeError( + f"remote path {remote_path!r} escapes sync base {base!r}" + ) + + staged = os.path.join(staging, rel_remote) os.makedirs(os.path.dirname(staged), exist_ok=True) os.symlink(os.path.abspath(host_path), staged) @@ -190,7 +203,7 @@ class SSHEnvironment(BaseEnvironment): # existing directories (e.g. /home/) with the staging # directory's mode. Without this, a umask 002 produces 0775 # dirs which breaks sshd StrictModes (refuses authorized_keys). - ssh_cmd.append("tar xf - --no-overwrite-dir -C /") + ssh_cmd.append(f"tar xf - --no-overwrite-dir -C {shlex.quote(base)}") tar_proc = subprocess.Popen( tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE diff --git a/tools/fal_common.py b/tools/fal_common.py new file mode 100644 index 00000000000..27636f90388 --- /dev/null +++ b/tools/fal_common.py @@ -0,0 +1,163 @@ +"""Shared FAL.ai SDK plumbing. + +Holds the stateless atoms that every FAL-backed tool needs: + +* :func:`import_fal_client` โ€” lazy import + ``lazy_deps`` integration so + ``fal_client`` isn't pulled at cold start (it added ~64 ms per CLI + invocation when imported eagerly). +* :class:`_ManagedFalSyncClient` โ€” wrapper that drives a Nous-managed + fal-queue gateway through the standard ``fal_client.SyncClient`` + primitives. +* :func:`_normalize_fal_queue_url_format`, :func:`_extract_http_status` + โ€” small helpers used by both the managed client wrapper and + ``_submit_fal_request``. + +Stateful pieces (cache globals, ``_managed_fal_client*`` selectors, +``_submit_fal_request``) intentionally stay on +:mod:`tools.image_generation_tool`. That module is the patch target for +existing test suites (``tests/tools/test_image_generation.py``, +``tests/tools/test_managed_media_gateways.py``) and for the +``plugins/image_gen/fal/`` plugin's ``_it`` indirection โ€” moving the +caches here would silently defeat ``monkeypatch.setattr(image_tool, +"_managed_fal_client", None)`` because the lookups would go against +``fal_common``'s namespace instead. See the per-rule walkthrough at +issue #26241 for details. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Union +from urllib.parse import urlencode + + +def import_fal_client() -> Any: + """Import ``fal_client`` (via ``lazy_deps`` when available) and return + the module reference. + + Callers are responsible for caching the result on their own module + global โ€” keeping per-module globals lets tests monkey-patch the + target module's ``fal_client`` attribute and have the patched value + stick for that module's call sites. + + Raises :class:`ImportError` if the package is genuinely unavailable. + """ + try: + from tools.lazy_deps import ensure as _lazy_ensure + _lazy_ensure("image.fal", prompt=False) + except ImportError: + pass + except Exception as exc: # noqa: BLE001 โ€” lazy_deps surfaces install hints + raise ImportError(str(exc)) + import fal_client # type: ignore # noqa: WPS433 โ€” intentionally lazy + return fal_client + + +def _normalize_fal_queue_url_format(queue_run_origin: str) -> str: + normalized_origin = str(queue_run_origin or "").strip().rstrip("/") + if not normalized_origin: + raise ValueError("Managed FAL queue origin is required") + return f"{normalized_origin}/" + + +def _extract_http_status(exc: BaseException) -> Optional[int]: + """Return an HTTP status code from httpx/fal exceptions, else None. + + Defensive across exception shapes โ€” httpx.HTTPStatusError exposes + ``.response.status_code`` while fal_client wrappers may expose + ``.status_code`` directly. + """ + response = getattr(exc, "response", None) + if response is not None: + status = getattr(response, "status_code", None) + if isinstance(status, int): + return status + status = getattr(exc, "status_code", None) + if isinstance(status, int): + return status + return None + + +class _ManagedFalSyncClient: + """Small per-instance wrapper around ``fal_client.SyncClient`` for + managed queue hosts. + + The wrapper carries its own ``fal_client`` module reference instead + of reaching into a module global, so callers stay in control of + which module's ``fal_client`` is in scope (matters for the test + patches that swap the legacy module's ``fal_client`` attribute). + """ + + def __init__(self, fal_client: Any, *, key: str, queue_run_origin: str): + sync_client_class = getattr(fal_client, "SyncClient", None) + if sync_client_class is None: + raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode") + + client_module = getattr(fal_client, "client", None) + if client_module is None: + raise RuntimeError("fal_client.client is required for managed FAL gateway mode") + + self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin) + self._sync_client = sync_client_class(key=key) + self._http_client = getattr(self._sync_client, "_client", None) + self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None) + self._raise_for_status = getattr(client_module, "_raise_for_status", None) + self._request_handle_class = getattr(client_module, "SyncRequestHandle", None) + self._add_hint_header = getattr(client_module, "add_hint_header", None) + self._add_priority_header = getattr(client_module, "add_priority_header", None) + self._add_timeout_header = getattr(client_module, "add_timeout_header", None) + + if self._http_client is None: + raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode") + if self._maybe_retry_request is None or self._raise_for_status is None: + raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode") + if self._request_handle_class is None: + raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode") + + def submit( + self, + application: str, + arguments: Dict[str, Any], + *, + path: str = "", + hint: Optional[str] = None, + webhook_url: Optional[str] = None, + priority: Any = None, + headers: Optional[Dict[str, str]] = None, + start_timeout: Optional[Union[int, float]] = None, + ): + url = self._queue_url_format + application + if path: + url += "/" + path.lstrip("/") + if webhook_url is not None: + url += "?" + urlencode({"fal_webhook": webhook_url}) + + request_headers = dict(headers or {}) + if hint is not None and self._add_hint_header is not None: + self._add_hint_header(hint, request_headers) + if priority is not None: + if self._add_priority_header is None: + raise RuntimeError("fal_client.client.add_priority_header is required for priority requests") + self._add_priority_header(priority, request_headers) + if start_timeout is not None: + if self._add_timeout_header is None: + raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests") + self._add_timeout_header(start_timeout, request_headers) + + response = self._maybe_retry_request( + self._http_client, + "POST", + url, + json=arguments, + timeout=getattr(self._sync_client, "default_timeout", 120.0), + headers=request_headers, + ) + self._raise_for_status(response) + + data = response.json() + return self._request_handle_class( + request_id=data["request_id"], + response_url=data["response_url"], + status_url=data["status_url"], + cancel_url=data["cancel_url"], + client=self._http_client, + ) diff --git a/tools/file_tools.py b/tools/file_tools.py index 2cedc4bcd5f..a5be71a8bfe 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -174,6 +174,37 @@ def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None return None +def _check_cross_profile_path(filepath: str, task_id: str = "default") -> str | None: + """Return a cross-profile warning string when ``filepath`` lands in + another Hermes profile's skills/plugins/cron/memories directory. + + Returns ``None`` when the write is in-scope (same profile) or outside + Hermes scope entirely. Soft guard โ€” the agent can override by passing + ``cross_profile=True`` to its write tool after explicit user direction. + + Defense-in-depth, NOT a security boundary โ€” the terminal tool runs + as the same OS user and can write any of these paths directly. + See ``agent/file_safety.classify_cross_profile_target`` for the + detection rules. + """ + try: + from agent.file_safety import get_cross_profile_warning + except Exception: + # Fail open on import error โ€” the existing sensitive-path guard + # plus the write_denied list still apply. + return None + + # Resolve via the task's cwd so a relative ``skills/foo/SKILL.md`` + # in a session that cd'd into ``~/.hermes/profiles/other/`` is + # classified against the right base. + try: + resolved = str(_resolve_path_for_task(filepath, task_id)) + except (OSError, ValueError): + resolved = filepath + + return get_cross_profile_warning(resolved) + + def _is_expected_write_exception(exc: Exception) -> bool: """Return True for expected write denials that should not hit error logs.""" if isinstance(exc, PermissionError): @@ -474,8 +505,13 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = }) # โ”€โ”€ Hermes internal path guard โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - # Prevent prompt injection via catalog or hub metadata files. - block_error = get_read_block_error(path) + # Prevent prompt injection via catalog or hub metadata files, + # and block credential stores under HERMES_HOME. Pass the + # already-resolved path so a relative-path read against + # TERMINAL_CWD == HERMES_HOME (e.g. "auth.json") still hits the + # denylist โ€” get_read_block_error's own resolve() runs against + # the Python process cwd, which can differ. + block_error = get_read_block_error(str(_resolved)) if block_error: return json.dumps({"error": block_error}) @@ -790,11 +826,23 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None: return None -def write_file_tool(path: str, content: str, task_id: str = "default") -> str: - """Write content to a file.""" +def write_file_tool(path: str, content: str, task_id: str = "default", + cross_profile: bool = False) -> str: + """Write content to a file. + + ``cross_profile`` opts out of the soft cross-Hermes-profile guard. The + guard fires only on writes that land in another profile's + skills/plugins/cron/memories directory; everything else is unaffected. + Pass ``True`` after explicit user direction โ€” same shape as ``force`` + on the terminal tool. + """ sensitive_err = _check_sensitive_path(path, task_id) if sensitive_err: return tool_error(sensitive_err) + if not cross_profile: + cross_warning = _check_cross_profile_path(path, task_id) + if cross_warning: + return tool_error(cross_warning) if _is_internal_file_status_text(content): return tool_error( "Refusing to write internal read_file status text as file content. " @@ -849,8 +897,13 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str: def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False, patch: str = None, - task_id: str = "default") -> str: - """Patch a file using replace mode or V4A patch format.""" + task_id: str = "default", cross_profile: bool = False) -> str: + """Patch a file using replace mode or V4A patch format. + + ``cross_profile`` opts out of the soft cross-Hermes-profile guard for + targets under another profile's skills/plugins/cron/memories + directory. Same shape as ``write_file``'s flag. + """ # Check sensitive paths for both replace (explicit path) and V4A patch (extract paths) _paths_to_check = [] if path: @@ -863,6 +916,10 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, sensitive_err = _check_sensitive_path(_p, task_id) if sensitive_err: return tool_error(sensitive_err) + if not cross_profile: + cross_warning = _check_cross_profile_path(_p, task_id) + if cross_warning: + return tool_error(cross_warning) try: # Resolve paths for locking. Ordered + deduplicated so concurrent # callers lock in the same order โ€” prevents deadlock on overlapping @@ -1047,7 +1104,12 @@ WRITE_FILE_SCHEMA = { "type": "object", "properties": { "path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"}, - "content": {"type": "string", "description": "Complete content to write to the file"} + "content": {"type": "string", "description": "Complete content to write to the file"}, + "cross_profile": { + "type": "boolean", + "description": "Opt out of the cross-profile soft guard. Defaults to false. Set true ONLY after explicit user direction to edit another Hermes profile's skills/plugins/cron/memories โ€” by default these writes are blocked with a warning because they affect a different profile than the one this session is running under.", + "default": False, + }, }, "required": ["path", "content"] } @@ -1094,6 +1156,11 @@ PATCH_SCHEMA = { "type": "string", "description": "REQUIRED when mode='patch'. V4A format patch content. Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch", }, + "cross_profile": { + "type": "boolean", + "description": "Opt out of the cross-profile soft guard. Defaults to false. Set true ONLY after explicit user direction to edit another Hermes profile's skills/plugins/cron/memories.", + "default": False, + }, }, "required": ["mode"], }, @@ -1144,7 +1211,10 @@ def _handle_write_file(args, **kw): f"write_file: 'content' must be a string, got " f"{type(args['content']).__name__}." ) - return write_file_tool(path=args["path"], content=args["content"], task_id=tid) + return write_file_tool( + path=args["path"], content=args["content"], task_id=tid, + cross_profile=bool(args.get("cross_profile", False)), + ) def _handle_patch(args, **kw): @@ -1152,7 +1222,9 @@ def _handle_patch(args, **kw): return patch_tool( mode=args.get("mode", "replace"), path=args.get("path"), old_string=args.get("old_string"), new_string=args.get("new_string"), - replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid) + replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid, + cross_profile=bool(args.get("cross_profile", False)), + ) def _handle_search_files(args, **kw): diff --git a/tools/image_generation_tool.py b/tools/image_generation_tool.py index 3d171f093c9..584f5e9fa1c 100644 --- a/tools/image_generation_tool.py +++ b/tools/image_generation_tool.py @@ -26,8 +26,7 @@ import os import datetime import threading import uuid -from typing import Any, Dict, Optional, Union -from urllib.parse import urlencode +from typing import Any, Dict, Optional # fal_client is imported lazily โ€” see _load_fal_client(). Pulling it # eagerly added ~64 ms to every CLI cold start because @@ -52,19 +51,17 @@ def _load_fal_client() -> Any: global fal_client if fal_client is not None: return fal_client - try: - from tools.lazy_deps import ensure as _lazy_ensure - _lazy_ensure("image.fal", prompt=False) - except ImportError: - pass - except Exception as e: - raise ImportError(str(e)) - import fal_client as _fal_client # noqa: F811 โ€” module-global rebind - fal_client = _fal_client + from tools.fal_common import import_fal_client + fal_client = import_fal_client() return fal_client from tools.debug_helpers import DebugSession +from tools.fal_common import ( + _ManagedFalSyncClient, + _extract_http_status, + _normalize_fal_queue_url_format, # noqa: F401 โ€” re-exported for tests +) from tools.managed_tool_gateway import resolve_managed_tool_gateway from tools.tool_backend_helpers import ( fal_key_is_configured, @@ -360,95 +357,6 @@ def _resolve_managed_fal_gateway(): return resolve_managed_tool_gateway("fal-queue") -def _normalize_fal_queue_url_format(queue_run_origin: str) -> str: - normalized_origin = str(queue_run_origin or "").strip().rstrip("/") - if not normalized_origin: - raise ValueError("Managed FAL queue origin is required") - return f"{normalized_origin}/" - - -class _ManagedFalSyncClient: - """Small per-instance wrapper around fal_client.SyncClient for managed queue hosts.""" - - def __init__(self, *, key: str, queue_run_origin: str): - # Trigger the lazy import on first construction. Idempotent โ€” the - # placeholder is overwritten with the real module on first call. - _load_fal_client() - sync_client_class = getattr(fal_client, "SyncClient", None) - if sync_client_class is None: - raise RuntimeError("fal_client.SyncClient is required for managed FAL gateway mode") - - client_module = getattr(fal_client, "client", None) - if client_module is None: - raise RuntimeError("fal_client.client is required for managed FAL gateway mode") - - self._queue_url_format = _normalize_fal_queue_url_format(queue_run_origin) - self._sync_client = sync_client_class(key=key) - self._http_client = getattr(self._sync_client, "_client", None) - self._maybe_retry_request = getattr(client_module, "_maybe_retry_request", None) - self._raise_for_status = getattr(client_module, "_raise_for_status", None) - self._request_handle_class = getattr(client_module, "SyncRequestHandle", None) - self._add_hint_header = getattr(client_module, "add_hint_header", None) - self._add_priority_header = getattr(client_module, "add_priority_header", None) - self._add_timeout_header = getattr(client_module, "add_timeout_header", None) - - if self._http_client is None: - raise RuntimeError("fal_client.SyncClient._client is required for managed FAL gateway mode") - if self._maybe_retry_request is None or self._raise_for_status is None: - raise RuntimeError("fal_client.client request helpers are required for managed FAL gateway mode") - if self._request_handle_class is None: - raise RuntimeError("fal_client.client.SyncRequestHandle is required for managed FAL gateway mode") - - def submit( - self, - application: str, - arguments: Dict[str, Any], - *, - path: str = "", - hint: Optional[str] = None, - webhook_url: Optional[str] = None, - priority: Any = None, - headers: Optional[Dict[str, str]] = None, - start_timeout: Optional[Union[int, float]] = None, - ): - url = self._queue_url_format + application - if path: - url += "/" + path.lstrip("/") - if webhook_url is not None: - url += "?" + urlencode({"fal_webhook": webhook_url}) - - request_headers = dict(headers or {}) - if hint is not None and self._add_hint_header is not None: - self._add_hint_header(hint, request_headers) - if priority is not None: - if self._add_priority_header is None: - raise RuntimeError("fal_client.client.add_priority_header is required for priority requests") - self._add_priority_header(priority, request_headers) - if start_timeout is not None: - if self._add_timeout_header is None: - raise RuntimeError("fal_client.client.add_timeout_header is required for timeout requests") - self._add_timeout_header(start_timeout, request_headers) - - response = self._maybe_retry_request( - self._http_client, - "POST", - url, - json=arguments, - timeout=getattr(self._sync_client, "default_timeout", 120.0), - headers=request_headers, - ) - self._raise_for_status(response) - - data = response.json() - return self._request_handle_class( - request_id=data["request_id"], - response_url=data["response_url"], - status_url=data["status_url"], - cancel_url=data["cancel_url"], - client=self._http_client, - ) - - def _get_managed_fal_client(managed_gateway): """Reuse the managed FAL client so its internal httpx.Client is not leaked per call.""" global _managed_fal_client, _managed_fal_client_config @@ -461,7 +369,11 @@ def _get_managed_fal_client(managed_gateway): if _managed_fal_client is not None and _managed_fal_client_config == client_config: return _managed_fal_client + # Resolve fal_client on the legacy module โ€” preserves the test + # pattern of monkey-patching ``image_generation_tool.fal_client``. + _load_fal_client() _managed_fal_client = _ManagedFalSyncClient( + fal_client, key=managed_gateway.nous_user_token, queue_run_origin=managed_gateway.gateway_origin, ) @@ -502,24 +414,6 @@ def _submit_fal_request(model: str, arguments: Dict[str, Any]): raise -def _extract_http_status(exc: BaseException) -> Optional[int]: - """Return an HTTP status code from httpx/fal exceptions, else None. - - Defensive across exception shapes โ€” httpx.HTTPStatusError exposes - ``.response.status_code`` while fal_client wrappers may expose - ``.status_code`` directly. - """ - response = getattr(exc, "response", None) - if response is not None: - status = getattr(response, "status_code", None) - if isinstance(status, int): - return status - status = getattr(exc, "status_code", None) - if isinstance(status, int): - return status - return None - - # --------------------------------------------------------------------------- # Model resolution + payload construction # --------------------------------------------------------------------------- @@ -973,9 +867,12 @@ def _read_configured_image_provider(): """Return the value of ``image_gen.provider`` from config.yaml, or None. We only consult the plugin registry when this is explicitly set โ€” an - unset value keeps users on the legacy in-tree FAL path even when other + unset value keeps users on the in-tree FAL fallback even when other providers happen to be registered (e.g. a user has OPENAI_API_KEY set - for other features but never asked for OpenAI image gen). + for other features but never asked for OpenAI image gen). ``"fal"`` + explicitly routes through ``plugins/image_gen/fal/`` (which delegates + back into this module's pipeline via call-time indirection โ€” see + issue #26241). """ try: from hermes_cli.config import load_config @@ -994,15 +891,16 @@ def _dispatch_to_plugin_provider(prompt: str, aspect_ratio: str): """Route the call to a plugin-registered provider when one is selected. Returns a JSON string on dispatch, or ``None`` to fall through to the - built-in FAL path. + in-tree FAL fallback in ``image_generate_tool``. - Dispatch only fires when ``image_gen.provider`` is explicitly set AND - it does not point to ``fal`` (FAL still lives in-tree in this PR; - a later PR ports it into ``plugins/image_gen/fal/``). Any other value - that matches a registered plugin provider wins. + Dispatch fires when ``image_gen.provider`` is explicitly set โ€” including + ``"fal"`` itself, which now resolves to the + ``plugins/image_gen/fal/`` plugin (the plugin re-enters this module's + pipeline via ``_it`` indirection so behavior is identical to the + direct call, just routed through the registry). """ configured = _read_configured_image_provider() - if not configured or configured == "fal": + if not configured: return None # Also read configured model so we can pass it to the plugin diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index e50efc05a0c..75c1c5e8633 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1255,6 +1255,15 @@ class MCPServerTask: async def _run_stdio(self, config: dict): """Run the server using stdio transport.""" + if not _MCP_AVAILABLE: + raise ImportError( + f"MCP server '{self.name}' requires the 'mcp' Python SDK, but " + "it is not installed. Install with:\n" + " pip install 'hermes-agent[mcp]'\n" + "or (full install):\n" + " pip install 'hermes-agent[all]'" + ) + command = config.get("command") args = config.get("args", []) user_env = config.get("env") diff --git a/tools/memory_tool.py b/tools/memory_tool.py index 78d3a154933..97ea5ae7cf5 100644 --- a/tools/memory_tool.py +++ b/tools/memory_tool.py @@ -28,6 +28,7 @@ import logging import os import re import tempfile +import time from contextlib import contextmanager from pathlib import Path from hermes_constants import get_hermes_home @@ -104,6 +105,36 @@ def _scan_memory_content(content: str) -> Optional[str]: return None +def _drift_error(path: "Path", bak_path: str) -> Dict[str, Any]: + """Build the error dict returned when external drift is detected. + + The on-disk memory file contains content that wouldn't round-trip + through the tool's parser/serializer โ€” flushing would discard the + appended/edited content from a patch tool, shell append, manual edit, + or sister-session write. We refuse the mutation, point the operator at + the .bak. snapshot we took, and tell them what to do next. + """ + return { + "success": False, + "error": ( + f"Refusing to write {path.name}: file on disk has content that " + f"wouldn't round-trip through the memory tool (likely added by " + f"the patch tool, a shell append, a manual edit, or a " + f"concurrent session). A snapshot was saved to {bak_path}. " + f"Resolve the drift first โ€” either rewrite the file as a clean " + f"ยง-delimited list of entries, or move the extra content out โ€” " + f"then retry. This guard exists to prevent silent data loss " + f"(issue #26045)." + ), + "drift_backup": bak_path, + "remediation": ( + "Open the .bak file, integrate the missing entries into the " + "memory tool one at a time via memory(action=add, content=...), " + "then remove or rewrite the original file to a clean state." + ), + } + + class MemoryStore: """ Bounded curated memory with file persistence. One instance per AIAgent. @@ -185,14 +216,23 @@ class MemoryStore: return mem_dir / "USER.md" return mem_dir / "MEMORY.md" - def _reload_target(self, target: str): + def _reload_target(self, target: str) -> Optional[str]: """Re-read entries from disk into in-memory state. Called under file lock to get the latest state before mutating. + Returns the backup path if external drift was detected (the on-disk + file contains content that wouldn't round-trip through our + parser/serializer, OR an entry larger than the store's char limit). + When drift is detected the caller must abort the mutation โ€” + flushing would discard the un-roundtrippable content. + Returns None on clean reload. """ - fresh = self._read_file(self._path_for(target)) + path = self._path_for(target) + bak = self._detect_external_drift(target) + fresh = self._read_file(path) fresh = list(dict.fromkeys(fresh)) # deduplicate self._set_entries(target, fresh) + return bak def save_to_disk(self, target: str): """Persist entries to the appropriate file. Called after every mutation.""" @@ -233,8 +273,13 @@ class MemoryStore: return {"success": False, "error": scan_error} with self._file_lock(self._path_for(target)): - # Re-read from disk under lock to pick up writes from other sessions - self._reload_target(target) + # Re-read from disk under lock to pick up writes from other sessions. + # If external drift was detected, the file was backed up to .bak. + # โ€” refuse the mutation so we don't clobber the un-roundtrippable + # content the patch tool / shell append / sister session wrote. + bak = self._reload_target(target) + if bak: + return _drift_error(self._path_for(target), bak) entries = self._entries_for(target) limit = self._char_limit(target) @@ -281,7 +326,9 @@ class MemoryStore: return {"success": False, "error": scan_error} with self._file_lock(self._path_for(target)): - self._reload_target(target) + bak = self._reload_target(target) + if bak: + return _drift_error(self._path_for(target), bak) entries = self._entries_for(target) matches = [(i, e) for i, e in enumerate(entries) if old_text in e] @@ -331,7 +378,9 @@ class MemoryStore: return {"success": False, "error": "old_text cannot be empty."} with self._file_lock(self._path_for(target)): - self._reload_target(target) + bak = self._reload_target(target) + if bak: + return _drift_error(self._path_for(target), bak) entries = self._entries_for(target) matches = [(i, e) for i, e in enumerate(entries) if old_text in e] @@ -430,6 +479,61 @@ class MemoryStore: entries = [e.strip() for e in raw.split(ENTRY_DELIMITER)] return [e for e in entries if e] + def _detect_external_drift(self, target: str) -> Optional[str]: + """Return a backup-path string if on-disk content shows external drift. + + The memory file is supposed to be a list of small entries the tool + wrote, joined by ยง. Detect drift via two signals: + + 1. Round-trip mismatch โ€” re-parsing and re-serializing the file + doesn't produce identical bytes (rare; would catch oddly-encoded + delimiters). + 2. Entry-size overflow โ€” any single parsed entry exceeds the + store's whole-file char limit. The tool budgets the ENTIRE store + against that limit; no single tool-written entry can exceed it. + When we see one entry larger than the limit, an external writer + (patch tool, shell append, manual edit, sister session) appended + free-form content into what the tool will treat as one entry. + Flushing would then truncate that entry to the model's new + content, discarding the appended bytes โ€” issue #26045. + + Returns the absolute path of the .bak file when drift was found and + backed up; returns None when the file looks tool-shaped. + + Note: this is an INSTANCE method (not static) because we need the + per-target char_limit for signal #2. + """ + path = self._path_for(target) + if not path.exists(): + return None + try: + raw = path.read_text(encoding="utf-8") + except (OSError, IOError): + return None + if not raw.strip(): + return None + + parsed = [e.strip() for e in raw.split(ENTRY_DELIMITER) if e.strip()] + roundtrip = ENTRY_DELIMITER.join(parsed) + + char_limit = self._char_limit(target) + max_entry_len = max((len(e) for e in parsed), default=0) + + drift_detected = (raw.strip() != roundtrip) or (max_entry_len > char_limit) + if not drift_detected: + return None + + # Drift confirmed โ€” snapshot the file so the operator can recover + # whatever the external writer added, then return the .bak path so + # the caller can refuse the mutation. + ts = int(time.time()) + bak_path = path.with_suffix(path.suffix + f".bak.{ts}") + try: + bak_path.write_text(raw, encoding="utf-8") + except (OSError, IOError): + return str(bak_path) + " (BACKUP FAILED โ€” file unchanged on disk)" + return str(bak_path) + @staticmethod def _write_file(path: Path, entries: List[str]): """Write entries to a memory file using atomic temp-file + rename. diff --git a/tools/process_registry.py b/tools/process_registry.py index 771ebf0b474..38c35b3c5a0 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -434,9 +434,50 @@ class ProcessRegistry: @staticmethod def _terminate_host_pid(pid: int) -> None: - """Terminate a host-visible PID without requiring the original process handle.""" + """Terminate a host-visible PID and its descendants. + + POSIX: walks the process tree with ``psutil`` and SIGTERMs + children before the parent so subprocess trees (e.g. Chromium + renderers/GPU helpers spawned by an ``agent-browser`` daemon) + don't get reparented to init and survive cleanup. + + Windows: shells out to ``taskkill /PID /T /F``. This is + the documented Microsoft primitive for tree-kill and matches the + existing convention in ``gateway.status.terminate_pid``. We can't + reuse the POSIX psutil path on Windows because: + + 1. Windows doesn't maintain a Unix-style process tree โ€” + ``psutil.Process.children(recursive=True)`` walks PPID + links that go stale when intermediate processes exit, so + enumeration is best-effort and misses orphaned descendants. + 2. ``psutil.Process.terminate()`` on Windows is + ``TerminateProcess()`` which kills only the target handle + and is a hard kill โ€” there is no Windows equivalent of a + SIGTERM that cascades through a process group. (See the + warning in ``gateway/status.py::terminate_pid``: "os.kill + with SIGTERM is not equivalent to a tree-killing hard stop" + on Windows.) Headless Chromium has no GUI window, so the + softer ``taskkill /T`` without ``/F`` won't reach it either. + + ``psutil`` is a hard dependency (see ``pyproject.toml``); the + bare-``os.kill`` fallback covers OSError / PermissionError on + POSIX and a missing ``taskkill.exe`` on Windows (effectively + unreachable on real Windows installs, but cheap insurance). + """ if _IS_WINDOWS: - os.kill(pid, signal.SIGTERM) + try: + subprocess.run( + ["taskkill", "/PID", str(pid), "/T", "/F"], + capture_output=True, + text=True, + timeout=10, + creationflags=windows_hide_flags(), + ) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + try: + os.kill(pid, signal.SIGTERM) + except (OSError, ProcessLookupError, PermissionError): + pass return import psutil diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 284eaab56a1..0f83e40c3c9 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -139,7 +139,7 @@ SEND_MESSAGE_SCHEMA = { }, "message": { "type": "string", - "description": "The message text to send. To send an image or file, include MEDIA: (e.g. 'MEDIA:/tmp/hermes/cache/img_xxx.jpg') in the message โ€” the platform will deliver it as a native media attachment." + "description": "The message text to send. To send an image or file, include MEDIA: for a file under a Hermes media cache or HERMES_MEDIA_ALLOW_DIRS โ€” the platform will deliver it as a native media attachment." } }, "required": [] @@ -251,6 +251,7 @@ def _handle_send(args): force_document_attachments = "[[as_document]]" in message media_files, cleaned_message = BasePlatformAdapter.extract_media(message) + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) mirror_text = cleaned_message.strip() or _describe_media_for_mirror(media_files) used_home_channel = False @@ -563,7 +564,6 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, """ from gateway.config import Platform from gateway.platforms.base import BasePlatformAdapter, utf16_len - from gateway.platforms.discord import DiscordAdapter from gateway.platforms.slack import SlackAdapter # Telegram adapter import is optional (requires python-telegram-bot) @@ -589,10 +589,10 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, except Exception: logger.debug("Failed to apply Slack mrkdwn formatting in _send_to_platform", exc_info=True) - # Platform message length limits (from adapter class attributes) + # Platform message length limits (from adapter class attributes for + # built-in platforms; from PlatformEntry.max_message_length for plugins). _MAX_LENGTHS = { Platform.TELEGRAM: TelegramAdapter.MAX_MESSAGE_LENGTH if _telegram_available else 4096, - Platform.DISCORD: DiscordAdapter.MAX_MESSAGE_LENGTH, Platform.SLACK: SlackAdapter.MAX_MESSAGE_LENGTH, } if _feishu_available: @@ -642,17 +642,27 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, if platform == Platform.WEIXIN: return await _send_weixin(pconfig, chat_id, message, media_files=media_files) - # --- Discord: special handling for media attachments --- + # --- Discord: chunked delivery via the registry's standalone_sender_fn. + # The plugin's ``_standalone_send`` (registered in + # plugins/platforms/discord/adapter.py) handles forum channels, threads, + # and multipart media uploads. ``_send_via_adapter`` tries the live + # in-process adapter first via ``adapter.send()``, but Discord's elif + # historically went straight to the HTTP path; we preserve that by + # explicitly invoking the registry hook here so behavior is unchanged. if platform == Platform.DISCORD: + from gateway.platform_registry import platform_registry + entry = platform_registry.get("discord") + if entry is None or entry.standalone_sender_fn is None: + return {"error": "Discord plugin not registered or missing standalone_sender_fn"} last_result = None for i, chunk in enumerate(chunks): is_last = (i == len(chunks) - 1) - result = await _send_discord( - pconfig.token, + result = await entry.standalone_sender_fn( + pconfig, chat_id, chunk, - media_files=media_files if is_last else [], thread_id=thread_id, + media_files=media_files if is_last else [], ) if isinstance(result, dict) and result.get("error"): return result @@ -1026,227 +1036,6 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No return _error(f"Telegram send failed: {e}") -def _derive_forum_thread_name(message: str) -> str: - """Derive a thread name from the first line of the message, capped at 100 chars.""" - first_line = message.strip().split("\n", 1)[0].strip() - # Strip common markdown heading prefixes - first_line = first_line.lstrip("#").strip() - if not first_line: - first_line = "New Post" - return first_line[:100] - - -# Process-local cache for Discord channel-type probes. Avoids re-probing the -# same channel on every send when the directory cache has no entry (e.g. fresh -# install, or channel created after the last directory build). -_DISCORD_CHANNEL_TYPE_PROBE_CACHE: Dict[str, bool] = {} - - -def _remember_channel_is_forum(chat_id: str, is_forum: bool) -> None: - _DISCORD_CHANNEL_TYPE_PROBE_CACHE[str(chat_id)] = bool(is_forum) - - -def _probe_is_forum_cached(chat_id: str) -> Optional[bool]: - return _DISCORD_CHANNEL_TYPE_PROBE_CACHE.get(str(chat_id)) - - -async def _send_discord(token, chat_id, message, thread_id=None, media_files=None): - """Send a single message via Discord REST API (no websocket client needed). - - Chunking is handled by _send_to_platform() before this is called. - - When thread_id is provided, the message is sent directly to that thread - via the /channels/{thread_id}/messages endpoint. - - Media files are uploaded one-by-one via multipart/form-data after the - text message is sent (same pattern as Telegram). - - Forum channels (type 15) reject POST /messages โ€” a thread post is created - automatically via POST /channels/{id}/threads. Media files are uploaded - as multipart attachments on the starter message of the new thread. - - Channel type is resolved from the channel directory first, then a - process-local probe cache, and only as a last resort with a live - GET /channels/{id} probe (whose result is memoized). - """ - try: - import aiohttp - except ImportError: - return {"error": "aiohttp not installed. Run: pip install aiohttp"} - try: - from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp - _proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY") - _sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy) - auth_headers = {"Authorization": f"Bot {token}"} - json_headers = {**auth_headers, "Content-Type": "application/json"} - media_files = media_files or [] - last_data = None - warnings = [] - - # Thread endpoint: Discord threads are channels; send directly to the thread ID. - if thread_id: - url = f"https://discord.com/api/v10/channels/{thread_id}/messages" - else: - # Check if the target channel is a forum channel (type 15). - # Forum channels reject POST /messages โ€” create a thread post instead. - # Three-layer detection: directory cache โ†’ process-local probe - # cache โ†’ GET /channels/{id} probe (with result memoized). - _channel_type = None - try: - from gateway.channel_directory import lookup_channel_type - _channel_type = lookup_channel_type("discord", chat_id) - except Exception: - pass - - if _channel_type == "forum": - is_forum = True - elif _channel_type is not None: - is_forum = False - else: - cached = _probe_is_forum_cached(chat_id) - if cached is not None: - is_forum = cached - else: - is_forum = False - try: - info_url = f"https://discord.com/api/v10/channels/{chat_id}" - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15), **_sess_kw) as info_sess: - async with info_sess.get(info_url, headers=json_headers, **_req_kw) as info_resp: - if info_resp.status == 200: - info = await info_resp.json() - is_forum = info.get("type") == 15 - _remember_channel_is_forum(chat_id, is_forum) - except Exception: - logger.debug("Failed to probe channel type for %s", chat_id, exc_info=True) - - if is_forum: - thread_name = _derive_forum_thread_name(message) - thread_url = f"https://discord.com/api/v10/channels/{chat_id}/threads" - - # Filter to readable media files up front so we can pick the - # right code path (JSON vs multipart) before opening a session. - valid_media = [] - for media_path, _is_voice in media_files: - if not os.path.exists(media_path): - warning = f"Media file not found, skipping: {media_path}" - logger.warning(warning) - warnings.append(warning) - continue - valid_media.append(media_path) - - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=60), **_sess_kw) as session: - if valid_media: - # Multipart: payload_json + files[N] creates a forum - # thread with the starter message plus attachments in - # a single API call. - attachments_meta = [ - {"id": str(idx), "filename": os.path.basename(path)} - for idx, path in enumerate(valid_media) - ] - starter_message = {"content": message, "attachments": attachments_meta} - payload_json = json.dumps({"name": thread_name, "message": starter_message}) - - form = aiohttp.FormData() - form.add_field("payload_json", payload_json, content_type="application/json") - - # Buffer file bytes up front โ€” aiohttp's FormData can - # read lazily and we don't want handles closing under - # it on retry. - try: - for idx, media_path in enumerate(valid_media): - with open(media_path, "rb") as fh: - form.add_field( - f"files[{idx}]", - fh.read(), - filename=os.path.basename(media_path), - ) - async with session.post(thread_url, headers=auth_headers, data=form, **_req_kw) as resp: - if resp.status not in {200, 201}: - body = await resp.text() - return _error(f"Discord forum thread creation error ({resp.status}): {body}") - data = await resp.json() - except Exception as e: - return _error(_sanitize_error_text(f"Discord forum thread upload failed: {e}")) - else: - # No media โ€” simple JSON POST creates the thread with - # just the text starter. - async with session.post( - thread_url, - headers=json_headers, - json={ - "name": thread_name, - "message": {"content": message}, - }, - **_req_kw, - ) as resp: - if resp.status not in {200, 201}: - body = await resp.text() - return _error(f"Discord forum thread creation error ({resp.status}): {body}") - data = await resp.json() - - thread_id_created = data.get("id") - starter_msg_id = (data.get("message") or {}).get("id", thread_id_created) - result = { - "success": True, - "platform": "discord", - "chat_id": chat_id, - "thread_id": thread_id_created, - "message_id": starter_msg_id, - } - if warnings: - result["warnings"] = warnings - return result - - url = f"https://discord.com/api/v10/channels/{chat_id}/messages" - - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session: - # Send text message (skip if empty and media is present) - if message.strip() or not media_files: - async with session.post(url, headers=json_headers, json={"content": message}, **_req_kw) as resp: - if resp.status not in {200, 201}: - body = await resp.text() - return _error(f"Discord API error ({resp.status}): {body}") - last_data = await resp.json() - - # Send each media file as a separate multipart upload - for media_path, _is_voice in media_files: - if not os.path.exists(media_path): - warning = f"Media file not found, skipping: {media_path}" - logger.warning(warning) - warnings.append(warning) - continue - try: - form = aiohttp.FormData() - filename = os.path.basename(media_path) - with open(media_path, "rb") as f: - form.add_field("files[0]", f, filename=filename) - async with session.post(url, headers=auth_headers, data=form, **_req_kw) as resp: - if resp.status not in {200, 201}: - body = await resp.text() - warning = _sanitize_error_text(f"Failed to send media {media_path}: Discord API error ({resp.status}): {body}") - logger.error(warning) - warnings.append(warning) - continue - last_data = await resp.json() - except Exception as e: - warning = _sanitize_error_text(f"Failed to send media {media_path}: {e}") - logger.error(warning) - warnings.append(warning) - - if last_data is None: - error = "No deliverable text or media remained after processing" - if warnings: - return {"error": error, "warnings": warnings} - return {"error": error} - - result = {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": last_data.get("id")} - if warnings: - result["warnings"] = warnings - return result - except Exception as e: - return _error(f"Discord send failed: {e}") - - async def _send_slack(token, chat_id, message): """Send via Slack Web API.""" try: diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 547167a6623..4ce5f06e4c9 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -40,7 +40,7 @@ import shutil import tempfile from pathlib import Path from hermes_constants import get_hermes_home, display_hermes_home -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, List, Optional, Tuple from utils import atomic_replace, is_truthy_value from hermes_cli.config import cfg_get @@ -295,6 +295,109 @@ def _find_skill(name: str) -> Optional[Dict[str, Any]]: return None +def _find_skill_in_other_profiles(name: str) -> List[Tuple[str, Path]]: + """Look for ``name`` under SKILL.md across OTHER Hermes profiles. + + Returns a list of ``(profile_name, skill_dir)`` pairs. Used to make + the "Skill X not found" error explain when the user is editing the + wrong profile. Empty list when no other profile has the skill (or + when profile discovery fails โ€” fail-quiet, the caller falls back to + the plain "not found" error). + """ + matches: List[Tuple[str, Path]] = [] + try: + from hermes_constants import get_default_hermes_root + from agent.skill_utils import is_excluded_skill_path + except Exception: + return matches + + try: + root = get_default_hermes_root() + except Exception: + return matches + + # Collect (profile_name, skills_dir) for every profile EXCEPT the + # one whose SKILLS_DIR we already searched in _find_skill(). + active_dir = SKILLS_DIR.resolve() if SKILLS_DIR.exists() else SKILLS_DIR + candidates: List[Tuple[str, Path]] = [] + + # Default profile (~/.hermes/skills) โ€” only consider when active is non-default. + default_skills = root / "skills" + try: + if default_skills.resolve() != active_dir: + candidates.append(("default", default_skills)) + except (OSError, RuntimeError): + pass + + # All named profiles (~/.hermes/profiles/*/skills) + profiles_root = root / "profiles" + if profiles_root.is_dir(): + try: + for entry in profiles_root.iterdir(): + if not entry.is_dir(): + continue + pskills = entry / "skills" + try: + if pskills.resolve() == active_dir: + continue + except (OSError, RuntimeError): + continue + candidates.append((entry.name, pskills)) + except OSError: + pass + + for profile_name, skills_dir in candidates: + if not skills_dir.is_dir(): + continue + try: + for skill_md in skills_dir.rglob("SKILL.md"): + if is_excluded_skill_path(skill_md): + continue + if skill_md.parent.name == name: + matches.append((profile_name, skill_md.parent)) + break # one match per profile is enough + except OSError: + continue + return matches + + +def _skill_not_found_error(name: str, suffix: str = "") -> str: + """Build a "skill not found" error that names other profiles holding + the same skill, so the agent can recognize a profile-scoping mistake. + + ``suffix`` is appended after the cross-profile hint if present + (e.g. ``" Create it first with action='create'."``). + """ + from agent.file_safety import _resolve_active_profile_name + active = _resolve_active_profile_name() + base = f"Skill '{name}' not found in active profile '{active}'." + + others = _find_skill_in_other_profiles(name) + if others: + if len(others) == 1: + other_profile, other_path = others[0] + base += ( + f" A skill by that name exists in profile " + f"'{other_profile}' ({other_path}). To edit a skill in " + f"another profile, switch profiles (`hermes -p " + f"{other_profile}`) or operate via explicit file tools " + f"with ``cross_profile=True``." + ) + else: + names = ", ".join(f"'{p}'" for p, _ in others) + base += ( + f" Skills by that name exist in other profiles: {names}. " + f"Switch profiles (`hermes -p `) to edit there, or " + f"operate via explicit file tools with ``cross_profile=True``." + ) + else: + base += " Use skills_list() to see available skills." + + if suffix: + base += suffix + return base + + def _validate_file_path(file_path: str) -> Optional[str]: """ Validate a file path for write_file/remove_file. @@ -439,7 +542,7 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]: existing = _find_skill(name) if not existing: - return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."} + return {"success": False, "error": _skill_not_found_error(name)} skill_md = existing["path"] / "SKILL.md" # Back up original content for rollback @@ -479,7 +582,7 @@ def _patch_skill( existing = _find_skill(name) if not existing: - return {"success": False, "error": f"Skill '{name}' not found."} + return {"success": False, "error": _skill_not_found_error(name)} skill_dir = existing["path"] @@ -568,7 +671,7 @@ def _delete_skill(name: str, absorbed_into: Optional[str] = None) -> Dict[str, A """ existing = _find_skill(name) if not existing: - return {"success": False, "error": f"Skill '{name}' not found."} + return {"success": False, "error": _skill_not_found_error(name)} pinned_err = _pinned_guard(name) if pinned_err: @@ -637,7 +740,7 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]: existing = _find_skill(name) if not existing: - return {"success": False, "error": f"Skill '{name}' not found. Create it first with action='create'."} + return {"success": False, "error": _skill_not_found_error(name, " Create it first with action='create'.")} target, err = _resolve_skill_target(existing["path"], file_path) if err: @@ -671,7 +774,7 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]: existing = _find_skill(name) if not existing: - return {"success": False, "error": f"Skill '{name}' not found."} + return {"success": False, "error": _skill_not_found_error(name)} skill_dir = existing["path"] diff --git a/tools/skills_ast_audit.py b/tools/skills_ast_audit.py new file mode 100644 index 00000000000..e127556c1d9 --- /dev/null +++ b/tools/skills_ast_audit.py @@ -0,0 +1,133 @@ +""" +AST-level deep audit for skill Python files โ€” opt-in diagnostic, not a security gate. + +Per SECURITY.md ยง2.4, Skills Guard is in-process heuristics ("useful โ€” not +boundaries"). This module is a separate opt-in diagnostic that flags dynamic +import / dynamic attribute access patterns operators may want to eyeball when +reviewing third-party skill code. Every pattern flagged here has legitimate +uses; findings are hints for human review, not verdicts. + +CLI: ``hermes skills audit --deep`` +""" + +from __future__ import annotations + +import ast +from pathlib import Path +from typing import List, Tuple + +# (file, line, pattern_id, description) +Finding = Tuple[str, int, str, str] + +_IGNORED_DIRS = {"__pycache__", ".venv", "venv", "node_modules"} + + +def _scan_source(content: str, rel_path: str) -> List[Finding]: + try: + tree = ast.parse(content) + except (SyntaxError, ValueError, RecursionError): + return [] + + findings: List[Finding] = [] + + class V(ast.NodeVisitor): + def visit_Call(self, node): + f = node.func + # importlib.import_module(...) + if isinstance(f, ast.Attribute) and f.attr == "import_module": + findings.append((rel_path, node.lineno, "dynamic_import", + "importlib.import_module() โ€” loads arbitrary modules at runtime")) + # __import__() + elif isinstance(f, ast.Name) and f.id == "__import__": + if node.args and not isinstance(node.args[0], ast.Constant): + findings.append((rel_path, node.lineno, "dynamic_import_computed", + "__import__ with non-literal module name")) + # getattr(obj, ) + elif isinstance(f, ast.Name) and f.id == "getattr": + if len(node.args) >= 2 and not isinstance(node.args[1], ast.Constant): + findings.append((rel_path, node.lineno, "dynamic_getattr", + "getattr with non-literal attribute name")) + self.generic_visit(node) + + def visit_Subscript(self, node): + # obj.__dict__[] + if (isinstance(node.value, ast.Attribute) + and node.value.attr == "__dict__" + and not isinstance(node.slice, ast.Constant)): + findings.append((rel_path, node.lineno, "dict_access", + "__dict__[] โ€” dynamic attribute access")) + self.generic_visit(node) + + def visit_Import(self, node): + for a in node.names: + if a.name == "importlib" or a.name.startswith("importlib."): + findings.append((rel_path, node.lineno, "importlib_import", + f"import {a.name} โ€” enables dynamic module loading")) + self.generic_visit(node) + + def visit_ImportFrom(self, node): + m = node.module or "" + if m == "importlib" or m.startswith("importlib."): + findings.append((rel_path, node.lineno, "importlib_import", + f"from {m} import ... โ€” enables dynamic module loading")) + self.generic_visit(node) + + try: + V().visit(tree) + except (RecursionError, ValueError, RuntimeError): + # Hostile/pathological input: return what we collected so far. + pass + + return findings + + +def ast_scan_path(path: Path) -> List[Finding]: + """Scan a single .py file or recursively scan all .py under a directory. + + Returns a list of (file, line, pattern_id, description) tuples. Empty for + non-Python paths, missing paths, or paths with no matching patterns. + """ + if path.is_file(): + if path.suffix.lower() != ".py": + return [] + try: + content = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return [] + return _scan_source(content, path.name) + + if not path.is_dir(): + return [] + + out: List[Finding] = [] + for py in sorted(path.rglob("*.py")): + if set(py.parent.parts) & _IGNORED_DIRS: + continue + try: + content = py.read_text(encoding="utf-8", errors="replace") + except OSError: + continue + try: + rel = py.relative_to(path).as_posix() + except ValueError: + rel = py.name + out.extend(_scan_source(content, rel)) + return out + + +def format_ast_report(findings: List[Finding], skill_name: str = "") -> str: + """Plain-text report (Rich-markup-free) grouped by file.""" + header = f"AST deep scan: {skill_name}" if skill_name else "AST deep scan" + if not findings: + return f"{header}\n No dynamic import/access patterns detected." + + lines = [header, f" {len(findings)} finding(s):"] + current = None + for f, line, pid, desc in sorted(findings): + if f != current: + current = f + lines.append(f" {f}") + lines.append(f" L{line} {pid} โ€” {desc}") + lines.append("") + lines.append(" Note: diagnostic hints for human review, not security verdicts.") + return "\n".join(lines) diff --git a/tools/skills_guard.py b/tools/skills_guard.py index 1610c3225cb..28d29daa5c6 100644 --- a/tools/skills_guard.py +++ b/tools/skills_guard.py @@ -661,7 +661,7 @@ def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, if decision == "allow": return True, f"Allowed ({result.trust_level} source, {result.verdict} verdict)" - if force: + if force and not (result.verdict == "dangerous" and result.trust_level in ("community", "trusted")): return True, ( f"Force-installed despite {result.verdict} verdict " f"({len(result.findings)} findings)" @@ -674,6 +674,13 @@ def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, f"{len(result.findings)} findings)" ) + # Dangerous verdicts cannot be overridden by --force (community/trusted); + # other blocks can. + if result.verdict == "dangerous" and result.trust_level in ("community", "trusted"): + return False, ( + f"Blocked ({result.trust_level} source + dangerous verdict, " + f"{len(result.findings)} findings). --force does not override a dangerous verdict." + ) return False, ( f"Blocked ({result.trust_level} source + {result.verdict} verdict, " f"{len(result.findings)} findings). Use --force to override." @@ -717,12 +724,24 @@ def format_scan_report(result: ScanResult) -> str: def content_hash(skill_path: Path) -> str: - """Compute a SHA-256 hash of all files in a skill directory for integrity tracking.""" + """Compute a SHA-256 hash of all files in a skill directory for integrity tracking. + + File paths (relative to ``skill_path``) are mixed into the hash alongside + file contents so that swapping the contents of two files in a skill + changes the hash. This must stay symmetric with + ``tools.skills_hub.bundle_content_hash`` โ€” both functions need to + produce the same digest for the same skill (one operates on disk, + one on an in-memory bundle), so any change to the hash shape MUST + land in both places at once. + """ h = hashlib.sha256() if skill_path.is_dir(): for f in sorted(skill_path.rglob("*")): if f.is_file(): try: + rel = f.relative_to(skill_path).as_posix() + h.update(rel.encode("utf-8")) + h.update(b"\x00") h.update(f.read_bytes()) except OSError: continue @@ -920,7 +939,8 @@ def _determine_verdict(findings: List[Finding]) -> str: return "dangerous" if has_high: return "caution" - return "caution" + # medium/low findings alone are informational, not blocking + return "safe" def _build_summary(name: str, source: str, trust: str, verdict: str, findings: List[Finding]) -> str: diff --git a/tools/skills_hub.py b/tools/skills_hub.py index 79be8dc34fc..35a6749cd5d 100644 --- a/tools/skills_hub.py +++ b/tools/skills_hub.py @@ -3000,6 +3000,13 @@ def uninstall_skill(skill_name: str) -> Tuple[bool, str]: return False, f"'{skill_name}' is not a hub-installed skill (may be a builtin)" install_path = SKILLS_DIR / entry["install_path"] + # Prevent path traversal from poisoned lock.json entries + try: + resolved = install_path.resolve() + if not resolved.is_relative_to(SKILLS_DIR.resolve()): + return False, f"Refusing to remove '{entry['install_path']}': resolves outside skills directory" + except (ValueError, OSError): + return False, f"Refusing to remove '{entry['install_path']}': path resolution failed" if install_path.exists(): shutil.rmtree(install_path) @@ -3013,6 +3020,10 @@ def bundle_content_hash(bundle: SkillBundle) -> str: """Compute a deterministic hash for an in-memory skill bundle.""" h = hashlib.sha256() for rel_path in sorted(bundle.files): + # Include the path so swapping file contents between two paths + # changes the hash (avoids filename-swap evading update detection). + h.update(rel_path.encode("utf-8")) + h.update(b"\x00") content = bundle.files[rel_path] if isinstance(content, bytes): h.update(content) diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 387e27881ad..f7a0e14bc88 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -904,9 +904,9 @@ Do NOT use echo/cat heredoc to create files โ€” use write_file instead. Reserve terminal for: builds, installs, git, processes, scripts, network, package managers, and anything that needs a shell. Foreground (default): Commands return INSTANTLY when done, even if the timeout is high. Set timeout=300 for long builds/scripts โ€” you'll still get the result in seconds if it's fast. Prefer foreground for short commands. -Background: Set background=true to get a session_id. Two patterns: - (1) Long-lived processes that never exit (servers, watchers). - (2) Long-running tasks with notify_on_complete=true โ€” you can keep working on other things and the system auto-notifies you when the task finishes. Great for test suites, builds, deployments, or anything that takes more than a minute. +Background: Set background=true to get a session_id. Almost always pair with notify_on_complete=true โ€” bg without notify runs SILENTLY and you have no way to learn it finished short of calling process(action='poll') yourself. Two legitimate uses: + (1) Long-lived processes that never exit (servers, watchers, daemons) โ€” silent is correct, there's no exit to notify on. + (2) Long-running bounded tasks (tests, builds, deploys, CI pollers, batch jobs) โ€” MUST set notify_on_complete=true. Without it you'll either forget to poll or sit blocked waiting for the user to surface the result. For servers/watchers, do NOT use shell-level background wrappers (nohup/disown/setsid/trailing '&') in foreground mode. Use background=true so Hermes can track lifecycle and output. After starting a server, verify readiness with a health check or log signal, then run tests in a separate terminal() call. Avoid blind sleep loops. Use process(action="poll") for progress checks, process(action="wait") to block until done. @@ -1959,6 +1959,32 @@ def terminal_tool( if pty_disabled_reason: result_data["pty_note"] = pty_disabled_reason + # Nudge: background=True without notify_on_complete=True OR + # watch_patterns is a silent process. The agent has NO way to + # learn it finished short of calling process(action="poll"/"wait") + # explicitly. That's correct only for genuine long-lived + # processes that never exit (servers, watchers). For every + # bounded task (tests, builds, CI pollers, deploys, batch + # jobs) the agent almost certainly wanted notification and + # forgot the flag. May 2026 PR #31231 incident: bg CI poller + # ran fine, exited green, agent never noticed โ€” user had to + # surface the result. Cheap nudge here costs ~one read for + # server cases (false positive) and prevents silent + # blindness for bounded-task cases (false negative). + if background and not notify_on_complete and not watch_patterns: + result_data["hint"] = ( + "background=true without notify_on_complete=true means " + "this process runs SILENTLY โ€” you will not be told when " + "it exits. If this is a bounded task (test suite, build, " + "CI poller, deploy, anything with a defined end), you " + "almost certainly wanted notify_on_complete=true so the " + "system pings you on exit. Re-launch with " + "notify_on_complete=true, or call process(action='poll') " + "/ process(action='wait') yourself to learn the outcome. " + "Only ignore this hint for genuine long-lived processes " + "that never exit (servers, watchers, daemons)." + ) + # Populate routing metadata on the session so that # watch-pattern and completion notifications can be # routed back to the correct chat/thread. @@ -2322,7 +2348,7 @@ TERMINAL_SCHEMA = { }, "background": { "type": "boolean", - "description": "Run the command in the background. Two patterns: (1) Long-lived processes that never exit (servers, watchers). (2) Long-running tasks paired with notify_on_complete=true โ€” you can keep working and get notified when the task finishes. For short commands, prefer foreground with a generous timeout instead.", + "description": "Run the command in the background. Almost always pair with notify_on_complete=true โ€” without it, the process runs silently and you'll have no way to learn it finished short of calling process(action='poll') yourself (easy to forget, leading to silent blindness on long jobs). Two legitimate patterns: (1) Long-lived processes that never exit (servers, watchers, daemons) โ€” these stay silent because there's no exit to notify on. (2) Long-running bounded tasks (tests, builds, deploys, CI pollers, batch jobs) โ€” these MUST set notify_on_complete=true. For short commands, prefer foreground with a generous timeout instead.", "default": False }, "timeout": { diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index d741530d358..a9af32023f3 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -197,6 +197,26 @@ def _normalize_local_command_model(model_name: Optional[str]) -> str: return _normalize_local_model(model_name) +def _try_lazy_install_stt() -> bool: + """Attempt to lazy-install faster-whisper and return True on success. + + The module-level ``_HAS_FASTER_WHISPER`` flag is set at import time and + cached. If the package wasn't installed at startup, calling ``ensure()`` + installs it. This function re-checks dynamically after installation so + the provider can use it immediately without a process restart. + """ + try: + from tools.lazy_deps import ensure + ensure("stt.faster_whisper") + # Re-check dynamically after install + import importlib.util as _iu + if _iu.find_spec("faster_whisper"): + return True + except Exception as exc: + logger.debug("Lazy install of faster-whisper failed: %s", exc) + return False + + def _get_provider(stt_config: dict) -> str: """Determine which STT provider to use. @@ -218,6 +238,9 @@ def _get_provider(stt_config: dict) -> str: return "local" if _has_local_command(): return "local_command" + # Try lazy-install before giving up + if _try_lazy_install_stt(): + return "local" logger.warning( "STT provider 'local' configured but unavailable " "(install faster-whisper or set HERMES_LOCAL_STT_COMMAND)" @@ -285,6 +308,9 @@ def _get_provider(stt_config: dict) -> str: return "local" if _has_local_command(): return "local_command" + # Try lazy-install before falling through to cloud providers + if _try_lazy_install_stt(): + return "local" if _HAS_OPENAI and get_env_value("GROQ_API_KEY"): logger.info("No local STT available, using Groq Whisper API") return "groq" @@ -403,7 +429,8 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]: global _local_model, _local_model_name if not _HAS_FASTER_WHISPER: - return {"success": False, "transcript": "", "error": "faster-whisper not installed"} + if not _try_lazy_install_stt(): + return {"success": False, "transcript": "", "error": "faster-whisper not installed"} try: # Lazy-load the model (downloads on first use, ~150 MB for 'base') diff --git a/tools/vision_tools.py b/tools/vision_tools.py index 912777e2e25..38d19919488 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -914,11 +914,26 @@ async def vision_analyze_tool( def check_vision_requirements() -> bool: - """Check if the configured runtime vision path can resolve a client.""" + """Check if the configured runtime vision path can resolve a client. + + Mirrors the fallback chain that ``call_llm(task="vision")`` actually uses + at runtime: first the explicit ``auxiliary.vision.provider`` (if any), + and if that fails, the auto chain (main provider โ†’ openrouter โ†’ nous). + Without the auto-fallback step the tool would disappear from the model's + tool list whenever the explicit provider name was unresolvable, even + when the auto chain would have served the request (issue #31179). + """ try: from agent.auxiliary_client import resolve_vision_provider_client - + except ImportError: + return False + try: _provider, client, _model = resolve_vision_provider_client() + if client is not None: + return True + # Same fallback to "auto" that call_llm performs when the configured + # provider can't be resolved. + _provider, client, _model = resolve_vision_provider_client(provider="auto") return client is not None except Exception: return False diff --git a/tools/yuanbao_tools.py b/tools/yuanbao_tools.py index 6466458d34f..46f635c9829 100644 --- a/tools/yuanbao_tools.py +++ b/tools/yuanbao_tools.py @@ -472,6 +472,7 @@ async def _handle_yb_send_dm(args, **kw): embedded_media, message = BasePlatformAdapter.extract_media(message) if embedded_media: media_files.extend(embedded_media) + media_files = BasePlatformAdapter.filter_media_delivery_paths(media_files) return tool_result(await send_dm( group_code=group_code, name=args.get("name", ""), diff --git a/toolsets.py b/toolsets.py index 5de07e4c7a1..bab7677887a 100644 --- a/toolsets.py +++ b/toolsets.py @@ -72,6 +72,16 @@ _HERMES_CORE_TOOLS = [ "computer_use", ] +# Webhook events may originate from untrusted third-party content (for example, +# public PR titles/comments). Keep the default webhook toolset intentionally +# constrained to avoid local file/system execution by prompt injection. +_HERMES_WEBHOOK_SAFE_TOOLS = [ + "web_search", + "web_extract", + "vision_analyze", + "clarify", +] + # Core toolset definitions # These can include individual tools or reference other toolsets @@ -523,7 +533,7 @@ TOOLSETS = { "hermes-webhook": { "description": "Webhook toolset - receive and process external webhook events", - "tools": _HERMES_CORE_TOOLS, + "tools": _HERMES_WEBHOOK_SAFE_TOOLS, "includes": [] }, diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 921853a34c5..dc13969f1be 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1061,6 +1061,10 @@ def _session_tool_progress_mode(sid: str) -> str: return str(_sessions.get(sid, {}).get("tool_progress_mode", "all") or "all") +def _session_verbose(sid: str) -> bool: + return _session_tool_progress_mode(sid) == "verbose" + + def _tool_progress_enabled(sid: str) -> bool: return _session_tool_progress_mode(sid) != "off" @@ -1492,6 +1496,74 @@ def _tool_ctx(name: str, args: dict) -> str: return "" +_TUI_VERBOSE_TEXT_MAX_CHARS = 16_000 +_TUI_VERBOSE_TEXT_MAX_LINES = 240 + + +def _cap_tui_verbose_text(text: str) -> str: + if ( + len(text) <= _TUI_VERBOSE_TEXT_MAX_CHARS + and text.count("\n") < _TUI_VERBOSE_TEXT_MAX_LINES + ): + return text + + idx = len(text) + start = 0 + for _ in range(_TUI_VERBOSE_TEXT_MAX_LINES): + idx = text.rfind("\n", 0, idx) + if idx < 0: + start = 0 + break + start = idx + 1 + + line_start = start + start = max(line_start, len(text) - _TUI_VERBOSE_TEXT_MAX_CHARS) + if start > line_start: + next_break = text.find("\n", start) + if 0 <= next_break < len(text) - 1: + start = next_break + 1 + + tail = text[start:].lstrip() + omitted_chars = max(0, len(text) - len(tail)) + omitted_lines = text[:start].count("\n") + if omitted_lines: + label = ( + "[showing verbose tail; omitted " + f"{omitted_lines} lines / {omitted_chars} chars]\n" + ) + else: + label = f"[showing verbose tail; omitted {omitted_chars} chars]\n" + return f"{label}{tail}" + + +def _redact_tui_verbose_text(text: str) -> str: + try: + from agent.redact import redact_sensitive_text + + redacted = redact_sensitive_text(str(text), force=True) + except Exception: + return "" + return _cap_tui_verbose_text(redacted) + + +def _tool_args_text(args: dict) -> str: + try: + raw = json.dumps(args or {}, indent=2, ensure_ascii=False, default=str) + except Exception: + raw = str(args or {}) + return _redact_tui_verbose_text(raw) + + +def _tool_result_text(result: object) -> str: + try: + from agent.tool_dispatch_helpers import _multimodal_text_summary + + raw = _multimodal_text_summary(result) + except Exception: + raw = str(result) + return _redact_tui_verbose_text(raw) + + def _fmt_tool_duration(seconds: float | None) -> str: if seconds is None: return "" @@ -1553,13 +1625,18 @@ def _on_tool_start(sid: str, tool_call_id: str, name: str, args: dict): pass session.setdefault("tool_started_at", {})[tool_call_id] = time.time() if _tool_progress_enabled(sid): + payload = { + "tool_id": tool_call_id, + "name": name, + "context": _tool_ctx(name, args), + } + if _session_verbose(sid): + args_text = _tool_args_text(args) + if args_text: + payload["args_text"] = args_text # tool.complete is the source of truth for todos (full list from the # tool result). args.todos here may be a partial merge update. - _emit( - "tool.start", - sid, - {"tool_id": tool_call_id, "name": name, "context": _tool_ctx(name, args)}, - ) + _emit("tool.start", sid, payload) def _on_tool_complete(sid: str, tool_call_id: str, name: str, args: dict, result: str): @@ -1576,6 +1653,10 @@ def _on_tool_complete(sid: str, tool_call_id: str, name: str, args: dict, result summary = _tool_summary(name, result, duration_s) if summary: payload["summary"] = summary + if _session_verbose(sid): + result_text = _tool_result_text(result) + if result_text: + payload["result_text"] = result_text if name == "todo": try: data = json.loads(result) @@ -1615,7 +1696,10 @@ def _on_tool_progress( _emit("tool.progress", sid, {"name": name, "preview": preview or ""}) return if event_type == "reasoning.available" and preview: - _emit("reasoning.available", sid, {"text": str(preview)}) + payload: dict[str, object] = {"text": str(preview)} + if _session_verbose(sid): + payload["verbose"] = True + _emit("reasoning.available", sid, payload) return if event_type.startswith("subagent."): payload = { @@ -1691,7 +1775,11 @@ def _agent_cbs(sid: str) -> dict: "tool_gen_callback": lambda name: _tool_progress_enabled(sid) and _emit("tool.generating", sid, {"name": name}), "thinking_callback": lambda text: _emit("thinking.delta", sid, {"text": text}), - "reasoning_callback": lambda text: _emit("reasoning.delta", sid, {"text": text}), + "reasoning_callback": lambda text: _emit( + "reasoning.delta", + sid, + {"text": text, **({"verbose": True} if _session_verbose(sid) else {})}, + ), "status_callback": lambda kind, text=None: _status_update( sid, str(kind), None if text is None else str(text) ), @@ -1945,7 +2033,11 @@ def _make_agent(sid: str, key: str, session_id: str | None = None): acp_args=runtime.get("args"), credential_pool=runtime.get("credential_pool"), quiet_mode=True, - verbose_logging=_load_tool_progress_mode() == "verbose", + # verbose_logging controls DEBUG-level agent logging; it is intentionally + # independent of tool_progress_mode (which only controls per-tool + # display detail). See cli.py PR (decoupling fix) for the matching + # change on the classic CLI side. + verbose_logging=False, reasoning_config=_load_reasoning_config(), service_tier=_load_service_tier(), enabled_toolsets=_load_enabled_toolsets(), @@ -3262,6 +3354,8 @@ def _run_prompt_submit(rid, sid: str, session: dict, text: Any) -> None: _read_main_model(), _cfg, ) + if getattr(agent, "api_mode", "") == "codex_app_server": + _mode = "text" except Exception as _img_exc: print( f"[tui_gateway] image_routing decision failed, defaulting to text: {_img_exc}", @@ -5292,7 +5386,12 @@ def _(rid, params: dict) -> dict: items = [ { "text": c.text, - "display": c.display or c.text, + # prompt_toolkit gives us FormattedText (a list of (style, + # text) tuples) for display/display_meta. Serialize both as + # plain strings โ€” the TUI's CompletionItem.display contract + # is a string, and sending the raw list trips Ink's row + # layout into 1-char truncation of the next column. + "display": to_plain_text(c.display) if c.display else c.text, "meta": to_plain_text(c.display_meta) if c.display_meta else "", } for c in completer.get_completions(doc, None) @@ -5781,6 +5880,9 @@ def _(rid, params: dict) -> dict: except Exception as e: logger.warning("voice: stop_continuous failed during toggle off: %s", e) + # Clear TTS so it can be toggled independently after voice is off. + os.environ["HERMES_VOICE_TTS"] = "0" + return _ok( rid, { diff --git a/ui-tui/packages/hermes-ink/src/ink/components/ScrollBox.tsx b/ui-tui/packages/hermes-ink/src/ink/components/ScrollBox.tsx index 15e896cb9c5..4f2604be0ec 100644 --- a/ui-tui/packages/hermes-ink/src/ink/components/ScrollBox.tsx +++ b/ui-tui/packages/hermes-ink/src/ink/components/ScrollBox.tsx @@ -48,10 +48,10 @@ export type ScrollBoxHandle = { */ isSticky: () => boolean /** - * Subscribe to imperative scroll changes (scrollTo/scrollBy/scrollToBottom). - * Does NOT fire for stickyScroll updates done by the Ink renderer โ€” those - * happen during Ink's render phase after React has committed. Callers that - * care about the sticky case should treat "at bottom" as a fallback. + * Subscribe to scroll viewport changes. Fires for imperative scroll changes + * (scrollTo/scrollBy/scrollToBottom) and for renderer-computed scroll bounds + * changes such as content growth or terminal resize. Callers use this to + * keep virtualized ranges aligned with the currently visible viewport. */ subscribe: (listener: () => void) => () => void /** diff --git a/ui-tui/packages/hermes-ink/src/ink/log-update.test.ts b/ui-tui/packages/hermes-ink/src/ink/log-update.test.ts index a11a028e771..c0935587d0f 100644 --- a/ui-tui/packages/hermes-ink/src/ink/log-update.test.ts +++ b/ui-tui/packages/hermes-ink/src/ink/log-update.test.ts @@ -42,7 +42,8 @@ const stdoutOnly = (diff: ReturnType) => .map(p => (p as { type: 'stdout'; content: string }).content) .join('') -const hasDecstbm = (text: string) => /\x1b\[\d+;\d+r/.test(text) +const ESC = '\u001b' +const hasDecstbm = (text: string) => new RegExp(`${ESC}\\[\\d+;\\d+r`).test(text) describe('LogUpdate.render diff contract', () => { it('emits only changed cells when most rows match', () => { @@ -87,6 +88,25 @@ describe('LogUpdate.render diff contract', () => { expect(stdoutOnly(diff)).toContain('shorterrownow') }) + it('height growth emits a clearTerminal patch before repainting', () => { + const w = 20 + const prevH = 3 + const nextH = 6 + + const prev = mkScreen(w, prevH) + paint(prev, 0, 'old rows') + + const next = mkScreen(w, nextH) + paint(next, 0, 'new rows') + next.damage = { x: 0, y: 0, width: w, height: nextH } + + const log = new LogUpdate({ isTTY: true, stylePool }) + const diff = log.render(mkFrame(prev, w, prevH), mkFrame(next, w, nextH), true, false) + + expect(diff.some(p => p.type === 'clearTerminal')).toBe(true) + expect(stdoutOnly(diff)).toContain('newrows') + }) + it('drift repro: identical prev/next emits no heal, even when the physical terminal is stale', () => { // Load-bearing theory for the rapid-resize scattered-letter bug: if the // physical terminal has stale cells that prev.screen doesn't know about @@ -167,10 +187,12 @@ describe('LogUpdate.render diff contract', () => { paint(next, 1, 'row one') const prevFrame = mkFrame(prev, w, h) + const nextFrame: Frame = { ...mkFrame(next, w, h), scrollHint: { top: 1, bottom: 4, delta: 1 } } + const log = new LogUpdate({ isTTY: true, stylePool }) const diff = log.render(prevFrame, nextFrame, true, true) @@ -187,10 +209,12 @@ describe('LogUpdate.render diff contract', () => { paint(next, 1, 'row one') const prevFrame = mkFrame(prev, w, h) + const nextFrame: Frame = { ...mkFrame(next, w, h), scrollHint: { top: 1, bottom: 5, delta: 1 } } + const log = new LogUpdate({ isTTY: true, stylePool }) const diff = log.render(prevFrame, nextFrame, true, true) diff --git a/ui-tui/packages/hermes-ink/src/ink/log-update.ts b/ui-tui/packages/hermes-ink/src/ink/log-update.ts index 0f36d4641e7..a428060b97d 100644 --- a/ui-tui/packages/hermes-ink/src/ink/log-update.ts +++ b/ui-tui/packages/hermes-ink/src/ink/log-update.ts @@ -141,14 +141,12 @@ export class LogUpdate { const startTime = performance.now() const stylePool = this.options.stylePool - // Since we assume the cursor is at the bottom on the screen, we only need - // to clear when the viewport gets shorter (i.e. the cursor position drifts) - // or when it gets thinner (and text wraps). We _could_ figure out how to - // not reset here but that would involve predicting the current layout - // _after_ the viewport change which means calcuating text wrapping. - // Resizing is a rare enough event that it's not practically a big issue. + // Terminal hosts can reflow/preserve old cells on any resize, including + // height-only growth. A partial diff can then leave stale transcript rows + // or cut off bordered content even when our virtual scrollTop is correct. + // Resizing is rare enough that a full repaint is the safer tradeoff. if ( - next.viewport.height < prev.viewport.height || + next.viewport.height !== prev.viewport.height || (prev.viewport.width !== 0 && next.viewport.width !== prev.viewport.width) ) { return fullResetSequence_CAUSES_FLICKER(next, 'resize', stylePool) diff --git a/ui-tui/packages/hermes-ink/src/ink/render-node-to-output.ts b/ui-tui/packages/hermes-ink/src/ink/render-node-to-output.ts index a31753c722a..5fee72cccaf 100644 --- a/ui-tui/packages/hermes-ink/src/ink/render-node-to-output.ts +++ b/ui-tui/packages/hermes-ink/src/ink/render-node-to-output.ts @@ -706,12 +706,22 @@ function renderNodeToOutput( const content = node.childNodes.find(c => (c as DOMElement).yogaNode) as DOMElement | undefined const contentYoga = content?.yogaNode - // scrollHeight is the intrinsic height of the content wrapper. - // Do NOT add getComputedTop() โ€” that's the wrapper's offset - // within the viewport (equal to the scroll container's - // paddingTop), and innerHeight already subtracts padding, so - // including it double-counts padding and inflates maxScroll. - const scrollHeight = contentYoga?.getComputedHeight() ?? 0 + // scrollHeight is the intrinsic height of the content wrapper, but + // after terminal resizes Yoga can leave tall descendants overflowing + // that wrapper. Use the deepest direct child bottom so sticky-bottom + // math can still reach the real final rendered row. + let scrollHeight = Math.ceil(contentYoga?.getComputedHeight() ?? 0) + + if (content) { + for (const child of content.childNodes) { + const childYoga = (child as DOMElement).yogaNode + + if (childYoga) { + scrollHeight = Math.max(scrollHeight, Math.ceil(childYoga.getComputedTop() + childYoga.getComputedHeight())) + } + } + } + // Capture previous scroll bounds BEFORE overwriting โ€” the at-bottom // follow check compares against last frame's max. const prevScrollHeight = node.scrollHeight ?? scrollHeight @@ -862,7 +872,12 @@ function renderNodeToOutput( scrollDrainNode = node } - if ((node.scrollTop ?? 0) !== scrollTopBeforeFollow || node.stickyScroll !== stickyBeforeFollow) { + if ( + (node.scrollTop ?? 0) !== scrollTopBeforeFollow || + node.stickyScroll !== stickyBeforeFollow || + scrollHeight !== prevScrollHeight || + innerHeight !== prevInnerHeight + ) { node.notifyScrollChange?.() } @@ -891,7 +906,14 @@ function renderNodeToOutput( const regionTop = Math.floor(y + contentYoga.getComputedTop()) const regionBottom = regionTop + innerHeight - 1 - if (cached?.y === y && cached.height === height && innerHeight > 0 && Math.abs(delta) < innerHeight) { + if ( + cached?.x === x && + cached.y === y && + cached.width === width && + cached.height === height && + innerHeight > 0 && + Math.abs(delta) < innerHeight + ) { hint = { top: regionTop, bottom: regionBottom, delta } scrollHint = hint } else { diff --git a/ui-tui/src/__tests__/createGatewayEventHandler.test.ts b/ui-tui/src/__tests__/createGatewayEventHandler.test.ts index 417b8c41b93..0a3e4227396 100644 --- a/ui-tui/src/__tests__/createGatewayEventHandler.test.ts +++ b/ui-tui/src/__tests__/createGatewayEventHandler.test.ts @@ -139,6 +139,7 @@ describe('createGatewayEventHandler', () => { const verdict = 'โœ“ Goal achieved: long judge reason goes only in transcript, not merged with cwd label.' vi.useFakeTimers() + try { onEvent({ payload: { kind: 'goal', text: verdict }, @@ -303,14 +304,40 @@ describe('createGatewayEventHandler', () => { vi.useFakeTimers() const appended: Msg[] = [] const streamed = 'short streamed reasoning' + const onEvent = createGatewayEventHandler(buildCtx(appended)) - createGatewayEventHandler(buildCtx(appended))({ payload: { text: streamed }, type: 'thinking.delta' } as any) - vi.runOnlyPendingTimers() + try { + onEvent({ payload: {}, type: 'message.start' } as any) + onEvent({ payload: { text: streamed }, type: 'thinking.delta' } as any) + vi.runOnlyPendingTimers() - expect(getTurnState().reasoning).toBe(streamed) - expect(getTurnState().reasoningActive).toBe(true) - expect(getTurnState().reasoningTokens).toBe(estimateTokensRough(streamed)) - vi.useRealTimers() + expect(getTurnState().reasoning).toBe(streamed) + expect(getTurnState().reasoningActive).toBe(true) + expect(getTurnState().reasoningTokens).toBe(estimateTokensRough(streamed)) + } finally { + vi.useRealTimers() + } + }) + + it('ignores late thinking.delta after the turn has already completed', () => { + vi.useFakeTimers() + const appended: Msg[] = [] + const onEvent = createGatewayEventHandler(buildCtx(appended)) + + try { + onEvent({ payload: {}, type: 'message.start' } as any) + onEvent({ payload: { text: 'final answer' }, type: 'message.complete' } as any) + expect(getUiState().busy).toBe(false) + expect(getUiState().status).toBe('ready') + + onEvent({ payload: { text: 'thinking...' }, type: 'thinking.delta' } as any) + vi.runOnlyPendingTimers() + + expect(getUiState().status).toBe('ready') + expect(getTurnState().reasoning).toBe('') + } finally { + vi.useRealTimers() + } }) it('preserves streamed reasoning as one completed thinking panel after segment flushes', () => { @@ -342,6 +369,25 @@ describe('createGatewayEventHandler', () => { expect(appended[appended.length - 1]).toMatchObject({ role: 'assistant', text: 'final answer' }) }) + it('shows verbose reasoning even when normal reasoning display is off', () => { + vi.useFakeTimers() + patchUiState({ showReasoning: false }) + const appended: Msg[] = [] + const streamed = 'verbose-only reasoning' + + try { + const onEvent = createGatewayEventHandler(buildCtx(appended)) + + onEvent({ payload: { text: streamed, verbose: true }, type: 'reasoning.delta' } as any) + vi.runOnlyPendingTimers() + + expect(turnController.reasoningText).toBe(streamed) + expect(getTurnState().reasoning).toBe(streamed) + } finally { + vi.useRealTimers() + } + }) + it('ignores fallback reasoning.available when streamed reasoning already exists', () => { const appended: Msg[] = [] const streamed = 'short streamed reasoning' @@ -485,6 +531,25 @@ describe('createGatewayEventHandler', () => { expect(appended[3]?.text).not.toContain('```diff') }) + it('keeps verbose result text on inline_diff tool completions', () => { + const appended: Msg[] = [] + const onEvent = createGatewayEventHandler(buildCtx(appended)) + const diff = '--- a/foo.ts\n+++ b/foo.ts\n@@\n-old\n+new' + + onEvent({ + payload: { args_text: '{ "path": "foo.ts" }', context: 'foo.ts', name: 'patch', tool_id: 'tool-1' }, + type: 'tool.start' + } as any) + onEvent({ + payload: { inline_diff: diff, result_text: 'patched result', tool_id: 'tool-1' }, + type: 'tool.complete' + } as any) + + expect(turnController.segmentMessages[0]).toMatchObject({ kind: 'diff' }) + expect(turnController.segmentMessages[0]?.tools?.[0]).toContain('Args:\n{ "path": "foo.ts" }') + expect(turnController.segmentMessages[0]?.tools?.[0]).toContain('Result:\npatched result') + }) + it('keeps full final responses from duplicating flushed pre-diff narration', () => { const appended: Msg[] = [] const onEvent = createGatewayEventHandler(buildCtx(appended)) diff --git a/ui-tui/src/__tests__/createSlashHandler.test.ts b/ui-tui/src/__tests__/createSlashHandler.test.ts index 952f34fc38b..e1251a4af9f 100644 --- a/ui-tui/src/__tests__/createSlashHandler.test.ts +++ b/ui-tui/src/__tests__/createSlashHandler.test.ts @@ -222,6 +222,21 @@ describe('createSlashHandler', () => { expect(ctx.gateway.rpc).not.toHaveBeenCalled() }) + it('keeps visible scrollback when branching a TUI session', async () => { + patchUiState({ sid: 'sid-parent' }) + const rpc = vi.fn(() => Promise.resolve({ session_id: 'sid-branch', title: 'branch title' })) + const ctx = buildCtx({ gateway: { ...buildGateway(), rpc } }) + + expect(createSlashHandler(ctx)('/branch branch title')).toBe(true) + + expect(rpc).toHaveBeenCalledWith('session.branch', { name: 'branch title', session_id: 'sid-parent' }) + await vi.waitFor(() => { + expect(getUiState().sid).toBe('sid-branch') + expect(ctx.transcript.sys).toHaveBeenCalledWith('branched โ†’ branch title') + }) + expect(ctx.transcript.setHistoryItems).not.toHaveBeenCalled() + }) + it('reloads skills in the live gateway and refreshes the catalog', async () => { const rpc = vi.fn((method: string) => { if (method === 'skills.reload') { diff --git a/ui-tui/src/__tests__/gatewayClient.test.ts b/ui-tui/src/__tests__/gatewayClient.test.ts index eac96c20780..f1228e56fbe 100644 --- a/ui-tui/src/__tests__/gatewayClient.test.ts +++ b/ui-tui/src/__tests__/gatewayClient.test.ts @@ -34,6 +34,7 @@ class FakeWebSocket { options !== null && 'once' in options && Boolean((options as { once?: unknown }).once) + const entries = this.listeners.get(type) ?? [] entries.push({ callback, once }) @@ -84,6 +85,7 @@ class FakeWebSocket { for (const entry of entries) { entry.callback(event) + if (entry.once) { this.removeEventListener(type, entry.callback) } @@ -170,6 +172,7 @@ describe('GatewayClient websocket attach mode', () => { method: 'event', params: { type: 'tool.start', payload: { tool_id: 't1' } } }) + gatewaySocket.message(eventFrame) expect(seen).toContain('tool.start') @@ -193,6 +196,8 @@ describe('GatewayClient websocket attach mode', () => { gatewaySocket.close(1011) expect(exits).toEqual([1011]) + expect(gw.getLogTail(20)).toContain('[lifecycle] websocket close code=1011') + expect(gw.getLogTail(20)).toContain('[lifecycle] transport exit code=1011') }) it('rejects pending RPCs with websocket wording when the attached socket closes', async () => { @@ -226,9 +231,10 @@ describe('GatewayClient websocket attach mode', () => { const req = gw.request('session.create', {}) await vi.waitFor(() => expect(gatewaySocket.sent.length).toBeGreaterThan(0)) - gw.kill() + gw.kill('test.shutdown') await expect(req).rejects.toThrow(/gateway closed/) + expect(gw.getLogTail(20)).toContain('[lifecycle] GatewayClient.kill reason=test.shutdown') }) it('reattaches when HERMES_TUI_GATEWAY_URL rotates between requests', async () => { @@ -279,6 +285,7 @@ describe('GatewayClient websocket attach mode', () => { gw.drain() expect(stderrLines.length).toBeGreaterThan(0) + for (const line of stderrLines) { expect(line).not.toContain('hunter2') expect(line).not.toContain('channel=secret') @@ -370,6 +377,7 @@ describe('GatewayClient websocket attach mode', () => { gw.drain() expect(stderrLines.length).toBeGreaterThan(0) + for (const line of stderrLines) { expect(line).not.toContain('alice') expect(line).not.toContain('hunter2') diff --git a/ui-tui/src/__tests__/prompt.test.ts b/ui-tui/src/__tests__/prompt.test.ts index 7b923c79a40..68c57354783 100644 --- a/ui-tui/src/__tests__/prompt.test.ts +++ b/ui-tui/src/__tests__/prompt.test.ts @@ -16,4 +16,16 @@ describe('composerPromptText', () => { expect(composerPromptText('โฏ', 'custom')).toBe('โฏ') expect(composerPromptText('โฏ')).toBe('โฏ') }) + + it('uses a Termux-safe ASCII prompt marker in normal mode', () => { + expect(composerPromptText('โฏ', 'coder', false, true, 50)).toBe('>') + }) + + it('keeps profile prefix suppressed on narrow Termux widths', () => { + expect(composerPromptText('โฏ', 'upstr', false, true, 72)).toBe('>') + }) + + it('allows profile prefix on very wide Termux panes', () => { + expect(composerPromptText('โฏ', 'upstr', false, true, 120)).toBe('upstr >') + }) }) diff --git a/ui-tui/src/__tests__/termuxComposerLayout.test.ts b/ui-tui/src/__tests__/termuxComposerLayout.test.ts new file mode 100644 index 00000000000..e845ef89c3f --- /dev/null +++ b/ui-tui/src/__tests__/termuxComposerLayout.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from 'vitest' + +import { stableComposerColumns, transcriptBodyWidth } from '../lib/inputMetrics.js' +import { composerPromptText } from '../lib/prompt.js' + +describe('Termux composer prompt + width guards', () => { + it('uses a single-cell ASCII prompt marker in Termux mode', () => { + expect(composerPromptText('โฏ', 'coder', false, true, 50)).toBe('>') + }) + + it('suppresses profile prefixes on narrow Termux panes', () => { + expect(composerPromptText('โฏ', 'upstr', false, true, 72)).toBe('>') + }) + + it('keeps profile context on very wide Termux panes', () => { + expect(composerPromptText('โฏ', 'upstr', false, true, 120)).toBe('upstr >') + }) + + it('reserves fewer columns for gutter on narrow Termux widths', () => { + // 32 columns after prompt: desktop reserves 2 for transcript scrollbar, + // Termux keeps those 2 columns for the active composer. + expect(stableComposerColumns(40, 8, false)).toBe(28) + expect(stableComposerColumns(40, 8, true)).toBe(30) + + // With ample room, Termux still reserves the gutter for alignment. + expect(stableComposerColumns(60, 8, true)).toBe(48) + }) + + it('never over-allocates transcript body width on narrow panes', () => { + // Old behavior hard-minned to 20 columns and overflowed narrow layouts. + expect(transcriptBodyWidth(24, 'assistant', '>', true)).toBe(19) + expect(transcriptBodyWidth(24, 'user', 'upstr >', true)).toBe(14) + expect(transcriptBodyWidth(10, 'user', '>', true)).toBeGreaterThanOrEqual(1) + }) + + it('keeps legacy desktop floor outside Termux mode', () => { + expect(transcriptBodyWidth(24, 'assistant', '>')).toBe(20) + expect(transcriptBodyWidth(24, 'user', 'upstr >')).toBe(20) + }) +}) diff --git a/ui-tui/src/__tests__/text.test.ts b/ui-tui/src/__tests__/text.test.ts index 306324d353d..6fd250b5bee 100644 --- a/ui-tui/src/__tests__/text.test.ts +++ b/ui-tui/src/__tests__/text.test.ts @@ -3,6 +3,7 @@ import { describe, expect, it } from 'vitest' import { boundedLiveRenderText, buildToolTrailLine, + buildVerboseToolTrailLine, edgePreview, estimateRows, estimateTokensRough, @@ -12,8 +13,8 @@ import { lastCotTrailIndex, parseToolTrailResultLine, pasteTokenLabel, - sanitizeAnsiForRender, sameToolTrailGroup, + sanitizeAnsiForRender, splitToolDuration, stripAnsi, thinkingPreview @@ -37,6 +38,39 @@ describe('buildToolTrailLine', () => { }) }) +describe('buildVerboseToolTrailLine', () => { + it('preserves multiline args and result details', () => { + const line = buildVerboseToolTrailLine( + 'terminal', + 'npm test', + false, + 1.25, + '{\n "cmd": "npm test"\n}', + 'first line\nsecond :: line' + ) + + expect(line).toContain('Args:\n{') + expect(line).toContain('Result:\nfirst line\nsecond :: line') + expect(parseToolTrailResultLine(line)).toEqual({ + call: 'Terminal("npm test") (1.3s)', + detail: 'Args:\n{\n "cmd": "npm test"\n}\nResult:\nfirst line\nsecond :: line', + mark: 'โœ“' + }) + }) + + it('labels verbose failures as errors', () => { + const line = buildVerboseToolTrailLine('terminal', 'npm test', true, 0.5, undefined, 'command failed') + + expect(line).toContain('Error:\ncommand failed') + expect(line).not.toContain('Result:\ncommand failed') + expect(parseToolTrailResultLine(line)).toEqual({ + call: 'Terminal("npm test") (0.5s)', + detail: 'Error:\ncommand failed', + mark: 'โœ—' + }) + }) +}) + describe('lastCotTrailIndex', () => { it('finds last non-result line', () => { expect(lastCotTrailIndex(['a โœ“', 'thinkingโ€ฆ'])).toBe(1) diff --git a/ui-tui/src/__tests__/textInputBurstInput.test.ts b/ui-tui/src/__tests__/textInputBurstInput.test.ts new file mode 100644 index 00000000000..1fdd5246614 --- /dev/null +++ b/ui-tui/src/__tests__/textInputBurstInput.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from 'vitest' + +import { applyPrintableInsert, shouldRouteMultiCharInputAsPaste } from '../components/textInput.js' + +describe('applyPrintableInsert', () => { + it('applies non-bracketed multi-character bursts immediately', () => { + const burst = applyPrintableInsert('abc', 3, 'xxxxx') + + const repeated = [...'xxxxx'].reduce( + (state, ch) => applyPrintableInsert(state.value, state.cursor, ch)!, + { cursor: 3, value: 'abc' } + ) + + expect(burst).toEqual({ cursor: 8, value: 'abcxxxxx' }) + expect(burst).toEqual(repeated) + }) + + it('replaces the selected range for burst input', () => { + expect(applyPrintableInsert('abZZef', 4, 'cd', { end: 4, start: 2 })).toEqual({ + cursor: 4, + value: 'abcdef' + }) + }) + + it('rejects control or escape-bearing input', () => { + expect(applyPrintableInsert('abc', 3, '\x1b[200~pasted')).toBeNull() + expect(applyPrintableInsert('abc', 3, '\t')).toBeNull() + }) +}) + +describe('shouldRouteMultiCharInputAsPaste', () => { + it('keeps newline-bearing chunks on the paste path', () => { + expect(shouldRouteMultiCharInputAsPaste('hello\nworld')).toBe(true) + expect(shouldRouteMultiCharInputAsPaste('hello\r\nworld'.replace(/\r\n/g, '\n'))).toBe(true) + }) + + it('treats repeated printable key bursts as immediate input', () => { + expect(shouldRouteMultiCharInputAsPaste('xxxxx')).toBe(false) + }) +}) diff --git a/ui-tui/src/__tests__/textInputFastEcho.test.ts b/ui-tui/src/__tests__/textInputFastEcho.test.ts index 83b5c511940..6221314a062 100644 --- a/ui-tui/src/__tests__/textInputFastEcho.test.ts +++ b/ui-tui/src/__tests__/textInputFastEcho.test.ts @@ -178,7 +178,22 @@ describe('supportsFastEchoTerminal', () => { expect(supportsFastEchoTerminal({ TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBe(false) }) - it('keeps fast-echo enabled in VS Code and unknown terminals', () => { + it('disables fast-echo by default in Termux mode', () => { + expect( + supportsFastEchoTerminal({ TERMUX_VERSION: '0.118.0', PREFIX: '/data/data/com.termux/files/usr' } as NodeJS.ProcessEnv) + ).toBe(false) + }) + + it('allows explicit Termux fast-echo opt-in via env override', () => { + expect( + supportsFastEchoTerminal({ + HERMES_TUI_TERMUX_FAST_ECHO: '1', + TERMUX_VERSION: '0.118.0' + } as NodeJS.ProcessEnv) + ).toBe(true) + }) + + it('keeps fast-echo enabled in VS Code and unknown non-Termux terminals', () => { expect(supportsFastEchoTerminal({ TERM_PROGRAM: 'vscode' } as NodeJS.ProcessEnv)).toBe(true) expect(supportsFastEchoTerminal({ TERM: 'xterm-256color' } as NodeJS.ProcessEnv)).toBe(true) }) diff --git a/ui-tui/src/__tests__/virtualHistoryOffsetCache.test.ts b/ui-tui/src/__tests__/virtualHistoryOffsetCache.test.ts index 5a3e8cd0976..a98b43972e6 100644 --- a/ui-tui/src/__tests__/virtualHistoryOffsetCache.test.ts +++ b/ui-tui/src/__tests__/virtualHistoryOffsetCache.test.ts @@ -4,10 +4,11 @@ import { Box, renderSync, ScrollBox, type ScrollBoxHandle, Text } from '@hermes/ import React, { useLayoutEffect, useRef } from 'react' import { describe, expect, it } from 'vitest' -import { useVirtualHistory } from '../hooks/useVirtualHistory.js' +import { useVirtualHistory, virtualHistorySnapshotKey } from '../hooks/useVirtualHistory.js' interface Item { height: number + heightAfterResize?: number key: string } @@ -49,13 +50,28 @@ const viewportIsMounted = (items: readonly Item[], virtualHistory: ReturnType= span.top && bottom <= span.bottom } -function Harness({ expose, items }: { expose: React.MutableRefObject; items: readonly Item[] }) { +const itemHeightForColumns = (item: Item | undefined, columns: number) => + columns >= 80 ? (item?.heightAfterResize ?? item?.height ?? 1) : (item?.height ?? 1) + +function Harness({ + columns = 80, + expose, + height = 10, + items, + maxMounted = 16 +}: { + columns?: number + expose: React.MutableRefObject + height?: number + items: readonly Item[] + maxMounted?: number +}) { const scrollRef = useRef(null) - const virtualHistory = useVirtualHistory(scrollRef, items, 80, { + const virtualHistory = useVirtualHistory(scrollRef, items, columns, { coldStartCount: 16, - estimateHeight: index => items[index]?.height ?? 1, - maxMounted: 16, + estimateHeight: index => itemHeightForColumns(items[index], columns), + maxMounted, overscan: 2 }) @@ -65,7 +81,7 @@ function Harness({ expose, items }: { expose: React.MutableRefObject React.createElement( Box, - { height: item.height, key: item.key, ref: virtualHistory.measureRef(item.key) }, + { + height: itemHeightForColumns(item, columns), + key: item.key, + ref: virtualHistory.measureRef(item.key) + }, React.createElement(Text, null, item.key) ) ), @@ -85,6 +105,113 @@ function Harness({ expose, items }: { expose: React.MutableRefObject { + it('includes viewport height in the external-store snapshot key', () => { + const base = { + getPendingDelta: () => 0, + getScrollTop: () => 20, + isSticky: () => false + } + + const short = virtualHistorySnapshotKey({ + ...base, + getViewportHeight: () => 5 + } as ScrollBoxHandle) + + const tall = virtualHistorySnapshotKey({ + ...base, + getViewportHeight: () => 25 + } as ScrollBoxHandle) + + expect(short).not.toBe(tall) + }) + + it('remounts enough tail rows after the scroll viewport grows', async () => { + const items = Array.from({ length: 100 }, (_, index) => ({ height: 1, key: `item-${index}` })) + const expose = { current: null as Exposed | null } + const streams = makeStreams() + + const instance = renderSync(React.createElement(Harness, { expose, height: 4, items, maxMounted: 80 }), { + patchConsole: false, + stderr: streams.stderr as NodeJS.WriteStream, + stdin: streams.stdin as NodeJS.ReadStream, + stdout: streams.stdout as NodeJS.WriteStream + }) + + try { + await delay(20) + instance.rerender(React.createElement(Harness, { expose, height: 9, items, maxMounted: 80 })) + await delay(80) + + expect(viewportIsMounted(items, expose.current!.virtualHistory, expose.current!.scroll!)).toBe(true) + } finally { + instance.unmount() + instance.cleanup() + } + }) + + it('recomputes tail coverage when wrapped rows shrink after a width resize', async () => { + const items = Array.from({ length: 100 }, (_, index) => ({ + height: 4, + heightAfterResize: 1, + key: `item-${index}` + })) + + const expose = { current: null as Exposed | null } + const streams = makeStreams() + + const instance = renderSync( + React.createElement(Harness, { columns: 40, expose, height: 10, items, maxMounted: 80 }), + { + patchConsole: false, + stderr: streams.stderr as NodeJS.WriteStream, + stdin: streams.stdin as NodeJS.ReadStream, + stdout: streams.stdout as NodeJS.WriteStream + } + ) + + try { + await delay(20) + instance.rerender(React.createElement(Harness, { columns: 80, expose, height: 10, items, maxMounted: 80 })) + await delay(80) + + const resizedItems = items.map(item => ({ height: item.heightAfterResize!, key: item.key })) + + expect(viewportIsMounted(resizedItems, expose.current!.virtualHistory, expose.current!.scroll!)).toBe(true) + } finally { + instance.unmount() + instance.cleanup() + } + }) + + it('keeps sticky scroll at the bottom when one tall tail row resizes', async () => { + const items = [{ height: 90, heightAfterResize: 50, key: 'tail' }] + const expose = { current: null as Exposed | null } + const streams = makeStreams() + + const instance = renderSync( + React.createElement(Harness, { columns: 70, expose, height: 18, items, maxMounted: 80 }), + { + patchConsole: false, + stderr: streams.stderr as NodeJS.WriteStream, + stdin: streams.stdin as NodeJS.ReadStream, + stdout: streams.stdout as NodeJS.WriteStream + } + ) + + try { + await delay(20) + instance.rerender(React.createElement(Harness, { columns: 120, expose, height: 36, items, maxMounted: 80 })) + await delay(80) + + const scroll = expose.current!.scroll! + + expect(scroll.getScrollTop()).toBe(scroll.getScrollHeight() - scroll.getViewportHeight()) + } finally { + instance.unmount() + instance.cleanup() + } + }) + it('recomputes offsets after a mounted row height changes', async () => { const tall = [ { height: 6, key: 'a' }, diff --git a/ui-tui/src/app/createGatewayEventHandler.ts b/ui-tui/src/app/createGatewayEventHandler.ts index 267334bfd72..26d6cfacd0c 100644 --- a/ui-tui/src/app/createGatewayEventHandler.ts +++ b/ui-tui/src/app/createGatewayEventHandler.ts @@ -1,6 +1,6 @@ import { STARTUP_IMAGE, STARTUP_QUERY } from '../config/env.js' import { STREAM_BATCH_MS } from '../config/timing.js' -import { SETUP_REQUIRED_TITLE, buildSetupRequiredSections } from '../content/setup.js' +import { buildSetupRequiredSections, SETUP_REQUIRED_TITLE } from '../content/setup.js' import type { CommandsCatalogResponse, ConfigFullResponse, @@ -313,6 +313,10 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: } case 'thinking.delta': { + if (!getUiState().busy) { + return + } + const text = ev.payload?.text if (text !== undefined) { @@ -340,6 +344,7 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: if (p.kind === 'goal') { sys(p.text) + const brief = p.text.startsWith('โœ“') ? 'โœ“ goal complete' : p.text.startsWith('โ†ป') @@ -347,8 +352,10 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: : p.text.startsWith('โธ') ? 'โธ goal paused' : 'ready' + setStatus(brief) restoreStatusAfter(6000) + return } @@ -356,6 +363,7 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: if (p.kind === 'compressing') { sys(p.text) + return } @@ -491,13 +499,13 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: case 'reasoning.delta': if (ev.payload?.text) { - turnController.recordReasoningDelta(ev.payload.text) + turnController.recordReasoningDelta(ev.payload.text, Boolean(ev.payload.verbose)) } return case 'reasoning.available': - turnController.recordReasoningAvailable(String(ev.payload?.text ?? '')) + turnController.recordReasoningAvailable(String(ev.payload?.text ?? ''), Boolean(ev.payload?.verbose)) return @@ -517,20 +525,28 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: case 'tool.start': turnController.recordTodos(ev.payload.todos) - turnController.recordToolStart(ev.payload.tool_id, ev.payload.name ?? 'tool', ev.payload.context ?? '') + turnController.recordToolStart( + ev.payload.tool_id, + ev.payload.name ?? 'tool', + ev.payload.context ?? '', + ev.payload.args_text ? stripAnsi(String(ev.payload.args_text)) : undefined + ) return case 'tool.complete': { const inlineDiffText = ev.payload.inline_diff && getUiState().inlineDiffs ? stripAnsi(String(ev.payload.inline_diff)).trim() : '' + const resultText = ev.payload.result_text ? stripAnsi(String(ev.payload.result_text)) : undefined + if (inlineDiffText) { turnController.recordInlineDiffToolComplete( inlineDiffText, ev.payload.tool_id, ev.payload.name, ev.payload.error, - ev.payload.duration_s + ev.payload.duration_s, + resultText ) } else { turnController.recordToolComplete( @@ -539,7 +555,8 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: ev.payload.error, ev.payload.summary, ev.payload.duration_s, - ev.payload.todos + ev.payload.todos, + resultText ) } @@ -581,7 +598,6 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: sys(`[bg ${ev.payload.task_id}] ${ev.payload.text}`) return - case 'review.summary': { // Self-improvement background review emitted a persistent summary // of what it saved to memory/skills. Surface it as a system line @@ -589,6 +605,7 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: // flash. Python-side already formats it as "๐Ÿ’พ Self-improvement // review: โ€ฆ". const text = String(ev.payload?.text ?? '').trim() + if (text) { sys(text) } diff --git a/ui-tui/src/app/interfaces.ts b/ui-tui/src/app/interfaces.ts index b71e34188ef..cb2788bbf4f 100644 --- a/ui-tui/src/app/interfaces.ts +++ b/ui-tui/src/app/interfaces.ts @@ -216,6 +216,7 @@ export interface InputHandlerContext { setProcessing: StateSetter setRecording: StateSetter setVoiceEnabled: StateSetter + setVoiceTts: StateSetter } wheelStep: number } @@ -254,6 +255,7 @@ export interface GatewayEventHandlerContext { setProcessing: StateSetter setRecording: StateSetter setVoiceEnabled: StateSetter + setVoiceTts: StateSetter } } @@ -296,6 +298,7 @@ export interface SlashHandlerContext { voice: { setVoiceEnabled: StateSetter setVoiceRecordKey: (v: ParsedVoiceRecordKey) => void + setVoiceTts: StateSetter } } diff --git a/ui-tui/src/app/slash/commands/session.ts b/ui-tui/src/app/slash/commands/session.ts index 466505d8ceb..fb990ef11be 100644 --- a/ui-tui/src/app/slash/commands/session.ts +++ b/ui-tui/src/app/slash/commands/session.ts @@ -212,7 +212,6 @@ export const sessionCommands: SlashCommand[] = [ void ctx.session.closeSession(prevSid) patchUiState({ sid: r.session_id }) ctx.session.setSessionStartedAt(Date.now()) - ctx.transcript.setHistoryItems([]) ctx.transcript.sys(`branched โ†’ ${r.title ?? ''}`) }) ) @@ -233,6 +232,7 @@ export const sessionCommands: SlashCommand[] = [ ctx.gateway.rpc('voice.toggle', { action }).then( ctx.guarded(r => { ctx.voice.setVoiceEnabled(!!r.enabled) + ctx.voice.setVoiceTts(!!r.tts) // Render the configured record key (config.yaml ``voice.record_key``) // instead of hardcoded "Ctrl+B" โ€” the gateway response carries the diff --git a/ui-tui/src/app/turnController.ts b/ui-tui/src/app/turnController.ts index b9e0aa04c19..4e22d3312cd 100644 --- a/ui-tui/src/app/turnController.ts +++ b/ui-tui/src/app/turnController.ts @@ -11,6 +11,7 @@ import { hasReasoningTag, splitReasoning } from '../lib/reasoning.js' import { boundedLiveRenderText, buildToolTrailLine, + buildVerboseToolTrailLine, estimateTokensRough, isTransientTrailLine, sameToolTrailGroup, @@ -542,8 +543,8 @@ class TurnController { } } - recordReasoningAvailable(text: string) { - if (this.interrupted || !getUiState().showReasoning) { + recordReasoningAvailable(text: string, force = false) { + if (this.interrupted || (!force && !getUiState().showReasoning)) { return } @@ -560,8 +561,8 @@ class TurnController { this.pulseReasoningStreaming() } - recordReasoningDelta(text: string) { - if (this.interrupted || !getUiState().showReasoning) { + recordReasoningDelta(text: string, force = false) { + if (this.interrupted || (!force && !getUiState().showReasoning)) { return } @@ -587,14 +588,15 @@ class TurnController { error?: string, summary?: string, duration?: number, - todos?: unknown + todos?: unknown, + resultText?: string ) { if (this.interrupted) { return } this.recordTodos(todos) - const line = this.completeTool(toolId, fallbackName, error, summary, duration) + const line = this.completeTool(toolId, fallbackName, error, summary, duration, resultText) this.pendingSegmentTools = [...this.pendingSegmentTools, line] this.flushPendingToolsIntoLastSegment() @@ -606,30 +608,42 @@ class TurnController { toolId: string, fallbackName?: string, error?: string, - duration?: number + duration?: number, + resultText?: string ) { if (this.interrupted) { return } this.flushStreamingSegment() - this.pushInlineDiffSegment(diffText, [this.completeTool(toolId, fallbackName, error, '', duration)]) + this.pushInlineDiffSegment(diffText, [this.completeTool(toolId, fallbackName, error, '', duration, resultText)]) this.publishToolState() } - private completeTool(toolId: string, fallbackName?: string, error?: string, summary?: string, duration?: number) { + private completeTool( + toolId: string, + fallbackName?: string, + error?: string, + summary?: string, + duration?: number, + resultText?: string + ) { const done = this.activeTools.find(tool => tool.id === toolId) const name = done?.name ?? fallbackName ?? 'tool' const label = toolTrailLabel(name) const fallbackDuration = done?.startedAt ? (Date.now() - done.startedAt) / 1000 : undefined - const line = buildToolTrailLine( - name, - done?.context || '', - Boolean(error), - error || summary || '', - duration ?? fallbackDuration - ) + const line = + done?.verboseArgs || resultText + ? buildVerboseToolTrailLine( + name, + done?.context || '', + Boolean(error), + duration ?? fallbackDuration, + done?.verboseArgs, + error || resultText || summary || '' + ) + : buildToolTrailLine(name, done?.context || '', Boolean(error), error || summary || '', duration ?? fallbackDuration) this.activeTools = this.activeTools.filter(tool => tool.id !== toolId) @@ -675,7 +689,7 @@ class TurnController { }, STREAM_BATCH_MS) } - recordToolStart(toolId: string, name: string, context: string) { + recordToolStart(toolId: string, name: string, context: string, verboseArgs?: string) { if (this.interrupted) { return } @@ -688,7 +702,7 @@ class TurnController { const sample = `${name} ${context}`.trim() this.toolTokenAcc += sample ? estimateTokensRough(sample) : 0 - this.activeTools = [...this.activeTools, { context, id: toolId, name, startedAt: Date.now() }] + this.activeTools = [...this.activeTools, { context, id: toolId, name, startedAt: Date.now(), verboseArgs }] patchTurnState({ toolTokens: this.toolTokenAcc, tools: this.activeTools }) } diff --git a/ui-tui/src/app/useMainApp.ts b/ui-tui/src/app/useMainApp.ts index 7996c7b910b..4d7ab8926ba 100644 --- a/ui-tui/src/app/useMainApp.ts +++ b/ui-tui/src/app/useMainApp.ts @@ -1,4 +1,4 @@ -import { useApp, useHasSelection, useSelection, useStdout, useTerminalTitle, type ScrollBoxHandle } from '@hermes/ink' +import { type ScrollBoxHandle, useApp, useHasSelection, useSelection, useStdout, useTerminalTitle } from '@hermes/ink' import { useStore } from '@nanostores/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' @@ -102,6 +102,7 @@ export function useMainApp(gw: GatewayClient) { const [stickyPrompt, setStickyPrompt] = useState('') const [catalog, setCatalog] = useState(null) const [voiceEnabled, setVoiceEnabled] = useState(false) + const [voiceTts, setVoiceTts] = useState(false) const [voiceRecording, setVoiceRecording] = useState(false) const [voiceProcessing, setVoiceProcessing] = useState(false) const [voiceRecordKey, setVoiceRecordKey] = useState(DEFAULT_VOICE_RECORD_KEY) @@ -233,9 +234,15 @@ export function useMainApp(gw: GatewayClient) { return next }, []) + // Wrapped row heights are width-dependent. Cached layout outlives a resize + // and lands sticky-scroll at the stale max, cutting off the tail. The + // hook's "scale heights by oldCols/newCols" path is too approximate for + // mixed markdown โ€” we deliberately remount every row so yoga re-measures + // off live geometry. Cost: per-row local state (e.g. systemOpen toggles) + // resets on resize; small UX hit for a hard correctness win. const virtualRows = useMemo( - () => historyItems.map((msg, index) => ({ index, key: messageId(msg), msg })), - [historyItems, messageId] + () => historyItems.map((msg, index) => ({ index, key: `${messageId(msg)}:c${cols}`, msg })), + [cols, historyItems, messageId] ) const detailsLayoutKey = useMemo(() => { @@ -365,7 +372,7 @@ export function useMainApp(gw: GatewayClient) { const gateway = useMemo(() => ({ gw, rpc }), [gw, rpc]) const die = useCallback(() => { - gw.kill() + gw.kill('app.die') exit() // Ink's exit() calls unmount() which resets terminal modes but does NOT // call process.exit(). Without an explicit exit the Node process stays @@ -377,7 +384,7 @@ export function useMainApp(gw: GatewayClient) { }, [exit, gw]) const dieWithCode = useCallback((code: number) => { - gw.kill() + gw.kill(`app.dieWithCode:${code}`) exit() process.exit(code) }, [exit, gw]) @@ -424,10 +431,20 @@ export function useMainApp(gw: GatewayClient) { let timer: ReturnType | undefined + // Resize reflows wrapped lines; if the user is still pinned to the tail + // we need to re-snap once React has remeasured. virtualRows is keyed on + // cols so every column change forces a fresh measurement pass before + // this timer fires. Re-check isSticky() inside the timeout โ€” a manual + // scroll during the 100ms window otherwise yanks the user back to tail. const onResize = () => { clearTimeout(timer) timer = setTimeout(() => { timer = undefined + + if (scrollRef.current?.isSticky()) { + scrollRef.current.scrollToBottom() + } + void rpc('terminal.resize', { cols: stdout.columns ?? 80, session_id: ui.sid }) }, 100) } @@ -555,7 +572,8 @@ export function useMainApp(gw: GatewayClient) { recording: voiceRecording, setProcessing: setVoiceProcessing, setRecording: setVoiceRecording, - setVoiceEnabled + setVoiceEnabled, + setVoiceTts }, wheelStep: WHEEL_SCROLL_STEP }) @@ -579,7 +597,8 @@ export function useMainApp(gw: GatewayClient) { voice: { setProcessing: setVoiceProcessing, setRecording: setVoiceRecording, - setVoiceEnabled + setVoiceEnabled, + setVoiceTts } }), [ @@ -736,10 +755,13 @@ export function useMainApp(gw: GatewayClient) { const anyPanelVisible = SECTION_NAMES.some( s => sectionMode(s, ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' ) + const thinkingPanelVisible = sectionMode('thinking', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' + const toolsPanelVisible = sectionMode('tools', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' + const activityPanelVisible = sectionMode('activity', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' @@ -827,7 +849,7 @@ export function useMainApp(gw: GatewayClient) { turnStartedAt: ui.sid ? turnStartedAt : null, // CLI parity: the classic prompt_toolkit status bar shows a red dot // on REC (cli.py:_get_voice_status_fragments line 2344). - voiceLabel: voiceRecording ? 'โ— REC' : voiceProcessing ? 'โ—‰ STT' : `voice ${voiceEnabled ? 'on' : 'off'}` + voiceLabel: voiceRecording ? 'โ— REC' : voiceProcessing ? 'โ—‰ STT' : `voice ${voiceEnabled ? 'on' : 'off'}${voiceTts ? ' [tts]' : ''}` }), [ cwd, @@ -839,7 +861,8 @@ export function useMainApp(gw: GatewayClient) { ui, voiceEnabled, voiceProcessing, - voiceRecording + voiceRecording, + voiceTts ] ) diff --git a/ui-tui/src/banner.ts b/ui-tui/src/banner.ts index 80da8f43d70..748e5a452bc 100644 --- a/ui-tui/src/banner.ts +++ b/ui-tui/src/banner.ts @@ -79,8 +79,8 @@ const colorize = (art: string[], gradient: readonly number[], c: ThemeColors): L return art.map((text, i) => [p[gradient[i]!] ?? c.muted, text]) } -export const LOGO_WIDTH = 98 -export const CADUCEUS_WIDTH = 30 +export const LOGO_WIDTH = Math.max(...LOGO_ART.map(line => line.length)) +export const CADUCEUS_WIDTH = Math.max(...CADUCEUS_ART.map(line => line.length)) export const logo = (c: ThemeColors, customLogo?: string): Line[] => customLogo ? parseRichMarkup(customLogo) : colorize(LOGO_ART, LOGO_GRADIENT, c) diff --git a/ui-tui/src/components/appLayout.tsx b/ui-tui/src/components/appLayout.tsx index a4b6963cb5a..8b69b9e4425 100644 --- a/ui-tui/src/components/appLayout.tsx +++ b/ui-tui/src/components/appLayout.tsx @@ -6,7 +6,7 @@ import { useGateway } from '../app/gatewayContext.js' import type { AppLayoutProps } from '../app/interfaces.js' import { $isBlocked, $overlayState, patchOverlayState } from '../app/overlayStore.js' import { $uiState } from '../app/uiStore.js' -import { INLINE_MODE, SHOW_FPS } from '../config/env.js' +import { INLINE_MODE, SHOW_FPS, TERMUX_TUI_MODE } from '../config/env.js' import { PLACEHOLDER } from '../content/placeholders.js' import { COMPOSER_PROMPT_GAP_WIDTH, @@ -112,9 +112,9 @@ const TranscriptPane = memo(function TranscriptPane({ {row.msg.kind === 'intro' ? ( - + - {row.msg.info && } + {row.msg.info && } ) : row.msg.kind === 'panel' && row.msg.panelData ? ( @@ -169,10 +169,10 @@ const ComposerPane = memo(function ComposerPane({ const ui = useStore($uiState) const isBlocked = useStore($isBlocked) const sh = (composer.inputBuf[0] ?? composer.input).startsWith('!') - const promptText = composerPromptText(ui.theme.brand.prompt, ui.info?.profile_name, sh) + const promptText = composerPromptText(ui.theme.brand.prompt, ui.info?.profile_name, sh, TERMUX_TUI_MODE, composer.cols) const promptWidth = composerPromptWidth(promptText) const promptBlank = ' '.repeat(promptWidth) - const inputColumns = stableComposerColumns(composer.cols, promptWidth) + const inputColumns = stableComposerColumns(composer.cols, promptWidth, TERMUX_TUI_MODE) const inputHeight = inputVisualHeight(composer.input, inputColumns) const inputMouseRef = useRef(null) diff --git a/ui-tui/src/components/appOverlays.tsx b/ui-tui/src/components/appOverlays.tsx index c12624a4bf8..7fd14563a99 100644 --- a/ui-tui/src/components/appOverlays.tsx +++ b/ui-tui/src/components/appOverlays.tsx @@ -187,10 +187,15 @@ export function FloatingOverlays({ key={`${start + i}:${item.text}:${item.display}:${item.meta ?? ''}`} width="100%" > - - {' '} - {item.display} - + {/* flexShrink=0 โ€” when meta overflows the row, Ink/Yoga + otherwise shaves the last char off the display column + (e.g. /goal renders as /goa). */} + + + {' '} + {item.display} + + {item.meta ? ( + {lines.map(([c, text], i) => ( - + {text} ))} - + ) } -export function Banner({ t }: { t: Theme }) { - const cols = useStdout().stdout?.columns ?? 80 +// Responsive Banner: full art โ†’ compact rule โ†’ text โ†’ hidden. +// +// Terminals can't scale glyphs, so "responsive" means picking a layout that +// fits the available columns. Thresholds are picked so each tier reads +// comfortably without forcing wrap or truncation drift on box-drawing edges. +const TAG_FULL = 'Nous Research ยท Messenger of the Digital Gods' +const TAG_MID = 'Messenger of the Digital Gods' +const TAG_TINY = 'Nous Research' +const HIDE_BELOW = 34 +const COMPACT_FROM = 58 + +const clip = (s: string, w: number) => + w <= 0 ? '' : s.length > w ? `${s.slice(0, Math.max(0, w - 1))}โ€ฆ` : s + +const centerIn = (s: string, w: number) => { + const f = clip(s, w) + const slack = Math.max(0, w - f.length) + const left = slack >> 1 + + return `${' '.repeat(left)}${f}${' '.repeat(slack - left)}` +} + +const ruleIn = (label: string, w: number) => { + const f = clip(label, Math.max(1, w - 4)) + const slack = Math.max(0, w - f.length - 2) + const left = slack >> 1 + + return `${'โ”€'.repeat(left)} ${f} ${'โ”€'.repeat(slack - left)}` +} + +function CompactBanner({ cols, t }: { cols: number; t: Theme }) { + // -4 keeps a margin so exact-edge rows don't trip terminal pending-wrap. + const w = Math.max(28, cols - 4) + + return ( + + {ruleIn(t.brand.name, w)} + {centerIn(TAG_FULL, w)} + {'โ”€'.repeat(w)} + + ) +} + +export function Banner({ maxWidth, t }: { maxWidth?: number; t: Theme }) { + const term = useStdout().stdout?.columns ?? 80 + const cols = Math.max(1, Math.min(term, maxWidth ?? term)) + + if (cols < HIDE_BELOW) { + return null + } + const logoLines = logo(t.color, t.bannerLogo || undefined) + const logoW = t.bannerLogo ? artWidth(logoLines) : LOGO_WIDTH + + if (cols >= logoW + 2) { + return ( + + + + {t.brand.icon} {TAG_FULL} + + + ) + } + + if (cols >= COMPACT_FROM) { + return + } + + const name = cols >= 52 ? t.brand.name : (t.brand.name.split(' ')[0] ?? t.brand.name) + const tag = cols >= 64 ? TAG_FULL : cols >= 46 ? TAG_MID : TAG_TINY return ( - {cols >= (t.bannerLogo ? artWidth(logoLines) : LOGO_WIDTH) ? ( - - ) : ( - - {t.brand.icon} NOUS HERMES - - )} - - {t.brand.icon} Nous Research ยท Messenger of the Digital Gods + {t.brand.icon} {name} + {t.brand.icon} {tag} ) } @@ -96,8 +157,9 @@ function CollapseToggle({ const SKILLS_MAX = 8 const TOOLSETS_MAX = 8 -export function SessionPanel({ info, sid, t }: SessionPanelProps) { - const cols = useStdout().stdout?.columns ?? 100 +export function SessionPanel({ info, maxWidth, sid, t }: SessionPanelProps) { + const term = useStdout().stdout?.columns ?? 100 + const cols = Math.max(20, Math.min(term, maxWidth ?? term)) const heroLines = caduceus(t.color, t.bannerHero || undefined) const leftW = Math.min((artWidth(heroLines) || CADUCEUS_WIDTH) + 4, Math.floor(cols * 0.4)) const wide = cols >= 90 && leftW + 40 < cols @@ -241,13 +303,33 @@ export function SessionPanel({ info, sid, t }: SessionPanelProps) { )} - - - {t.brand.name} - {info.version ? ` v${info.version}` : ''} - {info.release_date ? ` (${info.release_date})` : ''} - - + {wide ? ( + + + {t.brand.name} + {info.version ? ` v${info.version}` : ''} + {info.release_date ? ` (${info.release_date})` : ''} + + + ) : ( + // Narrow layout hides the hero column; surface model/cwd/session + // here so they aren't lost. + + + {info.model.split('/').pop()} + ยท Nous Research + + + {info.cwd || process.cwd()} + + {sid && ( + + Session: + {sid} + + )} + + )} {/* โ”€โ”€ Tools (expanded by default) โ”€โ”€ */} @@ -378,6 +460,7 @@ interface PanelProps { interface SessionPanelProps { info: SessionInfo + maxWidth?: number sid?: string | null t: Theme } diff --git a/ui-tui/src/components/messageLine.tsx b/ui-tui/src/components/messageLine.tsx index d44e29c1206..4d1481373ab 100644 --- a/ui-tui/src/components/messageLine.tsx +++ b/ui-tui/src/components/messageLine.tsx @@ -1,6 +1,7 @@ import { Ansi, Box, NoSelect, Text } from '@hermes/ink' import { memo, useState } from 'react' +import { TERMUX_TUI_MODE } from '../config/env.js' import { LONG_MSG } from '../config/limits.js' import { sectionMode } from '../domain/details.js' import { userDisplay } from '../domain/messages.js' @@ -139,7 +140,7 @@ export const MessageLine = memo(function MessageLine({ } if (msg.role === 'assistant') { - const bodyWidth = transcriptBodyWidth(cols, msg.role, t.brand.prompt) + const bodyWidth = transcriptBodyWidth(cols, msg.role, t.brand.prompt, TERMUX_TUI_MODE) return isStreaming ? ( // Incremental markdown: split at the last stable block boundary so @@ -201,7 +202,7 @@ export const MessageLine = memo(function MessageLine({ - {content} + {content} ) diff --git a/ui-tui/src/components/textInput.tsx b/ui-tui/src/components/textInput.tsx index 92082280a04..2e117a0a007 100644 --- a/ui-tui/src/components/textInput.tsx +++ b/ui-tui/src/components/textInput.tsx @@ -13,6 +13,7 @@ import { isVoiceToggleKey, type ParsedVoiceRecordKey } from '../lib/platform.js' +import { isTermuxTuiMode } from '../lib/termux.js' type InkExt = typeof Ink & { stringWidth: (s: string) => number @@ -33,6 +34,7 @@ const DIM_OFF = `${ESC}[22m` const FWD_DEL_RE = new RegExp(`${ESC}\\[3(?:[~$^]|;)`) const PRINTABLE = /^[ -~\u00a0-\uffff]+$/ const BRACKET_PASTE = new RegExp(`${ESC}?\\[20[01]~`, 'g') +const FRAME_BATCH_MS = 16 const MULTI_CLICK_MS = 500 const invert = (s: string) => INV + s + INV_OFF @@ -90,6 +92,36 @@ function snapPos(s: string, p: number) { return last } +export interface TextInsertResult { + cursor: number + value: string +} + +export function applyPrintableInsert( + value: string, + cursor: number, + text: string, + range?: { end: number; start: number } | null +): null | TextInsertResult { + if (!PRINTABLE.test(text)) { + return null + } + + if (range) { + return { + cursor: range.start + text.length, + value: value.slice(0, range.start) + text + value.slice(range.end) + } + } + + return { + cursor: cursor + text.length, + value: value.slice(0, cursor) + text + value.slice(cursor) + } +} + +export const shouldRouteMultiCharInputAsPaste = (text: string): boolean => text.includes('\n') + function prevPos(s: string, p: number) { const pos = snapPos(s, p) let prev = 0 @@ -298,7 +330,24 @@ export function canFastBackspaceShape(current: string, cursor: number, columns?: export function supportsFastEchoTerminal(env: NodeJS.ProcessEnv = process.env): boolean { // Terminal.app still shows paint/cursor artifacts under the fast-echo // bypass path. Fall back to the normal Ink render path there. - return (env.TERM_PROGRAM ?? '').trim() !== 'Apple_Terminal' + if ((env.TERM_PROGRAM ?? '').trim() === 'Apple_Terminal') { + return false + } + + // Termux terminals are especially sensitive to bypass-path cursor drift and + // stale paints at soft-wrap boundaries on tall/narrow viewports. Keep this + // off by default in Termux mode; allow explicit opt-in for local debugging. + if (isTermuxTuiMode(env)) { + const override = String(env.HERMES_TUI_TERMUX_FAST_ECHO ?? '').trim().toLowerCase() + + if (override) { + return /^(?:1|true|yes|on)$/i.test(override) + } + + return false + } + + return true } function renderWithCursor(value: string, cursor: number) { @@ -383,10 +432,7 @@ export function TextInput({ const selRef = useRef(null) const vRef = useRef(value) const self = useRef(false) - const pasteBuf = useRef('') - const pasteEnd = useRef(null) - const pasteTimer = useRef | null>(null) - const pastePos = useRef(0) + const keyBurstTimer = useRef | null>(null) const editVersionRef = useRef(0) const parentChangeTimer = useRef | null>(null) const pendingParentValue = useRef(null) @@ -519,8 +565,8 @@ export function TextInput({ useEffect( () => () => { - if (pasteTimer.current) { - clearTimeout(pasteTimer.current) + if (keyBurstTimer.current) { + clearTimeout(keyBurstTimer.current) } if (parentChangeTimer.current) { @@ -556,7 +602,7 @@ export function TextInput({ return } - parentChangeTimer.current = setTimeout(flushParentChange, 16) + parentChangeTimer.current = setTimeout(flushParentChange, FRAME_BATCH_MS) } const cancelLocalRender = () => { @@ -574,7 +620,7 @@ export function TextInput({ localRenderTimer.current = setTimeout(() => { localRenderTimer.current = null setCur(curRef.current) - }, 16) + }, FRAME_BATCH_MS) } const canFastEchoBase = () => supportsFastEchoTerminal() && focus && termFocus && !selected && !mask && !!stdout?.isTTY @@ -678,21 +724,26 @@ export function TextInput({ return !!h } - const flushPaste = () => { - const text = pasteBuf.current - const at = pastePos.current - const end = pasteEnd.current ?? at - pasteBuf.current = '' - pasteEnd.current = null - pasteTimer.current = null + const flushKeyBurst = () => { + if (keyBurstTimer.current) { + clearTimeout(keyBurstTimer.current) + keyBurstTimer.current = null + } - if (!text) { + flushParentChange() + } + + const scheduleKeyBurstCommit = (next: string, nextCur: number) => { + commit(next, nextCur, true, false, false) + + if (keyBurstTimer.current) { return } - if (!emitPaste({ cursor: at, text, value: vRef.current }) && PRINTABLE.test(text)) { - commit(vRef.current.slice(0, at) + text + vRef.current.slice(end), at + text.length) - } + keyBurstTimer.current = setTimeout(() => { + keyBurstTimer.current = null + flushParentChange() + }, FRAME_BATCH_MS) } const clearSel = () => { @@ -833,6 +884,8 @@ export function TextInput({ // follow-up on #19835). The pass-through predicate is a no-op for // ordinary typing and plain paste when voice is unbound to 'v'. if (shouldPassThroughToGlobalHandler(inp, k, voiceRecordKey)) { + flushKeyBurst() + return } @@ -842,6 +895,8 @@ export function TextInput({ eventRaw === '\x16' || (isMac && isActionMod(k) && inp.toLowerCase() === 'v') ) { + flushKeyBurst() + if (cbPaste.current) { return void emitPaste({ cursor: curRef.current, hotkey: true, text: '', value: vRef.current }) } @@ -858,6 +913,8 @@ export function TextInput({ } if (isMac && isActionMod(k) && inp.toLowerCase() === 'c') { + flushKeyBurst() + const range = selRange() if (range) { @@ -870,6 +927,8 @@ export function TextInput({ } if (k.upArrow || k.downArrow) { + flushKeyBurst() + const next = lineNav(vRef.current, curRef.current, k.upArrow ? -1 : 1) if (next !== null) { @@ -882,11 +941,11 @@ export function TextInput({ } if (k.return) { + flushKeyBurst() + if (k.shift || k.ctrl || (isMac ? isActionMod(k) : k.meta)) { - flushParentChange() commit(ins(vRef.current, curRef.current, '\n'), curRef.current + 1) } else { - flushParentChange() cbSubmit.current?.(vRef.current) } @@ -904,6 +963,11 @@ export function TextInput({ const actionDeleteWord = (mod && inp === 'w') || isMacActionFallback(k, inp, 'w') const range = selRange() const delFwd = k.delete || fwdDel.current + const isPrintableInput = (event.keypress.isPasted || inp.length > 0) && PRINTABLE.test(inp.replace(BRACKET_PASTE, '')) + + if (!isPrintableInput) { + flushKeyBurst() + } if (mod && inp === 'z') { return swap(undo, redo) @@ -1033,31 +1097,44 @@ export function TextInput({ } if (text.length > 1 || text.includes('\n')) { - if (!pasteBuf.current) { - pastePos.current = range ? range.start : c - pasteEnd.current = range ? range.end : pastePos.current + if (shouldRouteMultiCharInputAsPaste(text)) { + flushKeyBurst() + + if (!emitPaste({ cursor: c, text, value: v })) { + commit(ins(v, c, text), c + text.length) + } + + return } - pasteBuf.current += text + const inserted = applyPrintableInsert(v, c, text, range) - if (pasteTimer.current) { - clearTimeout(pasteTimer.current) + if (!inserted) { + return } - pasteTimer.current = setTimeout(flushPaste, 50) + v = inserted.value + c = inserted.cursor + scheduleKeyBurstCommit(v, c) return } - if (PRINTABLE.test(text)) { + { + const inserted = applyPrintableInsert(v, c, text, range) + + if (!inserted) { + return + } + if (range) { - v = v.slice(0, range.start) + text + v.slice(range.end) - c = range.start + text.length + v = inserted.value + c = inserted.cursor } else { const simpleAppend = canFastAppend(v, c, text) - v = v.slice(0, c) + text + v.slice(c) - c += text.length + v = inserted.value + c = inserted.cursor if (simpleAppend) { stdout!.write(text) @@ -1074,8 +1151,6 @@ export function TextInput({ return } } - } else { - return } } else { return @@ -1108,11 +1183,13 @@ export function TextInput({ if (e.button === 2) { e.stopImmediatePropagation?.() const decision = decideRightClickAction(vRef.current, selRange()) + if (decision.action === 'copy') { void writeClipboardText(decision.text) return } + emitPaste({ cursor: curRef.current, hotkey: true, text: '', value: vRef.current }) return @@ -1205,10 +1282,12 @@ export function decideRightClickAction( ): RightClickDecision { if (range && range.end > range.start) { const text = value.slice(range.start, range.end) + if (text) { return { action: 'copy', text } } } + return { action: 'paste' } } diff --git a/ui-tui/src/components/thinking.tsx b/ui-tui/src/components/thinking.tsx index 6908795f621..0d9ecee87c9 100644 --- a/ui-tui/src/components/thinking.tsx +++ b/ui-tui/src/components/thinking.tsx @@ -856,7 +856,16 @@ export const ToolTrail = memo(function ToolTrail({ color: t.color.text, key: tool.id, label, - details: [], + details: tool.verboseArgs + ? [ + { + color: t.color.muted, + content: `Args:\n${boundedLiveRenderText(tool.verboseArgs)}`, + dimColor: true, + key: `${tool.id}-args` + } + ] + : [], content: ( <> {label} diff --git a/ui-tui/src/entry.tsx b/ui-tui/src/entry.tsx index 690caf0cc95..effde40fef9 100644 --- a/ui-tui/src/entry.tsx +++ b/ui-tui/src/entry.tsx @@ -43,23 +43,24 @@ setupGracefulExit({ () => { resetTerminalModes() - return gw.kill() + return gw.kill('graceful-exit-cleanup') } ], onError: (scope, err) => { - const message = err instanceof Error ? `${err.name}: ${err.message}` : String(err) + const message = err instanceof Error ? `${err.name}: ${err.message}\n${err.stack ?? ''}` : String(err) - process.stderr.write(`hermes-tui ${scope}: ${message.slice(0, 2000)}\n`) + process.stderr.write(`hermes-tui lifecycle ${scope}: ${message.slice(0, 2000)}\n`) }, onSignal: signal => { resetTerminalModes() - process.stderr.write(`hermes-tui: received ${signal}\n`) + process.stderr.write(`hermes-tui lifecycle: received ${signal}\n`) } }) const stopMemoryMonitor = startMemoryMonitor({ onCritical: (snap, dump) => { resetTerminalModes() + process.stderr.write(`hermes-tui lifecycle: memory critical exit heap=${formatBytes(snap.heapUsed)} rss=${formatBytes(snap.rss)}\n`) process.stderr.write(dumpNotice(snap, dump)) process.stderr.write('hermes-tui: exiting to avoid OOM; restart to recover\n') process.exit(137) diff --git a/ui-tui/src/gatewayClient.ts b/ui-tui/src/gatewayClient.ts index 9590b386aa6..f3121152c90 100644 --- a/ui-tui/src/gatewayClient.ts +++ b/ui-tui/src/gatewayClient.ts @@ -21,6 +21,14 @@ const WS_CLOSED = 3 const truncateLine = (line: string) => line.length > MAX_LOG_LINE_BYTES ? `${line.slice(0, MAX_LOG_LINE_BYTES)}โ€ฆ [truncated ${line.length} bytes]` : line +const describeChild = (proc: ChildProcess | null) => { + if (!proc) { + return 'pid=none' + } + + return `pid=${proc.pid ?? 'unknown'} killed=${proc.killed} exitCode=${proc.exitCode ?? 'null'} signal=${proc.signalCode ?? 'null'}` +} + const resolveGatewayAttachUrl = () => { const raw = process.env.HERMES_TUI_GATEWAY_URL?.trim() @@ -85,7 +93,7 @@ const asWireText = (raw: unknown): string | null => { // otherwise-malformed URLs that the WHATWG `URL` parser can't accept. // Used by the `redactUrl` fallback so embedded credentials are // scrubbed from log lines even when the URL is unparseable. -const _USERINFO_FALLBACK_RE = /^([a-z][a-z0-9+.\-]*:\/\/)[^/?#@]*@/i +const _USERINFO_FALLBACK_RE = /^([a-z][a-z0-9+.-]*:\/\/)[^/?#@]*@/i // Connection URLs (gateway, sidecar) often carry bearer tokens in the query // string. We surface them in user-facing log lines and the @@ -191,6 +199,7 @@ export class GatewayClient extends EventEmitter { const ws = this.ws this.ws = null this.wsConnectPromise = null + try { ws?.close() } catch { @@ -239,6 +248,7 @@ export class GatewayClient extends EventEmitter { private handleTransportExit(code: null | number, reason?: string) { this.clearReadyTimer() this.closeSidecarSocket() + this.pushLog(`[lifecycle] transport exit code=${code ?? 'null'} reason=${reason ?? 'none'}`) this.rejectPending(new Error(reason || `gateway exited${code === null ? '' : ` (${code})`}`)) if (this.subscribed) { @@ -257,6 +267,7 @@ export class GatewayClient extends EventEmitter { if (typeof WebSocket === 'undefined') { this.pushLog(`[sidecar] WebSocket unavailable; skipping mirror to ${redactUrl(this.sidecarUrl)}`) + return } @@ -324,6 +335,7 @@ export class GatewayClient extends EventEmitter { env.PYTHONPATH = pyPath ? `${root}${delimiter}${pyPath}` : root this.startReadyTimer(python, cwd) this.proc = spawn(python, ['-m', 'tui_gateway.entry'], { cwd, env, stdio: ['pipe', 'pipe', 'pipe'] }) + this.pushLog(`[lifecycle] spawned gateway child ${describeChild(this.proc)} python=${python} cwd=${cwd}`) this.stdoutRl = createInterface({ input: this.proc.stdout! }) this.stdoutRl.on('line', raw => { @@ -353,11 +365,14 @@ export class GatewayClient extends EventEmitter { this.proc.on('error', err => { // Skip stale errors on an already-replaced child. if (this.proc !== ownedProc) { + this.pushLog(`[lifecycle] stale child error ignored ${describeChild(ownedProc)} message=${err.message}`) + return } const line = `[spawn] ${err.message}` + this.pushLog(`[lifecycle] child error ${describeChild(ownedProc)} message=${err.message}`) this.pushLog(line) this.publish({ type: 'gateway.stderr', payload: { line } }) // Detach the reference up front so the late `exit` event for @@ -369,14 +384,19 @@ export class GatewayClient extends EventEmitter { this.proc = null this.handleTransportExit(1, `gateway error: ${err.message}`) }) - this.proc.on('exit', code => { + this.proc.on('exit', (code, signal) => { // start() can replace `this.proc` while an old child is still // tearing down. Skip stale exits so we don't clear the new // startup timer or reject newly-issued pending requests. if (this.proc !== ownedProc) { + this.pushLog( + `[lifecycle] stale child exit ignored ${describeChild(ownedProc)} code=${code ?? 'null'} signal=${signal ?? 'null'}` + ) + return } + this.pushLog(`[lifecycle] child exit ${describeChild(ownedProc)} code=${code ?? 'null'} signal=${signal ?? 'null'}`) this.handleTransportExit(code) }) } @@ -400,6 +420,7 @@ export class GatewayClient extends EventEmitter { let settled = false this.ws = ws + const connectPromise = new Promise((resolve, reject) => { ws.addEventListener( 'open', @@ -454,9 +475,12 @@ export class GatewayClient extends EventEmitter { // new ready timer or reject the new pending requests on behalf // of a stale socket. if (this.ws !== ws) { + this.pushLog(`[lifecycle] stale websocket close ignored code=${ev.code}`) + return } + this.pushLog(`[lifecycle] websocket close code=${ev.code}`) this.ws = null this.wsConnectPromise = null this.handleTransportExit(ev.code, `gateway websocket closed${ev.code ? ` (${ev.code})` : ''}`) @@ -483,14 +507,17 @@ export class GatewayClient extends EventEmitter { this.resetStartupState() if (this.proc && !this.proc.killed && this.proc.exitCode === null) { + this.pushLog(`[lifecycle] replacing live gateway child ${describeChild(this.proc)}`) this.proc.kill() } + this.proc = null this.closeGatewaySocket() this.closeSidecarSocket() if (attachUrl) { this.startAttachedGateway(attachUrl) + return } @@ -686,8 +713,11 @@ export class GatewayClient extends EventEmitter { }) } - kill() { - this.proc?.kill() + kill(reason = 'requested') { + const proc = this.proc + const killed = proc?.kill() + + this.pushLog(`[lifecycle] GatewayClient.kill reason=${reason} ${describeChild(proc)} killResult=${killed ?? 'none'}`) this.closeGatewaySocket() this.closeSidecarSocket() this.clearReadyTimer() diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index ab85c39fbdd..9de1c85112d 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -477,11 +477,11 @@ export type GatewayEvent = type: 'gateway.start_timeout' } | { payload?: { preview?: string }; session_id?: string; type: 'gateway.protocol_error' } - | { payload?: { text?: string }; session_id?: string; type: 'reasoning.delta' | 'reasoning.available' } + | { payload?: { text?: string; verbose?: boolean }; session_id?: string; type: 'reasoning.delta' | 'reasoning.available' } | { payload: { name?: string; preview?: string }; session_id?: string; type: 'tool.progress' } | { payload: { name?: string }; session_id?: string; type: 'tool.generating' } | { - payload: { context?: string; name?: string; tool_id: string; todos?: unknown[] } + payload: { args_text?: string; context?: string; name?: string; tool_id: string; todos?: unknown[] } session_id?: string type: 'tool.start' } @@ -491,6 +491,7 @@ export type GatewayEvent = error?: string inline_diff?: string name?: string + result_text?: string summary?: string tool_id: string todos?: unknown[] diff --git a/ui-tui/src/hooks/useVirtualHistory.ts b/ui-tui/src/hooks/useVirtualHistory.ts index ef96ae1078c..592d20e9a07 100644 --- a/ui-tui/src/hooks/useVirtualHistory.ts +++ b/ui-tui/src/hooks/useVirtualHistory.ts @@ -51,6 +51,18 @@ const SLIDE_STEP = 12 const NOOP = () => {} +export const virtualHistorySnapshotKey = (s?: ScrollBoxHandle | null): string => { + if (!s) { + return 'none' + } + + const target = s.getScrollTop() + s.getPendingDelta() + const bin = Math.floor(target / QUANTUM) + const viewportHeight = Math.max(0, s.getViewportHeight()) + + return `${s.isSticky() ? ~bin : bin}:${viewportHeight}` +} + const upperBound = (arr: ArrayLike, target: number, length = arr.length) => { let lo = 0 let hi = length @@ -174,11 +186,9 @@ export function useVirtualHistory( }, [scrollRef]) // Quantized snapshot: same-bin scrolls (most wheel ticks) produce the same - // number โ†’ React.Object.is short-circuits the commit entirely. sticky state - // is folded in via the sign bit so stickyโ†’broken transitions also trigger. - // Uses the TARGET (committed + pendingDelta), not committed scrollTop, so - // scrollBy notifications immediately remount for the destination before - // Ink's drain frames need the children. + // key โ†’ React.Object.is short-circuits the commit entirely. The key includes + // sticky state, target scroll position, and viewport height so resize-only + // changes still recompute the mounted transcript window. const subscribe = useCallback( (cb: () => void) => (hasScrollRef ? scrollRef.current?.subscribe(cb) : null) ?? NOOP, [hasScrollRef, scrollRef] @@ -186,19 +196,8 @@ export function useVirtualHistory( useSyncExternalStore( subscribe, - () => { - const s = scrollRef.current - - if (!s) { - return NaN - } - - const target = s.getScrollTop() + s.getPendingDelta() - const bin = Math.floor(target / QUANTUM) - - return s.isSticky() ? ~bin : bin - }, - () => NaN + () => virtualHistorySnapshotKey(scrollRef.current), + () => 'none' ) useEffect(() => { @@ -249,8 +248,26 @@ export function useVirtualHistory( // During a freeze, drop the frozen range if items shrank past its start // (/clear, compaction) โ€” clamping would collapse to an empty mount and // flash blank. Fall through to the normal path in that case. - const frozenRange = - freezeRenders.current > 0 && prevRange.current && prevRange.current[0] < n ? prevRange.current : null + const frozenRangeCandidate = + freezeRenders.current > 0 && prevRange.current && prevRange.current[0] < n + ? ([prevRange.current[0], Math.min(prevRange.current[1], n)] as const) + : null + + // Width grows can shrink wrapped rows enough that the old tail window no + // longer covers the viewport. In that case freezing preserves stale spacers + // and visually cuts off the last message, so recompute immediately. + const frozenRange = (() => { + if (!frozenRangeCandidate || vp <= 0) { + return frozenRangeCandidate + } + + const visibleTop = sticky && !recentManual ? Math.max(0, total - vp) : target + const visibleBottom = visibleTop + vp + const rangeTop = offsets[frozenRangeCandidate[0]] ?? 0 + const rangeBottom = offsets[frozenRangeCandidate[1]] ?? total + + return rangeTop <= visibleTop && rangeBottom >= visibleBottom ? frozenRangeCandidate : null + })() let start = 0 let end = n @@ -465,6 +482,7 @@ export function useVirtualHistory( if (skipMeasurement.current) { skipMeasurement.current = false + bumpMeasuredHeightVersion(n => n + 1) } else { for (let i = effStart; i < effEnd; i++) { const k = items[i]?.key diff --git a/ui-tui/src/lib/inputMetrics.ts b/ui-tui/src/lib/inputMetrics.ts index 4c624da167a..5311e8e888b 100644 --- a/ui-tui/src/lib/inputMetrics.ts +++ b/ui-tui/src/lib/inputMetrics.ts @@ -61,6 +61,7 @@ function visualLines(value: string, cols: number): VisualLine[] { } lineStart = originalIdx + continue } @@ -177,14 +178,26 @@ export function transcriptGutterWidth(role: Role, userPrompt: string) { return role === 'user' ? composerPromptWidth(userPrompt) : 3 } -export function transcriptBodyWidth(totalCols: number, role: Role, userPrompt: string) { - return Math.max(20, totalCols - transcriptGutterWidth(role, userPrompt) - 2) +export function transcriptBodyWidth(totalCols: number, role: Role, userPrompt: string, termuxMode = false) { + const horizontalReserve = termuxMode ? 2 : 4 + const available = Math.max(1, totalCols - transcriptGutterWidth(role, userPrompt) - horizontalReserve) + + if (termuxMode) { + // On narrow / unusual aspect-ratio mobile panes, forcing a wide minimum + // width causes right-edge clipping and chopped words. + return available + } + + return Math.max(20, available) } -export function stableComposerColumns(totalCols: number, promptWidth: number) { +export function stableComposerColumns(totalCols: number, promptWidth: number, termuxMode = false) { // Physical render/wrap width. Always reserve outer composer padding and // prompt prefix. Only reserve the transcript scrollbar gutter when the // terminal is wide enough; on narrow panes, preserving input columns beats // keeping gutters visually aligned. - return Math.max(1, totalCols - promptWidth - 2 - (totalCols - promptWidth >= 24 ? 2 : 0)) + const afterPrompt = totalCols - promptWidth + const reserveScrollbar = afterPrompt >= (termuxMode ? 36 : 24) ? 2 : 0 + + return Math.max(1, totalCols - promptWidth - 2 - reserveScrollbar) } diff --git a/ui-tui/src/lib/prompt.ts b/ui-tui/src/lib/prompt.ts index 15607b61362..10961b90312 100644 --- a/ui-tui/src/lib/prompt.ts +++ b/ui-tui/src/lib/prompt.ts @@ -1,8 +1,32 @@ -export function composerPromptText(prompt: string, profileName?: null | string, shellMode = false): string { +const TERMUX_SAFE_PROMPT = '>' + +export function composerPromptText( + prompt: string, + profileName?: null | string, + shellMode = false, + termuxMode = false, + totalCols?: number +): string { if (shellMode) { return '$' } + if (termuxMode) { + // Termux fonts/terminal backends can render decorative prompt glyphs with + // ambiguous width; keep the live composer marker strictly single-cell ASCII + // so we never leave stale arrow artifacts while typing. + const basePrompt = TERMUX_SAFE_PROMPT + + // On very wide panes we can still include profile context. On narrow/mobile + // panes this burns precious columns and increases wrap/clipping risk. + const wideEnoughForProfile = typeof totalCols === 'number' ? totalCols >= 90 : false + if (wideEnoughForProfile && profileName && !['default', 'custom'].includes(profileName)) { + return `${profileName} ${basePrompt}` + } + + return basePrompt + } + if (profileName && !['default', 'custom'].includes(profileName)) { return `${profileName} ${prompt}` } diff --git a/ui-tui/src/lib/text.ts b/ui-tui/src/lib/text.ts index 5b52c236719..2b1ae33c592 100644 --- a/ui-tui/src/lib/text.ts +++ b/ui-tui/src/lib/text.ts @@ -212,6 +212,28 @@ export const buildToolTrailLine = ( return `${formatToolCall(name, context)}${took}${detail ? ` :: ${detail}` : ''} ${error ? 'โœ—' : 'โœ“'}` } +const verboseToolBlock = (label: string, text?: string) => { + const body = (text ?? '').trim() + + return body ? `${label}:\n${boundedLiveRenderText(body)}` : '' +} + +export const buildVerboseToolTrailLine = ( + name: string, + context: string, + error?: boolean, + duration?: number, + argsText?: string, + resultText?: string +) => { + const detail = [verboseToolBlock('Args', argsText), verboseToolBlock(error ? 'Error' : 'Result', resultText)] + .filter(Boolean) + .join('\n') + const took = duration !== undefined ? ` (${duration.toFixed(1)}s)` : '' + + return `${formatToolCall(name, context)}${took}${detail ? ` :: ${detail}` : ''} ${error ? 'โœ—' : 'โœ“'}` +} + export const isToolTrailResultLine = (line: string) => line.endsWith(' โœ“') || line.endsWith(' โœ—') export const parseToolTrailResultLine = (line: string) => { @@ -221,10 +243,10 @@ export const parseToolTrailResultLine = (line: string) => { const mark = line.endsWith(' โœ—') ? 'โœ—' : 'โœ“' const body = line.slice(0, -2) - const [call, detail] = body.split(' :: ', 2) + const sep = body.indexOf(' :: ') - if (detail != null) { - return { call, detail, mark } + if (sep >= 0) { + return { call: body.slice(0, sep), detail: body.slice(sep + 4), mark } } const legacy = body.indexOf(': ') diff --git a/ui-tui/src/lib/virtualHeights.ts b/ui-tui/src/lib/virtualHeights.ts index 0e58b814d12..874f8a1b8dc 100644 --- a/ui-tui/src/lib/virtualHeights.ts +++ b/ui-tui/src/lib/virtualHeights.ts @@ -1,5 +1,6 @@ import type { Msg } from '../types.js' +import { TERMUX_TUI_MODE } from '../config/env.js' import { transcriptBodyWidth } from './inputMetrics.js' const hashText = (text: string) => { @@ -96,7 +97,7 @@ export const estimatedMsgHeight = ( return Math.max(2, msg.todos.length + 2) } - const bodyWidth = transcriptBodyWidth(cols, msg.role, userPrompt) + const bodyWidth = transcriptBodyWidth(cols, msg.role, userPrompt, TERMUX_TUI_MODE) const text = msg.text let h = wrappedLines(text || ' ', bodyWidth) diff --git a/ui-tui/src/types.ts b/ui-tui/src/types.ts index f0651bef9c5..0bfab6c271d 100644 --- a/ui-tui/src/types.ts +++ b/ui-tui/src/types.ts @@ -2,6 +2,7 @@ export interface ActiveTool { context?: string id: string name: string + verboseArgs?: string startedAt?: number } diff --git a/web/README.md b/web/README.md index d8127f96e03..c9581635b2f 100644 --- a/web/README.md +++ b/web/README.md @@ -17,9 +17,14 @@ python -m hermes_cli.main web --no-open # In another terminal, start the Vite dev server (with HMR + API proxy) cd web/ +npm install npm run dev ``` +Open the **Vite URL** printed in the terminal (usually `http://localhost:5173`). That is the live-reload UI. + +`hermes dashboard` on port 9119 serves the **built** bundle from `hermes_cli/web_dist/`, not the Vite dev server โ€” changes in `web/src/` will not appear there until you run `npm run build` and restart the dashboard (or use `web --no-open` + Vite as above). + The Vite dev server proxies `/api` requests to `http://127.0.0.1:9119` (the FastAPI backend). ## Build @@ -46,3 +51,54 @@ src/ โ”œโ”€โ”€ main.tsx # React entry point โ””โ”€โ”€ index.css # Tailwind imports and theme variables ``` + +## Typography & contrast rules + +Read before adding or editing UI styles. These rules keep the dashboard legible across all built-in themes and stop drift back into the patterns the design system was just refactored out of. + +### Text size floor + +- **Minimum body size: `text-xs` (12px / 0.75rem).** Do not use arbitrary `text-[0.6rem]`, `text-[0.65rem]`, `text-[9px]`, `text-[10px]`, or `text-[11px]` on copy, hints, labels, counts, or badges. Use the standard scale: `text-xs`, `text-sm`, `text-base`. +- Smaller sizes are only acceptable on **decorative overlays** (chart stripes, empty-state icons) โ€” never on text the user is meant to read. + +### Opacity floor on text + +- **Never apply opacity below 0.7 to text.** No `opacity-30`, `opacity-50`, `opacity-60` on ``s, `

`s, labels, etc. +- **Do not stack opacity tokens.** Patterns like `text-muted-foreground/60`, `text-midground/70`, `text-foreground/50` create unpredictable WCAG failures because the parent token already has alpha. +- Use the **semantic text tokens** from `@nous-research/ui`'s `globals.css`: + - `text-text-primary` โ€” default body text. + - `text-text-secondary` โ€” subtitles, meta, inactive nav. + - `text-text-tertiary` โ€” small chrome labels, counts, footnotes. + - `text-text-disabled` โ€” disabled states. + - `text-text-on-accent` โ€” text on filled accent surfaces. + +### Brand uppercase via `text-display`, not raw `uppercase` + +- The dashboard preserves the Nous brand uppercase aesthetic, but it is **opt-in per element, not global**. +- Apply uppercase via the DS utility `text-display` on **brand chrome only** โ€” page titles, nav section headings, badges, brand wordmark. DS components (`Button`, `Badge`, `Tabs`, `Segmented`, etc.) already self-apply `text-display`. +- **Do not introduce new `uppercase`** (the literal Tailwind class) in `hermes-agent/web/src`. Prefer `text-display` for new brand chrome. Legacy `uppercase` call sites (e.g. `components/ui/label.tsx`, `card.tsx`) remain until migrated. +- The app shell no longer forces uppercase globally, so blanket `normal-case` opt-outs are unnecessary. Use `normal-case` only where a DS component applies `text-display` but the label should stay sentence case โ€” e.g. dynamic user content (model slugs, theme names) **or** fixed UI copy that is not brand chrome (EnvPage โ€œnot configuredโ€ toggle, sidebar โ€œNew chatโ€). + +### Fonts + +Typography is **opt-in per surface**, not global on layout shells โ€” the app shell and page header keep their original theme/expanded fonts; Mondwest applies only where explicitly set. + +| Tier | Classes | Use for | +|------|---------|---------| +| Brand chrome | `font-mondwest text-display` (or `themedChrome`) | Sidebar nav, card section headers (`CardTitle`), Segmented filter buttons, filter panel headings | +| Themed body | `font-mondwest normal-case` (or `themedBody`) | Card content (`Card`, `CardDescription`), session/platform rows, analytics tables โ€” **scoped to the component** | +| Page chrome | `font-expanded` | Page header h1 (`PageHeaderProvider`) โ€” sentence case, not `text-display` | +| Wordmark | `Typography` + size/tracking only | Sidebar/mobile โ€œHermes Agentโ€ โ€” mixed case, no Mondwest, no `text-display` | +| Technical | `font-mono-ui` / `font-mono` / `font-courier` | Model slugs, env keys, schedules, YAML, repo URLs | + +- Do **not** put `themedBody` or `themedFont` on `

`, `App`, or other layout wrappers โ€” it overrides component-scoped styles. +- **`Card`** applies `themedBody`; **`CardTitle`** uses `text-display` (uppercase chrome); **`CardDescription`** uses `themedBody`. +- **`NouiTypography`** defaults to `font-sans` unless a font prop is passed. +- Do **not** use raw `font-sans` or `font-display` (theme sans variable) on new dashboard UI โ€” prefer Mondwest tiers above where brand-appropriate. + +### Color tokens + +- Prefer **semantic tokens** (`text-text-*`, `bg-card`, `border-border`, `text-foreground`, `text-destructive`, `text-success`, `text-warning`) over raw layer references (`text-midground`, `text-foreground`). +- `text-muted-foreground` is now wired to `--color-text-secondary`, so existing call sites stay correct, but new code should prefer the semantic name. +- When you genuinely need a non-token color (icon de-emphasis on a chart, terminal foreground via inline style), keep alpha at `โ‰ฅ 0.7` for any text. + diff --git a/web/package-lock.json b/web/package-lock.json index 034d48a1f89..caf43731a17 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -8,7 +8,7 @@ "name": "web", "version": "0.0.0", "dependencies": { - "@nous-research/ui": "^0.14.2", + "@nous-research/ui": "0.16.0", "@observablehq/plot": "^0.6.17", "@react-three/fiber": "^9.6.0", "@tailwindcss/vite": "^4.2.1", @@ -77,6 +77,7 @@ "integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.29.0", "@babel/generator": "^7.29.0", @@ -1079,9 +1080,9 @@ } }, "node_modules/@nous-research/ui": { - "version": "0.14.2", - "resolved": "https://registry.npmjs.org/@nous-research/ui/-/ui-0.14.2.tgz", - "integrity": "sha512-H3cMt2e0IpmcTNOmR6zVX+8ja48w4X4F/IFXhWCpaoVs8zKVRN12Ryb4RnX/ac8IrbUu6UsIds7ZtmXxPHcfdQ==", + "version": "0.16.0", + "resolved": "https://registry.npmjs.org/@nous-research/ui/-/ui-0.16.0.tgz", + "integrity": "sha512-JvSwf9vBOCEEGDSOYIRn/F/JJSBDh9DvGU3s3OFbX6K1otnSK7s47cZdgvfBoEPmeKFom2fWQDDqfzLV+eR7Qg==", "license": "MIT", "dependencies": { "@nanostores/react": "^1.1.0", @@ -1127,6 +1128,7 @@ "resolved": "https://registry.npmjs.org/@observablehq/plot/-/plot-0.6.17.tgz", "integrity": "sha512-/qaXP/7mc4MUS0s4cPPFASDRjtsWp85/TbfsciqDgU1HwYixbSbbytNuInD8AcTYC3xaxACgVX06agdfQy9W+g==", "license": "ISC", + "peer": true, "dependencies": { "d3": "^7.9.0", "interval-tree-1d": "^1.0.0", @@ -1865,6 +1867,7 @@ "resolved": "https://registry.npmjs.org/@react-three/fiber/-/fiber-9.6.0.tgz", "integrity": "sha512-90abYK2q5/qDM+GACs9zRvc5KhEEpEWqWlHSd64zTPNxg+9wCJvTfyD9x2so7hlQhjRYO1Fa6flR3BC/kpTFkA==", "license": "MIT", + "peer": true, "dependencies": { "@babel/runtime": "^7.17.8", "@types/webxr": "*", @@ -2570,6 +2573,7 @@ "integrity": "sha512-A1sre26ke7HDIuY/M23nd9gfB+nrmhtYyMINbjI1zHJxYteKR6qSMX56FsmjMcDb3SMcjJg5BiRRgOCC/yBD0g==", "devOptional": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -2579,6 +2583,7 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.14.tgz", "integrity": "sha512-ilcTH/UniCkMdtexkoCN0bI7pMcJDvmQFPvuPvmEaYA/NSfFTAgdUSLAoVjaRJm7+6PvcM+q1zYOwS4wTYMF9w==", "license": "MIT", + "peer": true, "dependencies": { "csstype": "^3.2.2" } @@ -2589,6 +2594,7 @@ "integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==", "devOptional": true, "license": "MIT", + "peer": true, "peerDependencies": { "@types/react": "^19.2.0" } @@ -2653,6 +2659,7 @@ "integrity": "sha512-HDQH9O/47Dxi1ceDhBXdaldtf/WV9yRYMjbjCuNk3qnaTD564qwv61Y7+gTxwxRKzSrgO5uhtw584igXVuuZkA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.59.1", "@typescript-eslint/types": "8.59.1", @@ -2981,6 +2988,7 @@ "integrity": "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -3133,6 +3141,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.10.12", "caniuse-lite": "^1.0.30001782", @@ -3640,6 +3649,7 @@ "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "license": "ISC", + "peer": true, "engines": { "node": ">=12" } @@ -3959,6 +3969,7 @@ "integrity": "sha512-XoMjdBOwe/esVgEvLmNsD3IRHkm7fbKIUGvrleloJXUZgDHig2IPWNniv+GwjyJXzuNqVjlr5+4yVUZjycJwfQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -4269,13 +4280,13 @@ } }, "node_modules/framer-motion": { - "version": "12.39.0", - "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-12.39.0.tgz", - "integrity": "sha512-+vnLfzrv0MzjLzNl+nvNvR7jdg3q4cxxjz/YvzfifHl0TREtL00cs1RoMTxs+1PzLiEqZGV6gYsBY0oEAYZ24w==", + "version": "12.38.0", + "resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-12.38.0.tgz", + "integrity": "sha512-rFYkY/pigbcswl1XQSb7q424kSTQ8q6eAC+YUsSKooHQYuLdzdHjrt6uxUC+PRAO++q5IS7+TamgIw1AphxR+g==", "license": "MIT", "dependencies": { - "motion-dom": "^12.39.0", - "motion-utils": "^12.39.0", + "motion-dom": "^12.38.0", + "motion-utils": "^12.36.0", "tslib": "^2.4.0" }, "peerDependencies": { @@ -4364,7 +4375,8 @@ "version": "3.15.0", "resolved": "https://registry.npmjs.org/gsap/-/gsap-3.15.0.tgz", "integrity": "sha512-dMW4CWBTUK1AEEDeZc1g4xpPGIrSf9fJF960qbTZmN/QwZIWY5wgliS6JWl9/25fpTGJrMRtSjGtOmPnfjZB+A==", - "license": "Standard 'no charge' license: https://gsap.com/standard-license." + "license": "Standard 'no charge' license: https://gsap.com/standard-license.", + "peer": true }, "node_modules/has-flag": { "version": "4.0.0", @@ -4679,6 +4691,7 @@ "resolved": "https://registry.npmjs.org/leva/-/leva-0.10.1.tgz", "integrity": "sha512-BcjnfUX8jpmwZUz2L7AfBtF9vn4ggTH33hmeufDULbP3YgNZ/C+ss/oO3stbrqRQyaOmRwy70y7BGTGO81S3rA==", "license": "MIT", + "peer": true, "dependencies": { "@radix-ui/react-portal": "^1.1.4", "@radix-ui/react-tooltip": "^1.1.8", @@ -5082,12 +5095,13 @@ } }, "node_modules/motion": { - "version": "12.39.0", - "resolved": "https://registry.npmjs.org/motion/-/motion-12.39.0.tgz", - "integrity": "sha512-H4a+Ze+a9j+/NTla5ezfb/g9vmIOxC+viDj++NGDZyTZkdRKjiOz3kSv6TalRWM8ZmD2y/CfC6TkQc97ybyqSA==", + "version": "12.38.0", + "resolved": "https://registry.npmjs.org/motion/-/motion-12.38.0.tgz", + "integrity": "sha512-uYfXzeHlgThchzwz5Te47dlv5JOUC7OB4rjJ/7XTUgtBZD8CchMN8qEJ4ZVsUmTyYA44zjV0fBwsiktRuFnn+w==", "license": "MIT", + "peer": true, "dependencies": { - "framer-motion": "^12.39.0", + "framer-motion": "^12.38.0", "tslib": "^2.4.0" }, "peerDependencies": { @@ -5108,18 +5122,18 @@ } }, "node_modules/motion-dom": { - "version": "12.39.0", - "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-12.39.0.tgz", - "integrity": "sha512-Xn7aAcGDhco/JZTXOub64UmaYn73C6J1Po7Fk+8EvkJsNGTqfhon6UJY53vJKXW5v5Zl8HrYsVxv6oPXeGoGLQ==", + "version": "12.38.0", + "resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-12.38.0.tgz", + "integrity": "sha512-pdkHLD8QYRp8VfiNLb8xIBJis1byQ9gPT3Jnh2jqfFtAsWUA3dEepDlsWe/xMpO8McV+VdpKVcp+E+TGJEtOoA==", "license": "MIT", "dependencies": { - "motion-utils": "^12.39.0" + "motion-utils": "^12.36.0" } }, "node_modules/motion-utils": { - "version": "12.39.0", - "resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-12.39.0.tgz", - "integrity": "sha512-8nadJAJjTtqRkmRF36FoJTrywK9nnFmnPwnSMyxaOCU7GDjN9RTMJIxx9De8ErM+vpPhMccr/6fo5WciyQLnMQ==", + "version": "12.36.0", + "resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-12.36.0.tgz", + "integrity": "sha512-eHWisygbiwVvf6PZ1vhaHCLamvkSbPIeAYxWUuL3a2PD/TROgE7FvfHWTIH4vMl798QLfMw15nRqIaRDXTlYRg==", "license": "MIT" }, "node_modules/ms": { @@ -5158,6 +5172,7 @@ } ], "license": "MIT", + "peer": true, "engines": { "node": "^20.0.0 || >=22.0.0" } @@ -5285,6 +5300,7 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -5356,6 +5372,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.5.tgz", "integrity": "sha512-llUJLzz1zTUBrskt2pwZgLq59AemifIftw4aB7JxOqf1HY2FDaGDxgwpAPVzHU1kdWabH7FauP4i1oEeer2WCA==", "license": "MIT", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -5375,6 +5392,7 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.5.tgz", "integrity": "sha512-J5bAZz+DXMMwW/wV3xzKke59Af6CHY7G4uYLN1OvBcKEsWOs4pQExj86BBKamxl/Ik5bx9whOrvBlSDfWzgSag==", "license": "MIT", + "peer": true, "dependencies": { "scheduler": "^0.27.0" }, @@ -5735,7 +5753,8 @@ "version": "0.180.0", "resolved": "https://registry.npmjs.org/three/-/three-0.180.0.tgz", "integrity": "sha512-o+qycAMZrh+TsE01GqWUxUIKR1AL0S8pq7zDkYOQw8GqfX8b8VoCKYUoHbhiX5j+7hr8XsuHDVU6+gkQJQKg9w==", - "license": "MIT" + "license": "MIT", + "peer": true }, "node_modules/tinyglobby": { "version": "0.2.16", @@ -5800,6 +5819,7 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -5898,6 +5918,7 @@ "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", "license": "MIT", + "peer": true, "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } @@ -5913,6 +5934,7 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz", "integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==", "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -6034,6 +6056,7 @@ "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", "dev": true, "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/web/package.json b/web/package.json index 7c4c60bfc68..49880e04b67 100644 --- a/web/package.json +++ b/web/package.json @@ -10,7 +10,7 @@ "preview": "vite preview" }, "dependencies": { - "@nous-research/ui": "^0.14.2", + "@nous-research/ui": "0.16.0", "@observablehq/plot": "^0.6.17", "@react-three/fiber": "^9.6.0", "@tailwindcss/vite": "^4.2.1", diff --git a/web/src/App.tsx b/web/src/App.tsx index 987252ce0bb..aeac02ae789 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -326,7 +326,9 @@ export default function App() { api .getConfig() .then((cfg) => { - const dash = (cfg?.dashboard ?? {}) as { show_token_analytics?: unknown }; + const dash = (cfg?.dashboard ?? {}) as { + show_token_analytics?: unknown; + }; setShowTokenAnalytics(dash.show_token_analytics === true); }) .catch(() => setShowTokenAnalytics(false)); @@ -366,7 +368,9 @@ export default function App() { const base = embeddedChat ? [CHAT_NAV_ITEM, ...BUILTIN_NAV_REST] : BUILTIN_NAV_REST; - return showTokenAnalytics ? base : base.filter((n) => n.path !== "/analytics"); + return showTokenAnalytics + ? base + : base.filter((n) => n.path !== "/analytics"); }, [embeddedChat, showTokenAnalytics]); const sidebarNav = useMemo( @@ -416,7 +420,7 @@ export default function App() { return (
@@ -442,7 +446,7 @@ export default function App() { aria-label={t.app.openNavigation} aria-expanded={mobileOpen} aria-controls="app-sidebar" - className="text-midground/70 hover:text-midground" + className="text-text-secondary hover:text-midground" > @@ -498,7 +502,7 @@ export default function App() { Hermes @@ -512,7 +516,7 @@ export default function App() { size="icon" onClick={closeMobile} aria-label={t.app.closeNavigation} - className="lg:hidden text-midground/70 hover:text-midground" + className="lg:hidden text-text-secondary hover:text-midground" > @@ -542,7 +546,7 @@ export default function App() { @@ -671,10 +675,12 @@ function SidebarNavLink({ closeMobile, item, t }: SidebarNavLinkProps) { cn( "group relative flex items-center gap-3", "px-5 py-2.5", - "font-mondwest text-[0.8rem] tracking-[0.12em]", + "font-mondwest text-display uppercase text-sm tracking-[0.12em]", "whitespace-nowrap transition-colors cursor-pointer", "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-midground", - isActive ? "text-midground" : "opacity-60 hover:opacity-100", + isActive + ? "text-midground" + : "text-text-secondary hover:text-midground", ) } style={{ @@ -746,7 +752,7 @@ function SidebarSystemActions({ onNavigate }: { onNavigate: () => void }) { {t.app.system} @@ -772,12 +778,12 @@ function SidebarSystemActions({ onNavigate }: { onNavigate: () => void }) { active={busy} className={cn( "gap-3 px-5 py-1.5 whitespace-nowrap", - "font-mondwest text-[0.75rem] tracking-[0.1em]", - "transition-opacity", + "font-mondwest text-display text-xs tracking-[0.1em]", + "transition-colors", busy - ? "text-midground opacity-100" - : "opacity-60 hover:opacity-100", - "disabled:opacity-30", + ? "text-midground" + : "text-text-secondary hover:text-midground", + "disabled:text-text-disabled", )} > {isPending ? ( diff --git a/web/src/components/AutoField.tsx b/web/src/components/AutoField.tsx index 0f96d420425..4e3451c10fd 100644 --- a/web/src/components/AutoField.tsx +++ b/web/src/components/AutoField.tsx @@ -11,8 +11,8 @@ function FieldHint({ schema, schemaKey }: { schema: Record; sch return (
- {keyPath && {keyPath}} - {description && {description}} + {keyPath && {keyPath}} + {description && {description}}
); } diff --git a/web/src/components/BottomPickSheet.tsx b/web/src/components/BottomPickSheet.tsx index 1490f4090c8..38cae8daa00 100644 --- a/web/src/components/BottomPickSheet.tsx +++ b/web/src/components/BottomPickSheet.tsx @@ -7,7 +7,7 @@ import { } from "react"; import { createPortal } from "react-dom"; import { Typography } from "@/components/NouiTypography"; -import { cn } from "@/lib/utils"; +import { cn, themedBody } from "@/lib/utils"; const CLOSE_DRAG_MIN_PX = 72; const CLOSE_DRAG_RATIO = 0.18; @@ -168,6 +168,7 @@ export function BottomPickSheet({ aria-modal="true" ref={sheetRef} className={cn( + themedBody, "relative flex max-h-[85dvh] min-h-0 flex-col rounded-t-xl border border-current/20", "bg-background-base/98 pb-[max(1rem,env(safe-area-inset-bottom))]", "shadow-[0_-12px_40px_-8px_rgba(0,0,0,0.55)] backdrop-blur-md", @@ -200,7 +201,7 @@ export function BottomPickSheet({ {title} diff --git a/web/src/components/ChatSidebar.tsx b/web/src/components/ChatSidebar.tsx index c311673fafc..a115d887ec3 100644 --- a/web/src/components/ChatSidebar.tsx +++ b/web/src/components/ChatSidebar.tsx @@ -304,13 +304,13 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) { return (