diff --git a/.env.example b/.env.example index 066e93f7c9..589978e6b5 100644 --- a/.env.example +++ b/.env.example @@ -398,3 +398,19 @@ IMAGE_TOOLS_DEBUG=false # Override STT provider endpoints (for proxies or self-hosted instances) # GROQ_BASE_URL=https://api.groq.com/openai/v1 # STT_OPENAI_BASE_URL=https://api.openai.com/v1 + +# ============================================================================= +# MICROSOFT TEAMS INTEGRATION +# ============================================================================= +# Register a Bot in Azure: https://dev.botframework.com/ → "Register a bot" +# Or use Azure Portal: Azure Active Directory → App registrations → New registration +# Then add the bot to Teams via the Bot Framework or App Studio. +# +# TEAMS_CLIENT_ID= # Azure AD App (client) ID +# TEAMS_CLIENT_SECRET= # Azure AD client secret value +# TEAMS_TENANT_ID= # Azure AD tenant ID (or "common" for multi-tenant) +# TEAMS_ALLOWED_USERS= # Comma-separated AAD object IDs or UPNs +# TEAMS_ALLOW_ALL_USERS=false # Set true to skip the allowlist +# TEAMS_HOME_CHANNEL= # Default channel/chat ID for cron delivery +# TEAMS_HOME_CHANNEL_NAME= # Display name for the home channel +# TEAMS_PORT=3978 # Webhook listen port (Bot Framework default) diff --git a/Dockerfile b/Dockerfile index d988ea6407..18177cc1ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ ENV PLAYWRIGHT_BROWSERS_PATH=/opt/hermes/.playwright # that would otherwise accumulate when hermes runs as PID 1. See #15012. RUN apt-get update && \ apt-get install -y --no-install-recommends \ - build-essential nodejs npm python3 ripgrep ffmpeg gcc python3-dev libffi-dev procps git openssh-client docker-cli tini && \ + build-essential curl nodejs npm python3 ripgrep ffmpeg gcc python3-dev libffi-dev procps git openssh-client docker-cli tini && \ rm -rf /var/lib/apt/lists/* # Non-root user for runtime; UID can be overridden via HERMES_UID at runtime diff --git a/acp_adapter/server.py b/acp_adapter/server.py index b21e6fa3bf..64a31063eb 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -13,6 +13,7 @@ from typing import Any, Deque, Optional import acp from acp.schema import ( AgentCapabilities, + AgentMessageChunk, AuthenticateResponse, AvailableCommand, AvailableCommandsUpdate, @@ -45,6 +46,7 @@ from acp.schema import ( TextContentBlock, UnstructuredCommandInput, Usage, + UserMessageChunk, ) # AuthMethodAgent was renamed from AuthMethod in agent-client-protocol 0.9.0 @@ -377,6 +379,78 @@ class HermesACPAgent(acp.Agent): # ---- Session management ------------------------------------------------- + @staticmethod + def _history_message_text(message: dict[str, Any]) -> str: + """Extract displayable text from a persisted OpenAI-style message.""" + content = message.get("content") + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + elif item.get("type") == "text" and isinstance(item.get("content"), str): + parts.append(item["content"]) + elif isinstance(item, str): + parts.append(item) + return "\n".join(part.strip() for part in parts if part and part.strip()).strip() + return "" + + @staticmethod + def _history_message_update( + *, + role: str, + text: str, + ) -> UserMessageChunk | AgentMessageChunk | None: + """Build an ACP history replay update for a user/assistant message.""" + block = TextContentBlock(type="text", text=text) + if role == "user": + return UserMessageChunk( + session_update="user_message_chunk", + content=block, + ) + if role == "assistant": + return AgentMessageChunk( + session_update="agent_message_chunk", + content=block, + ) + return None + + async def _replay_session_history(self, state: SessionState) -> None: + """Send persisted user/assistant history to clients during session/load. + + Zed's ACP history UI calls ``session/load`` after the user picks an item + from the Agents sidebar. The agent must then replay the full conversation + as ``user_message_chunk`` / ``agent_message_chunk`` notifications; merely + restoring server-side state makes Hermes remember context, but leaves the + editor looking like a clean thread. + """ + if not self._conn or not state.history: + return + + for message in state.history: + role = str(message.get("role") or "") + if role not in {"user", "assistant"}: + continue + text = self._history_message_text(message) + if not text: + continue + update = self._history_message_update(role=role, text=text) + if update is None: + continue + try: + await self._conn.session_update(session_id=state.session_id, update=update) + except Exception: + logger.warning( + "Failed to replay ACP history for session %s", + state.session_id, + exc_info=True, + ) + return + async def new_session( self, cwd: str, @@ -405,6 +479,7 @@ class HermesACPAgent(acp.Agent): return None await self._register_session_mcp_servers(state, mcp_servers) logger.info("Loaded session %s", session_id) + await self._replay_session_history(state) self._schedule_available_commands_update(session_id) return LoadSessionResponse(models=self._build_model_state(state)) @@ -421,6 +496,7 @@ class HermesACPAgent(acp.Agent): state = self.session_manager.create_session(cwd=cwd) await self._register_session_mcp_servers(state, mcp_servers) logger.info("Resumed session %s", state.session_id) + await self._replay_session_history(state) self._schedule_available_commands_update(state.session_id) return ResumeSessionResponse(models=self._build_model_state(state)) diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index 0ac0359e0a..efee8f6bf1 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -20,7 +20,7 @@ from pathlib import Path from hermes_constants import get_hermes_home from typing import Any, Dict, List, Optional, Tuple -from utils import normalize_proxy_env_vars +from utils import base_url_host_matches, normalize_proxy_env_vars # NOTE: `import anthropic` is deliberately NOT at module top — the SDK pulls # ~220 ms of imports (anthropic.types, anthropic.lib.tools._beta_runner, etc.) @@ -257,11 +257,10 @@ _OAUTH_ONLY_BETAS = [ "oauth-2025-04-20", ] -# Claude Code version — sent on OAuth token-exchange / refresh requests -# (platform.claude.com/v1/oauth/token) as the client's user-agent. Anthropic's -# OAuth flow validates the UA and may reject requests with a version that's -# too old, so detecting dynamically keeps users on a current Claude Code -# install from hitting stale-version errors during login/refresh. +# Claude Code identity — required for OAuth requests to be routed correctly. +# Without these, Anthropic's infrastructure intermittently 500s OAuth traffic. +# The version must stay reasonably current — Anthropic rejects OAuth requests +# when the spoofed user-agent version is too far behind the actual release. _CLAUDE_CODE_VERSION_FALLBACK = "2.1.74" _claude_code_version_cache: Optional[str] = None @@ -269,9 +268,9 @@ _claude_code_version_cache: Optional[str] = None def _detect_claude_code_version() -> str: """Detect the installed Claude Code version, fall back to a static constant. - Used only by the OAuth token-exchange / refresh flow - (``platform.claude.com/v1/oauth/token``). The Messages API client no - longer sends a claude-cli user-agent. + Anthropic's OAuth infrastructure validates the user-agent version and may + reject requests with a version that's too old. Detecting dynamically means + users who keep Claude Code updated never hit stale-version 400s. """ import subprocess as _sp @@ -291,13 +290,12 @@ def _detect_claude_code_version() -> str: return _CLAUDE_CODE_VERSION_FALLBACK -def _get_claude_code_version() -> str: - """Lazily detect the installed Claude Code version for OAuth flow headers. +_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude." +_MCP_TOOL_PREFIX = "mcp_" - Used only on the OAuth token-exchange and refresh endpoints - (``platform.claude.com/v1/oauth/token``). The Messages API client does - not send a claude-cli user-agent. - """ + +def _get_claude_code_version() -> str: + """Lazily detect the installed Claude Code version when OAuth headers need it.""" global _claude_code_version_cache if _claude_code_version_cache is None: _claude_code_version_cache = _detect_claude_code_version() @@ -367,6 +365,88 @@ def _is_kimi_coding_endpoint(base_url: str | None) -> bool: return normalized.rstrip("/").lower().startswith("https://api.kimi.com/coding") +# Model-name prefixes that identify the Kimi / Moonshot family. Covers +# - official slugs: ``kimi-k2.5``, ``kimi_thinking``, ``moonshot-v1-8k`` +# - common release lines: ``k1.5-...``, ``k2-thinking``, ``k25-...``, ``k2.5-...`` +# Matched case-insensitively against the post-``normalize_model_name`` form, +# so a caller's ``provider/vendor/model`` slug is handled the same as a +# bare name. +_KIMI_FAMILY_MODEL_PREFIXES = ( + "kimi-", "kimi_", + "moonshot-", "moonshot_", + "k1.", "k1-", + "k2.", "k2-", + "k25", "k2.5", +) + + +def _model_name_is_kimi_family(model: str | None) -> bool: + if not isinstance(model, str): + return False + m = model.strip().lower() + if not m: + return False + # Strip vendor prefix (e.g. ``moonshotai/kimi-k2.5`` → ``kimi-k2.5``) + if "/" in m: + m = m.rsplit("/", 1)[-1] + return m.startswith(_KIMI_FAMILY_MODEL_PREFIXES) + + +def _is_kimi_family_endpoint(base_url: str | None, model: str | None = None) -> bool: + """Return True for any Kimi / Moonshot Anthropic-Messages-speaking endpoint. + + Broader than ``_is_kimi_coding_endpoint`` — matches: + + - Kimi's official ``/coding`` URL (legacy check, preserved) + - Any ``api.kimi.com`` / ``moonshot.ai`` / ``moonshot.cn`` host + - Custom or proxied endpoints whose *model* name is in the Kimi / Moonshot + family (``kimi-*``, ``moonshot-*``, ``k1.*``, ``k2.*``, …). Users with + ``api_mode: anthropic_messages`` on a private gateway fronting Kimi + fall into this branch — the upstream still enforces Kimi's thinking + semantics (reasoning_content required on every replayed tool-call + message) regardless of the gateway's hostname. + + Used to decide whether to drop Anthropic's ``thinking`` kwarg and to + preserve unsigned reasoning_content-derived thinking blocks on replay. + See hermes-agent#13848, #17057. + """ + if _is_kimi_coding_endpoint(base_url): + return True + for _domain in ("api.kimi.com", "moonshot.ai", "moonshot.cn"): + if base_url_host_matches(base_url or "", _domain): + return True + if _model_name_is_kimi_family(model): + return True + return False + + +def _is_deepseek_anthropic_endpoint(base_url: str | None) -> bool: + """Return True for DeepSeek's Anthropic-compatible endpoint. + + DeepSeek's ``/anthropic`` route speaks the Anthropic Messages protocol + but, when thinking mode is enabled, requires the ``thinking`` blocks + from prior assistant turns to round-trip on subsequent requests — the + generic third-party path strips them and triggers HTTP 400:: + + The content[].thinking in the thinking mode must be passed back + to the API. + + Per DeepSeek's published compatibility matrix the blocks are unsigned + (no Anthropic-proprietary signature, no ``redacted_thinking`` support), + so this endpoint is handled with the same strip-signed / keep-unsigned + policy used for Kimi's ``/coding`` endpoint. The match is pinned to + the ``/anthropic`` path so the OpenAI-compatible ``api.deepseek.com`` + base URL (which never reaches this adapter) is not misclassified. + See hermes-agent#16748. + """ + if not base_url_host_matches(base_url or "", "api.deepseek.com"): + return False + normalized = _normalize_base_url_text(base_url) + if not normalized: + return False + return "/anthropic" in normalized.rstrip("/").lower() + + def _requires_bearer_auth(base_url: str | None) -> bool: """Return True for Anthropic-compatible providers that require Bearer auth. @@ -381,7 +461,11 @@ def _requires_bearer_auth(base_url: str | None) -> bool: return normalized.startswith(("https://api.minimax.io/anthropic", "https://api.minimaxi.com/anthropic")) -def _common_betas_for_base_url(base_url: str | None) -> list[str]: +def _common_betas_for_base_url( + base_url: str | None, + *, + drop_context_1m_beta: bool = False, +) -> list[str]: """Return the beta headers that are safe for the configured endpoint. MiniMax's Anthropic-compatible endpoints (Bearer-auth) reject requests @@ -392,14 +476,30 @@ def _common_betas_for_base_url(base_url: str | None) -> list[str]: The ``context-1m-2025-08-07`` beta is also stripped for Bearer-auth endpoints — MiniMax hosts its own models, not Claude, so the header is irrelevant at best and risks request rejection at worst. + + ``drop_context_1m_beta=True`` additionally strips the 1M-context beta on + otherwise-unrelated endpoints. The OAuth retry path flips this flag after + a subscription rejects the beta with + "The long context beta is not yet available for this subscription" so + subsequent requests in the same session don't repeat the probe. See the + reactive recovery loop in ``run_agent.py`` and issue-comment history on + PR #17680 for the full rationale. """ if _requires_bearer_auth(base_url): _stripped = {_TOOL_STREAMING_BETA, _CONTEXT_1M_BETA} return [b for b in _COMMON_BETAS if b not in _stripped] + if drop_context_1m_beta: + return [b for b in _COMMON_BETAS if b != _CONTEXT_1M_BETA] return _COMMON_BETAS -def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = None): +def build_anthropic_client( + api_key: str, + base_url: str = None, + timeout: float = None, + *, + drop_context_1m_beta: bool = False, +): """Create an Anthropic client, auto-detecting setup-tokens vs API keys. If *timeout* is provided it overrides the default 900s read timeout. The @@ -408,6 +508,12 @@ def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = Anthropic-compatible providers respect the same knob as OpenAI-wire providers. + ``drop_context_1m_beta=True`` strips ``context-1m-2025-08-07`` from the + client-level ``anthropic-beta`` header. Used by the reactive OAuth retry + path in ``run_agent.py`` when a subscription rejects the beta; leave at + its default on fresh clients so 1M-capable subscriptions keep the + capability. + Returns an anthropic.Anthropic instance. """ _anthropic_sdk = _get_anthropic_sdk() @@ -437,7 +543,10 @@ def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = kwargs["default_query"] = {"api-version": "2025-04-15"} else: kwargs["base_url"] = normalized_base_url - common_betas = _common_betas_for_base_url(normalized_base_url) + common_betas = _common_betas_for_base_url( + normalized_base_url, + drop_context_1m_beta=drop_context_1m_beta, + ) if _is_kimi_coding_endpoint(base_url): # Kimi's /coding endpoint requires User-Agent: claude-code/0.1.0 @@ -467,21 +576,15 @@ def build_anthropic_client(api_key: str, base_url: str = None, timeout: float = if common_betas: kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)} elif _is_oauth_token(api_key): - # OAuth access token / setup-token → Bearer auth + OAuth-only betas. - # The OAuth-specific beta headers are still required by Anthropic's - # OAuth-gated Messages API path; the Claude Code user-agent / x-app - # spoofing is deliberately NOT sent — Hermes identifies as itself. - # - # ``context-1m-2025-08-07`` is stripped here: Anthropic rejects - # OAuth requests that carry it with - # "This authentication style is incompatible with the long - # context beta header." - # Subscription-gated OAuth traffic gets the 200K default window. - oauth_safe_common = [b for b in common_betas if b != _CONTEXT_1M_BETA] - all_betas = oauth_safe_common + _OAUTH_ONLY_BETAS + # OAuth access token / setup-token → Bearer auth + Claude Code identity. + # Anthropic routes OAuth requests based on user-agent and headers; + # without Claude Code's fingerprint, requests get intermittent 500s. + all_betas = common_betas + _OAUTH_ONLY_BETAS kwargs["auth_token"] = api_key kwargs["default_headers"] = { "anthropic-beta": ",".join(all_betas), + "user-agent": f"claude-cli/{_get_claude_code_version()} (external, cli)", + "x-app": "cli", } else: # Regular API key → x-api-key header + common betas @@ -825,45 +928,17 @@ def resolve_anthropic_token() -> Optional[str]: """Resolve an Anthropic token from all available sources. Priority: - 1. Hermes credential pool (``~/.hermes/auth.json`` → - ``credential_pool.anthropic``) — OAuth tokens minted by Hermes' - own PKCE login flow. Entries are auto-refreshed when near - expiry. Env-sourced pool entries (``source="env:..."``) are - skipped here so the env-var priority logic below still runs. - 2. ANTHROPIC_TOKEN env var (OAuth/setup token saved by Hermes) - 3. CLAUDE_CODE_OAUTH_TOKEN env var - 4. Claude Code credentials (~/.claude.json or ~/.claude/.credentials.json) + 1. ANTHROPIC_TOKEN env var (OAuth/setup token saved by Hermes) + 2. CLAUDE_CODE_OAUTH_TOKEN env var + 3. Claude Code credentials (~/.claude.json or ~/.claude/.credentials.json) — with automatic refresh if expired and a refresh token is available - 5. ANTHROPIC_API_KEY env var (regular API key, or legacy fallback) + 4. ANTHROPIC_API_KEY env var (regular API key, or legacy fallback) Returns the token string or None. """ - # 1. Hermes credential pool — the live source of truth for tokens - # minted via ``hermes login anthropic`` / the dashboard PKCE flow. - # ``select()`` picks the best available entry and refreshes it if - # it's near expiry, so callers always get a fresh token. - # - # Skip env-sourced pool entries (``env:ANTHROPIC_TOKEN``, etc.) — - # those are passthroughs of the env var, and the env-var branches - # below have richer priority logic (``_prefer_refreshable_claude_code_token``) - # that can upgrade a static env OAuth token to a refreshed - # Claude Code token. Letting the pool win here would short-circuit - # that upgrade. - try: - from agent.credential_pool import load_pool - pool = load_pool("anthropic") - entry = pool.select() - if entry and entry.access_token and not entry.source.startswith("env:"): - return entry.access_token - except Exception as exc: - # Pool lookup is best-effort — fall through to env/file sources - # if anything goes wrong (e.g. auth.json corruption during a - # concurrent write). - logger.debug("Credential-pool lookup failed for anthropic: %s", exc) - creds = read_claude_code_credentials() - # 2. Hermes-managed OAuth/setup token env var + # 1. Hermes-managed OAuth/setup token env var token = os.getenv("ANTHROPIC_TOKEN", "").strip() if token: preferred = _prefer_refreshable_claude_code_token(token, creds) @@ -871,7 +946,7 @@ def resolve_anthropic_token() -> Optional[str]: return preferred return token - # 3. CLAUDE_CODE_OAUTH_TOKEN (used by Claude Code for setup-tokens) + # 2. CLAUDE_CODE_OAUTH_TOKEN (used by Claude Code for setup-tokens) cc_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip() if cc_token: preferred = _prefer_refreshable_claude_code_token(cc_token, creds) @@ -879,12 +954,12 @@ def resolve_anthropic_token() -> Optional[str]: return preferred return cc_token - # 4. Claude Code credential file + # 3. Claude Code credential file resolved_claude_token = _resolve_claude_code_token_from_credentials(creds) if resolved_claude_token: return resolved_claude_token - # 5. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY. + # 4. Regular API key, or a legacy OAuth token saved in ANTHROPIC_API_KEY. # This remains as a compatibility fallback for pre-migration Hermes configs. api_key = os.getenv("ANTHROPIC_API_KEY", "").strip() if api_key: @@ -1112,9 +1187,12 @@ def normalize_model_name(model: str, preserve_dots: bool = False) -> str: # These must not be converted to hyphens. See issue #12295. if _is_bedrock_model_id(model): return model - # OpenRouter uses dots for version separators (claude-opus-4.6), - # Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens. - model = model.replace(".", "-") + # Only convert dots to hyphens for Anthropic/Claude models. + # Non-Anthropic models (gpt-5.4, gemini-2.5, etc.) use dots + # as part of their canonical names. See issue #17171. + _lower = model.lower() + if _lower.startswith("claude-") or _lower.startswith("anthropic/"): + model = model.replace(".", "-") return model @@ -1301,6 +1379,7 @@ def _convert_content_to_anthropic(content: Any) -> Any: def convert_messages_to_anthropic( messages: List[Dict], base_url: str | None = None, + model: str | None = None, ) -> Tuple[Optional[Any], List[Dict]]: """Convert OpenAI-format messages to Anthropic format. @@ -1312,6 +1391,12 @@ def convert_messages_to_anthropic( endpoint, all thinking block signatures are stripped. Signatures are Anthropic-proprietary — third-party endpoints cannot validate them and will reject them with HTTP 400 "Invalid signature in thinking block". + + When *model* is provided and matches the Kimi / Moonshot family (or + *base_url* is a Kimi / Moonshot host), unsigned thinking blocks + synthesised from ``reasoning_content`` are preserved on replayed + assistant tool-call messages — Kimi requires the field to exist, even + if empty. """ system = None result = [] @@ -1540,7 +1625,16 @@ def convert_messages_to_anthropic( # cache markers can interfere with signature validation. _THINKING_TYPES = frozenset(("thinking", "redacted_thinking")) _is_third_party = _is_third_party_anthropic_endpoint(base_url) - _is_kimi = _is_kimi_coding_endpoint(base_url) + # Kimi /coding and DeepSeek /anthropic share a contract: both speak the + # Anthropic Messages protocol upstream but require that thinking blocks + # synthesised from reasoning_content round-trip on subsequent turns when + # thinking is enabled. Signed Anthropic blocks still have to be stripped + # (neither endpoint can validate Anthropic's signatures); unsigned blocks + # are preserved. See hermes-agent#13848 (Kimi) and #16748 (DeepSeek). + _preserve_unsigned_thinking = ( + _is_kimi_family_endpoint(base_url, model) + or _is_deepseek_anthropic_endpoint(base_url) + ) last_assistant_idx = None for i in range(len(result) - 1, -1, -1): @@ -1552,22 +1646,22 @@ def convert_messages_to_anthropic( if m.get("role") != "assistant" or not isinstance(m.get("content"), list): continue - if _is_kimi: - # Kimi's /coding endpoint enables thinking server-side and - # requires unsigned thinking blocks on replayed assistant - # tool-call messages. Strip signed Anthropic blocks (Kimi - # can't validate signatures) but preserve the unsigned ones - # we synthesised from reasoning_content above. + if _preserve_unsigned_thinking: + # Kimi's /coding and DeepSeek's /anthropic endpoints both enable + # thinking server-side and require unsigned thinking blocks on + # replayed assistant tool-call messages. Strip signed Anthropic + # blocks (neither upstream can validate Anthropic signatures) but + # preserve the unsigned ones we synthesised from reasoning_content. new_content = [] for b in m["content"]: if not isinstance(b, dict) or b.get("type") not in _THINKING_TYPES: new_content.append(b) continue if b.get("signature") or b.get("data"): - # Anthropic-signed block — Kimi can't validate, strip + # Anthropic-signed block — upstream can't validate, strip continue # Unsigned thinking (synthesised from reasoning_content) — - # keep it: Kimi needs it for message-history validation. + # keep it: the upstream needs it for message-history validation. new_content.append(b) m["content"] = new_content or [{"type": "text", "text": "(empty)"}] elif _is_third_party or idx != last_assistant_idx: @@ -1624,6 +1718,7 @@ def build_anthropic_kwargs( context_length: Optional[int] = None, base_url: str | None = None, fast_mode: bool = False, + drop_context_1m_beta: bool = False, ) -> Dict[str, Any]: """Build kwargs for anthropic.messages.create(). @@ -1649,10 +1744,8 @@ def build_anthropic_kwargs( "max_tokens too large given prompt" errors and retry with a smaller cap (see parse_available_output_tokens_from_error + _ephemeral_max_output_tokens). - When *is_oauth* is True, enables the OAuth-only beta headers required by - Anthropic's subscription-gated Messages endpoint (fast-mode branch only; - the default headers are set by build_anthropic_client). No system-prompt - or tool-name rewriting is performed — Hermes identifies as itself. + When *is_oauth* is True, applies Claude Code compatibility transforms: + system prompt prefix, tool name prefixing, and prompt sanitization. When *preserve_dots* is True, model name dots are not converted to hyphens (for Alibaba/DashScope anthropic-compatible endpoints: qwen3.5-plus). @@ -1665,7 +1758,9 @@ def build_anthropic_kwargs( Currently only supported on native Anthropic endpoints (not third-party compatible ones). """ - system, anthropic_messages = convert_messages_to_anthropic(messages, base_url=base_url) + system, anthropic_messages = convert_messages_to_anthropic( + messages, base_url=base_url, model=model + ) anthropic_tools = convert_tools_to_anthropic(tools) if tools else [] model = normalize_model_name(model, preserve_dots=preserve_dots) @@ -1685,11 +1780,45 @@ def build_anthropic_kwargs( if context_length and effective_max_tokens > context_length: effective_max_tokens = max(context_length - 1, 1) - # OAuth requests go through Anthropic's subscription-gated Messages - # endpoint but otherwise send the real Hermes system prompt and real - # Hermes tool names — the only OAuth-specific wire differences are - # Bearer auth and the _OAUTH_ONLY_BETAS header (applied in - # build_anthropic_client and the fast-mode branch below). + # ── OAuth: Claude Code identity ────────────────────────────────── + if is_oauth: + # 1. Prepend Claude Code system prompt identity + cc_block = {"type": "text", "text": _CLAUDE_CODE_SYSTEM_PREFIX} + if isinstance(system, list): + system = [cc_block] + system + elif isinstance(system, str) and system: + system = [cc_block, {"type": "text", "text": system}] + else: + system = [cc_block] + + # 2. Sanitize system prompt — replace product name references + # to avoid Anthropic's server-side content filters. + for block in system: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + text = text.replace("Hermes Agent", "Claude Code") + text = text.replace("Hermes agent", "Claude Code") + text = text.replace("hermes-agent", "claude-code") + text = text.replace("Nous Research", "Anthropic") + block["text"] = text + + # 3. Prefix tool names with mcp_ (Claude Code convention) + if anthropic_tools: + for tool in anthropic_tools: + if "name" in tool: + tool["name"] = _MCP_TOOL_PREFIX + tool["name"] + + # 4. Prefix tool names in message history (tool_use and tool_result blocks) + for msg in anthropic_messages: + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + if block.get("type") == "tool_use" and "name" in block: + if not block["name"].startswith(_MCP_TOOL_PREFIX): + block["name"] = _MCP_TOOL_PREFIX + block["name"] + elif block.get("type") == "tool_result" and "tool_use_id" in block: + pass # tool_result uses ID, not name kwargs: Dict[str, Any] = { "model": model, @@ -1737,7 +1866,7 @@ def build_anthropic_kwargs( # silently hides reasoning text that Hermes surfaces in its CLI. We # request "summarized" so the reasoning blocks stay populated — matching # 4.6 behavior and preserving the activity-feed UX during long tool runs. - _is_kimi_coding = _is_kimi_coding_endpoint(base_url) + _is_kimi_coding = _is_kimi_family_endpoint(base_url, model) if reasoning_config and isinstance(reasoning_config, dict) and not _is_kimi_coding: if reasoning_config.get("enabled") is not False and "haiku" not in model.lower(): effort = str(reasoning_config.get("effort", "medium")).lower() @@ -1778,11 +1907,11 @@ def build_anthropic_kwargs( kwargs.setdefault("extra_body", {})["speed"] = "fast" # Build extra_headers with ALL applicable betas (the per-request # extra_headers override the client-level anthropic-beta header). - betas = list(_common_betas_for_base_url(base_url)) + betas = list(_common_betas_for_base_url( + base_url, + drop_context_1m_beta=drop_context_1m_beta, + )) if is_oauth: - # Strip context-1m — incompatible with OAuth auth. See matching - # comment in build_anthropic_client(). - betas = [b for b in betas if b != _CONTEXT_1M_BETA] betas.extend(_OAUTH_ONLY_BETAS) betas.append(_FAST_MODE_BETA) kwargs["extra_headers"] = {"anthropic-beta": ",".join(betas)} diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index a472ddbcfc..5d957ca869 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -5,11 +5,11 @@ session search, web extraction, vision analysis, browser vision) picks up the best available backend without duplicating fallback logic. Resolution order for text tasks (auto mode): - 1. OpenRouter (OPENROUTER_API_KEY) - 2. Nous Portal (~/.hermes/auth.json active provider) - 3. Custom endpoint (config.yaml model.base_url + OPENAI_API_KEY) - 4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex, - wrapped to look like a chat.completions client) + 1. User's main provider + main model (used regardless of provider type — + aggregators, direct API-key providers, native Anthropic, Codex, etc.) + 2. OpenRouter (OPENROUTER_API_KEY) + 3. Nous Portal (~/.hermes/auth.json active provider) + 4. Custom endpoint (config.yaml model.base_url + OPENAI_API_KEY) 5. Native Anthropic 6. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN) 7. None @@ -18,10 +18,16 @@ Resolution order for vision/multimodal tasks (auto mode): 1. Selected main provider, if it is one of the supported vision backends below 2. OpenRouter 3. Nous Portal - 4. Codex OAuth (gpt-5.3-codex supports vision via Responses API) - 5. Native Anthropic - 6. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.) - 7. None + 4. Native Anthropic + 5. Custom endpoint (for local vision models: Qwen-VL, LLaVA, Pixtral, etc.) + 6. None + +Codex OAuth (ChatGPT-account auth) is intentionally NOT in either +fallback chain: OpenAI gates this endpoint behind an undocumented, +shifting model allow-list, so "just try Codex with a hardcoded model" +rots on its own. Codex is used only when the user's main provider *is* +openai-codex (Step 1 above) or when a caller explicitly requests it with +a model (auxiliary..provider + auxiliary..model). Per-task overrides are configured in config.yaml under the ``auxiliary:`` section (e.g. ``auxiliary.vision.provider``, ``auxiliary.compression.model``). @@ -101,6 +107,14 @@ from utils import base_url_host_matches, base_url_hostname, normalize_proxy_env_ logger = logging.getLogger(__name__) +def _safe_isinstance(obj: Any, maybe_type: Any) -> bool: + """Return False instead of raising when a patched symbol is not a type.""" + try: + return isinstance(obj, maybe_type) + except TypeError: + return False + + def _extract_url_query_params(url: str): """Extract query params from URL, return (clean_url, default_query dict or None).""" parsed = urlparse(url) @@ -210,6 +224,7 @@ _API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = { "kimi-coding-cn": "kimi-k2-turbo-preview", "gmi": "google/gemini-3.1-flash-lite-preview", "minimax": "MiniMax-M2.7", + "minimax-oauth": "MiniMax-M2.7-highspeed", "minimax-cn": "MiniMax-M2.7", "anthropic": "claude-haiku-4-5-20251001", "ai-gateway": "google/gemini-3-flash", @@ -229,6 +244,21 @@ _PROVIDER_VISION_MODELS: Dict[str, str] = { "zai": "glm-5v-turbo", } +# Providers whose endpoint does not accept image input, even though the +# provider's broader ecosystem has vision models available elsewhere. When +# `auxiliary.vision.provider: auto` sees one of these as the main provider, +# it must skip straight to the aggregator chain instead of returning a client +# that will 404 on every vision request. +# +# kimi-coding / kimi-coding-cn: the Kimi Coding Plan routes through +# api.kimi.com/coding (Anthropic Messages wire) which Kimi's own docs +# describe as having no image_in capability. Vision lives on the separate +# Kimi Platform (api.moonshot.ai, OpenAI-wire, pay-as-you-go). See #17076. +_PROVIDERS_WITHOUT_VISION: frozenset = frozenset({ + "kimi-coding", + "kimi-coding-cn", +}) + # OpenRouter app attribution headers _OR_HEADERS = { "HTTP-Referer": "https://hermes-agent.nousresearch.com", @@ -261,12 +291,14 @@ _NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1" _ANTHROPIC_DEFAULT_BASE_URL = "https://api.anthropic.com" _AUTH_JSON_PATH = get_hermes_home() / "auth.json" -# Codex fallback: uses the Responses API (the only endpoint the Codex -# OAuth token can access) with a fast model for auxiliary tasks. -# ChatGPT-backed Codex accounts currently reject gpt-5.3-codex for these -# auxiliary flows, while gpt-5.2-codex remains broadly available and supports -# vision via Responses. -_CODEX_AUX_MODEL = "gpt-5.2-codex" +# Codex OAuth endpoint used when a caller explicitly requests +# provider="openai-codex". There is deliberately no hardcoded default +# model: the set of models OpenAI accepts on this endpoint for +# ChatGPT-account auth is an undocumented, shifting allow-list, and +# pinning one here has drifted silently twice (gpt-5.3-codex → gpt-5.2-codex +# → gpt-5.4 over 6 weeks in early 2026). Callers must pass the model +# they want explicitly (from config.yaml model.model, auxiliary..model, +# or the user's active Codex model selection). _CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex" @@ -713,7 +745,9 @@ class _AnthropicCompletionsAdapter: response = self._client.messages.create(**anthropic_kwargs) _transport = get_transport("anthropic_messages") - _nr = _transport.normalize_response(response) + _nr = _transport.normalize_response( + response, strip_tool_prefix=self._is_oauth + ) # ToolCall already duck-types as OpenAI shape (.type, .function.name, # .function.arguments) via properties, so no wrapping needed. @@ -843,20 +877,20 @@ def _maybe_wrap_anthropic( - The ``anthropic`` SDK is not installed (falls back to OpenAI wire). """ # Already wrapped — don't double-wrap. - if isinstance(client_obj, AnthropicAuxiliaryClient): + if _safe_isinstance(client_obj, AnthropicAuxiliaryClient): return client_obj # Other specialized adapters we should never re-dispatch. - if isinstance(client_obj, CodexAuxiliaryClient): + if _safe_isinstance(client_obj, CodexAuxiliaryClient): return client_obj try: from agent.gemini_native_adapter import GeminiNativeClient - if isinstance(client_obj, GeminiNativeClient): + if _safe_isinstance(client_obj, GeminiNativeClient): return client_obj except ImportError: pass try: from agent.copilot_acp_client import CopilotACPClient - if isinstance(client_obj, CopilotACPClient): + if _safe_isinstance(client_obj, CopilotACPClient): return client_obj except ImportError: pass @@ -1394,7 +1428,23 @@ def _try_custom_endpoint() -> Tuple[Optional[Any], Optional[str]]: return _fallback_client, model -def _try_codex() -> Tuple[Optional[Any], Optional[str]]: +def _build_codex_client(model: str) -> Tuple[Optional[Any], Optional[str]]: + """Build a CodexAuxiliaryClient for an explicitly-requested model. + + There is no auto-selection of the Codex model: the ChatGPT-account + Codex endpoint's accepted model list is an undocumented, drifting + allow-list, so any hardcoded default we pick goes stale. The caller + is responsible for passing the model (e.g. from the user's own + ``model.model`` or ``auxiliary..model`` config). + + Returns (None, None) when no Codex OAuth token is available. + """ + if not model: + logger.warning( + "Auxiliary client: openai-codex requested without a model; " + "pass model explicitly (auxiliary..model in config.yaml)." + ) + return None, None pool_present, entry = _select_pool_entry("openai-codex") if pool_present: codex_token = _pool_runtime_api_key(entry) @@ -1410,13 +1460,13 @@ def _try_codex() -> Tuple[Optional[Any], Optional[str]]: if not codex_token: return None, None base_url = _CODEX_AUX_BASE_URL - logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL) + logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", model) real_client = OpenAI( api_key=codex_token, base_url=base_url, default_headers=_codex_cloudflare_headers(codex_token), ) - return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL + return CodexAuxiliaryClient(real_client, model), model def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]: @@ -1471,7 +1521,6 @@ _AUTO_PROVIDER_LABELS = { "_try_openrouter": "openrouter", "_try_nous": "nous", "_try_custom_endpoint": "local/custom", - "_try_codex": "openai-codex", "_resolve_api_key_provider": "api-key", } @@ -1498,12 +1547,18 @@ def _get_provider_chain() -> List[tuple]: Built at call time (not module level) so that test patches on the ``_try_*`` functions are picked up correctly. + + NOTE: ``openai-codex`` is deliberately NOT in this chain. The + ChatGPT-account Codex endpoint only accepts a shifting, undocumented + allow-list of model IDs, so falling back to it with a guessed model + fails more often than not. Codex is used only when the user's main + provider *is* openai-codex (see Step 1 of ``_resolve_auto``) or when + a caller explicitly requests it with a model. """ return [ ("openrouter", _try_openrouter), ("nous", _try_nous), ("local/custom", _try_custom_endpoint), - ("openai-codex", _try_codex), ("api-key", _resolve_api_key_provider), ] @@ -2019,6 +2074,13 @@ def resolve_provider_client( # ── OpenAI Codex (OAuth → Responses API) ───────────────────────── if provider == "openai-codex": + if not model: + logger.warning( + "resolve_provider_client: openai-codex requested without a " + "model; pass model explicitly (e.g. model.model in config.yaml " + "or auxiliary..model for per-task aux routing)." + ) + return None, None if raw_codex: # Return the raw OpenAI client for callers that need direct # access to responses.stream() (e.g., the main agent loop). @@ -2027,7 +2089,7 @@ def resolve_provider_client( logger.warning("resolve_provider_client: openai-codex requested " "but no Codex OAuth token found (run: hermes model)") return None, None - final_model = _normalize_resolved_model(model or _CODEX_AUX_MODEL, provider) + final_model = _normalize_resolved_model(model, provider) raw_client = OpenAI( api_key=codex_token, base_url=_CODEX_AUX_BASE_URL, @@ -2035,7 +2097,7 @@ def resolve_provider_client( ) return (raw_client, final_model) # Standard path: wrap in CodexAuxiliaryClient adapter - client, default = _try_codex() + client, default = _build_codex_client(model) if client is None: logger.warning("resolve_provider_client: openai-codex requested " "but no Codex OAuth token found (run: hermes model)") @@ -2078,9 +2140,9 @@ def resolve_provider_client( client = _wrap_if_needed(client, final_model, custom_base, custom_key) return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode else (client, final_model)) - # Try custom first, then codex, then API-key providers - for try_fn in (_try_custom_endpoint, _try_codex, - _resolve_api_key_provider): + # Try custom first, then API-key providers (Codex excluded here: + # falling through to Codex with no model is a stale-constant trap). + for try_fn in (_try_custom_endpoint, _resolve_api_key_provider): client, default = try_fn() if client is not None: final_model = _normalize_resolved_model(model or default, provider) @@ -2427,7 +2489,10 @@ def _resolve_strict_vision_backend( if provider == "nous": return _try_nous(vision=True) if provider == "openai-codex": - return _try_codex() + # Route through resolve_provider_client so the caller's explicit + # model is used. There is no safe default Codex model (shifting + # allow-list); callers must specify via auxiliary..model. + return resolve_provider_client("openai-codex", model, is_vision=True) if provider == "anthropic": return _try_anthropic() if provider == "custom": @@ -2532,6 +2597,19 @@ def resolve_vision_provider_client( main_provider, default_model or resolved_model or main_model, ) return _finalize(main_provider, sync_client, default_model) + elif main_provider in _PROVIDERS_WITHOUT_VISION: + # Kimi Coding Plan's /coding endpoint (Anthropic Messages wire) + # does not accept image input — Kimi's own docs say "Current + # model does not support image input, switch to a model with + # image_in capability" and vision lives on the separate Kimi + # Platform (api.moonshot.ai). Skip the main provider and fall + # through to the aggregator chain instead of returning a + # client that will 404 on every vision request (#17076). + logger.debug( + "Vision auto-detect: skipping main provider %s (no " + "vision support) — falling through to aggregator chain", + main_provider, + ) else: rpc_client, rpc_model = resolve_provider_client( main_provider, vision_model, @@ -3013,7 +3091,7 @@ def _get_task_extra_body(task: str) -> Dict[str, Any]: # Providers that use Anthropic-compatible endpoints (via OpenAI SDK wrapper). # Their image content blocks must use Anthropic format, not OpenAI format. -_ANTHROPIC_COMPAT_PROVIDERS = frozenset({"minimax", "minimax-cn"}) +_ANTHROPIC_COMPAT_PROVIDERS = frozenset({"minimax", "minimax-oauth", "minimax-cn"}) def _is_anthropic_compat_endpoint(provider: str, base_url: str) -> bool: diff --git a/agent/copilot_acp_client.py b/agent/copilot_acp_client.py index 94d40d2d97..027defa22b 100644 --- a/agent/copilot_acp_client.py +++ b/agent/copilot_acp_client.py @@ -608,7 +608,7 @@ class CopilotACPClient: end = start + limit if isinstance(limit, int) and limit > 0 else None content = "".join(lines[start:end]) if content: - content = redact_sensitive_text(content) + content = redact_sensitive_text(content, force=True) response = { "jsonrpc": "2.0", "id": message_id, diff --git a/agent/credential_pool.py b/agent/credential_pool.py index d11b0186e4..004b574988 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -1299,6 +1299,48 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup except Exception as exc: logger.debug("Qwen OAuth token seed failed: %s", exc) + elif provider == "minimax-oauth": + # MiniMax OAuth tokens live in ~/.hermes/auth.json providers.minimax-oauth. + # Seed the pool so `/auth list` reflects the logged-in state and the + # standard `hermes auth remove minimax-oauth ` flow works. + # Use refresh_if_expiring=False equivalent: resolve_minimax_oauth_runtime_credentials + # always refreshes on expiry, so instead read raw state here to avoid + # surprise network calls during provider discovery. + try: + from hermes_cli.auth import get_provider_auth_state + state = get_provider_auth_state("minimax-oauth") + if state and state.get("access_token"): + source_name = "oauth" + if not _is_suppressed(provider, source_name): + active_sources.add(source_name) + expires_at_ms = None + try: + from datetime import datetime as _dt + raw = state.get("expires_at", "") + if raw: + expires_at_ms = int(_dt.fromisoformat(raw).timestamp() * 1000) + except Exception: + expires_at_ms = None + base_url = str(state.get("inference_base_url", "") or "").rstrip("/") + changed |= _upsert_entry( + entries, + provider, + source_name, + { + "source": source_name, + "auth_type": AUTH_TYPE_OAUTH, + "access_token": state["access_token"], + "refresh_token": state.get("refresh_token"), + "expires_at_ms": expires_at_ms, + "base_url": base_url, + "label": state.get("label", "") or label_from_token( + state.get("access_token", ""), source_name + ), + }, + ) + except Exception as exc: + logger.debug("MiniMax OAuth token seed failed: %s", exc) + elif provider == "openai-codex": # Respect user suppression — `hermes auth remove openai-codex` marks # the device_code source as suppressed so it won't be re-seeded from diff --git a/agent/credential_sources.py b/agent/credential_sources.py index dce6293526..7420491924 100644 --- a/agent/credential_sources.py +++ b/agent/credential_sources.py @@ -252,6 +252,19 @@ def _remove_nous_device_code(provider: str, removed) -> RemovalResult: return result +def _remove_minimax_oauth(provider: str, removed) -> RemovalResult: + """MiniMax OAuth lives in auth.json providers.minimax-oauth — clear it. + + Same pattern as Nous: single-source OAuth state with refresh tokens. + Suppression of the `oauth` source ensures the pool reseed path + (_seed_from_singletons) doesn't instantly undo the removal. + """ + result = RemovalResult() + if _clear_auth_store_provider(provider): + result.cleaned.append(f"Cleared {provider} OAuth tokens from auth store") + return result + + def _remove_codex_device_code(provider: str, removed) -> RemovalResult: """Codex tokens live in TWO places: our auth store AND ~/.codex/auth.json. @@ -389,6 +402,11 @@ def _register_all_sources() -> None: remove_fn=_remove_qwen_cli, description="~/.qwen/oauth_creds.json", )) + register(RemovalStep( + provider="minimax-oauth", source_id="oauth", + remove_fn=_remove_minimax_oauth, + description="auth.json providers.minimax-oauth", + )) register(RemovalStep( provider="*", source_id="config:", match_fn=lambda src: src.startswith("config:") or src == "model_config", diff --git a/agent/curator.py b/agent/curator.py index 6858830aac..044f9904c1 100644 --- a/agent/curator.py +++ b/agent/curator.py @@ -28,7 +28,7 @@ import tempfile import threading from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Set from hermes_constants import get_hermes_home from tools import skill_usage @@ -354,6 +354,218 @@ CURATOR_REVIEW_PROMPT = ( ) +# --------------------------------------------------------------------------- +# Per-run reports — {YYYYMMDD-HHMMSS}/run.json + REPORT.md under logs/curator/ +# --------------------------------------------------------------------------- + +def _reports_root() -> Path: + """Directory where curator run reports are written. + + Lives under the profile-aware logs dir (``~/.hermes/logs/curator/``) + alongside ``agent.log`` and ``gateway.log`` so it's found by anyone + looking for operational telemetry, not mixed in with the user's + authored skill data in ``~/.hermes/skills/``. + """ + return get_hermes_home() / "logs" / "curator" + + +def _write_run_report( + *, + started_at: datetime, + elapsed_seconds: float, + auto_counts: Dict[str, int], + auto_summary: str, + before_report: List[Dict[str, Any]], + before_names: Set[str], + after_report: List[Dict[str, Any]], + llm_meta: Dict[str, Any], +) -> Optional[Path]: + """Write run.json + REPORT.md under logs/curator/{YYYYMMDD-HHMMSS}/. + + Returns the report directory path on success, None if the write + couldn't happen (caller logs and continues — reporting is best-effort). + """ + root = _reports_root() + try: + root.mkdir(parents=True, exist_ok=True) + except Exception as e: + logger.debug("Curator report dir create failed: %s", e) + return None + + stamp = started_at.strftime("%Y%m%d-%H%M%S") + run_dir = root / stamp + # If we crash-reran within the same second, append a disambiguator + suffix = 1 + while run_dir.exists(): + suffix += 1 + run_dir = root / f"{stamp}-{suffix}" + try: + run_dir.mkdir(parents=True, exist_ok=False) + except Exception as e: + logger.debug("Curator run dir create failed: %s", e) + return None + + # Diff before/after + after_by_name = {r.get("name"): r for r in after_report if isinstance(r, dict)} + after_names = set(after_by_name.keys()) + removed = sorted(before_names - after_names) # archived during this run + added = sorted(after_names - before_names) # new skills this run + before_by_name = {r.get("name"): r for r in before_report if isinstance(r, dict)} + + # State transitions between the two snapshots (e.g. active -> stale) + transitions: List[Dict[str, str]] = [] + for name in sorted(after_names & before_names): + s_before = (before_by_name.get(name) or {}).get("state") + s_after = (after_by_name.get(name) or {}).get("state") + if s_before and s_after and s_before != s_after: + transitions.append({"name": name, "from": s_before, "to": s_after}) + + # Classify LLM tool calls + tc_counts: Dict[str, int] = {} + for tc in llm_meta.get("tool_calls", []) or []: + name = tc.get("name", "unknown") + tc_counts[name] = tc_counts.get(name, 0) + 1 + + payload = { + "started_at": started_at.isoformat(), + "duration_seconds": round(elapsed_seconds, 2), + "model": llm_meta.get("model", ""), + "provider": llm_meta.get("provider", ""), + "auto_transitions": auto_counts, + "counts": { + "before": len(before_names), + "after": len(after_names), + "delta": len(after_names) - len(before_names), + "archived_this_run": len(removed), + "added_this_run": len(added), + "state_transitions": len(transitions), + "tool_calls_total": sum(tc_counts.values()), + }, + "tool_call_counts": tc_counts, + "archived": removed, + "added": added, + "state_transitions": transitions, + "llm_final": llm_meta.get("final", ""), + "llm_summary": llm_meta.get("summary", ""), + "llm_error": llm_meta.get("error"), + "tool_calls": llm_meta.get("tool_calls", []), + } + + # run.json — machine-readable, full fidelity + try: + (run_dir / "run.json").write_text( + json.dumps(payload, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + except Exception as e: + logger.debug("Curator run.json write failed: %s", e) + + # REPORT.md — human-readable + try: + md = _render_report_markdown(payload) + (run_dir / "REPORT.md").write_text(md, encoding="utf-8") + except Exception as e: + logger.debug("Curator REPORT.md write failed: %s", e) + + return run_dir + + +def _render_report_markdown(p: Dict[str, Any]) -> str: + """Render the human-readable report.""" + lines: List[str] = [] + started = p.get("started_at", "") + duration = p.get("duration_seconds", 0) or 0 + mins, secs = divmod(int(duration), 60) + dur_label = f"{mins}m {secs}s" if mins else f"{secs}s" + + lines.append(f"# Curator run — {started}\n") + model = p.get("model") or "(not resolved)" + prov = p.get("provider") or "(not resolved)" + counts = p.get("counts") or {} + lines.append( + f"Model: `{model}` via `{prov}` · Duration: {dur_label} · " + f"Agent-created skills: {counts.get('before', 0)} → {counts.get('after', 0)} " + f"({counts.get('delta', 0):+d})\n" + ) + + error = p.get("llm_error") + if error: + lines.append(f"> ⚠ LLM pass error: `{error}`\n") + + # Auto-transitions (pure, no LLM) + auto = p.get("auto_transitions") or {} + lines.append("## Auto-transitions (pure, no LLM)\n") + lines.append(f"- checked: {auto.get('checked', 0)}") + lines.append(f"- marked stale: {auto.get('marked_stale', 0)}") + lines.append(f"- archived: {auto.get('archived', 0)}") + lines.append(f"- reactivated: {auto.get('reactivated', 0)}") + lines.append("") + + # LLM pass numbers + tc_counts = p.get("tool_call_counts") or {} + lines.append("## LLM consolidation pass\n") + lines.append(f"- tool calls: **{counts.get('tool_calls_total', 0)}** " + f"(by name: {', '.join(f'{k}={v}' for k, v in sorted(tc_counts.items())) or 'none'})") + lines.append(f"- archived this run: **{counts.get('archived_this_run', 0)}**") + lines.append(f"- new skills this run: **{counts.get('added_this_run', 0)}**") + lines.append(f"- state transitions (active ↔ stale ↔ archived): " + f"**{counts.get('state_transitions', 0)}**") + lines.append("") + + # Archived list + archived = p.get("archived") or [] + if archived: + lines.append(f"### Skills archived ({len(archived)})\n") + lines.append("_Archived skills are at `~/.hermes/skills/.archive/`. " + "Restore any via `hermes curator restore `._\n") + # Show first 50 inline, note truncation after that + SHOW = 50 + for n in archived[:SHOW]: + lines.append(f"- `{n}`") + if len(archived) > SHOW: + lines.append(f"- … and {len(archived) - SHOW} more (see `run.json` for the full list)") + lines.append("") + + # Added list + added = p.get("added") or [] + if added: + lines.append(f"### New skills this run ({len(added)})\n") + lines.append("_Usually these are new class-level umbrellas created via `skill_manage action=create`._\n") + for n in added: + lines.append(f"- `{n}`") + lines.append("") + + # State transitions + trans = p.get("state_transitions") or [] + if trans: + lines.append(f"### State transitions ({len(trans)})\n") + for t in trans: + lines.append(f"- `{t.get('name')}`: {t.get('from')} → {t.get('to')}") + lines.append("") + + # Full LLM final response + final = (p.get("llm_final") or "").strip() + if final: + lines.append("## LLM final summary\n") + lines.append(final) + lines.append("") + elif not error: + llm_sum = p.get("llm_summary") or "" + if llm_sum: + lines.append("## LLM summary\n") + lines.append(llm_sum) + lines.append("") + + # Recovery footer + lines.append("## Recovery\n") + lines.append("- Restore an archived skill: `hermes curator restore `") + lines.append("- All archives live under `~/.hermes/skills/.archive/` and are recoverable by `mv`") + lines.append("- See `run.json` in this directory for the full machine-readable record.") + lines.append("") + + return "\n".join(lines) + + # --------------------------------------------------------------------------- # Orchestrator — spawn a forked AIAgent for the LLM review pass # --------------------------------------------------------------------------- @@ -415,22 +627,72 @@ def run_curator_review( def _llm_pass(): nonlocal auto_summary + # Snapshot skill state BEFORE the LLM pass so the report can diff. + try: + before_report = skill_usage.agent_created_report() + except Exception: + before_report = [] + before_names = {r.get("name") for r in before_report if isinstance(r, dict)} + + llm_meta: Dict[str, Any] = {} try: candidate_list = _render_candidate_list() if "No agent-created skills" in candidate_list: final_summary = f"auto: {auto_summary}; llm: skipped (no candidates)" + llm_meta = { + "final": "", + "summary": "skipped (no candidates)", + "model": "", + "provider": "", + "tool_calls": [], + "error": None, + } else: prompt = f"{CURATOR_REVIEW_PROMPT}\n\n{candidate_list}" - llm_summary = _run_llm_review(prompt) - final_summary = f"auto: {auto_summary}; llm: {llm_summary}" + llm_meta = _run_llm_review(prompt) + final_summary = ( + f"auto: {auto_summary}; llm: {llm_meta.get('summary', 'no change')}" + ) except Exception as e: logger.debug("Curator LLM pass failed: %s", e, exc_info=True) final_summary = f"auto: {auto_summary}; llm: error ({e})" + llm_meta = { + "final": "", + "summary": f"error ({e})", + "model": "", + "provider": "", + "tool_calls": [], + "error": str(e), + } elapsed = (datetime.now(timezone.utc) - start).total_seconds() state2 = load_state() state2["last_run_duration_seconds"] = elapsed state2["last_run_summary"] = final_summary + + # Write the per-run report. Runs in a best-effort try so a + # reporting bug never breaks the curator itself. Report path is + # recorded in state so `hermes curator status` can point at it. + try: + after_report = skill_usage.agent_created_report() + except Exception: + after_report = [] + try: + report_path = _write_run_report( + started_at=start, + elapsed_seconds=elapsed, + auto_counts=counts, + auto_summary=auto_summary, + before_report=before_report, + before_names=before_names, + after_report=after_report, + llm_meta=llm_meta, + ) + if report_path is not None: + state2["last_report_path"] = str(report_path) + except Exception as e: + logger.debug("Curator report write failed: %s", e, exc_info=True) + save_state(state2) if on_summary: @@ -452,14 +714,77 @@ def run_curator_review( } -def _run_llm_review(prompt: str) -> str: - """Spawn an AIAgent fork to run the curator review prompt. Returns a short - summary of what the model said in its final response.""" +def _resolve_review_model(cfg: Dict[str, Any]) -> tuple[str, str]: + """Pick (provider, model) for the curator review fork. + + Curator is a regular auxiliary task slot — ``auxiliary.curator.{provider,model}`` + — so it participates in the canonical aux-model plumbing (``hermes model`` → + auxiliary picker, the dashboard Models tab, ``auxiliary.curator.{timeout, + base_url,api_key,extra_body}``). ``provider: "auto"`` with an empty model + means "use the main chat model" — same default as every other aux task. + + Legacy fallback: users who configured ``curator.auxiliary.{provider,model}`` + under the previous one-off schema still work. Precedence: + 1. ``auxiliary.curator.{provider,model}`` when both are set non-auto + 2. Legacy ``curator.auxiliary.{provider,model}`` when both are set + 3. Main ``model.{provider,default/model}`` pair + """ + _main = cfg.get("model", {}) if isinstance(cfg.get("model"), dict) else {} + _main_provider = _main.get("provider") or "auto" + _main_model = _main.get("default") or _main.get("model") or "" + + # 1. Canonical aux task slot + _aux = cfg.get("auxiliary", {}) if isinstance(cfg.get("auxiliary"), dict) else {} + _cur_task = _aux.get("curator", {}) if isinstance(_aux.get("curator"), dict) else {} + _task_provider = (_cur_task.get("provider") or "").strip() or None + _task_model = (_cur_task.get("model") or "").strip() or None + if _task_provider and _task_provider != "auto" and _task_model: + return _task_provider, _task_model + + # 2. Legacy curator.auxiliary.{provider,model} (deprecated, pre-unification) + _cur = cfg.get("curator", {}) if isinstance(cfg.get("curator"), dict) else {} + _legacy = _cur.get("auxiliary", {}) if isinstance(_cur.get("auxiliary"), dict) else {} + _legacy_provider = _legacy.get("provider") or None + _legacy_model = _legacy.get("model") or None + if _legacy_provider and _legacy_model: + logger.info( + "curator: using deprecated curator.auxiliary.{provider,model} " + "config — please migrate to auxiliary.curator.{provider,model}" + ) + return _legacy_provider, _legacy_model + + # 3. Fall through to the main chat model + return _main_provider, _main_model + + +def _run_llm_review(prompt: str) -> Dict[str, Any]: + """Spawn an AIAgent fork to run the curator review prompt. + + Returns a dict with: + - final: full (untruncated) final response from the reviewer + - summary: short summary suitable for state file (240-char cap) + - model, provider: what the fork actually ran on + - tool_calls: list of {name, arguments} for every tool call made during + the pass (arguments may be truncated for readability) + - error: set if the pass failed mid-run; final/summary may still be empty + + Never raises; callers get a structured failure instead. + """ import contextlib + result_meta: Dict[str, Any] = { + "final": "", + "summary": "", + "model": "", + "provider": "", + "tool_calls": [], + "error": None, + } try: from run_agent import AIAgent except Exception as e: - return f"AIAgent import failed: {e}" + result_meta["error"] = f"AIAgent import failed: {e}" + result_meta["summary"] = result_meta["error"] + return result_meta # Resolve provider + model the same way the CLI does, so the curator # fork inherits the user's active main config rather than falling @@ -467,6 +792,11 @@ def _run_llm_review(prompt: str) -> str: # "No models provided"). AIAgent() without explicit provider/model # arguments hits an auto-resolution path that fails for OAuth-only # providers and for pool-backed credentials. + # + # `_resolve_review_model()` honors `auxiliary.curator.{provider,model}` + # (canonical aux-task slot, wired through `hermes model` → auxiliary + # picker and the dashboard Models tab), with a legacy fallback to + # `curator.auxiliary.{provider,model}`. See docs/user-guide/features/curator.md. _api_key = None _base_url = None _api_mode = None @@ -476,9 +806,7 @@ def _run_llm_review(prompt: str) -> str: from hermes_cli.config import load_config from hermes_cli.runtime_provider import resolve_runtime_provider _cfg = load_config() - _m = _cfg.get("model", {}) if isinstance(_cfg.get("model"), dict) else {} - _provider = _m.get("provider") or "auto" - _model_name = _m.get("default") or _m.get("model") or "" + _provider, _model_name = _resolve_review_model(_cfg) _rp = resolve_runtime_provider( requested=_provider, target_model=_model_name ) @@ -489,6 +817,9 @@ def _run_llm_review(prompt: str) -> str: except Exception as e: logger.debug("Curator provider resolution failed: %s", e, exc_info=True) + result_meta["model"] = _model_name + result_meta["provider"] = _resolved_provider or "" + review_agent = None try: review_agent = AIAgent( @@ -520,20 +851,43 @@ def _run_llm_review(prompt: str) -> str: with open(os.devnull, "w") as _devnull, \ contextlib.redirect_stdout(_devnull), \ contextlib.redirect_stderr(_devnull): - result = review_agent.run_conversation(user_message=prompt) + conv_result = review_agent.run_conversation(user_message=prompt) final = "" - if isinstance(result, dict): - final = str(result.get("final_response") or "").strip() - return (final[:240] + "…") if len(final) > 240 else (final or "no change") + if isinstance(conv_result, dict): + final = str(conv_result.get("final_response") or "").strip() + result_meta["final"] = final + result_meta["summary"] = (final[:240] + "…") if len(final) > 240 else (final or "no change") + + # Collect tool calls for the report. Walk the forked agent's + # session messages and extract every tool_call made during the + # pass. Truncate argument payloads so a giant skill_manage create + # doesn't blow up the report. + _calls: List[Dict[str, Any]] = [] + for msg in getattr(review_agent, "_session_messages", []) or []: + if not isinstance(msg, dict): + continue + tcs = msg.get("tool_calls") or [] + for tc in tcs: + if not isinstance(tc, dict): + continue + fn = tc.get("function") or {} + name = fn.get("name") or "" + args_raw = fn.get("arguments") or "" + if isinstance(args_raw, str) and len(args_raw) > 400: + args_raw = args_raw[:400] + "…" + _calls.append({"name": name, "arguments": args_raw}) + result_meta["tool_calls"] = _calls except Exception as e: - return f"error: {e}" + result_meta["error"] = f"error: {e}" + result_meta["summary"] = result_meta["error"] finally: if review_agent is not None: try: review_agent.close() except Exception: pass + return result_meta # --------------------------------------------------------------------------- diff --git a/agent/error_classifier.py b/agent/error_classifier.py index 511ab353c0..86e99ec1ac 100644 --- a/agent/error_classifier.py +++ b/agent/error_classifier.py @@ -54,6 +54,7 @@ class FailoverReason(enum.Enum): # Provider-specific thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate + oauth_long_context_beta_forbidden = "oauth_long_context_beta_forbidden" # Anthropic OAuth subscription rejects 1M context beta — disable beta and retry # Catch-all unknown = "unknown" # Unclassifiable — retry with backoff @@ -450,6 +451,25 @@ def classify_api_error( should_compress=True, ) + # Anthropic OAuth subscription rejects the 1M-context beta header. + # Observed error body: "The long context beta is not yet available for + # this subscription." Returned as HTTP 400 from native Anthropic when + # the subscription doesn't include 1M context, even though the request + # carries ``anthropic-beta: context-1m-2025-08-07``. The recovery path + # in run_agent.py rebuilds the Anthropic client with the beta stripped + # and retries once. Pattern is narrow enough that it won't collide with + # the 429 tier-gate pattern above (different status, different phrase). + if ( + status_code == 400 + and "long context beta" in error_msg + and "not yet available" in error_msg + ): + return _result( + FailoverReason.oauth_long_context_beta_forbidden, + retryable=True, + should_compress=False, + ) + # ── 2. HTTP status code classification ────────────────────────── if status_code is not None: diff --git a/agent/memory_manager.py b/agent/memory_manager.py index 2831eb7bf8..ea9b7425fc 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -402,6 +402,41 @@ class MemoryManager: provider.name, e, ) + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Notify all providers that the agent's session_id has rotated. + + Fires on ``/resume``, ``/branch``, ``/reset``, ``/new``, and + context compression — any path that reassigns + ``AIAgent.session_id`` without tearing the provider down. + + Providers keep running; they only need to refresh cached + per-session state so subsequent writes land in the correct + session's record. See ``MemoryProvider.on_session_switch`` for + the full contract. + """ + if not new_session_id: + return + for provider in self._providers: + try: + provider.on_session_switch( + new_session_id, + parent_session_id=parent_session_id, + reset=reset, + **kwargs, + ) + except Exception as e: + logger.debug( + "Memory provider '%s' on_session_switch failed: %s", + provider.name, e, + ) + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Notify all providers before context compression. diff --git a/agent/memory_provider.py b/agent/memory_provider.py index 535338f4ee..1c8dbaf682 100644 --- a/agent/memory_provider.py +++ b/agent/memory_provider.py @@ -25,6 +25,7 @@ Lifecycle (called by MemoryManager, wired in run_agent.py): Optional hooks (override to opt in): on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context on_session_end(messages) — end-of-session extraction + on_session_switch(new_session_id, **kwargs) — mid-process session_id rotation on_pre_compress(messages) -> str — extract before context compression on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes on_delegation(task, result, **kwargs) — parent-side observation of subagent work @@ -160,6 +161,45 @@ class MemoryProvider(ABC): (CLI exit, /reset, gateway session expiry). """ + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Called when the agent switches session_id mid-process. + + Fires on ``/resume``, ``/branch``, ``/reset``, ``/new`` (CLI), the + gateway equivalents, and context compression — any path that + reassigns ``AIAgent.session_id`` without tearing the provider down. + + Providers that cache per-session state in ``initialize()`` + (``_session_id``, ``_document_id``, accumulated turn buffers, + counters) should update or reset that state here so subsequent + writes land in the correct session's record. + + Parameters + ---------- + new_session_id: + The session_id the agent just switched to. + parent_session_id: + The previous session_id, if meaningful — set for ``/branch`` + (fork lineage), context compression (continuation lineage), + and ``/resume`` (the session we're leaving). Empty string + when no lineage applies. + reset: + ``True`` when this is a genuinely new conversation, not a + resumption of an existing one. Fired by ``/reset`` / ``/new``. + Providers should flush accumulated per-session buffers + (``_session_turns``, ``_turn_counter``, etc.) when this is + set. ``False`` for ``/resume`` / ``/branch`` / compression + where the logical conversation continues under the new id. + + Default is no-op for backward compatibility. + """ + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Called before context compression discards old messages. diff --git a/agent/model_metadata.py b/agent/model_metadata.py index afd8bee192..cca842f6b0 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -46,7 +46,7 @@ def _resolve_requests_verify() -> bool | str: # are preserved so the full model name reaches cache lookups and server queries. _PROVIDER_PREFIXES: frozenset[str] = frozenset({ "openrouter", "nous", "openai-codex", "copilot", "copilot-acp", - "gemini", "ollama-cloud", "zai", "kimi-coding", "kimi-coding-cn", "stepfun", "minimax", "minimax-cn", "anthropic", "deepseek", + "gemini", "ollama-cloud", "zai", "kimi-coding", "kimi-coding-cn", "stepfun", "minimax", "minimax-oauth", "minimax-cn", "anthropic", "deepseek", "opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba", "qwen-oauth", "xiaomi", diff --git a/agent/models_dev.py b/agent/models_dev.py index 236dd582f9..79cfa90ca9 100644 --- a/agent/models_dev.py +++ b/agent/models_dev.py @@ -149,6 +149,7 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = { "stepfun": "stepfun", "kimi-coding-cn": "kimi-for-coding", "minimax": "minimax", + "minimax-oauth": "minimax", "minimax-cn": "minimax-cn", "deepseek": "deepseek", "alibaba": "alibaba", diff --git a/agent/onboarding.py b/agent/onboarding.py index cf66bad108..220b1c6052 100644 --- a/agent/onboarding.py +++ b/agent/onboarding.py @@ -98,17 +98,19 @@ def tool_progress_hint_cli() -> str: def openclaw_residue_hint_cli() -> str: """Banner shown the first time Hermes starts and finds ``~/.openclaw/``. - OpenClaw-era config, memory, and skill paths in ``~/.openclaw/`` will - otherwise attract the agent (memory entries like ``~/.openclaw/config.yaml`` - get carried forward and the agent dutifully reads them). ``hermes claw - cleanup`` renames the directory so the agent stops finding it. + Points users at ``hermes claw migrate`` (non-destructive port of config, + memory, and skills) first. ``hermes claw cleanup`` is mentioned as the + follow-up step for users who have already migrated and want to archive + the old directory — with a warning that archiving breaks OpenClaw. """ return ( - "Heads up — an OpenClaw workspace was detected at ~/.openclaw/.\n" - "After migrating, the agent can still get confused and read that " - "directory's config/memory instead of Hermes's.\n" - "Run `hermes claw cleanup` to archive it (rename → .openclaw.pre-migration). " - "This tip only shows once; rerun it any time with `hermes claw cleanup`." + "A legacy OpenClaw directory was detected at ~/.openclaw/.\n" + "To port your config, memory, and skills over to Hermes, run " + "`hermes claw migrate`.\n" + "If you've already migrated and want to archive the old directory, " + "run `hermes claw cleanup` (renames it to ~/.openclaw.pre-migration — " + "OpenClaw will stop working after this).\n" + "This tip only shows once." ) diff --git a/agent/redact.py b/agent/redact.py index 0a66502c75..970ad5adfb 100644 --- a/agent/redact.py +++ b/agent/redact.py @@ -305,11 +305,13 @@ def _redact_form_body(text: str) -> str: return _redact_query_string(text.strip()) -def redact_sensitive_text(text: str) -> str: +def redact_sensitive_text(text: str, *, force: bool = False) -> str: """Apply all redaction patterns to a block of text. Safe to call on any string -- non-matching text passes through unchanged. Disabled by default — enable via security.redact_secrets: true in config.yaml. + Set force=True for safety boundaries that must never return raw secrets + regardless of the user's global logging redaction preference. """ if text is None: return None @@ -317,7 +319,7 @@ def redact_sensitive_text(text: str) -> str: text = str(text) if not text: return text - if not _REDACT_ENABLED: + if not (force or _REDACT_ENABLED): return text # Known prefixes (sk-, ghp_, etc.) diff --git a/agent/skill_commands.py b/agent/skill_commands.py index 19c9b06c6c..fff29eff66 100644 --- a/agent/skill_commands.py +++ b/agent/skill_commands.py @@ -284,6 +284,71 @@ def get_skill_commands() -> Dict[str, Dict[str, Any]]: return _skill_commands +def reload_skills() -> Dict[str, Any]: + """Re-scan the skills directory and return a diff of what changed. + + Rescans ``~/.hermes/skills/`` and any ``skills.external_dirs`` so the + slash-command map (``agent.skill_commands._skill_commands``) reflects + skills added or removed on disk. + + This does NOT invalidate the skills system-prompt cache. Skills are + called by name via ``/skill-name``, ``skills_list``, or ``skill_view`` + — they don't need to be in the system prompt for the model to use them. + Keeping the prompt cache intact preserves prefix caching across the + reload, so a user invoking ``/reload-skills`` pays no cache-reset cost. + + Returns: + Dict with keys:: + + { + "added": [{"name": str, "description": str}, ...], + "removed": [{"name": str, "description": str}, ...], + "unchanged": [skill names present before and after], + "total": total skill count after rescan, + "commands": total /slash-skill count after rescan, + } + + ``description`` is the skill's full SKILL.md frontmatter + ``description:`` field — the same string the system prompt renders + as `` - name: description`` for pre-existing skills. + """ + # Snapshot pre-reload state (name -> description) from the current + # slash-command cache. Using dicts lets the post-rescan diff carry + # descriptions for newly-visible or just-removed skills without a + # second disk walk. + def _snapshot(cmds: Dict[str, Dict[str, Any]]) -> Dict[str, str]: + out: Dict[str, str] = {} + for slash_key, info in cmds.items(): + bare = slash_key.lstrip("/") + out[bare] = (info or {}).get("description") or "" + return out + + before = _snapshot(_skill_commands) + + # Rescan the skills dir. ``scan_skill_commands`` resets + # ``_skill_commands = {}`` internally and repopulates it. + new_commands = scan_skill_commands() + + after = _snapshot(new_commands) + + added_names = sorted(set(after) - set(before)) + removed_names = sorted(set(before) - set(after)) + unchanged = sorted(set(after) & set(before)) + + added = [{"name": n, "description": after[n]} for n in added_names] + # For removed skills, use the description we had cached pre-rescan + # (the skill file is gone so we can't re-read it). + removed = [{"name": n, "description": before[n]} for n in removed_names] + + return { + "added": added, + "removed": removed, + "unchanged": unchanged, + "total": len(after), + "commands": len(new_commands), + } + + def resolve_skill_command_key(command: str) -> Optional[str]: """Resolve a user-typed /command to its canonical skill_cmds key. diff --git a/agent/transports/anthropic.py b/agent/transports/anthropic.py index 5ecc8a29df..72024ac20f 100644 --- a/agent/transports/anthropic.py +++ b/agent/transports/anthropic.py @@ -58,6 +58,7 @@ class AnthropicTransport(ProviderTransport): context_length: int | None base_url: str | None fast_mode: bool + drop_context_1m_beta: bool """ from agent.anthropic_adapter import build_anthropic_kwargs @@ -73,6 +74,7 @@ class AnthropicTransport(ProviderTransport): context_length=params.get("context_length"), base_url=params.get("base_url"), fast_mode=params.get("fast_mode", False), + drop_context_1m_beta=params.get("drop_context_1m_beta", False), ) def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse: @@ -85,6 +87,9 @@ class AnthropicTransport(ProviderTransport): from agent.anthropic_adapter import _to_plain_data from agent.transports.types import ToolCall + strip_tool_prefix = kwargs.get("strip_tool_prefix", False) + _MCP_PREFIX = "mcp_" + text_parts = [] reasoning_parts = [] reasoning_details = [] @@ -99,10 +104,13 @@ class AnthropicTransport(ProviderTransport): if isinstance(block_dict, dict): reasoning_details.append(block_dict) elif block.type == "tool_use": + name = block.name + if strip_tool_prefix and name.startswith(_MCP_PREFIX): + name = name[len(_MCP_PREFIX):] tool_calls.append( ToolCall( id=block.id, - name=block.name, + name=name, arguments=json.dumps(block.input), ) ) diff --git a/agent/transports/chat_completions.py b/agent/transports/chat_completions.py index 6206e325cf..c43611f85f 100644 --- a/agent/transports/chat_completions.py +++ b/agent/transports/chat_completions.py @@ -20,12 +20,7 @@ from agent.transports.types import NormalizedResponse, ToolCall, Usage def _build_gemini_thinking_config(model: str, reasoning_config: dict | None) -> dict | None: - """Translate Hermes/OpenRouter-style reasoning config to Gemini thinkingConfig. - - Gemini native/cloud-code adapters do not read ``extra_body.reasoning``. - They only inspect ``extra_body.thinking_config`` / ``thinkingConfig`` and - then request thought parts with ``includeThoughts`` enabled. - """ + """Translate Hermes/OpenRouter-style reasoning config to Gemini thinkingConfig.""" if reasoning_config is None or not isinstance(reasoning_config, dict): return None @@ -71,6 +66,30 @@ def _build_gemini_thinking_config(model: str, reasoning_config: dict | None) -> return thinking_config +def _snake_case_gemini_thinking_config(config: dict | None) -> dict | None: + """Convert Gemini thinking config keys to the OpenAI-compat field names.""" + if not isinstance(config, dict) or not config: + return None + + translated: Dict[str, Any] = {} + if isinstance(config.get("includeThoughts"), bool): + translated["include_thoughts"] = config["includeThoughts"] + if isinstance(config.get("thinkingLevel"), str) and config["thinkingLevel"].strip(): + translated["thinking_level"] = config["thinkingLevel"].strip().lower() + if isinstance(config.get("thinkingBudget"), (int, float)): + translated["thinking_budget"] = int(config["thinkingBudget"]) + return translated or None + + +def _is_gemini_openai_compat_base_url(base_url: Any) -> bool: + normalized = str(base_url or "").strip().rstrip("/").lower() + if not normalized: + return False + if "generativelanguage.googleapis.com" not in normalized: + return False + return normalized.endswith("/openai") + + class ChatCompletionsTransport(ProviderTransport): """Transport for api_mode='chat_completions'. @@ -309,6 +328,7 @@ class ChatCompletionsTransport(ProviderTransport): is_nous = params.get("is_nous", False) is_github_models = params.get("is_github_models", False) provider_name = str(params.get("provider_name") or "").strip().lower() + base_url = params.get("base_url") provider_prefs = params.get("provider_preferences") if provider_prefs and is_openrouter: @@ -362,7 +382,19 @@ class ChatCompletionsTransport(ProviderTransport): if is_qwen: extra_body["vl_high_resolution_images"] = True - if provider_name in {"gemini", "google-gemini-cli"}: + if provider_name == "gemini": + raw_thinking_config = _build_gemini_thinking_config(model, reasoning_config) + if _is_gemini_openai_compat_base_url(base_url): + thinking_config = _snake_case_gemini_thinking_config(raw_thinking_config) + if thinking_config: + openai_compat_extra = extra_body.get("extra_body", {}) + google_extra = openai_compat_extra.get("google", {}) + google_extra["thinking_config"] = thinking_config + openai_compat_extra["google"] = google_extra + extra_body["extra_body"] = openai_compat_extra + elif raw_thinking_config: + extra_body["thinking_config"] = raw_thinking_config + elif provider_name == "google-gemini-cli": thinking_config = _build_gemini_thinking_config(model, reasoning_config) if thinking_config: extra_body["thinking_config"] = thinking_config diff --git a/agent/usage_pricing.py b/agent/usage_pricing.py index 1dfe59ea32..746f962097 100644 --- a/agent/usage_pricing.py +++ b/agent/usage_pricing.py @@ -359,6 +359,25 @@ _OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = { source_url="https://aws.amazon.com/bedrock/pricing/", pricing_version="bedrock-pricing-2026-04", ), + # MiniMax + ( + "minimax", + "minimax-m2.7", + ): PricingEntry( + input_cost_per_million=Decimal("0.30"), + output_cost_per_million=Decimal("1.20"), + source="official_docs_snapshot", + pricing_version="minimax-pricing-2026-04", + ), + ( + "minimax-cn", + "minimax-m2.7", + ): PricingEntry( + input_cost_per_million=Decimal("0.30"), + output_cost_per_million=Decimal("1.20"), + source="official_docs_snapshot", + pricing_version="minimax-pricing-2026-04", + ), } @@ -400,6 +419,8 @@ def resolve_billing_route( return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") if provider_name == "openai": return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") + if provider_name in {"minimax", "minimax-cn"}: + return BillingRoute(provider=provider_name, model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") if provider_name in {"custom", "local"} or (base and "localhost" in base): return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown") return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown") diff --git a/cli-config.yaml.example b/cli-config.yaml.example index b2a6868604..e292498b0c 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -180,6 +180,11 @@ terminal: # lifetime_seconds: 300 # docker_image: "nikolaik/python-nodejs:python3.11-nodejs20" # docker_mount_cwd_to_workspace: true # Explicit opt-in: mount your launch cwd into /workspace +# # Optional: run the container as your host user's uid:gid so files written +# # into bind-mounted dirs are owned by you, not root. Drops SETUID/SETGID +# # caps too since no gosu privilege drop is needed. Leave off if your +# # chosen docker_image expects to start as root. +# docker_run_as_host_user: true # # Optional: explicitly forward selected env vars into Docker. # # These values come from your current shell first, then ~/.hermes/.env. # # Warning: anything forwarded here is visible to commands run in the container. @@ -565,7 +570,7 @@ agent: # - A preset like "hermes-cli" or "hermes-telegram" (curated tool set) # - A list of individual toolsets to compose your own (see list below) # -# Supported platform keys: cli, telegram, discord, whatsapp, slack, qqbot +# Supported platform keys: cli, telegram, discord, whatsapp, slack, qqbot, teams # # Examples: # @@ -595,6 +600,7 @@ agent: # signal: hermes-signal (same as telegram) # homeassistant: hermes-homeassistant (same as telegram) # qqbot: hermes-qqbot (same as telegram) +# teams: hermes-teams (same as telegram) # platform_toolsets: cli: [hermes-cli] @@ -606,6 +612,7 @@ platform_toolsets: homeassistant: [hermes-homeassistant] qqbot: [hermes-qqbot] yuanbao: [hermes-yuanbao] + teams: [hermes-teams] # ============================================================================= # Gateway Platform Settings diff --git a/cli.py b/cli.py index 12b73664ec..1d0285a574 100644 --- a/cli.py +++ b/cli.py @@ -497,18 +497,20 @@ def load_cli_config() -> Dict[str, Any]: "singularity_image": "TERMINAL_SINGULARITY_IMAGE", "modal_image": "TERMINAL_MODAL_IMAGE", "daytona_image": "TERMINAL_DAYTONA_IMAGE", + "vercel_runtime": "TERMINAL_VERCEL_RUNTIME", # SSH config "ssh_host": "TERMINAL_SSH_HOST", "ssh_user": "TERMINAL_SSH_USER", "ssh_port": "TERMINAL_SSH_PORT", "ssh_key": "TERMINAL_SSH_KEY", - # Container resource config (docker, singularity, modal, daytona -- ignored for local/ssh) + # Container resource config (docker, singularity, modal, daytona, vercel_sandbox -- ignored for local/ssh) "container_cpu": "TERMINAL_CONTAINER_CPU", "container_memory": "TERMINAL_CONTAINER_MEMORY", "container_disk": "TERMINAL_CONTAINER_DISK", "container_persistent": "TERMINAL_CONTAINER_PERSISTENT", "docker_volumes": "TERMINAL_DOCKER_VOLUMES", "docker_mount_cwd_to_workspace": "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", + "docker_run_as_host_user": "TERMINAL_DOCKER_RUN_AS_HOST_USER", "sandbox_dir": "TERMINAL_SANDBOX_DIR", # Persistent shell (non-local backends) "persistent_shell": "TERMINAL_PERSISTENT_SHELL", @@ -3105,6 +3107,8 @@ class HermesCLI: return "Processing skills command..." if cmd_lower == "/reload-mcp": return "Reloading MCP servers..." + if cmd_lower == "/reload-skills" or cmd_lower == "/reload_skills": + return "Reloading skills..." if cmd_lower.startswith("/browser"): return "Configuring browser..." return "Processing command..." @@ -4808,6 +4812,22 @@ class HermesCLI: ) except Exception: pass + # Notify memory providers that session_id rotated to a fresh + # conversation. reset=True signals providers to flush accumulated + # per-session state (_session_turns, _turn_counter, _document_id). + # Fires BEFORE the plugin on_session_reset hook (shell hooks only + # see the new id; Python providers see the transition). See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + self.session_id, + parent_session_id=old_session_id or "", + reset=True, + reason="new_session", + ) + except Exception: + pass self._notify_session_boundary("on_session_reset") if not silent: @@ -4860,6 +4880,7 @@ class HermesCLI: _cprint(" Already on that session.") return + old_session_id = self.session_id # End current session try: self._session_db.end_session(self.session_id, "resumed_other") @@ -4897,6 +4918,22 @@ class HermesCLI: if hasattr(self.agent, "_invalidate_system_prompt"): self.agent._invalidate_system_prompt() + # Notify memory providers that session_id rotated to a resumed + # session. reset=False — the provider's accumulated state is + # still valid; it just needs to target the new session_id for + # subsequent writes. See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + target_id, + parent_session_id=old_session_id or "", + reset=False, + reason="resume", + ) + except Exception: + pass + title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else "" msg_count = len([m for m in self.conversation_history if m.get("role") == "user"]) if self.conversation_history: @@ -5017,6 +5054,22 @@ class HermesCLI: if hasattr(self.agent, "_invalidate_system_prompt"): self.agent._invalidate_system_prompt() + # Notify memory providers that session_id forked to a new branch. + # reset=False — the branched session carries the transcript + # forward, so provider state tracks the lineage. parent_session_id + # links the branch back to the original. See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + new_session_id, + parent_session_id=parent_session_id or "", + reset=False, + reason="branch", + ) + except Exception: + pass + msg_count = len([m for m in self.conversation_history if m.get("role") == "user"]) _cprint( f" ⑂ Branched session \"{branch_title}\"" @@ -6233,8 +6286,13 @@ class HermesCLI: count = reload_env() print(f" Reloaded .env ({count} var(s) updated)") elif canonical == "reload-mcp": + # Interactive reload: confirm first (unless the user has opted out). + # The auto-reload path (file watcher) calls _reload_mcp directly + # without this confirmation. + self._confirm_and_reload_mcp(cmd_original) + elif canonical == "reload-skills": with self._busy_command(self._slow_command_status(cmd_original)): - self._reload_mcp() + self._reload_skills() elif canonical == "browser": self._handle_browser_command(cmd_original) elif canonical == "plugins": @@ -7361,6 +7419,77 @@ class HermesCLI: if _reload_thread.is_alive(): print(" ⚠️ MCP reload timed out (30s). Some servers may not have reconnected.") + def _confirm_and_reload_mcp(self, cmd_original: str = "") -> None: + """Interactive /reload-mcp — confirm with the user, then reload. + + Reloading MCP tools invalidates the provider prompt cache for the + active session (tool schemas are baked into the system prompt). + The next message re-sends full input tokens — can be expensive on + long-context or high-reasoning models. + + Three options: Approve Once, Always Approve (persists + ``approvals.mcp_reload_confirm: false`` so future reloads run + without this prompt), Cancel. Gated by + ``approvals.mcp_reload_confirm`` — default on. + """ + # Gate check — respects prior "Always Approve" clicks. + try: + cfg = load_cli_config() + approvals = cfg.get("approvals") if isinstance(cfg, dict) else None + confirm_required = True + if isinstance(approvals, dict): + confirm_required = bool(approvals.get("mcp_reload_confirm", True)) + except Exception: + confirm_required = True + + if not confirm_required: + with self._busy_command(self._slow_command_status(cmd_original)): + self._reload_mcp() + return + + # Render warning + prompt. Use a single-line prompt so the user + # sees the warning as output and types a response into the composer. + print() + print("⚠️ /reload-mcp — Prompt cache invalidation warning") + print() + print(" Reloading MCP servers rebuilds the tool set for this session and") + print(" invalidates the provider prompt cache. The next message will") + print(" re-send full input tokens (can be expensive on long-context or") + print(" high-reasoning models).") + print() + print(" [1] Approve Once — reload now") + print(" [2] Always Approve — reload now and silence this prompt permanently") + print(" [3] Cancel — leave MCP tools unchanged") + print() + raw = self._prompt_text_input("Choice [1/2/3]: ") + if raw is None: + print("🟡 /reload-mcp cancelled (no input).") + return + choice_raw = raw.strip().lower() + if choice_raw in ("1", "once", "approve", "yes", "y", "ok"): + choice = "once" + elif choice_raw in ("2", "always", "remember"): + choice = "always" + elif choice_raw in ("3", "cancel", "nevermind", "no", "n", ""): + choice = "cancel" + else: + print(f"🟡 Unrecognized choice '{raw}'. /reload-mcp cancelled.") + return + + if choice == "cancel": + print("🟡 /reload-mcp cancelled. MCP tools unchanged.") + return + + if choice == "always": + if save_config_value("approvals.mcp_reload_confirm", False): + print("🔒 Future /reload-mcp calls will run without confirmation.") + print(" Re-enable via `approvals.mcp_reload_confirm: true` in config.yaml.") + else: + print("⚠️ Couldn't persist opt-out — reloading once.") + + with self._busy_command(self._slow_command_status(cmd_original)): + self._reload_mcp() + def _reload_mcp(self): """Reload MCP servers: disconnect all, re-read config.yaml, reconnect. @@ -7446,6 +7575,78 @@ class HermesCLI: except Exception as e: print(f" ❌ MCP reload failed: {e}") + def _reload_skills(self) -> None: + """Reload skills: rescan ~/.hermes/skills/ and queue a note for the + next user turn. + + Skills don't need to live in the system prompt for the model to use + them (they're invoked via ``/skill-name``, ``skills_list``, or + ``skill_view`` at runtime), so this does NOT clear the prompt cache. + It rescans the slash-command map, prints the diff for the user, and + — if any skills were added or removed — queues a one-shot note that + gets prepended to the next user message. This preserves message + alternation (no phantom user turn injected out of band) and keeps + prompt caching intact. + """ + try: + from agent.skill_commands import reload_skills + + if not self._command_running: + print("🔄 Reloading skills...") + + result = reload_skills() + added = result.get("added", []) # [{"name", "description"}, ...] + removed = result.get("removed", []) # [{"name", "description"}, ...] + total = result.get("total", 0) + + if not added and not removed: + print(" No new skills detected.") + print(f" 📚 {total} skill(s) available") + return + + def _fmt_line(item: dict) -> str: + nm = item.get("name", "") + desc = item.get("description", "") + return f" - {nm}: {desc}" if desc else f" - {nm}" + + if added: + print(" ➕ Added Skills:") + for item in added: + print(f" {_fmt_line(item)}") + if removed: + print(" ➖ Removed Skills:") + for item in removed: + print(f" {_fmt_line(item)}") + print(f" 📚 {total} skill(s) available") + + # Queue a one-shot note for the NEXT user turn. The CLI's agent + # loop prepends ``_pending_skills_reload_note`` (if set) to the + # API-call-local message at ~L8770, then clears it — same + # pattern as ``_pending_model_switch_note``. Nothing is written + # to conversation_history here, so message alternation stays + # intact and no out-of-band user turn is persisted. + # + # Format matches how the system prompt renders pre-existing + # skills (`` - name: description``) so the model reads the + # diff in the same shape as its original skill catalog. + sections = ["[USER INITIATED SKILLS RELOAD:"] + if added: + sections.append("") + sections.append("Added Skills:") + for item in added: + sections.append(_fmt_line(item)) + if removed: + sections.append("") + sections.append("Removed Skills:") + for item in removed: + sections.append(_fmt_line(item)) + sections.append("") + sections.append("Use skills_list to see the updated catalog.]") + self._pending_skills_reload_note = "\n".join(sections) + + except Exception as e: + print(f" ❌ Skills reload failed: {e}") + # ==================================================================== # Tool-call generation indicator (shown during streaming) # ==================================================================== @@ -8654,6 +8855,13 @@ class HermesCLI: if _msn: agent_message = _msn + "\n\n" + agent_message self._pending_model_switch_note = None + # Prepend pending /reload-skills note so the model sees which + # skills were added/removed before handling this turn. Same + # one-shot queue pattern as the model-switch note above. + _srn = getattr(self, '_pending_skills_reload_note', None) + if _srn: + agent_message = _srn + "\n\n" + agent_message + self._pending_skills_reload_note = None try: result = self.agent.run_conversation( user_message=agent_message, diff --git a/cron/jobs.py b/cron/jobs.py index 5bdb1122fe..6376260828 100644 --- a/cron/jobs.py +++ b/cron/jobs.py @@ -313,13 +313,21 @@ def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None elif schedule["kind"] == "cron": if not HAS_CRONITER: logger.warning( - "Cannot compute next run for cron schedule %r: 'croniter' " - "is not installed. Install the 'cron' extra (pip install " - "'hermes-agent[cron]') to re-enable recurring cron jobs.", + "Cannot compute next run for cron schedule %r: 'croniter' is " + "not installed. croniter is a core dependency as of v0.9.x; " + "reinstall hermes-agent or run 'pip install croniter' in your " + "runtime env.", schedule.get("expr"), ) return None - cron = croniter(schedule["expr"], now) + # Use last_run_at as the croniter base when available, consistent + # with interval jobs. This ensures that after a crash/restart, + # the next run is anchored to the actual last execution time + # rather than to an arbitrary restart time. + base_time = now + if last_run_at: + base_time = _ensure_aware(datetime.fromisoformat(last_run_at)) + cron = croniter(schedule["expr"], base_time) next_run = cron.get_next(datetime) return next_run.isoformat() diff --git a/cron/scheduler.py b/cron/scheduler.py index 3b38a20336..08d73c1beb 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -233,12 +233,32 @@ def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[d } +def _normalize_deliver_value(deliver) -> str: + """Normalize a stored/submitted ``deliver`` value to its canonical string form. + + The contract is that ``deliver`` is a string (``"local"``, ``"origin"``, + ``"telegram"``, ``"telegram:-1001:17"``, or comma-separated combinations). + Historically some callers — MCP clients passing an array, direct edits of + ``jobs.json``, or stale code paths — have stored a list/tuple like + ``["telegram"]``. ``str(["telegram"])`` would serialize to the literal + string ``"['telegram']"``, which is not a known platform and fails + resolution silently. Flatten lists/tuples into a comma-separated string + so both forms work. Returns ``"local"`` for anything falsy. + """ + if deliver is None or deliver == "": + return "local" + if isinstance(deliver, (list, tuple)): + parts = [str(p).strip() for p in deliver if str(p).strip()] + return ",".join(parts) if parts else "local" + return str(deliver) + + def _resolve_delivery_targets(job: dict) -> List[dict]: """Resolve all concrete auto-delivery targets for a cron job (supports comma-separated deliver).""" - deliver = job.get("deliver", "local") + deliver = _normalize_deliver_value(job.get("deliver", "local")) if deliver == "local": return [] - parts = [p.strip() for p in str(deliver).split(",") if p.strip()] + parts = [p.strip() for p in deliver.split(",") if p.strip()] seen = set() targets = [] for part in parts: @@ -257,13 +277,21 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]: return targets[0] if targets else None -# Media extension sets — keep in sync with gateway/platforms/base.py:_process_message_background -_AUDIO_EXTS = frozenset({'.ogg', '.opus', '.mp3', '.wav', '.m4a'}) +# Media extension sets — audio routing is centralized in gateway.platforms.base +# via should_send_media_as_audio() so Telegram-specific rules stay in one place. _VIDEO_EXTS = frozenset({'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'}) _IMAGE_EXTS = frozenset({'.jpg', '.jpeg', '.png', '.webp', '.gif'}) -def _send_media_via_adapter(adapter, chat_id: str, media_files: list, metadata: dict | None, loop, job: dict) -> None: +def _send_media_via_adapter( + adapter, + chat_id: str, + media_files: list, + metadata: dict | None, + loop, + job: dict, + platform=None, +) -> None: """Send extracted MEDIA files as native platform attachments via a live adapter. Routes each file to the appropriate adapter method (send_voice, send_image_file, @@ -272,10 +300,13 @@ def _send_media_via_adapter(adapter, chat_id: str, media_files: list, metadata: """ from pathlib import Path + from gateway.platforms.base import should_send_media_as_audio + for media_path, _is_voice in media_files: try: ext = Path(media_path).suffix.lower() - if ext in _AUDIO_EXTS: + route_platform = platform if platform is not None else getattr(adapter, "platform", None) + if should_send_media_as_audio(route_platform, ext, is_voice=_is_voice): coro = adapter.send_voice(chat_id=chat_id, audio_path=media_path, metadata=metadata) elif ext in _VIDEO_EXTS: coro = adapter.send_video(chat_id=chat_id, video_path=media_path, metadata=metadata) @@ -321,27 +352,6 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option from tools.send_message_tool import _send_to_platform from gateway.config import load_gateway_config, Platform - platform_map = { - "telegram": Platform.TELEGRAM, - "discord": Platform.DISCORD, - "slack": Platform.SLACK, - "whatsapp": Platform.WHATSAPP, - "signal": Platform.SIGNAL, - "matrix": Platform.MATRIX, - "mattermost": Platform.MATTERMOST, - "homeassistant": Platform.HOMEASSISTANT, - "dingtalk": Platform.DINGTALK, - "feishu": Platform.FEISHU, - "wecom": Platform.WECOM, - "wecom_callback": Platform.WECOM_CALLBACK, - "weixin": Platform.WEIXIN, - "email": Platform.EMAIL, - "sms": Platform.SMS, - "bluebubbles": Platform.BLUEBUBBLES, - "qqbot": Platform.QQBOT, - "yuanbao": Platform.YUANBAO, - } - # Optionally wrap the content with a header/footer so the user knows this # is a cron delivery. Wrapping is on by default; set cron.wrap_response: false # in config.yaml for clean output. @@ -398,13 +408,23 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option job["id"], platform_name, chat_id, thread_id, ) - platform = platform_map.get(platform_name.lower()) - if not platform: + # Built-in names resolve to their enum member; plugin platform names + # create dynamic members via Platform._missing_(). + try: + platform = Platform(platform_name.lower()) + except (ValueError, KeyError): msg = f"unknown platform '{platform_name}'" logger.warning("Job '%s': %s", job["id"], msg) delivery_errors.append(msg) continue + pconfig = config.platforms.get(platform) + if not pconfig or not pconfig.enabled: + msg = f"platform '{platform_name}' not configured/enabled" + logger.warning("Job '%s': %s", job["id"], msg) + delivery_errors.append(msg) + continue + # Prefer the live adapter when the gateway is running — this supports E2EE # rooms (e.g. Matrix) where the standalone HTTP path cannot encrypt. runtime_adapter = (adapters or {}).get(platform) @@ -435,7 +455,15 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option # Send extracted media files as native attachments via the live adapter if adapter_ok and media_files: - _send_media_via_adapter(runtime_adapter, chat_id, media_files, send_metadata, loop, job) + _send_media_via_adapter( + runtime_adapter, + chat_id, + media_files, + send_metadata, + loop, + job, + platform=platform, + ) if adapter_ok: logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id) @@ -447,13 +475,6 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option ) if not delivered: - pconfig = config.platforms.get(platform) - if not pconfig or not pconfig.enabled: - msg = f"platform '{platform_name}' not configured/enabled" - logger.warning("Job '%s': %s", job["id"], msg) - delivery_errors.append(msg) - continue - # Standalone path: run the async send in a fresh event loop (safe from any thread) coro = _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files) try: @@ -840,6 +861,13 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: chat_id=str(origin["chat_id"]) if origin else "", chat_name=origin.get("chat_name", "") if origin else "", ) + _cron_delivery_vars = ( + "HERMES_CRON_AUTO_DELIVER_PLATFORM", + "HERMES_CRON_AUTO_DELIVER_CHAT_ID", + "HERMES_CRON_AUTO_DELIVER_THREAD_ID", + ) + for _var_name in _cron_delivery_vars: + _VAR_MAP[_var_name].set("") # Per-job working directory. When set (and validated at create/update # time), we point TERMINAL_CWD at it so: @@ -878,8 +906,11 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: if delivery_target: _VAR_MAP["HERMES_CRON_AUTO_DELIVER_PLATFORM"].set(delivery_target["platform"]) _VAR_MAP["HERMES_CRON_AUTO_DELIVER_CHAT_ID"].set(str(delivery_target["chat_id"])) - if delivery_target.get("thread_id") is not None: - _VAR_MAP["HERMES_CRON_AUTO_DELIVER_THREAD_ID"].set(str(delivery_target["thread_id"])) + _VAR_MAP["HERMES_CRON_AUTO_DELIVER_THREAD_ID"].set( + "" + if delivery_target.get("thread_id") is None + else str(delivery_target["thread_id"]) + ) model = job.get("model") or os.getenv("HERMES_MODEL") or "" @@ -1013,10 +1044,12 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: enabled_toolsets=_resolve_cron_enabled_toolsets(job, _cfg), disabled_toolsets=["cronjob", "messaging", "clarify"], quiet_mode=True, - # When a workdir is configured, inject AGENTS.md / CLAUDE.md / - # .cursorrules from that directory; otherwise preserve the old - # behaviour (don't inject SOUL.md/AGENTS.md from the scheduler cwd). + # Cron jobs should always inherit the user's SOUL.md identity from + # HERMES_HOME. When a workdir is configured, also inject project + # context files (AGENTS.md / CLAUDE.md / .cursorrules) from there. + # Without a workdir, keep cwd context discovery disabled. skip_context_files=not bool(_job_workdir), + load_soul_identity=True, skip_memory=True, # Cron system prompts would corrupt user representations platform="cron", session_id=_cron_session_id, @@ -1031,7 +1064,18 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: # # Uses the agent's built-in activity tracker (updated by # _touch_activity() on every tool call, API call, and stream delta). - _cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600)) + _raw_cron_timeout = os.getenv("HERMES_CRON_TIMEOUT", "").strip() + if _raw_cron_timeout: + try: + _cron_timeout = float(_raw_cron_timeout) + except (ValueError, TypeError): + logger.warning( + "Invalid HERMES_CRON_TIMEOUT=%r; using default 600s", + _raw_cron_timeout, + ) + _cron_timeout = 600.0 + else: + _cron_timeout = 600.0 _cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None _POLL_INTERVAL = 5.0 _cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -1165,6 +1209,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: os.environ["TERMINAL_CWD"] = _prior_terminal_cwd # Clean up ContextVar session/delivery state for this job. clear_session_vars(_ctx_tokens) + for _var_name in _cron_delivery_vars: + _VAR_MAP[_var_name].set("") if _session_db: try: _session_db.end_session(_cron_session_id, "cron_complete") diff --git a/docker-compose.yml b/docker-compose.yml index a0fe1a100a..ecf59d40c3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,6 +34,13 @@ services: # uncomment BOTH lines (API_SERVER_KEY is mandatory for auth): # - API_SERVER_HOST=0.0.0.0 # - API_SERVER_KEY=${API_SERVER_KEY} + # Microsoft Teams — uncomment and fill in to enable Teams gateway. + # Register your bot at https://dev.botframework.com/ to get these values. + # - TEAMS_CLIENT_ID=${TEAMS_CLIENT_ID} + # - TEAMS_CLIENT_SECRET=${TEAMS_CLIENT_SECRET} + # - TEAMS_TENANT_ID=${TEAMS_TENANT_ID} + # - TEAMS_ALLOWED_USERS=${TEAMS_ALLOWED_USERS} + # - TEAMS_PORT=3978 command: ["gateway", "run"] dashboard: diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 94936ac9dd..ff4af85a89 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -86,6 +86,16 @@ async def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: continue platforms[plat_name] = _build_from_sessions(plat_name) + # Include plugin-registered platforms (dynamic enum members aren't in + # Platform.__members__, so the loop above misses them). + try: + from gateway.platform_registry import platform_registry + for entry in platform_registry.plugin_entries(): + if entry.name not in _SKIP_SESSION_DISCOVERY and entry.name not in platforms: + platforms[entry.name] = _build_from_sessions(entry.name) + except Exception: + pass + directory = { "updated_at": datetime.now().isoformat(), "platforms": platforms, diff --git a/gateway/config.py b/gateway/config.py index da9830fcf2..7d4d259ca3 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -13,7 +13,7 @@ import os import json from pathlib import Path from dataclasses import dataclass, field -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, Callable from enum import Enum from hermes_cli.config import get_hermes_home @@ -45,8 +45,19 @@ def _normalize_unauthorized_dm_behavior(value: Any, default: str = "pair") -> st return default +# Module-level cache for bundled platform plugin names (lives outside the +# enum so it doesn't become an accidental enum member). +_Platform__bundled_plugin_names: Optional[set] = None + + class Platform(Enum): - """Supported messaging platforms.""" + """Supported messaging platforms. + + Built-in platforms have explicit members. Plugin platforms use dynamic + members created on-demand by ``_missing_()`` so that + ``Platform("irc")`` works without modifying this enum. Dynamic members + are cached in ``_value2member_map_`` for identity-stable comparisons. + """ LOCAL = "local" TELEGRAM = "telegram" DISCORD = "discord" @@ -68,6 +79,76 @@ class Platform(Enum): BLUEBUBBLES = "bluebubbles" QQBOT = "qqbot" YUANBAO = "yuanbao" + @classmethod + def _missing_(cls, value): + """Accept unknown platform names only for known plugin adapters. + + Creates a pseudo-member cached in ``_value2member_map_`` so that + ``Platform("irc") is Platform("irc")`` holds True (identity-stable). + Arbitrary strings are rejected to prevent enum pollution. + """ + if not isinstance(value, str) or not value.strip(): + return None + # Normalise to lowercase to avoid case mismatches in config + value = value.strip().lower() + # Check cache first (another call may have created it already) + if value in cls._value2member_map_: + return cls._value2member_map_[value] + + # Only create pseudo-members for bundled plugin platforms (discovered + # via filesystem scan) or runtime-registered plugin platforms. + global _Platform__bundled_plugin_names + if _Platform__bundled_plugin_names is None: + _Platform__bundled_plugin_names = cls._scan_bundled_plugin_platforms() + if value in _Platform__bundled_plugin_names: + pseudo = object.__new__(cls) + pseudo._value_ = value + pseudo._name_ = value.upper().replace("-", "_").replace(" ", "_") + cls._value2member_map_[value] = pseudo + cls._member_map_[pseudo._name_] = pseudo + return pseudo + + # Runtime-registered plugins (e.g. user-installed, discovered after + # the enum was defined). + try: + from gateway.platform_registry import platform_registry + if platform_registry.is_registered(value): + pseudo = object.__new__(cls) + pseudo._value_ = value + pseudo._name_ = value.upper().replace("-", "_").replace(" ", "_") + cls._value2member_map_[value] = pseudo + cls._member_map_[pseudo._name_] = pseudo + return pseudo + except Exception: + pass + + return None + + @classmethod + def _scan_bundled_plugin_platforms(cls) -> set: + """Return names of bundled platform plugins under ``plugins/platforms/``.""" + names: set = set() + try: + platforms_dir = Path(__file__).parent.parent / "plugins" / "platforms" + if platforms_dir.is_dir(): + for child in platforms_dir.iterdir(): + if ( + child.is_dir() + and (child / "__init__.py").exists() + and ( + (child / "plugin.yaml").exists() + or (child / "plugin.yml").exists() + ) + ): + names.add(child.name.lower()) + except Exception: + pass + return names + + +# Snapshot of built-in platform values before any dynamic _missing_ lookups. +# Used to distinguish real platforms from arbitrary strings. +_BUILTIN_PLATFORM_VALUES = frozenset(m.value for m in Platform.__members__.values()) @dataclass @@ -231,6 +312,44 @@ class StreamingConfig: ) +# ----------------------------------------------------------------------------- +# Built-in platform connection checkers +# ----------------------------------------------------------------------------- +# Each callable receives a ``PlatformConfig`` and returns ``True`` when the +# platform is sufficiently configured to be considered "connected". Platforms +# that rely on the generic ``token or api_key`` check (Telegram, Discord, +# Slack, Matrix, Mattermost, HomeAssistant) do not need an entry here. +_PLATFORM_CONNECTED_CHECKERS: dict[Platform, Callable[[PlatformConfig], bool]] = { + Platform.WEIXIN: lambda cfg: bool( + cfg.extra.get("account_id") and (cfg.token or cfg.extra.get("token")) + ), + Platform.WHATSAPP: lambda cfg: True, # bridge handles auth + Platform.SIGNAL: lambda cfg: bool(cfg.extra.get("http_url")), + Platform.EMAIL: lambda cfg: bool(cfg.extra.get("address")), + Platform.SMS: lambda cfg: bool(os.getenv("TWILIO_ACCOUNT_SID")), + Platform.API_SERVER: lambda cfg: True, + Platform.WEBHOOK: lambda cfg: True, + Platform.FEISHU: lambda cfg: bool(cfg.extra.get("app_id")), + Platform.WECOM: lambda cfg: bool(cfg.extra.get("bot_id")), + Platform.WECOM_CALLBACK: lambda cfg: bool( + cfg.extra.get("corp_id") or cfg.extra.get("apps") + ), + Platform.BLUEBUBBLES: lambda cfg: bool( + cfg.extra.get("server_url") and cfg.extra.get("password") + ), + Platform.QQBOT: lambda cfg: bool( + cfg.extra.get("app_id") and cfg.extra.get("client_secret") + ), + Platform.YUANBAO: lambda cfg: bool( + cfg.extra.get("app_id") and cfg.extra.get("app_secret") + ), + Platform.DINGTALK: lambda cfg: bool( + (cfg.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID")) + and (cfg.extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET")) + ), +} + + @dataclass class GatewayConfig: """ @@ -284,61 +403,43 @@ class GatewayConfig: for platform, config in self.platforms.items(): if not config.enabled: continue - # Weixin requires both a token and an account_id - if platform == Platform.WEIXIN: - if config.extra.get("account_id") and (config.token or config.extra.get("token")): - connected.append(platform) - continue - # Platforms that use token/api_key auth - if config.token or config.api_key: + if self._is_platform_connected(platform, config): connected.append(platform) - # WhatsApp uses enabled flag only (bridge handles auth) - elif platform == Platform.WHATSAPP: - connected.append(platform) - # Signal uses extra dict for config (http_url + account) - elif platform == Platform.SIGNAL and config.extra.get("http_url"): - connected.append(platform) - # Email uses extra dict for config (address + imap_host + smtp_host) - elif platform == Platform.EMAIL and config.extra.get("address"): - connected.append(platform) - # SMS uses api_key (Twilio auth token) — SID checked via env - elif platform == Platform.SMS and os.getenv("TWILIO_ACCOUNT_SID"): - connected.append(platform) - # API Server uses enabled flag only (no token needed) - elif platform == Platform.API_SERVER: - connected.append(platform) - # Webhook uses enabled flag only (secrets are per-route) - elif platform == Platform.WEBHOOK: - connected.append(platform) - # Feishu uses extra dict for app credentials - elif platform == Platform.FEISHU and config.extra.get("app_id"): - connected.append(platform) - # WeCom bot mode uses extra dict for bot credentials - elif platform == Platform.WECOM and config.extra.get("bot_id"): - connected.append(platform) - # WeCom callback mode uses corp_id or apps list - elif platform == Platform.WECOM_CALLBACK and ( - config.extra.get("corp_id") or config.extra.get("apps") - ): - connected.append(platform) - # BlueBubbles uses extra dict for local server config - elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): - connected.append(platform) - # QQBot uses extra dict for app credentials - elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"): - connected.append(platform) - # Yuanbao uses extra dict for app credentials - elif platform == Platform.YUANBAO and config.extra.get("app_id") and config.extra.get("app_secret"): - connected.append(platform) - # DingTalk uses client_id/client_secret from config.extra or env vars - elif platform == Platform.DINGTALK and ( - config.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID") - ) and ( - config.extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET") - ): - connected.append(platform) - return connected + + def _is_platform_connected(self, platform: Platform, config: PlatformConfig) -> bool: + """Check whether a single platform is sufficiently configured.""" + # Weixin requires both a token and an account_id (checked first so + # the generic token branch doesn't let it through without account_id). + if platform == Platform.WEIXIN: + return bool( + config.extra.get("account_id") + and (config.token or config.extra.get("token")) + ) + + # Generic token/api_key auth covers Telegram, Discord, Slack, etc. + if config.token or config.api_key: + return True + + # Platform-specific check + checker = _PLATFORM_CONNECTED_CHECKERS.get(platform) + if checker is not None: + return checker(config) + + # Plugin-registered platforms + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform.value) + if entry: + if entry.is_connected is not None: + return entry.is_connected(config) + if entry.validate_config is not None: + return entry.validate_config(config) + return True + except Exception: + pass # Registry not yet initialised during early import + + return False def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: """Get the home channel for a platform.""" @@ -714,11 +815,21 @@ def load_gateway_config() -> GatewayConfig: os.environ["TELEGRAM_REACTIONS"] = str(telegram_cfg["reactions"]).lower() if "proxy_url" in telegram_cfg and not os.getenv("TELEGRAM_PROXY"): os.environ["TELEGRAM_PROXY"] = str(telegram_cfg["proxy_url"]).strip() - if "group_allowed_chats" in telegram_cfg and not os.getenv("TELEGRAM_GROUP_ALLOWED_USERS"): - gac = telegram_cfg["group_allowed_chats"] - if isinstance(gac, list): - gac = ",".join(str(v) for v in gac) - os.environ["TELEGRAM_GROUP_ALLOWED_USERS"] = str(gac) + allowed_users = telegram_cfg.get("allow_from") + if allowed_users is not None and not os.getenv("TELEGRAM_ALLOWED_USERS"): + if isinstance(allowed_users, list): + allowed_users = ",".join(str(v) for v in allowed_users) + os.environ["TELEGRAM_ALLOWED_USERS"] = str(allowed_users) + group_allowed_users = telegram_cfg.get("group_allow_from") + if group_allowed_users is not None and not os.getenv("TELEGRAM_GROUP_ALLOWED_USERS"): + if isinstance(group_allowed_users, list): + group_allowed_users = ",".join(str(v) for v in group_allowed_users) + os.environ["TELEGRAM_GROUP_ALLOWED_USERS"] = str(group_allowed_users) + group_allowed_chats = telegram_cfg.get("group_allowed_chats") + if group_allowed_chats is not None and not os.getenv("TELEGRAM_GROUP_ALLOWED_CHATS"): + if isinstance(group_allowed_chats, list): + group_allowed_chats = ",".join(str(v) for v in group_allowed_chats) + os.environ["TELEGRAM_GROUP_ALLOWED_CHATS"] = str(group_allowed_chats) if "disable_link_previews" in telegram_cfg: plat_data = platforms_data.setdefault(Platform.TELEGRAM.value, {}) if not isinstance(plat_data, dict): @@ -1371,3 +1482,25 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.default_reset_policy.at_hour = int(reset_hour) except ValueError: pass + + # Registry-driven enable for plugin platforms. Built-ins have explicit + # blocks above; plugins expose check_fn() which is the single source of + # truth for "are my env vars set?". When it returns True, ensure the + # platform is enabled so start() will create its adapter. + try: + from hermes_cli.plugins import discover_plugins + discover_plugins() # idempotent + from gateway.platform_registry import platform_registry + for entry in platform_registry.plugin_entries(): + try: + if not entry.check_fn(): + continue + except Exception as e: + logger.debug("check_fn for %s raised: %s", entry.name, e) + continue + platform = Platform(entry.name) + if platform not in config.platforms: + config.platforms[platform] = PlatformConfig() + config.platforms[platform].enabled = True + except Exception as e: + logger.debug("Plugin platform enable pass failed: %s", e) diff --git a/gateway/hooks.py b/gateway/hooks.py index f887cf5df0..5ab4511920 100644 --- a/gateway/hooks.py +++ b/gateway/hooks.py @@ -21,6 +21,7 @@ Errors in hooks are caught and logged but never block the main pipeline. import asyncio import importlib.util +import sys from typing import Any, Callable, Dict, List, Optional import yaml @@ -97,16 +98,28 @@ class HookRegistry: print(f"[hooks] Skipping {hook_name}: no events declared", flush=True) continue - # Dynamically load the handler module + # Dynamically load the handler module. + # Register in sys.modules BEFORE exec_module so Pydantic / + # dataclasses / typing introspection can resolve forward + # references (triggered by `from __future__ import annotations` + # in the handler). Without this, a handler that declares a + # Pydantic BaseModel for webhook/event payloads fails at first + # dispatch with "TypeAdapter ... is not fully defined". + module_name = f"hermes_hook_{hook_name}" spec = importlib.util.spec_from_file_location( - f"hermes_hook_{hook_name}", handler_path + module_name, handler_path ) if spec is None or spec.loader is None: print(f"[hooks] Skipping {hook_name}: could not load handler.py", flush=True) continue module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(module_name, None) + raise handle_fn = getattr(module, "handle", None) if handle_fn is None: diff --git a/gateway/platform_registry.py b/gateway/platform_registry.py new file mode 100644 index 0000000000..11303466da --- /dev/null +++ b/gateway/platform_registry.py @@ -0,0 +1,212 @@ +""" +Platform Adapter Registry + +Allows platform adapters (built-in and plugin) to self-register so the gateway +can discover and instantiate them without hardcoded if/elif chains. + +Built-in adapters continue to use the existing if/elif in _create_adapter() +for now. Plugin adapters register here via PluginContext.register_platform() +and are looked up first -- if nothing is found the gateway falls through to +the legacy code path. + +Usage (plugin side): + + from gateway.platform_registry import platform_registry, PlatformEntry + + platform_registry.register(PlatformEntry( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=check_requirements, + validate_config=lambda cfg: bool(cfg.extra.get("server")), + required_env=["IRC_SERVER"], + install_hint="pip install irc", + )) + +Usage (gateway side): + + adapter = platform_registry.create_adapter("irc", platform_config) +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class PlatformEntry: + """Metadata and factory for a single platform adapter.""" + + # Identifier used in config.yaml (e.g. "irc", "viber"). + name: str + + # Human-readable label (e.g. "IRC", "Viber"). + label: str + + # Factory callable: receives a PlatformConfig, returns an adapter instance. + # Using a factory instead of a bare class lets plugins do custom init + # (e.g. passing extra kwargs, wrapping in try/except). + adapter_factory: Callable[[Any], Any] + + # Returns True when the platform's dependencies are available. + check_fn: Callable[[], bool] + + # Optional: given a PlatformConfig, is it properly configured? + # If None, the registry skips config validation and lets the adapter + # fail at connect() time with a descriptive error. + validate_config: Optional[Callable[[Any], bool]] = None + + # Optional: given a PlatformConfig, is the platform connected/enabled? + # Used by ``GatewayConfig.get_connected_platforms()`` and setup UI status. + # If None, falls back to ``validate_config`` or ``check_fn``. + is_connected: Optional[Callable[[Any], bool]] = None + + # Env vars this platform needs (for ``hermes setup`` display). + required_env: list = field(default_factory=list) + + # Hint shown when check_fn returns False. + install_hint: str = "" + + # Optional setup function for interactive configuration. + # Signature: () -> None (prompts user, saves env vars). + # If None, falls back to _setup_standard_platform (needs token_var + vars) + # or a generic "set these env vars" display. + setup_fn: Optional[Callable[[], None]] = None + + # "builtin" or "plugin" + source: str = "plugin" + + # Name of the plugin manifest that registered this entry (empty for + # built-ins). Used by ``hermes gateway setup`` to auto-enable the + # owning plugin when the user configures its platform. + plugin_name: str = "" + + # ── Auth env var names (for _is_user_authorized integration) ── + # E.g. "IRC_ALLOWED_USERS" — checked for comma-separated user IDs. + allowed_users_env: str = "" + # E.g. "IRC_ALLOW_ALL_USERS" — if truthy, all users authorized. + allow_all_env: str = "" + + # ── Message limits ── + # Max message length for smart-chunking. 0 = no limit. + max_message_length: int = 0 + + # ── Privacy ── + # If True, session descriptions redact PII (phone numbers, etc.) + pii_safe: bool = False + + # ── Display ── + # Emoji for CLI/gateway display (e.g. "💬") + emoji: str = "🔌" + + # Whether this platform should appear in _UPDATE_ALLOWED_PLATFORMS + # (allows /update command from this platform). + allow_update_command: bool = True + + # ── LLM guidance ── + # Platform hint injected into the system prompt (e.g. "You are on IRC. + # Do not use markdown."). Empty string = no hint. + platform_hint: str = "" + + +class PlatformRegistry: + """Central registry of platform adapters. + + Thread-safe for reads (dict lookups are atomic under GIL). + Writes happen at startup during sequential discovery. + """ + + def __init__(self) -> None: + self._entries: dict[str, PlatformEntry] = {} + + def register(self, entry: PlatformEntry) -> None: + """Register a platform adapter entry. + + If an entry with the same name exists, it is replaced (last writer + wins -- this lets plugins override built-in adapters if desired). + """ + if entry.name in self._entries: + prev = self._entries[entry.name] + logger.info( + "Platform '%s' re-registered (was %s, now %s)", + entry.name, + prev.source, + entry.source, + ) + self._entries[entry.name] = entry + logger.debug("Registered platform adapter: %s (%s)", entry.name, entry.source) + + def unregister(self, name: str) -> bool: + """Remove a platform entry. Returns True if it existed.""" + return self._entries.pop(name, None) is not None + + def get(self, name: str) -> Optional[PlatformEntry]: + """Look up a platform entry by name.""" + return self._entries.get(name) + + def all_entries(self) -> list[PlatformEntry]: + """Return all registered platform entries.""" + return list(self._entries.values()) + + def plugin_entries(self) -> list[PlatformEntry]: + """Return only plugin-registered platform entries.""" + return [e for e in self._entries.values() if e.source == "plugin"] + + def is_registered(self, name: str) -> bool: + return name in self._entries + + def create_adapter(self, name: str, config: Any) -> Optional[Any]: + """Create an adapter instance for the given platform name. + + Returns None if: + - No entry registered for *name* + - check_fn() returns False (missing deps) + - validate_config() returns False (misconfigured) + - The factory raises an exception + """ + entry = self._entries.get(name) + if entry is None: + return None + + if not entry.check_fn(): + hint = f" ({entry.install_hint})" if entry.install_hint else "" + logger.warning( + "Platform '%s' requirements not met%s", + entry.label, + hint, + ) + return None + + if entry.validate_config is not None: + try: + if not entry.validate_config(config): + logger.warning( + "Platform '%s' config validation failed", + entry.label, + ) + return None + except Exception as e: + logger.warning( + "Platform '%s' config validation error: %s", + entry.label, + e, + ) + return None + + try: + adapter = entry.adapter_factory(config) + return adapter + except Exception as e: + logger.error( + "Failed to create adapter for platform '%s': %s", + entry.label, + e, + exc_info=True, + ) + return None + + +# Module-level singleton +platform_registry = PlatformRegistry() diff --git a/gateway/platforms/ADDING_A_PLATFORM.md b/gateway/platforms/ADDING_A_PLATFORM.md index f773f8c8f8..7fd28245b1 100644 --- a/gateway/platforms/ADDING_A_PLATFORM.md +++ b/gateway/platforms/ADDING_A_PLATFORM.md @@ -1,9 +1,30 @@ # Adding a New Messaging Platform -Checklist for integrating a new messaging platform into the Hermes gateway. -Use this as a reference when building a new adapter — every item here is a -real integration point that exists in the codebase. Missing any of them will -cause broken functionality, missing features, or inconsistent behavior. +There are two ways to add a platform to the Hermes gateway: + +## Plugin Path (Recommended for Community/Third-Party) + +Create a plugin directory in `~/.hermes/plugins/` with a `PLUGIN.yaml` and +`adapter.py`. The adapter inherits from `BasePlatformAdapter` and registers +via `ctx.register_platform()` in the `register(ctx)` entry point. This +requires **zero changes to core Hermes code**. + +The plugin system automatically handles: adapter creation, config parsing, +user authorization, cron delivery, send_message routing, system prompt hints, +status display, gateway setup, and more. + +See `plugins/platforms/irc/` for a complete reference implementation, and +`website/docs/developer-guide/adding-platform-adapters.md` for the full +plugin guide with code examples. + +--- + +## Built-in Path (Core Contributors Only) + +Checklist for integrating a platform directly into the Hermes core. +Use this as a reference when building a built-in adapter — every item here +is a real integration point. Missing any of them will cause broken +functionality, missing features, or inconsistent behavior. --- diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index b7a6a09693..8c46cc6157 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -7,7 +7,9 @@ Exposes an HTTP server with endpoints: - GET /v1/responses/{response_id} — Retrieve a stored response - DELETE /v1/responses/{response_id} — Delete a stored response - GET /v1/models — lists hermes-agent as an available model +- GET /v1/capabilities — machine-readable API capabilities for external UIs - POST /v1/runs — start a run, returns run_id immediately (202) +- GET /v1/runs/{run_id} — retrieve current run status - GET /v1/runs/{run_id}/events — SSE stream of structured lifecycle events - POST /v1/runs/{run_id}/stop — interrupt a running agent - GET /health — health check @@ -590,6 +592,8 @@ class APIServerAdapter(BasePlatformAdapter): # Active run agent/task references for stop support self._active_run_agents: Dict[str, Any] = {} self._active_run_tasks: Dict[str, "asyncio.Task"] = {} + # Pollable run status for dashboards and external control-plane UIs. + self._run_statuses: Dict[str, Dict[str, Any]] = {} self._session_db: Optional[Any] = None # Lazy-init SessionDB for session continuity @staticmethod @@ -808,6 +812,51 @@ class APIServerAdapter(BasePlatformAdapter): ], }) + async def _handle_capabilities(self, request: "web.Request") -> "web.Response": + """GET /v1/capabilities — advertise the stable API surface. + + External UIs and orchestrators use this endpoint to discover the API + server's plugin-safe contract without scraping docs or assuming that + every Hermes version exposes the same endpoints. + """ + auth_err = self._check_auth(request) + if auth_err: + return auth_err + + return web.json_response({ + "object": "hermes.api_server.capabilities", + "platform": "hermes-agent", + "model": self._model_name, + "auth": { + "type": "bearer", + "required": bool(self._api_key), + }, + "features": { + "chat_completions": True, + "chat_completions_streaming": True, + "responses_api": True, + "responses_streaming": True, + "run_submission": True, + "run_status": True, + "run_events_sse": True, + "run_stop": True, + "tool_progress_events": True, + "session_continuity_header": "X-Hermes-Session-Id", + "cors": bool(self._cors_origins), + }, + "endpoints": { + "health": {"method": "GET", "path": "/health"}, + "health_detailed": {"method": "GET", "path": "/health/detailed"}, + "models": {"method": "GET", "path": "/v1/models"}, + "chat_completions": {"method": "POST", "path": "/v1/chat/completions"}, + "responses": {"method": "POST", "path": "/v1/responses"}, + "runs": {"method": "POST", "path": "/v1/runs"}, + "run_status": {"method": "GET", "path": "/v1/runs/{run_id}"}, + "run_events": {"method": "GET", "path": "/v1/runs/{run_id}/events"}, + "run_stop": {"method": "POST", "path": "/v1/runs/{run_id}/stop"}, + }, + }) + async def _handle_chat_completions(self, request: "web.Request") -> "web.Response": """POST /v1/chat/completions — OpenAI Chat Completions format.""" auth_err = self._check_auth(request) @@ -932,39 +981,62 @@ class APIServerAdapter(BasePlatformAdapter): if delta is not None: _stream_q.put(delta) - def _on_tool_progress(event_type, name, preview, args, **kwargs): - """Send tool progress as a separate SSE event. + # Track which tool_call_ids we've emitted a "running" lifecycle + # event for, so a "completed" event without a matching "running" + # (e.g. internal/filtered tools) is silently dropped instead of + # producing an orphaned event clients can't correlate. + _started_tool_call_ids: set[str] = set() - Previously, progress markers like ``⏰ list`` were injected - directly into ``delta.content``. OpenAI-compatible frontends - (Open WebUI, LobeChat, …) store ``delta.content`` verbatim as - the assistant message and send it back on subsequent requests. - After enough turns the model learns to *emit* the markers as - plain text instead of issuing real tool calls — silently - hallucinating tool results. See #6972. + def _on_tool_start(tool_call_id, function_name, function_args): + """Emit ``hermes.tool.progress`` with ``status: running``. - The fix: push a tagged tuple ``("__tool_progress__", payload)`` - onto the stream queue. The SSE writer emits it as a custom - ``event: hermes.tool.progress`` line that compliant frontends - can render for UX but will *not* persist into conversation - history. Clients that don't understand the custom event type - silently ignore it per the SSE specification. + Replaces the old ``tool_progress_callback("tool.started", + ...)`` emit so SSE consumers receive a single event per + tool start, carrying both the legacy ``tool``/``emoji``/ + ``label`` payload (for #6972 frontends) and the new + ``toolCallId``/``status`` correlation fields (#16588). + + Skips tools whose names start with ``_`` so internal + events (``_thinking``, …) stay off the wire — matching + the prior ``_on_tool_progress`` filter exactly. """ - if event_type != "tool.started": + if not tool_call_id or function_name.startswith("_"): return - if name.startswith("_"): - return - from agent.display import get_tool_emoji - emoji = get_tool_emoji(name) - label = preview or name + _started_tool_call_ids.add(tool_call_id) + from agent.display import build_tool_preview, get_tool_emoji + label = build_tool_preview(function_name, function_args) or function_name _stream_q.put(("__tool_progress__", { - "tool": name, - "emoji": emoji, + "tool": function_name, + "emoji": get_tool_emoji(function_name), "label": label, + "toolCallId": tool_call_id, + "status": "running", + })) + + def _on_tool_complete(tool_call_id, function_name, function_args, function_result): + """Emit the matching ``status: completed`` event. + + Dropped if the start was filtered (internal tool, missing + id, or never seen) so clients never get an orphaned + ``completed`` they can't correlate to a prior ``running``. + """ + if not tool_call_id or tool_call_id not in _started_tool_call_ids: + return + _started_tool_call_ids.discard(tool_call_id) + _stream_q.put(("__tool_progress__", { + "tool": function_name, + "toolCallId": tool_call_id, + "status": "completed", })) # Start agent in background. agent_ref is a mutable container # so the SSE writer can interrupt the agent on client disconnect. + # + # ``tool_progress_callback`` is intentionally not wired here: + # it would duplicate every emit because ``run_agent`` fires it + # side-by-side with ``tool_start_callback``/``tool_complete_callback``. + # The structured callbacks are strictly richer (they carry the + # tool_call id), so they own the chat-completions SSE channel. agent_ref = [None] agent_task = asyncio.ensure_future(self._run_agent( user_message=user_message, @@ -972,7 +1044,8 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt=system_prompt, session_id=session_id, stream_delta_callback=_on_delta, - tool_progress_callback=_on_tool_progress, + tool_start_callback=_on_tool_start, + tool_complete_callback=_on_tool_complete, agent_ref=agent_ref, )) @@ -1087,7 +1160,8 @@ class APIServerAdapter(BasePlatformAdapter): Tagged tuples ``("__tool_progress__", payload)`` are sent as a custom ``event: hermes.tool.progress`` SSE event so frontends can display them without storing the markers in - conversation history. See #6972. + conversation history. See #6972 for the original event, + #16588 for the ``toolCallId``/``status`` lifecycle fields. """ if isinstance(item, tuple) and len(item) == 2 and item[0] == "__tool_progress__": event_data = json.dumps(item[1]) @@ -2297,10 +2371,31 @@ class APIServerAdapter(BasePlatformAdapter): _MAX_CONCURRENT_RUNS = 10 # Prevent unbounded resource allocation _RUN_STREAM_TTL = 300 # seconds before orphaned runs are swept + _RUN_STATUS_TTL = 3600 # seconds to retain terminal run status for polling + + def _set_run_status(self, run_id: str, status: str, **fields: Any) -> Dict[str, Any]: + """Update pollable run status without exposing private agent objects.""" + now = time.time() + current = self._run_statuses.get(run_id, {}) + current.update({ + "object": "hermes.run", + "run_id": run_id, + "status": status, + "updated_at": now, + }) + current.setdefault("created_at", fields.pop("created_at", now)) + current.update(fields) + self._run_statuses[run_id] = current + return current def _make_run_event_callback(self, run_id: str, loop: "asyncio.AbstractEventLoop"): """Return a tool_progress_callback that pushes structured events to the run's SSE queue.""" def _push(event: Dict[str, Any]) -> None: + self._set_run_status( + run_id, + self._run_statuses.get(run_id, {}).get("status", "running"), + last_event=event.get("event"), + ) q = self._run_streams.get(run_id) if q is None: return @@ -2365,28 +2460,6 @@ class APIServerAdapter(BasePlatformAdapter): if not user_message: return web.json_response(_openai_error("No user message found in input"), status=400) - run_id = f"run_{uuid.uuid4().hex}" - loop = asyncio.get_running_loop() - q: "asyncio.Queue[Optional[Dict]]" = asyncio.Queue() - self._run_streams[run_id] = q - self._run_streams_created[run_id] = time.time() - - event_cb = self._make_run_event_callback(run_id, loop) - - # Also wire stream_delta_callback so message.delta events flow through - def _text_cb(delta: Optional[str]) -> None: - if delta is None: - return - try: - loop.call_soon_threadsafe(q.put_nowait, { - "event": "message.delta", - "run_id": run_id, - "timestamp": time.time(), - "delta": delta, - }) - except Exception: - pass - instructions = body.get("instructions") previous_response_id = body.get("previous_response_id") @@ -2434,11 +2507,42 @@ class APIServerAdapter(BasePlatformAdapter): ) conversation_history.append({"role": msg["role"], "content": str(content)}) + run_id = f"run_{uuid.uuid4().hex}" session_id = body.get("session_id") or stored_session_id or run_id ephemeral_system_prompt = instructions + loop = asyncio.get_running_loop() + q: "asyncio.Queue[Optional[Dict]]" = asyncio.Queue() + created_at = time.time() + self._run_streams[run_id] = q + self._run_streams_created[run_id] = created_at + + event_cb = self._make_run_event_callback(run_id, loop) + + # Also wire stream_delta_callback so message.delta events flow through. + def _text_cb(delta: Optional[str]) -> None: + if delta is None: + return + try: + loop.call_soon_threadsafe(q.put_nowait, { + "event": "message.delta", + "run_id": run_id, + "timestamp": time.time(), + "delta": delta, + }) + except Exception: + pass + + self._set_run_status( + run_id, + "queued", + created_at=created_at, + session_id=session_id, + model=body.get("model", self._model_name), + ) async def _run_and_close(): try: + self._set_run_status(run_id, "running") agent = self._create_agent( ephemeral_system_prompt=ephemeral_system_prompt, session_id=session_id, @@ -2468,8 +2572,36 @@ class APIServerAdapter(BasePlatformAdapter): "output": final_response, "usage": usage, }) + self._set_run_status( + run_id, + "completed", + output=final_response, + usage=usage, + last_event="run.completed", + ) + except asyncio.CancelledError: + self._set_run_status( + run_id, + "cancelled", + last_event="run.cancelled", + ) + try: + q.put_nowait({ + "event": "run.cancelled", + "run_id": run_id, + "timestamp": time.time(), + }) + except Exception: + pass + raise except Exception as exc: logger.exception("[api_server] run %s failed", run_id) + self._set_run_status( + run_id, + "failed", + error=str(exc), + last_event="run.failed", + ) try: q.put_nowait({ "event": "run.failed", @@ -2499,6 +2631,21 @@ class APIServerAdapter(BasePlatformAdapter): return web.json_response({"run_id": run_id, "status": "started"}, status=202) + async def _handle_get_run(self, request: "web.Request") -> "web.Response": + """GET /v1/runs/{run_id} — return pollable run status for external UIs.""" + auth_err = self._check_auth(request) + if auth_err: + return auth_err + + run_id = request.match_info["run_id"] + status = self._run_statuses.get(run_id) + if status is None: + return web.json_response( + _openai_error(f"Run not found: {run_id}", code="run_not_found"), + status=404, + ) + return web.json_response(status) + async def _handle_run_events(self, request: "web.Request") -> "web.StreamResponse": """GET /v1/runs/{run_id}/events — SSE stream of structured agent lifecycle events.""" auth_err = self._check_auth(request) @@ -2561,6 +2708,8 @@ class APIServerAdapter(BasePlatformAdapter): if agent is None and task is None: return web.json_response(_openai_error(f"Run not found: {run_id}", code="run_not_found"), status=404) + self._set_run_status(run_id, "stopping", last_event="run.stopping") + if agent is not None: try: agent.interrupt("Stop requested via API") @@ -2603,6 +2752,15 @@ class APIServerAdapter(BasePlatformAdapter): self._active_run_agents.pop(run_id, None) self._active_run_tasks.pop(run_id, None) + stale_statuses = [ + run_id + for run_id, status in list(self._run_statuses.items()) + if status.get("status") in {"completed", "failed", "cancelled"} + and now - float(status.get("updated_at", 0) or 0) > self._RUN_STATUS_TTL + ] + for run_id in stale_statuses: + self._run_statuses.pop(run_id, None) + # ------------------------------------------------------------------ # BasePlatformAdapter interface # ------------------------------------------------------------------ @@ -2621,6 +2779,7 @@ class APIServerAdapter(BasePlatformAdapter): self._app.router.add_get("/health/detailed", self._handle_health_detailed) self._app.router.add_get("/v1/health", self._handle_health) self._app.router.add_get("/v1/models", self._handle_models) + self._app.router.add_get("/v1/capabilities", self._handle_capabilities) self._app.router.add_post("/v1/chat/completions", self._handle_chat_completions) self._app.router.add_post("/v1/responses", self._handle_responses) self._app.router.add_get("/v1/responses/{response_id}", self._handle_get_response) @@ -2636,6 +2795,7 @@ class APIServerAdapter(BasePlatformAdapter): self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job) # Structured event streaming self._app.router.add_post("/v1/runs", self._handle_runs) + self._app.router.add_get("/v1/runs/{run_id}", self._handle_get_run) self._app.router.add_get("/v1/runs/{run_id}/events", self._handle_run_events) self._app.router.add_post("/v1/runs/{run_id}/stop", self._handle_stop_run) # Start background sweep to clean up orphaned (unconsumed) run streams diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index a06b6fa711..da992792e3 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -23,6 +23,45 @@ from utils import normalize_proxy_url logger = logging.getLogger(__name__) +# Audio file extensions Hermes recognizes for native audio delivery. +# Kept in sync with tools/send_message_tool.py and cron/scheduler.py via +# should_send_media_as_audio() below. +_AUDIO_EXTS = frozenset({'.ogg', '.opus', '.mp3', '.wav', '.m4a', '.flac'}) +# Telegram's Bot API sendAudio only accepts MP3 / M4A. Other audio +# formats either need to go through sendVoice (Opus/OGG) or must be +# delivered as a regular document. +_TELEGRAM_AUDIO_ATTACHMENT_EXTS = frozenset({'.mp3', '.m4a'}) +_TELEGRAM_VOICE_EXTS = frozenset({'.ogg', '.opus'}) + + +def _platform_name(platform) -> str: + """Normalize a Platform enum / raw string into a lowercase name.""" + value = getattr(platform, "value", platform) + return str(value or "").lower() + + +def should_send_media_as_audio(platform, ext: str, is_voice: bool = False) -> bool: + """Return True when a media file should use the platform's audio sender. + + Other platforms: every recognized audio extension routes through the + audio sender. + + Telegram: the Bot API only accepts MP3/M4A for sendAudio and + Opus/OGG for sendVoice. Opus/OGG is only routed as audio when the + caller flagged ``is_voice=True`` (so we don't turn a regular audio + attachment into a voice bubble just because the file happens to be + Opus). Everything else falls through to document delivery by + returning ``False``. + """ + normalized_ext = (ext or "").lower() + if normalized_ext not in _AUDIO_EXTS: + return False + if _platform_name(platform) == "telegram": + if normalized_ext in _TELEGRAM_VOICE_EXTS: + return is_voice + return normalized_ext in _TELEGRAM_AUDIO_ATTACHMENT_EXTS + return True + def utf16_len(s: str) -> int: """Count UTF-16 code units in *s*. @@ -1415,6 +1454,41 @@ class BasePlatformAdapter(ABC): """ return False + async def send_slash_confirm( + self, + chat_id: str, + title: str, + message: str, + session_key: str, + confirm_id: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a three-option slash-command confirmation prompt. + + Used by the gateway's generic slash-confirm primitive (see + ``GatewayRunner._request_slash_confirm``) for commands that have a + non-destructive but expensive side effect the user should explicitly + acknowledge — the current caller is ``/reload-mcp``, which + invalidates the provider prompt cache. + + Platforms with inline-button support (Telegram, Discord, Slack, + Matrix, Feishu) should override this to render three buttons: + Approve Once / Always Approve / Cancel. Button callbacks MUST be + routed back through the gateway by calling + ``GatewayRunner._resolve_slash_confirm(confirm_id, choice)`` where + ``choice`` is ``"once"`` / ``"always"`` / ``"cancel"``. + + Platforms without button UIs leave this as the default and fall + through to the gateway's text fallback (which sends ``message`` as + plain text and intercepts the next ``/approve`` / ``/always`` / + ``/cancel`` reply). + + ``confirm_id`` is a short string generated by the gateway; the + adapter stores it alongside any platform-specific state needed to + route the callback (e.g. Telegram's ``_approval_state`` dict). + """ + return SendResult(success=False, error="Not supported") + async def send_typing(self, chat_id: str, metadata=None) -> None: """ Send a typing indicator. @@ -1640,7 +1714,7 @@ class BasePlatformAdapter(ABC): # Extract MEDIA: tags, allowing optional whitespace after the colon # and quoted/backticked paths for LLM-formatted outputs. media_pattern = re.compile( - r'''[`"']?MEDIA:\s*(?P`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a|epub|pdf|zip|rar|7z|docx?|xlsx?|pptx?|txt|csv|apk|ipa)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?''' + r'''[`"']?MEDIA:\s*(?P`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a|flac|epub|pdf|zip|rar|7z|docx?|xlsx?|pptx?|txt|csv|apk|ipa)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?''' ) for match in media_pattern.finditer(content): path = match.group("path").strip() @@ -1780,11 +1854,19 @@ class BasePlatformAdapter(ABC): if stop_event is None: await asyncio.sleep(interval) continue - try: - await asyncio.wait_for(stop_event.wait(), timeout=interval) - except asyncio.TimeoutError: - continue - return + loop = asyncio.get_running_loop() + deadline = loop.time() + interval + while not stop_event.is_set(): + remaining = deadline - loop.time() + if remaining <= 0: + break + # Poll instead of wait_for(stop_event.wait()). Cancelling + # wait_for while it owns the inner Event.wait task can leave + # shutdown paths stuck awaiting the typing task on Python + # 3.11/pytest-asyncio; sleep cancellation is immediate. + await asyncio.sleep(min(0.25, remaining)) + if stop_event.is_set(): + return except asyncio.CancelledError: pass # Normal cancellation when handler completes finally: @@ -2117,6 +2199,12 @@ class BasePlatformAdapter(ABC): ``release_guard=False`` keeps the adapter-level session guard in place so reset-like commands can finish atomically before follow-up messages are allowed to start a fresh background task. + + Bounded by a 5s timeout so a wedged finally block in the cancelled + task (typing-task cleanup, on_processing_complete hook, etc.) can't + stall the calling dispatch coroutine — particularly under pytest- + asyncio where the event loop's cancellation-propagation semantics + differ subtly from a bare ``asyncio.run`` harness. """ task = self._session_tasks.pop(session_key, None) if task is not None and not task.done(): @@ -2128,9 +2216,15 @@ class BasePlatformAdapter(ABC): self._expected_cancelled_tasks.add(task) task.cancel() try: - await task + await asyncio.wait_for(asyncio.shield(task), timeout=5.0) except asyncio.CancelledError: pass + except asyncio.TimeoutError: + logger.warning( + "[%s] Cancelled task for %s did not exit within 5s; " + "unblocking dispatch and letting the task unwind in the background", + self.name, session_key, + ) except Exception: logger.debug( "[%s] Session cancellation raised while unwinding %s", @@ -2382,6 +2476,16 @@ class BasePlatformAdapter(ABC): **_keep_typing_kwargs, ) ) + + async def _stop_typing_task() -> None: + typing_task.cancel() + try: + await asyncio.wait_for(asyncio.shield(typing_task), timeout=0.5) + except (asyncio.CancelledError, asyncio.TimeoutError): + # Cancellation cleanup must not block adapter shutdown. The + # typing task is already cancelled; if the parent task is also + # cancelling, let this message-processing task unwind now. + pass try: await self._run_processing_hook("on_processing_start", event) @@ -2514,7 +2618,6 @@ class BasePlatformAdapter(ABC): logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True) # Send extracted media files — route by file type - _AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'} _VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'} _IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'} @@ -2523,7 +2626,7 @@ class BasePlatformAdapter(ABC): await asyncio.sleep(human_delay) try: ext = Path(media_path).suffix.lower() - if ext in _AUDIO_EXTS: + if should_send_media_as_audio(self.platform, ext, is_voice=is_voice): media_result = await self.send_voice( chat_id=event.source.chat_id, audio_path=media_path, @@ -2604,11 +2707,7 @@ class BasePlatformAdapter(ABC): _active = self._active_sessions.get(session_key) if _active is not None: _active.clear() - typing_task.cancel() - try: - await typing_task - except asyncio.CancelledError: - pass + await _stop_typing_task() # Process pending message in new background task await self._process_message_background(pending_event, session_key) return # Already cleaned up @@ -2656,11 +2755,7 @@ class BasePlatformAdapter(ABC): except Exception: pass # Stop typing indicator - typing_task.cancel() - try: - await typing_task - except asyncio.CancelledError: - pass + await _stop_typing_task() # Also cancel any platform-level persistent typing tasks (e.g. Discord) # that may have been recreated by _keep_typing after the last stop_typing() try: @@ -2713,6 +2808,11 @@ class BasePlatformAdapter(ABC): Used during gateway shutdown/replacement so active sessions from the old process do not keep running after adapters are being torn down. + + Each cancelled task is awaited with a 5s bound so a wedged finally + (typing-task cleanup, on_processing_complete hook) can't stall the + whole shutdown path. Stragglers are released from our tracking and + allowed to finish unwinding on their own. """ # Loop until no new tasks appear. Without this, a message # arriving during the `await asyncio.gather` below would spawn @@ -2731,7 +2831,21 @@ class BasePlatformAdapter(ABC): for task in tasks: self._expected_cancelled_tasks.add(task) task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + try: + await asyncio.wait_for( + asyncio.gather( + *(asyncio.shield(t) for t in tasks), + return_exceptions=True, + ), + timeout=5.0, + ) + except asyncio.TimeoutError: + logger.warning( + "[%s] %d background task(s) did not exit within 5s; " + "releasing tracking and letting them unwind in the background", + self.name, len([t for t in tasks if not t.done()]), + ) + break # Loop: late-arrival tasks spawned during the gather above # will be in self._background_tasks now. Re-check. self._background_tasks.clear() diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 2e08d77c68..fb691ec535 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -2270,6 +2270,10 @@ class DiscordAdapter(BasePlatformAdapter): async def slash_reload_mcp(interaction: discord.Interaction): await self._run_simple_slash(interaction, "/reload-mcp") + @tree.command(name="reload-skills", description="Re-scan ~/.hermes/skills/ for new or removed skills") + async def slash_reload_skills(interaction: discord.Interaction): + await self._run_simple_slash(interaction, "/reload-skills") + @tree.command(name="voice", description="Toggle voice reply mode") @discord.app_commands.describe(mode="Voice mode: on, off, tts, channel, leave, or status") @discord.app_commands.choices(mode=[ @@ -2906,6 +2910,43 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: return SendResult(success=False, error=str(e)) + async def send_slash_confirm( + self, chat_id: str, title: str, message: str, session_key: str, + confirm_id: str, metadata: Optional[dict] = None, + ) -> SendResult: + """Send a three-button slash-command confirmation prompt.""" + if not self._client or not DISCORD_AVAILABLE: + return SendResult(success=False, error="Not connected") + + try: + target_id = chat_id + if metadata and metadata.get("thread_id"): + target_id = metadata["thread_id"] + + channel = self._client.get_channel(int(target_id)) + if not channel: + channel = await self._client.fetch_channel(int(target_id)) + + # Embed description limit is 4096; message usually fits easily. + max_desc = 4088 + body = message if len(message) <= max_desc else message[: max_desc - 3] + "..." + embed = discord.Embed( + title=title or "Confirm", + description=body, + color=discord.Color.orange(), + ) + + view = SlashConfirmView( + session_key=session_key, + confirm_id=confirm_id, + allowed_user_ids=self._allowed_user_ids, + ) + + msg = await channel.send(embed=embed, view=view) + return SendResult(success=True, message_id=str(msg.id)) + except Exception as e: + return SendResult(success=False, error=str(e)) + async def send_update_prompt( self, chat_id: str, prompt: str, default: str = "", session_key: str = "", @@ -3639,6 +3680,103 @@ if DISCORD_AVAILABLE: for child in self.children: child.disabled = True + class SlashConfirmView(discord.ui.View): + """Three-button view for generic slash-command confirmations. + + Used by ``/reload-mcp`` and any future slash command routed through + ``GatewayRunner._request_slash_confirm``. Buttons map to the + gateway's three choices: + + * "Approve Once" → ``choice="once"`` + * "Always Approve" → ``choice="always"`` + * "Cancel" → ``choice="cancel"`` + + Clicking calls the module-level + ``tools.slash_confirm.resolve(session_key, confirm_id, choice)`` + which runs the handler the runner stored for this ``session_key``. + Only users in the adapter's allowlist can click. Times out after + 5 minutes (matches the gateway primitive's timeout). + """ + + def __init__(self, session_key: str, confirm_id: str, allowed_user_ids: set): + super().__init__(timeout=300) + self.session_key = session_key + self.confirm_id = confirm_id + self.allowed_user_ids = allowed_user_ids + self.resolved = False + + def _check_auth(self, interaction: discord.Interaction) -> bool: + if not self.allowed_user_ids: + return True + return str(interaction.user.id) in self.allowed_user_ids + + async def _resolve( + self, interaction: discord.Interaction, choice: str, + color: discord.Color, label: str, + ): + if self.resolved: + await interaction.response.send_message( + "This prompt has already been resolved~", ephemeral=True, + ) + return + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized to answer this prompt~", ephemeral=True, + ) + return + + self.resolved = True + + embed = interaction.message.embeds[0] if interaction.message.embeds else None + if embed: + embed.color = color + embed.set_footer(text=f"{label} by {interaction.user.display_name}") + + for child in self.children: + child.disabled = True + + await interaction.response.edit_message(embed=embed, view=self) + + # Resolve via the module-level primitive. If the handler + # returns a follow-up message, post it in the same channel. + try: + from tools import slash_confirm as _slash_confirm_mod + result_text = await _slash_confirm_mod.resolve( + self.session_key, self.confirm_id, choice, + ) + if result_text: + await interaction.followup.send(result_text) + logger.info( + "Discord button resolved slash-confirm for session %s " + "(choice=%s, user=%s)", + self.session_key, choice, interaction.user.display_name, + ) + except Exception as exc: + logger.error("Discord slash-confirm resolve failed: %s", exc, exc_info=True) + + @discord.ui.button(label="Approve Once", style=discord.ButtonStyle.green) + async def approve_once( + self, interaction: discord.Interaction, button: discord.ui.Button, + ): + await self._resolve(interaction, "once", discord.Color.green(), "Approved once") + + @discord.ui.button(label="Always Approve", style=discord.ButtonStyle.blurple) + async def approve_always( + self, interaction: discord.Interaction, button: discord.ui.Button, + ): + await self._resolve(interaction, "always", discord.Color.purple(), "Always approved") + + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.red) + async def cancel( + self, interaction: discord.Interaction, button: discord.ui.Button, + ): + await self._resolve(interaction, "cancel", discord.Color.greyple(), "Cancelled") + + async def on_timeout(self): + self.resolved = True + for child in self.children: + child.disabled = True + class UpdatePromptView(discord.ui.View): """Interactive Yes/No buttons for ``hermes update`` prompts. diff --git a/gateway/platforms/qqbot/adapter.py b/gateway/platforms/qqbot/adapter.py index 28a297944d..10e1f62e72 100644 --- a/gateway/platforms/qqbot/adapter.py +++ b/gateway/platforms/qqbot/adapter.py @@ -976,6 +976,18 @@ class QQAdapter(BasePlatformAdapter): if not channel_id: return + # Apply group_policy ACL — guild channels are group-like contexts. + # Without this check any member of any guild the bot is in could + # bypass the configured allowlist. + guild_id = str(d.get("guild_id", "")) + author_id = str(author.get("id", "")) + if not self._is_group_allowed(guild_id or channel_id, author_id): + logger.debug( + "[%s] Guild message blocked by ACL: channel=%s user=%s", + self._log_tag, channel_id, author_id, + ) + return + member = d.get("member") if isinstance(d.get("member"), dict) else {} nick = str(member.get("nick", "")) or str(author.get("username", "")) @@ -1032,6 +1044,17 @@ class QQAdapter(BasePlatformAdapter): if not guild_id: return + # Apply dm_policy ACL — guild DMs were previously unauthenticated. + # Without this check any member of any guild the bot is in could + # bypass the configured allowlist via direct messages. + author_id = str(author.get("id", "")) + if not self._is_dm_allowed(author_id): + logger.debug( + "[%s] Guild DM blocked by ACL: guild=%s user=%s", + self._log_tag, guild_id, author_id, + ) + return + text = content att_result = await self._process_attachments(d.get("attachments")) image_urls = att_result["image_urls"] diff --git a/gateway/platforms/signal.py b/gateway/platforms/signal.py index 9a0a6256a4..3dd1e349bd 100644 --- a/gateway/platforms/signal.py +++ b/gateway/platforms/signal.py @@ -31,6 +31,7 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, cache_image_from_bytes, cache_audio_from_bytes, @@ -162,6 +163,10 @@ class SignalAdapter(BasePlatformAdapter): """Signal messenger adapter using signal-cli HTTP daemon.""" platform = Platform.SIGNAL + # Signal has no real edit API for already-sent messages. Mark it explicitly + # so streaming suppresses the visible cursor instead of leaving a stale tofu + # square behind in chat clients when edit attempts fail. + SUPPORTS_MESSAGE_EDITING = False def __init__(self, config: PlatformConfig): super().__init__(config, Platform.SIGNAL) @@ -488,6 +493,11 @@ class SignalAdapter(BasePlatformAdapter): if text and mentions: text = _render_mentions(text, mentions) + # Extract quote (reply-to) context from Signal dataMessage + quote_data = data_message.get("quote") or {} + reply_to_id = str(quote_data.get("id")) if quote_data.get("id") else None + reply_to_text = quote_data.get("text") + # Process attachments attachments_data = data_message.get("attachments", []) media_urls = [] @@ -541,7 +551,9 @@ class SignalAdapter(BasePlatformAdapter): else: timestamp = datetime.now(tz=timezone.utc) - # Build and dispatch event + # Build and dispatch event. + # Store raw envelope data in raw_message so on_processing_start/complete + # can extract targetAuthor + targetTimestamp for sendReaction. event = MessageEvent( source=source, text=text or "", @@ -549,6 +561,9 @@ class SignalAdapter(BasePlatformAdapter): media_urls=media_urls, media_types=media_types, timestamp=timestamp, + raw_message={"sender": sender, "timestamp_ms": ts_ms}, + reply_to_message_id=reply_to_id, + reply_to_text=reply_to_text, ) logger.debug("Signal: message from %s in %s: %s", @@ -707,6 +722,159 @@ class SignalAdapter(BasePlatformAdapter): logger.debug("Signal RPC %s failed: %s", method, e) return None + # ------------------------------------------------------------------ + # Formatting — markdown → Signal body ranges + # ------------------------------------------------------------------ + + @staticmethod + def _markdown_to_signal(text: str) -> tuple: + """Convert markdown to plain text + Signal textStyles list. + + Signal doesn't render markdown. Instead it uses ``bodyRanges`` + (exposed by signal-cli as ``textStyle`` / ``textStyles`` params) + with the format ``start:length:STYLE``. + + Positions are measured in **UTF-16 code units** (not Python code + points) because that's what the Signal protocol uses. + + Supported styles: BOLD, ITALIC, STRIKETHROUGH, MONOSPACE. + (Signal's SPOILER style is not currently mapped — no standard + markdown syntax for it; would need ``||spoiler||`` parsing.) + + Returns ``(plain_text, styles_list)`` where *styles_list* may be + empty if there's nothing to format. + """ + import re + + def _utf16_len(s: str) -> int: + """Length of *s* in UTF-16 code units.""" + return len(s.encode("utf-16-le")) // 2 + + # Pre-process: normalize whitespace before any position tracking + # so later operations don't invalidate recorded offsets. + text = re.sub(r"\n{3,}", "\n\n", text) + text = text.strip() + + styles: list = [] + + # --- Phase 1: fenced code blocks ```...``` → MONOSPACE --- + _CB = re.compile(r"```[a-zA-Z0-9_+-]*\n?(.*?)```", re.DOTALL) + while m := _CB.search(text): + inner = m.group(1).rstrip("\n") + start = m.start() + text = text[: m.start()] + inner + text[m.end() :] + styles.append((start, len(inner), "MONOSPACE")) + + # --- Phase 2: heading markers # Foo → Foo (BOLD) --- + _HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) + new_text = "" + last_end = 0 + for m in _HEADING.finditer(text): + new_text += text[last_end : m.start()] + last_end = m.end() + eol = text.find("\n", m.end()) + if eol == -1: + eol = len(text) + heading_text = text[m.end() : eol] + start = len(new_text) + new_text += heading_text + styles.append((start, len(heading_text), "BOLD")) + last_end = eol + new_text += text[last_end:] + text = new_text + + # --- Phase 3: inline patterns (single-pass to avoid offset drift) --- + # The old code processed each pattern sequentially, stripping markers + # and recording positions per-pass. Later passes shifted text without + # adjusting earlier positions → bold/italic landed mid-word. + # + # Fix: collect ALL non-overlapping matches first, then strip every + # marker in one pass so positions are computed against the final text. + _PATTERNS = [ + (re.compile(r"\*\*(.+?)\*\*", re.DOTALL), "BOLD"), + (re.compile(r"__(.+?)__", re.DOTALL), "BOLD"), + (re.compile(r"~~(.+?)~~", re.DOTALL), "STRIKETHROUGH"), + (re.compile(r"`(.+?)`"), "MONOSPACE"), + (re.compile(r"(? os for os, oe in occupied): + all_matches.append((ms, me, m.start(1), m.end(1), style)) + occupied.append((ms, me)) + all_matches.sort() + + # Build removal list so we can adjust Phase 1/2 styles. + # Each match removes its prefix markers (start..g1_start) and + # suffix markers (g1_end..end). + removals: list = [] # (position, length) sorted + for ms, me, g1s, g1e, _ in all_matches: + if g1s > ms: + removals.append((ms, g1s - ms)) + if me > g1e: + removals.append((g1e, me - g1e)) + removals.sort() + + # Adjust Phase 1/2 styles for characters about to be removed. + def _adj(pos: int) -> int: + shift = 0 + for rp, rl in removals: + if rp < pos: + shift += min(rl, pos - rp) + else: + break + return pos - shift + + adjusted_prior: list = [] + for s, l, st in styles: + ns = _adj(s) + ne = _adj(s + l) + if ne > ns: + adjusted_prior.append((ns, ne - ns, st)) + + # Strip all inline markers in one pass → positions are correct. + result = "" + last_end = 0 + inline_styles: list = [] + for ms, me, g1s, g1e, sty in all_matches: + result += text[last_end:ms] + pos = len(result) + inner = text[g1s:g1e] + result += inner + inline_styles.append((pos, len(inner), sty)) + last_end = me + result += text[last_end:] + text = result + + styles = adjusted_prior + inline_styles + + # Convert code-point offsets → UTF-16 code-unit offsets + style_strings = [] + for cp_start, cp_len, stype in sorted(styles): + # Safety: skip any out-of-bounds styles + if cp_start < 0 or cp_start + cp_len > len(text): + continue + u16_start = _utf16_len(text[:cp_start]) + u16_len = _utf16_len(text[cp_start : cp_start + cp_len]) + style_strings.append(f"{u16_start}:{u16_len}:{stype}") + + return text, style_strings + + def format_message(self, content: str) -> str: + """Strip markdown for plain-text fallback (used by base class). + + The actual rich formatting happens in send() via _markdown_to_signal(). + """ + # This is only called if someone uses the base-class send path. + # Our send() override bypasses this entirely. + return content + # ------------------------------------------------------------------ # Sending # ------------------------------------------------------------------ @@ -718,14 +886,22 @@ class SignalAdapter(BasePlatformAdapter): reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: - """Send a text message.""" + """Send a text message with native Signal formatting.""" await self._stop_typing_indicator(chat_id) + plain_text, text_styles = self._markdown_to_signal(content) + params: Dict[str, Any] = { "account": self.account, - "message": content, + "message": plain_text, } + if text_styles: + if len(text_styles) == 1: + params["textStyle"] = text_styles[0] + else: + params["textStyles"] = text_styles + if chat_id.startswith("group:"): params["groupId"] = chat_id[6:] else: @@ -735,11 +911,10 @@ class SignalAdapter(BasePlatformAdapter): if result is not None: self._track_sent_timestamp(result) - # Use the timestamp from the RPC result as a pseudo message_id. - # Signal doesn't have real message IDs, but the stream consumer - # needs a truthy value to follow its edit→fallback path correctly. - _msg_id = str(result.get("timestamp", "")) if isinstance(result, dict) else None - return SendResult(success=True, message_id=_msg_id or None) + # Signal has no editable message identifier. Returning None keeps the + # stream consumer on the non-edit fallback path instead of pretending + # future edits can remove an in-progress cursor from the chat thread. + return SendResult(success=True, message_id=None) return SendResult(success=False, error="RPC send failed") def _track_sent_timestamp(self, rpc_result) -> None: @@ -963,6 +1138,110 @@ class SignalAdapter(BasePlatformAdapter): _keep_typing finally block to clean up platform-level typing tasks.""" await self._stop_typing_indicator(chat_id) + # ------------------------------------------------------------------ + # Reactions + # ------------------------------------------------------------------ + + async def send_reaction( + self, + chat_id: str, + emoji: str, + target_author: str, + target_timestamp: int, + ) -> bool: + """Send a reaction emoji to a specific message via signal-cli RPC. + + Args: + chat_id: The chat (phone number or "group:") + emoji: Reaction emoji string (e.g. "👀", "✅") + target_author: Phone number / UUID of the message author + target_timestamp: Signal timestamp (ms) of the message to react to + """ + params: Dict[str, Any] = { + "account": self.account, + "emoji": emoji, + "targetAuthor": target_author, + "targetTimestamp": target_timestamp, + } + + if chat_id.startswith("group:"): + params["groupId"] = chat_id[6:] + else: + params["recipient"] = [chat_id] + + result = await self._rpc("sendReaction", params) + if result is not None: + return True + logger.debug("Signal: sendReaction failed (chat=%s, emoji=%s)", chat_id[:20], emoji) + return False + + async def remove_reaction( + self, + chat_id: str, + target_author: str, + target_timestamp: int, + ) -> bool: + """Remove a reaction by sending an empty-string emoji.""" + params: Dict[str, Any] = { + "account": self.account, + "emoji": "", + "targetAuthor": target_author, + "targetTimestamp": target_timestamp, + "remove": True, + } + + if chat_id.startswith("group:"): + params["groupId"] = chat_id[6:] + else: + params["recipient"] = [chat_id] + + result = await self._rpc("sendReaction", params) + return result is not None + + # ------------------------------------------------------------------ + # Processing Lifecycle Hooks (reactions as progress indicators) + # ------------------------------------------------------------------ + + def _extract_reaction_target(self, event: MessageEvent) -> Optional[tuple]: + """Extract (target_author, target_timestamp) from a MessageEvent. + + Returns None if the event doesn't carry the raw Signal envelope data + needed for sendReaction. + """ + raw = event.raw_message + if not isinstance(raw, dict): + return None + author = raw.get("sender") + ts = raw.get("timestamp_ms") + if not author or not ts: + return None + return (author, ts) + + async def on_processing_start(self, event: MessageEvent) -> None: + """React with 👀 when processing begins.""" + target = self._extract_reaction_target(event) + if target: + await self.send_reaction(event.source.chat_id, "👀", *target) + + async def on_processing_complete(self, event: MessageEvent, outcome: "ProcessingOutcome") -> None: + """Swap the 👀 reaction for ✅ (success) or ❌ (failure). + + On CANCELLED we leave the 👀 in place — no terminal outcome means + the reaction should keep reflecting "in progress" (matches Telegram). + """ + if outcome == ProcessingOutcome.CANCELLED: + return + target = self._extract_reaction_target(event) + if not target: + return + chat_id = event.source.chat_id + # Remove the in-progress reaction, then add the final one + await self.remove_reaction(chat_id, *target) + if outcome == ProcessingOutcome.SUCCESS: + await self.send_reaction(chat_id, "✅", *target) + elif outcome == ProcessingOutcome.FAILURE: + await self.send_reaction(chat_id, "❌", *target) + # ------------------------------------------------------------------ # Chat Info # ------------------------------------------------------------------ diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index ea75130a9a..e18594f564 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -514,6 +514,15 @@ class SlackAdapter(BasePlatformAdapter): ): self._app.action(_action_id)(self._handle_approval_action) + # Register Block Kit action handlers for slash-confirm buttons + # (generic three-option prompts; see tools/slash_confirm.py). + for _action_id in ( + "hermes_confirm_once", + "hermes_confirm_always", + "hermes_confirm_cancel", + ): + self._app.action(_action_id)(self._handle_slash_confirm_action) + # Start Socket Mode handler in background self._handler = AsyncSocketModeHandler(self._app, app_token, proxy=proxy_url) _apply_slack_proxy(self._handler.client, proxy_url) @@ -1931,6 +1940,168 @@ class SlackAdapter(BasePlatformAdapter): logger.error("[Slack] send_exec_approval failed: %s", e, exc_info=True) return SendResult(success=False, error=str(e)) + async def send_slash_confirm( + self, chat_id: str, title: str, message: str, session_key: str, + confirm_id: str, metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a Block Kit three-option slash-command confirmation prompt.""" + if not self._app: + return SendResult(success=False, error="Not connected") + + try: + body = message[:2900] + "..." if len(message) > 2900 else message + thread_ts = self._resolve_thread_ts(None, metadata) + # Encode session_key and confirm_id into the button value so the + # callback handler can resolve without extra bookkeeping. + value = f"{session_key}|{confirm_id}" + + blocks = [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"*{title or 'Confirm'}*\n\n{body}", + }, + }, + { + "type": "actions", + "elements": [ + { + "type": "button", + "text": {"type": "plain_text", "text": "Approve Once"}, + "style": "primary", + "action_id": "hermes_confirm_once", + "value": value, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Always Approve"}, + "action_id": "hermes_confirm_always", + "value": value, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Cancel"}, + "style": "danger", + "action_id": "hermes_confirm_cancel", + "value": value, + }, + ], + }, + ] + + kwargs: Dict[str, Any] = { + "channel": chat_id, + "text": f"{title or 'Confirm'}: {body[:100]}", + "blocks": blocks, + } + if thread_ts: + kwargs["thread_ts"] = thread_ts + + result = await self._get_client(chat_id).chat_postMessage(**kwargs) + return SendResult(success=True, message_id=result.get("ts", ""), raw_response=result) + except Exception as e: + logger.error("[Slack] send_slash_confirm failed: %s", e, exc_info=True) + return SendResult(success=False, error=str(e)) + + async def _handle_slash_confirm_action(self, ack, body, action) -> None: + """Handle a slash-confirm button click from Block Kit.""" + await ack() + + action_id = action.get("action_id", "") + value = action.get("value", "") + message = body.get("message", {}) + msg_ts = message.get("ts", "") + channel_id = body.get("channel", {}).get("id", "") + user_name = body.get("user", {}).get("name", "unknown") + user_id = body.get("user", {}).get("id", "") + + # Authorization — reuse the exec-approval allowlist. + allowed_csv = os.getenv("SLACK_ALLOWED_USERS", "").strip() + if allowed_csv: + allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()} + if "*" not in allowed_ids and user_id not in allowed_ids: + logger.warning( + "[Slack] Unauthorized slash-confirm click by %s (%s) — ignoring", + user_name, user_id, + ) + return + + # Parse session_key|confirm_id back out + if "|" not in value: + logger.warning("[Slack] Malformed slash-confirm value: %s", value) + return + session_key, confirm_id = value.split("|", 1) + + choice_map = { + "hermes_confirm_once": "once", + "hermes_confirm_always": "always", + "hermes_confirm_cancel": "cancel", + } + choice = choice_map.get(action_id, "cancel") + + label_map = { + "once": f"✅ Approved once by {user_name}", + "always": f"🔒 Always approved by {user_name}", + "cancel": f"❌ Cancelled by {user_name}", + } + decision_text = label_map.get(choice, f"Resolved by {user_name}") + + # Pull original prompt body out of the section block so we can show + # the decision inline without losing context. + original_text = "" + for block in message.get("blocks", []): + if block.get("type") == "section": + original_text = block.get("text", {}).get("text", "") + break + + updated_blocks = [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": original_text or "Confirmation prompt", + }, + }, + { + "type": "context", + "elements": [ + {"type": "mrkdwn", "text": decision_text}, + ], + }, + ] + + try: + await self._get_client(channel_id).chat_update( + channel=channel_id, + ts=msg_ts, + text=decision_text, + blocks=updated_blocks, + ) + except Exception as e: + logger.warning("[Slack] Failed to update slash-confirm message: %s", e) + + # Resolve via the module-level primitive and post any follow-up. + try: + from tools import slash_confirm as _slash_confirm_mod + result_text = await _slash_confirm_mod.resolve(session_key, confirm_id, choice) + if result_text: + post_kwargs: Dict[str, Any] = { + "channel": channel_id, + "text": result_text, + } + # Inherit the thread so the reply stays in the same place. + thread_ts = message.get("thread_ts") or msg_ts + if thread_ts: + post_kwargs["thread_ts"] = thread_ts + await self._get_client(channel_id).chat_postMessage(**post_kwargs) + logger.info( + "Slack button resolved slash-confirm for session %s (choice=%s, user=%s)", + session_key, choice, user_name, + ) + except Exception as exc: + logger.error("Failed to resolve slash-confirm from Slack button: %s", exc, exc_info=True) + async def _handle_approval_action(self, ack, body, action) -> None: """Handle an approval button click from Block Kit.""" await ack() diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 09a70ccf51..b58ca45ec9 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -237,14 +237,14 @@ def _wrap_markdown_tables(text: str) -> str: class TelegramAdapter(BasePlatformAdapter): """ Telegram bot adapter. - + Handles: - Receiving messages from users and groups - Sending responses with Telegram markdown - Forum topics (thread_id support) - Media messages """ - + # Telegram message limits MAX_MESSAGE_LENGTH = 4096 # Threshold for detecting Telegram client-side message splits. @@ -252,7 +252,7 @@ class TelegramAdapter(BasePlatformAdapter): _SPLIT_THRESHOLD = 4000 MEDIA_GROUP_WAIT_SECONDS = 0.8 _GENERAL_TOPIC_THREAD_ID = "1" - + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.TELEGRAM) self._app: Optional[Application] = None @@ -286,6 +286,9 @@ class TelegramAdapter(BasePlatformAdapter): self._model_picker_state: Dict[str, dict] = {} # Approval button state: message_id → session_key self._approval_state: Dict[int, str] = {} + # Slash-confirm button state: confirm_id → session_key (for /reload-mcp + # and any other slash-confirm prompts; see GatewayRunner._request_slash_confirm). + self._slash_confirm_state: Dict[str, str] = {} @staticmethod def _is_callback_user_authorized(user_id: str) -> bool: @@ -994,7 +997,7 @@ class TelegramAdapter(BasePlatformAdapter): self._set_fatal_error("telegram_connect_error", message, retryable=True) logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True) return False - + async def disconnect(self) -> None: """Stop polling/webhook, cancel pending album flushes, and disconnect.""" pending_media_group_tasks = list(self._media_group_tasks.values()) @@ -1411,6 +1414,48 @@ class TelegramAdapter(BasePlatformAdapter): logger.warning("[%s] send_exec_approval failed: %s", self.name, e) return SendResult(success=False, error=str(e)) + async def send_slash_confirm( + self, chat_id: str, title: str, message: str, session_key: str, + confirm_id: str, metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Render a three-button slash-command confirmation prompt.""" + if not self._bot: + return SendResult(success=False, error="Not connected") + + try: + # Message body: render as plain text (message already contains + # markdown formatting from the gateway primitive). + preview = message if len(message) <= 3800 else message[:3800] + "..." + + keyboard = InlineKeyboardMarkup([ + [ + InlineKeyboardButton("✅ Approve Once", callback_data=f"sc:once:{confirm_id}"), + InlineKeyboardButton("🔒 Always Approve", callback_data=f"sc:always:{confirm_id}"), + ], + [ + InlineKeyboardButton("❌ Cancel", callback_data=f"sc:cancel:{confirm_id}"), + ], + ]) + + thread_id = self._metadata_thread_id(metadata) + kwargs: Dict[str, Any] = { + "chat_id": int(chat_id), + "text": preview, + "parse_mode": ParseMode.MARKDOWN, + "reply_markup": keyboard, + **self._link_preview_kwargs(), + } + message_thread_id = self._message_thread_id_for_send(thread_id) + if message_thread_id is not None: + kwargs["message_thread_id"] = message_thread_id + + msg = await self._bot.send_message(**kwargs) + self._slash_confirm_state[confirm_id] = session_key + return SendResult(success=True, message_id=str(msg.message_id)) + except Exception as e: + logger.warning("[%s] send_slash_confirm failed: %s", self.name, e) + return SendResult(success=False, error=str(e)) + async def send_model_picker( self, chat_id: str, @@ -1779,6 +1824,68 @@ class TelegramAdapter(BasePlatformAdapter): logger.error("Failed to resolve gateway approval from Telegram button: %s", exc) return + # --- Slash-confirm callbacks (sc:choice:confirm_id) --- + if data.startswith("sc:"): + parts = data.split(":", 2) + if len(parts) == 3: + choice = parts[1] # once, always, cancel + confirm_id = parts[2] + + caller_id = str(getattr(query.from_user, "id", "")) + if not self._is_callback_user_authorized(caller_id): + await query.answer(text="⛔ You are not authorized to answer this prompt.") + return + + session_key = self._slash_confirm_state.pop(confirm_id, None) + if not session_key: + await query.answer(text="This prompt has already been resolved.") + return + + label_map = { + "once": "✅ Approved once", + "always": "🔒 Always approve", + "cancel": "❌ Cancelled", + } + user_display = getattr(query.from_user, "first_name", "User") + label = label_map.get(choice, "Resolved") + + await query.answer(text=label) + + try: + await query.edit_message_text( + text=f"{label} by {user_display}", + parse_mode=ParseMode.MARKDOWN, + reply_markup=None, + ) + except Exception: + pass + + # Resolve via the module-level primitive. The runner stored + # a handler keyed by session_key; we run it on the event + # loop and (if it returns a string) send it as a follow-up + # message in the same chat. + try: + from tools import slash_confirm as _slash_confirm_mod + result_text = await _slash_confirm_mod.resolve( + session_key, confirm_id, choice, + ) + if result_text and query.message: + # Inherit the prompt message's thread so the reply + # lands in the same supergroup topic / reply chain. + thread_id = getattr(query.message, "message_thread_id", None) + send_kwargs: Dict[str, Any] = { + "chat_id": int(query.message.chat_id), + "text": result_text, + "parse_mode": ParseMode.MARKDOWN, + **self._link_preview_kwargs(), + } + if thread_id is not None: + send_kwargs["message_thread_id"] = thread_id + await self._bot.send_message(**send_kwargs) + except Exception as exc: + logger.error("[%s] slash-confirm callback failed: %s", self.name, exc, exc_info=True) + return + # --- Update prompt callbacks --- if not data.startswith("update_prompt:"): return @@ -1844,8 +1951,9 @@ class TelegramAdapter(BasePlatformAdapter): return SendResult(success=False, error=self._missing_media_path_error("Audio", audio_path)) with open(audio_path, "rb") as audio_file: - # .ogg files -> send as voice (round playable bubble) - if audio_path.endswith((".ogg", ".opus")): + ext = os.path.splitext(audio_path)[1].lower() + # .ogg / .opus files -> send as voice (round playable bubble) + if ext in (".ogg", ".opus"): _voice_thread = self._metadata_thread_id(metadata) msg = await self._bot.send_voice( chat_id=int(chat_id), @@ -1854,8 +1962,8 @@ class TelegramAdapter(BasePlatformAdapter): reply_to_message_id=int(reply_to) if reply_to else None, message_thread_id=self._message_thread_id_for_send(_voice_thread), ) - else: - # .mp3 and others -> send as audio file + elif ext in (".mp3", ".m4a"): + # Telegram's Bot API sendAudio only accepts MP3 / M4A. _audio_thread = self._metadata_thread_id(metadata) msg = await self._bot.send_audio( chat_id=int(chat_id), @@ -1864,6 +1972,16 @@ class TelegramAdapter(BasePlatformAdapter): reply_to_message_id=int(reply_to) if reply_to else None, message_thread_id=self._message_thread_id_for_send(_audio_thread), ) + else: + # Formats Telegram can't play natively (.wav, .flac, ...) + # — fall back to document delivery instead of raising. + return await self.send_document( + chat_id=chat_id, + file_path=audio_path, + caption=caption, + reply_to=reply_to, + metadata=metadata, + ) return SendResult(success=True, message_id=str(msg.message_id)) except Exception as e: logger.error( @@ -1873,7 +1991,7 @@ class TelegramAdapter(BasePlatformAdapter): exc_info=True, ) return await super().send_voice(chat_id, audio_path, caption, reply_to) - + async def send_image_file( self, chat_id: str, @@ -2040,7 +2158,7 @@ class TelegramAdapter(BasePlatformAdapter): ) # Final fallback: send URL as text return await super().send_image(chat_id, image_url, caption, reply_to) - + async def send_animation( self, chat_id: str, @@ -2102,7 +2220,7 @@ class TelegramAdapter(BasePlatformAdapter): e, exc_info=True, ) - + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Get information about a Telegram chat.""" if not self._bot: @@ -2136,7 +2254,7 @@ class TelegramAdapter(BasePlatformAdapter): exc_info=True, ) return {"name": str(chat_id), "type": "dm", "error": str(e)} - + def format_message(self, content: str) -> str: """ Convert standard markdown to Telegram MarkdownV2 format. @@ -2308,7 +2426,7 @@ class TelegramAdapter(BasePlatformAdapter): text = ''.join(_safe_parts) return text - + # ── Group mention gating ────────────────────────────────────────────── def _telegram_require_mention(self) -> bool: @@ -2523,7 +2641,7 @@ class TelegramAdapter(BasePlatformAdapter): event = self._build_message_event(update.message, MessageType.TEXT, update_id=update.update_id) event.text = self._clean_bot_trigger_text(event.text) self._enqueue_text_event(event) - + async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming command messages.""" if not update.message or not update.message.text: @@ -2533,7 +2651,7 @@ class TelegramAdapter(BasePlatformAdapter): event = self._build_message_event(update.message, MessageType.COMMAND, update_id=update.update_id) await self.handle_message(event) - + async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming location/venue pin messages.""" if not update.message: @@ -2891,7 +3009,7 @@ class TelegramAdapter(BasePlatformAdapter): return await self.handle_message(event) - + async def _queue_media_group_event(self, media_group_id: str, event: MessageEvent) -> None: """Buffer Telegram media-group items so albums arrive as one logical event. diff --git a/gateway/platforms/webhook.py b/gateway/platforms/webhook.py index e3a736a451..34e2dfa2c5 100644 --- a/gateway/platforms/webhook.py +++ b/gateway/platforms/webhook.py @@ -202,26 +202,22 @@ class WebhookAdapter(BasePlatformAdapter): if deliver_type == "github_comment": return await self._deliver_github_comment(content, delivery) - # Cross-platform delivery — any platform with a gateway adapter - if self.gateway_runner and deliver_type in ( - "telegram", - "discord", - "slack", - "signal", - "sms", - "whatsapp", - "matrix", - "mattermost", - "homeassistant", - "email", - "dingtalk", - "feishu", - "wecom", - "wecom_callback", - "weixin", - "bluebubbles", - "qqbot", - ): + # Cross-platform delivery — any platform with a gateway adapter. + # Check both built-in names and plugin-registered platforms. + _BUILTIN_DELIVER_PLATFORMS = { + "telegram", "discord", "slack", "signal", "sms", "whatsapp", + "matrix", "mattermost", "homeassistant", "email", "dingtalk", + "feishu", "wecom", "wecom_callback", "weixin", "bluebubbles", + "qqbot", "yuanbao", + } + _is_known_platform = deliver_type in _BUILTIN_DELIVER_PLATFORMS + if not _is_known_platform: + try: + from gateway.platform_registry import platform_registry + _is_known_platform = platform_registry.is_registered(deliver_type) + except Exception: + pass + if self.gateway_runner and _is_known_platform: return await self._deliver_cross_platform( deliver_type, content, delivery ) diff --git a/gateway/platforms/weixin.py b/gateway/platforms/weixin.py index 426d6e27ee..72b7d2a4df 100644 --- a/gateway/platforms/weixin.py +++ b/gateway/platforms/weixin.py @@ -92,6 +92,18 @@ SESSION_EXPIRED_ERRCODE = -14 RATE_LIMIT_ERRCODE = -2 # iLink frequency limit — backoff and retry MESSAGE_DEDUP_TTL_SECONDS = 300 + +def _is_stale_session_ret( + ret: "Optional[int]", errcode: "Optional[int]", errmsg: "Optional[str]", +) -> bool: + """True when iLink returns ret=-2 / errcode=-2 with 'unknown error', + which is a stale-session signal (same as errcode=-14) rather than + a genuine rate limit.""" + if ret != RATE_LIMIT_ERRCODE and errcode != RATE_LIMIT_ERRCODE: + return False + return (errmsg or "").lower() == "unknown error" + + MEDIA_IMAGE = 1 MEDIA_VIDEO = 2 MEDIA_FILE = 3 @@ -1210,6 +1222,17 @@ class WeixinAdapter(BasePlatformAdapter): self._mark_connected() _LIVE_ADAPTERS[self._token] = self logger.info("[%s] Connected account=%s base=%s", self.name, _safe_id(self._account_id), self._base_url) + if self._group_policy != "disabled": + logger.warning( + "[%s] WEIXIN_GROUP_POLICY=%s is set, but QR-login connects an iLink bot " + "identity (e.g. ...@im.bot) which typically cannot be invited into ordinary " + "WeChat groups. iLink usually does not deliver ordinary-group events for " + "these accounts, so group messages may never reach Hermes regardless of this " + "policy. If group delivery doesn't work, the limitation is on the iLink side, " + "not in Hermes.", + self.name, + self._group_policy, + ) return True async def disconnect(self) -> None: @@ -1254,7 +1277,8 @@ class WeixinAdapter(BasePlatformAdapter): ret = response.get("ret", 0) errcode = response.get("errcode", 0) if ret not in (0, None) or errcode not in (0, None): - if ret == SESSION_EXPIRED_ERRCODE or errcode == SESSION_EXPIRED_ERRCODE: + if (ret == SESSION_EXPIRED_ERRCODE or errcode == SESSION_EXPIRED_ERRCODE + or _is_stale_session_ret(ret, errcode, response.get("errmsg"))): logger.error("[%s] Session expired; pausing for 10 minutes", self.name) await asyncio.sleep(600) consecutive_failures = 0 @@ -1519,6 +1543,7 @@ class WeixinAdapter(BasePlatformAdapter): is_session_expired = ( ret == SESSION_EXPIRED_ERRCODE or errcode == SESSION_EXPIRED_ERRCODE + or _is_stale_session_ret(ret, errcode, resp.get("errmsg")) ) # Session expired — strip token and retry once if is_session_expired and not retried_without_token and context_token: @@ -1595,7 +1620,7 @@ class WeixinAdapter(BasePlatformAdapter): _, image_cleaned = self.extract_images(cleaned_content) local_files, final_content = self.extract_local_files(image_cleaned) - _AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"} + _AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a", ".flac"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".3gp"} _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"} diff --git a/gateway/run.py b/gateway/run.py index 7cf2d5901a..19dc5eae74 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -38,6 +38,7 @@ from typing import Dict, Optional, Any, List # gateway is a long-running daemon, so its boot cost matters less than # preserving the established test-patch surface. from agent.account_usage import fetch_account_usage, render_account_usage_lines +from hermes_cli.config import cfg_get # --- Agent cache tuning --------------------------------------------------- # Bounds the per-session AIAgent cache to prevent unbounded growth in @@ -46,6 +47,7 @@ from agent.account_usage import fetch_account_usage, render_account_usage_lines # from _enforce_agent_cache_cap() and _session_expiry_watcher() below. _AGENT_CACHE_MAX_SIZE = 128 _AGENT_CACHE_IDLE_TTL_SECS = 3600.0 # evict agents idle for >1h +_PLATFORM_CONNECT_TIMEOUT_SECS_DEFAULT = 30.0 # Only auto-continue interrupted gateway turns while the interruption is fresh. # Stale tool-tail/resume markers can otherwise revive an unrelated old task # after a gateway restart when the user's next message starts new work. @@ -265,6 +267,7 @@ if _config_path.exists(): "singularity_image": "TERMINAL_SINGULARITY_IMAGE", "modal_image": "TERMINAL_MODAL_IMAGE", "daytona_image": "TERMINAL_DAYTONA_IMAGE", + "vercel_runtime": "TERMINAL_VERCEL_RUNTIME", "ssh_host": "TERMINAL_SSH_HOST", "ssh_user": "TERMINAL_SSH_USER", "ssh_port": "TERMINAL_SSH_PORT", @@ -274,6 +277,8 @@ if _config_path.exists(): "container_disk": "TERMINAL_CONTAINER_DISK", "container_persistent": "TERMINAL_CONTAINER_PERSISTENT", "docker_volumes": "TERMINAL_DOCKER_VOLUMES", + "docker_mount_cwd_to_workspace": "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", + "docker_run_as_host_user": "TERMINAL_DOCKER_RUN_AS_HOST_USER", "sandbox_dir": "TERMINAL_SANDBOX_DIR", "persistent_shell": "TERMINAL_PERSISTENT_SHELL", } @@ -416,6 +421,7 @@ if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"): from gateway.config import ( Platform, + _BUILTIN_PLATFORM_VALUES, GatewayConfig, load_gateway_config, ) @@ -777,6 +783,13 @@ def _format_gateway_process_notification(evt: dict) -> "str | None": return None +# Module-level weak reference to the active GatewayRunner instance. +# Used by tools (e.g. send_message) that need to route through a live +# adapter for plugin platforms. Set in GatewayRunner.__init__(). +import weakref as _weakref +_gateway_runner_ref: _weakref.ref = lambda: None + + class GatewayRunner: """ Main gateway controller. @@ -799,11 +812,13 @@ class GatewayRunner: _stop_task: Optional[asyncio.Task] = None _session_model_overrides: Dict[str, Dict[str, str]] = {} _session_reasoning_overrides: Dict[str, Dict[str, Any]] = {} - + def __init__(self, config: Optional[GatewayConfig] = None): + global _gateway_runner_ref self.config = config or load_gateway_config() self.adapters: Dict[Platform, BasePlatformAdapter] = {} self._warn_if_docker_media_delivery_is_risky() + _gateway_runner_ref = _weakref.ref(self) # Load ephemeral config from config.yaml / env vars. # Both are injected at API-call time only and never persisted. @@ -887,6 +902,14 @@ class GatewayRunner: # Key: session_key, Value: True when a prompt is waiting for user input. self._update_prompt_pending: Dict[str, bool] = {} + # Slash-confirm state lives in tools.slash_confirm (module-level), + # so platform adapters can resolve callbacks without a backref to + # this runner. Keep a local counter for confirm_id generation so + # IDs stay compact (button callback_data has a 64-byte cap on + # some platforms). + import itertools as _itertools + self._slash_confirm_counter = _itertools.count(1) + # Persistent Honcho managers keyed by gateway session key. # This preserves write_frequency="session" semantics across short-lived # per-message AIAgent instances. @@ -1157,6 +1180,33 @@ class GatewayRunner: e, ) + def _platform_connect_timeout_secs(self) -> float: + """Return the per-platform connect timeout used during startup/retry.""" + raw = os.getenv("HERMES_GATEWAY_PLATFORM_CONNECT_TIMEOUT", "").strip() + if raw: + try: + timeout = float(raw) + except ValueError: + logger.warning( + "Ignoring invalid HERMES_GATEWAY_PLATFORM_CONNECT_TIMEOUT=%r", + raw, + ) + else: + return max(0.0, timeout) + return _PLATFORM_CONNECT_TIMEOUT_SECS_DEFAULT + + async def _connect_adapter_with_timeout(self, adapter, platform) -> bool: + """Connect an adapter without allowing one platform to block others.""" + timeout = self._platform_connect_timeout_secs() + if timeout <= 0: + return await adapter.connect() + try: + return await asyncio.wait_for(adapter.connect(), timeout=timeout) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"{platform.value} connect timed out after {timeout:g}s" + ) from exc + @property def should_exit_cleanly(self) -> bool: return self._exit_cleanly @@ -1494,7 +1544,7 @@ class GatewayRunner: ) except Exception: pass - + @staticmethod def _load_prefill_messages() -> List[Dict[str, Any]]: """Load ephemeral prefill messages from config or env var. @@ -1549,7 +1599,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - return (cfg.get("agent", {}).get("system_prompt", "") or "").strip() + return (cfg_get(cfg, "agent", "system_prompt", default="") or "").strip() except Exception: pass return "" @@ -1570,7 +1620,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - effort = str(cfg.get("agent", {}).get("reasoning_effort", "") or "").strip() + effort = str(cfg_get(cfg, "agent", "reasoning_effort", default="") or "").strip() except Exception: pass result = parse_reasoning_effort(effort) @@ -1653,7 +1703,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - raw = str(cfg.get("agent", {}).get("service_tier", "") or "").strip() + raw = str(cfg_get(cfg, "agent", "service_tier", default="") or "").strip() except Exception: pass @@ -1674,7 +1724,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - return bool(cfg.get("display", {}).get("show_reasoning", False)) + return bool(cfg_get(cfg, "display", "show_reasoning", default=False)) except Exception: pass return False @@ -1690,7 +1740,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower() + mode = str(cfg_get(cfg, "display", "busy_input_mode", default="") or "").strip().lower() except Exception: pass if mode == "queue": @@ -1710,7 +1760,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - raw = str(cfg.get("agent", {}).get("restart_drain_timeout", "") or "").strip() + raw = str(cfg_get(cfg, "agent", "restart_drain_timeout", default="") or "").strip() except Exception: pass value = parse_restart_drain_timeout(raw) @@ -1743,7 +1793,7 @@ class GatewayRunner: if cfg_path.exists(): with open(cfg_path, encoding="utf-8") as _f: cfg = _y.safe_load(_f) or {} - raw = cfg.get("display", {}).get("background_process_notifications") + raw = cfg_get(cfg, "display", "background_process_notifications") if raw is False: mode = "off" elif raw not in (None, ""): @@ -2308,37 +2358,61 @@ class GatewayRunner: pass # Warn if no user allowlists are configured and open access is not opted in + _builtin_allowed_vars = ( + "TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", + "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", + "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", + "TELEGRAM_GROUP_ALLOWED_USERS", + "TELEGRAM_GROUP_ALLOWED_CHATS", + "EMAIL_ALLOWED_USERS", + "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", + "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", + "FEISHU_ALLOWED_USERS", + "WECOM_ALLOWED_USERS", + "WECOM_CALLBACK_ALLOWED_USERS", + "WEIXIN_ALLOWED_USERS", + "BLUEBUBBLES_ALLOWED_USERS", + "QQ_ALLOWED_USERS", + "YUANBAO_ALLOWED_USERS", + "GATEWAY_ALLOWED_USERS", + ) + _builtin_allow_all_vars = ( + "TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS", + "WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS", + "SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS", + "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", + "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", + "FEISHU_ALLOW_ALL_USERS", + "WECOM_ALLOW_ALL_USERS", + "WECOM_CALLBACK_ALLOW_ALL_USERS", + "WEIXIN_ALLOW_ALL_USERS", + "BLUEBUBBLES_ALLOW_ALL_USERS", + "QQ_ALLOW_ALL_USERS", + "YUANBAO_ALLOW_ALL_USERS", + ) + # Also pick up plugin-registered platforms — each entry can declare + # its own allowed_users_env / allow_all_env, so the warning stays + # accurate as plugins like IRC come online. + _plugin_allowed_vars: tuple = () + _plugin_allow_all_vars: tuple = () + try: + from gateway.platform_registry import platform_registry + _plugin_allowed_vars = tuple( + e.allowed_users_env for e in platform_registry.plugin_entries() + if e.allowed_users_env + ) + _plugin_allow_all_vars = tuple( + e.allow_all_env for e in platform_registry.plugin_entries() + if e.allow_all_env + ) + except Exception: + pass _any_allowlist = any( - os.getenv(v) - for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", - "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", - "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", - "EMAIL_ALLOWED_USERS", - "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", - "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", - "FEISHU_ALLOWED_USERS", - "WECOM_ALLOWED_USERS", - "WECOM_CALLBACK_ALLOWED_USERS", - "WEIXIN_ALLOWED_USERS", - "BLUEBUBBLES_ALLOWED_USERS", - "QQ_ALLOWED_USERS", - "YUANBAO_ALLOWED_USERS", - "GATEWAY_ALLOWED_USERS") + os.getenv(v) for v in _builtin_allowed_vars + _plugin_allowed_vars ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any( os.getenv(v, "").lower() in ("true", "1", "yes") - for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS", - "WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS", - "SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS", - "SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS", - "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", - "FEISHU_ALLOW_ALL_USERS", - "WECOM_ALLOW_ALL_USERS", - "WECOM_CALLBACK_ALLOW_ALL_USERS", - "WEIXIN_ALLOW_ALL_USERS", - "BLUEBUBBLES_ALLOW_ALL_USERS", - "QQ_ALLOW_ALL_USERS", - "YUANBAO_ALLOW_ALL_USERS") + for v in _builtin_allow_all_vars + _plugin_allow_all_vars ) if not _any_allowlist and not _allow_all: logger.warning( @@ -2441,7 +2515,17 @@ class GatewayRunner: adapter = self._create_adapter(platform, platform_config) if not adapter: - logger.warning("No adapter available for %s", platform.value) + # Distinguish between missing builtin deps and missing plugin + _pval = platform.value + _builtin_names = {m.value for m in Platform.__members__.values()} + if _pval not in _builtin_names: + logger.warning( + "No adapter for '%s' — is the plugin installed? " + "(platform is enabled in config.yaml but no plugin registered it)", + _pval, + ) + else: + logger.warning("No adapter available for %s", _pval) continue # Set up message + fatal error handlers @@ -2459,7 +2543,7 @@ class GatewayRunner: error_message=None, ) try: - success = await adapter.connect() + success = await self._connect_adapter_with_timeout(adapter, platform) if success: self.adapters[platform] = adapter self._sync_voice_mode_state_to_adapter(adapter) @@ -2629,7 +2713,7 @@ class GatewayRunner: logger.info("Press Ctrl+C to stop") return True - + async def _session_expiry_watcher(self, interval: int = 300): """Background task that finalizes expired sessions. @@ -2850,7 +2934,7 @@ class GatewayRunner: adapter.set_session_store(self.session_store) adapter.set_busy_session_handler(self._handle_active_session_busy_message) - success = await adapter.connect() + success = await self._connect_adapter_with_timeout(adapter, platform) if success: self.adapters[platform] = adapter self._sync_voice_mode_state_to_adapter(adapter) @@ -3170,17 +3254,21 @@ class GatewayRunner: self._stop_task = asyncio.create_task(_stop_impl()) await self._stop_task - + async def wait_for_shutdown(self) -> None: """Wait for shutdown signal.""" await self._shutdown_event.wait() - + def _create_adapter( self, platform: Platform, config: Any ) -> Optional[BasePlatformAdapter]: - """Create the appropriate adapter for a platform.""" + """Create the appropriate adapter for a platform. + + Checks the platform_registry first (plugin adapters), then falls + through to the built-in if/elif chain for core platforms. + """ if hasattr(config, "extra") and isinstance(config.extra, dict): config.extra.setdefault( "group_sessions_per_user", @@ -3191,6 +3279,25 @@ class GatewayRunner: getattr(self.config, "thread_sessions_per_user", False), ) + # ── Plugin-registered platforms (checked first) ─────────────────── + try: + from gateway.platform_registry import platform_registry + if platform_registry.is_registered(platform.value): + adapter = platform_registry.create_adapter(platform.value, config) + if adapter is not None: + return adapter + # Registered but failed to instantiate — don't silently fall + # through to built-ins (there are none for plugin platforms). + logger.error( + "Platform '%s' is registered but adapter creation failed " + "(check dependencies and config)", + platform.value, + ) + return None + except Exception as e: + logger.debug("Platform registry lookup for '%s' failed: %s", platform.value, e) + # Fall through to built-in adapters below + if platform == Platform.TELEGRAM: from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements if not check_telegram_requirements(): @@ -3379,8 +3486,11 @@ class GatewayRunner: Platform.QQBOT: "QQ_ALLOWED_USERS", Platform.YUANBAO: "YUANBAO_ALLOWED_USERS", } - platform_group_env_map = { + platform_group_user_env_map = { Platform.TELEGRAM: "TELEGRAM_GROUP_ALLOWED_USERS", + } + platform_group_chat_env_map = { + Platform.TELEGRAM: "TELEGRAM_GROUP_ALLOWED_CHATS", Platform.QQBOT: "QQ_GROUP_ALLOWED_USERS", } platform_allow_all_map = { @@ -3403,6 +3513,19 @@ class GatewayRunner: Platform.YUANBAO: "YUANBAO_ALLOW_ALL_USERS", } + # Plugin platforms: check the registry for auth env var names + if source.platform not in platform_env_map: + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(source.platform.value) + if entry: + if entry.allowed_users_env: + platform_env_map[source.platform] = entry.allowed_users_env + if entry.allow_all_env: + platform_allow_all_map[source.platform] = entry.allow_all_env + except Exception: + pass + # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) platform_allow_all_var = platform_allow_all_map.get(source.platform, "") if platform_allow_all_var and os.getenv(platform_allow_all_var, "").lower() in ("true", "1", "yes"): @@ -3437,27 +3560,66 @@ class GatewayRunner: # Check platform-specific and global allowlists platform_allowlist = os.getenv(platform_env_map.get(source.platform, ""), "").strip() - group_allowlist = "" + group_user_allowlist = "" + group_chat_allowlist = "" if source.chat_type in {"group", "forum"}: - group_allowlist = os.getenv(platform_group_env_map.get(source.platform, ""), "").strip() + group_user_allowlist = os.getenv(platform_group_user_env_map.get(source.platform, ""), "").strip() + group_chat_allowlist = os.getenv(platform_group_chat_env_map.get(source.platform, ""), "").strip() global_allowlist = os.getenv("GATEWAY_ALLOWED_USERS", "").strip() - if not platform_allowlist and not group_allowlist and not global_allowlist: + if not platform_allowlist and not group_user_allowlist and not group_chat_allowlist and not global_allowlist: # No allowlists configured -- check global allow-all flag return os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") - # Some platforms authorize group traffic by chat ID rather than sender ID. - if group_allowlist and source.chat_type in {"group", "forum"} and source.chat_id: + # Telegram can optionally authorize group traffic by chat ID. + # Keep this separate from TELEGRAM_GROUP_ALLOWED_USERS, which gates + # the sender user ID for group/forum messages. + if group_chat_allowlist and source.chat_type in {"group", "forum"} and source.chat_id: allowed_group_ids = { - chat_id.strip() for chat_id in group_allowlist.split(",") if chat_id.strip() + chat_id.strip() for chat_id in group_chat_allowlist.split(",") if chat_id.strip() } if "*" in allowed_group_ids or source.chat_id in allowed_group_ids: return True - # Check if user is in any allowlist + # Backward-compat shim for #15027: prior to PR #17686, + # TELEGRAM_GROUP_ALLOWED_USERS was (mis)used as a chat-ID allowlist. + # Values starting with "-" are Telegram chat IDs, not user IDs, so if + # users still have those in TELEGRAM_GROUP_ALLOWED_USERS we honor them + # as chat IDs and warn once. The correct var is now + # TELEGRAM_GROUP_ALLOWED_CHATS. + if ( + source.platform == Platform.TELEGRAM + and group_user_allowlist + and source.chat_type in {"group", "forum"} + and source.chat_id + ): + legacy_chat_ids = { + v.strip() + for v in group_user_allowlist.split(",") + if v.strip().startswith("-") + } + if legacy_chat_ids: + if not getattr(self, "_warned_telegram_group_users_legacy", False): + logger.warning( + "TELEGRAM_GROUP_ALLOWED_USERS contains chat-ID-shaped values " + "(%s). Treating them as chat IDs for backward compatibility. " + "Move chat IDs to TELEGRAM_GROUP_ALLOWED_CHATS — the _USERS var " + "is now for sender user IDs.", + ",".join(sorted(legacy_chat_ids)), + ) + self._warned_telegram_group_users_legacy = True + if source.chat_id in legacy_chat_ids: + return True + + # Check if user is in any allowlist. In group/forum chats, + # TELEGRAM_GROUP_ALLOWED_USERS is the scoped allowlist and should not + # imply DM access; TELEGRAM_ALLOWED_USERS remains the platform-wide + # allowlist and still works everywhere for backward compatibility. allowed_ids = set() if platform_allowlist: allowed_ids.update(uid.strip() for uid in platform_allowlist.split(",") if uid.strip()) + if group_user_allowlist: + allowed_ids.update(uid.strip() for uid in group_user_allowlist.split(",") if uid.strip()) if global_allowlist: allowed_ids.update(uid.strip() for uid in global_allowlist.split(",") if uid.strip()) @@ -3491,10 +3653,12 @@ class GatewayRunner: Resolution order: 1. Explicit per-platform ``unauthorized_dm_behavior`` in config — always wins. 2. Explicit global ``unauthorized_dm_behavior`` in config — wins when no per-platform. - 3. When an allowlist (``PLATFORM_ALLOWED_USERS`` or ``GATEWAY_ALLOWED_USERS``) is - configured, default to ``"ignore"`` — the allowlist signals that the owner has - deliberately restricted access; spamming unknown contacts with pairing codes - is both noisy and a potential info-leak. (#9337) + 3. When an allowlist (``PLATFORM_ALLOWED_USERS``, + ``PLATFORM_GROUP_ALLOWED_USERS`` / ``PLATFORM_GROUP_ALLOWED_CHATS``, + or ``GATEWAY_ALLOWED_USERS``) is configured, default to ``"ignore"`` — + the allowlist signals that the owner has deliberately restricted + access; spamming unknown contacts with pairing codes is both noisy + and a potential info-leak. (#9337) 4. No allowlist and no explicit config → ``"pair"`` (open-gateway default). """ config = getattr(self, "config", None) @@ -3533,14 +3697,24 @@ class GatewayRunner: Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", Platform.QQBOT: "QQ_ALLOWED_USERS", } + platform_group_env_map = { + Platform.TELEGRAM: ( + "TELEGRAM_GROUP_ALLOWED_USERS", + "TELEGRAM_GROUP_ALLOWED_CHATS", + ), + Platform.QQBOT: ("QQ_GROUP_ALLOWED_USERS",), + } if os.getenv(platform_env_map.get(platform, ""), "").strip(): return "ignore" + for env_key in platform_group_env_map.get(platform, ()): + if os.getenv(env_key, "").strip(): + return "ignore" if os.getenv("GATEWAY_ALLOWED_USERS", "").strip(): return "ignore" return "pair" - + async def _handle_message(self, event: MessageEvent) -> Optional[str]: """ Handle an incoming message from any platform. @@ -3717,6 +3891,50 @@ class GatewayRunner: ) _update_prompts.pop(_quick_key, None) + # Intercept messages that are responses to a pending /reload-mcp + # (or future) slash-confirm prompt. Recognized confirm replies are + # /approve, /always, /cancel (plus short aliases). Anything else + # falls through to normal dispatch — a stale pending confirm does + # NOT block other commands. + # + # Important: if a dangerous-command approval is ALSO pending (agent + # blocked inside tools/approval.py), the tool approval takes + # precedence — /approve there unblocks the waiting tool thread. + # Slash-confirm only catches /approve when no tool approval is live. + from tools import slash_confirm as _slash_confirm_mod + _pending_confirm = _slash_confirm_mod.get_pending(_quick_key) + _tool_approval_live = False + try: + from tools.approval import has_blocking_approval + _tool_approval_live = has_blocking_approval(_quick_key) + except Exception: + _tool_approval_live = False + if _pending_confirm and not _tool_approval_live: + _raw_reply = (event.text or "").strip() + _cmd_reply = event.get_command() + _confirm_choice = None + if _cmd_reply in ("approve", "yes", "ok", "confirm"): + _confirm_choice = "once" + elif _cmd_reply in ("always", "remember"): + _confirm_choice = "always" + elif _cmd_reply in ("cancel", "no", "deny", "nevermind"): + _confirm_choice = "cancel" + elif _raw_reply.lower() in ("approve", "approve once", "once"): + _confirm_choice = "once" + elif _raw_reply.lower() in ("always", "always approve"): + _confirm_choice = "always" + elif _raw_reply.lower() in ("cancel", "nevermind", "no"): + _confirm_choice = "cancel" + if _confirm_choice is not None: + _resolved = await _slash_confirm_mod.resolve( + _quick_key, _pending_confirm.get("confirm_id"), _confirm_choice, + ) + return _resolved or "" + # Stale pending + unrelated command: drop the pending state so + # the confirm doesn't block normal usage indefinitely. The user + # clearly moved on. + _slash_confirm_mod.clear_if_stale(_quick_key) + # PRIORITY handling when an agent is already running for this session. # Default behavior is to interrupt immediately so user text/stop messages # are handled with minimal latency. @@ -4187,6 +4405,9 @@ class GatewayRunner: if canonical == "reload-mcp": return await self._handle_reload_mcp_command(event) + if canonical == "reload-skills": + return await self._handle_reload_skills_command(event) + if canonical == "approve": return await self._handle_approve_command(event) @@ -5516,7 +5737,7 @@ class GatewayRunner: finally: # Restore session context variables to their pre-handler state self._clear_session_env(_session_env_tokens) - + def _format_session_info(self) -> str: """Resolve current model config and return a formatted info block. @@ -5713,7 +5934,7 @@ class GatewayRunner: if session_info: return f"{header}\n\n{session_info}{_tip_line}" return f"{header}{_tip_line}" - + async def _handle_profile_command(self, event: MessageEvent) -> str: """Handle /profile — show active profile name and home directory.""" from hermes_constants import display_hermes_home @@ -5862,7 +6083,7 @@ class GatewayRunner: lines.append("No active agents or running tasks.") return "\n".join(lines) - + async def _handle_stop_command(self, event: MessageEvent) -> str: """Handle /stop command - interrupt a running agent. @@ -6102,7 +6323,7 @@ class GatewayRunner: if page != requested_page: lines.append(f"_(Requested page {requested_page} was out of range, showing page {page}.)_") return "\n".join(lines) - + async def _handle_model_command(self, event: MessageEvent) -> Optional[str]: """Handle /model command — switch model for this session. @@ -6449,7 +6670,7 @@ class GatewayRunner: try: config = _load_gateway_config() - personalities = config.get("agent", {}).get("personalities", {}) if config else {} + personalities = cfg_get(config, "agent", "personalities", default={}) except Exception: config = {} personalities = {} @@ -6508,7 +6729,7 @@ class GatewayRunner: available = "`none`, " + ", ".join(f"`{n}`" for n in personalities) return f"Unknown personality: `{args}`\n\nAvailable: {available}" - + async def _handle_retry_command(self, event: MessageEvent) -> str: """Handle /retry command - re-send the last user message.""" source = event.source @@ -6544,7 +6765,7 @@ class GatewayRunner: # Let the normal message handler process it return await self._handle_message(retry_event) - + async def _handle_undo_command(self, event: MessageEvent) -> str: """Handle /undo command - remove the last user/assistant exchange.""" source = event.source @@ -6569,7 +6790,7 @@ class GatewayRunner: preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" - + async def _handle_set_home_command(self, event: MessageEvent) -> str: """Handle /sethome command -- set the current chat as the platform's home channel.""" source = event.source @@ -6590,7 +6811,7 @@ class GatewayRunner: f"✅ Home channel set to **{chat_name}** (ID: {chat_id}).\n" f"Cron jobs and cross-platform messages will be delivered here." ) - + @staticmethod def _get_guild_id(event: MessageEvent) -> Optional[int]: """Extract Discord guild_id from the raw message object.""" @@ -6958,14 +7179,15 @@ class GatewayRunner: _thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None - _AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'} + from gateway.platforms.base import should_send_media_as_audio + _VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'} _IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'} for media_path, is_voice in media_files: try: ext = Path(media_path).suffix.lower() - if ext in _AUDIO_EXTS: + if should_send_media_as_audio(event.source.platform, ext, is_voice=is_voice): await adapter.send_voice( chat_id=event.source.chat_id, audio_path=media_path, @@ -7450,7 +7672,7 @@ class GatewayRunner: # --- check config gate ------------------------------------------------ try: user_config = _load_gateway_config() - gate_enabled = user_config.get("display", {}).get("tool_progress_command", False) + gate_enabled = cfg_get(user_config, "display", "tool_progress_command", default=False) except Exception: gate_enabled = False @@ -7814,6 +8036,13 @@ class GatewayRunner: return "Failed to switch session." self._clear_session_boundary_security_state(session_key) + # Evict any cached agent for this session so the next message + # rebuilds with the correct session_id end-to-end — mirrors + # /branch and /reset. Without this, the cached AIAgent (and its + # memory provider, which cached `_session_id` during initialize()) + # keeps writing into the wrong session's record. See #6672. + self._evict_cached_agent(session_key) + # Get the title for confirmation title = self._session_db.get_session_title(target_id) or name @@ -8102,8 +8331,91 @@ class GatewayRunner: logger.error("Insights command error: %s", e, exc_info=True) return f"Error generating insights: {e}" - async def _handle_reload_mcp_command(self, event: MessageEvent) -> str: - """Handle /reload-mcp command -- disconnect and reconnect all MCP servers.""" + async def _handle_reload_mcp_command(self, event: MessageEvent) -> Optional[str]: + """Handle /reload-mcp — reconnect MCP servers and rebuild the cached agent. + + Reloading MCP tools invalidates the provider prompt cache for the + active session (tool schemas are baked into the system prompt). The + next message re-sends full input tokens, which is expensive on + long-context or high-reasoning models. + + To surface that cost, the command routes through the slash-confirm + primitive: users get an Approve Once / Always Approve / Cancel + prompt before the reload actually runs. "Always Approve" persists + ``approvals.mcp_reload_confirm: false`` so the prompt is silenced + for subsequent reloads in any session. + + Users can also skip the confirm by flipping the config key directly. + """ + source = event.source + session_key = self._session_key_for_source(source) + + # Read the gate fresh from disk so a prior "always" click takes + # effect on the next invocation without restarting the gateway. + user_config = self._read_user_config() + approvals = user_config.get("approvals") if isinstance(user_config, dict) else None + confirm_required = True + if isinstance(approvals, dict): + confirm_required = bool(approvals.get("mcp_reload_confirm", True)) + + if not confirm_required: + return await self._execute_mcp_reload(event) + + # Route through slash-confirm. The primitive sends the prompt and + # stores the resume handler; the button/text response triggers + # ``_resolve_slash_confirm`` which invokes the handler with the + # chosen outcome. + async def _on_confirm(choice: str) -> Optional[str]: + if choice == "cancel": + return "🟡 /reload-mcp cancelled. MCP tools unchanged." + if choice == "always": + # Persist the opt-out and run the reload. + try: + from cli import save_config_value + save_config_value("approvals.mcp_reload_confirm", False) + logger.info( + "User opted out of /reload-mcp confirmation (session=%s)", + session_key, + ) + except Exception as exc: + logger.warning("Failed to persist mcp_reload_confirm=false: %s", exc) + # once / always → run the reload + result = await self._execute_mcp_reload(event) + if choice == "always": + return ( + f"{result}\n\n" + "ℹ️ Future `/reload-mcp` calls will run without confirmation. " + "Re-enable via `approvals.mcp_reload_confirm: true` in config.yaml." + ) + return result + + prompt_message = ( + "⚠️ **Confirm /reload-mcp**\n\n" + "Reloading MCP servers rebuilds the tool set for this session " + "and **invalidates the provider prompt cache** — the next " + "message will re-send full input tokens. On long-context or " + "high-reasoning models this can be expensive.\n\n" + "Choose:\n" + "• **Approve Once** — reload now\n" + "• **Always Approve** — reload now and silence this prompt permanently\n" + "• **Cancel** — leave MCP tools unchanged\n\n" + "_Text fallback: reply `/approve`, `/always`, or `/cancel`._" + ) + return await self._request_slash_confirm( + event=event, + command="reload-mcp", + title="/reload-mcp", + message=prompt_message, + handler=_on_confirm, + ) + + async def _execute_mcp_reload(self, event: MessageEvent) -> str: + """Actually disconnect, reconnect, and notify MCP tool changes. + + Split out from ``_handle_reload_mcp_command`` so the confirmation + wrapper can invoke the same path whether the user confirmed via + button, text reply, or has the confirm gate disabled. + """ loop = asyncio.get_running_loop() try: from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _servers, _lock @@ -8169,6 +8481,178 @@ class GatewayRunner: logger.warning("MCP reload failed: %s", e) return f"❌ MCP reload failed: {e}" + async def _handle_reload_skills_command(self, event: MessageEvent) -> str: + """Handle /reload-skills — rescan skills dir, queue a note for next turn. + + Skills don't need to be in the system prompt for the model to use + them (they're invoked via ``/skill-name``, ``skills_list``, or + ``skill_view`` at runtime), so this does NOT clear the prompt cache + — prefix caching stays intact. + + If any skills were added or removed, a one-shot note is queued on + ``self._pending_skills_reload_notes[session_key]``. The gateway + prepends it to the NEXT user message in this session (see the + consumer at ~L11025 in ``_run_agent_turn``), then clears it. Nothing + is written to the session transcript out-of-band, so message + alternation is preserved. + """ + loop = asyncio.get_running_loop() + try: + from agent.skill_commands import reload_skills + + result = await loop.run_in_executor(None, reload_skills) + added = result.get("added", []) # [{"name", "description"}, ...] + removed = result.get("removed", []) # [{"name", "description"}, ...] + total = result.get("total", 0) + + lines = ["🔄 **Skills Reloaded**\n"] + if not added and not removed: + lines.append("No new skills detected.") + lines.append(f"\n📚 {total} skill(s) available") + return "\n".join(lines) + + def _fmt_line(item: dict) -> str: + nm = item.get("name", "") + desc = item.get("description", "") + return f" - {nm}: {desc}" if desc else f" - {nm}" + + if added: + lines.append("➕ **Added Skills:**") + for item in added: + lines.append(_fmt_line(item)) + if removed: + lines.append("➖ **Removed Skills:**") + for item in removed: + lines.append(_fmt_line(item)) + lines.append(f"\n📚 {total} skill(s) available") + + # Queue the one-shot note for the next user turn in this session. + # Format matches how the system prompt renders pre-existing + # skills (`` - name: description``) so the model reads the + # diff in the same shape as its original skill catalog. + sections = ["[USER INITIATED SKILLS RELOAD:"] + if added: + sections.append("") + sections.append("Added Skills:") + for item in added: + sections.append(_fmt_line(item)) + if removed: + sections.append("") + sections.append("Removed Skills:") + for item in removed: + sections.append(_fmt_line(item)) + sections.append("") + sections.append("Use skills_list to see the updated catalog.]") + note = "\n".join(sections) + + session_key = self._session_key_for_source(event.source) + if not hasattr(self, "_pending_skills_reload_notes"): + self._pending_skills_reload_notes = {} + if session_key: + self._pending_skills_reload_notes[session_key] = note + + return "\n".join(lines) + + except Exception as e: + logger.warning("Skills reload failed: %s", e) + return f"❌ Skills reload failed: {e}" + + # ------------------------------------------------------------------ + # Slash-command confirmation primitive (generic) + # ------------------------------------------------------------------ + # Used by slash commands that have a non-destructive but expensive + # side effect worth an explicit user confirmation (currently only + # /reload-mcp, which invalidates the prompt cache). Two delivery + # paths: + # 1. Button UI — adapters that override ``send_slash_confirm`` + # (Telegram, Discord, Slack, Matrix, Feishu) render three + # inline buttons. The adapter routes the button click back via + # ``tools.slash_confirm.resolve(session_key, confirm_id, choice)``. + # 2. Text fallback — adapters that don't override the hook get a + # plain text prompt. Users reply with /approve, /always, or + # /cancel; the early intercept in ``_handle_message`` matches + # those replies against ``tools.slash_confirm.get_pending()``. + + async def _request_slash_confirm( + self, + *, + event: MessageEvent, + command: str, + title: str, + message: str, + handler, + ) -> Optional[str]: + """Ask the user to confirm an expensive slash command. + + ``handler`` is an async callable ``handler(choice: str) -> str`` + where ``choice`` is ``"once"``, ``"always"``, or ``"cancel"``. + The handler runs on the event loop when the user responds; its + return value is sent back as a gateway message. + + Returns a short acknowledgment string to send immediately (before + the user's response). If buttons rendered successfully the ack + is ``None`` (buttons are self-explanatory); if we fell back to + text the message itself IS the ack. + """ + from tools import slash_confirm as _slash_confirm_mod + + source = event.source + session_key = self._session_key_for_source(source) + confirm_id = f"{next(self._slash_confirm_counter)}" + + # Register the pending confirm FIRST so a super-fast button click + # cannot race the send_slash_confirm return. + _slash_confirm_mod.register(session_key, confirm_id, command, handler) + + adapter = self.adapters.get(source.platform) + metadata = self._thread_metadata_for_source(source) + + used_buttons = False + if adapter is not None: + try: + button_result = await adapter.send_slash_confirm( + chat_id=source.chat_id, + title=title, + message=message, + session_key=session_key, + confirm_id=confirm_id, + metadata=metadata, + ) + if button_result and getattr(button_result, "success", False): + used_buttons = True + except Exception as exc: + logger.debug( + "send_slash_confirm failed for %s on %s: %s", + command, source.platform, exc, + ) + + if used_buttons: + # Buttons rendered — no redundant text ack. + return None + # Text fallback — return the prompt message as the direct reply. + return message + + def _read_user_config(self) -> Dict[str, Any]: + """Read the user's raw config.yaml (cached) for gate lookups. + + Used by slash-confirm gates that must reflect on-disk state changes + (e.g. a prior "Always Approve" click) without a gateway restart. + """ + try: + from hermes_cli.config import load_config + cfg = load_config() + return cfg if isinstance(cfg, dict) else {} + except Exception: + return {} + + def _thread_metadata_for_source(self, source) -> Optional[Dict[str, Any]]: + """Build the metadata dict platforms need for thread-aware replies.""" + thread_id = getattr(source, "thread_id", None) + if thread_id is None: + return None + return {"thread_id": thread_id} + + # ------------------------------------------------------------------ # /approve & /deny — explicit dangerous-command approval # ------------------------------------------------------------------ @@ -8342,8 +8826,16 @@ class GatewayRunner: # Block non-messaging platforms (API server, webhooks, ACP) platform = event.source.platform - if platform not in self._UPDATE_ALLOWED_PLATFORMS: - return "✗ /update is only available from messaging platforms. Run `hermes update` from the terminal." + _allowed = self._UPDATE_ALLOWED_PLATFORMS + # Plugin platforms with allow_update_command=True are also allowed + if platform not in _allowed: + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform.value) + if not entry or not entry.allow_update_command: + return "✗ /update is only available from messaging platforms. Run `hermes update` from the terminal." + except Exception: + return "✗ /update is only available from messaging platforms. Run `hermes update` from the terminal." if is_managed(): return f"✗ {format_managed_message('update Hermes Agent')}" @@ -9003,6 +9495,16 @@ class GatewayRunner: try: platform = Platform(platform_name) + # Reject arbitrary strings that create dynamic pseudo-members. + # Built-in platforms are always valid; plugin platforms must be + # registered in the platform registry. + if platform.value not in _BUILTIN_PLATFORM_VALUES: + try: + from gateway.platform_registry import platform_registry + if not platform_registry.is_registered(platform.value): + raise ValueError(platform_name) + except Exception: + raise ValueError(platform_name) except Exception: logger.warning( "Synthetic process event has invalid platform metadata: %r", @@ -9230,11 +9732,17 @@ class GatewayRunner: @classmethod def _extract_cache_busting_config(cls, user_config: dict | None) -> dict: - """Pull the subset of config values that must bust the agent cache. + """Pull values that must bust the cached agent. - Returns a flat dict keyed by 'section.key'. Missing keys and - non-dict sections yield None values, which still contribute to - the signature (so 'absent' vs 'present-and-null' differ). + Returns a flat dict keyed by 'section.key'. Missing config keys and + non-dict sections yield None values, which still contribute to the + signature (so 'absent' vs 'present-and-null' differ). + + The live tool registry generation is included too. MCP reloads and + dynamic MCP tool-list changes mutate the registry without necessarily + changing config.yaml. Cached AIAgent instances freeze their tool + schemas at construction time, so a registry generation change must + rebuild the agent before the next turn. """ out: Dict[str, Any] = {} cfg = user_config if isinstance(user_config, dict) else {} @@ -9244,6 +9752,12 @@ class GatewayRunner: out[f"{section}.{key}"] = section_val.get(key) else: out[f"{section}.{key}"] = None + try: + from tools.registry import registry + + out["tools.registry_generation"] = getattr(registry, "_generation", None) + except Exception: + out["tools.registry_generation"] = None return out @staticmethod @@ -10017,10 +10531,26 @@ class GatewayRunner: # Tool progress mode — resolved per-platform with env var fallback _resolved_tp = resolve_display_setting(user_config, platform_key, "tool_progress") + _env_tp = os.getenv("HERMES_TOOL_PROGRESS_MODE") + _display_cfg = display_config if isinstance(display_config, dict) else {} + _platforms_cfg = _display_cfg.get("platforms") or {} + _platform_cfg = _platforms_cfg.get(platform_key) or {} + _legacy_tp_overrides = _display_cfg.get("tool_progress_overrides") or {} + _tool_progress_configured = ( + "tool_progress" in _display_cfg + or ( + isinstance(_platform_cfg, dict) + and "tool_progress" in _platform_cfg + ) + or ( + isinstance(_legacy_tp_overrides, dict) + and platform_key in _legacy_tp_overrides + ) + ) progress_mode = ( - _resolved_tp - or os.getenv("HERMES_TOOL_PROGRESS_MODE") - or "all" + _env_tp + if _env_tp and not _tool_progress_configured + else (_resolved_tp or _env_tp or "all") ) # Disable tool progress for webhooks - they don't support message editing, # so each progress line would be sent as a separate message. @@ -10069,7 +10599,7 @@ class GatewayRunner: tool_progress_hint_gateway, ) _cfg = _load_gateway_config() - gate_on = bool(_cfg.get("display", {}).get("tool_progress_command", False)) + gate_on = bool(cfg_get(_cfg, "display", "tool_progress_command", default=False)) if gate_on and not is_seen(_cfg, TOOL_PROGRESS_FLAG): long_tool_hint_fired[0] = True progress_queue.put(tool_progress_hint_gateway()) @@ -10078,6 +10608,7 @@ class GatewayRunner: logger.debug("tool-progress onboarding hint failed: %s", _hint_err) return + # Only act on tool.started events (ignore tool.completed, reasoning.available, etc.) if event_type not in ("tool.started",): return @@ -10928,6 +11459,17 @@ class GatewayRunner: + message ) + # Consume one-shot /reload-skills note (if the user ran + # /reload-skills since their last turn in this session). Same + # queue pattern as CLI: prepend to the NEXT user message, then + # clear. Nothing was written to the transcript out-of-band, so + # message alternation stays intact. + _pending_notes = getattr(self, "_pending_skills_reload_notes", None) + if _pending_notes and session_key and session_key in _pending_notes: + _srn = _pending_notes.pop(session_key, None) + if _srn: + message = _srn + "\n\n" + message + _approval_session_key = session_key or "" _approval_session_token = set_current_session_key(_approval_session_key) register_gateway_notify(_approval_session_key, _approval_notify_sync) diff --git a/gateway/session.py b/gateway/session.py index 6f35b95865..557f026ff1 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -62,6 +62,7 @@ from .config import ( ) from .whatsapp_identity import ( canonical_whatsapp_identifier, + normalize_whatsapp_identifier, # noqa: F401 - re-exported for gateway.session callers ) from utils import atomic_replace @@ -234,7 +235,7 @@ def build_session_context_prompt( ) -> str: """ Build the dynamic system prompt section that tells the agent about its context. - + This is injected into the system prompt so the agent knows: - Where messages are coming from - What platforms are connected @@ -246,13 +247,23 @@ def build_session_context_prompt( Platforms like Discord are excluded because mentions need real IDs. Routing still uses the original values (they stay in SessionSource). """ - # Only apply redaction on platforms where IDs aren't needed for mentions - redact_pii = redact_pii and context.source.platform in _PII_SAFE_PLATFORMS + # Only apply redaction on platforms where IDs aren't needed for mentions. + # Check both the hardcoded set (builtins) and the plugin registry. + _is_pii_safe = context.source.platform in _PII_SAFE_PLATFORMS + if not _is_pii_safe: + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(context.source.platform.value) + if entry and entry.pii_safe: + _is_pii_safe = True + except Exception: + pass + redact_pii = redact_pii and _is_pii_safe lines = [ "## Current Session Context", "", ] - + # Source info platform_name = context.source.platform.value.title() if context.source.platform == Platform.LOCAL: @@ -277,7 +288,7 @@ def build_session_context_prompt( else: desc = src.description lines.append(f"**Source:** {platform_name} ({desc})") - + # Channel topic (if available - provides context about the channel's purpose) if context.source.chat_topic: lines.append(f"**Channel Topic:** {context.source.chat_topic}") @@ -302,7 +313,7 @@ def build_session_context_prompt( if redact_pii: uid = _hash_sender_id(uid) lines.append(f"**User ID:** {uid}") - + # Platform-specific behavioral notes if context.source.platform == Platform.SLACK: lines.append("") @@ -368,9 +379,9 @@ def build_session_context_prompt( for p in context.connected_platforms: if p != Platform.LOCAL: platforms_list.append(f"{p.value}: Connected ✓") - + lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}") - + # Home channels if context.home_channels: lines.append("") @@ -378,11 +389,11 @@ def build_session_context_prompt( for platform, home in context.home_channels.items(): hc_id = _hash_chat_id(home.chat_id) if redact_pii else home.chat_id lines.append(f" - {platform.value}: {home.name} (ID: {hc_id})") - + # Delivery options for scheduled tasks lines.append("") lines.append("**Delivery options for scheduled tasks:**") - + from hermes_constants import display_hermes_home # Origin delivery @@ -398,15 +409,15 @@ def build_session_context_prompt( lines.append( f"- `\"local\"` → Save to local files only ({display_hermes_home()}/cron/output/)" ) - + # Platform home channels for platform, home in context.home_channels.items(): lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})") - + # Note about explicit targeting lines.append("") lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*") - + return "\n".join(lines) diff --git a/hermes_cli/_parser.py b/hermes_cli/_parser.py new file mode 100644 index 0000000000..29ac96c97b --- /dev/null +++ b/hermes_cli/_parser.py @@ -0,0 +1,373 @@ +""" +Top-level argparse construction for the hermes CLI. + +Lives in its own module so other modules (e.g. ``relaunch.py``) can +introspect the parser to discover which flags exist without running the +``main`` fn. + +Only the top-level parser and the ``chat`` subparser live here. Every other +subparser (model, gateway, sessions, …) is built inline in ``main.py`` +because its dispatch is tightly coupled to module-level ``cmd_*`` functions. +""" + +import argparse + + +# `--profile` / `-p` is consumed by ``main._apply_profile_override`` before +# argparse runs (it sets ``HERMES_HOME`` and strips itself from ``sys.argv``), +# so it isn't on the parser. Listed here so all "carry over on relaunch" +# metadata lives in one file. +PRE_ARGPARSE_INHERITED_FLAGS: list[tuple[str, bool]] = [ + ("--profile", True), + ("-p", True), +] + + +def _inherited_flag(parser, *args, **kwargs): + """Register a flag that ``hermes_cli.relaunch`` should carry over when + the CLI re-execs itself (e.g. after ``sessions browse`` picks a session, + or after the setup wizard launches chat). + + Equivalent to ``parser.add_argument(...)`` plus tagging the resulting + Action with ``inherit_on_relaunch = True`` so the relaunch table builder + can find it via introspection. + """ + action = parser.add_argument(*args, **kwargs) + action.inherit_on_relaunch = True + return action + + +_EPILOGUE = """ +Examples: + hermes Start interactive chat + hermes chat -q "Hello" Single query mode + hermes -c Resume the most recent session + hermes -c "my project" Resume a session by name (latest in lineage) + hermes --resume Resume a specific session by ID + hermes setup Run setup wizard + hermes logout Clear stored authentication + hermes auth add Add a pooled credential + hermes auth list List pooled credentials + hermes auth remove

Remove pooled credential by index, id, or label + hermes auth reset Clear exhaustion status for a provider + hermes model Select default model + hermes fallback [list] Show fallback provider chain + hermes fallback add Add a fallback provider (same picker as `hermes model`) + hermes fallback remove Remove a fallback provider from the chain + hermes config View configuration + hermes config edit Edit config in $EDITOR + hermes config set model gpt-4 Set a config value + hermes gateway Run messaging gateway + hermes -s hermes-agent-dev,github-auth + hermes -w Start in isolated git worktree + hermes gateway install Install gateway background service + hermes sessions list List past sessions + hermes sessions browse Interactive session picker + hermes sessions rename ID T Rename/title a session + hermes logs View agent.log (last 50 lines) + hermes logs -f Follow agent.log in real time + hermes logs errors View errors.log + hermes logs --since 1h Lines from the last hour + hermes debug share Upload debug report for support + hermes update Update to latest version + +For more help on a command: + hermes --help +""" + + +def build_top_level_parser(): + """Build the top-level parser, the subparsers action, and the ``chat`` subparser. + + Returns ``(parser, subparsers, chat_parser)``. The caller wires + ``chat_parser.set_defaults(func=cmd_chat)`` and continues registering + other subparsers via ``subparsers.add_parser(...)``. + """ + parser = argparse.ArgumentParser( + prog="hermes", + description="Hermes Agent - AI assistant with tool-calling capabilities", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=_EPILOGUE, + ) + + parser.add_argument( + "--version", "-V", action="store_true", help="Show version and exit" + ) + parser.add_argument( + "-z", + "--oneshot", + metavar="PROMPT", + default=None, + help=( + "One-shot mode: send a single prompt and print ONLY the final " + "response text to stdout. No banner, no spinner, no tool " + "previews, no session_id line. Tools, memory, rules, and " + "AGENTS.md in the CWD are loaded as normal; approvals are " + "auto-bypassed. Intended for scripts / pipes." + ), + ) + # --model / --provider are accepted at the top level so they can pair + # with -z without needing the `chat` subcommand. If neither -z nor a + # subcommand consumes them, they fall through harmlessly as None. + # Mirrors `hermes chat --model ... --provider ...` semantics. + _inherited_flag( + parser, + "-m", + "--model", + default=None, + help=( + "Model override for this invocation (e.g. anthropic/claude-sonnet-4.6). " + "Applies to -z/--oneshot and --tui. Also settable via HERMES_INFERENCE_MODEL env var." + ), + ) + _inherited_flag( + parser, + "--provider", + default=None, + help=( + "Provider override for this invocation (e.g. openrouter, anthropic). " + "Applies to -z/--oneshot and --tui. Also settable via HERMES_INFERENCE_PROVIDER env var." + ), + ) + parser.add_argument( + "-t", + "--toolsets", + default=None, + help="Comma-separated toolsets to enable for this invocation. Applies to -z/--oneshot and --tui.", + ) + parser.add_argument( + "--resume", + "-r", + metavar="SESSION", + default=None, + help="Resume a previous session by ID or title", + ) + parser.add_argument( + "--continue", + "-c", + dest="continue_last", + nargs="?", + const=True, + default=None, + metavar="SESSION_NAME", + help="Resume a session by name, or the most recent if no name given", + ) + parser.add_argument( + "--worktree", + "-w", + action="store_true", + default=False, + help="Run in an isolated git worktree (for parallel agents)", + ) + _inherited_flag( + parser, + "--accept-hooks", + action="store_true", + default=False, + help=( + "Auto-approve any unseen shell hooks declared in config.yaml " + "without a TTY prompt. Equivalent to HERMES_ACCEPT_HOOKS=1 or " + "hooks_auto_accept: true in config.yaml. Use on CI / headless " + "runs that can't prompt." + ), + ) + _inherited_flag( + parser, + "--skills", + "-s", + action="append", + default=None, + help="Preload one or more skills for the session (repeat flag or comma-separate)", + ) + _inherited_flag( + parser, + "--yolo", + action="store_true", + default=False, + help="Bypass all dangerous command approval prompts (use at your own risk)", + ) + _inherited_flag( + parser, + "--pass-session-id", + action="store_true", + default=False, + help="Include the session ID in the agent's system prompt", + ) + _inherited_flag( + parser, + "--ignore-user-config", + action="store_true", + default=False, + help="Ignore ~/.hermes/config.yaml and fall back to built-in defaults (credentials in .env are still loaded)", + ) + _inherited_flag( + parser, + "--ignore-rules", + action="store_true", + default=False, + help="Skip auto-injection of AGENTS.md, SOUL.md, .cursorrules, memory, and preloaded skills", + ) + _inherited_flag( + parser, + "--tui", + action="store_true", + default=False, + help="Launch the modern TUI instead of the classic REPL", + ) + _inherited_flag( + parser, + "--dev", + dest="tui_dev", + action="store_true", + default=False, + help="With --tui: run TypeScript sources via tsx (skip dist build)", + ) + + subparsers = parser.add_subparsers(dest="command", help="Command to run") + + # ========================================================================= + # chat command + # ========================================================================= + chat_parser = subparsers.add_parser( + "chat", + help="Interactive chat with the agent", + description="Start an interactive chat session with Hermes Agent", + ) + chat_parser.add_argument( + "-q", "--query", help="Single query (non-interactive mode)" + ) + chat_parser.add_argument( + "--image", help="Optional local image path to attach to a single query" + ) + _inherited_flag( + chat_parser, + "-m", "--model", help="Model to use (e.g., anthropic/claude-sonnet-4)", + ) + chat_parser.add_argument( + "-t", "--toolsets", help="Comma-separated toolsets to enable" + ) + _inherited_flag( + chat_parser, + "-s", + "--skills", + action="append", + default=argparse.SUPPRESS, + help="Preload one or more skills for the session (repeat flag or comma-separate)", + ) + _inherited_flag( + chat_parser, + "--provider", + # No `choices=` here: user-defined providers from config.yaml `providers:` + # are also valid values, and runtime resolution (resolve_runtime_provider) + # handles validation/error reporting consistently with the top-level + # `--provider` flag. + default=None, + help="Inference provider (default: auto). Built-in or a user-defined name from `providers:` in config.yaml.", + ) + chat_parser.add_argument( + "-v", "--verbose", action="store_true", help="Verbose output" + ) + chat_parser.add_argument( + "-Q", + "--quiet", + action="store_true", + help="Quiet mode for programmatic use: suppress banner, spinner, and tool previews. Only output the final response and session info.", + ) + chat_parser.add_argument( + "--resume", + "-r", + metavar="SESSION_ID", + default=argparse.SUPPRESS, + help="Resume a previous session by ID (shown on exit)", + ) + chat_parser.add_argument( + "--continue", + "-c", + dest="continue_last", + nargs="?", + const=True, + default=argparse.SUPPRESS, + metavar="SESSION_NAME", + help="Resume a session by name, or the most recent if no name given", + ) + chat_parser.add_argument( + "--worktree", + "-w", + action="store_true", + default=argparse.SUPPRESS, + help="Run in an isolated git worktree (for parallel agents on the same repo)", + ) + _inherited_flag( + chat_parser, + "--accept-hooks", + action="store_true", + default=argparse.SUPPRESS, + help=( + "Auto-approve any unseen shell hooks declared in config.yaml " + "without a TTY prompt (see also HERMES_ACCEPT_HOOKS env var and " + "hooks_auto_accept: in config.yaml)." + ), + ) + chat_parser.add_argument( + "--checkpoints", + action="store_true", + default=False, + help="Enable filesystem checkpoints before destructive file operations (use /rollback to restore)", + ) + chat_parser.add_argument( + "--max-turns", + type=int, + default=None, + metavar="N", + help="Maximum tool-calling iterations per conversation turn (default: 90, or agent.max_turns in config)", + ) + _inherited_flag( + chat_parser, + "--yolo", + action="store_true", + default=argparse.SUPPRESS, + help="Bypass all dangerous command approval prompts (use at your own risk)", + ) + _inherited_flag( + chat_parser, + "--pass-session-id", + action="store_true", + default=argparse.SUPPRESS, + help="Include the session ID in the agent's system prompt", + ) + _inherited_flag( + chat_parser, + "--ignore-user-config", + action="store_true", + default=argparse.SUPPRESS, + help="Ignore ~/.hermes/config.yaml and fall back to built-in defaults (credentials in .env are still loaded). Useful for isolated CI runs, reproduction, and third-party integrations.", + ) + _inherited_flag( + chat_parser, + "--ignore-rules", + action="store_true", + default=argparse.SUPPRESS, + help="Skip auto-injection of AGENTS.md, SOUL.md, .cursorrules, memory, and preloaded skills. Combine with --ignore-user-config for a fully isolated run.", + ) + chat_parser.add_argument( + "--source", + default=None, + help="Session source tag for filtering (default: cli). Use 'tool' for third-party integrations that should not appear in user session lists.", + ) + _inherited_flag( + chat_parser, + "--tui", + action="store_true", + default=False, + help="Launch the modern TUI instead of the classic REPL", + ) + _inherited_flag( + chat_parser, + "--dev", + dest="tui_dev", + action="store_true", + default=False, + help="With --tui: run TypeScript sources via tsx (skip dist build)", + ) + + return parser, subparsers, chat_parser diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index c5ff23e586..7885e99d1e 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -72,6 +72,14 @@ DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" +MINIMAX_OAUTH_CLIENT_ID = "78257093-7e40-4613-99e0-527b14b39113" +MINIMAX_OAUTH_SCOPE = "group_id profile model.completion" +MINIMAX_OAUTH_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:user_code" +MINIMAX_OAUTH_GLOBAL_BASE = "https://api.minimax.io" +MINIMAX_OAUTH_CN_BASE = "https://api.minimaxi.com" +MINIMAX_OAUTH_GLOBAL_INFERENCE = "https://api.minimax.io/anthropic" +MINIMAX_OAUTH_CN_INFERENCE = "https://api.minimaxi.com/anthropic" +MINIMAX_OAUTH_REFRESH_SKEW_SECONDS = 60 DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1" DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com" DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot" @@ -126,7 +134,7 @@ class ProviderConfig: """Describes a known inference provider.""" id: str name: str - auth_type: str # "oauth_device_code", "oauth_external", or "api_key" + auth_type: str # "oauth_device_code", "oauth_external", "oauth_minimax", or "api_key" portal_base_url: str = "" inference_base_url: str = "" client_id: str = "" @@ -255,6 +263,17 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { api_key_env_vars=("MINIMAX_API_KEY",), base_url_env_var="MINIMAX_BASE_URL", ), + "minimax-oauth": ProviderConfig( + id="minimax-oauth", + name="MiniMax (OAuth \u00b7 minimax.io)", + auth_type="oauth_minimax", + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + inference_base_url=MINIMAX_OAUTH_GLOBAL_INFERENCE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + scope=MINIMAX_OAUTH_SCOPE, + extra={"region": "global", "cn_portal_base_url": MINIMAX_OAUTH_CN_BASE, + "cn_inference_base_url": MINIMAX_OAUTH_CN_INFERENCE}, + ), "anthropic": ProviderConfig( id="anthropic", name="Anthropic", @@ -1153,6 +1172,7 @@ def resolve_provider( "arcee-ai": "arcee", "arceeai": "arcee", "gmi-cloud": "gmi", "gmicloud": "gmi", "minimax-china": "minimax-cn", "minimax_cn": "minimax-cn", + "minimax-portal": "minimax-oauth", "minimax-global": "minimax-oauth", "minimax_oauth": "minimax-oauth", "alibaba_coding": "alibaba-coding-plan", "alibaba-coding": "alibaba-coding-plan", "alibaba_coding_plan": "alibaba-coding-plan", "claude": "anthropic", "claude-code": "anthropic", @@ -4116,6 +4136,326 @@ def _codex_device_code_login() -> Dict[str, Any]: } +# ==================== MiniMax Portal OAuth ==================== + +def _minimax_pkce_pair() -> tuple: + """Generate (code_verifier, code_challenge_S256, state) for MiniMax OAuth.""" + import secrets + verifier = secrets.token_urlsafe(64)[:96] + challenge = base64.urlsafe_b64encode( + hashlib.sha256(verifier.encode()).digest() + ).decode().rstrip("=") + state = secrets.token_urlsafe(16) + return verifier, challenge, state + + +def _minimax_request_user_code( + client: httpx.Client, *, portal_base_url: str, client_id: str, + code_challenge: str, state: str, +) -> Dict[str, Any]: + response = client.post( + f"{portal_base_url}/oauth/code", + data={ + "response_type": "code", + "client_id": client_id, + "scope": MINIMAX_OAUTH_SCOPE, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": state, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "x-request-id": str(uuid.uuid4()), + }, + ) + if response.status_code != 200: + raise AuthError( + f"MiniMax OAuth authorization failed: {response.text or response.reason_phrase}", + provider="minimax-oauth", code="authorization_failed", + ) + payload = response.json() + for field in ("user_code", "verification_uri", "expired_in"): + if field not in payload: + raise AuthError( + f"MiniMax OAuth response missing field: {field}", + provider="minimax-oauth", code="authorization_incomplete", + ) + if payload.get("state") != state: + raise AuthError( + "MiniMax OAuth state mismatch (possible CSRF).", + provider="minimax-oauth", code="state_mismatch", + ) + return payload + + +def _minimax_poll_token( + client: httpx.Client, *, portal_base_url: str, client_id: str, + user_code: str, code_verifier: str, expired_in: int, interval_ms: Optional[int], +) -> Dict[str, Any]: + # OpenClaw treats expired_in as a unix-ms timestamp (Date.now() < expireTimeMs). + # Defensive parsing: if it's small enough to be a duration, treat as seconds. + import time as _time + now_ms = int(_time.time() * 1000) + if expired_in > now_ms // 2: + # Looks like a unix-ms timestamp. + deadline = expired_in / 1000.0 + else: + # Treat as duration in seconds from now. + deadline = _time.time() + max(1, expired_in) + interval = max(2.0, (interval_ms or 2000) / 1000.0) + + while _time.time() < deadline: + response = client.post( + f"{portal_base_url}/oauth/token", + data={ + "grant_type": MINIMAX_OAUTH_GRANT_TYPE, + "client_id": client_id, + "user_code": user_code, + "code_verifier": code_verifier, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, + ) + try: + payload = response.json() if response.text else {} + except Exception: + payload = {} + + if response.status_code != 200: + msg = (payload.get("base_resp", {}) or {}).get("status_msg") or response.text + raise AuthError( + f"MiniMax OAuth error: {msg or 'unknown'}", + provider="minimax-oauth", code="token_exchange_failed", + ) + + status = payload.get("status") + if status == "error": + raise AuthError( + "MiniMax OAuth reported an error. Please try again later.", + provider="minimax-oauth", code="authorization_denied", + ) + if status == "success": + if not all(payload.get(k) for k in ("access_token", "refresh_token", "expired_in")): + raise AuthError( + "MiniMax OAuth success payload missing required token fields.", + provider="minimax-oauth", code="token_incomplete", + ) + return payload + # "pending" or any other status -> keep polling + _time.sleep(interval) + + raise AuthError( + "MiniMax OAuth timed out before authorization completed.", + provider="minimax-oauth", code="timeout", + ) + + +def _minimax_save_auth_state(auth_state: Dict[str, Any]) -> None: + """Persist MiniMax OAuth state to Hermes auth store (~/.hermes/auth.json).""" + with _auth_store_lock(): + auth_store = _load_auth_store() + _save_provider_state(auth_store, "minimax-oauth", auth_state) + _save_auth_store(auth_store) + + +def _minimax_oauth_login( + *, region: str = "global", open_browser: bool = True, + timeout_seconds: float = 15.0, +) -> Dict[str, Any]: + """Run MiniMax OAuth flow, persist tokens, return auth state dict.""" + pconfig = PROVIDER_REGISTRY["minimax-oauth"] + if region == "cn": + portal_base_url = pconfig.extra["cn_portal_base_url"] + inference_base_url = pconfig.extra["cn_inference_base_url"] + else: + portal_base_url = pconfig.portal_base_url + inference_base_url = pconfig.inference_base_url + + verifier, challenge, state = _minimax_pkce_pair() + + if _is_remote_session(): + open_browser = False + + print(f"Starting Hermes login via MiniMax ({region}) OAuth...") + print(f"Portal: {portal_base_url}") + + with httpx.Client(timeout=httpx.Timeout(timeout_seconds), + headers={"Accept": "application/json"}) as client: + code_data = _minimax_request_user_code( + client, portal_base_url=portal_base_url, + client_id=pconfig.client_id, + code_challenge=challenge, state=state, + ) + verification_url = str(code_data["verification_uri"]) + user_code = str(code_data["user_code"]) + + print() + print("To continue:") + print(f" 1. Open: {verification_url}") + print(f" 2. If prompted, enter code: {user_code}") + if open_browser: + if webbrowser.open(verification_url): + print(" (Opened browser for verification)") + else: + print(" Could not open browser automatically -- use the URL above.") + + interval_raw = code_data.get("interval") + interval_ms = int(interval_raw) if interval_raw is not None else None + print("Waiting for approval...") + + token_data = _minimax_poll_token( + client, portal_base_url=portal_base_url, + client_id=pconfig.client_id, + user_code=user_code, code_verifier=verifier, + expired_in=int(code_data["expired_in"]), + interval_ms=interval_ms, + ) + + now = datetime.now(timezone.utc) + expires_in_s = int(token_data["expired_in"]) + expires_at = now.timestamp() + expires_in_s + + auth_state = { + "provider": "minimax-oauth", + "region": region, + "portal_base_url": portal_base_url, + "inference_base_url": inference_base_url, + "client_id": pconfig.client_id, + "scope": MINIMAX_OAUTH_SCOPE, + "token_type": token_data.get("token_type", "Bearer"), + "access_token": token_data["access_token"], + "refresh_token": token_data["refresh_token"], + "resource_url": token_data.get("resource_url"), + "obtained_at": now.isoformat(), + "expires_at": datetime.fromtimestamp(expires_at, tz=timezone.utc).isoformat(), + "expires_in": expires_in_s, + } + + _minimax_save_auth_state(auth_state) + print("\u2713 MiniMax OAuth login successful.") + if msg := token_data.get("notification_message"): + print(f"Note from MiniMax: {msg}") + return auth_state + + +def _refresh_minimax_oauth_state( + state: Dict[str, Any], *, timeout_seconds: float = 15.0, + force: bool = False, +) -> Dict[str, Any]: + """Refresh MiniMax OAuth access token if close to expiry (or forced).""" + if not state.get("refresh_token"): + raise AuthError( + "MiniMax OAuth state has no refresh_token; please re-login.", + provider="minimax-oauth", code="no_refresh_token", relogin_required=True, + ) + try: + expires_at = datetime.fromisoformat(state.get("expires_at", "")).timestamp() + except Exception: + expires_at = 0.0 + now = time.time() + if not force and (expires_at - now) > MINIMAX_OAUTH_REFRESH_SKEW_SECONDS: + return state + + portal_base_url = state["portal_base_url"] + with httpx.Client(timeout=httpx.Timeout(timeout_seconds)) as client: + response = client.post( + f"{portal_base_url}/oauth/token", + data={ + "grant_type": "refresh_token", + "client_id": state["client_id"], + "refresh_token": state["refresh_token"], + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + }, + ) + if response.status_code != 200: + body = response.text.lower() + relogin = any(m in body for m in + ("invalid_grant", "refresh_token_reused", "invalid_refresh_token")) + raise AuthError( + f"MiniMax OAuth refresh failed: {response.text or response.reason_phrase}", + provider="minimax-oauth", code="refresh_failed", + relogin_required=relogin, + ) + payload = response.json() + if payload.get("status") != "success": + raise AuthError( + "MiniMax OAuth refresh did not return success.", + provider="minimax-oauth", code="refresh_failed", + relogin_required=True, + ) + now_dt = datetime.now(timezone.utc) + expires_in_s = int(payload["expired_in"]) + new_state = dict(state) + new_state.update({ + "access_token": payload["access_token"], + "refresh_token": payload.get("refresh_token", state["refresh_token"]), + "obtained_at": now_dt.isoformat(), + "expires_at": datetime.fromtimestamp(now_dt.timestamp() + expires_in_s, + tz=timezone.utc).isoformat(), + "expires_in": expires_in_s, + }) + _minimax_save_auth_state(new_state) + return new_state + + +def resolve_minimax_oauth_runtime_credentials( + *, min_token_ttl_seconds: int = MINIMAX_OAUTH_REFRESH_SKEW_SECONDS, +) -> Dict[str, Any]: + """Return {provider, api_key, base_url, source} for minimax-oauth.""" + state = get_provider_auth_state("minimax-oauth") + if not state or not state.get("access_token"): + raise AuthError( + "Not logged into MiniMax OAuth. Run `hermes model` and select " + "MiniMax (OAuth).", + provider="minimax-oauth", code="not_logged_in", relogin_required=True, + ) + state = _refresh_minimax_oauth_state(state) + return { + "provider": "minimax-oauth", + "api_key": state["access_token"], + "base_url": state["inference_base_url"].rstrip("/"), + "source": "oauth", + } + + +def get_minimax_oauth_auth_status() -> Dict[str, Any]: + """Return auth status dict for MiniMax OAuth provider.""" + state = get_provider_auth_state("minimax-oauth") + if not state or not state.get("access_token"): + return {"logged_in": False, "provider": "minimax-oauth"} + try: + expires_at = datetime.fromisoformat(state.get("expires_at", "")).timestamp() + token_valid = (expires_at - time.time()) > 0 + except Exception: + token_valid = bool(state.get("access_token")) + return { + "logged_in": token_valid, + "provider": "minimax-oauth", + "region": state.get("region", "global"), + "expires_at": state.get("expires_at"), + } + + +def _login_minimax_oauth(args, pconfig: ProviderConfig) -> None: + """CLI entry for MiniMax OAuth login.""" + region = getattr(args, "region", None) or "global" + open_browser = not getattr(args, "no_browser", False) + timeout = getattr(args, "timeout", None) or 15.0 + try: + _minimax_oauth_login( + region=region, open_browser=open_browser, timeout_seconds=timeout, + ) + except AuthError as exc: + print(format_auth_error(exc)) + raise SystemExit(1) + + def _nous_device_code_login( *, portal_base_url: Optional[str] = None, diff --git a/hermes_cli/auth_commands.py b/hermes_cli/auth_commands.py index 94ea2559c4..a9eb206647 100644 --- a/hermes_cli/auth_commands.py +++ b/hermes_cli/auth_commands.py @@ -33,7 +33,7 @@ from hermes_constants import OPENROUTER_BASE_URL # Providers that support OAuth login in addition to API keys. -_OAUTH_CAPABLE_PROVIDERS = {"anthropic", "nous", "openai-codex", "qwen-oauth", "google-gemini-cli"} +_OAUTH_CAPABLE_PROVIDERS = {"anthropic", "nous", "openai-codex", "qwen-oauth", "google-gemini-cli", "minimax-oauth"} def _get_custom_provider_names() -> list: @@ -170,7 +170,7 @@ def auth_add_command(args) -> None: if provider.startswith(CUSTOM_POOL_PREFIX): requested_type = AUTH_TYPE_API_KEY else: - requested_type = AUTH_TYPE_OAUTH if provider in {"anthropic", "nous", "openai-codex", "qwen-oauth", "google-gemini-cli"} else AUTH_TYPE_API_KEY + requested_type = AUTH_TYPE_OAUTH if provider in {"anthropic", "nous", "openai-codex", "qwen-oauth", "google-gemini-cli", "minimax-oauth"} else AUTH_TYPE_API_KEY pool = load_pool(provider) @@ -333,6 +333,27 @@ def auth_add_command(args) -> None: print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"') return + if provider == "minimax-oauth": + from hermes_cli.auth import resolve_minimax_oauth_runtime_credentials + creds = resolve_minimax_oauth_runtime_credentials() + label = (getattr(args, "label", None) or "").strip() or label_from_token( + creds["api_key"], + _oauth_default_label(provider, len(pool.entries()) + 1), + ) + entry = PooledCredential( + provider=provider, + id=uuid.uuid4().hex[:6], + label=label, + auth_type=AUTH_TYPE_OAUTH, + priority=0, + source=f"{SOURCE_MANUAL}:minimax_oauth", + access_token=creds["api_key"], + base_url=creds.get("base_url"), + ) + pool.add_entry(entry) + print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"') + return + raise SystemExit(f"`hermes auth add {provider}` is not implemented for auth type {requested_type} yet.") diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index d46c853997..c8446f04d9 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -5,6 +5,7 @@ Pure display functions with no HermesCLI state dependency. import json import logging +import os import shutil import subprocess import threading @@ -122,35 +123,36 @@ def get_available_skills() -> Dict[str, List[str]]: # Cache update check results for 6 hours to avoid repeated git fetches _UPDATE_CHECK_CACHE_SECONDS = 6 * 3600 +# Sentinel returned when we know an update exists but can't count commits +# (e.g. nix-built hermes — no local git history to count against). +UPDATE_AVAILABLE_NO_COUNT = -1 -def check_for_updates() -> Optional[int]: - """Check how many commits behind origin/main the local repo is. +_UPSTREAM_REPO_URL = "https://github.com/NousResearch/hermes-agent.git" - Does a ``git fetch`` at most once every 6 hours (cached to - ``~/.hermes/.update_check``). Returns the number of commits behind, - or ``None`` if the check fails or isn't applicable. + +def _check_via_rev(local_rev: str) -> Optional[int]: + """Compare an embedded git revision to upstream main via ls-remote. + + Returns 0 if up-to-date, ``UPDATE_AVAILABLE_NO_COUNT`` if behind, + or ``None`` on failure. """ - hermes_home = get_hermes_home() - repo_dir = hermes_home / "hermes-agent" - cache_file = hermes_home / ".update_check" - - # Must be a git repo — fall back to project root for dev installs - if not (repo_dir / ".git").exists(): - repo_dir = Path(__file__).parent.parent.resolve() - if not (repo_dir / ".git").exists(): - return None - - # Read cache - now = time.time() try: - if cache_file.exists(): - cached = json.loads(cache_file.read_text()) - if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS: - return cached.get("behind") + result = subprocess.run( + ["git", "ls-remote", _UPSTREAM_REPO_URL, "refs/heads/main"], + capture_output=True, text=True, timeout=10, + ) except Exception: - pass + return None + if result.returncode != 0 or not result.stdout: + return None + upstream_rev = result.stdout.split()[0] + if not upstream_rev: + return None + return 0 if upstream_rev == local_rev else UPDATE_AVAILABLE_NO_COUNT - # Fetch latest refs (fast — only downloads ref metadata, no files) + +def _check_via_local_git(repo_dir: Path) -> Optional[int]: + """Count commits behind origin/main in a local checkout.""" try: subprocess.run( ["git", "fetch", "origin", "--quiet"], @@ -160,7 +162,6 @@ def check_for_updates() -> Optional[int]: except Exception: pass # Offline or timeout — use stale refs, that's fine - # Count commits behind try: result = subprocess.run( ["git", "rev-list", "--count", "HEAD..origin/main"], @@ -168,15 +169,52 @@ def check_for_updates() -> Optional[int]: cwd=str(repo_dir), ) if result.returncode == 0: - behind = int(result.stdout.strip()) - else: - behind = None + return int(result.stdout.strip()) except Exception: - behind = None + pass + return None - # Write cache + +def check_for_updates() -> Optional[int]: + """Check whether a Hermes update is available. + + Two paths: if ``HERMES_REVISION`` is set (nix builds embed it), compare + it to upstream main via ``git ls-remote``. Otherwise look for a local + git checkout and count commits behind ``origin/main``. + + Returns the number of commits behind, ``UPDATE_AVAILABLE_NO_COUNT`` (-1) + if behind but the count is unknown, ``0`` if up-to-date, or ``None`` if + the check failed or doesn't apply. Cached for 6 hours. + """ + hermes_home = get_hermes_home() + cache_file = hermes_home / ".update_check" + embedded_rev = os.environ.get("HERMES_REVISION") or None + + # Read cache — invalidate if the embedded rev has changed since last check + now = time.time() try: - cache_file.write_text(json.dumps({"ts": now, "behind": behind})) + if cache_file.exists(): + cached = json.loads(cache_file.read_text()) + if ( + now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS + and cached.get("rev") == embedded_rev + ): + return cached.get("behind") + except Exception: + pass + + if embedded_rev: + behind = _check_via_rev(embedded_rev) + else: + repo_dir = hermes_home / "hermes-agent" + if not (repo_dir / ".git").exists(): + repo_dir = Path(__file__).parent.parent.resolve() + if not (repo_dir / ".git").exists(): + return None + behind = _check_via_local_git(repo_dir) + + try: + cache_file.write_text(json.dumps({"ts": now, "behind": behind, "rev": embedded_rev})) except Exception: pass @@ -549,13 +587,23 @@ def build_welcome_banner(console: Console, model: str, cwd: str, # Update check — use prefetched result if available try: behind = get_update_result(timeout=0.5) - if behind and behind > 0: - from hermes_cli.config import recommended_update_command - commits_word = "commit" if behind == 1 else "commits" - right_lines.append( - f"[bold yellow]⚠ {behind} {commits_word} behind[/]" - f"[dim yellow] — run [bold]{recommended_update_command()}[/bold] to update[/]" - ) + if behind is not None and behind != 0: + from hermes_cli.config import get_managed_update_command, recommended_update_command + if behind > 0: + commits_word = "commit" if behind == 1 else "commits" + right_lines.append( + f"[bold yellow]⚠ {behind} {commits_word} behind[/]" + f"[dim yellow] — run [bold]{recommended_update_command()}[/bold] to update[/]" + ) + else: + # UPDATE_AVAILABLE_NO_COUNT: nix-built hermes; we know an update + # exists but not by how much, and we don't know how the user + # installed it (nix run, profile, system flake, home-manager). + managed_cmd = get_managed_update_command() + line = "[bold yellow]⚠ update available[/]" + if managed_cmd: + line += f"[dim yellow] — run [bold]{managed_cmd}[/bold][/]" + right_lines.append(line) except Exception: pass # Never break the banner over an update check diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 7e3e14c540..5ca562d87a 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -155,6 +155,8 @@ COMMAND_REGISTRY: list[CommandDef] = [ cli_only=True), CommandDef("reload-mcp", "Reload MCP servers from config", "Tools & Skills", aliases=("reload_mcp",)), + CommandDef("reload-skills", "Re-scan ~/.hermes/skills/ for newly installed or removed skills", + "Tools & Skills", aliases=("reload_skills",)), CommandDef("browser", "Connect browser tools to your live Chrome via CDP", "Tools & Skills", cli_only=True, args_hint="[connect|disconnect|status]", subcommands=("connect", "disconnect", "status")), diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 22ad4004f3..0c3f40ab67 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -73,6 +73,8 @@ _EXTRA_ENV_KEYS = frozenset({ "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", # legacy aliases (pre-rename, still read for back-compat) "QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT", "QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL", + "IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", + "IRC_USE_TLS", "IRC_SERVER_PASSWORD", "IRC_NICKSERV_PASSWORD", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", "MATTERMOST_HOME_CHANNEL", "MATTERMOST_HOME_CHANNEL_NAME", "MATTERMOST_REPLY_MODE", @@ -499,7 +501,8 @@ DEFAULT_CONFIG = { "singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20", "modal_image": "nikolaik/python-nodejs:python3.11-nodejs20", "daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20", - # Container resource limits (docker, singularity, modal, daytona — ignored for local/ssh) + "vercel_runtime": "node24", + # Container resource limits (docker, singularity, modal, daytona, vercel_sandbox — ignored for local/ssh) "container_cpu": 1, "container_memory": 5120, # MB (default 5GB) "container_disk": 51200, # MB (default 50GB) @@ -515,6 +518,16 @@ DEFAULT_CONFIG = { # Explicit opt-in: mount the host cwd into /workspace for Docker sessions. # Default off because passing host directories into a sandbox weakens isolation. "docker_mount_cwd_to_workspace": False, + # Explicit opt-in: run the Docker container as the host user's uid:gid + # (via `--user`). When enabled, files written into bind-mounted dirs + # (docker_volumes, the persistent workspace, or the auto-mounted cwd) + # are owned by your host user instead of root, which avoids needing + # `sudo chown` after container runs. Default off to preserve behavior + # for images whose entrypoints expect to start as root (e.g. the + # bundled Hermes image, which drops to the `hermes` user via gosu). + # When on, SETUID/SETGID caps are omitted from the container since + # no privilege drop is needed. + "docker_run_as_host_user": False, # Persistent shell — keep a long-lived bash shell across execute() calls # so cwd/env vars/shell variables survive between commands. # Enabled by default for non-local backends (SSH); local is always opt-in @@ -696,6 +709,19 @@ DEFAULT_CONFIG = { "timeout": 30, "extra_body": {}, }, + # Curator — skill-usage review fork. Timeout is generous because the + # review pass can take several minutes on reasoning models (umbrella + # building over hundreds of candidate skills). "auto" = use main chat + # model; override via `hermes model` → auxiliary → Curator to route + # to a cheaper aux model (e.g. openrouter google/gemini-3-flash-preview). + "curator": { + "provider": "auto", + "model": "", + "base_url": "", + "api_key": "", + "timeout": 600, + "extra_body": {}, + }, }, "display": { @@ -753,7 +779,7 @@ DEFAULT_CONFIG = { # limit (OpenAI 4096, xAI 15000, MiniMax 10000, ElevenLabs 5k-40k model-aware, # Gemini 5000, Edge 5000, Mistral 4000, NeuTTS/KittenTTS 2000). "tts": { - "provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" | "xai" | "minimax" | "mistral" | "neutts" (local) + "provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai" | "xai" | "minimax" | "mistral" | "gemini" | "neutts" (local) | "kittentts" (local) | "piper" (local) "edge": { "voice": "en-US-AriaNeural", # Popular: AriaNeural, JennyNeural, AndrewNeural, BrianNeural, SoniaNeural @@ -783,6 +809,19 @@ DEFAULT_CONFIG = { "model": "neuphonic/neutts-air-q4-gguf", # HuggingFace model repo "device": "cpu", # cpu, cuda, or mps }, + "piper": { + # Voice name (e.g. "en_US-lessac-medium") downloaded on first + # use, OR an absolute path to a pre-downloaded .onnx file. + # Full voice list: https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/VOICES.md + "voice": "en_US-lessac-medium", + # "voices_dir": "", # Override voice cache dir; default = ~/.hermes/cache/piper-voices/ + # "use_cuda": False, # Requires onnxruntime-gpu + # "length_scale": 1.0, # 2.0 = twice as slow + # "noise_scale": 0.667, + # "noise_w_scale": 0.8, + # "volume": 1.0, + # "normalize_audio": True, + }, }, "stt": { @@ -936,12 +975,6 @@ DEFAULT_CONFIG = { # Archive a skill (move to skills/.archive/) after this many days # without use. Archived skills are recoverable — no auto-deletion. "archive_after_days": 90, - # Optional per-task override for the curator's aux model. Leave null - # to use Hermes' main auxiliary client resolution. - "auxiliary": { - "provider": None, - "model": None, - }, }, # Honcho AI-native memory -- reads ~/.honcho/config.json as single source of truth. @@ -1007,6 +1040,14 @@ DEFAULT_CONFIG = { "mode": "manual", "timeout": 60, "cron_mode": "deny", + # When true, /reload-mcp asks the user to confirm before rebuilding + # the MCP tool set for the active session. Reloading invalidates + # the provider prompt cache (tool schemas are baked into the system + # prompt), so the next message re-sends full input tokens — this can + # be expensive on long-context or high-reasoning models. Users click + # "Always Approve" to silence the prompt permanently; that flips + # this key to false. + "mcp_reload_confirm": True, }, # Permanently allowed dangerous command patterns (added via "always" approval) @@ -2065,6 +2106,43 @@ OPTIONAL_ENV_VARS = { "prompt": "QQ Sandbox Mode", "category": "messaging", }, + "IRC_SERVER": { + "description": "IRC server hostname (e.g. irc.libera.chat)", + "prompt": "IRC server", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_CHANNEL": { + "description": "IRC channel to join (e.g. #hermes)", + "prompt": "IRC channel", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_NICKNAME": { + "description": "Bot nickname on IRC (default: hermes-bot)", + "prompt": "IRC nickname", + "url": None, + "password": False, + "category": "messaging", + }, + "IRC_SERVER_PASSWORD": { + "description": "IRC server password (if required)", + "prompt": "IRC server password", + "url": None, + "password": True, + "category": "messaging", + "advanced": True, + }, + "IRC_NICKSERV_PASSWORD": { + "description": "NickServ password for nick identification", + "prompt": "NickServ password", + "url": None, + "password": True, + "category": "messaging", + "advanced": True, + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", @@ -3477,6 +3555,52 @@ def _normalize_max_turns_config(config: Dict[str, Any]) -> Dict[str, Any]: return config +def cfg_get(cfg: Optional[Dict[str, Any]], *keys: str, default: Any = None) -> Any: + """Traverse nested dict keys safely, returning ``default`` on any miss. + + Canonical helper for the ``cfg.get("X", {}).get("Y", default)`` pattern + that appears 50+ times across the codebase. Handles three common gotchas + in one place: + + 1. Missing intermediate keys (returns ``default``, no KeyError). + 2. An intermediate value that's not a dict (e.g. a user wrote a string + where a section was expected). Returns ``default`` instead of + AttributeError on ``.get()``. + 3. ``cfg is None`` (callers sometimes pass ``load_config() or None``). + + Named ``cfg_get`` rather than ``cfg_path`` to avoid shadowing the + ubiquitous ``cfg_path = _hermes_home / "config.yaml"`` local variable + that appears in gateway/run.py, cron/scheduler.py, main.py, etc. + + Explicit ``None`` values are returned as-is (matches ``dict.get(key, + default)`` semantics — ``default`` is only returned when the key is + *absent*, not when it's present but set to ``None``). + + Examples: + >>> cfg_get({"agent": {"reasoning_effort": "high"}}, "agent", "reasoning_effort") + 'high' + >>> cfg_get({}, "agent", "reasoning_effort", default="medium") + 'medium' + >>> cfg_get({"agent": "oops_a_string"}, "agent", "reasoning_effort", default="low") + 'low' + >>> cfg_get(None, "anything", default=42) + 42 + >>> cfg_get({"a": {"b": None}}, "a", "b", default="def") # explicit None preserved + >>> cfg_get({"a": {"b": False}}, "a", "b", default=True) # falsy values preserved + False + """ + if not isinstance(cfg, dict): + return default + node: Any = cfg + for key in keys: + if not isinstance(node, dict): + return default + if key not in node: + return default + node = node[key] + return node + + def read_raw_config() -> Dict[str, Any]: """Read ~/.hermes/config.yaml as-is, without merging defaults or migrating. @@ -4137,6 +4261,9 @@ def show_config(): print(f" Daytona image: {terminal.get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}") daytona_key = get_env_value('DAYTONA_API_KEY') print(f" API key: {'configured' if daytona_key else '(not set)'}") + elif terminal.get('backend') == 'vercel_sandbox': + print(f" Vercel runtime: {terminal.get('vercel_runtime', 'node24')}") + print(f" Vercel auth: {'configured' if get_env_value('VERCEL_OIDC_TOKEN') or (get_env_value('VERCEL_TOKEN') and get_env_value('VERCEL_PROJECT_ID') and get_env_value('VERCEL_TEAM_ID')) else '(not set)'}") elif terminal.get('backend') == 'ssh': ssh_host = get_env_value('TERMINAL_SSH_HOST') ssh_user = get_env_value('TERMINAL_SSH_USER') @@ -4329,7 +4456,9 @@ def set_config_value(key: str, value: str): "terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE", "terminal.modal_image": "TERMINAL_MODAL_IMAGE", "terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE", + "terminal.vercel_runtime": "TERMINAL_VERCEL_RUNTIME", "terminal.docker_mount_cwd_to_workspace": "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", + "terminal.docker_run_as_host_user": "TERMINAL_DOCKER_RUN_AS_HOST_USER", "terminal.cwd": "TERMINAL_CWD", "terminal.timeout": "TERMINAL_TIMEOUT", "terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR", diff --git a/hermes_cli/curator.py b/hermes_cli/curator.py index f580005794..a8bbcbafb0 100644 --- a/hermes_cli/curator.py +++ b/hermes_cli/curator.py @@ -55,6 +55,9 @@ def _cmd_status(args) -> int: print(f" runs: {runs}") print(f" last run: {_fmt_ts(last_run)}") print(f" last summary: {summary}") + _report = state.get("last_report_path") + if _report: + print(f" last report: {_report}") _ih = curator.get_interval_hours() _interval_label = ( f"{_ih // 24}d" if _ih % 24 == 0 and _ih >= 24 diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index dbba03fae6..f0822bdce8 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -8,6 +8,7 @@ import os import sys import subprocess import shutil +import importlib.util from pathlib import Path from hermes_cli.config import get_project_root, get_hermes_home, get_env_path @@ -30,6 +31,7 @@ load_dotenv(PROJECT_ROOT / ".env", override=False, encoding="utf-8") from hermes_cli.colors import Colors, color from hermes_cli.models import _HERMES_USER_AGENT +from hermes_cli.vercel_auth import describe_vercel_auth from hermes_constants import OPENROUTER_MODELS_URL from utils import base_url_host_matches @@ -76,6 +78,14 @@ def _system_package_install_cmd(pkg: str) -> str: return f"sudo apt install {pkg}" +def _safe_which(cmd: str) -> str | None: + """shutil.which wrapper resilient to platform monkeypatching in tests.""" + try: + return shutil.which(cmd) + except Exception: + return None + + def _termux_browser_setup_steps(node_installed: bool) -> list[str]: steps: list[str] = [] step = 1 @@ -537,6 +547,7 @@ def run_doctor(args): get_nous_auth_status, get_codex_auth_status, get_gemini_oauth_auth_status, + get_minimax_oauth_auth_status, ) nous_status = get_nous_auth_status() @@ -566,10 +577,17 @@ def run_doctor(args): check_ok("Google Gemini OAuth", f"(logged in{suffix})") else: check_warn("Google Gemini OAuth", "(not logged in)") + + minimax_status = get_minimax_oauth_auth_status() + if minimax_status.get("logged_in"): + region = minimax_status.get("region", "global") + check_ok("MiniMax OAuth", f"(logged in, region={region})") + else: + check_warn("MiniMax OAuth", "(not logged in)") except Exception as e: check_warn("Auth provider status", f"(could not check: {e})") - if shutil.which("codex"): + if _safe_which("codex"): check_ok("codex CLI") else: # Native OAuth uses Hermes' own device-code flow — the Codex CLI is @@ -787,13 +805,13 @@ def run_doctor(args): print(color("◆ External Tools", Colors.CYAN, Colors.BOLD)) # Git - if shutil.which("git"): + if _safe_which("git"): check_ok("git") else: check_warn("git not found", "(optional)") # ripgrep (optional, for faster file search) - if shutil.which("rg"): + if _safe_which("rg"): check_ok("ripgrep (rg)", "(faster file search)") else: check_warn("ripgrep (rg) not found", "(file search uses grep fallback)") @@ -802,7 +820,7 @@ def run_doctor(args): # Docker (optional) terminal_env = os.getenv("TERMINAL_ENV", "local") if terminal_env == "docker": - if shutil.which("docker"): + if _safe_which("docker"): # Check if docker daemon is running try: result = subprocess.run(["docker", "info"], capture_output=True, timeout=10) @@ -817,7 +835,7 @@ def run_doctor(args): check_fail("docker not found", "(required for TERMINAL_ENV=docker)") issues.append("Install Docker or change TERMINAL_ENV") else: - if shutil.which("docker"): + if _safe_which("docker"): check_ok("docker", "(optional)") else: if _is_termux(): @@ -863,8 +881,52 @@ def run_doctor(args): check_fail("daytona SDK not installed", "(pip install daytona)") issues.append("Install daytona SDK: pip install daytona") + # Vercel Sandbox (if using vercel_sandbox backend) + if terminal_env == "vercel_sandbox": + runtime = os.getenv("TERMINAL_VERCEL_RUNTIME", "node24").strip() or "node24" + from tools.terminal_tool import _SUPPORTED_VERCEL_RUNTIMES + if runtime in _SUPPORTED_VERCEL_RUNTIMES: + check_ok("Vercel runtime", f"({runtime})") + else: + supported = ", ".join(_SUPPORTED_VERCEL_RUNTIMES) + check_fail("Vercel runtime unsupported", f"({runtime}; use {supported})") + issues.append(f"Set TERMINAL_VERCEL_RUNTIME to one of: {supported}") + + disk = os.getenv("TERMINAL_CONTAINER_DISK", "51200").strip() + if disk in ("", "0", "51200"): + check_ok("Vercel disk setting", "(uses platform default)") + else: + check_fail("Vercel custom disk unsupported", "(reset terminal.container_disk to 51200)") + issues.append("Vercel Sandbox does not support custom container_disk; use the shared default 51200") + + if importlib.util.find_spec("vercel") is not None: + check_ok("vercel SDK", "(installed)") + else: + check_fail("vercel SDK not installed", "(pip install 'hermes-agent[vercel]')") + issues.append("Install the Vercel optional dependency: pip install 'hermes-agent[vercel]'") + + auth_status = describe_vercel_auth() + if auth_status.ok: + check_ok("Vercel auth", f"({auth_status.label})") + elif auth_status.label.startswith("partial"): + check_fail("Vercel auth incomplete", f"({auth_status.label})") + issues.append("Set VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID together") + else: + check_fail("Vercel auth not configured", f"({auth_status.label})") + issues.append( + "Configure Vercel Sandbox auth with VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID" + ) + for line in auth_status.detail_lines: + check_info(f"Vercel auth {line}") + + persistent = os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("1", "true", "yes", "on") + if persistent: + check_info("Vercel persistence: snapshot filesystem only; live processes do not survive sandbox recreation") + else: + check_info("Vercel persistence: ephemeral filesystem") + # Node.js + agent-browser (for browser automation tools) - if shutil.which("node"): + if _safe_which("node"): check_ok("Node.js") # Check if agent-browser is installed agent_browser_path = PROJECT_ROOT / "node_modules" / "agent-browser" @@ -890,7 +952,7 @@ def run_doctor(args): check_warn("Node.js not found", "(optional, needed for browser tools)") # npm audit for all Node.js packages - if shutil.which("npm"): + if _safe_which("npm"): npm_dirs = [ (PROJECT_ROOT, "Browser tools (agent-browser)"), (PROJECT_ROOT / "scripts" / "whatsapp-bridge", "WhatsApp bridge"), @@ -969,10 +1031,16 @@ def run_doctor(args): print(" Checking Anthropic API...", end="", flush=True) try: import httpx - from agent.anthropic_adapter import _is_oauth_token, _COMMON_BETAS, _OAUTH_ONLY_BETAS + from agent.anthropic_adapter import ( + _is_oauth_token, + _COMMON_BETAS, + _OAUTH_ONLY_BETAS, + _CONTEXT_1M_BETA, + ) headers = {"anthropic-version": "2023-06-01"} - if _is_oauth_token(anthropic_key): + is_oauth = _is_oauth_token(anthropic_key) + if is_oauth: headers["Authorization"] = f"Bearer {anthropic_key}" headers["anthropic-beta"] = ",".join(_COMMON_BETAS + _OAUTH_ONLY_BETAS) else: @@ -982,6 +1050,25 @@ def run_doctor(args): headers=headers, timeout=10 ) + # Reactive recovery: OAuth subscriptions that don't include 1M + # context reject the request with 400 "long context beta is not + # yet available for this subscription". Retry once with that + # beta stripped so the doctor check doesn't falsely report the + # Anthropic API as unreachable for those users. + if ( + is_oauth + and response.status_code == 400 + and "long context beta" in response.text.lower() + and "not yet available" in response.text.lower() + ): + headers["anthropic-beta"] = ",".join( + [b for b in _COMMON_BETAS if b != _CONTEXT_1M_BETA] + list(_OAUTH_ONLY_BETAS) + ) + response = httpx.get( + "https://api.anthropic.com/v1/models", + headers=headers, + timeout=10, + ) if response.status_code == 200: print(f"\r {color('✓', Colors.GREEN)} Anthropic API ") elif response.status_code == 401: diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index a3896b5bbd..9670a1d83b 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -279,9 +279,11 @@ def _scan_gateway_pids(exclude_pids: set[int], all_profiles: bool = False) -> li ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True, + encoding="utf-8", + errors="ignore", timeout=10, ) - if result.returncode != 0: + if result.returncode != 0 or result.stdout is None: return [] current_cmd = "" for line in result.stdout.split("\n"): @@ -830,6 +832,22 @@ def _user_dbus_socket_path() -> Path: return Path(xdg) / "bus" +def _user_systemd_private_socket_path() -> Path: + """Return the per-user systemd private socket path (regardless of existence).""" + xdg = os.environ.get("XDG_RUNTIME_DIR") or f"/run/user/{os.getuid()}" + return Path(xdg) / "systemd" / "private" + + +def _user_systemd_socket_ready() -> bool: + """Return True when user-scope systemd has a reachable control socket. + + Some distros expose only the per-user systemd private socket even when the + D-Bus session bus socket is absent. ``systemctl --user`` can still work in + that configuration, so preflight checks must treat either socket as valid. + """ + return _user_dbus_socket_path().exists() or _user_systemd_private_socket_path().exists() + + def _ensure_user_systemd_env() -> None: """Ensure DBUS_SESSION_BUS_ADDRESS and XDG_RUNTIME_DIR are set for systemctl --user. @@ -853,28 +871,29 @@ def _ensure_user_systemd_env() -> None: def _wait_for_user_dbus_socket(timeout: float = 3.0) -> bool: - """Poll for the user D-Bus socket to appear, up to ``timeout`` seconds. + """Poll for the user systemd runtime socket(s), up to ``timeout`` seconds. - Linger-enabled user@.service can take a second or two to spawn the socket - after ``loginctl enable-linger`` runs. Returns True once the socket exists. + Linger-enabled user@.service can take a second or two to spawn its control + socket(s) after ``loginctl enable-linger`` runs. Returns True once either + the user D-Bus socket or the per-user systemd private socket exists. """ import time deadline = time.monotonic() + timeout while time.monotonic() < deadline: - if _user_dbus_socket_path().exists(): + if _user_systemd_socket_ready(): _ensure_user_systemd_env() return True time.sleep(0.2) - return _user_dbus_socket_path().exists() + return _user_systemd_socket_ready() def _preflight_user_systemd(*, auto_enable_linger: bool = True) -> None: - """Ensure ``systemctl --user`` will reach the user D-Bus session bus. + """Ensure ``systemctl --user`` will reach the user-scope systemd instance. - No-op when the bus socket is already there (the common case on desktops - and linger-enabled servers). On fresh SSH sessions where the socket is - missing: + No-op when the user D-Bus socket or per-user systemd private socket is + already there (the common case on desktops and linger-enabled servers). On + fresh SSH sessions where both are missing: * If linger is already enabled, wait briefly for user@.service to spawn the socket. @@ -888,8 +907,7 @@ def _preflight_user_systemd(*, auto_enable_linger: bool = True) -> None: systemd operations and surface the message to the user. """ _ensure_user_systemd_env() - bus_path = _user_dbus_socket_path() - if bus_path.exists(): + if _user_systemd_socket_ready(): return import getpass @@ -903,7 +921,7 @@ def _preflight_user_systemd(*, auto_enable_linger: bool = True) -> None: # Linger is on but socket still missing — unusual; fall through to error. _raise_user_systemd_unavailable( username, - reason="User D-Bus socket is missing even though linger is enabled.", + reason="User systemd control sockets are missing even though linger is enabled.", fix_hint=( f" systemctl start user@{os.getuid()}.service\n" " (may require sudo; try again after the command succeeds)" @@ -2743,15 +2761,77 @@ _PLATFORMS = [ ], }, ] +def _all_platforms() -> list[dict]: + """Return the full list of platforms for setup menus. + + Combines the built-in ``_PLATFORMS`` with plugin platforms registered via + ``platform_registry``. Plugins are discovered on first call so bundled + platforms (like IRC, which auto-load via ``kind: platform``) appear in + ``hermes setup gateway`` without needing the gateway to be running. + Built-ins keep their dict shape; plugin entries are adapted to the same + shape with ``_registry_entry`` holding the source. + """ + # Populate the registry so plugin platforms are visible. Idempotent. + # Bundled platform plugins (``kind: platform``) auto-load unconditionally, + # so every shipped messaging channel appears in the setup menu by default. + # User-installed platform plugins under ~/.hermes/plugins/ still require + # opt-in via ``plugins.enabled`` (untrusted code). + try: + from hermes_cli.plugins import discover_plugins + discover_plugins() + except Exception as e: + logger.debug("plugin discovery failed during platform enumeration: %s", e) + + platforms = [dict(p) for p in _PLATFORMS] + by_key = {p["key"]: p for p in platforms} + + try: + from gateway.platform_registry import platform_registry + except Exception: + return platforms + + for entry in platform_registry.all_entries(): + if entry.name in by_key: + continue # built-in already covers it + platforms.append({ + "key": entry.name, + "label": entry.label, + "emoji": entry.emoji, + "token_var": entry.required_env[0] if entry.required_env else "", + "install_hint": entry.install_hint, + "_registry_entry": entry, + }) + return platforms def _platform_status(platform: dict) -> str: """Return a plain-text status string for a platform. Returns uncolored text so it can safely be embedded in - simple_term_menu items (ANSI codes break width calculation). + curses menu items (ANSI codes break width calculation). """ - token_var = platform["token_var"] + entry = platform.get("_registry_entry") + if entry is not None: + configured = False + # Prefer is_connected (checks both env and config.yaml) over + # check_fn (typically just dependency / env presence). + if entry.is_connected is not None: + try: + from gateway.config import PlatformConfig + synthetic = PlatformConfig(enabled=True) + configured = bool(entry.is_connected(synthetic)) + except Exception: + configured = False + if not configured: + try: + configured = bool(entry.check_fn()) + except Exception: + configured = False + return "configured" if configured else "not configured" + + token_var = platform.get("token_var", "") + if not token_var: + return "not configured" val = get_env_value(token_var) if token_var == "WHATSAPP_ENABLED": if val and val.lower() == "true": @@ -3277,6 +3357,12 @@ def _setup_weixin(): print_warning(" Direct messages disabled.") print() + print_info(" Note: QR login connects an iLink bot identity (e.g. ...@im.bot), not a") + print_info(" scriptable personal WeChat account. Ordinary WeChat groups typically cannot") + print_info(" invite an @im.bot identity, and iLink does not deliver ordinary-group events") + print_info(" to most bot accounts. The settings below only apply when iLink actually") + print_info(" delivers group events for your account type — otherwise DM remains the only") + print_info(" working channel regardless of this choice.") group_choices = [ "Disable group chats (recommended)", "Allow all group chats", @@ -3290,12 +3376,12 @@ def _setup_weixin(): elif group_idx == 1: save_env_value("WEIXIN_GROUP_POLICY", "open") save_env_value("WEIXIN_GROUP_ALLOWED_USERS", "") - print_warning(" All group chats enabled.") + print_warning(" All group chats enabled (only takes effect if iLink delivers group events).") else: - allow_groups = prompt(" Allowed group chat IDs (comma-separated)", "", password=False).replace(" ", "") + allow_groups = prompt(" Allowed group chat IDs (comma-separated, not member user IDs)", "", password=False).replace(" ", "") save_env_value("WEIXIN_GROUP_POLICY", "allowlist") save_env_value("WEIXIN_GROUP_ALLOWED_USERS", allow_groups) - print_success(" Group allowlist saved.") + print_success(" Group allowlist saved (only takes effect if iLink delivers group events).") if user_id: print() @@ -3703,6 +3789,71 @@ def _setup_signal(): print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}") +def _builtin_setup_fn(key: str): + """Resolve the interactive setup function for a built-in platform key. + + Late-bound to avoid a circular import with ``hermes_cli.setup`` (which + imports from this module for the remaining bespoke flows). + """ + from hermes_cli import setup as _s + return { + "telegram": _s._setup_telegram, + "discord": _s._setup_discord, + "slack": _s._setup_slack, + "matrix": _s._setup_matrix, + "mattermost": _s._setup_mattermost, + "bluebubbles": _s._setup_bluebubbles, + "webhooks": _s._setup_webhooks, + "signal": _setup_signal, + "whatsapp": _setup_whatsapp, + "weixin": _setup_weixin, + "dingtalk": _setup_dingtalk, + "feishu": _setup_feishu, + "wecom": _setup_wecom, + "qqbot": _setup_qqbot, + }.get(key) +def _configure_platform(platform: dict) -> None: + """Run the interactive setup flow for a single platform. + + Dispatch order: + 1. Plugin-provided ``setup_fn`` on the registry entry. + 2. Built-in setup function matched by platform key. + 3. ``_setup_standard_platform`` when the entry has a ``vars`` schema. + 4. Env-var hint fallback for plugins that offer no setup helper. + + Bundled platform plugins (e.g. IRC) auto-load, so no plugin enable step + is needed here. User-installed platform plugins under ~/.hermes/plugins/ + must already be in ``plugins.enabled`` before they appear in this menu. + """ + entry = platform.get("_registry_entry") + + if entry is not None and entry.setup_fn is not None: + entry.setup_fn() + return + + fn = _builtin_setup_fn(platform["key"]) + if fn is not None: + fn() + return + + if platform.get("vars"): + _setup_standard_platform(platform) + return + + # Plugin with no setup helper — show env-var instructions. + label = platform.get("label", platform["key"]) + emoji = platform.get("emoji", "🔌") + print() + print(color(f" ─── {emoji} {label} Setup ───", Colors.CYAN)) + required = entry.required_env if entry else [] + if required: + print_info(f" Set these env vars in ~/.hermes/.env: {', '.join(required)}") + else: + print_info(f" Configure {label} in config.yaml under gateway.platforms.{platform['key']}") + if platform.get("install_hint"): + print_info(f" {platform['install_hint']}") + + def gateway_setup(): """Interactive setup for messaging platforms + gateway service.""" if is_managed(): @@ -3755,42 +3906,36 @@ def gateway_setup(): print() print_header("Messaging Platforms") - menu_items = [] - for plat in _PLATFORMS: - status = _platform_status(plat) - menu_items.append(f"{plat['label']} ({status})") + platforms = _all_platforms() + + menu_items = [ + f"{p['emoji']} {p['label']} ({_platform_status(p)})" + for p in platforms + ] menu_items.append("Done") choice = prompt_choice("Select a platform to configure:", menu_items, len(menu_items) - 1) - - if choice == len(_PLATFORMS): + if choice == len(platforms): break - platform = _PLATFORMS[choice] - - if platform["key"] == "whatsapp": - _setup_whatsapp() - elif platform["key"] == "signal": - _setup_signal() - elif platform["key"] == "weixin": - _setup_weixin() - elif platform["key"] == "dingtalk": - _setup_dingtalk() - elif platform["key"] == "feishu": - _setup_feishu() - elif platform["key"] == "qqbot": - _setup_qqbot() - elif platform["key"] == "wecom": - _setup_wecom() - else: - _setup_standard_platform(platform) + _configure_platform(platforms[choice]) # ── Post-setup: offer to install/restart gateway ── + # Consider any platform (built-in or plugin) where the user has made + # meaningful progress. ``_platform_status`` already handles plugin + # entries via their check_fn and per-platform dual-states like + # WhatsApp's "enabled, not paired". + def _is_progress(status: str) -> bool: + s = status.lower() + return not ( + s == "not configured" + or s.startswith("partially") + or s.startswith("plugin disabled") + ) + any_configured = any( - bool(get_env_value(p["token_var"])) - for p in _PLATFORMS - if p["key"] != "whatsapp" - ) or (get_env_value("WHATSAPP_ENABLED") or "").lower() == "true" + _is_progress(_platform_status(p)) for p in _all_platforms() + ) if any_configured: print() @@ -4228,4 +4373,4 @@ def _gateway_command_inner(args): if not supports_systemd_services() and not is_macos(): print("Legacy unit migration only applies to systemd-based Linux hosts.") return - remove_legacy_hermes_units(interactive=not yes, dry_run=dry_run) + remove_legacy_hermes_units(interactive=not yes, dry_run=dry_run) \ No newline at end of file diff --git a/hermes_cli/main.py b/hermes_cli/main.py index def2ef34ff..bdbf0390a6 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -114,6 +114,12 @@ def _apply_profile_override() -> None: consume = 1 break + # 1.5 If HERMES_HOME is already set and no explicit flag was given, trust it. + # This lets child processes (relaunch, subprocess) inherit the parent's + # profile choice without having to pass --profile again. + if profile_name is None and os.environ.get("HERMES_HOME"): + return + # 2. If no flag, check active_profile in the hermes root if profile_name is None: try: @@ -1094,11 +1100,36 @@ def _make_tui_argv(tui_dir: Path, tui_dev: bool) -> tuple[list[str], Path]: return [node, str(root / "dist" / "entry.js")], root +def _normalize_tui_toolsets(toolsets: object) -> list[str]: + """Normalize argparse/Fire-style toolset input for the TUI subprocess.""" + try: + from hermes_cli.oneshot import _normalize_toolsets + + return _normalize_toolsets(toolsets) or [] + except (AttributeError, ImportError): + if not toolsets: + return [] + + raw_items = [toolsets] if isinstance(toolsets, str) else toolsets + if not isinstance(raw_items, (list, tuple)): + raw_items = [raw_items] + + normalized: list[str] = [] + for item in raw_items: + if isinstance(item, str): + normalized.extend(part.strip() for part in item.split(",")) + else: + normalized.append(str(item).strip()) + + return [item for item in normalized if item] + + def _launch_tui( resume_session_id: Optional[str] = None, tui_dev: bool = False, model: Optional[str] = None, provider: Optional[str] = None, + toolsets: object = None, ): """Replace current process with the TUI.""" tui_dir = PROJECT_ROOT / "ui-tui" @@ -1123,6 +1154,9 @@ def _launch_tui( if provider: env["HERMES_TUI_PROVIDER"] = provider env["HERMES_INFERENCE_PROVIDER"] = provider + tui_toolsets = _normalize_tui_toolsets(toolsets) + if tui_toolsets: + env["HERMES_TUI_TOOLSETS"] = ",".join(tui_toolsets) # Guarantee an 8GB V8 heap + exposed GC for the TUI. Default node cap is # ~1.5–4GB depending on version and can fatal-OOM on long sessions with # large transcripts / reasoning blobs. Token-level merge: respect any @@ -1270,6 +1304,7 @@ def cmd_chat(args): tui_dev=getattr(args, "tui_dev", False), model=getattr(args, "model", None), provider=getattr(args, "provider", None), + toolsets=getattr(args, "toolsets", None), ) # Import and run the CLI @@ -1770,6 +1805,8 @@ def select_provider_and_model(args=None): _model_flow_openai_codex(config, current_model) elif selected_provider == "qwen-oauth": _model_flow_qwen_oauth(config, current_model) + elif selected_provider == "minimax-oauth": + _model_flow_minimax_oauth(config, current_model, args=args) elif selected_provider == "google-gemini-cli": _model_flow_google_gemini_cli(config, current_model) elif selected_provider == "copilot-acp": @@ -1890,6 +1927,7 @@ _AUX_TASKS: list[tuple[str, str, str]] = [ ("mcp", "MCP", "MCP tool reasoning"), ("title_generation", "Title generation", "session titles"), ("skills_hub", "Skills hub", "skills search/install"), + ("curator", "Curator", "skill-usage review pass"), ] @@ -2658,6 +2696,53 @@ def _model_flow_qwen_oauth(_config, current_model=""): print("No change.") +def _model_flow_minimax_oauth(config, current_model="", args=None): + """MiniMax OAuth provider: ensure logged in, then pick model.""" + from hermes_cli.auth import ( + get_provider_auth_state, + _prompt_model_selection, + _save_model_choice, + _update_config_for_provider, + resolve_minimax_oauth_runtime_credentials, + AuthError, + format_auth_error, + _login_minimax_oauth, + PROVIDER_REGISTRY, + ) + state = get_provider_auth_state("minimax-oauth") + if not state or not state.get("access_token"): + print("Not logged into MiniMax. Starting OAuth login...") + print() + try: + mock_args = argparse.Namespace( + region=getattr(args, "region", None) or "global", + no_browser=bool(getattr(args, "no_browser", False)), + timeout=getattr(args, "timeout", None) or 15.0, + ) + _login_minimax_oauth(mock_args, PROVIDER_REGISTRY["minimax-oauth"]) + except SystemExit: + print("Login cancelled or failed.") + return + except Exception as exc: + print(f"Login failed: {exc}") + return + + try: + creds = resolve_minimax_oauth_runtime_credentials() + except AuthError as exc: + print(format_auth_error(exc)) + return + + from hermes_cli.models import _PROVIDER_MODELS + model_ids = _PROVIDER_MODELS.get("minimax-oauth", []) + selected = _prompt_model_selection(model_ids, current_model) + if not selected: + return + _save_model_choice(selected) + _update_config_for_provider("minimax-oauth", creds["base_url"]) + print(f"\u2713 Using MiniMax model: {selected}") + + def _model_flow_google_gemini_cli(_config, current_model=""): """Google Gemini OAuth (PKCE) via Cloud Code Assist — supports free AND paid tiers. @@ -5251,8 +5336,8 @@ def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool: return True -def _warn_stale_dashboard_processes() -> None: - """Warn about running dashboard processes that still hold pre-update code. +def _find_stale_dashboard_pids() -> list[int]: + """Return PIDs of ``hermes dashboard`` processes other than ourselves. ``hermes dashboard`` is a long-lived server process commonly started and forgotten. When ``hermes update`` replaces files on disk, the running @@ -5260,9 +5345,13 @@ def _warn_stale_dashboard_processes() -> None: disk is updated, causing a silent frontend/backend mismatch (e.g. new auth headers the old backend doesn't recognise → every API call 401s). - Unlike the gateway, the dashboard has no service manager (systemd / - launchd), so we can only warn — we don't auto-kill user-managed - background processes. + The dashboard has no service manager (systemd / launchd), no PID file, + and we can't know the original launch args — so the only sane action + after an update is to kill the stale process and let the user restart + it. This helper is just the detection step; see + ``_kill_stale_dashboard_processes`` for the kill. + + Returns an empty list on any scan error (missing ps/wmic, timeout, etc.). """ patterns = [ "hermes dashboard", @@ -5274,13 +5363,21 @@ def _warn_stale_dashboard_processes() -> None: try: if sys.platform == "win32": + # wmic may emit text in the system code page (for example cp936 + # on zh-CN systems), not UTF-8. In text mode, subprocess output + # decoding depends on Python's configuration (locale-dependent + # by default, or UTF-8 in UTF-8 mode). The important protection + # here is errors="ignore": it prevents a reader-thread + # UnicodeDecodeError from leaving result.stdout=None and turning + # the later .split() into an AttributeError (#17049). result = subprocess.run( ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True, timeout=10, + encoding="utf-8", errors="ignore", ) - if result.returncode != 0: - return + if result.returncode != 0 or result.stdout is None: + return [] current_cmd = "" for line in result.stdout.split("\n"): line = line.strip() @@ -5306,7 +5403,7 @@ def _warn_stale_dashboard_processes() -> None: capture_output=True, text=True, timeout=10, ) if result.returncode == 0: - for line in result.stdout.split("\n"): + for line in getattr(result, "stdout", "").split("\n"): stripped = line.strip() if not stripped or "grep" in stripped: continue @@ -5322,20 +5419,112 @@ def _warn_stale_dashboard_processes() -> None: and pid != self_pid): dashboard_pids.append(pid) except (FileNotFoundError, subprocess.TimeoutExpired, OSError): - return + return [] - if not dashboard_pids: + return dashboard_pids + + +def _kill_stale_dashboard_processes( + reason: str = "the running backend no longer matches the updated frontend", +) -> None: + """Kill running ``hermes dashboard`` processes. + + Called at the end of ``hermes update`` (default ``reason``) and also + from ``hermes dashboard --stop`` (which overrides ``reason``). The + dashboard has no service manager, so after a code update the running + process is guaranteed to be serving stale Python against a + freshly-updated JS bundle. Leaving it alive produces silent + frontend/backend mismatches (new auth headers the old backend doesn't + recognise → every API call 401s). + + POSIX: SIGTERM, wait up to ~3s for graceful exit, SIGKILL any survivors. + Windows: ``taskkill /PID /F`` since there's no clean SIGTERM + equivalent for background console apps. + + The dashboard isn't auto-restarted because we don't know the original + launch args (--host, --port, --insecure, --tui, --no-open). The user + restarts it manually; a hint is printed. + """ + pids = _find_stale_dashboard_pids() + if not pids: return print() - print(f"⚠ {len(dashboard_pids)} dashboard process(es) still running " - f"with the previous version:") - for pid in dashboard_pids: - print(f" PID {pid}") - print(" The running backend may not match the updated frontend,") - print(" causing silent auth failures or empty data.") - print(" Restart them to pick up the changes:") - print(" kill && hermes dashboard --port ...") + print(f"⟲ Stopping {len(pids)} dashboard process(es) ({reason})") + + killed: list[int] = [] + failed: list[tuple[int, str]] = [] + + if sys.platform == "win32": + for pid in pids: + try: + result = subprocess.run( + ["taskkill", "/PID", str(pid), "/F"], + capture_output=True, text=True, timeout=10, + ) + if result.returncode == 0: + killed.append(pid) + else: + failed.append((pid, (result.stderr or result.stdout or "").strip())) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError) as e: + failed.append((pid, str(e))) + else: + import signal as _signal + import time as _time + + # SIGTERM first — give each process a chance to shut down cleanly + # (uvicorn closes its socket, flushes logs, etc.). + for pid in pids: + try: + os.kill(pid, _signal.SIGTERM) + except ProcessLookupError: + # Already gone — count as killed. + killed.append(pid) + except (PermissionError, OSError) as e: + failed.append((pid, str(e))) + + # Poll for exit up to ~3s total. + deadline = _time.monotonic() + 3.0 + pending = [p for p in pids if p not in killed + and p not in {f[0] for f in failed}] + while pending and _time.monotonic() < deadline: + _time.sleep(0.1) + still_pending = [] + for pid in pending: + try: + os.kill(pid, 0) # probe + except ProcessLookupError: + killed.append(pid) + except (PermissionError, OSError): + # Can't probe — assume still there. + still_pending.append(pid) + else: + still_pending.append(pid) + pending = still_pending + + # SIGKILL any survivors. + for pid in pending: + try: + os.kill(pid, _signal.SIGKILL) + killed.append(pid) + except ProcessLookupError: + killed.append(pid) + except (PermissionError, OSError) as e: + failed.append((pid, str(e))) + + for pid in killed: + print(f" ✓ stopped PID {pid}") + for pid, reason in failed: + print(f" ✗ failed to stop PID {pid}: {reason}") + + if killed: + print(" Restart the dashboard when you're ready:") + print(" hermes dashboard --port ") + + +# Back-compat alias: some tests and any external callers may import the old +# warn-only name. The new behaviour (kill stale processes) replaces it. +_warn_stale_dashboard_processes = _kill_stale_dashboard_processes def _update_via_zip(args): @@ -5472,7 +5661,7 @@ def _update_via_zip(args): print() print("✓ Update complete!") - _warn_stale_dashboard_processes() + _kill_stale_dashboard_processes() def _stash_local_changes_if_needed(git_cmd: list[str], cwd: Path) -> Optional[str]: @@ -7289,9 +7478,12 @@ def _cmd_update_impl(args, gateway_mode: bool): except Exception as e: logger.debug("Legacy unit check during update failed: %s", e) - # Warn about stale dashboard processes — the dashboard has no - # service manager, so we can only tell the user to restart them. - _warn_stale_dashboard_processes() + # Kill stale dashboard processes — the dashboard has no service + # manager, so leaving it alive after a code update produces a + # silent frontend/backend mismatch. We can't auto-restart it + # (no saved launch args) but we can stop it, and a hint is + # printed for the user to re-launch. + _kill_stale_dashboard_processes() print() print("Tip: You can now select a provider and model:") @@ -7682,8 +7874,59 @@ def cmd_profile(args): sys.exit(1) +def _report_dashboard_status() -> int: + """Print ``hermes dashboard`` PIDs and return the count. + + Uses the same detection logic as ``_find_stale_dashboard_pids`` (the + current process is excluded, but since ``hermes dashboard --status`` + runs in a short-lived CLI process that never matches the pattern, + the exclusion is irrelevant here). + """ + pids = _find_stale_dashboard_pids() + if not pids: + print("No hermes dashboard processes running.") + return 0 + + print(f"{len(pids)} hermes dashboard process(es) running:") + for pid in pids: + # Best-effort: show the full cmdline so users can tell profiles apart. + cmdline = "" + try: + if sys.platform != "win32": + cmdline_path = f"/proc/{pid}/cmdline" + if os.path.exists(cmdline_path): + with open(cmdline_path, "rb") as f: + cmdline = f.read().replace(b"\x00", b" ").decode( + "utf-8", errors="replace").strip() + except (OSError, ValueError): + pass + if cmdline: + print(f" PID {pid}: {cmdline}") + else: + print(f" PID {pid}") + return len(pids) + + def cmd_dashboard(args): - """Start the web UI server.""" + """Start the web UI server, or (with --stop/--status) manage running ones.""" + # --status: report running dashboards and exit, no deps needed. + if getattr(args, "status", False): + count = _report_dashboard_status() + sys.exit(0 if count == 0 else 0) # status is informational, always 0 + + # --stop: kill any running dashboards and exit, no deps needed. + if getattr(args, "stop", False): + pids = _find_stale_dashboard_pids() + if not pids: + print("No hermes dashboard processes running.") + sys.exit(0) + # Reuse the same SIGTERM-grace-SIGKILL path used after `hermes update`. + _kill_stale_dashboard_processes(reason="requested via --stop") + # _kill_stale_dashboard_processes prints outcomes itself. Exit 0 if + # we killed at least one, 1 if they were all unkillable. + remaining = _find_stale_dashboard_pids() + sys.exit(1 if remaining else 0) + try: import fastapi # noqa: F401 import uvicorn # noqa: F401 @@ -7750,302 +7993,9 @@ def cmd_logs(args): def main(): """Main entry point for hermes CLI.""" - parser = argparse.ArgumentParser( - prog="hermes", - description="Hermes Agent - AI assistant with tool-calling capabilities", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - hermes Start interactive chat - hermes chat -q "Hello" Single query mode - hermes -c Resume the most recent session - hermes -c "my project" Resume a session by name (latest in lineage) - hermes --resume Resume a specific session by ID - hermes setup Run setup wizard - hermes logout Clear stored authentication - hermes auth add Add a pooled credential - hermes auth list List pooled credentials - hermes auth remove

Remove pooled credential by index, id, or label - hermes auth reset Clear exhaustion status for a provider - hermes model Select default model - hermes fallback [list] Show fallback provider chain - hermes fallback add Add a fallback provider (same picker as `hermes model`) - hermes fallback remove Remove a fallback provider from the chain - hermes config View configuration - hermes config edit Edit config in $EDITOR - hermes config set model gpt-4 Set a config value - hermes gateway Run messaging gateway - hermes -s hermes-agent-dev,github-auth - hermes -w Start in isolated git worktree - hermes gateway install Install gateway background service - hermes sessions list List past sessions - hermes sessions browse Interactive session picker - hermes sessions rename ID T Rename/title a session - hermes logs View agent.log (last 50 lines) - hermes logs -f Follow agent.log in real time - hermes logs errors View errors.log - hermes logs --since 1h Lines from the last hour - hermes debug share Upload debug report for support - hermes update Update to latest version + from hermes_cli._parser import build_top_level_parser -For more help on a command: - hermes --help -""", - ) - - parser.add_argument( - "--version", "-V", action="store_true", help="Show version and exit" - ) - parser.add_argument( - "-z", - "--oneshot", - metavar="PROMPT", - default=None, - help=( - "One-shot mode: send a single prompt and print ONLY the final " - "response text to stdout. No banner, no spinner, no tool " - "previews, no session_id line. Tools, memory, rules, and " - "AGENTS.md in the CWD are loaded as normal; approvals are " - "auto-bypassed. Intended for scripts / pipes." - ), - ) - # --model / --provider are accepted at the top level so they can pair - # with -z without needing the `chat` subcommand. If neither -z nor a - # subcommand consumes them, they fall through harmlessly as None. - # Mirrors `hermes chat --model ... --provider ...` semantics. - parser.add_argument( - "-m", - "--model", - default=None, - help=( - "Model override for this invocation (e.g. anthropic/claude-sonnet-4.6). " - "Applies to -z/--oneshot and --tui. Also settable via HERMES_INFERENCE_MODEL env var." - ), - ) - parser.add_argument( - "--provider", - default=None, - help=( - "Provider override for this invocation (e.g. openrouter, anthropic). " - "Applies to -z/--oneshot and --tui. Also settable via HERMES_INFERENCE_PROVIDER env var." - ), - ) - parser.add_argument( - "--resume", - "-r", - metavar="SESSION", - default=None, - help="Resume a previous session by ID or title", - ) - parser.add_argument( - "--continue", - "-c", - dest="continue_last", - nargs="?", - const=True, - default=None, - metavar="SESSION_NAME", - help="Resume a session by name, or the most recent if no name given", - ) - parser.add_argument( - "--worktree", - "-w", - action="store_true", - default=False, - help="Run in an isolated git worktree (for parallel agents)", - ) - parser.add_argument( - "--accept-hooks", - action="store_true", - default=False, - help=( - "Auto-approve any unseen shell hooks declared in config.yaml " - "without a TTY prompt. Equivalent to HERMES_ACCEPT_HOOKS=1 or " - "hooks_auto_accept: true in config.yaml. Use on CI / headless " - "runs that can't prompt." - ), - ) - parser.add_argument( - "--skills", - "-s", - action="append", - default=None, - help="Preload one or more skills for the session (repeat flag or comma-separate)", - ) - parser.add_argument( - "--yolo", - action="store_true", - default=False, - help="Bypass all dangerous command approval prompts (use at your own risk)", - ) - parser.add_argument( - "--pass-session-id", - action="store_true", - default=False, - help="Include the session ID in the agent's system prompt", - ) - parser.add_argument( - "--ignore-user-config", - action="store_true", - default=False, - help="Ignore ~/.hermes/config.yaml and fall back to built-in defaults (credentials in .env are still loaded)", - ) - parser.add_argument( - "--ignore-rules", - action="store_true", - default=False, - help="Skip auto-injection of AGENTS.md, SOUL.md, .cursorrules, memory, and preloaded skills", - ) - parser.add_argument( - "--tui", - action="store_true", - default=False, - help="Launch the modern TUI instead of the classic REPL", - ) - parser.add_argument( - "--dev", - dest="tui_dev", - action="store_true", - default=False, - help="With --tui: run TypeScript sources via tsx (skip dist build)", - ) - - subparsers = parser.add_subparsers(dest="command", help="Command to run") - - # ========================================================================= - # chat command - # ========================================================================= - chat_parser = subparsers.add_parser( - "chat", - help="Interactive chat with the agent", - description="Start an interactive chat session with Hermes Agent", - ) - chat_parser.add_argument( - "-q", "--query", help="Single query (non-interactive mode)" - ) - chat_parser.add_argument( - "--image", help="Optional local image path to attach to a single query" - ) - chat_parser.add_argument( - "-m", "--model", help="Model to use (e.g., anthropic/claude-sonnet-4)" - ) - chat_parser.add_argument( - "-t", "--toolsets", help="Comma-separated toolsets to enable" - ) - chat_parser.add_argument( - "-s", - "--skills", - action="append", - default=argparse.SUPPRESS, - help="Preload one or more skills for the session (repeat flag or comma-separate)", - ) - chat_parser.add_argument( - "--provider", - # No `choices=` here: user-defined providers from config.yaml `providers:` - # are also valid values, and runtime resolution (resolve_runtime_provider) - # handles validation/error reporting consistently with the top-level - # `--provider` flag. - default=None, - help="Inference provider (default: auto). Built-in or a user-defined name from `providers:` in config.yaml.", - ) - chat_parser.add_argument( - "-v", "--verbose", action="store_true", help="Verbose output" - ) - chat_parser.add_argument( - "-Q", - "--quiet", - action="store_true", - help="Quiet mode for programmatic use: suppress banner, spinner, and tool previews. Only output the final response and session info.", - ) - chat_parser.add_argument( - "--resume", - "-r", - metavar="SESSION_ID", - default=argparse.SUPPRESS, - help="Resume a previous session by ID (shown on exit)", - ) - chat_parser.add_argument( - "--continue", - "-c", - dest="continue_last", - nargs="?", - const=True, - default=argparse.SUPPRESS, - metavar="SESSION_NAME", - help="Resume a session by name, or the most recent if no name given", - ) - chat_parser.add_argument( - "--worktree", - "-w", - action="store_true", - default=argparse.SUPPRESS, - help="Run in an isolated git worktree (for parallel agents on the same repo)", - ) - chat_parser.add_argument( - "--accept-hooks", - action="store_true", - default=argparse.SUPPRESS, - help=( - "Auto-approve any unseen shell hooks declared in config.yaml " - "without a TTY prompt (see also HERMES_ACCEPT_HOOKS env var and " - "hooks_auto_accept: in config.yaml)." - ), - ) - chat_parser.add_argument( - "--checkpoints", - action="store_true", - default=False, - help="Enable filesystem checkpoints before destructive file operations (use /rollback to restore)", - ) - chat_parser.add_argument( - "--max-turns", - type=int, - default=None, - metavar="N", - help="Maximum tool-calling iterations per conversation turn (default: 90, or agent.max_turns in config)", - ) - chat_parser.add_argument( - "--yolo", - action="store_true", - default=argparse.SUPPRESS, - help="Bypass all dangerous command approval prompts (use at your own risk)", - ) - chat_parser.add_argument( - "--pass-session-id", - action="store_true", - default=argparse.SUPPRESS, - help="Include the session ID in the agent's system prompt", - ) - chat_parser.add_argument( - "--ignore-user-config", - action="store_true", - default=argparse.SUPPRESS, - help="Ignore ~/.hermes/config.yaml and fall back to built-in defaults (credentials in .env are still loaded). Useful for isolated CI runs, reproduction, and third-party integrations.", - ) - chat_parser.add_argument( - "--ignore-rules", - action="store_true", - default=argparse.SUPPRESS, - help="Skip auto-injection of AGENTS.md, SOUL.md, .cursorrules, memory, and preloaded skills. Combine with --ignore-user-config for a fully isolated run.", - ) - chat_parser.add_argument( - "--source", - default=None, - help="Session source tag for filtering (default: cli). Use 'tool' for third-party integrations that should not appear in user session lists.", - ) - chat_parser.add_argument( - "--tui", - action="store_true", - default=False, - help="Launch the modern TUI instead of the classic REPL", - ) - chat_parser.add_argument( - "--dev", - dest="tui_dev", - action="store_true", - default=False, - help="With --tui: run TypeScript sources via tsx (skip dist build)", - ) + parser, subparsers, chat_parser = build_top_level_parser() chat_parser.set_defaults(func=cmd_chat) # ========================================================================= @@ -9715,15 +9665,8 @@ Examples: # Launch hermes --resume by replacing the current process print(f"Resuming session: {selected_id}") - hermes_bin = shutil.which("hermes") - if hermes_bin: - os.execvp(hermes_bin, ["hermes", "--resume", selected_id]) - else: - # Fallback: re-invoke via python -m - os.execvp( - sys.executable, - [sys.executable, "-m", "hermes_cli.main", "--resume", selected_id], - ) + from hermes_cli.relaunch import relaunch + relaunch(["--resume", selected_id]) return # won't reach here after execvp elif action == "stats": @@ -10081,6 +10024,22 @@ Examples: "Alternatively set HERMES_DASHBOARD_TUI=1." ), ) + # Lifecycle flags — mutually exclusive with each other and with the + # start-a-server flags above (if both are passed, --stop / --status win + # because they exit before the server is started). The dashboard has + # no service manager and no PID file, so these scan the process table + # for `hermes dashboard` cmdlines and SIGTERM them directly — the same + # path `hermes update` uses to clean up stale dashboards. + dashboard_parser.add_argument( + "--stop", + action="store_true", + help="Stop all running hermes dashboard processes and exit", + ) + dashboard_parser.add_argument( + "--status", + action="store_true", + help="List running hermes dashboard processes and exit", + ) dashboard_parser.set_defaults(func=cmd_dashboard) # ========================================================================= @@ -10270,6 +10229,7 @@ Examples: args.oneshot, model=getattr(args, "model", None), provider=getattr(args, "provider", None), + toolsets=getattr(args, "toolsets", None), )) # Handle top-level --resume / --continue as shortcut to chat diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index ae845b069b..0e01f558dd 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -16,6 +16,7 @@ import time from typing import Any, Dict, List, Optional, Tuple from hermes_cli.config import ( + cfg_get, load_config, save_config, get_env_value, @@ -716,7 +717,7 @@ def cmd_mcp_configure(args): # Update config config = load_config() - server_entry = config.get("mcp_servers", {}).get(name, {}) + server_entry = cfg_get(config, "mcp_servers", name, default={}) if len(chosen) == total: # All selected → remove include/exclude (register all) diff --git a/hermes_cli/model_normalize.py b/hermes_cli/model_normalize.py index 99e6c34e48..433e342796 100644 --- a/hermes_cli/model_normalize.py +++ b/hermes_cli/model_normalize.py @@ -96,6 +96,7 @@ _MATCHING_PREFIX_STRIP_PROVIDERS: frozenset[str] = frozenset({ "kimi-coding", "kimi-coding-cn", "minimax", + "minimax-oauth", "minimax-cn", "alibaba", "qwen-oauth", diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index 869d82bf6d..1d37900f3c 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -1018,6 +1018,37 @@ def list_authenticated_providers( results: List[dict] = [] seen_slugs: set = set() # lowercase-normalized to catch case variants (#9545) seen_mdev_ids: set = set() # prevent duplicate entries for aliases (e.g. kimi-coding + kimi-coding-cn) + # Effective base URLs of every built-in row we emit (normalized lower+rstrip). + # Section 4 uses this to hide ``custom_providers`` entries that point at the + # same endpoint as a built-in (e.g. a user-defined "my-dashscope" on + # https://coding-intl.dashscope.aliyuncs.com/v1 collides with the built-in + # alibaba-coding-plan row when DASHSCOPE_API_KEY is present). Fixes #16970. + _builtin_endpoints: set = set() + + def _norm_url(url: str) -> str: + return str(url or "").strip().rstrip("/").lower() + + def _record_builtin_endpoint(slug: str) -> None: + """Record the effective base URL for a built-in provider row. + + Prefers the live env-override (e.g. DASHSCOPE_BASE_URL) over the + static inference_base_url so the dedup matches what a user typing + that URL into custom_providers would actually hit.""" + try: + from hermes_cli.auth import PROVIDER_REGISTRY as _reg + except Exception: + return + pcfg = _reg.get(slug) + if not pcfg: + return + url = "" + if getattr(pcfg, "base_url_env_var", ""): + url = os.environ.get(pcfg.base_url_env_var, "") or "" + if not url: + url = getattr(pcfg, "inference_base_url", "") or "" + normed = _norm_url(url) + if normed: + _builtin_endpoints.add(normed) data = fetch_models_dev() @@ -1124,6 +1155,7 @@ def list_authenticated_providers( }) seen_slugs.add(slug.lower()) seen_mdev_ids.add(mdev_id) + _record_builtin_endpoint(slug) # --- 2. Check Hermes-only providers (nous, openai-codex, copilot, opencode-go) --- from hermes_cli.providers import HERMES_OVERLAYS @@ -1238,6 +1270,7 @@ def list_authenticated_providers( }) seen_slugs.add(pid.lower()) seen_slugs.add(hermes_slug.lower()) + _record_builtin_endpoint(hermes_slug) # --- 2b. Cross-check canonical provider list --- # Catches providers that are in CANONICAL_PROVIDERS but weren't found @@ -1317,6 +1350,7 @@ def list_authenticated_providers( "source": "canonical", }) seen_slugs.add(_cp.slug.lower()) + _record_builtin_endpoint(_cp.slug) # --- 3. User-defined endpoints from config --- # Track (name, base_url) of what section 3 emits so section 4 can skip @@ -1526,6 +1560,15 @@ def list_authenticated_providers( ) if _pair_key[0] and _pair_key[1] and _pair_key in _section3_emitted_pairs: continue + # Skip if a built-in row (sections 1/2/2b) already represents this + # endpoint. Fixes #16970: a user-defined "my-dashscope" pointing at + # https://coding-intl.dashscope.aliyuncs.com/v1 duplicates the + # built-in alibaba-coding-plan row whenever DASHSCOPE_API_KEY is + # set. The built-in row carries the curated model list, correct + # auth wiring, and canonical slug — keep it and hide the shadow. + _grp_url_norm = _pair_key[1] + if _grp_url_norm and _grp_url_norm in _builtin_endpoints: + continue results.append({ "slug": slug, "name": grp["name"], diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 852c097536..f5ca1a3b22 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -288,6 +288,10 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "MiniMax-M2.1", "MiniMax-M2", ], + "minimax-oauth": [ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + ], "minimax-cn": [ "MiniMax-M2.7", "MiniMax-M2.5", @@ -788,6 +792,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [ ProviderEntry("kimi-coding-cn", "Kimi / Moonshot (China)", "Kimi / Moonshot China (Moonshot CN direct API)"), ProviderEntry("stepfun", "StepFun Step Plan", "StepFun Step Plan (agent/coding models via Step Plan API)"), ProviderEntry("minimax", "MiniMax", "MiniMax (global direct API)"), + ProviderEntry("minimax-oauth", "MiniMax (OAuth)", "MiniMax via OAuth browser login (Coding Plan, minimax.io)"), ProviderEntry("minimax-cn", "MiniMax (China)", "MiniMax China (domestic direct API)"), ProviderEntry("alibaba", "Alibaba Cloud (DashScope)","Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"), ProviderEntry("ollama-cloud", "Ollama Cloud", "Ollama Cloud (cloud-hosted open models — ollama.com)"), @@ -831,6 +836,9 @@ _PROVIDER_ALIASES = { "gmicloud": "gmi", "minimax-china": "minimax-cn", "minimax_cn": "minimax-cn", + "minimax-portal": "minimax-oauth", + "minimax-global": "minimax-oauth", + "minimax_oauth": "minimax-oauth", "claude": "anthropic", "claude-code": "anthropic", "deep-seek": "deepseek", @@ -2026,28 +2034,56 @@ def _fetch_anthropic_models(timeout: float = 5.0) -> Optional[list[str]]: return None headers: dict[str, str] = {"anthropic-version": "2023-06-01"} - if _is_oauth_token(token): + is_oauth = _is_oauth_token(token) + if is_oauth: headers["Authorization"] = f"Bearer {token}" - from agent.anthropic_adapter import _COMMON_BETAS, _OAUTH_ONLY_BETAS + from agent.anthropic_adapter import _COMMON_BETAS, _OAUTH_ONLY_BETAS, _CONTEXT_1M_BETA headers["anthropic-beta"] = ",".join(_COMMON_BETAS + _OAUTH_ONLY_BETAS) else: headers["x-api-key"] = token - req = urllib.request.Request( - "https://api.anthropic.com/v1/models", - headers=headers, - ) - try: + def _do_request(h: dict[str, str]): + req = urllib.request.Request( + "https://api.anthropic.com/v1/models", + headers=h, + ) with urllib.request.urlopen(req, timeout=timeout) as resp: - data = json.loads(resp.read().decode()) - models = [m["id"] for m in data.get("data", []) if m.get("id")] - # Sort: latest/largest first (opus > sonnet > haiku, higher version first) - return sorted(models, key=lambda m: ( - "opus" not in m, # opus first - "sonnet" not in m, # then sonnet - "haiku" not in m, # then haiku - m, # alphabetical within tier - )) + return json.loads(resp.read().decode()) + + try: + try: + data = _do_request(headers) + except urllib.error.HTTPError as http_err: + # Reactive recovery for OAuth subscriptions that reject the 1M + # context beta with 400 "long context beta is not yet available + # for this subscription". Retry once without the beta; re-raise + # anything else so the outer except logs it. + if ( + is_oauth + and http_err.code == 400 + ): + try: + body_text = http_err.read().decode(errors="ignore").lower() + except Exception: + body_text = "" + if "long context beta" in body_text and "not yet available" in body_text: + headers["anthropic-beta"] = ",".join( + [b for b in _COMMON_BETAS if b != _CONTEXT_1M_BETA] + + list(_OAUTH_ONLY_BETAS) + ) + data = _do_request(headers) + else: + raise + else: + raise + models = [m["id"] for m in data.get("data", []) if m.get("id")] + # Sort: latest/largest first (opus > sonnet > haiku, higher version first) + return sorted(models, key=lambda m: ( + "opus" not in m, # opus first + "sonnet" not in m, # then sonnet + "haiku" not in m, # then haiku + m, # alphabetical within tier + )) except Exception as e: import logging logging.getLogger(__name__).debug("Failed to fetch Anthropic models: %s", e) diff --git a/hermes_cli/oneshot.py b/hermes_cli/oneshot.py index e1065b662e..ca30f07904 100644 --- a/hermes_cli/oneshot.py +++ b/hermes_cli/oneshot.py @@ -3,7 +3,8 @@ Bypasses cli.py entirely. No banner, no spinner, no session_id line, no stderr chatter. Just the agent's final text to stdout. -Toolsets = whatever the user has configured for "cli" in `hermes tools`. +Toolsets = explicit --toolsets when provided, otherwise whatever the user has +configured for "cli" in `hermes tools`. Rules / memory / AGENTS.md / preloaded skills = same as a normal chat turn. Approvals = auto-bypassed (HERMES_YOLO_MODE=1 is set for the call). Working directory = the user's CWD (AGENTS.md etc. resolve from there as usual). @@ -28,10 +29,103 @@ from contextlib import redirect_stderr, redirect_stdout from typing import Optional +def _normalize_toolsets(toolsets: object = None) -> list[str] | None: + if not toolsets: + return None + + raw_items = [toolsets] if isinstance(toolsets, str) else toolsets + if not isinstance(raw_items, (list, tuple)): + raw_items = [raw_items] + + normalized: list[str] = [] + for item in raw_items: + if isinstance(item, str): + normalized.extend(part.strip() for part in item.split(",")) + else: + normalized.append(str(item).strip()) + + return [item for item in normalized if item] or None + + +def _validate_explicit_toolsets(toolsets: object = None) -> tuple[list[str] | None, str | None]: + normalized = _normalize_toolsets(toolsets) + if normalized is None: + return None, None + + try: + from toolsets import validate_toolset + except Exception as exc: + return None, f"hermes -z: failed to validate --toolsets: {exc}\n" + + built_in = [name for name in normalized if validate_toolset(name)] + unresolved = [name for name in normalized if name not in built_in] + + if unresolved: + try: + from hermes_cli.plugins import discover_plugins + + discover_plugins() + plugin_valid = [name for name in unresolved if validate_toolset(name)] + except Exception: + plugin_valid = [] + + if plugin_valid: + built_in.extend(plugin_valid) + unresolved = [name for name in unresolved if name not in plugin_valid] + + if any(name in {"all", "*"} for name in built_in): + ignored = [name for name in normalized if name not in {"all", "*"}] + if ignored: + sys.stderr.write( + "hermes -z: --toolsets all enables every toolset; " + f"ignoring additional entries: {', '.join(ignored)}\n" + ) + return None, None + + mcp_names: set[str] = set() + mcp_disabled: set[str] = set() + if unresolved: + try: + from hermes_cli.config import read_raw_config + from hermes_cli.tools_config import _parse_enabled_flag + + cfg = read_raw_config() + mcp_servers = cfg.get("mcp_servers") if isinstance(cfg.get("mcp_servers"), dict) else {} + for name, server_cfg in mcp_servers.items(): + if not isinstance(server_cfg, dict): + continue + if _parse_enabled_flag(server_cfg.get("enabled", True), default=True): + mcp_names.add(str(name)) + else: + mcp_disabled.add(str(name)) + except Exception: + mcp_names = set() + mcp_disabled = set() + + mcp_valid = [name for name in unresolved if name in mcp_names] + disabled = [name for name in unresolved if name in mcp_disabled] + unknown = [name for name in unresolved if name not in mcp_names and name not in mcp_disabled] + valid = built_in + mcp_valid + + if unknown: + sys.stderr.write(f"hermes -z: ignoring unknown --toolsets entries: {', '.join(unknown)}\n") + if disabled: + sys.stderr.write( + "hermes -z: ignoring disabled MCP servers (set enabled: true in config.yaml to use): " + f"{', '.join(disabled)}\n" + ) + + if not valid: + return None, "hermes -z: --toolsets did not contain any valid toolsets.\n" + + return valid, None + + def run_oneshot( prompt: str, model: Optional[str] = None, provider: Optional[str] = None, + toolsets: object = None, ) -> int: """Execute a single prompt and print only the final content block. @@ -42,6 +136,7 @@ def run_oneshot( provider: Optional provider override. Falls back to HERMES_INFERENCE_PROVIDER env var, then config.yaml's model.provider, then "auto". + toolsets: Optional comma-separated string or iterable of toolsets. Returns the exit code. Caller should sys.exit() with the return. """ @@ -65,6 +160,12 @@ def run_oneshot( ) return 2 + explicit_toolsets, toolsets_error = _validate_explicit_toolsets(toolsets) + if toolsets_error: + sys.stderr.write(toolsets_error) + return 2 + use_config_toolsets = _normalize_toolsets(toolsets) is None + # Auto-approve any shell / tool approvals. Non-interactive by # definition — a prompt would hang forever. os.environ["HERMES_YOLO_MODE"] = "1" @@ -77,7 +178,13 @@ def run_oneshot( try: with redirect_stdout(devnull), redirect_stderr(devnull): - response = _run_agent(prompt, model=model, provider=provider) + response = _run_agent( + prompt, + model=model, + provider=provider, + toolsets=explicit_toolsets, + use_config_toolsets=use_config_toolsets, + ) finally: try: devnull.close() @@ -96,6 +203,8 @@ def _run_agent( prompt: str, model: Optional[str] = None, provider: Optional[str] = None, + toolsets: object = None, + use_config_toolsets: bool = True, ) -> str: """Build an AIAgent exactly like a normal CLI chat turn would, then run a single conversation. Returns the final response string.""" @@ -168,9 +277,12 @@ def _run_agent( explicit_base_url=explicit_base_url_from_alias, ) - # Pull in whatever toolsets the user has enabled for "cli". - # sorted() gives stable ordering; set→list for AIAgent's signature. - toolsets_list = sorted(_get_platform_tools(cfg, "cli")) + # Pull in explicit toolsets when provided; otherwise use whatever the user + # has enabled for "cli". sorted() gives stable ordering for config-derived + # sets; explicit values preserve user order. + toolsets_list = _normalize_toolsets(toolsets) + if toolsets_list is None and use_config_toolsets: + toolsets_list = sorted(_get_platform_tools(cfg, "cli")) agent = AIAgent( api_key=runtime.get("api_key"), diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py index bc609277c4..e341b734ee 100644 --- a/hermes_cli/platforms.py +++ b/hermes_cli/platforms.py @@ -44,6 +44,40 @@ PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ def platform_label(key: str, default: str = "") -> str: - """Return the display label for a platform key, or *default*.""" + """Return the display label for a platform key, or *default*. + + Checks the static PLATFORMS dict first, then the plugin platform + registry for dynamically registered platforms. + """ info = PLATFORMS.get(key) - return info.label if info is not None else default + if info is not None: + return info.label + # Check plugin registry + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(key) + if entry: + return f"{entry.emoji} {entry.label}" if entry.emoji else entry.label + except Exception: + pass + return default + + +def get_all_platforms() -> "OrderedDict[str, PlatformInfo]": + """Return PLATFORMS merged with any plugin-registered platforms. + + Plugin platforms are appended after builtins. This is the function + that tools_config and skills_config should use for platform menus. + """ + merged = OrderedDict(PLATFORMS) + try: + from gateway.platform_registry import platform_registry + for entry in platform_registry.plugin_entries(): + if entry.name not in merged: + merged[entry.name] = PlatformInfo( + label=f"{entry.emoji} {entry.label}" if entry.emoji else entry.label, + default_toolset=f"hermes-{entry.name}", + ) + except Exception: + pass + return merged diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 3d514ddc12..d7913eb9b5 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -37,6 +37,7 @@ import importlib import importlib.metadata import importlib.util import logging +import os import sys import types from dataclasses import dataclass, field @@ -45,6 +46,20 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union from hermes_constants import get_hermes_home from utils import env_var_enabled +from hermes_cli.config import cfg_get + + +def get_bundled_plugins_dir() -> Path: + """Locate the bundled ``plugins/`` directory. + + Honours ``HERMES_BUNDLED_PLUGINS`` (set by the Nix wrapper / packaged + installs) so read-only store paths are consulted first. Falls back to + the in-repo path used during development. + """ + env_override = os.getenv("HERMES_BUNDLED_PLUGINS") + if env_override: + return Path(env_override) + return Path(__file__).resolve().parent.parent / "plugins" try: import yaml @@ -115,7 +130,7 @@ def _get_disabled_plugins() -> set: try: from hermes_cli.config import load_config config = load_config() - disabled = config.get("plugins", {}).get("disabled", []) + disabled = cfg_get(config, "plugins", "disabled", default=[]) return set(disabled) if isinstance(disabled, list) else set() except Exception: return set() @@ -155,7 +170,7 @@ def _get_enabled_plugins() -> Optional[set]: # Data classes # --------------------------------------------------------------------------- -_VALID_PLUGIN_KINDS: Set[str] = {"standalone", "backend", "exclusive"} +_VALID_PLUGIN_KINDS: Set[str] = {"standalone", "backend", "exclusive", "platform"} @dataclass @@ -181,6 +196,11 @@ class PluginManifest: # Selection via ``.provider`` config key; the # category's own discovery system handles loading and the # general scanner skips these. + # ``platform``: gateway messaging platform adapter (e.g. IRC). Bundled + # platform plugins auto-load so every shipped platform is + # available out of the box; user-installed platform plugins + # in ~/.hermes/plugins/ still gated by ``plugins.enabled`` + # (untrusted code). kind: str = "standalone" # Registry key — path-derived, used by ``plugins.enabled``/``disabled`` # lookups and by ``hermes plugins list``. For a flat plugin at @@ -444,6 +464,62 @@ class PluginContext: self.manifest.name, provider.name, ) + # -- platform adapter registration --------------------------------------- + + def register_platform( + self, + name: str, + label: str, + adapter_factory: Callable, + check_fn: Callable, + validate_config: Callable | None = None, + required_env: list | None = None, + install_hint: str = "", + **entry_kwargs: Any, + ) -> None: + """Register a gateway platform adapter. + + The adapter_factory receives a ``PlatformConfig`` and returns a + ``BasePlatformAdapter`` subclass instance. The gateway calls + ``check_fn()`` before instantiation to verify dependencies. + + Extra keyword arguments are forwarded to ``PlatformEntry`` (e.g. + ``setup_fn``, ``emoji``, ``allowed_users_env``, ``platform_hint``). + Unknown keys raise TypeError from the dataclass constructor. + + Example:: + + ctx.register_platform( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=lambda: True, + emoji="💬", + setup_fn=irc_interactive_setup, + ) + """ + from gateway.platform_registry import platform_registry, PlatformEntry + + entry_kwargs.setdefault("plugin_name", self.manifest.name) + entry = PlatformEntry( + name=name, + label=label, + adapter_factory=adapter_factory, + check_fn=check_fn, + validate_config=validate_config, + required_env=required_env or [], + install_hint=install_hint, + source="plugin", + **entry_kwargs, + ) + platform_registry.register(entry) + self._manager._plugin_platform_names.add(name) + logger.debug( + "Plugin %s registered platform: %s", + self.manifest.name, + name, + ) + # -- hook registration -------------------------------------------------- def register_hook(self, hook_name: str, callback: Callable) -> None: @@ -522,6 +598,7 @@ class PluginManager: self._plugins: Dict[str, LoadedPlugin] = {} self._hooks: Dict[str, List[Callable]] = {} self._plugin_tool_names: Set[str] = set() + self._plugin_platform_names: Set[str] = set() self._cli_commands: Dict[str, dict] = {} self._context_engine = None # Set by a plugin via register_context_engine() self._plugin_commands: Dict[str, dict] = {} # Slash commands registered by plugins @@ -564,16 +641,19 @@ class PluginManager: # - category: ``plugins/image_gen/openai/plugin.yaml`` (backend) # # ``memory/`` and ``context_engine/`` are skipped at the top level — - # they have their own discovery systems. Porting those to the - # category-namespace ``kind: exclusive`` model is a future PR. - repo_plugins = Path(__file__).resolve().parent.parent / "plugins" + # they have their own discovery systems. ``platforms/`` is a category + # holding platform adapters (scanned one level deeper below). + repo_plugins = get_bundled_plugins_dir() manifests.extend( self._scan_directory( repo_plugins, source="bundled", - skip_names={"memory", "context_engine"}, + skip_names={"memory", "context_engine", "platforms"}, ) ) + manifests.extend( + self._scan_directory(repo_plugins / "platforms", source="bundled") + ) # 2. User plugins (~/.hermes/plugins/) user_dir = get_hermes_home() / "plugins" @@ -630,7 +710,11 @@ class PluginManager: # just work. Selection among them (e.g. which image_gen backend # services calls) is driven by ``.provider`` config, # enforced by the tool wrapper. - if manifest.kind == "backend" and manifest.source == "bundled": + # + # Bundled platform plugins (gateway adapters like IRC) auto-load + # for the same reason: every platform Hermes ships must be + # available out of the box without the user having to opt in. + if manifest.source == "bundled" and manifest.kind in ("backend", "platform"): self._load_plugin(manifest) continue diff --git a/hermes_cli/plugins_cmd.py b/hermes_cli/plugins_cmd.py index 349a11de11..352dadd194 100644 --- a/hermes_cli/plugins_cmd.py +++ b/hermes_cli/plugins_cmd.py @@ -18,6 +18,7 @@ from pathlib import Path from typing import Optional from hermes_constants import get_hermes_home +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -519,7 +520,7 @@ def _get_disabled_set() -> set: try: from hermes_cli.config import load_config config = load_config() - disabled = config.get("plugins", {}).get("disabled", []) + disabled = cfg_get(config, "plugins", "disabled", default=[]) return set(disabled) if isinstance(disabled, list) else set() except Exception: return set() @@ -629,10 +630,9 @@ def _plugin_exists(name: str) -> bool: manifest = _read_manifest(child) if manifest.get("name") == name: return True - # Bundled: /plugins// - from pathlib import Path as _P - import hermes_cli - repo_plugins = _P(hermes_cli.__file__).resolve().parent.parent / "plugins" + # Bundled: /plugins// (or HERMES_BUNDLED_PLUGINS on Nix). + from hermes_cli.plugins import get_bundled_plugins_dir + repo_plugins = get_bundled_plugins_dir() if repo_plugins.is_dir(): candidate = repo_plugins / name if candidate.is_dir() and ( @@ -659,8 +659,8 @@ def _discover_all_plugins() -> list: seen: dict = {} # name -> (name, version, description, source, path) # Bundled (/plugins//), excluding memory/ and context_engine/ - import hermes_cli - repo_plugins = Path(hermes_cli.__file__).resolve().parent.parent / "plugins" + from hermes_cli.plugins import get_bundled_plugins_dir + repo_plugins = get_bundled_plugins_dir() for base, source in ((repo_plugins, "bundled"), (_plugins_dir(), "user")): if not base.is_dir(): continue @@ -763,7 +763,7 @@ def _get_current_memory_provider() -> str: try: from hermes_cli.config import load_config config = load_config() - return config.get("memory", {}).get("provider", "") or "" + return cfg_get(config, "memory", "provider", default="") or "" except Exception: return "" @@ -773,7 +773,7 @@ def _get_current_context_engine() -> str: try: from hermes_cli.config import load_config config = load_config() - return config.get("context", {}).get("engine", "compressor") or "compressor" + return cfg_get(config, "context", "engine", default="compressor") or "compressor" except Exception: return "compressor" diff --git a/hermes_cli/profiles.py b/hermes_cli/profiles.py index 872d59563c..dd5fabcec4 100644 --- a/hermes_cli/profiles.py +++ b/hermes_cli/profiles.py @@ -71,6 +71,29 @@ _CLONE_ALL_STRIP = [ "processes.json", ] + +def _clone_all_copytree_ignore(source_dir: Path): + """Ignore ``profiles/`` at the root of *source_dir* only. + + ``~/.hermes`` contains ``profiles//`` for sibling named profiles. + ``shutil.copytree`` would otherwise duplicate that entire tree inside the + new profile (recursive ``.../profiles/.../profiles/...``). Export already + excludes ``profiles`` via ``_DEFAULT_EXPORT_EXCLUDE_ROOT`` — match that + behavior for ``--clone-all``. + """ + source_resolved = source_dir.resolve() + + def _ignore(directory: str, names: List[str]) -> List[str]: + try: + if Path(directory).resolve() == source_resolved: + return [n for n in names if n == "profiles"] + except (OSError, ValueError): + pass + return [] + + return _ignore + + # Directories/files to exclude when exporting the default (~/.hermes) profile. # The default profile contains infrastructure (repo checkout, worktrees, DBs, # caches, binaries) that named profiles don't have. We exclude those so the @@ -425,8 +448,12 @@ def create_profile( ) if clone_all and source_dir: - # Full copy of source profile - shutil.copytree(source_dir, profile_dir) + # Full copy of source profile (exclude sibling ~/.hermes/profiles/) + shutil.copytree( + source_dir, + profile_dir, + ignore=_clone_all_copytree_ignore(source_dir), + ) # Strip runtime files for stale in _CLONE_ALL_STRIP: (profile_dir / stale).unlink(missing_ok=True) diff --git a/hermes_cli/providers.py b/hermes_cli/providers.py index 60f8dd8eaa..4909870954 100644 --- a/hermes_cli/providers.py +++ b/hermes_cli/providers.py @@ -111,6 +111,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = { transport="anthropic_messages", base_url_env_var="MINIMAX_BASE_URL", ), + "minimax-oauth": HermesOverlay( + transport="anthropic_messages", + auth_type="oauth_external", + base_url_override="https://api.minimax.io/anthropic", + ), "minimax-cn": HermesOverlay( transport="anthropic_messages", base_url_env_var="MINIMAX_CN_BASE_URL", diff --git a/hermes_cli/relaunch.py b/hermes_cli/relaunch.py new file mode 100644 index 0000000000..32a5dacd22 --- /dev/null +++ b/hermes_cli/relaunch.py @@ -0,0 +1,149 @@ +""" +Unified self-relaunch for Hermes CLI. + +Preserves critical flags (--tui, --dev, --profile, --model, etc.) across +process replacement so that ``hermes sessions browse`` or post-setup relaunch +doesn't silently drop the user's UI mode or other preferences. + +Also works when ``hermes`` is not on PATH (e.g. ``nix run`` or ``python -m``). +""" + +import os +import shutil +import sys +from typing import Optional, Sequence + +from hermes_cli._parser import ( + PRE_ARGPARSE_INHERITED_FLAGS, + build_top_level_parser, +) + + +def _build_inherited_flag_table() -> list[tuple[str, bool]]: + """Build the ``(option_string, takes_value)`` table of flags that must + survive a self-relaunch, by introspecting the real parser used by + ``hermes`` itself. + + A flag participates if its argparse Action carries + ``inherit_on_relaunch = True`` — set by ``_parser._inherited_flag``. + """ + parser, _subparsers, chat_parser = build_top_level_parser() + + table: list[tuple[str, bool]] = [] + seen: set[tuple[str, bool]] = set() + for p in (parser, chat_parser): + for action in p._actions: + if not action.option_strings: + continue # positional / no flag form + if not getattr(action, "inherit_on_relaunch", False): + continue + takes_value = action.nargs != 0 # store_true/false set nargs=0 + for opt in action.option_strings: + key = (opt, takes_value) + if key not in seen: + seen.add(key) + table.append(key) + + table.extend(PRE_ARGPARSE_INHERITED_FLAGS) + return table + + +_INHERITED_FLAGS_TABLE = _build_inherited_flag_table() + + +def _extract_inherited_flags(argv: Sequence[str]) -> list[str]: + """Pull out flags that should carry over into a self-relaunched hermes.""" + flags: list[str] = [] + i = 0 + while i < len(argv): + arg = argv[i] + if "=" in arg: + key = arg.split("=", 1)[0] + for flag, _ in _INHERITED_FLAGS_TABLE: + if key == flag: + flags.append(arg) + break + i += 1 + continue + + for flag, takes_value in _INHERITED_FLAGS_TABLE: + if arg == flag: + flags.append(arg) + if takes_value and i + 1 < len(argv) and not argv[i + 1].startswith("-"): + flags.append(argv[i + 1]) + i += 1 + break + i += 1 + return flags + + +def resolve_hermes_bin() -> Optional[str]: + """Find the hermes entry point. + + Priority: + 1. ``sys.argv[0]`` if it resolves to a real executable. + 2. ``shutil.which("hermes")`` on PATH. + 3. ``None`` → caller should fall back to ``python -m hermes_cli.main``. + """ + argv0 = sys.argv[0] + + # Absolute path to an executable (covers nix store, venv wrappers, etc.) + if os.path.isabs(argv0) and os.path.isfile(argv0) and os.access(argv0, os.X_OK): + return argv0 + + # Relative path — resolve against CWD + if not argv0.startswith("-") and os.path.isfile(argv0): + abs_path = os.path.abspath(argv0) + if os.access(abs_path, os.X_OK): + return abs_path + + # PATH lookup + path_bin = shutil.which("hermes") + if path_bin: + return path_bin + + return None + + +def build_relaunch_argv( + extra_args: Sequence[str], + *, + preserve_inherited: bool = True, + original_argv: Optional[Sequence[str]] = None, +) -> list[str]: + """Construct an argv list for replacing the current process with hermes. + + Args: + extra_args: Arguments to append (e.g. ``["--resume", id]``). + preserve_inherited: Whether to carry over UI / behaviour flags + tagged with ``inherit_on_relaunch`` in the parser. + original_argv: The original argv to scan for flags (defaults to + ``sys.argv[1:]``). + """ + bin_path = resolve_hermes_bin() + + if bin_path: + argv = [bin_path] + else: + argv = [sys.executable, "-m", "hermes_cli.main"] + + src = list(original_argv) if original_argv is not None else list(sys.argv[1:]) + + if preserve_inherited: + argv.extend(_extract_inherited_flags(src)) + + argv.extend(extra_args) + return argv + + +def relaunch( + extra_args: Sequence[str], + *, + preserve_inherited: bool = True, + original_argv: Optional[Sequence[str]] = None, +) -> None: + """Replace the current process with a fresh hermes invocation.""" + new_argv = build_relaunch_argv( + extra_args, preserve_inherited=preserve_inherited, original_argv=original_argv + ) + os.execvp(new_argv[0], new_argv) \ No newline at end of file diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index e2883c883f..c46ebf3991 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -1070,6 +1070,20 @@ def resolve_runtime_provider( logger.info("Qwen OAuth credentials failed; " "falling through to next provider.") + if provider == "minimax-oauth": + pconfig = PROVIDER_REGISTRY.get(provider) + if pconfig and pconfig.auth_type == "oauth_minimax": + from hermes_cli.auth import resolve_minimax_oauth_runtime_credentials + creds = resolve_minimax_oauth_runtime_credentials() + return { + "provider": provider, + "api_mode": "anthropic_messages", + "base_url": creds["base_url"], + "api_key": creds["api_key"], + "source": creds.get("source", "oauth"), + "requested_provider": requested_provider, + } + if provider == "google-gemini-cli": try: creds = resolve_gemini_oauth_runtime_credentials() diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 011b4575e4..3933ad8494 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -12,6 +12,7 @@ Config files are stored in ~/.hermes/ for easy access. """ import importlib.util +import json import logging import os import shutil @@ -131,6 +132,7 @@ def _set_reasoning_effort(config: Dict[str, Any], effort: str) -> None: # Import config helpers from hermes_cli.config import ( + cfg_get, DEFAULT_CONFIG, get_hermes_home, get_config_path, @@ -138,6 +140,7 @@ from hermes_cli.config import ( load_config, save_config, save_env_value, + remove_env_value, get_env_value, ensure_hermes_home, ) @@ -441,7 +444,7 @@ def _print_setup_summary(config: dict, hermes_home): tool_status.append(("Image Generation", False, "FAL_KEY or OPENAI_API_KEY")) # TTS — show configured provider - tts_provider = config.get("tts", {}).get("provider", "edge") + tts_provider = cfg_get(config, "tts", "provider", default="edge") if subscription_features.tts.managed_by_nous: tool_status.append(("Text-to-Speech (OpenAI via Nous subscription)", True, None)) elif tts_provider == "elevenlabs" and get_env_value("ELEVENLABS_API_KEY"): @@ -480,7 +483,7 @@ def _print_setup_summary(config: dict, hermes_home): if subscription_features.modal.managed_by_nous: tool_status.append(("Modal Execution (Nous subscription)", True, None)) - elif config.get("terminal", {}).get("backend") == "modal": + elif cfg_get(config, "terminal", "backend") == "modal": if subscription_features.modal.direct_override: tool_status.append(("Modal Execution (direct Modal)", True, None)) else: @@ -654,6 +657,102 @@ def _prompt_container_resources(config: dict): pass +def _prompt_vercel_sandbox_settings(config: dict): + """Prompt for Vercel Sandbox settings without exposing unsupported disk sizing.""" + terminal = config.setdefault("terminal", {}) + + print() + print_info("Vercel Sandbox settings:") + print_info(" Filesystem persistence uses Vercel snapshots.") + print_info(" Snapshots restore files only; live processes do not continue after sandbox recreation.") + + from tools.terminal_tool import _SUPPORTED_VERCEL_RUNTIMES + + current_runtime = terminal.get("vercel_runtime") or "node24" + supported_label = ", ".join(_SUPPORTED_VERCEL_RUNTIMES) + runtime = prompt(f" Runtime ({supported_label})", current_runtime).strip() or current_runtime + if runtime not in _SUPPORTED_VERCEL_RUNTIMES: + print_warning(f"Unsupported Vercel runtime '{runtime}', keeping {current_runtime}.") + runtime = current_runtime if current_runtime in _SUPPORTED_VERCEL_RUNTIMES else "node24" + terminal["vercel_runtime"] = runtime + save_env_value("TERMINAL_VERCEL_RUNTIME", runtime) + + current_persist = terminal.get("container_persistent", True) + persist_label = "yes" if current_persist else "no" + terminal["container_persistent"] = prompt( + " Persist filesystem with snapshots? (yes/no)", persist_label + ).lower() in ("yes", "true", "y", "1") + + current_cpu = terminal.get("container_cpu", 1) + cpu_str = prompt(" CPU cores", str(current_cpu)) + try: + terminal["container_cpu"] = float(cpu_str) + except ValueError: + pass + + current_mem = terminal.get("container_memory", 5120) + mem_str = prompt(" Memory in MB (5120 = 5GB)", str(current_mem)) + try: + terminal["container_memory"] = int(mem_str) + except ValueError: + pass + + if terminal.get("container_disk", 51200) not in (0, 51200): + print_warning("Vercel Sandbox does not support custom disk sizing; resetting container_disk to 51200.") + terminal["container_disk"] = 51200 + + print() + print_info("Vercel authentication:") + print_info(" Use a long-lived Vercel access token plus project/team IDs.") + linked_project = _read_nearest_vercel_project() + if linked_project: + print_info(" Found defaults in nearest .vercel/project.json.") + + remove_env_value("VERCEL_OIDC_TOKEN") + token = prompt(" Vercel access token", get_env_value("VERCEL_TOKEN") or "", password=True) + project = prompt( + " Vercel project ID", + get_env_value("VERCEL_PROJECT_ID") or linked_project.get("projectId", ""), + ) + team = prompt( + " Vercel team ID", + get_env_value("VERCEL_TEAM_ID") or linked_project.get("orgId", ""), + ) + if token: + save_env_value("VERCEL_TOKEN", token) + if project: + save_env_value("VERCEL_PROJECT_ID", project) + if team: + save_env_value("VERCEL_TEAM_ID", team) + + +def _read_nearest_vercel_project(start: Path | None = None) -> dict[str, str]: + """Read project/team defaults from the nearest Vercel link file.""" + current = (start or Path.cwd()).resolve() + if current.is_file(): + current = current.parent + + for directory in (current, *current.parents): + project_file = directory / ".vercel" / "project.json" + if not project_file.exists(): + continue + try: + data = json.loads(project_file.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return {} + if not isinstance(data, dict): + return {} + return { + key: value + for key, value in { + "projectId": data.get("projectId"), + "orgId": data.get("orgId"), + }.items() + if isinstance(value, str) and value.strip() + } + return {} + + # Tool categories and provider config are now in tools_config.py (shared # between `hermes tools` and `hermes setup tools`). @@ -1179,7 +1278,7 @@ def setup_terminal_backend(config: dict): print_info(f" Guide: {_DOCS_BASE}/developer-guide/environments") print() - current_backend = config.get("terminal", {}).get("backend", "local") + current_backend = cfg_get(config, "terminal", "backend", default="local") is_linux = _platform.system() == "Linux" # Build backend choices with descriptions @@ -1189,11 +1288,12 @@ def setup_terminal_backend(config: dict): "Modal - serverless cloud sandbox", "SSH - run on a remote machine", "Daytona - persistent cloud development environment", + "Vercel Sandbox - cloud microVM with snapshot filesystem persistence", ] - idx_to_backend = {0: "local", 1: "docker", 2: "modal", 3: "ssh", 4: "daytona"} - backend_to_idx = {"local": 0, "docker": 1, "modal": 2, "ssh": 3, "daytona": 4} + idx_to_backend = {0: "local", 1: "docker", 2: "modal", 3: "ssh", 4: "daytona", 5: "vercel_sandbox"} + backend_to_idx = {"local": 0, "docker": 1, "modal": 2, "ssh": 3, "daytona": 4, "vercel_sandbox": 5} - next_idx = 5 + next_idx = 6 if is_linux: terminal_choices.append("Singularity/Apptainer - HPC-friendly container") idx_to_backend[next_idx] = "singularity" @@ -1228,7 +1328,7 @@ def setup_terminal_backend(config: dict): print_info( " the agent starts. CLI mode always starts in the current directory." ) - current_cwd = config.get("terminal", {}).get("cwd", "") + current_cwd = cfg_get(config, "terminal", "cwd", default="") cwd = prompt(" Messaging working directory", current_cwd or str(Path.home())) if cwd: config["terminal"]["cwd"] = cwd @@ -1259,9 +1359,7 @@ def setup_terminal_backend(config: dict): print_info(f"Docker found: {docker_bin}") # Docker image - current_image = config.get("terminal", {}).get( - "docker_image", "nikolaik/python-nodejs:python3.11-nodejs20" - ) + current_image = cfg_get(config, "terminal", "docker_image", default="nikolaik/python-nodejs:python3.11-nodejs20") image = prompt(" Docker image", current_image) config["terminal"]["docker_image"] = image save_env_value("TERMINAL_DOCKER_IMAGE", image) @@ -1281,9 +1379,7 @@ def setup_terminal_backend(config: dict): else: print_info(f"Found: {sing_bin}") - current_image = config.get("terminal", {}).get( - "singularity_image", "docker://nikolaik/python-nodejs:python3.11-nodejs20" - ) + current_image = cfg_get(config, "terminal", "singularity_image", default="docker://nikolaik/python-nodejs:python3.11-nodejs20") image = prompt(" Container image", current_image) config["terminal"]["singularity_image"] = image save_env_value("TERMINAL_SINGULARITY_IMAGE", image) @@ -1302,7 +1398,7 @@ def setup_terminal_backend(config: dict): get_nous_subscription_features(config).nous_auth_present and is_managed_tool_gateway_ready("modal") ) - modal_mode = normalize_modal_mode(config.get("terminal", {}).get("modal_mode")) + modal_mode = normalize_modal_mode(cfg_get(config, "terminal", "modal_mode")) use_managed_modal = False if managed_modal_available: modal_choices = [ @@ -1439,15 +1535,46 @@ def setup_terminal_backend(config: dict): print_success(" Configured") # Daytona image - current_image = config.get("terminal", {}).get( - "daytona_image", "nikolaik/python-nodejs:python3.11-nodejs20" - ) + current_image = cfg_get(config, "terminal", "daytona_image", default="nikolaik/python-nodejs:python3.11-nodejs20") image = prompt(" Sandbox image", current_image) config["terminal"]["daytona_image"] = image save_env_value("TERMINAL_DAYTONA_IMAGE", image) _prompt_container_resources(config) + elif selected_backend == "vercel_sandbox": + print_success("Terminal backend: Vercel Sandbox") + print_info("Cloud microVM sandboxes with snapshot-backed filesystem persistence.") + print_info("Requires the optional SDK: pip install 'hermes-agent[vercel]'") + + try: + __import__("vercel") + except ImportError: + print_info("Installing vercel SDK...") + import subprocess + + uv_bin = shutil.which("uv") + if uv_bin: + result = subprocess.run( + [uv_bin, "pip", "install", "--python", sys.executable, "vercel"], + capture_output=True, + text=True, + ) + else: + result = subprocess.run( + [sys.executable, "-m", "pip", "install", "vercel"], + capture_output=True, + text=True, + ) + if result.returncode == 0: + print_success("vercel SDK installed") + else: + print_warning("Install failed — run manually: pip install 'hermes-agent[vercel]'") + if result.stderr: + print_info(f" Error: {result.stderr.strip().splitlines()[-1]}") + + _prompt_vercel_sandbox_settings(config) + elif selected_backend == "ssh": print_success("Terminal backend: SSH") print_info("Run commands on a remote machine via SSH.") @@ -1501,6 +1628,8 @@ def setup_terminal_backend(config: dict): save_env_value("TERMINAL_ENV", selected_backend) if selected_backend == "modal": save_env_value("TERMINAL_MODAL_MODE", config["terminal"].get("modal_mode", "auto")) + if selected_backend == "vercel_sandbox": + save_env_value("TERMINAL_VERCEL_RUNTIME", config["terminal"].get("vercel_runtime", "node24")) save_config(config) print() print_success(f"Terminal backend set to: {selected_backend}") @@ -1545,7 +1674,7 @@ def setup_agent_settings(config: dict): # ── Max Iterations ── current_max = get_env_value("HERMES_MAX_ITERATIONS") or str( - config.get("agent", {}).get("max_turns", 90) + cfg_get(config, "agent", "max_turns", default=90) ) print_info("Maximum tool-calling iterations per conversation.") print_info("Higher = more complex tasks, but costs more tokens.") @@ -1573,7 +1702,7 @@ def setup_agent_settings(config: dict): print_info(" all — Show every tool call with a short preview") print_info(" verbose — Full args, results, and debug logs") - current_mode = config.get("display", {}).get("tool_progress", "all") + current_mode = cfg_get(config, "display", "tool_progress", default="all") mode = prompt("Tool progress mode", current_mode) if mode.lower() in ("off", "new", "all", "verbose"): if "display" not in config: @@ -1593,7 +1722,7 @@ def setup_agent_settings(config: dict): config.setdefault("compression", {})["enabled"] = True - current_threshold = config.get("compression", {}).get("threshold", 0.50) + current_threshold = cfg_get(config, "compression", "threshold", default=0.50) threshold_str = prompt("Compression threshold (0.5-0.95)", str(current_threshold)) try: threshold = float(threshold_str) @@ -2075,80 +2204,7 @@ def _setup_mattermost(): home_channel = prompt("Home channel ID (leave empty to set later with /set-home)") if home_channel: save_env_value("MATTERMOST_HOME_CHANNEL", home_channel) - - -def _setup_whatsapp(): - """Configure WhatsApp bridge.""" - print_header("WhatsApp") - existing = get_env_value("WHATSAPP_ENABLED") - if existing: - print_info("WhatsApp: already enabled") - return - - print_info("WhatsApp connects via a built-in bridge (Baileys).") - print_info("Requires Node.js. Run 'hermes whatsapp' for guided setup.") - print() - if prompt_yes_no("Enable WhatsApp now?", True): - save_env_value("WHATSAPP_ENABLED", "true") - print_success("WhatsApp enabled") - print_info("Run 'hermes whatsapp' to choose your mode (separate bot number") - print_info("or personal self-chat) and pair via QR code.") - - -def _setup_weixin(): - """Configure Weixin (personal WeChat) via iLink Bot API QR login.""" - from hermes_cli.gateway import _setup_weixin as _gateway_setup_weixin - _gateway_setup_weixin() - - -def _setup_signal(): - """Configure Signal via gateway setup.""" - from hermes_cli.gateway import _setup_signal as _gateway_setup_signal - _gateway_setup_signal() - - -def _setup_email(): - """Configure Email via gateway setup.""" - from hermes_cli.gateway import _setup_email as _gateway_setup_email - _gateway_setup_email() - - -def _setup_sms(): - """Configure SMS (Twilio) via gateway setup.""" - from hermes_cli.gateway import _setup_sms as _gateway_setup_sms - _gateway_setup_sms() - - -def _setup_dingtalk(): - """Configure DingTalk via gateway setup.""" - from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk - _gateway_setup_dingtalk() - - -def _setup_feishu(): - """Configure Feishu / Lark via gateway setup.""" - from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu - _gateway_setup_feishu() - - -def _setup_yuanbao(): - """Configure Yuanbao via gateway setup.""" - from hermes_cli.gateway import _setup_yuanbao as _gateway_setup_yuanbao - _gateway_setup_yuanbao() - - -def _setup_wecom(): - """Configure WeCom (Enterprise WeChat) via gateway setup.""" - from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom - _gateway_setup_wecom() - - -def _setup_wecom_callback(): - """Configure WeCom Callback (self-built app) via gateway setup.""" - from hermes_cli.gateway import _setup_wecom_callback as _gw_setup - _gw_setup() - - + print_info(" Open config in your editor: hermes config edit") def _setup_bluebubbles(): @@ -2266,49 +2322,27 @@ def _setup_webhooks(): print_info(" https://hermes-agent.nousresearch.com/docs/user-guide/messaging/webhooks/#configuring-routes") print() print_info(" Open config in your editor: hermes config edit") - - -# Platform registry for the gateway checklist -_GATEWAY_PLATFORMS = [ - ("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram), - ("Discord", "DISCORD_BOT_TOKEN", _setup_discord), - ("Slack", "SLACK_BOT_TOKEN", _setup_slack), - ("Signal", "SIGNAL_HTTP_URL", _setup_signal), - ("Email", "EMAIL_ADDRESS", _setup_email), - ("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms), - ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), - ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), - ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), - ("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk), - ("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu), - ("Yuanbao", "YUANBAO_APP_ID", _setup_yuanbao), - ("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom), - ("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback), - ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), - ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), - ("QQ Bot", "QQ_APP_ID", _setup_qqbot), - ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), -] + print_info(" Open config in your editor: hermes config edit") def setup_gateway(config: dict): """Configure messaging platform integrations.""" + from hermes_cli.gateway import _all_platforms, _platform_status, _configure_platform + print_header("Messaging Platforms") print_info("Connect to messaging platforms to chat with Hermes from anywhere.") print_info("Toggle with Space, confirm with Enter.") print() - # Build checklist items, pre-selecting already-configured platforms + platforms = _all_platforms() + + # Build checklist, pre-selecting already-configured platforms. items = [] pre_selected = [] - for i, (name, env_var, _func) in enumerate(_GATEWAY_PLATFORMS): - # Matrix has two possible env vars - is_configured = bool(get_env_value(env_var)) - if name == "Matrix" and not is_configured: - is_configured = bool(get_env_value("MATRIX_PASSWORD")) - label = f"{name} (configured)" if is_configured else name - items.append(label) - if is_configured: + for i, plat in enumerate(platforms): + status = _platform_status(plat) + items.append(f"{plat['emoji']} {plat['label']} ({status})") + if status == "configured": pre_selected.append(i) selected = prompt_checklist("Select platforms to configure:", items, pre_selected) @@ -2318,28 +2352,22 @@ def setup_gateway(config: dict): return for idx in selected: - name, _env_var, setup_func = _GATEWAY_PLATFORMS[idx] - setup_func() + _configure_platform(platforms[idx]) # ── Gateway Service Setup ── - any_messaging = ( - get_env_value("TELEGRAM_BOT_TOKEN") - or get_env_value("DISCORD_BOT_TOKEN") - or get_env_value("SLACK_BOT_TOKEN") - or get_env_value("SIGNAL_HTTP_URL") - or get_env_value("EMAIL_ADDRESS") - or get_env_value("TWILIO_ACCOUNT_SID") - or get_env_value("MATTERMOST_TOKEN") - or get_env_value("MATRIX_ACCESS_TOKEN") - or get_env_value("MATRIX_PASSWORD") - or get_env_value("WHATSAPP_ENABLED") - or get_env_value("DINGTALK_CLIENT_ID") - or get_env_value("FEISHU_APP_ID") - or get_env_value("WECOM_BOT_ID") - or get_env_value("WEIXIN_ACCOUNT_ID") - or get_env_value("BLUEBUBBLES_SERVER_URL") - or get_env_value("QQ_APP_ID") - or get_env_value("WEBHOOK_ENABLED") + # Count any platform (built-in or plugin) the user configured during this + # setup pass — reuses ``_platform_status`` so plugin platforms like IRC + # are picked up without another hard-coded env-var list. + def _is_progress(status: str) -> bool: + s = status.lower() + return not ( + s == "not configured" + or s.startswith("partially") + or s.startswith("plugin disabled") + ) + + any_messaging = any( + _is_progress(_platform_status(p)) for p in _all_platforms() ) if any_messaging: print() @@ -2601,21 +2629,26 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] return "configured" elif section_key == "terminal": - backend = config.get("terminal", {}).get("backend", "local") + backend = cfg_get(config, "terminal", "backend", default="local") return f"backend: {backend}" elif section_key == "agent": - max_turns = config.get("agent", {}).get("max_turns", 90) + max_turns = cfg_get(config, "agent", "max_turns", default=90) return f"max turns: {max_turns}" elif section_key == "gateway": - platforms = [ - _gateway_platform_short_label(label) - for label, env_var, _ in _GATEWAY_PLATFORMS - if get_env_value(env_var) + from hermes_cli.gateway import _all_platforms, _platform_status + # Count any non-empty status other than the "not configured" sentinel — + # platforms like WhatsApp ("enabled, not paired"), Matrix ("configured + # + E2EE"), and Signal ("partially configured") all indicate the user + # has already started setup and we shouldn't force the section to rerun. + configured = [ + _gateway_platform_short_label(plat["label"]) + for plat in _all_platforms() + if _platform_status(plat) and _platform_status(plat) != "not configured" ] - if platforms: - return ", ".join(platforms) + if configured: + return ", ".join(configured) return None # No platforms configured — section must run elif section_key == "tools": @@ -3120,33 +3153,14 @@ def run_setup_wizard(args): _offer_launch_chat() -def _resolve_hermes_chat_argv() -> Optional[list[str]]: - """Resolve argv for launching ``hermes chat`` in a fresh process.""" - hermes_bin = shutil.which("hermes") - if hermes_bin: - return [hermes_bin, "chat"] - - try: - if importlib.util.find_spec("hermes_cli") is not None: - return [sys.executable, "-m", "hermes_cli.main", "chat"] - except Exception: - pass - - return None - - def _offer_launch_chat(): """Prompt the user to jump straight into chat after setup.""" print() if not prompt_yes_no("Launch hermes chat now?", True): return - chat_argv = _resolve_hermes_chat_argv() - if not chat_argv: - print_info("Could not relaunch Hermes automatically. Run 'hermes chat' manually.") - return - - os.execvp(chat_argv[0], chat_argv) + from hermes_cli.relaunch import relaunch + relaunch(["chat"]) def _run_first_time_quick_setup(config: dict, hermes_home, is_existing: bool): diff --git a/hermes_cli/skills_config.py b/hermes_cli/skills_config.py index 741a8b8341..8eaf64605a 100644 --- a/hermes_cli/skills_config.py +++ b/hermes_cli/skills_config.py @@ -13,7 +13,7 @@ Config stored in ~/.hermes/config.yaml under: """ from typing import List, Optional, Set -from hermes_cli.config import load_config, save_config +from hermes_cli.config import cfg_get, load_config, save_config from hermes_cli.colors import Colors, color from hermes_cli.platforms import PLATFORMS as _PLATFORMS @@ -30,7 +30,7 @@ def get_disabled_skills(config: dict, platform: Optional[str] = None) -> Set[str global_disabled = set(skills_cfg.get("disabled", [])) if platform is None: return global_disabled - platform_disabled = skills_cfg.get("platform_disabled", {}).get(platform) + platform_disabled = cfg_get(skills_cfg, "platform_disabled", platform) if platform_disabled is None: return global_disabled return set(platform_disabled) diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 31aa1d5c2c..fb2d010a4e 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -7,6 +7,7 @@ Shows the status of all Hermes Agent components. import os import sys import subprocess # noqa: F401 — re-exported for tests that monkeypatch status.subprocess to guard against regressions +import importlib.util from pathlib import Path PROJECT_ROOT = Path(__file__).parent.parent.resolve() @@ -17,6 +18,7 @@ from hermes_cli.config import get_env_path, get_env_value, get_hermes_home, load from hermes_cli.models import provider_label from hermes_cli.nous_subscription import get_nous_subscription_features from hermes_cli.runtime_provider import resolve_requested_provider +from hermes_cli.vercel_auth import describe_vercel_auth from hermes_constants import OPENROUTER_MODELS_URL from tools.tool_backend_helpers import managed_nous_tools_enabled @@ -89,12 +91,12 @@ def show_status(args): """Show status of all Hermes Agent components.""" show_all = getattr(args, 'all', False) deep = getattr(args, 'deep', False) - + print() print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN)) print(color("│ ⚕ Hermes Agent Status │", Colors.CYAN)) print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN)) - + # ========================================================================= # Environment # ========================================================================= @@ -102,7 +104,7 @@ def show_status(args): print(color("◆ Environment", Colors.CYAN, Colors.BOLD)) print(f" Project: {PROJECT_ROOT}") print(f" Python: {sys.version.split()[0]}") - + env_path = get_env_path() print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}") @@ -113,13 +115,13 @@ def show_status(args): print(f" Model: {_configured_model_label(config)}") print(f" Provider: {_effective_provider_label()}") - + # ========================================================================= # API Keys # ========================================================================= print() print(color("◆ API Keys", Colors.CYAN, Colors.BOLD)) - + keys = { "OpenRouter": "OPENROUTER_API_KEY", "OpenAI": "OPENAI_API_KEY", @@ -138,7 +140,7 @@ def show_status(args): "ElevenLabs": "ELEVENLABS_API_KEY", "GitHub": "GITHUB_TOKEN", } - + for name, env_var in keys.items(): value = get_env_value(env_var) or "" has_key = bool(value) @@ -157,14 +159,21 @@ def show_status(args): print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD)) try: - from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status, get_qwen_auth_status + from hermes_cli.auth import ( + get_nous_auth_status, + get_codex_auth_status, + get_qwen_auth_status, + get_minimax_oauth_auth_status, + ) nous_status = get_nous_auth_status() codex_status = get_codex_auth_status() qwen_status = get_qwen_auth_status() + minimax_status = get_minimax_oauth_auth_status() except Exception: nous_status = {} codex_status = {} qwen_status = {} + minimax_status = {} nous_logged_in = bool(nous_status.get("logged_in")) nous_error = nous_status.get("error") @@ -217,6 +226,20 @@ def show_status(args): if qwen_status.get("error") and not qwen_logged_in: print(f" Error: {qwen_status.get('error')}") + minimax_logged_in = bool(minimax_status.get("logged_in")) + print( + f" {'MiniMax OAuth':<12} {check_mark(minimax_logged_in)} " + f"{'logged in' if minimax_logged_in else 'not logged in (run: hermes auth add minimax-oauth)'}" + ) + minimax_region = minimax_status.get("region") + if minimax_logged_in and minimax_region: + print(f" Region: {minimax_region}") + minimax_exp = minimax_status.get("expires_at") + if minimax_exp: + print(f" Access exp: {minimax_exp}") + if minimax_status.get("error") and not minimax_logged_in: + print(f" Error: {minimax_status.get('error')}") + # ========================================================================= # Nous Subscription Features # ========================================================================= @@ -299,18 +322,13 @@ def show_status(args): # ========================================================================= print() print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD)) - + + terminal_cfg = config.get("terminal", {}) if isinstance(config.get("terminal"), dict) else {} terminal_env = os.getenv("TERMINAL_ENV", "") if not terminal_env: - # Fall back to config file value when env var isn't set - # (hermes status doesn't go through cli.py's config loading) - try: - _cfg = load_config() - terminal_env = _cfg.get("terminal", {}).get("backend", "local") - except Exception: - terminal_env = "local" + terminal_env = terminal_cfg.get("backend", "local") print(f" Backend: {terminal_env}") - + if terminal_env == "ssh": ssh_host = os.getenv("TERMINAL_SSH_HOST", "") ssh_user = os.getenv("TERMINAL_SSH_USER", "") @@ -322,16 +340,33 @@ def show_status(args): elif terminal_env == "daytona": daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20") print(f" Daytona Image: {daytona_image}") - + elif terminal_env == "vercel_sandbox": + runtime = os.getenv("TERMINAL_VERCEL_RUNTIME") or terminal_cfg.get("vercel_runtime") or "node24" + persist = os.getenv("TERMINAL_CONTAINER_PERSISTENT") + if persist is None: + persist_enabled = bool(terminal_cfg.get("container_persistent", True)) + else: + persist_enabled = persist.lower() in ("1", "true", "yes", "on") + auth_status = describe_vercel_auth() + sdk_ok = importlib.util.find_spec("vercel") is not None + sdk_label = "installed" if sdk_ok else "missing (install: pip install 'hermes-agent[vercel]')" + print(f" Runtime: {runtime}") + print(f" SDK: {check_mark(sdk_ok)} {sdk_label}") + print(f" Auth: {check_mark(auth_status.ok)} {auth_status.label}") + for line in auth_status.detail_lines: + print(f" Auth detail: {line}") + print(f" Persistence: {'snapshot filesystem' if persist_enabled else 'ephemeral filesystem'}") + print(" Processes: live processes do not survive cleanup, snapshots, or sandbox recreation") + sudo_password = os.getenv("SUDO_PASSWORD", "") print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}") - + # ========================================================================= # Messaging Platforms # ========================================================================= print() print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD)) - + platforms = { "Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"), "Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"), @@ -349,7 +384,7 @@ def show_status(args): "QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"), "Yuanbao": ("YUANBAO_APP_ID", "YUANBAO_HOME_CHANNEL"), } - + for name, (token_var, home_var) in platforms.items(): token = os.getenv(token_var, "") has_token = bool(token) @@ -366,7 +401,18 @@ def show_status(args): status += f" (home: {home_channel})" print(f" {name:<12} {check_mark(has_token)} {status}") - + + # Plugin-registered platforms + try: + from gateway.platform_registry import platform_registry + for entry in platform_registry.plugin_entries(): + configured = entry.check_fn() + status_str = "configured" if configured else "not configured" + label = entry.label + print(f" {label:<12} {check_mark(configured)} {status_str} (plugin)") + except Exception: + pass + # ========================================================================= # Gateway Status # ========================================================================= @@ -402,13 +448,13 @@ def show_status(args): else: print(f" Status: {color('N/A', Colors.DIM)}") print(" Manager: (not supported on this platform)") - + # ========================================================================= # Cron Jobs # ========================================================================= print() print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD)) - + jobs_file = get_hermes_home() / "cron" / "jobs.json" if jobs_file.exists(): import json @@ -422,13 +468,13 @@ def show_status(args): print(" Jobs: (error reading jobs file)") else: print(" Jobs: 0") - + # ========================================================================= # Sessions # ========================================================================= print() print(color("◆ Sessions", Colors.CYAN, Colors.BOLD)) - + sessions_file = get_hermes_home() / "sessions" / "sessions.json" if sessions_file.exists(): import json @@ -440,7 +486,7 @@ def show_status(args): print(" Active: (error reading sessions file)") else: print(" Active: 0") - + # ========================================================================= # Deep checks # ========================================================================= @@ -476,7 +522,7 @@ def show_status(args): print(f" Port 18789: {'in use' if port_in_use else 'available'}") except OSError: pass - + print() print(color("─" * 60, Colors.DIM)) print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM)) diff --git a/hermes_cli/tips.py b/hermes_cli/tips.py index 8e07323b62..62fad2eb6a 100644 --- a/hermes_cli/tips.py +++ b/hermes_cli/tips.py @@ -100,6 +100,9 @@ TIPS = [ "hermes gateway install sets up Hermes as a system service (systemd/launchd).", "hermes memory setup lets you configure an external memory provider (Honcho, Mem0, etc.).", "hermes webhook subscribe creates event-driven webhook routes with HMAC validation.", + "Save money: hermes tools disables unused tools, hermes skills config trims skills down.", + "/reasoning low or /reasoning minimal cuts thinking depth below the default (medium) — faster, cheaper responses.", + "hermes models routes vision, compression, and aux tasks to cheaper models — cuts background token cost 85%+ without downgrading your main chat model.", # --- Configuration --- "Set display.bell_on_complete: true in config.yaml to hear a bell when long tasks finish.", diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index b4a19f0bc4..5edb227d95 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Set from hermes_cli.config import ( + cfg_get, load_config, save_config, get_env_value, save_env_value, ) from hermes_cli.colors import Colors, color @@ -226,6 +227,14 @@ TOOL_CATEGORIES = { "tts_provider": "kittentts", "post_setup": "kittentts", }, + { + "name": "Piper", + "badge": "local · free", + "tag": "Local neural TTS, 44 languages (voices ~20-90MB)", + "env_vars": [], + "tts_provider": "piper", + "post_setup": "piper", + }, ], }, "web": { @@ -623,6 +632,33 @@ def _run_post_setup(post_setup_key: str): _print_warning(" kittentts install timed out (>5min)") _print_info(f" Run manually: python -m pip install -U '{wheel_url}' soundfile") + elif post_setup_key == "piper": + try: + __import__("piper") + _print_success(" piper-tts is already installed") + except ImportError: + import subprocess + _print_info(" Installing piper-tts (~14MB wheel, voices downloaded on first use)...") + try: + result = subprocess.run( + [sys.executable, "-m", "pip", "install", "-U", "piper-tts", "--quiet"], + capture_output=True, text=True, timeout=300, + ) + if result.returncode == 0: + _print_success(" piper-tts installed") + else: + _print_warning(" piper-tts install failed:") + _print_info(f" {result.stderr.strip()[:300]}") + _print_info(" Run manually: python -m pip install -U piper-tts") + return + except subprocess.TimeoutExpired: + _print_warning(" piper-tts install timed out (>5min)") + _print_info(" Run manually: python -m pip install -U piper-tts") + return + _print_info(" Default voice: en_US-lessac-medium (downloaded on first TTS call)") + _print_info(" Full voice list: https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/VOICES.md") + _print_info(" Switch voices by setting tts.piper.voice in ~/.hermes/config.yaml") + elif post_setup_key == "spotify": # Run the full `hermes auth spotify` flow — if the user has no # client_id yet, this drops them into the interactive wizard @@ -780,7 +816,12 @@ def _get_platform_tools( toolset_names = platform_toolsets.get(platform) if toolset_names is None or not isinstance(toolset_names, list): - default_ts = PLATFORMS[platform]["default_toolset"] + plat_info = PLATFORMS.get(platform) + if plat_info: + default_ts = plat_info["default_toolset"] + else: + # Plugin platform — derive toolset name from platform key + default_ts = f"hermes-{platform}" toolset_names = [default_ts] # YAML may parse bare numeric names (e.g. ``12306:``) as int. @@ -843,7 +884,9 @@ def _get_platform_tools( # checklist or in a user-saved config. Must run in BOTH branches — # otherwise saving via `hermes tools` (which flips has_explicit_config # to True) silently drops them. - platform_tool_universe = set(resolve_toolset(PLATFORMS[platform]["default_toolset"])) + _plat_info = PLATFORMS.get(platform) + _default_ts = _plat_info["default_toolset"] if _plat_info else f"hermes-{platform}" + platform_tool_universe = set(resolve_toolset(_default_ts)) configurable_tool_universe = set() for ck in configurable_keys: configurable_tool_universe.update(resolve_toolset(ck)) @@ -965,7 +1008,7 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[ platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()} # Get existing toolsets for this platform - existing_toolsets = config.get("platform_toolsets", {}).get(platform, []) + existing_toolsets = cfg_get(config, "platform_toolsets", platform, default=[]) if not isinstance(existing_toolsets, list): existing_toolsets = [] existing_toolsets = [str(ts) for ts in existing_toolsets] @@ -1352,23 +1395,23 @@ def _is_provider_active(provider: dict, config: dict) -> bool: if provider.get("tts_provider"): return ( feature.managed_by_nous - and config.get("tts", {}).get("provider") == provider["tts_provider"] + and cfg_get(config, "tts", "provider") == provider["tts_provider"] ) if "browser_provider" in provider: - current = config.get("browser", {}).get("cloud_provider") + current = cfg_get(config, "browser", "cloud_provider") return feature.managed_by_nous and provider["browser_provider"] == current if provider.get("web_backend"): - current = config.get("web", {}).get("backend") + current = cfg_get(config, "web", "backend") return feature.managed_by_nous and current == provider["web_backend"] return feature.managed_by_nous if provider.get("tts_provider"): - return config.get("tts", {}).get("provider") == provider["tts_provider"] + return cfg_get(config, "tts", "provider") == provider["tts_provider"] if "browser_provider" in provider: - current = config.get("browser", {}).get("cloud_provider") + current = cfg_get(config, "browser", "cloud_provider") return provider["browser_provider"] == current if provider.get("web_backend"): - current = config.get("web", {}).get("backend") + current = cfg_get(config, "web", "backend") return current == provider["web_backend"] if provider.get("imagegen_backend"): image_cfg = config.get("image_gen", {}) diff --git a/hermes_cli/vercel_auth.py b/hermes_cli/vercel_auth.py new file mode 100644 index 0000000000..4666d516e1 --- /dev/null +++ b/hermes_cli/vercel_auth.py @@ -0,0 +1,70 @@ +"""Helpers for reporting Vercel Sandbox authentication state.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + + +_TOKEN_TUPLE_VARS = ("VERCEL_TOKEN", "VERCEL_PROJECT_ID", "VERCEL_TEAM_ID") + + +@dataclass(frozen=True) +class VercelAuthStatus: + ok: bool + label: str + detail_lines: tuple[str, ...] + + +def _present(name: str) -> bool: + return bool(os.getenv(name)) + + +def describe_vercel_auth() -> VercelAuthStatus: + """Return Vercel auth status without exposing secret values.""" + + has_oidc = _present("VERCEL_OIDC_TOKEN") + token_states = {name: _present(name) for name in _TOKEN_TUPLE_VARS} + present_token_vars = tuple(name for name, present in token_states.items() if present) + missing_token_vars = tuple(name for name, present in token_states.items() if not present) + + if has_oidc: + details = [ + "mode: OIDC", + "active env: VERCEL_OIDC_TOKEN", + "note: OIDC tokens are development-only; use access-token auth for deployments and long-running processes", + ] + if present_token_vars: + details.append(f"also present: {', '.join(present_token_vars)}") + return VercelAuthStatus(True, "OIDC token via VERCEL_OIDC_TOKEN", tuple(details)) + + if not missing_token_vars: + return VercelAuthStatus( + True, + "access token + project/team via VERCEL_TOKEN, VERCEL_PROJECT_ID, VERCEL_TEAM_ID", + ( + "mode: access token", + "active env: VERCEL_TOKEN, VERCEL_PROJECT_ID, VERCEL_TEAM_ID", + ), + ) + + if present_token_vars: + return VercelAuthStatus( + False, + f"partial access-token auth (missing {', '.join(missing_token_vars)})", + ( + "mode: incomplete access token", + f"present env: {', '.join(present_token_vars)}", + f"missing env: {', '.join(missing_token_vars)}", + "recommended: set VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID together", + ), + ) + + return VercelAuthStatus( + False, + "not configured", + ( + "recommended: set VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID", + "development-only alternative: set VERCEL_OIDC_TOKEN", + ), + ) diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index d8ba44d850..570a0a7a88 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -23,7 +23,7 @@ import time import urllib.parse import urllib.request from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import yaml @@ -33,6 +33,7 @@ if str(PROJECT_ROOT) not in sys.path: from hermes_cli import __version__, __release_date__ from hermes_cli.config import ( + cfg_get, DEFAULT_CONFIG, OPTIONAL_ENV_VARS, get_config_path, @@ -252,7 +253,12 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = { "terminal.backend": { "type": "select", "description": "Terminal execution backend", - "options": ["local", "docker", "ssh", "modal", "daytona", "singularity"], + "options": ["local", "docker", "ssh", "modal", "daytona", "vercel_sandbox", "singularity"], + }, + "terminal.vercel_runtime": { + "type": "select", + "description": "Vercel Sandbox runtime", + "options": ["node24", "node22", "python3.13"], # sync with _SUPPORTED_VERCEL_RUNTIMES in terminal_tool.py }, "terminal.modal_mode": { "type": "select", @@ -338,6 +344,11 @@ _CATEGORY_MERGE: Dict[str, str] = { "human_delay": "display", "dashboard": "display", "code_execution": "agent", + "prompt_caching": "agent", + # Only `telegram.reactions` currently lives under telegram — fold it in + # with the other messaging-platform config (discord) so it isn't an + # orphan tab of one field. + "telegram": "discord", } # Display order for tabs — unlisted categories sort alphabetically after these. @@ -434,6 +445,20 @@ class EnvVarReveal(BaseModel): key: str +class ModelAssignment(BaseModel): + """Payload for POST /api/model/set — assign a provider/model to a slot. + + scope="main" → writes model.provider + model.default + scope="auxiliary" → writes auxiliary..provider + auxiliary..model + scope="auxiliary" with task="" → applied to every auxiliary.* slot + scope="auxiliary" with task="__reset__" → resets every slot to provider="auto" + """ + scope: str + provider: str + model: str + task: str = "" + + _GATEWAY_HEALTH_URL = os.getenv("GATEWAY_HEALTH_URL") try: _GATEWAY_HEALTH_TIMEOUT = float(os.getenv("GATEWAY_HEALTH_TIMEOUT", "3")) @@ -910,6 +935,207 @@ def get_model_info(): return dict(_EMPTY_MODEL_INFO) +# --------------------------------------------------------------------------- +# Model assignment — pick provider+model for main slot or auxiliary slots. +# Mirrors the model.options JSON-RPC from tui_gateway but uses REST so the +# Models page (which has no chat PTY open) can drive it. +# --------------------------------------------------------------------------- + +# Canonical auxiliary task slots. Keep in sync with DEFAULT_CONFIG["auxiliary"] +# in hermes_cli/config.py — listed here for deterministic ordering in the UI. +_AUX_TASK_SLOTS: Tuple[str, ...] = ( + "vision", + "web_extract", + "compression", + "session_search", + "skills_hub", + "approval", + "mcp", + "title_generation", + "curator", +) + + +@app.get("/api/model/options") +def get_model_options(): + """Return authenticated providers + their curated model lists. + + REST equivalent of the ``model.options`` JSON-RPC on tui_gateway, so the + dashboard Models page can render the picker without a live chat session. + The response shape matches ``model.options`` 1:1 so ``ModelPickerDialog`` + can share the same types. + """ + try: + from hermes_cli.model_switch import list_authenticated_providers + + cfg = load_config() + model_cfg = cfg.get("model", {}) + if isinstance(model_cfg, dict): + current_model = model_cfg.get("default", model_cfg.get("name", "")) or "" + current_provider = model_cfg.get("provider", "") or "" + current_base_url = model_cfg.get("base_url", "") or "" + else: + current_model = str(model_cfg) if model_cfg else "" + current_provider = "" + current_base_url = "" + + user_providers = cfg.get("providers") if isinstance(cfg.get("providers"), dict) else {} + custom_providers = ( + cfg.get("custom_providers") + if isinstance(cfg.get("custom_providers"), list) + else [] + ) + + providers = list_authenticated_providers( + current_provider=current_provider, + current_base_url=current_base_url, + current_model=current_model, + user_providers=user_providers, + custom_providers=custom_providers, + max_models=50, + ) + return { + "providers": providers, + "model": current_model, + "provider": current_provider, + } + except Exception: + _log.exception("GET /api/model/options failed") + raise HTTPException(status_code=500, detail="Failed to list model options") + + +@app.get("/api/model/auxiliary") +def get_auxiliary_models(): + """Return current auxiliary task assignments. + + Shape: + { + "tasks": [ + {"task": "vision", "provider": "auto", "model": "", "base_url": ""}, + ... + ], + "main": {"provider": "openrouter", "model": "anthropic/claude-opus-4.7"}, + } + """ + try: + cfg = load_config() + aux_cfg = cfg.get("auxiliary", {}) + if not isinstance(aux_cfg, dict): + aux_cfg = {} + + tasks = [] + for slot in _AUX_TASK_SLOTS: + slot_cfg = aux_cfg.get(slot, {}) if isinstance(aux_cfg.get(slot), dict) else {} + tasks.append({ + "task": slot, + "provider": str(slot_cfg.get("provider", "auto") or "auto"), + "model": str(slot_cfg.get("model", "") or ""), + "base_url": str(slot_cfg.get("base_url", "") or ""), + }) + + model_cfg = cfg.get("model", {}) + if isinstance(model_cfg, dict): + main = { + "provider": str(model_cfg.get("provider", "") or ""), + "model": str(model_cfg.get("default", model_cfg.get("name", "")) or ""), + } + else: + main = {"provider": "", "model": str(model_cfg) if model_cfg else ""} + + return {"tasks": tasks, "main": main} + except Exception: + _log.exception("GET /api/model/auxiliary failed") + raise HTTPException(status_code=500, detail="Failed to read auxiliary config") + + +@app.post("/api/model/set") +async def set_model_assignment(body: ModelAssignment): + """Assign a model to the main slot or an auxiliary task slot. + + Writes to ``~/.hermes/config.yaml`` — applies to **new** sessions only. + The currently running chat PTY (if any) is not affected; use the + ``/model`` slash command inside a chat to hot-swap that specific session. + """ + scope = (body.scope or "").strip().lower() + provider = (body.provider or "").strip() + model = (body.model or "").strip() + task = (body.task or "").strip().lower() + + if scope not in ("main", "auxiliary"): + raise HTTPException(status_code=400, detail="scope must be 'main' or 'auxiliary'") + + try: + cfg = load_config() + + if scope == "main": + if not provider or not model: + raise HTTPException(status_code=400, detail="provider and model required for main") + model_cfg = cfg.get("model", {}) + if not isinstance(model_cfg, dict): + model_cfg = {} + model_cfg["provider"] = provider + model_cfg["default"] = model + # Clear stale base_url so the resolver picks the provider's own default. + if "base_url" in model_cfg and model_cfg.get("base_url"): + model_cfg["base_url"] = "" + # Also clear hardcoded context_length override — new model may have + # a different context window. + if "context_length" in model_cfg: + model_cfg.pop("context_length", None) + cfg["model"] = model_cfg + save_config(cfg) + return {"ok": True, "scope": "main", "provider": provider, "model": model} + + # scope == "auxiliary" + aux = cfg.get("auxiliary") + if not isinstance(aux, dict): + aux = {} + + if task == "__reset__": + # Reset every slot to provider="auto", model="" — keeps other fields intact. + for slot in _AUX_TASK_SLOTS: + slot_cfg = aux.get(slot) + if not isinstance(slot_cfg, dict): + slot_cfg = {} + slot_cfg["provider"] = "auto" + slot_cfg["model"] = "" + aux[slot] = slot_cfg + cfg["auxiliary"] = aux + save_config(cfg) + return {"ok": True, "scope": "auxiliary", "reset": True} + + if not provider: + raise HTTPException(status_code=400, detail="provider required for auxiliary") + + targets = [task] if task else list(_AUX_TASK_SLOTS) + for slot in targets: + if slot not in _AUX_TASK_SLOTS: + raise HTTPException(status_code=400, detail=f"unknown auxiliary task: {slot}") + slot_cfg = aux.get(slot) + if not isinstance(slot_cfg, dict): + slot_cfg = {} + slot_cfg["provider"] = provider + slot_cfg["model"] = model + aux[slot] = slot_cfg + + cfg["auxiliary"] = aux + save_config(cfg) + return { + "ok": True, + "scope": "auxiliary", + "tasks": targets, + "provider": provider, + "model": model, + } + except HTTPException: + raise + except Exception: + _log.exception("POST /api/model/set failed") + raise HTTPException(status_code=500, detail="Failed to save model assignment") + + + + def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: """Reverse _normalize_config_for_web before saving. @@ -1214,6 +1440,14 @@ _OAUTH_PROVIDER_CATALOG: tuple[Dict[str, Any], ...] = ( "docs_url": "https://github.com/QwenLM/qwen-code", "status_fn": None, # dispatched via auth.get_qwen_auth_status }, + { + "id": "minimax-oauth", + "name": "MiniMax (OAuth)", + "flow": "pkce", + "cli_command": "hermes auth add minimax-oauth", + "docs_url": "https://www.minimax.io", + "status_fn": None, # dispatched via auth.get_minimax_oauth_auth_status + }, ) @@ -1257,6 +1491,16 @@ def _resolve_provider_status(provider_id: str, status_fn) -> Dict[str, Any]: "expires_at": raw.get("expires_at"), "has_refresh_token": bool(raw.get("has_refresh_token")), } + if provider_id == "minimax-oauth": + raw = hauth.get_minimax_oauth_auth_status() + return { + "logged_in": bool(raw.get("logged_in")), + "source": "minimax_oauth", + "source_label": f"MiniMax ({raw.get('region', 'global')})", + "token_preview": None, + "expires_at": raw.get("expires_at"), + "has_refresh_token": True, + } except Exception as e: return {"logged_in": False, "error": str(e)} return {"logged_in": False} @@ -2245,12 +2489,13 @@ async def open_profile_terminal_endpoint(name: str): command = _profile_setup_command(name) if sys.platform.startswith("win"): - subprocess.Popen(["cmd.exe", "/k", command]) + subprocess.Popen(["cmd.exe", "/c", "start", "", command]) elif sys.platform == "darwin": + escaped = command.replace("\\", "\\\\").replace('"', '\\"') applescript = ( 'tell application "Terminal"\n' "activate\n" - f'do script "{command.replace("\\\\", "\\\\\\\\").replace(\'"\', \'\\\\"\')}"\n' + f'do script "{escaped}"\n' "end tell" ) subprocess.Popen(["osascript", "-e", applescript]) @@ -2517,6 +2762,99 @@ async def get_usage_analytics(days: int = 30): db.close() +@app.get("/api/analytics/models") +async def get_models_analytics(days: int = 30): + """Rich per-model analytics for the Models dashboard page. + + Returns token/cost/session breakdown per model plus capability metadata + from models.dev (context window, vision, tools, reasoning, etc.). + """ + from hermes_state import SessionDB + + db = SessionDB() + try: + cutoff = time.time() - (days * 86400) + + cur = db._conn.execute(""" + SELECT model, + billing_provider, + SUM(input_tokens) as input_tokens, + SUM(output_tokens) as output_tokens, + SUM(cache_read_tokens) as cache_read_tokens, + SUM(reasoning_tokens) as reasoning_tokens, + COALESCE(SUM(estimated_cost_usd), 0) as estimated_cost, + COALESCE(SUM(actual_cost_usd), 0) as actual_cost, + COUNT(*) as sessions, + SUM(COALESCE(api_call_count, 0)) as api_calls, + SUM(tool_call_count) as tool_calls, + MAX(started_at) as last_used_at, + AVG(input_tokens + output_tokens) as avg_tokens_per_session + FROM sessions WHERE started_at > ? AND model IS NOT NULL AND model != '' + GROUP BY model, billing_provider + ORDER BY SUM(input_tokens) + SUM(output_tokens) DESC + """, (cutoff,)) + rows = [dict(r) for r in cur.fetchall()] + + models = [] + for row in rows: + provider = row.get("billing_provider") or "" + model_name = row["model"] + caps = {} + try: + from agent.models_dev import get_model_capabilities + mc = get_model_capabilities(provider=provider, model=model_name) + if mc is not None: + caps = { + "supports_tools": mc.supports_tools, + "supports_vision": mc.supports_vision, + "supports_reasoning": mc.supports_reasoning, + "context_window": mc.context_window, + "max_output_tokens": mc.max_output_tokens, + "model_family": mc.model_family, + } + except Exception: + pass + + models.append({ + "model": model_name, + "provider": provider, + "input_tokens": row["input_tokens"], + "output_tokens": row["output_tokens"], + "cache_read_tokens": row["cache_read_tokens"], + "reasoning_tokens": row["reasoning_tokens"], + "estimated_cost": row["estimated_cost"], + "actual_cost": row["actual_cost"], + "sessions": row["sessions"], + "api_calls": row["api_calls"], + "tool_calls": row["tool_calls"], + "last_used_at": row["last_used_at"], + "avg_tokens_per_session": row["avg_tokens_per_session"], + "capabilities": caps, + }) + + totals_cur = db._conn.execute(""" + SELECT COUNT(DISTINCT model) as distinct_models, + SUM(input_tokens) as total_input, + SUM(output_tokens) as total_output, + SUM(cache_read_tokens) as total_cache_read, + SUM(reasoning_tokens) as total_reasoning, + COALESCE(SUM(estimated_cost_usd), 0) as total_estimated_cost, + COALESCE(SUM(actual_cost_usd), 0) as total_actual_cost, + COUNT(*) as total_sessions, + SUM(COALESCE(api_call_count, 0)) as total_api_calls + FROM sessions WHERE started_at > ? AND model IS NOT NULL AND model != '' + """, (cutoff,)) + totals = dict(totals_cur.fetchone()) + + return { + "models": models, + "totals": totals, + "period_days": days, + } + finally: + db.close() + + # --------------------------------------------------------------------------- # /api/pty — PTY-over-WebSocket bridge for the dashboard "Chat" tab. # @@ -3149,7 +3487,7 @@ async def get_dashboard_themes(): them without a stub. """ config = load_config() - active = config.get("dashboard", {}).get("theme", "default") + active = cfg_get(config, "dashboard", "theme", default="default") user_themes = _discover_user_themes() seen = set() themes = [] @@ -3199,10 +3537,12 @@ def _discover_dashboard_plugins() -> list: plugins = [] seen_names: set = set() + from hermes_cli.plugins import get_bundled_plugins_dir + bundled_root = get_bundled_plugins_dir() search_dirs = [ (get_hermes_home() / "plugins", "user"), - (PROJECT_ROOT / "plugins" / "memory", "bundled"), - (PROJECT_ROOT / "plugins", "bundled"), + (bundled_root / "memory", "bundled"), + (bundled_root, "bundled"), ] if os.environ.get("HERMES_ENABLE_PROJECT_PLUGINS"): search_dirs.append((Path.cwd() / ".hermes" / "plugins", "project")) @@ -3347,13 +3687,23 @@ def _mount_plugin_api_routes(): _log.warning("Plugin %s declares api=%s but file not found", plugin["name"], api_file_name) continue try: - spec = importlib.util.spec_from_file_location( - f"hermes_dashboard_plugin_{plugin['name']}", api_path, - ) + module_name = f"hermes_dashboard_plugin_{plugin['name']}" + spec = importlib.util.spec_from_file_location(module_name, api_path) if spec is None or spec.loader is None: continue mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) + # Register in sys.modules BEFORE exec_module so pydantic/FastAPI + # can resolve forward references (e.g. models defined in a file + # that uses `from __future__ import annotations`). Without this, + # TypeAdapter lazy-build fails at first request with + # "is not fully defined" because the module namespace isn't + # reachable by name for string-annotation resolution. + sys.modules[module_name] = mod + try: + spec.loader.exec_module(mod) + except Exception: + sys.modules.pop(module_name, None) + raise router = getattr(mod, "router", None) if router is None: _log.warning("Plugin %s api file has no 'router' attribute", plugin["name"]) diff --git a/hermes_cli/webhook.py b/hermes_cli/webhook.py index 0ec4c6784b..4b74204bcc 100644 --- a/hermes_cli/webhook.py +++ b/hermes_cli/webhook.py @@ -19,6 +19,7 @@ from typing import Dict from hermes_constants import display_hermes_home from utils import atomic_replace +from hermes_cli.config import cfg_get _SUBSCRIPTIONS_FILENAME = "webhook_subscriptions.json" @@ -60,7 +61,7 @@ def _get_webhook_config() -> dict: try: from hermes_cli.config import load_config cfg = load_config() - return cfg.get("platforms", {}).get("webhook", {}) + return cfg_get(cfg, "platforms", "webhook", default={}) except Exception: return {} diff --git a/model_tools.py b/model_tools.py index d85a8b8efd..25830d2e5a 100644 --- a/model_tools.py +++ b/model_tools.py @@ -107,17 +107,58 @@ def _run_async(coro): loop = None if loop and loop.is_running(): - # Inside an async context (gateway, RL env) — run in a fresh thread. + # Inside an async context (gateway, RL env) — run in a fresh thread + # with its own event loop we own a reference to, so on timeout we + # can cancel the task inside that loop (ThreadPoolExecutor.cancel() + # only works on not-yet-started futures — it's a no-op on a running + # worker, which previously leaked the thread on every 300 s timeout). import concurrent.futures + + worker_loop: Optional[asyncio.AbstractEventLoop] = None + loop_ready = threading.Event() + + def _run_in_worker(): + nonlocal worker_loop + worker_loop = asyncio.new_event_loop() + loop_ready.set() + try: + asyncio.set_event_loop(worker_loop) + return worker_loop.run_until_complete(coro) + finally: + try: + # Cancel anything still pending (e.g. task cancelled + # externally via call_soon_threadsafe on timeout). + pending = asyncio.all_tasks(worker_loop) + for t in pending: + t.cancel() + if pending: + worker_loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass + worker_loop.close() + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) - future = pool.submit(asyncio.run, coro) + future = pool.submit(_run_in_worker) try: return future.result(timeout=300) except concurrent.futures.TimeoutError: - future.cancel() + # Cancel the coroutine inside its own loop so the worker thread + # can wind down instead of running forever. + if loop_ready.wait(timeout=1.0) and worker_loop is not None: + try: + for t in asyncio.all_tasks(worker_loop): + worker_loop.call_soon_threadsafe(t.cancel) + except RuntimeError: + # Loop already closed — nothing to cancel. + pass raise finally: - pool.shutdown(wait=False, cancel_futures=True) + # wait=False: don't block the caller on a stuck coroutine. We've + # already requested cancellation above; the worker will exit + # once the coroutine observes it (usually at the next await). + pool.shutdown(wait=False) # If we're on a worker thread (e.g., parallel tool execution in # delegate_task), use a per-thread persistent loop. This avoids @@ -627,6 +668,13 @@ def handle_function_call( # Check plugin hooks for a block directive (unless caller already # checked — e.g. run_agent._invoke_tool passes skip=True to # avoid double-firing the hook). + # + # Single-fire contract: pre_tool_call fires exactly once per tool + # execution. get_pre_tool_call_block_message() internally calls + # invoke_hook("pre_tool_call", ...) and returns the first block + # directive (if any), so observer plugins see the hook on that same + # pass. When skip=True, the caller already fired it — do nothing + # here. if not skip_pre_tool_call_hook: block_message: Optional[str] = None try: @@ -643,21 +691,6 @@ def handle_function_call( if block_message is not None: return json.dumps({"error": block_message}, ensure_ascii=False) - else: - # Still fire the hook for observers — just don't check for blocking - # (the caller already did that). - try: - from hermes_cli.plugins import invoke_hook - invoke_hook( - "pre_tool_call", - tool_name=function_name, - args=function_args, - task_id=task_id or "", - session_id=session_id or "", - tool_call_id=tool_call_id or "", - ) - except Exception: - pass # Notify the read-loop tracker when a non-read/search tool runs, # so the *consecutive* counter resets (reads after other work are fine). @@ -737,7 +770,7 @@ def handle_function_call( except Exception as e: error_msg = f"Error executing {function_name}: {str(e)}" - logger.error(error_msg) + logger.exception(error_msg) return json.dumps({"error": error_msg}, ensure_ascii=False) diff --git a/nix/checks.nix b/nix/checks.nix index cf11082b98..bb8801a0b8 100644 --- a/nix/checks.nix +++ b/nix/checks.nix @@ -124,6 +124,26 @@ json.dump(sorted(leaf_paths(DEFAULT_CONFIG)), sys.stdout, indent=2) echo "ok" > $out/result ''; + # Verify bundled plugins (platforms, memory, context_engine) are present + bundled-plugins = pkgs.runCommand "hermes-bundled-plugins" { } '' + set -e + echo "=== Checking bundled plugins ===" + test -d ${hermes-agent}/share/hermes-agent/plugins || (echo "FAIL: plugins directory missing"; exit 1) + echo "PASS: plugins directory exists" + + test -f ${hermes-agent}/share/hermes-agent/plugins/platforms/irc/plugin.yaml || \ + (echo "FAIL: irc plugin manifest missing"; exit 1) + echo "PASS: irc plugin manifest present" + + grep -q "HERMES_BUNDLED_PLUGINS" ${hermes-agent}/bin/hermes || \ + (echo "FAIL: HERMES_BUNDLED_PLUGINS not in wrapper"; exit 1) + echo "PASS: HERMES_BUNDLED_PLUGINS set in wrapper" + + echo "=== All bundled plugins checks passed ===" + mkdir -p $out + echo "ok" > $out/result + ''; + # Verify bundled TUI is present and compiled bundled-tui = pkgs.runCommand "hermes-bundled-tui" { } '' set -e diff --git a/nix/hermes-agent.nix b/nix/hermes-agent.nix index 85ba71fb13..886bb1aadb 100644 --- a/nix/hermes-agent.nix +++ b/nix/hermes-agent.nix @@ -19,6 +19,10 @@ pyproject-nix, pyproject-build-systems, npm-lockfile-fix, + # Locked git revision of the flake source — embedded so banner.py can + # check for updates without needing a local .git directory. Null for + # impure / dirty builds where flakes can't determine a rev. + rev ? null, # Overridable parameters extraPythonPackages ? [ ], }: @@ -44,6 +48,14 @@ let filter = path: _type: !(lib.hasInfix "/index-cache/" path); }; + # Import bundled plugins (memory, context_engine, platforms/*). Keeping + # them out of the Python site-packages keeps import semantics identical + # to a dev checkout — the loader reads them from HERMES_BUNDLED_PLUGINS. + bundledPlugins = lib.cleanSourceWith { + src = ../plugins; + filter = path: _type: !(lib.hasInfix "/__pycache__/" path); + }; + runtimeDeps = [ nodejs_22 ripgrep @@ -84,6 +96,7 @@ stdenv.mkDerivation { mkdir -p $out/share/hermes-agent $out/bin cp -r ${bundledSkills} $out/share/hermes-agent/skills + cp -r ${bundledPlugins} $out/share/hermes-agent/plugins cp -r ${hermesWeb} $out/share/hermes-agent/web_dist mkdir -p $out/ui-tui @@ -94,10 +107,12 @@ stdenv.mkDerivation { makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \ --suffix PATH : "${runtimePath}" \ --set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills \ + --set HERMES_BUNDLED_PLUGINS $out/share/hermes-agent/plugins \ --set HERMES_WEB_DIST $out/share/hermes-agent/web_dist \ --set HERMES_TUI_DIR $out/ui-tui \ --set HERMES_PYTHON ${hermesVenv}/bin/python3 \ --set HERMES_NODE ${nodejs_22}/bin/node \ + ${lib.optionalString (rev != null) ''--set HERMES_REVISION ${rev} \''} ${lib.optionalString (extraPythonPackages != [ ]) ''--suffix PYTHONPATH : "${pythonPath}"''} '') [ diff --git a/nix/nixosModules.nix b/nix/nixosModules.nix index 863ebd6ed5..fbff28e18b 100644 --- a/nix/nixosModules.nix +++ b/nix/nixosModules.nix @@ -647,6 +647,16 @@ }]; } + # ── Assertions ───────────────────────────────────────────────────── + { + assertions = let + names = map lib.getName cfg.extraPlugins; + in [{ + assertion = (lib.length names) == (lib.length (lib.unique names)); + message = "services.hermes-agent.extraPlugins: duplicate plugin names detected: ${toString names}. If using fetchFromGitHub, set name = \"plugin-name\" to disambiguate."; + }]; + } + # ── Warnings ────────────────────────────────────────────────────── # ── Per-user profile for extraPackages ─────────────────────────── # Wire extraPackages into the hermes user's per-user profile so the @@ -730,12 +740,12 @@ # is disabled so the host CLI falls back to native execution. ${if cfg.container.enable then '' cat > ${cfg.stateDir}/.hermes/.container-mode <<'HERMES_CONTAINER_MODE_EOF' -# Written by NixOS activation script. Do not edit manually. -backend=${cfg.container.backend} -container_name=${containerName} -exec_user=${cfg.user} -hermes_bin=${containerDataDir}/current-package/bin/hermes -HERMES_CONTAINER_MODE_EOF + # Written by NixOS activation script. Do not edit manually. + backend=${cfg.container.backend} + container_name=${containerName} + exec_user=${cfg.user} + hermes_bin=${containerDataDir}/current-package/bin/hermes + HERMES_CONTAINER_MODE_EOF chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.container-mode chmod 0644 ${cfg.stateDir}/.hermes/.container-mode '' else '' @@ -796,8 +806,8 @@ HERMES_CONTAINER_MODE_EOF ENV_FILE="${cfg.stateDir}/.hermes/.env" install -o ${cfg.user} -g ${cfg.group} -m 0640 /dev/null "$ENV_FILE" cat > "$ENV_FILE" <<'HERMES_NIX_ENV_EOF' -${envFileContent} -HERMES_NIX_ENV_EOF + ${envFileContent} + HERMES_NIX_ENV_EOF ${lib.concatStringsSep "\n" (map (f: '' if [ -f "${f}" ]; then echo "" >> "$ENV_FILE" diff --git a/nix/overlays.nix b/nix/overlays.nix index 4d7bb2a121..474e57d852 100644 --- a/nix/overlays.nix +++ b/nix/overlays.nix @@ -5,6 +5,7 @@ hermes-agent = final.callPackage ./hermes-agent.nix { inherit (inputs) uv2nix pyproject-nix pyproject-build-systems; npm-lockfile-fix = inputs.npm-lockfile-fix.packages.${final.stdenv.hostPlatform.system}.default; + rev = inputs.self.rev or null; }; }; } diff --git a/nix/packages.nix b/nix/packages.nix index f27c43a75e..d95133d26a 100644 --- a/nix/packages.nix +++ b/nix/packages.nix @@ -7,6 +7,9 @@ hermesAgent = pkgs.callPackage ./hermes-agent.nix { inherit (inputs) uv2nix pyproject-nix pyproject-build-systems; npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default; + # Only embed clean revs — dirtyRev doesn't represent any upstream + # commit, so comparing it would always claim "update available". + rev = inputs.self.rev or null; }; in { diff --git a/nix/tui.nix b/nix/tui.nix index 7453fa2673..4d27dde798 100644 --- a/nix/tui.nix +++ b/nix/tui.nix @@ -4,7 +4,7 @@ let src = ../ui-tui; npmDeps = pkgs.fetchNpmDeps { inherit src; - hash = "sha256-Chz+NW9NXqboXHOa6PKwf5bhAkkcFtKNhvKWwg2XSPc="; + hash = "sha256-a/HGI9OgVcTnZrMXA7xFMGnFoVxyHe95fulVz+WNYB0="; }; npm = hermesNpmLib.mkNpmPassthru { folder = "ui-tui"; attr = "tui"; pname = "hermes-tui"; }; diff --git a/nix/web.nix b/nix/web.nix index bff29983d6..7084a04c8e 100644 --- a/nix/web.nix +++ b/nix/web.nix @@ -4,7 +4,7 @@ let src = ../web; npmDeps = pkgs.fetchNpmDeps { inherit src; - hash = "sha256-+B2+Fe4djPzHHcUXRx+m0cuyaopAhW0PcHsMgYfV5VE="; + hash = "sha256-HWB1piIPglTXbzQHXFYHLgVZIbDb60esupXSQGa1+lI="; }; npm = hermesNpmLib.mkNpmPassthru { folder = "web"; attr = "web"; pname = "hermes-web"; }; diff --git a/plugins/hermes-achievements/LICENSE b/plugins/hermes-achievements/LICENSE new file mode 100644 index 0000000000..2312b92352 --- /dev/null +++ b/plugins/hermes-achievements/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Hermes Achievements contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/plugins/hermes-achievements/README.md b/plugins/hermes-achievements/README.md new file mode 100644 index 0000000000..dd360197e8 --- /dev/null +++ b/plugins/hermes-achievements/README.md @@ -0,0 +1,148 @@ +# Hermes Achievements + +> **Bundled with Hermes Agent.** Originally authored by [@PCinkusz](https://github.com/PCinkusz) at https://github.com/PCinkusz/hermes-achievements — vendored into `plugins/hermes-achievements/` so it ships with the dashboard out-of-the-box and stays in lockstep with Hermes feature changes. Upstream repo remains the staging ground for new badges and UI iteration. +> +> When Hermes is installed via `pip install hermes-agent` or cloned from source, this plugin auto-registers as a dashboard tab on first `hermes dashboard` launch. No separate install step. See [Built-in Plugins → hermes-achievements](../../website/docs/user-guide/features/built-in-plugins.md) in the main docs. + +Achievement system for the Hermes Dashboard: collectible, tiered badges generated from real local Hermes session history. + +![Hermes Achievements dashboard](docs/assets/achievements-dashboard-hd.png) + +The screenshots use temporary demo tier data to show the full visual range. The plugin itself reads real local Hermes session history by default. + +> **Update notice (2026-04-29):** If you installed this plugin before today, update to the latest version. The achievements scan path was refactored for much faster warm loads (snapshot cache + incremental checkpoint scan). + +## What it does + +Hermes Achievements scans local Hermes sessions and unlocks badges based on real agent behavior: + +- autonomous tool chains +- debugging and recovery patterns +- vibe-coding file edits +- Hermes-native skills, memory, cron, and plugin usage +- web research and browser automation +- model/provider workflows +- lifestyle patterns such as weekend or night sessions + +Achievements have three visible states: + +- **Unlocked** — earned at least one tier +- **Discovered** — known achievement, progress visible, not earned yet +- **Secret** — hidden until Hermes detects the first related signal + +Most achievements level through: + +```text +Copper → Silver → Gold → Diamond → Olympian +``` + +Each card has a collapsible **What counts** section showing the exact tracked metric or requirement once the user wants details. + +Version `0.2.x` expands the catalog to 60+ achievements, including model/provider badges such as **Five-Model Flight**, **Provider Polyglot**, **Claude Confidant**, **Gemini Cartographer**, and **Open Weights Pilgrim**. + +## Examples + +- Let Him Cook +- Toolchain Maxxer +- Red Text Connoisseur +- Port 3000 Is Taken +- This Was Supposed To Be Quick +- One More Small Change +- Skillsmith +- Memory Keeper +- Context Dragon +- Plugin Goblin +- Rabbit Hole Certified + +## Install + +Clone into your Hermes plugins directory: + +```bash +git clone https://github.com/PCinkusz/hermes-achievements ~/.hermes/plugins/hermes-achievements +``` + +For local development, keep the repo elsewhere and symlink it: + +```bash +git clone https://github.com/PCinkusz/hermes-achievements ~/hermes-achievements +ln -s ~/hermes-achievements ~/.hermes/plugins/hermes-achievements +``` + +Then rescan dashboard plugins: + +```bash +curl http://127.0.0.1:9119/api/dashboard/plugins/rescan +``` + +If backend API routes 404, restart `hermes dashboard`; plugin APIs are mounted at dashboard startup. + +## Updating + +If you installed with git: + +```bash +cd ~/.hermes/plugins/hermes-achievements +git pull --ff-only +curl http://127.0.0.1:9119/api/dashboard/plugins/rescan +``` + +If the update changes backend routes or `plugin_api.py`, restart `hermes dashboard` after pulling. + +As of 2026-04-29, updating is strongly recommended because scan performance changed significantly: +- removed duplicate `/overview` scan path +- added cached `/achievements` snapshot +- added incremental checkpoint reuse for unchanged sessions + +Achievement unlock state is stored locally in `state.json` and is not overwritten by git updates. New achievements are evaluated from your existing Hermes session history. Achievement IDs are stable and should not be renamed casually because they are the unlock-state keys. + +Releases are tagged in git, for example: + +```bash +git fetch --tags +git checkout v0.2.0 +``` + +## Files + +```text +dashboard/ +├── manifest.json +├── plugin_api.py +└── dist/ + ├── index.js + └── style.css +``` + +## API + +Routes are mounted under: + +```text +/api/plugins/hermes-achievements/ +``` + +Endpoints: + +```text +GET /achievements +GET /scan-status +GET /recent-unlocks +GET /sessions/{session_id}/badges +POST /rescan +POST /reset-state +``` + +## Development + +Run checks: + +```bash +node --check dashboard/dist/index.js +python3 -m py_compile dashboard/plugin_api.py +python3 -m unittest tests/test_achievement_engine.py -v +``` + +## License + +MIT diff --git a/plugins/hermes-achievements/dashboard/dist/index.js b/plugins/hermes-achievements/dashboard/dist/index.js new file mode 100644 index 0000000000..56b9427e84 --- /dev/null +++ b/plugins/hermes-achievements/dashboard/dist/index.js @@ -0,0 +1,351 @@ +(function () { + "use strict"; + // hermes-achievements dashboard plugin + // Originally authored by @PCinkusz — https://github.com/PCinkusz/hermes-achievements (MIT). + // Bundled into hermes-agent. Upstream repo remains the staging ground for new + // badges and UI iteration; the in-progress scan banner below is a small addition + // layered on top of the original dist bundle. + const SDK = window.__HERMES_PLUGIN_SDK__; + if (!SDK || !window.__HERMES_PLUGINS__) return; + + const React = SDK.React; + const hooks = SDK.hooks; + const C = SDK.components; + const cn = SDK.utils.cn; + + const LUCIDE = {"flame":"","avalanche":"\n ","nodes":"\n \n \n \n ","rocket":"\n \n \n ","branch":"\n \n \n ","daemon":"\n ","clock":"\n ","warning":"\n \n ","wine":"\n \n \n ","scroll":"\n \n \n ","plug":"\n \n \n \n \n ","lock":"\n \n ","package_skull":"\n \n \n \n ","restart":"\n \n \n ","key":"\n ","colon":"\n ","container":"\n \n \n \n ","melting_clock":"\n \n ","pencil":"\n ","blueprint":"\n \n \n \n ","pixel":"\n \n \n \n ","ship":"\n \n \n \n ","spark_cursor":"\n \n \n \n ","needle":"","hammer_scroll":"\n \n ","anvil":"\n \n \n \n ","crystal":"\n \n ","palace":"\n \n \n \n \n ","dragon":"","antenna":"\n \n \n \n \n \n ","puzzle":"","rewind":"\n ","spiral":"\n \n \n \n ","quote":"\n ","compass":"\n ","browser":"\n \n ","terminal":"\n ","wand":"\n \n \n \n \n \n \n ","folder":"\n \n ","eye":"\n ","wave":"","swap":"\n \n \n ","router":"\n \n \n \n \n ","codex":"\n \n ","prism":"\n \n ","marathon":"\n \n ","calendar":"\n \n \n \n \n \n \n \n \n ","moon":"","cache":"\n \n ","secret":"\n \n "}; + + const tierClass = function (tier) { + return tier ? "ha-tier-" + tier.toLowerCase() : "ha-tier-pending"; + }; + + async function api(path, options) { + const url = "/api/plugins/hermes-achievements" + path; + const res = await fetch(url, options || {}); + if (!res.ok) { + const text = await res.text().catch(function () { return res.statusText; }); + throw new Error(res.status + ": " + text); + } + const text = await res.text(); + try { + return JSON.parse(text); + } catch (_) { + return null; + } + } + + function AchievementIcon({ icon }) { + const svg = LUCIDE[icon] || LUCIDE.secret; + const ref = React.useRef(null); + React.useEffect(function () { + if (!ref.current) return; + const el = ref.current; + while (el.firstChild) el.removeChild(el.firstChild); + try { + const doc = new DOMParser().parseFromString( + "" + svg + "", + "image/svg+xml" + ); + if (!doc.querySelector("parsererror")) { + Array.from(doc.documentElement.childNodes).forEach(function (n) { + el.appendChild(document.importNode(n, true)); + }); + } + } catch (_) {} + }, [svg]); + return React.createElement("svg", { + ref: ref, + className: "ha-lucide", + viewBox: "0 0 24 24", + fill: "none", + stroke: "currentColor", + strokeWidth: 2, + strokeLinecap: "round", + strokeLinejoin: "round", + "aria-hidden": "true", + }); + } + + function StatCard(props) { + return React.createElement(C.Card, { className: "ha-stat" }, + React.createElement(C.CardContent, { className: "ha-stat-content" }, + React.createElement("div", { className: "ha-stat-label" }, props.label), + React.createElement("div", { className: "ha-stat-value" }, props.value), + props.hint && React.createElement("div", { className: "ha-stat-hint" }, props.hint) + ) + ); + } + + function TierLegend() { + return React.createElement("div", { className: "ha-tier-legend" }, + ["Copper", "Silver", "Gold", "Diamond", "Olympian"].map(function (tier, index, arr) { + return React.createElement(React.Fragment, { key: tier }, + React.createElement("span", { className: "ha-tier-step ha-tier-" + tier.toLowerCase() }, + React.createElement("i", null), + tier + ), + index < arr.length - 1 && React.createElement("span", { className: "ha-tier-arrow" }, "→") + ); + }) + ); + } + + + function LoadingSkeletonCard(props) { + return React.createElement(C.Card, { className: "ha-card ha-skeleton-card ha-tier-pending" }, + React.createElement(C.CardContent, { className: "ha-card-content" }, + React.createElement("div", { className: "ha-card-head" }, + React.createElement("div", { className: "ha-skeleton ha-skeleton-icon" }), + React.createElement("div", { className: "ha-skeleton-stack" }, + React.createElement("div", { className: "ha-skeleton ha-skeleton-title" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-meta" }) + ), + React.createElement("div", { className: "ha-badges" }, + React.createElement("div", { className: "ha-skeleton ha-skeleton-badge" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-badge ha-skeleton-badge-short" }) + ) + ), + React.createElement("div", { className: "ha-skeleton ha-skeleton-line" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-line ha-skeleton-line-short" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-criteria" }), + React.createElement("div", { className: "ha-evidence-slot" }, React.createElement("div", { className: "ha-skeleton ha-skeleton-evidence" })), + React.createElement("div", { className: "ha-progress-row" }, + React.createElement("div", { className: "ha-skeleton ha-skeleton-progress" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-progress-text" }) + ) + ) + ); + } + + function LoadingPage() { + return React.createElement("div", { className: "ha-page ha-page-loading" }, + React.createElement("section", { className: "ha-hero ha-loading-hero" }, + React.createElement("div", null, + React.createElement("div", { className: "ha-kicker" }, "Agentic Gamerscore"), + React.createElement("h1", null, "Hermes Achievements"), + React.createElement("p", null, "Scanning Hermes session history. First scan can take 5–10 seconds on large histories.") + ), + React.createElement("div", { className: "ha-scan-status", role: "status", "aria-live": "polite" }, + React.createElement("span", { className: "ha-scan-pulse", "aria-hidden": "true" }), + React.createElement("div", null, + React.createElement("strong", null, "Building achievement profile…"), + React.createElement("p", null, "Reading sessions, tool calls, model metadata, and unlock state.") + ) + ) + ), + React.createElement("div", { className: "ha-stats" }, + ["Unlocked", "Discovered", "Secrets", "Highest tier", "Latest"].map(function (label) { + return React.createElement(C.Card, { key: label, className: "ha-stat ha-skeleton-stat" }, + React.createElement(C.CardContent, { className: "ha-stat-content" }, + React.createElement("div", { className: "ha-stat-label" }, label), + React.createElement("div", { className: "ha-skeleton ha-skeleton-stat-value" }), + React.createElement("div", { className: "ha-skeleton ha-skeleton-stat-hint" }) + ) + ); + }) + ), + React.createElement("section", { className: "ha-guide ha-loading-guide" }, + React.createElement("div", null, + React.createElement("strong", null, "Scan status"), + React.createElement("p", null, "Hermes is scanning local history once, then cards will appear automatically. Nothing is stuck if this takes a few seconds.") + ), + React.createElement("div", null, + React.createElement("strong", null, "What is scanned"), + React.createElement("p", null, "Sessions, tool calls, model metadata, errors, achievements, and local unlock state.") + ) + ), + React.createElement("section", { className: "ha-grid" }, [0, 1, 2, 3, 4, 5].map(function (i) { + return React.createElement(LoadingSkeletonCard, { key: i }); + })) + ); + } + + + function AchievementCard({ achievement }) { + const unlocked = achievement.unlocked; + const progress = achievement.progress || 0; + const pct = achievement.progress_pct || (unlocked ? 100 : 0); + const state = achievement.state || (unlocked ? "unlocked" : "discovered"); + const stateLabel = state === "unlocked" ? "Unlocked" : (state === "secret" ? "Secret" : "Discovered"); + const targetTier = achievement.next_tier || achievement.tier; + const tierLabel = achievement.tier ? achievement.tier : (targetTier ? "Target " + targetTier : (state === "secret" ? "Hidden" : (unlocked ? "Complete" : "Objective"))); + const progressText = state === "secret" ? "hidden" : (progress + (achievement.next_threshold ? " / " + achievement.next_threshold : "")); + return React.createElement(C.Card, { className: cn("ha-card", "ha-state-" + state, tierClass(achievement.tier || achievement.next_tier)) }, + React.createElement(C.CardContent, { className: "ha-card-content" }, + React.createElement("div", { className: "ha-card-head" }, + React.createElement("div", { className: "ha-icon" }, React.createElement(AchievementIcon, { icon: achievement.icon || "secret" })), + React.createElement("div", { className: "ha-card-title-wrap" }, + React.createElement("div", { className: "ha-card-title" }, achievement.name), + React.createElement("div", { className: "ha-card-category" }, achievement.category) + ), + React.createElement("div", { className: "ha-badges" }, + React.createElement("span", { className: "ha-state-badge" }, stateLabel), + React.createElement("span", { className: "ha-tier-badge" }, tierLabel) + ) + ), + React.createElement("p", { className: "ha-description" }, achievement.description), + achievement.criteria && React.createElement("details", { className: "ha-criteria" }, + React.createElement("summary", null, state === "secret" ? "How to reveal" : "What counts"), + React.createElement("p", null, achievement.criteria) + ), + React.createElement("div", { className: "ha-evidence-slot" }, + achievement.evidence ? React.createElement("div", { className: "ha-evidence" }, + React.createElement("span", { className: "ha-evidence-label" }, "Evidence"), + React.createElement("span", { className: "ha-evidence-title" }, achievement.evidence.title || achievement.evidence.session_id || "session") + ) : React.createElement("div", { className: "ha-evidence ha-evidence-empty", "aria-hidden": "true" }, "No evidence yet") + ), + React.createElement("div", { className: "ha-progress-row" }, + React.createElement("div", { className: "ha-progress-track" }, + React.createElement("div", { className: "ha-progress-fill", style: { width: Math.max(state === "secret" ? 0 : 3, Math.min(100, pct)) + "%" } }) + ), + React.createElement("span", { className: "ha-progress-text" }, progressText) + ) + ) + ); + } + + function AchievementsPage() { + const [data, setData] = hooks.useState(null); + const [loading, setLoading] = hooks.useState(true); + const [error, setError] = hooks.useState(null); + const [category, setCategory] = hooks.useState("All"); + const [visibility, setVisibility] = hooks.useState("all"); + + function load() { + setLoading(true); + api("/achievements") + .then(function (payload) { setData(payload); setError((payload && payload.error) || null); }) + .catch(function (err) { setError(String(err)); }) + .finally(function () { setLoading(false); }); + } + // refresh() re-fetches without flipping the loading state — used by the + // auto-poller during an in-progress background scan so the page updates + // with growing unlock counts instead of flashing the loading skeleton. + function refresh() { + api("/achievements") + .then(function (payload) { setData(payload); setError((payload && payload.error) || null); }) + .catch(function (err) { setError(String(err)); }); + } + hooks.useEffect(load, []); + + // Auto-poll while the backend is still scanning. scan_meta.mode is + // "pending" on the very first request (no cache yet) and "in_progress" + // while the background thread is publishing partial snapshots. Once it + // flips to "full" or "incremental" the scan is done and we stop polling. + const scanMode = (data && data.scan_meta && data.scan_meta.mode) || null; + const scanInFlight = scanMode === "pending" || scanMode === "in_progress"; + hooks.useEffect(function () { + if (!scanInFlight) return undefined; + const id = setInterval(refresh, 4000); + return function () { clearInterval(id); }; + }, [scanInFlight]); + + const achievements = (data && data.achievements) || []; + const categories = ["All"].concat(Array.from(new Set(achievements.map(function (a) { return a.category; })))); + const visible = achievements.filter(function (a) { + if (category !== "All" && a.category !== category) return false; + if (visibility === "unlocked" && a.state !== "unlocked") return false; + if (visibility === "discovered" && a.state !== "discovered") return false; + if (visibility === "secret" && a.state !== "secret") return false; + return true; + }); + const unlocked = achievements.filter(function (a) { return a.state === "unlocked"; }); + const discovered = achievements.filter(function (a) { return a.state === "discovered"; }); + const secret = achievements.filter(function (a) { return a.state === "secret"; }); + const latest = unlocked.slice().sort(function (a, b) { return (b.unlocked_at || 0) - (a.unlocked_at || 0); }).slice(0, 5); + const highest = ["Olympian", "Diamond", "Gold", "Silver", "Copper"].find(function (tier) { return unlocked.some(function (a) { return a.tier === tier; }); }) || "None yet"; + + // Build the in-progress scan banner once so the JSX below stays readable. + // Shows nothing when the scan is idle. When a scan is running it renders + // a pulsing status row with "X / Y sessions · Z%" and a filling bar, so + // the user gets continuous visual feedback during long cold scans on + // large session databases (can take several minutes on 8000+ sessions). + let scanBanner = null; + if (scanInFlight) { + const meta = (data && data.scan_meta) || {}; + const scanned = Number(meta.sessions_scanned_so_far || meta.sessions_total || 0); + const total = Number(meta.sessions_expected_total || 0); + const pct = total > 0 ? Math.max(0, Math.min(100, Math.floor((scanned / total) * 100))) : 0; + const headline = scanMode === "pending" + ? "Starting achievement scan…" + : "Building achievement profile…"; + const detail = total > 0 + ? ("Scanned " + scanned.toLocaleString() + " of " + total.toLocaleString() + " sessions · " + pct + "%. Badges unlock as more history streams in.") + : "Reading sessions, tool calls, model metadata, and unlock state. Badges appear here as they unlock."; + scanBanner = React.createElement("section", { className: "ha-scan-banner", role: "status", "aria-live": "polite" }, + React.createElement("div", { className: "ha-scan-banner-head" }, + React.createElement("span", { className: "ha-scan-pulse", "aria-hidden": "true" }), + React.createElement("div", { className: "ha-scan-banner-text" }, + React.createElement("strong", null, headline), + React.createElement("p", null, detail) + ) + ), + total > 0 && React.createElement("div", { className: "ha-scan-progress-track", role: "progressbar", "aria-valuemin": 0, "aria-valuemax": 100, "aria-valuenow": pct }, + React.createElement("div", { className: "ha-scan-progress-fill", style: { width: pct + "%" } }) + ) + ); + } + + if (loading) { + return React.createElement(LoadingPage, null); + } + + return React.createElement("div", { className: "ha-page" }, + React.createElement("section", { className: "ha-hero" }, + React.createElement("div", null, + React.createElement("div", { className: "ha-kicker" }, "Agentic Gamerscore"), + React.createElement("h1", null, "Hermes Achievements"), + React.createElement("p", null, "Collectible Hermes badges earned from real session history. Known unfinished achievements are shown as Discovered; Secret achievements stay hidden until the first matching behavior appears.") + ), + React.createElement(C.Button, { onClick: load, className: "ha-refresh" }, "Rescan") + ), + scanBanner, + error && React.createElement(C.Card, { className: "ha-error" }, React.createElement(C.CardContent, null, String(error))), + React.createElement("div", { className: "ha-stats" }, + React.createElement(StatCard, { label: "Unlocked", value: (data ? data.unlocked_count : 0) + " / " + (data ? data.total_count : 0), hint: "earned badges" }), + React.createElement(StatCard, { label: "Discovered", value: discovered.length, hint: "known, not earned yet" }), + React.createElement(StatCard, { label: "Secrets", value: secret.length, hint: "hidden until first signal" }), + React.createElement(StatCard, { label: "Highest tier", value: highest, hint: "Copper → Silver → Gold → Diamond → Olympian" }), + React.createElement(StatCard, { label: "Latest", value: latest[0] ? latest[0].name : "None yet", hint: latest[0] ? latest[0].category : "run Hermes more" }) + ), + React.createElement("section", { className: "ha-guide" }, + React.createElement("div", null, + React.createElement("strong", null, "Tiers"), + React.createElement(TierLegend, null) + ), + React.createElement("div", null, + React.createElement("strong", null, "Secret achievements"), + React.createElement("p", null, "Secrets hide their exact trigger. Once Hermes sees a related signal, the card becomes Discovered and shows its requirement.") + ) + ), + React.createElement("div", { className: "ha-toolbar" }, + React.createElement("div", { className: "ha-pills" }, categories.map(function (cat) { + return React.createElement("button", { key: cat, onClick: function () { setCategory(cat); }, className: cat === category ? "active" : "" }, cat); + })), + React.createElement("div", { className: "ha-pills" }, ["all", "unlocked", "discovered", "secret"].map(function (v) { + return React.createElement("button", { key: v, onClick: function () { setVisibility(v); }, className: v === visibility ? "active" : "" }, v); + })) + ), + latest.length > 0 && React.createElement("section", { className: "ha-latest" }, + React.createElement("h2", null, "Recent unlocks"), + React.createElement("div", { className: "ha-latest-row" }, latest.map(function (a) { + return React.createElement("div", { key: a.id, className: cn("ha-chip", tierClass(a.tier)) }, + React.createElement("span", { className: "ha-chip-icon" }, React.createElement(AchievementIcon, { icon: a.icon || "secret" })), + a.name + ); + })) + ), + visibility === "secret" && visible.length === 0 && React.createElement(C.Card, { className: "ha-secret-empty" }, + React.createElement(C.CardContent, { className: "ha-secret-empty-content" }, + React.createElement("strong", null, "No hidden secrets left in this scan."), + React.createElement("p", null, "Clue: secrets usually start from unusual failure or power-user patterns — port conflicts, permission walls, missing env vars, YAML mistakes, Docker collisions, rollback/checkpoint use, cache hits, or tiny fixes after lots of red text.") + ) + ), + React.createElement("section", { className: "ha-grid" }, visible.map(function (a) { + return React.createElement(AchievementCard, { key: a.id, achievement: a }); + })) + ); + } + + window.__HERMES_PLUGINS__.register("hermes-achievements", AchievementsPage); +})(); diff --git a/plugins/hermes-achievements/dashboard/dist/style.css b/plugins/hermes-achievements/dashboard/dist/style.css new file mode 100644 index 0000000000..fc0e138f4e --- /dev/null +++ b/plugins/hermes-achievements/dashboard/dist/style.css @@ -0,0 +1,120 @@ +/* hermes-achievements dashboard styles + * Originally authored by @PCinkusz — https://github.com/PCinkusz/hermes-achievements (MIT). + * Bundled into hermes-agent. The in-progress scan banner rules at the bottom + * (.ha-scan-banner*) are a small addition layered on top of the original bundle. + */ +.ha-page { display: flex; flex-direction: column; gap: 1rem; } +.ha-hero { position: relative; overflow: hidden; display: flex; align-items: flex-end; justify-content: space-between; gap: 1rem; border: 1px solid var(--color-border); background: radial-gradient(circle at 12% 0, rgba(103,232,249,.13), transparent 30%), linear-gradient(135deg, color-mix(in srgb, var(--color-card) 88%, transparent), color-mix(in srgb, var(--color-primary) 10%, transparent)); padding: 1.25rem; } +.ha-hero:before { content: ""; position: absolute; inset: auto -10% -80% -10%; height: 180%; pointer-events: none; background: radial-gradient(circle, rgba(242,201,76,.12), transparent 55%); } +.ha-hero h1 { position: relative; margin: 0; font-size: clamp(2rem, 4vw, 4.2rem); line-height: .9; letter-spacing: -0.06em; } +.ha-hero p { position: relative; max-width: 52rem; margin: .65rem 0 0; color: var(--color-muted-foreground); } +.ha-kicker { position: relative; color: var(--color-muted-foreground); text-transform: uppercase; letter-spacing: .18em; font-size: .72rem; font-family: var(--font-mono, ui-monospace, monospace); } +.ha-refresh { position: relative; white-space: nowrap; } +.ha-stats { display: grid; grid-template-columns: repeat(5, minmax(0, 1fr)); gap: .75rem; } +.ha-stat-content { padding: 1rem !important; } +.ha-stat-label { color: var(--color-muted-foreground); font-size: .75rem; text-transform: uppercase; letter-spacing: .12em; } +.ha-stat-value { margin-top: .35rem; font-size: 1.4rem; font-weight: 750; letter-spacing: -0.035em; } +.ha-stat-hint { margin-top: .2rem; color: var(--color-muted-foreground); font-size: .75rem; } +.ha-toolbar { display: flex; justify-content: space-between; gap: .75rem; align-items: center; flex-wrap: wrap; } +.ha-pills { display: flex; gap: .35rem; flex-wrap: wrap; } +.ha-pills button { border: 1px solid var(--color-border); background: color-mix(in srgb, var(--color-card) 72%, transparent); color: var(--color-muted-foreground); padding: .35rem .6rem; font-size: .78rem; cursor: pointer; } +.ha-pills button.active, .ha-pills button:hover { color: var(--color-foreground); border-color: var(--ha-tier, var(--color-ring)); background: color-mix(in srgb, var(--color-primary) 16%, var(--color-card)); } +.ha-grid { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: .9rem; } +.ha-card { --ha-tier: var(--color-border); position: relative; overflow: hidden; min-height: 214px; border: 1px solid color-mix(in srgb, var(--ha-tier) 46%, var(--color-border)); background: radial-gradient(circle at 2.6rem 2.2rem, color-mix(in srgb, var(--ha-tier) 16%, transparent), transparent 34%), linear-gradient(180deg, rgba(255,255,255,.04), transparent), color-mix(in srgb, var(--color-card) 92%, #000); transition: transform .16s ease, border-color .16s ease, opacity .16s ease, box-shadow .16s ease; } +.ha-card:hover { transform: translateY(-2px); border-color: var(--ha-tier); box-shadow: 0 0 0 1px color-mix(in srgb, var(--ha-tier) 16%, transparent); } +.ha-card-content { position: relative; z-index: 1; padding: 1rem !important; display: flex; flex-direction: column; gap: .75rem; height: 100%; } +.ha-card-head { display: grid; grid-template-columns: 3.1rem minmax(0, 1fr) auto; gap: .85rem; align-items: start; } +.ha-icon { display: grid; place-items: center; width: 2.9rem; height: 2.9rem; color: var(--ha-tier); } +.ha-lucide { width: 1.78rem; height: 1.78rem; stroke: currentColor; stroke-width: 2.15; filter: drop-shadow(0 0 8px color-mix(in srgb, var(--ha-tier) 24%, transparent)); } +.ha-card-title { font-weight: 780; line-height: 1.05; letter-spacing: -0.025em; } +.ha-card-category { margin-top: .28rem; color: var(--color-muted-foreground); font-size: .76rem; } +.ha-badges { display: flex; flex-direction: column; align-items: flex-end; gap: .25rem; } +.ha-tier-badge, .ha-state-badge { border: 1px solid var(--ha-tier); color: var(--ha-tier); background: color-mix(in srgb, var(--ha-tier) 10%, transparent); padding: .16rem .38rem; font-size: .67rem; text-transform: uppercase; letter-spacing: .08em; font-family: var(--font-mono, ui-monospace, monospace); } +.ha-description { margin: 0; color: var(--color-muted-foreground); font-size: .86rem; line-height: 1.45; min-height: 2.4em; } +.ha-criteria { border: 1px solid color-mix(in srgb, var(--ha-tier) 28%, var(--color-border)); background: color-mix(in srgb, var(--ha-tier) 5%, transparent); } +.ha-criteria summary { cursor: pointer; padding: .5rem .65rem; color: var(--ha-tier); text-transform: uppercase; letter-spacing: .1em; font-size: .66rem; font-family: var(--font-mono, ui-monospace, monospace); user-select: none; } +.ha-criteria summary:hover { background: color-mix(in srgb, var(--ha-tier) 8%, transparent); } +.ha-criteria p { margin: 0; border-top: 1px solid color-mix(in srgb, var(--ha-tier) 18%, var(--color-border)); padding: .55rem .65rem .65rem; color: color-mix(in srgb, var(--color-foreground) 78%, var(--color-muted-foreground)); font-size: .76rem; line-height: 1.38; } +.ha-progress-row { display: flex; align-items: center; gap: .55rem; margin-top: 0; } +.ha-progress-track { flex: 1; height: .48rem; border: 1px solid color-mix(in srgb, var(--ha-tier) 34%, var(--color-border)); background: rgba(0,0,0,.22); overflow: hidden; } +.ha-progress-fill { height: 100%; background: linear-gradient(90deg, var(--ha-tier), color-mix(in srgb, var(--ha-tier) 48%, white)); } +.ha-progress-text { min-width: 5.4rem; text-align: right; font-family: var(--font-mono, ui-monospace, monospace); color: var(--color-muted-foreground); font-size: .72rem; } +.ha-evidence-slot { min-height: 1.65rem; margin-top: auto; display: flex; align-items: flex-end; } +.ha-evidence { width: 100%; display: flex; align-items: center; gap: .4rem; color: var(--color-muted-foreground); font-size: .72rem; min-width: 0; } +.ha-evidence-label { text-transform: uppercase; letter-spacing: .09em; font-family: var(--font-mono, ui-monospace, monospace); flex: 0 0 auto; } +.ha-evidence-title { min-width: 0; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; color: color-mix(in srgb, var(--color-foreground) 84%, var(--color-muted-foreground)); } +.ha-evidence-empty { visibility: hidden; } +.ha-latest h2 { margin: 0 0 .5rem; font-size: 1rem; } +.ha-latest-row { display: flex; gap: .5rem; flex-wrap: wrap; } +.ha-chip { display: inline-flex; align-items: center; gap: .35rem; border: 1px solid var(--ha-tier); color: var(--ha-tier); background: color-mix(in srgb, var(--ha-tier) 10%, transparent); padding: .35rem .55rem; font-size: .8rem; } +.ha-chip-icon .ha-lucide { width: .95rem; height: .95rem; } +.ha-slot { border-style: dashed; } +.ha-slot-content { display: flex; gap: .6rem; align-items: center; padding: .65rem .8rem !important; font-size: .82rem; } +.ha-slot-star { color: #67e8f9; } +.ha-slot-muted { color: var(--color-muted-foreground); margin-left: auto; } +.ha-error { border-color: #ef4444; color: #fecaca; } +.ha-loading { color: var(--color-muted-foreground); font-family: var(--font-mono, ui-monospace, monospace); padding: 2rem; border: 1px dashed var(--color-border); } +.ha-guide { display: grid; grid-template-columns: minmax(0, 1.15fr) minmax(0, .85fr); gap: .75rem; } +.ha-guide > div { border: 1px solid var(--color-border); background: color-mix(in srgb, var(--color-card) 82%, transparent); padding: .85rem 1rem; } +.ha-guide strong { display: block; margin-bottom: .45rem; font-size: .78rem; text-transform: uppercase; letter-spacing: .12em; font-family: var(--font-mono, ui-monospace, monospace); } +.ha-guide p { margin: 0; color: var(--color-muted-foreground); font-size: .84rem; line-height: 1.45; } +.ha-tier-legend { display: flex; align-items: center; gap: .45rem; flex-wrap: wrap; } +.ha-tier-step { --ha-tier: var(--color-border); display: inline-flex; align-items: center; gap: .32rem; color: var(--ha-tier); border: 1px solid color-mix(in srgb, var(--ha-tier) 52%, var(--color-border)); background: color-mix(in srgb, var(--ha-tier) 8%, transparent); padding: .28rem .45rem; font-size: .72rem; font-family: var(--font-mono, ui-monospace, monospace); text-transform: uppercase; letter-spacing: .06em; } +.ha-tier-step i { width: .55rem; height: .55rem; background: var(--ha-tier); display: inline-block; } +.ha-tier-arrow { color: var(--color-muted-foreground); } +.ha-state-discovered { opacity: .92; } +.ha-state-discovered .ha-card-title { color: color-mix(in srgb, var(--color-foreground) 82%, var(--ha-tier)); } +.ha-state-secret { opacity: .5; filter: grayscale(.55); } +.ha-state-secret:after { content: ""; position: absolute; inset: 0; pointer-events: none; background: repeating-linear-gradient(-45deg, transparent 0 8px, rgba(255,255,255,.035) 8px 10px); } +.ha-tier-pending { --ha-tier: color-mix(in srgb, var(--color-muted-foreground) 64%, transparent); } +.ha-tier-copper { --ha-tier: #b87333; } +.ha-tier-silver { --ha-tier: #c0c7d2; } +.ha-tier-gold { --ha-tier: #f2c94c; box-shadow: 0 0 22px rgba(242,201,76,.08); } +.ha-tier-diamond { --ha-tier: #67e8f9; box-shadow: 0 0 24px rgba(103,232,249,.1); } +.ha-tier-olympian { --ha-tier: #c084fc; box-shadow: 0 0 34px rgba(192,132,252,.18), 0 0 12px rgba(242,201,76,.1); } +@media (max-width: 980px) { .ha-stats { grid-template-columns: repeat(2, minmax(0, 1fr)); } .ha-guide { grid-template-columns: 1fr; } } +@media (max-width: 800px) { .ha-stats { grid-template-columns: 1fr; } .ha-hero { flex-direction: column; align-items: stretch; } .ha-card-head { grid-template-columns: 3.1rem 1fr; } .ha-badges { grid-column: 1 / -1; align-items: flex-start; flex-direction: row; } } + +.ha-secret-empty-content { padding: 1rem !important; } +.ha-secret-empty strong { display: block; margin-bottom: .35rem; } +.ha-secret-empty p { margin: 0; color: var(--color-muted-foreground); font-size: .86rem; line-height: 1.45; } +.ha-page-loading { animation: ha-fade-in .18s ease-out; } +.ha-loading-hero { align-items: center; } +.ha-scan-status { position: relative; z-index: 1; display: flex; align-items: center; gap: .8rem; min-width: 18rem; border: 1px solid color-mix(in srgb, #67e8f9 35%, var(--color-border)); background: color-mix(in srgb, var(--color-card) 78%, transparent); padding: .8rem .95rem; color: var(--color-foreground); } +.ha-scan-status strong { display: block; font-size: .82rem; text-transform: uppercase; letter-spacing: .1em; font-family: var(--font-mono, ui-monospace, monospace); } +.ha-scan-status p { margin: .25rem 0 0; font-size: .78rem; line-height: 1.35; color: var(--color-muted-foreground); } +.ha-scan-pulse { width: .72rem; height: .72rem; flex: 0 0 auto; border-radius: 999px; background: #67e8f9; box-shadow: 0 0 0 0 rgba(103,232,249,.55); animation: ha-pulse 1.35s ease-out infinite; } +.ha-skeleton-card { pointer-events: none; } +.ha-skeleton { position: relative; overflow: hidden; border-radius: 0; background: color-mix(in srgb, var(--color-muted-foreground) 16%, transparent); } +.ha-skeleton:after { content: ""; position: absolute; inset: 0; transform: translateX(-100%); background: linear-gradient(90deg, transparent, rgba(255,255,255,.14), transparent); animation: ha-shimmer 1.35s infinite; } +.ha-skeleton-stack { display: flex; flex-direction: column; gap: .45rem; padding-top: .15rem; } +.ha-skeleton-icon { width: 2.9rem; height: 2.9rem; } +.ha-skeleton-title { width: 72%; height: .95rem; } +.ha-skeleton-meta { width: 45%; height: .65rem; } +.ha-skeleton-badge { width: 4.4rem; height: 1.05rem; } +.ha-skeleton-badge-short { width: 3.6rem; } +.ha-skeleton-line { height: .78rem; width: 92%; } +.ha-skeleton-line-short { width: 68%; } +.ha-skeleton-criteria { height: 2.2rem; width: 100%; border: 1px solid color-mix(in srgb, var(--color-muted-foreground) 18%, var(--color-border)); } +.ha-skeleton-evidence { width: 58%; height: .8rem; } +.ha-skeleton-progress { flex: 1; height: .48rem; } +.ha-skeleton-progress-text { width: 4.6rem; height: .75rem; } +.ha-skeleton-stat-value { width: 56%; height: 1.35rem; margin-top: .55rem; } +.ha-skeleton-stat-hint { width: 76%; height: .7rem; margin-top: .55rem; } +.ha-loading-guide p { color: var(--color-muted-foreground); } +@keyframes ha-shimmer { 100% { transform: translateX(100%); } } +@keyframes ha-pulse { 0% { box-shadow: 0 0 0 0 rgba(103,232,249,.48); } 70% { box-shadow: 0 0 0 .65rem rgba(103,232,249,0); } 100% { box-shadow: 0 0 0 0 rgba(103,232,249,0); } } +@keyframes ha-fade-in { from { opacity: 0; transform: translateY(3px); } to { opacity: 1; transform: translateY(0); } } +.ha-loading-hero p, .ha-scan-status p, .ha-loading-guide p { text-transform: none; letter-spacing: normal; } + +/* In-progress scan banner — shown on the main page while the background scan + * is still walking through session history, so the user sees continuous + * progress (X / Y sessions · Z%) instead of guessing whether anything is + * happening. Reuses .ha-scan-pulse + ha-pulse keyframes from the loading page. + */ +.ha-scan-banner { display: flex; flex-direction: column; gap: .6rem; border: 1px solid color-mix(in srgb, #67e8f9 35%, var(--color-border)); background: color-mix(in srgb, var(--color-card) 78%, transparent); padding: .8rem .95rem; animation: ha-fade-in .18s ease-out; } +.ha-scan-banner-head { display: flex; align-items: center; gap: .8rem; } +.ha-scan-banner-text strong { display: block; font-size: .82rem; text-transform: uppercase; letter-spacing: .1em; font-family: var(--font-mono, ui-monospace, monospace); color: var(--color-foreground); } +.ha-scan-banner-text p { margin: .25rem 0 0; font-size: .78rem; line-height: 1.35; color: var(--color-muted-foreground); text-transform: none; letter-spacing: normal; } +.ha-scan-progress-track { height: .4rem; border: 1px solid color-mix(in srgb, #67e8f9 28%, var(--color-border)); background: rgba(0,0,0,.22); overflow: hidden; } +.ha-scan-progress-fill { height: 100%; background: linear-gradient(90deg, #67e8f9, color-mix(in srgb, #67e8f9 48%, white)); transition: width .4s ease-out; } diff --git a/plugins/hermes-achievements/dashboard/manifest.json b/plugins/hermes-achievements/dashboard/manifest.json new file mode 100644 index 0000000000..02c4050f34 --- /dev/null +++ b/plugins/hermes-achievements/dashboard/manifest.json @@ -0,0 +1,11 @@ +{ + "name": "hermes-achievements", + "label": "Achievements", + "description": "Steam-style achievements for vibe coding and agentic Hermes workflows.", + "icon": "Star", + "version": "0.3.1", + "tab": { "path": "/achievements", "position": "after:analytics" }, + "entry": "dist/index.js", + "css": "dist/style.css", + "api": "plugin_api.py" +} diff --git a/plugins/hermes-achievements/dashboard/plugin_api.py b/plugins/hermes-achievements/dashboard/plugin_api.py new file mode 100644 index 0000000000..678d49fb61 --- /dev/null +++ b/plugins/hermes-achievements/dashboard/plugin_api.py @@ -0,0 +1,1053 @@ +"""Hermes Achievements dashboard plugin backend. + +Mounted at /api/plugins/hermes-achievements/ by Hermes dashboard. +""" +from __future__ import annotations + +import json +import math +import re +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +try: + from fastapi import APIRouter +except Exception: # Allows local unit tests without dashboard dependencies. + class APIRouter: # type: ignore + def get(self, *_args, **_kwargs): + return lambda fn: fn + def post(self, *_args, **_kwargs): + return lambda fn: fn + +router = APIRouter() + +SNAPSHOT_TTL_SECONDS = 120 +_SCAN_LOCK = threading.Lock() +_SNAPSHOT_CACHE: Optional[Dict[str, Any]] = None +_SNAPSHOT_CACHE_AT = 0 +_SCAN_STATUS: Dict[str, Any] = { + "state": "idle", + "started_at": None, + "finished_at": None, + "last_error": None, + "last_duration_ms": None, + "run_count": 0, +} + +ERROR_RE = re.compile(r"\b(error|failed|failure|traceback|exception|permission denied|not found|eaddrinuse|already in use|timed out|blocked)\b", re.I) +PORT_RE = re.compile(r"\b(port\s+)?(3000|5173|8000|8080|9119)\b.*\b(in use|already|taken|eaddrinuse)\b|\beaddrinuse\b", re.I) +INSTALL_RE = re.compile(r"\b(npm|pnpm|yarn|pip|uv)\b.*\b(install|add)\b", re.I) +SUCCESS_RE = re.compile(r"\b(success|passed|built|compiled|done|exit_code[\"']?\s*[:=]\s*0|verified|ok)\b", re.I) +FILE_RE = re.compile(r"(?:/home/|~/?|\./|/mnt/)[\w./-]+\.(?:py|js|ts|tsx|jsx|css|html|md|json|yaml|yml|svg|sql|sh)") + +TIER_NAMES = ["Copper", "Silver", "Gold", "Diamond", "Olympian"] + + +def tiers(values: List[int]) -> List[Dict[str, Any]]: + return [{"name": name, "threshold": threshold} for name, threshold in zip(TIER_NAMES, values)] + + +def req(metric: str, gte: int) -> Dict[str, Any]: + return {"metric": metric, "gte": gte} + + +ACHIEVEMENTS: List[Dict[str, Any]] = [ + # Agent Autonomy — mostly best-session feats + {"id": "let_him_cook", "name": "Let Him Cook", "description": "Let Hermes run a serious autonomous tool chain in one session.", "category": "Agent Autonomy", "kind": "best_session", "icon": "flame", "threshold_metric": "max_tool_calls_in_session", "tiers": tiers([200, 500, 1200, 3000, 8000])}, + {"id": "autonomous_avalanche", "name": "Autonomous Avalanche", "description": "Accumulate a lifetime avalanche of Hermes tool calls across sessions.", "category": "Agent Autonomy", "kind": "lifetime", "icon": "avalanche", "threshold_metric": "total_tool_calls", "tiers": tiers([1000, 3000, 8000, 20000, 50000])}, + {"id": "toolchain_maxxer", "name": "Toolchain Maxxer", "description": "Use a wide spread of distinct Hermes tools in one session.", "category": "Agent Autonomy", "kind": "best_session", "icon": "nodes", "threshold_metric": "max_distinct_tools_in_session", "tiers": tiers([18, 28, 45, 70, 100])}, + {"id": "full_send", "name": "Full Send", "description": "Terminal, files, and web/browser all get involved in one real run.", "category": "Agent Autonomy", "kind": "multi_condition", "icon": "rocket", "requirements": [req("max_terminal_calls_in_session", 180), req("max_file_tool_calls_in_session", 120), req("max_web_browser_calls_in_session", 60)]}, + {"id": "subagent_commander", "name": "Subagent Commander", "description": "Coordinate delegated agent work.", "category": "Agent Autonomy", "kind": "lifetime", "icon": "branch", "threshold_metric": "total_delegate_calls", "tiers": tiers([5, 40, 100, 1000, 5000])}, + {"id": "background_process_enjoyer", "name": "Background Process Enjoyer", "description": "Start or control enough long-running processes to deserve the title.", "category": "Agent Autonomy", "kind": "lifetime", "icon": "daemon", "threshold_metric": "total_process_calls", "tiers": tiers([300, 800, 2000, 6000, 15000])}, + {"id": "cron_necromancer", "name": "Cron Necromancer", "description": "Raise scheduled autonomous jobs from the dead.", "category": "Agent Autonomy", "kind": "lifetime", "icon": "clock", "threshold_metric": "total_cron_calls", "tiers": tiers([1000, 3000, 8000, 20000, 50000])}, + + # Debugging Chaos — higher thresholds + multi-condition events + {"id": "red_text_connoisseur", "name": "Red Text Connoisseur", "description": "Encounter enough errors to develop a palate for red text.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "warning", "threshold_metric": "total_errors", "tiers": tiers([1500, 4000, 10000, 25000, 75000])}, + {"id": "stack_trace_sommelier", "name": "Stack Trace Sommelier", "description": "Taste tracebacks by the flight, not by the sip.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "wine", "threshold_metric": "traceback_events", "tiers": tiers([300, 1000, 3000, 8000, 20000])}, + {"id": "actually_read_the_logs", "name": "Actually Read The Logs", "description": "Inspect logs repeatedly instead of guessing.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "scroll", "threshold_metric": "log_read_events", "tiers": tiers([1000, 3000, 8000, 20000, 50000])}, + {"id": "port_3000_taken", "name": "Port 3000 Is Taken", "description": "Discover dev-server port conflict patterns enough times to become numb.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "plug", "secret": True, "threshold_metric": "port_conflict_events", "tiers": tiers([15, 40, 100, 300, 1000])}, + {"id": "permission_denied_any_percent", "name": "Permission Denied Any%", "description": "Speedrun into permission walls.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "lock", "secret": True, "threshold_metric": "permission_denied_events", "tiers": tiers([25, 75, 200, 600, 1500])}, + {"id": "dependency_hell_tourist", "name": "Dependency Hell Tourist", "description": "Package installs fail, then somehow life continues.", "category": "Debugging Chaos", "kind": "multi_condition", "icon": "package_skull", "requirements": [req("install_error_events", 25), req("install_success_events", 10)]}, + {"id": "the_fix_was_restarting", "name": "The Fix Was Restarting It", "description": "Restart after enough error clusters to call it a technique.", "category": "Debugging Chaos", "kind": "multi_condition", "icon": "restart", "requirements": [req("restart_after_error_events", 50), req("total_errors", 4000)]}, + {"id": "forgot_the_env_var", "name": "Forgot The Env Var", "description": "Auth or configuration failed because an environment variable was missing.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "key", "secret": True, "threshold_metric": "env_var_error_events", "tiers": tiers([5000, 15000, 40000, 100000, 250000])}, + {"id": "yaml_colon_incident", "name": "YAML Colon Incident", "description": "Configuration syntax bites back.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "colon", "secret": True, "threshold_metric": "yaml_error_events", "tiers": tiers([1000, 3000, 8000, 20000, 50000])}, + {"id": "docker_name_collision", "name": "Docker Name Collision", "description": "A container name already exists. Of course it does.", "category": "Debugging Chaos", "kind": "lifetime", "icon": "container", "secret": True, "threshold_metric": "docker_conflict_events", "tiers": tiers([75, 200, 600, 1500, 4000])}, + + # Vibe Coding + {"id": "supposed_to_be_quick", "name": "This Was Supposed To Be Quick", "description": "A tiny ask becomes an entire expedition.", "category": "Vibe Coding", "kind": "best_session", "icon": "melting_clock", "threshold_metric": "max_messages_in_session", "tiers": tiers([300, 600, 1200, 2500, 6000])}, + {"id": "one_more_small_change", "name": "One More Small Change", "description": "Make enough file edits in one session to invalidate the phrase small change.", "category": "Vibe Coding", "kind": "best_session", "icon": "pencil", "threshold_metric": "max_file_tool_calls_in_session", "tiers": tiers([150, 400, 1000, 3000, 8000])}, + {"id": "vibe_architect", "name": "Vibe Architect", "description": "Touch a broad surface area in one project session.", "category": "Vibe Coding", "kind": "best_session", "icon": "blueprint", "threshold_metric": "max_files_touched_in_session", "tiers": tiers([300, 700, 1500, 4000, 10000])}, + {"id": "pixel_goblin", "name": "Pixel Goblin", "description": "Do sustained frontend, CSS, SVG, or visual tuning.", "category": "Vibe Coding", "kind": "lifetime", "icon": "pixel", "threshold_metric": "frontend_activity_events", "tiers": tiers([20000, 50000, 120000, 300000, 800000])}, + {"id": "ship_first_ask_later", "name": "Ship First, Ask Later", "description": "Git activity after a serious tool chain.", "category": "Vibe Coding", "kind": "multi_condition", "icon": "ship", "requirements": [req("git_events", 50), req("max_tool_calls_in_session", 500)]}, + {"id": "css_exorcist", "name": "CSS Exorcist", "description": "Cast repeated styling demons out of the interface.", "category": "Vibe Coding", "kind": "lifetime", "icon": "spark_cursor", "threshold_metric": "css_activity_events", "tiers": tiers([10000, 30000, 80000, 200000, 500000])}, + {"id": "one_character_fix", "name": "One Character Fix", "description": "A tiny edit after a pile of errors. Painful. Beautiful.", "category": "Vibe Coding", "kind": "multi_condition", "icon": "needle", "secret": True, "requirements": [req("tiny_patch_after_errors_events", 5), req("total_errors", 4000)]}, + + # Hermes Native + {"id": "skillsmith", "name": "Skillsmith", "description": "Work with Hermes skills enough to leave fingerprints.", "category": "Hermes Native", "kind": "lifetime", "icon": "hammer_scroll", "threshold_metric": "skill_events", "tiers": tiers([5000, 15000, 40000, 100000, 250000])}, + {"id": "skill_issue_skill_created", "name": "Skill Issue? Skill Created.", "description": "Create or patch durable procedures instead of repeating yourself.", "category": "Hermes Native", "kind": "lifetime", "icon": "anvil", "threshold_metric": "skill_manage_events", "tiers": tiers([25, 75, 200, 600, 1500])}, + {"id": "memory_keeper", "name": "Memory Keeper", "description": "Persist durable knowledge with memory or Mnemosyne.", "category": "Hermes Native", "kind": "lifetime", "icon": "crystal", "threshold_metric": "memory_events", "tiers": tiers([100, 300, 1000, 3000, 8000])}, + {"id": "memory_palace", "name": "Memory Palace", "description": "Build a serious durable-memory trail.", "category": "Hermes Native", "kind": "lifetime", "icon": "palace", "threshold_metric": "memory_write_events", "tiers": tiers([100, 300, 1000, 3000, 8000])}, + {"id": "context_dragon", "name": "Context Dragon", "description": "Brush against compression, huge context, or token pressure repeatedly.", "category": "Hermes Native", "kind": "lifetime", "icon": "dragon", "threshold_metric": "context_events", "tiers": tiers([5000, 15000, 40000, 100000, 250000])}, + {"id": "gateway_dweller", "name": "Gateway Dweller", "description": "Live through gateway-connected Hermes workflows.", "category": "Hermes Native", "kind": "lifetime", "icon": "antenna", "threshold_metric": "gateway_events", "tiers": tiers([5000, 15000, 40000, 100000, 250000])}, + {"id": "plugin_goblin", "name": "Plugin Goblin", "description": "Use or develop plugins enough that the dashboard notices.", "category": "Hermes Native", "kind": "lifetime", "icon": "puzzle", "threshold_metric": "plugin_events", "tiers": tiers([1000, 3000, 8000, 20000, 50000])}, + {"id": "rollback_wizard", "name": "Rollback Wizard", "description": "Invoke rollback/checkpoint recovery magic.", "category": "Hermes Native", "kind": "lifetime", "icon": "rewind", "secret": True, "threshold_metric": "rollback_events", "tiers": tiers([500, 1500, 4000, 10000, 25000])}, + + # Research/Web + {"id": "rabbit_hole_certified", "name": "Rabbit Hole Certified", "description": "Search or extract enough web content to qualify as a research spiral.", "category": "Research/Web", "kind": "lifetime", "icon": "spiral", "threshold_metric": "total_web_calls", "tiers": tiers([400, 1200, 3000, 8000, 20000])}, + {"id": "citation_goblin", "name": "Citation Goblin", "description": "Extract enough web pages to become a tiny librarian.", "category": "Research/Web", "kind": "lifetime", "icon": "quote", "threshold_metric": "total_web_extract_calls", "tiers": tiers([100, 300, 1000, 3000, 8000])}, + {"id": "docs_archaeologist", "name": "Docs Archaeologist", "description": "Dig through documentation sources over and over.", "category": "Research/Web", "kind": "lifetime", "icon": "compass", "threshold_metric": "docs_activity_events", "tiers": tiers([5000, 15000, 40000, 100000, 250000])}, + {"id": "browser_possession", "name": "Browser Possession", "description": "Possess a browser through automation repeatedly.", "category": "Research/Web", "kind": "lifetime", "icon": "browser", "threshold_metric": "browser_calls", "tiers": tiers([75, 200, 600, 1500, 4000])}, + + # Tool Mastery + {"id": "terminal_goblin", "name": "Terminal Goblin", "description": "Spend serious time in shell-land.", "category": "Tool Mastery", "kind": "lifetime", "icon": "terminal", "threshold_metric": "total_terminal_calls", "tiers": tiers([750, 2000, 6000, 15000, 50000])}, + {"id": "patch_wizard", "name": "Patch Wizard", "description": "Bend files to your will with targeted patches.", "category": "Tool Mastery", "kind": "lifetime", "icon": "wand", "threshold_metric": "total_patch_calls", "tiers": tiers([250, 750, 2000, 6000, 15000])}, + {"id": "file_archaeologist", "name": "File Archaeologist", "description": "Dig through the filesystem with reads and searches.", "category": "Tool Mastery", "kind": "lifetime", "icon": "folder", "threshold_metric": "total_file_reads_searches", "tiers": tiers([750, 2000, 6000, 15000, 50000])}, + {"id": "image_whisperer", "name": "Image Whisperer", "description": "Use image generation or vision tools enough for visual work.", "category": "Tool Mastery", "kind": "lifetime", "icon": "eye", "threshold_metric": "image_vision_calls", "tiers": tiers([100, 300, 1000, 3000, 8000])}, + {"id": "voice_of_the_machine", "name": "Voice Of The Machine", "description": "Use text-to-speech or voice tooling repeatedly.", "category": "Tool Mastery", "kind": "lifetime", "icon": "wave", "threshold_metric": "tts_calls", "tiers": tiers([10, 30, 100, 300, 800])}, + + # Model Lore + {"id": "model_hopper", "name": "Model Hopper", "description": "Switch or inspect providers/models enough to count as a habit.", "category": "Model Lore", "kind": "lifetime", "icon": "swap", "threshold_metric": "model_events", "tiers": tiers([10000, 30000, 80000, 200000, 500000])}, + {"id": "openrouter_enjoyer", "name": "OpenRouter Enjoyer", "description": "Route model work through OpenRouter repeatedly.", "category": "Model Lore", "kind": "lifetime", "icon": "router", "threshold_metric": "openrouter_events", "tiers": tiers([250, 750, 2000, 6000, 15000])}, + {"id": "codex_conjurer", "name": "Codex Conjurer", "description": "Summon Codex-flavored assistance often enough for a ritual.", "category": "Model Lore", "kind": "lifetime", "icon": "codex", "threshold_metric": "codex_events", "tiers": tiers([500, 1500, 4000, 10000, 25000])}, + {"id": "multi_model_mage", "name": "Multi-Model Mage", "description": "Use a real spread of distinct model names across Hermes history.", "category": "Model Lore", "kind": "lifetime", "icon": "prism", "threshold_metric": "distinct_model_count", "tiers": tiers([10, 20, 40, 80, 160])}, + {"id": "five_model_flight", "name": "Five-Model Flight", "description": "Try at least five distinct LLMs instead of marrying the first model that answers.", "category": "Model Lore", "kind": "lifetime", "icon": "prism", "threshold_metric": "distinct_model_count", "tiers": tiers([5, 10, 20, 40, 80])}, + {"id": "provider_polyglot", "name": "Provider Polyglot", "description": "Use models from multiple providers across Hermes history.", "category": "Model Lore", "kind": "lifetime", "icon": "swap", "threshold_metric": "distinct_provider_count", "tiers": tiers([2, 3, 5, 8, 12])}, + {"id": "model_sommelier", "name": "Model Sommelier", "description": "Taste enough model/provider conversations to develop preferences.", "category": "Model Lore", "kind": "lifetime", "icon": "wine", "threshold_metric": "model_events", "tiers": tiers([250, 750, 2000, 6000, 15000])}, + {"id": "claude_confidant", "name": "Claude Confidant", "description": "Bring Claude-flavored reasoning into the workflow repeatedly.", "category": "Model Lore", "kind": "lifetime", "icon": "quote", "threshold_metric": "claude_events", "tiers": tiers([50, 150, 500, 1500, 4000])}, + {"id": "gemini_cartographer", "name": "Gemini Cartographer", "description": "Map enough Gemini-related workflows to know the terrain.", "category": "Model Lore", "kind": "lifetime", "icon": "compass", "threshold_metric": "gemini_events", "tiers": tiers([50, 150, 500, 1500, 4000])}, + {"id": "open_weights_pilgrim", "name": "Open Weights Pilgrim", "description": "Actually chat with local/open-weight models through Hermes session metadata.", "category": "Model Lore", "kind": "lifetime", "icon": "terminal", "threshold_metric": "local_model_chat_sessions", "tiers": tiers([1, 3, 10, 30, 100])}, + + # Workflow Intelligence + {"id": "toolset_cartographer", "name": "Toolset Cartographer", "description": "Navigate Hermes toolsets deliberately instead of treating tools as a blur.", "category": "Hermes Native", "kind": "lifetime", "icon": "compass", "threshold_metric": "toolset_events", "tiers": tiers([20, 60, 200, 600, 1500])}, + {"id": "config_surgeon", "name": "Config Surgeon", "description": "Operate on real config files, manifests, env files, and dashboard settings without flinching.", "category": "Hermes Native", "kind": "lifetime", "icon": "key", "threshold_metric": "config_events", "tiers": tiers([100, 300, 1000, 3000, 10000])}, + {"id": "rebase_acrobat", "name": "Rebase Acrobat", "description": "Handle real git history surgery: rebase, conflict, merge, fetch, push.", "category": "Vibe Coding", "kind": "lifetime", "icon": "branch", "threshold_metric": "git_history_events", "tiers": tiers([10, 30, 100, 300, 800])}, + {"id": "test_suite_tamer", "name": "Test Suite Tamer", "description": "Run enough verification commands that green text becomes part of the ritual.", "category": "Tool Mastery", "kind": "lifetime", "icon": "daemon", "threshold_metric": "test_events", "tiers": tiers([100, 300, 800, 2400, 6000])}, + {"id": "screenshot_hunter", "name": "Screenshot Hunter", "description": "Capture, inspect, and polish visual proof instead of just claiming it works.", "category": "Tool Mastery", "kind": "lifetime", "icon": "eye", "threshold_metric": "screenshot_events", "tiers": tiers([50, 150, 500, 1500, 5000])}, + + # Lifestyle + {"id": "marathon_operator", "name": "Marathon Operator", "description": "Accumulate a serious number of Hermes sessions.", "category": "Lifestyle", "kind": "lifetime", "icon": "marathon", "threshold_metric": "session_count", "tiers": tiers([75, 200, 500, 1500, 5000])}, + {"id": "weekend_warrior", "name": "Weekend Warrior", "description": "Run Hermes on weekends enough times to make it a lifestyle.", "category": "Lifestyle", "kind": "lifetime", "icon": "calendar", "threshold_metric": "weekend_sessions", "tiers": tiers([25, 75, 200, 600, 1500])}, + {"id": "night_shift_operator", "name": "Night Shift Operator", "description": "Run sessions during gremlin hours repeatedly.", "category": "Lifestyle", "kind": "lifetime", "icon": "moon", "threshold_metric": "night_sessions", "tiers": tiers([25, 75, 200, 600, 1500])}, + {"id": "cache_hit_appreciator", "name": "Cache Hit Appreciator", "description": "Notice or benefit from prompt/cache behavior.", "category": "Lifestyle", "kind": "lifetime", "icon": "cache", "secret": True, "threshold_metric": "cache_events", "tiers": tiers([100, 300, 1000, 3000, 8000])}, +] + + +def state_path() -> Path: + return Path.home() / ".hermes" / "plugins" / "hermes-achievements" / "state.json" + + +def snapshot_path() -> Path: + return Path.home() / ".hermes" / "plugins" / "hermes-achievements" / "scan_snapshot.json" + + +def checkpoint_path() -> Path: + return Path.home() / ".hermes" / "plugins" / "hermes-achievements" / "scan_checkpoint.json" + + +def load_state() -> Dict[str, Any]: + path = state_path() + if not path.exists(): + return {"unlocks": {}} + try: + return json.loads(path.read_text()) + except Exception: + return {"unlocks": {}} + + +def save_state(state: Dict[str, Any]) -> None: + path = state_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(state, indent=2, sort_keys=True)) + + +def _json_safe(value: Any) -> Any: + if isinstance(value, dict): + return {k: _json_safe(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe(v) for v in value] + if isinstance(value, set): + return sorted(_json_safe(v) for v in value) + return value + + +def load_snapshot() -> Optional[Dict[str, Any]]: + path = snapshot_path() + if not path.exists(): + return None + try: + data = json.loads(path.read_text()) + if isinstance(data, dict): + return data + except Exception: + return None + return None + + +def save_snapshot(data: Dict[str, Any]) -> None: + path = snapshot_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(_json_safe(data), indent=2, sort_keys=True)) + + +def load_checkpoint() -> Dict[str, Any]: + path = checkpoint_path() + if not path.exists(): + return {"schema_version": 1, "generated_at": 0, "sessions": {}} + try: + data = json.loads(path.read_text()) + if isinstance(data, dict): + data.setdefault("schema_version", 1) + data.setdefault("generated_at", 0) + data.setdefault("sessions", {}) + if isinstance(data.get("sessions"), dict): + return data + except Exception: + pass + return {"schema_version": 1, "generated_at": 0, "sessions": {}} + + +def save_checkpoint(data: Dict[str, Any]) -> None: + path = checkpoint_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(_json_safe(data), indent=2, sort_keys=True)) + + +def session_fingerprint(meta: Dict[str, Any]) -> Dict[str, Any]: + return { + "last_active": meta.get("last_active"), + "started_at": meta.get("started_at"), + "model": meta.get("model"), + "title": meta.get("title") or meta.get("preview") or "Untitled", + } + + +def _cache_is_fresh(now: int) -> bool: + return _SNAPSHOT_CACHE is not None and (now - _SNAPSHOT_CACHE_AT) <= SNAPSHOT_TTL_SECONDS + + +def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], now: Optional[int] = None) -> bool: + if not isinstance(snapshot, dict): + return True + ts = int(snapshot.get("generated_at") or 0) + current = int(now or time.time()) + if ts <= 0: + return True + return (current - ts) > SNAPSHOT_TTL_SECONDS + + +def _scan_status_payload(now: Optional[int] = None) -> Dict[str, Any]: + current = int(now or time.time()) + snap = _SNAPSHOT_CACHE if isinstance(_SNAPSHOT_CACHE, dict) else None + generated_at = int((snap or {}).get("generated_at") or 0) if snap else 0 + return { + "state": _SCAN_STATUS.get("state", "idle"), + "started_at": _SCAN_STATUS.get("started_at"), + "finished_at": _SCAN_STATUS.get("finished_at"), + "last_error": _SCAN_STATUS.get("last_error"), + "last_duration_ms": _SCAN_STATUS.get("last_duration_ms"), + "run_count": _SCAN_STATUS.get("run_count", 0), + "ttl_seconds": SNAPSHOT_TTL_SECONDS, + "snapshot_generated_at": generated_at or None, + "snapshot_age_seconds": (current - generated_at) if generated_at else None, + "snapshot_stale": _is_snapshot_stale(snap, current), + } + + +def _tool_name_from_call(call: Any) -> Optional[str]: + if not isinstance(call, dict): + return None + fn = call.get("function") or {} + return call.get("name") or fn.get("name") + + +def _content(msg: Dict[str, Any]) -> str: + content = msg.get("content") + if content is None: + return "" + if isinstance(content, str): + return content + try: + return json.dumps(content) + except Exception: + return str(content) + + +def _count_tool(tool_names: List[str], *needles: str) -> int: + lowered = [name.lower() for name in tool_names] + return sum(1 for name in lowered if any(needle in name for needle in needles)) + + +def model_provider(model_name: str) -> Optional[str]: + name = (model_name or "").strip().lower() + if not name or name == "none": + return None + if "/" in name: + return name.split("/", 1)[0] + for provider in ["openai", "anthropic", "google", "gemini", "mistral", "meta", "qwen", "deepseek", "xai", "nous", "ollama", "groq", "openrouter", "codex"]: + if provider in name: + return "google" if provider == "gemini" else provider + return name.split(":", 1)[0].split("-", 1)[0] + + +def is_local_model_name(model_name: str) -> bool: + name = (model_name or "").strip().lower() + if not name or name == "none": + return False + local_markers = ["ollama", "llama.cpp", "localhost", "127.0.0.1", "local/", "local:", "gguf", "vllm-local"] + return any(marker in name for marker in local_markers) + + +def analyze_messages(session_id: str, title: str, messages: List[Dict[str, Any]]) -> Dict[str, Any]: + tool_names: Set[str] = set() + tool_sequence: List[str] = [] + files_touched: Set[str] = set() + full_text_parts: List[str] = [] + error_count = 0 + + for msg in messages: + text = _content(msg) + full_text_parts.append(text) + if msg.get("tool_name"): + name = str(msg["tool_name"]) + tool_names.add(name) + # Tool result rows name the tool that already appeared in the assistant tool_calls. + # Keep it for distinct-tool detection, but do not double-count it as a new call. + if msg.get("role") != "tool": + tool_sequence.append(name) + for call in msg.get("tool_calls") or []: + name = _tool_name_from_call(call) + if name: + tool_names.add(name) + tool_sequence.append(name) + if ERROR_RE.search(text): + error_count += 1 + blob = text + if msg.get("tool_calls"): + blob += " " + json.dumps(msg.get("tool_calls"), default=str) + files_touched.update(FILE_RE.findall(blob)) + + full_text = "\n".join(full_text_parts) + lower = full_text.lower() + terminal_calls = _count_tool(tool_sequence, "terminal") + web_calls = _count_tool(tool_sequence, "web_search", "web_extract") + web_extract_calls = _count_tool(tool_sequence, "web_extract") + browser_calls = _count_tool(tool_sequence, "browser") + web_browser_calls = web_calls + browser_calls + patch_calls = _count_tool(tool_sequence, "patch") + file_reads_searches = _count_tool(tool_sequence, "read_file", "search_files") + file_tool_calls = _count_tool(tool_sequence, "read_file", "write_file", "patch", "search_files") + delegate_calls = _count_tool(tool_sequence, "delegate_task") + process_calls = _count_tool(tool_sequence, "process") + len(re.findall(r"background\s*=\s*true", full_text, re.I)) + cron_calls = _count_tool(tool_sequence, "cronjob") + image_vision_calls = _count_tool(tool_sequence, "image", "vision") + tts_calls = _count_tool(tool_sequence, "tts", "text_to_speech") + skill_events = _count_tool(tool_sequence, "skill") + len(re.findall(r"\bskill", lower)) + skill_manage_events = _count_tool(tool_sequence, "skill_manage") + memory_events = _count_tool(tool_sequence, "memory", "mnemosyne") + memory_write_events = _count_tool(tool_sequence, "mnemosyne_remember", "memory") + + return { + "session_id": session_id, + "title": title or "Untitled session", + "message_count": len(messages), + "tool_call_count": len(tool_sequence), + "tool_names": tool_names, + "distinct_tool_count": len(tool_names), + "error_count": error_count, + "terminal_calls": terminal_calls, + "web_calls": web_calls, + "web_extract_calls": web_extract_calls, + "browser_calls": browser_calls, + "web_browser_calls": web_browser_calls, + "patch_calls": patch_calls, + "file_reads_searches": file_reads_searches, + "file_tool_calls": file_tool_calls, + "files_touched_count": len(files_touched), + "delegate_calls": delegate_calls, + "process_calls": process_calls, + "cron_calls": cron_calls, + "image_vision_calls": image_vision_calls, + "tts_calls": tts_calls, + "skill_events": skill_events, + "skill_manage_events": skill_manage_events, + "memory_events": memory_events, + "memory_write_events": memory_write_events, + "port_conflict": bool(PORT_RE.search(full_text)), + "port_conflict_events": 1 if PORT_RE.search(full_text) else 0, + "traceback_events": len(re.findall(r"traceback|exception", full_text, re.I)), + "log_read_events": len(re.findall(r"gateway\.log|errors\.log|agent\.log|/api/logs|\blogs\b", full_text, re.I)), + "permission_denied_events": len(re.findall(r"permission denied|eacces|operation not permitted", full_text, re.I)), + "install_error_events": 1 if INSTALL_RE.search(full_text) and ERROR_RE.search(full_text) else 0, + "install_success_events": 1 if INSTALL_RE.search(full_text) and SUCCESS_RE.search(full_text) else 0, + "restart_after_error_events": 1 if error_count and re.search(r"\brestart|reload|kill|start\b", full_text, re.I) else 0, + "env_var_error_events": len(re.findall(r"missing .*env|api key|environment variable|not configured|unauthorized|auth", full_text, re.I)), + "yaml_error_events": len(re.findall(r"yaml|yml|colon|parse error", full_text, re.I)) if ERROR_RE.search(full_text) else 0, + "docker_conflict_events": len(re.findall(r"docker.*(name|container).*already|container name conflict|Conflict\. The container", full_text, re.I)), + "frontend_activity_events": len(re.findall(r"\.(css|svg|tsx|jsx)|frontend|tailwind|react", full_text, re.I)), + "css_activity_events": len(re.findall(r"\.css|tailwind|style|className|visual", full_text, re.I)), + "git_events": len(re.findall(r"\bgit\s+(commit|push|merge|rebase|status|diff)", full_text, re.I)), + "tiny_patch_after_errors_events": 1 if error_count >= 5 and re.search(r"one character|single character|typo", full_text, re.I) else 0, + "context_events": len(re.findall(r"compress|context window|token|cache", full_text, re.I)), + "gateway_events": len(re.findall(r"gateway|discord|telegram|slack|api_server", full_text, re.I)), + "plugin_events": len(re.findall(r"plugin|dashboard-plugins|__HERMES_PLUGIN|manifest\.json", full_text, re.I)), + "rollback_events": len(re.findall(r"rollback|checkpoint", full_text, re.I)), + "docs_activity_events": len(re.findall(r"docs|documentation|docusaurus|README", full_text, re.I)), + "model_events": len(re.findall(r"model|provider|openrouter|codex|gemini|claude|anthropic|openai|mistral|qwen|deepseek|llama|ollama|vllm|gguf", full_text, re.I)), + "openrouter_events": len(re.findall(r"openrouter", full_text, re.I)), + "codex_events": len(re.findall(r"codex", full_text, re.I)), + "claude_events": len(re.findall(r"claude|anthropic", full_text, re.I)), + "gemini_events": len(re.findall(r"gemini|google ai|google model", full_text, re.I)), + "local_model_events": len(re.findall(r"ollama|llama\.cpp|gguf|vllm|local model|open[- ]weight|open weights", full_text, re.I)), + "toolset_events": len(re.findall(r"toolset|enabled_toolsets|browser tool|terminal tool|file tool|web tool", full_text, re.I)), + "config_events": len(re.findall(r"config\.ya?ml|\b[a-z0-9_-]+config\.(?:js|ts|json|ya?ml)|\.env(?:\b|\.)|manifest\.json|settings\.json|pyproject\.toml|package\.json", full_text, re.I)), + "git_history_events": len(re.findall(r"\bgit\s+(rebase|merge|fetch|pull|push|tag|checkout)|merge conflict|conflict\s*\(|rebase --continue", full_text, re.I)), + "test_events": len(re.findall(r"pytest|unittest|vitest|playwright|npm test|pnpm test|node --check|py_compile|tests? passed|\bOK\b", full_text, re.I)), + "screenshot_events": len(re.findall(r"screenshot|playwright|vision_analyze|browser_vision|\.png|image data", full_text, re.I)), + "release_events": len(re.findall(r"\bgit\s+tag|release|version bump|changelog|publish|pushed? tag", full_text, re.I)), + "cache_events": len(re.findall(r"cache hit|prompt caching|cache_read", full_text, re.I)), + "model_names": set(), + } + + +def evaluate_tiered(definition: Dict[str, Any], aggregate: Dict[str, Any]) -> Dict[str, Any]: + metric = definition["threshold_metric"] + progress = int(aggregate.get(metric, 0) or 0) + tiers_list = sorted(definition.get("tiers", []), key=lambda t: t["threshold"]) + achieved = [t for t in tiers_list if progress >= t["threshold"]] + next_tiers = [t for t in tiers_list if progress < t["threshold"]] + tier = achieved[-1]["name"] if achieved else None + next_tier = next_tiers[0]["name"] if next_tiers else None + next_threshold = next_tiers[0]["threshold"] if next_tiers else (tiers_list[-1]["threshold"] if tiers_list else 1) + current_threshold = achieved[-1]["threshold"] if achieved else 0 + denom = max(1, next_threshold - current_threshold) + pct = 100 if not next_tiers and achieved else max(0, min(99, math.floor(((progress - current_threshold) / denom) * 100))) + unlocked = bool(achieved) + discovered = bool(progress > 0) + state = "unlocked" if unlocked else ("secret" if definition.get("secret") and not discovered else "discovered") + return {"unlocked": unlocked, "discovered": discovered or not definition.get("secret"), "state": state, "tier": tier, "progress": progress, "next_tier": next_tier, "next_threshold": next_threshold, "progress_pct": pct} + + +def evaluate_requirements(definition: Dict[str, Any], aggregate: Dict[str, Any]) -> Dict[str, Any]: + requirements = definition.get("requirements", []) + if not requirements: + return {"unlocked": False, "discovered": not definition.get("secret"), "state": "secret" if definition.get("secret") else "discovered", "tier": None, "progress": 0, "next_tier": None, "next_threshold": 1, "progress_pct": 0} + parts = [] + any_progress = False + complete = True + for requirement in requirements: + value = int(aggregate.get(requirement["metric"], 0) or 0) + threshold = int(requirement.get("gte", 1)) + any_progress = any_progress or value > 0 + complete = complete and value >= threshold + parts.append(min(1.0, value / max(1, threshold))) + pct = math.floor((sum(parts) / len(parts)) * 100) + state = "unlocked" if complete else ("secret" if definition.get("secret") and not any_progress else "discovered") + return {"unlocked": complete, "discovered": any_progress or not definition.get("secret"), "state": state, "tier": None, "progress": pct, "next_tier": None, "next_threshold": 100, "progress_pct": 100 if complete else min(99, pct)} + + +def evaluate_boolean(definition: Dict[str, Any], aggregate: Dict[str, Any]) -> Dict[str, Any]: + # Backward-compatible helper for old tests/definitions. New catalog avoids simple booleans. + unlocked = bool(aggregate.get(definition["metric"])) + return {"unlocked": unlocked, "discovered": True, "state": "unlocked" if unlocked else "discovered", "tier": None, "progress": 1 if unlocked else 0, "next_tier": None, "next_threshold": 1, "progress_pct": 100 if unlocked else 0} + + +METRIC_LABELS = { + "max_tool_calls_in_session": "tool calls in one session", + "max_distinct_tools_in_session": "distinct Hermes tools used in one session", + "max_terminal_calls_in_session": "terminal calls in one session", + "max_file_tool_calls_in_session": "file/search/patch calls in one session", + "max_web_browser_calls_in_session": "web search/extract or browser calls in one session", + "max_messages_in_session": "messages in one session", + "max_files_touched_in_session": "files touched in one session", + "total_delegate_calls": "lifetime delegate_task calls", + "total_process_calls": "lifetime background process operations", + "total_cron_calls": "lifetime scheduled-job operations", + "total_errors": "error/failed/traceback messages observed", + "traceback_events": "traceback or exception mentions", + "log_read_events": "log inspections", + "port_conflict_events": "dev-server port conflict detections", + "permission_denied_events": "permission-denied errors", + "install_error_events": "package-install failures", + "install_success_events": "successful package installs after package work", + "restart_after_error_events": "restart/reload actions after error clusters", + "env_var_error_events": "missing auth/config/environment-variable events", + "yaml_error_events": "YAML/config parse incidents", + "docker_conflict_events": "Docker/container-name conflicts", + "frontend_activity_events": "frontend/CSS/SVG/React activity mentions", + "css_activity_events": "CSS, styling, Tailwind, or className activity", + "git_events": "git workflow commands", + "tiny_patch_after_errors_events": "tiny typo-style fixes after error clusters", + "skill_events": "Hermes skill mentions or tool use", + "skill_manage_events": "skill_manage create/patch/delete operations", + "memory_events": "memory or Mnemosyne tool events", + "memory_write_events": "durable memory writes", + "context_events": "context, compression, token, or cache-pressure mentions", + "gateway_events": "gateway/API/chat-platform activity", + "plugin_events": "dashboard plugin development or usage signals", + "rollback_events": "rollback/checkpoint recovery mentions", + "docs_activity_events": "documentation/README/docs activity", + "model_events": "model/provider-related activity", + "openrouter_events": "OpenRouter mentions", + "codex_events": "Codex mentions", + "cache_events": "prompt-cache/cache-hit mentions", + "total_web_calls": "lifetime web_search/web_extract calls", + "total_web_extract_calls": "lifetime web_extract calls", + "browser_calls": "lifetime browser automation calls", + "total_tool_calls": "lifetime Hermes tool calls", + "total_terminal_calls": "lifetime terminal calls", + "total_patch_calls": "lifetime targeted patch edits", + "total_file_reads_searches": "lifetime read_file/search_files calls", + "image_vision_calls": "image generation or vision tool calls", + "tts_calls": "text-to-speech or voice tool calls", + "distinct_model_count": "distinct model names seen in session metadata", + "distinct_provider_count": "distinct model providers inferred from session metadata", + "claude_events": "Claude/Anthropic model mentions", + "gemini_events": "Gemini/Google model mentions", + "local_model_events": "local/open-weight model mentions", + "local_model_chat_sessions": "Hermes sessions whose model metadata is local/open-weight", + "toolset_events": "toolset or tool-family mentions", + "config_events": "configuration/environment/manifest activity", + "git_history_events": "git history operations such as rebase, merge, fetch, push, or tag", + "test_events": "test/check/verification command mentions", + "screenshot_events": "screenshot, Playwright, PNG, or vision-inspection activity", + "release_events": "release, version, publish, or git tag events", + "session_count": "Hermes sessions", + "weekend_sessions": "sessions started on weekends", + "night_sessions": "sessions started late night or before dawn", +} + + +def metric_label(metric: str) -> str: + return METRIC_LABELS.get(metric, metric.replace("_", " ")) + + +def criteria_for(definition: Dict[str, Any]) -> str: + if definition.get("secret") and definition.get("state") == "secret": + return "Secret: exact requirement hidden until Hermes sees the first matching signal. Keep using Hermes across debugging, tools, memory, skills, plugins, and model workflows to reveal it." + secret_prefix = "" + if "threshold_metric" in definition: + tiers_list = sorted(definition.get("tiers", []), key=lambda t: t["threshold"]) + if not tiers_list: + return secret_prefix + "Requirement: use Hermes in the matching workflow." + metric = metric_label(definition["threshold_metric"]) + ladder = ", ".join(f"{t['name']} {t['threshold']}" for t in tiers_list) + return secret_prefix + f"Requirement: {metric}. Tier ladder: {ladder}." + requirements = definition.get("requirements") or [] + if requirements: + parts = [f"{metric_label(r['metric'])} ≥ {int(r.get('gte', 1))}" for r in requirements] + return secret_prefix + "Requirement: " + "; ".join(parts) + "." + return secret_prefix + "Requirement: complete the matching Hermes behavior." + + +def display_achievement(item: Dict[str, Any]) -> Dict[str, Any]: + clean = dict(item) + if clean.get("state") == "secret": + return {**clean, "name": "???", "description": "Secret achievement: hidden until Hermes detects the first relevant behavior in your session history.", "criteria": criteria_for(clean), "icon": "secret"} + clean["criteria"] = criteria_for(clean) + return clean + + +def scan_sessions( + limit: Optional[int] = None, + progress_callback: Optional[Any] = None, + progress_every: int = 250, +) -> Dict[str, Any]: + """Scan Hermes sessions and build per-session achievement stats. + + ``limit=None`` (the default) scans the ENTIRE session history. Prior + versions capped this at 200, which silently reduced achievement totals + to ~2% of history on long-running installs and made lifetime badges + unreachable. SQLite's ``LIMIT -1`` means "unlimited"; we map ``None`` + and non-positive values to ``-1`` so callers get the full catalog. + + Warm scans stay cheap: the checkpoint cache stores per-session stats + keyed by ``(started_at, last_active)`` and only re-analyzes sessions + whose fingerprint changed. Cold scans on large histories (thousands + of sessions) take tens of seconds to several minutes; ``evaluate_all`` + runs them on a background thread so the dashboard UI never blocks on + the first request. + + ``progress_callback(partial_sessions, scanned_so_far, total)`` — when + provided, fires every ``progress_every`` sessions with the sessions + analyzed so far and progress counters. Background scans use this to + publish intermediate snapshots so a long cold scan surfaces badges + incrementally on each dashboard refresh instead of going all-at-once + at the end. + """ + try: + from hermes_state import SessionDB + except Exception as exc: + return {"sessions": [], "aggregate": {}, "error": f"Could not import SessionDB: {exc}", "scan_meta": {"mode": "failed", "sessions_total": 0, "sessions_rescanned": 0, "sessions_reused": 0}} + + checkpoint = load_checkpoint() + previous_sessions = checkpoint.get("sessions") if isinstance(checkpoint.get("sessions"), dict) else {} + reused = 0 + rescanned = 0 + + # SQLite treats LIMIT -1 as "no limit". Map None / <=0 to -1 so the + # full session history flows through unless the caller explicitly + # requests a small sample (e.g. a smoke test). + db_limit = -1 if (limit is None or limit <= 0) else int(limit) + + db = SessionDB() + try: + sessions_meta = db.list_sessions_rich(limit=db_limit, include_children=True, project_compression_tips=False) + total_sessions = len(sessions_meta) + sessions: List[Dict[str, Any]] = [] + checkpoint_sessions: Dict[str, Any] = {} + for idx, meta in enumerate(sessions_meta, start=1): + sid = meta.get("id") + if not sid: + continue + fp = session_fingerprint(meta) + cached = previous_sessions.get(sid) if isinstance(previous_sessions, dict) else None + cached_stats = cached.get("stats") if isinstance(cached, dict) else None + cached_fp = cached.get("fingerprint") if isinstance(cached, dict) else None + + if isinstance(cached_stats, dict) and cached_fp == fp: + stats = dict(cached_stats) + reused += 1 + else: + messages = db.get_messages(sid) + stats = analyze_messages(sid, meta.get("title") or meta.get("preview") or "Untitled", messages) + rescanned += 1 + + stats["session_id"] = sid + stats["title"] = meta.get("title") or meta.get("preview") or stats.get("title") or "Untitled" + stats["started_at"] = meta.get("started_at") + stats["last_active"] = meta.get("last_active") + stats["source"] = meta.get("source") + if meta.get("model"): + stats.setdefault("model_names", set()) + if isinstance(stats["model_names"], set): + stats["model_names"].add(str(meta.get("model"))) + elif isinstance(stats["model_names"], list): + if str(meta.get("model")) not in stats["model_names"]: + stats["model_names"].append(str(meta.get("model"))) + else: + stats["model_names"] = {str(meta.get("model"))} + + sessions.append(stats) + checkpoint_sessions[sid] = {"fingerprint": fp, "stats": _json_safe(stats)} + + if progress_callback is not None and progress_every > 0 and (idx % progress_every == 0) and idx < total_sessions: + try: + progress_callback(list(sessions), idx, total_sessions) + except Exception: + # Progress callbacks are advisory — a broken publisher + # must never abort the scan itself. + pass + + save_checkpoint({ + "schema_version": 1, + "generated_at": int(time.time()), + "sessions": checkpoint_sessions, + }) + finally: + close = getattr(db, "close", None) + if close: + close() + return { + "sessions": sessions, + "aggregate": aggregate_stats(sessions), + "scan_meta": { + "mode": "incremental" if reused > 0 else "full", + "sessions_total": len(sessions), + "sessions_rescanned": rescanned, + "sessions_reused": reused, + "sessions_scanned_so_far": len(sessions), + "sessions_expected_total": total_sessions, + }, + } + + +def aggregate_stats(sessions: List[Dict[str, Any]]) -> Dict[str, Any]: + agg: Dict[str, Any] = { + "session_count": len(sessions), + "max_tool_calls_in_session": 0, + "max_distinct_tools_in_session": 0, + "max_messages_in_session": 0, + "max_terminal_calls_in_session": 0, + "max_file_tool_calls_in_session": 0, + "max_web_calls_in_session": 0, + "max_web_browser_calls_in_session": 0, + "max_files_touched_in_session": 0, + "total_errors": 0, + "total_tool_calls": 0, + "total_terminal_calls": 0, + "total_web_calls": 0, + "total_web_extract_calls": 0, + "total_patch_calls": 0, + "total_file_reads_searches": 0, + "total_delegate_calls": 0, + "total_process_calls": 0, + "total_cron_calls": 0, + "browser_calls": 0, + "image_vision_calls": 0, + "tts_calls": 0, + "distinct_model_count": 0, + "distinct_provider_count": 0, + "local_model_chat_sessions": 0, + "weekend_sessions": 0, + "night_sessions": 0, + } + sum_keys = [ + "traceback_events", "log_read_events", "port_conflict_events", "permission_denied_events", "install_error_events", "install_success_events", "restart_after_error_events", "env_var_error_events", "yaml_error_events", "docker_conflict_events", "frontend_activity_events", "css_activity_events", "git_events", "tiny_patch_after_errors_events", "skill_events", "skill_manage_events", "memory_events", "memory_write_events", "context_events", "gateway_events", "plugin_events", "rollback_events", "docs_activity_events", "model_events", "openrouter_events", "codex_events", "claude_events", "gemini_events", "local_model_events", "toolset_events", "config_events", "git_history_events", "test_events", "screenshot_events", "release_events", "cache_events", + ] + for key in sum_keys: + agg[key] = 0 + + model_names: Set[str] = set() + provider_names: Set[str] = set() + for s in sessions: + agg["max_tool_calls_in_session"] = max(agg["max_tool_calls_in_session"], s.get("tool_call_count", 0)) + agg["max_distinct_tools_in_session"] = max(agg["max_distinct_tools_in_session"], s.get("distinct_tool_count", 0)) + agg["max_messages_in_session"] = max(agg["max_messages_in_session"], s.get("message_count", 0)) + agg["max_terminal_calls_in_session"] = max(agg["max_terminal_calls_in_session"], s.get("terminal_calls", 0)) + agg["max_file_tool_calls_in_session"] = max(agg["max_file_tool_calls_in_session"], s.get("file_tool_calls", 0)) + agg["max_web_calls_in_session"] = max(agg["max_web_calls_in_session"], s.get("web_calls", 0)) + agg["max_web_browser_calls_in_session"] = max(agg["max_web_browser_calls_in_session"], s.get("web_browser_calls", 0)) + agg["max_files_touched_in_session"] = max(agg["max_files_touched_in_session"], s.get("files_touched_count", 0)) + agg["total_errors"] += s.get("error_count", 0) + agg["total_tool_calls"] += s.get("tool_call_count", 0) + agg["total_terminal_calls"] += s.get("terminal_calls", 0) + agg["total_web_calls"] += s.get("web_calls", 0) + agg["total_web_extract_calls"] += s.get("web_extract_calls", 0) + agg["total_patch_calls"] += s.get("patch_calls", 0) + agg["total_file_reads_searches"] += s.get("file_reads_searches", 0) + agg["total_delegate_calls"] += s.get("delegate_calls", 0) + agg["total_process_calls"] += s.get("process_calls", 0) + agg["total_cron_calls"] += s.get("cron_calls", 0) + agg["browser_calls"] += s.get("browser_calls", 0) + agg["image_vision_calls"] += s.get("image_vision_calls", 0) + agg["tts_calls"] += s.get("tts_calls", 0) + for key in sum_keys: + agg[key] += s.get(key, 0) + model_names.update(s.get("model_names") or set()) + session_models = s.get("model_names") or set() + for model_name in session_models: + provider = model_provider(str(model_name)) + if provider: + provider_names.add(provider) + if any(is_local_model_name(str(model_name)) for model_name in session_models): + agg["local_model_chat_sessions"] += 1 + if s.get("started_at"): + try: + lt = time.localtime(float(s.get("started_at"))) + if lt.tm_wday >= 5: + agg["weekend_sessions"] += 1 + if lt.tm_hour < 6 or lt.tm_hour >= 23: + agg["night_sessions"] += 1 + except Exception: + pass + agg["distinct_model_count"] = len({m for m in model_names if m and m != "None"}) + agg["distinct_provider_count"] = len(provider_names) + return agg + + +def evaluate_definition(definition: Dict[str, Any], aggregate: Dict[str, Any]) -> Dict[str, Any]: + if "threshold_metric" in definition: + return evaluate_tiered(definition, aggregate) + if "requirements" in definition: + return evaluate_requirements(definition, aggregate) + return evaluate_boolean(definition, aggregate) + + +def evidence_for(definition: Dict[str, Any], sessions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if not sessions: + return None + metric = definition.get("threshold_metric") + metric_to_session_key = { + "max_tool_calls_in_session": "tool_call_count", + "max_distinct_tools_in_session": "distinct_tool_count", + "max_messages_in_session": "message_count", + "max_terminal_calls_in_session": "terminal_calls", + "max_file_tool_calls_in_session": "file_tool_calls", + "max_web_calls_in_session": "web_calls", + "max_web_browser_calls_in_session": "web_browser_calls", + "max_files_touched_in_session": "files_touched_count", + } + if metric in metric_to_session_key: + key = metric_to_session_key[metric] + s = max(sessions, key=lambda x: x.get(key, 0)) + return {"session_id": s.get("session_id"), "title": s.get("title"), "value": s.get(key, 0)} + return None + + +def _compute_from_scan(scan: Dict[str, Any], *, is_partial: bool = False) -> Dict[str, Any]: + """Evaluate every achievement definition against a scan result. + + Used by ``compute_all`` for finished scans AND by the background + progress callback for partial, in-flight snapshots. ``is_partial=True`` + skips persisting ``state.json`` unlocks — we don't want to record an + "unlock time" based on half a scan that a later session might shift. + """ + aggregate = scan.get("aggregate", {}) + state = load_state() if not is_partial else {"unlocks": {}} + unlocks = state.setdefault("unlocks", {}) + now = int(time.time()) + evaluated = [] + for definition in ACHIEVEMENTS: + result = evaluate_definition(definition, aggregate) + unlock_id = definition["id"] + if not is_partial and result["unlocked"] and unlock_id not in unlocks: + unlocks[unlock_id] = {"unlocked_at": now, "first_tier": result.get("tier"), "evidence": evidence_for(definition, scan.get("sessions", []))} + item = {**definition, **result} + if result["unlocked"]: + item["unlocked_at"] = unlocks.get(unlock_id, {}).get("unlocked_at") + item["evidence"] = unlocks.get(unlock_id, {}).get("evidence") or evidence_for(definition, scan.get("sessions", [])) + evaluated.append(display_achievement(item)) + if not is_partial: + save_state(state) + unlocked = [a for a in evaluated if a["unlocked"]] + discovered = [a for a in evaluated if a.get("state") == "discovered"] + secret = [a for a in evaluated if a.get("state") == "secret"] + return { + "achievements": evaluated, + "sessions": scan.get("sessions", []), + "aggregate": aggregate, + "scan_meta": scan.get("scan_meta", {}), + "error": scan.get("error"), + "unlocked_count": len(unlocked), + "discovered_count": len(discovered), + "secret_count": len(secret), + "total_count": len(evaluated), + "generated_at": now, + } + + +def compute_all(progress_callback: Optional[Any] = None, progress_every: int = 250) -> Dict[str, Any]: + scan = scan_sessions(progress_callback=progress_callback, progress_every=progress_every) + return _compute_from_scan(scan, is_partial=False) + + +_BACKGROUND_SCAN_THREAD: Optional[threading.Thread] = None +_BACKGROUND_SCAN_LOCK = threading.Lock() + + +def _build_pending_snapshot(now: int) -> Dict[str, Any]: + """Placeholder payload used while the first-ever scan is still running. + + Returns a structurally-complete response so the dashboard UI can render + an empty achievement list + spinner without special-casing "no data yet". + """ + evaluated = [display_achievement({**d, **{"unlocked": False, "discovered": False, "state": "secret" if d.get("secret") else "discovered", "progress": 0, "progress_pct": 0, "next_tier": (d.get("tiers") or [{}])[0].get("name"), "next_threshold": (d.get("tiers") or [{}])[0].get("threshold", 1), "tier": None}}) for d in ACHIEVEMENTS] + return { + "achievements": evaluated, + "sessions": [], + "aggregate": {}, + "scan_meta": {"mode": "pending", "sessions_total": 0, "sessions_rescanned": 0, "sessions_reused": 0}, + "error": None, + "unlocked_count": 0, + "discovered_count": sum(1 for a in evaluated if a.get("state") == "discovered"), + "secret_count": sum(1 for a in evaluated if a.get("state") == "secret"), + "total_count": len(evaluated), + "generated_at": now, + } + + +def _run_scan_and_update_cache(publish_partial_snapshots: bool = True) -> None: + """Execute a scan + snapshot update. Called synchronously or from a thread. + + When ``publish_partial_snapshots=True`` (the default for background + scans), the scanner periodically publishes an in-progress snapshot to + ``_SNAPSHOT_CACHE`` so each dashboard refresh during a long cold scan + shows more progress — badges unlock incrementally as sessions stream + in, instead of staying at zero for minutes and then jumping to the + final state. Synchronous /rescan callers pass ``False`` because they + block on the full result anyway. + """ + global _SNAPSHOT_CACHE, _SNAPSHOT_CACHE_AT + with _SCAN_LOCK: + started = int(time.time()) + _SCAN_STATUS["state"] = "running" + _SCAN_STATUS["started_at"] = started + _SCAN_STATUS["last_error"] = None + + def _publish_partial(partial_sessions, scanned_so_far, total): + global _SNAPSHOT_CACHE, _SNAPSHOT_CACHE_AT + try: + partial_scan = { + "sessions": partial_sessions, + "aggregate": aggregate_stats(partial_sessions), + "scan_meta": { + "mode": "in_progress", + "sessions_total": scanned_so_far, + "sessions_rescanned": 0, + "sessions_reused": 0, + "sessions_scanned_so_far": scanned_so_far, + "sessions_expected_total": total, + }, + } + partial = _compute_from_scan(partial_scan, is_partial=True) + # Keep the cache in the 'stale' TTL regime by NOT bumping + # _SNAPSHOT_CACHE_AT to "now". The UI treats partial + # results as stale so it keeps polling /scan-status and + # sees the final snapshot when the scan finishes. In-flight + # partials are visible but are never mistaken for finished. + _SNAPSHOT_CACHE = _json_safe(partial) + _SNAPSHOT_CACHE_AT = 0 + except Exception: + # Intermediate publication is best-effort; don't kill the scan. + pass + + callback = _publish_partial if publish_partial_snapshots else None + try: + computed = compute_all(progress_callback=callback) + _SNAPSHOT_CACHE = _json_safe(computed) + _SNAPSHOT_CACHE_AT = int(_SNAPSHOT_CACHE.get("generated_at") or int(time.time())) + save_snapshot(_SNAPSHOT_CACHE) + _SCAN_STATUS["state"] = "idle" + except Exception as exc: + _SCAN_STATUS["state"] = "failed" + _SCAN_STATUS["last_error"] = str(exc) + finally: + _SCAN_STATUS["finished_at"] = int(time.time()) + _SCAN_STATUS["last_duration_ms"] = int((_SCAN_STATUS["finished_at"] - started) * 1000) + _SCAN_STATUS["run_count"] = int(_SCAN_STATUS.get("run_count", 0)) + 1 + + +def _start_background_scan() -> None: + """Kick off a scan in a daemon thread if one isn't already running. + + Idempotent: concurrent callers see the in-flight thread and return + immediately. The thread updates ``_SNAPSHOT_CACHE`` on completion so + subsequent ``/achievements`` requests see fresh data. While running, + it also publishes partial snapshots every ~250 sessions so the UI + reflects incremental progress on long cold scans. + """ + global _BACKGROUND_SCAN_THREAD + with _BACKGROUND_SCAN_LOCK: + existing = _BACKGROUND_SCAN_THREAD + if existing is not None and existing.is_alive(): + return + thread = threading.Thread( + target=_run_scan_and_update_cache, + kwargs={"publish_partial_snapshots": True}, + name="hermes-achievements-scan", + daemon=True, + ) + _BACKGROUND_SCAN_THREAD = thread + thread.start() + + +def evaluate_all(force: bool = False) -> Dict[str, Any]: + """Return the current achievements payload. + + Behavior matrix: + + * Fresh in-memory cache → return it instantly. + * Stale on-disk snapshot → load it, kick a background rescan, return + the stale data (UI decorates it with ``is_stale=True``). + * No snapshot yet (first-ever run) → kick a background scan, return + an empty-but-valid "pending" payload so the UI can render a spinner + without blocking. + * ``force=True`` (manual /rescan) → run synchronously, block the + caller, replace the cache. + + Warm scans stay cheap (the checkpoint cache reuses per-session stats). + Cold scans on 8000+ session databases take minutes; the background + thread prevents that from ever blocking the dashboard request path. + """ + global _SNAPSHOT_CACHE, _SNAPSHOT_CACHE_AT + now = int(time.time()) + + if not force and _cache_is_fresh(now): + return _SNAPSHOT_CACHE or {} + + # Lazy-load persisted snapshot from disk so fresh process starts + # don't have to wait for a scan to serve cached data. + if _SNAPSHOT_CACHE is None: + persisted = load_snapshot() + if isinstance(persisted, dict): + generated_at = int(persisted.get("generated_at") or 0) + _SNAPSHOT_CACHE = persisted + _SNAPSHOT_CACHE_AT = generated_at or now + + if force: + # Manual /rescan — block the caller, synchronous scan path. + # No partial publishing: the caller is waiting for the final result. + _run_scan_and_update_cache(publish_partial_snapshots=False) + if _SNAPSHOT_CACHE is not None: + return _SNAPSHOT_CACHE + # Scan failed with no prior cache — surface empty payload. + return _build_pending_snapshot(now) + + # Non-force path: serve whatever we have and refresh in background. + if _SNAPSHOT_CACHE is not None: + if not _cache_is_fresh(now): + _start_background_scan() + return _SNAPSHOT_CACHE + + # First-ever run on this machine — no snapshot yet. Kick off a scan + # and return a pending placeholder. The UI polls /scan-status and + # re-fetches /achievements when the scan completes. + _start_background_scan() + return _build_pending_snapshot(now) + + +@router.get("/achievements") +async def achievements(): + data = evaluate_all() + payload = {k: data[k] for k in ["achievements", "unlocked_count", "discovered_count", "secret_count", "total_count", "error", "generated_at"] if k in data} + payload["is_stale"] = _is_snapshot_stale(data) + payload["scan_meta"] = { + **(data.get("scan_meta") or {}), + "status": _scan_status_payload(), + } + return payload + + +@router.get("/scan-status") +async def scan_status(): + return _scan_status_payload() + + +@router.get("/recent-unlocks") +async def recent_unlocks(): + data = evaluate_all() + return sorted([a for a in data["achievements"] if a["unlocked"]], key=lambda a: a.get("unlocked_at") or 0, reverse=True)[:20] + + +@router.get("/sessions/{session_id}/badges") +async def session_badges(session_id: str): + data = evaluate_all() + session = next((s for s in data["sessions"] if s["session_id"] == session_id), None) + if not session: + return {"session_id": session_id, "badges": []} + aggregate = aggregate_stats([session]) + badges = [] + for definition in ACHIEVEMENTS: + result = evaluate_definition(definition, aggregate) + if result["unlocked"]: + badges.append(display_achievement({**definition, **result})) + return {"session_id": session_id, "badges": badges} + + +@router.post("/rescan") +async def rescan(): + return {"ok": True, **evaluate_all(force=True)} + + +@router.post("/reset-state") +async def reset_state(): + global _SNAPSHOT_CACHE, _SNAPSHOT_CACHE_AT + save_state({"unlocks": {}}) + _SNAPSHOT_CACHE = None + _SNAPSHOT_CACHE_AT = 0 + _SCAN_STATUS["state"] = "idle" + _SCAN_STATUS["started_at"] = None + _SCAN_STATUS["finished_at"] = None + _SCAN_STATUS["last_error"] = None + _SCAN_STATUS["last_duration_ms"] = None + try: + snapshot_path().unlink(missing_ok=True) + except Exception: + pass + try: + checkpoint_path().unlink(missing_ok=True) + except Exception: + pass + return {"ok": True} diff --git a/plugins/hermes-achievements/docs/achievements-performance-implementation-plan.md b/plugins/hermes-achievements/docs/achievements-performance-implementation-plan.md new file mode 100644 index 0000000000..76336b9d2a --- /dev/null +++ b/plugins/hermes-achievements/docs/achievements-performance-implementation-plan.md @@ -0,0 +1,157 @@ +# Hermes Achievements Performance Implementation Plan + +Status: Ready for execution after hackathon review window +Constraint: Plugin remains frozen until judging is complete +Decision: `/overview` and top-banner slots are out of scope and will be removed. + +--- + +## Phase 0 — Baseline & Safety (no behavior change) + +### Task 0.1: Add perf benchmark script (local) +Objective: Repro baseline before/after. + +Acceptance: +- Can print endpoint timings for `/achievements` (3 runs each, cold + warm). + +### Task 0.2: Define acceptance thresholds +Objective: Lock success criteria now. + +Acceptance: +- Documented SLOs: + - `/achievements` p95 < 1s (cached) + - max active scan jobs = 1 + +--- + +## Phase 1 — Remove unused overview/slot surface (highest certainty) + +### Task 1.1: Remove `/overview` backend route +Objective: Eliminate duplicate heavy endpoint path. + +Acceptance: +- `plugin_api.py` no longer exposes `/overview`. + +### Task 1.2: Remove slot registration and SummarySlot frontend code +Objective: Remove cross-tab banner fetch behavior. + +Acceptance: +- No `registerSlot(..."sessions:top"...)` or `registerSlot(..."analytics:top"...)`. +- No frontend call to `api("/overview")`. + +### Task 1.3: Update plugin manifest +Objective: Reflect final UI scope. + +Acceptance: +- `manifest.json` removes `slots` declarations. +- Tab registration remains intact. + +--- + +## Phase 2 — Shared snapshot persistence + single-flight for `/achievements` + +### Task 2.1: Introduce snapshot store abstraction + on-disk persistence +Objective: Single source of truth for Achievements data that survives process restarts. + +Acceptance: +- One structure contains dataset consumed by `/achievements`. +- Repeated requests do not recompute when cache is fresh. +- Snapshot persisted at `~/.hermes/plugins/hermes-achievements/scan_snapshot.json`. + +### Task 2.2: Single-flight scan coordinator +Objective: Prevent concurrent recomputes. + +Acceptance: +- Simultaneous requests result in one compute run. + +### Task 2.3: Refactor `/achievements` to read snapshot +Objective: Remove direct repeated compute from request path. + +Acceptance: +- `/achievements` does not run independent full recompute per request when cache is valid. + +--- + +## Phase 3 — Stale-While-Revalidate + +### Task 3.1: TTL state (`FRESH`/`STALE`) +Objective: Serve immediately when stale, refresh in background. + +Acceptance: +- Cached response returned quickly even when expired. +- Refresh is asynchronous. + +### Task 3.2: Add `scan-status` endpoint (optional) +Objective: Let UI/ops inspect scan state. + +Acceptance: +- Returns state, last success time, last duration, last error. + +### Task 3.3: Add metadata fields to `/achievements` +Objective: Improve transparency. + +Acceptance: +- Response includes `generated_at`, `is_stale`, maybe `scan_id`. + +--- + +## Phase 4 — Incremental Scanning (optional but recommended) + +### Task 4.1: Add per-session checkpoint file +Objective: Track session-level changes, not just global scan time. + +Acceptance: +- Checkpoint persisted at `~/.hermes/plugins/hermes-achievements/scan_checkpoint.json`. +- For each session: `session_id`, fingerprint (`updated_at`/message_count/hash), and cached contribution. + +### Task 4.2: Incremental aggregation +Objective: Recompute only changed/new sessions and reuse unchanged contributions. + +Acceptance: +- Typical refresh time drops materially below full scan. +- Aggregate rebuild uses: subtract old contribution + add new contribution for changed sessions. + +### Task 4.3: Full rebuild fallback +Objective: Preserve correctness. + +Acceptance: +- Manual full rescan always possible. +- Schema/version changes invalidate checkpoint and force full rebuild. + +--- + +## Test Plan + +1. Unit tests +- Snapshot lifecycle transitions +- Dedupe logic under parallel requests +- `/achievements` response compatibility + +2. Integration tests +- Opening Achievements repeatedly causes <=1 heavy scan while in-flight +- `/achievements` warm-cache load is fast +- manual rescan updates snapshot and timestamps + +3. Manual benchmarks +- Compare pre/post `/achievements` timings with same history dataset + +--- + +## Rollout Plan + +1. Release internal branch with Phase 1 (remove overview/slots). +2. Validate no UI regression in Achievements tab. +3. Add Phase 2 snapshot/dedupe. +4. Add Phase 3 stale-while-revalidate + status metadata. +5. Optional: incremental scanner. + +Rollback: keep old compute path behind temporary feature flag for one release window. + +--- + +## Definition of Done + +- Achievements tab remains fully functional (counts, latest, tiers, cards, filters). +- No `/overview` endpoint or slot calls remain. +- Repeated Achievements loads feel immediate after warm cache. +- Metrics/unlocks remain unchanged versus baseline. diff --git a/plugins/hermes-achievements/docs/achievements-performance-implementation-spec.md b/plugins/hermes-achievements/docs/achievements-performance-implementation-spec.md new file mode 100644 index 0000000000..b6574d9831 --- /dev/null +++ b/plugins/hermes-achievements/docs/achievements-performance-implementation-spec.md @@ -0,0 +1,219 @@ +# Hermes Achievements Implementation Spec (Detailed) + +This document is implementation-facing detail to execute the performance refactor later. + +Decision scope: keep only Achievements tab flow; remove `/overview` + top-banner slot integration. + +--- + +## A) Current Behavior Summary + +- `evaluate_all()` performs: + - full `scan_sessions()` + - `SessionDB.list_sessions_rich(...)` + - `db.get_messages(session_id)` for each session + - text/tool regex analysis + aggregation + evaluation +- `/overview` and `/achievements` both currently call `evaluate_all()` directly. +- slot calls (`sessions:top`, `analytics:top`) currently invoke `/overview`. + +Consequence: repeated full recomputes and contention. + +--- + +## B) De-scope/Removal Changes + +1. Remove backend route: +- `GET /overview` + +2. Remove frontend slot usage: +- `SummarySlot` component +- `registerSlot("sessions:top")` +- `registerSlot("analytics:top")` + +3. Remove manifest slot declarations: +- `"slots": ["sessions:top", "analytics:top"]` + +4. Keep: +- tab route/page for Achievements +- `/achievements` endpoint and full tab rendering + +--- + +## C) Target Internal Interfaces + +### 1) `SnapshotStore` +Responsibilities: +- hold latest computed snapshot in memory +- persist/load snapshot from disk +- expose age and staleness checks + +Storage path: +- `~/.hermes/plugins/hermes-achievements/scan_snapshot.json` + +Methods (conceptual): +- `get()` -> snapshot | null +- `set(snapshot)` +- `is_stale(ttl_seconds)` + +### 2) `ScanCoordinator` +Responsibilities: +- single-flight guard for compute jobs +- track scan status + +Methods: +- `run_if_needed(force: bool = false)` +- `get_status()` + +State fields: +- `state`: `idle|running|failed` +- `started_at`, `finished_at` +- `last_error` +- `run_count` + +### 3) `build_snapshot()` +Responsibilities: +- execute current compute logic once +- on first run, perform full scan and materialize per-session contributions +- on subsequent runs, process only changed/new sessions via checkpoint fingerprints +- produce shape consumed by `/achievements` + +Output: +- `achievements` +- count fields +- optional `scan_meta` + +--- + +## D) Endpoint Behavior Matrix (No `/overview`) + +| Endpoint | Cache fresh | Cache stale | No cache | Force rescan | +|---|---|---|---|---| +| `/achievements` | return cached | return stale + trigger bg refresh | blocking bootstrap scan | n/a | +| `/rescan` | trigger refresh | trigger refresh | trigger refresh | yes | +| `/scan-status` | status only | status only | status only | status only | + +Notes: +- At most one scan run active. +- Other callers either await same run or receive stale snapshot according to policy. + +--- + +## E) Data Shape (Proposed) + +```json +{ + "generated_at": 0, + "is_stale": false, + "scan_meta": { + "duration_ms": 0, + "sessions_scanned": 0, + "messages_scanned": 0, + "mode": "full", + "error": null + }, + "achievements": [], + "unlocked_count": 0, + "discovered_count": 0, + "secret_count": 0, + "total_count": 0, + "error": null +} +``` + +Compatibility guidance: +- Keep existing `/achievements` keys. +- Add metadata keys without breaking old callers. + +Checkpoint file (new): +- `~/.hermes/plugins/hermes-achievements/scan_checkpoint.json` + +Suggested checkpoint shape: +```json +{ + "schema_version": 1, + "generated_at": 0, + "sessions": { + "": { + "fingerprint": { + "updated_at": 0, + "message_count": 0, + "hash": "optional" + }, + "contribution": { + "metrics": {} + } + } + } +} +``` + +Notes: +- fingerprint mismatch => recompute that session contribution only. +- unchanged fingerprint => reuse stored contribution. + +--- + +## F) Concurrency Contract + +- Any request path that needs fresh data must pass through single-flight coordinator. +- If a scan is running: + - do not start second scan + - either await in-flight run (bounded) or serve stale snapshot immediately +- lock scope must include scan start/finish state transitions. + +--- + +## G) Error Handling Contract + +- If refresh fails and prior snapshot exists: + - return prior snapshot with `is_stale=true` and error metadata +- If refresh fails and no prior snapshot: + - return explicit error response (current behavior equivalent) +- `scan-status` should always return last known state/error. + +--- + +## H) Frontend Integration Contract + +- Achievements page: + - one fetch on mount to `/achievements` + - optional background refresh indicator if stale +- no top-banner slot integration +- avoid duplicate in-flight calls during fast navigation by cancellation/debounce. + +--- + +## I) Validation Checklist + +- [ ] `/overview` route removed +- [ ] manifest has no `sessions:top`/`analytics:top` slots +- [ ] frontend has no `api("/overview")` calls +- [ ] repeated Achievements navigation does not create multiple heavy scans +- [ ] average warm load times meet SLOs +- [ ] unlock totals match pre-refactor baseline for same history +- [ ] no schema regression in `/achievements` response + +--- + +## J) Suggested File Placement for Future Work + +- backend changes: `dashboard/plugin_api.py` +- optional extraction: + - `dashboard/perf_snapshot.py` + - `dashboard/perf_scan_coordinator.py` +- frontend request hygiene: `dashboard/dist/index.js` (or source if available) +- plugin metadata: `dashboard/manifest.json` +- persisted runtime files: + - `~/.hermes/plugins/hermes-achievements/state.json` (existing unlock state) + - `~/.hermes/plugins/hermes-achievements/scan_snapshot.json` (new) + - `~/.hermes/plugins/hermes-achievements/scan_checkpoint.json` (new) + +--- + +## K) Post-Implementation Reporting Template + +Record: +- dataset size (sessions/messages/tool calls) +- pre/post `/achievements` timings (cold/warm) +- whether single-flight dedupe triggered under repeated tab open +- any behavioral diffs in unlock counts diff --git a/plugins/hermes-achievements/docs/achievements-performance-spec.md b/plugins/hermes-achievements/docs/achievements-performance-spec.md new file mode 100644 index 0000000000..1355246948 --- /dev/null +++ b/plugins/hermes-achievements/docs/achievements-performance-spec.md @@ -0,0 +1,174 @@ +# Hermes Achievements Performance Spec (Post-Hackathon) + +Status: Draft (no code changes yet) +Owner: hermes-achievements plugin +Scope: `dashboard/plugin_api.py` + `dashboard/dist/index.js` request behavior +Decision: **Drop `/overview` and top-banner slots**; keep only Achievements tab data path. + +--- + +## 1) Problem Statement + +Current plugin endpoints `/achievements` and `/overview` both execute a full history recomputation (`evaluate_all()`), which performs a full SessionDB scan each request. + +Observed on this machine/repo: +- ~83 sessions +- ~7,125 messages +- ~3,623 tool calls +- `evaluate_all()` ~13–16s per call +- `/achievements` ~13–15s per call +- `/overview` ~12–15s per call +- Overlap between endpoints increases perceived wait. + +Given current product direction, `/overview` and cross-tab top-banner slots are not needed. + +--- + +## 2) Goals + +- Keep achievement correctness unchanged. +- Keep all Achievements-tab UX/data (unlocked/discovered/secrets/highest/latest/cards). +- Remove unused summary path (`/overview`) and slot wiring. +- Make Achievements tab faster by avoiding duplicate endpoint pathways. +- Ensure at most one heavy scan can run at a time. + +Non-goals (phase 1): +- Rewriting achievement rules. +- Changing badge semantics/states. + +--- + +## 3) Endpoint Semantics (Target) + +### `GET /api/plugins/hermes-achievements/achievements` +Single source endpoint for Achievements UI. +Returns full payload used by the tab: +- `achievements` +- `unlocked_count` +- `discovered_count` +- `secret_count` +- `total_count` +- `error` + +### `POST /api/plugins/hermes-achievements/rescan` (optional) +Manual refresh trigger. +Prefer async trigger + immediate status response. + +### `GET /api/plugins/hermes-achievements/scan-status` (optional new) +Reports scan state for UX/ops. + +### Removed +- `GET /api/plugins/hermes-achievements/overview` + +--- + +## 4) UI Scope (Target) + +Keep: +- Achievements page/tab (`/achievements` in plugin tab manifest) +- All existing Achievements tab stats/cards/filters + +Remove: +- Top-banner summary slot components using `sessions:top` and `analytics:top` +- Any frontend call path to `/overview` + +--- + +## 5) Runtime State Machine (for `/achievements`) + +- `FRESH`: cached snapshot age <= TTL +- `STALE`: snapshot exists but expired +- `SCANNING`: background recompute running +- `FAILED`: last recompute failed, last good snapshot still served + +Rules: +1. FRESH -> serve immediately. +2. STALE + not scanning -> serve stale snapshot immediately and launch background refresh. +3. SCANNING -> do not start another scan; join single-flight in-flight job. +4. No snapshot yet -> allow one blocking bootstrap scan. + +--- + +## 6) Caching & Invalidation + +### Phase 1 +- In-memory cache + persisted snapshot file. +- TTL: 60–180 seconds (configurable). +- Single-flight dedupe for scan requests. +- Persist plugin data under: + - `~/.hermes/plugins/hermes-achievements/scan_snapshot.json` + +### Phase 2 +- Incremental scan checkpoints with per-session fingerprints. +- Persist checkpoint data under: + - `~/.hermes/plugins/hermes-achievements/scan_checkpoint.json` +- Checkpoint stores, per session: + - `session_id` + - fingerprint (`updated_at`, message_count, or hash) + - cached per-session contribution used for aggregate recomposition +- Scan policy: + - First run: full scan and materialize snapshot + checkpoint. + - Next runs: process only new/changed sessions, reuse unchanged contributions. +- Full rebuild only on: + - schema/version change + - checkpoint corruption + - explicit full rescan + +--- + +## 7) Frontend Contract + +- Achievements tab requests `/achievements` once on mount. +- No slot-based summary fetches. +- If response says `is_stale=true`, UI may display “Updating in background”. +- Avoid duplicate mount-triggered calls and cancel stale requests on navigation. + +--- + +## 8) SLO Targets + +- `/achievements` p95 < 1s (cached) +- Max concurrent heavy scans: 1 +- Background refresh should not block UI + +--- + +## 9) Observability Requirements + +Track: +- scan count +- scan duration avg/p95 +- dedupe hit count (joined in-flight scans) +- stale-served count +- failures + last error + +Expose minimal diagnostics in `/scan-status`. + +--- + +## 10) Backward Compatibility + +- Keep `/achievements` response shape backward-compatible. +- Removing `/overview` is acceptable because slot UI is intentionally removed. +- If temporary compatibility is needed, `/overview` can return static deprecation response for one release. + +--- + +## 11) Risks + +- Stale data confusion -> mitigate with `generated_at` and explicit refresh status. +- Cache invalidation bugs -> start with conservative TTL + manual rescan. +- Concurrency bugs -> protect scan section with lock/single-flight guard. +- Session mutation edge cases -> use per-session fingerprint invalidation (not global timestamp only). + +--- + +## 12) Persistence Files (Explicit) + +Plugin state directory: +- `~/.hermes/plugins/hermes-achievements/` + +Files: +- `state.json` (existing): unlock tracking +- `scan_snapshot.json` (new): latest materialized achievements payload +- `scan_checkpoint.json` (new): per-session fingerprints + contributions for incremental refresh diff --git a/plugins/hermes-achievements/docs/assets/achievements-dashboard-hd.png b/plugins/hermes-achievements/docs/assets/achievements-dashboard-hd.png new file mode 100644 index 0000000000..2342f548e3 Binary files /dev/null and b/plugins/hermes-achievements/docs/assets/achievements-dashboard-hd.png differ diff --git a/plugins/hermes-achievements/docs/assets/achievements-tier-showcase-hd.png b/plugins/hermes-achievements/docs/assets/achievements-tier-showcase-hd.png new file mode 100644 index 0000000000..64dfc85c60 Binary files /dev/null and b/plugins/hermes-achievements/docs/assets/achievements-tier-showcase-hd.png differ diff --git a/plugins/hermes-achievements/tests/test_achievement_engine.py b/plugins/hermes-achievements/tests/test_achievement_engine.py new file mode 100644 index 0000000000..a941c8fd14 --- /dev/null +++ b/plugins/hermes-achievements/tests/test_achievement_engine.py @@ -0,0 +1,156 @@ +import importlib.util +import unittest +from pathlib import Path + +MODULE_PATH = Path(__file__).resolve().parents[1] / "dashboard" / "plugin_api.py" +spec = importlib.util.spec_from_file_location("plugin_api", MODULE_PATH) +plugin_api = importlib.util.module_from_spec(spec) +spec.loader.exec_module(plugin_api) + + +class AchievementEngineTests(unittest.TestCase): + def test_tool_call_stats_detect_tool_names_and_errors(self): + messages = [ + {"role": "assistant", "tool_calls": [{"function": {"name": "terminal"}}]}, + {"role": "tool", "tool_name": "terminal", "content": "Error: port 3000 already in use"}, + {"role": "assistant", "tool_calls": [{"function": {"name": "web_search"}}]}, + ] + + stats = plugin_api.analyze_messages("s1", "Fix dev server", messages) + + self.assertEqual(stats["tool_call_count"], 2) + self.assertEqual(stats["tool_names"], {"terminal", "web_search"}) + self.assertEqual(stats["error_count"], 1) + self.assertIs(stats["port_conflict"], True) + + def test_tiered_achievement_reaches_highest_matching_tier(self): + definition = { + "id": "let_him_cook", + "threshold_metric": "max_tool_calls_in_session", + "tiers": [ + {"name": "Copper", "threshold": 10}, + {"name": "Silver", "threshold": 25}, + {"name": "Gold", "threshold": 50}, + ], + } + aggregate = {"max_tool_calls_in_session": 28} + + result = plugin_api.evaluate_tiered(definition, aggregate) + + self.assertIs(result["unlocked"], True) + self.assertEqual(result["tier"], "Silver") + self.assertEqual(result["progress"], 28) + self.assertEqual(result["next_tier"], "Gold") + + def test_tiered_achievement_can_be_discovered_without_unlocking(self): + definition = { + "id": "terminal_goblin", + "threshold_metric": "total_terminal_calls", + "tiers": [{"name": "Copper", "threshold": 50}], + } + aggregate = {"total_terminal_calls": 12} + + result = plugin_api.evaluate_tiered(definition, aggregate) + + self.assertIs(result["unlocked"], False) + self.assertIs(result["discovered"], True) + self.assertEqual(result["state"], "discovered") + self.assertEqual(result["progress"], 12) + self.assertEqual(result["next_threshold"], 50) + + def test_secret_achievement_stays_hidden_without_progress(self): + definition = { + "id": "permission_denied_any_percent", + "name": "Permission Denied Any%", + "secret": True, + "requirements": [{"metric": "permission_denied_events", "gte": 3}], + } + aggregate = {"permission_denied_events": 0} + + result = plugin_api.evaluate_requirements(definition, aggregate) + display = plugin_api.display_achievement({**definition, **result}) + + self.assertEqual(result["state"], "secret") + self.assertEqual(display["name"], "???") + self.assertNotIn("Permission", display["description"]) + + def test_multi_condition_unlock_requires_all_requirements(self): + definition = { + "id": "full_send", + "requirements": [ + {"metric": "max_terminal_calls_in_session", "gte": 10}, + {"metric": "max_file_tool_calls_in_session", "gte": 5}, + {"metric": "max_web_calls_in_session", "gte": 2}, + ], + } + + partial = plugin_api.evaluate_requirements(definition, { + "max_terminal_calls_in_session": 12, + "max_file_tool_calls_in_session": 2, + "max_web_calls_in_session": 0, + }) + complete = plugin_api.evaluate_requirements(definition, { + "max_terminal_calls_in_session": 12, + "max_file_tool_calls_in_session": 6, + "max_web_calls_in_session": 2, + }) + + self.assertEqual(partial["state"], "discovered") + self.assertIs(partial["unlocked"], False) + self.assertLess(partial["progress_pct"], 100) + self.assertEqual(complete["state"], "unlocked") + self.assertIs(complete["unlocked"], True) + + def test_catalog_has_60_plus_unique_achievements(self): + ids = [achievement["id"] for achievement in plugin_api.ACHIEVEMENTS] + self.assertGreaterEqual(len(ids), 60) + self.assertEqual(len(ids), len(set(ids))) + + def test_model_provider_metrics_are_aggregated(self): + sessions = [ + {"model_names": {"openai/gpt-5", "anthropic/claude-sonnet-4"}}, + {"model_names": {"google/gemini-pro", "mistral/large"}}, + {"model_names": {"qwen/qwen3"}}, + ] + + aggregate = plugin_api.aggregate_stats(sessions) + + self.assertEqual(aggregate["distinct_model_count"], 5) + self.assertEqual(aggregate["distinct_provider_count"], 5) + result = plugin_api.evaluate_definition( + next(a for a in plugin_api.ACHIEVEMENTS if a["id"] == "five_model_flight"), + aggregate, + ) + self.assertEqual(result["state"], "unlocked") + self.assertEqual(result["tier"], "Copper") + + def test_removed_noisy_achievements_are_not_in_catalog(self): + ids = {achievement["id"] for achievement in plugin_api.ACHIEVEMENTS} + self.assertNotIn("fallback_pilot", ids) + self.assertNotIn("browser_sleuth", ids) + self.assertNotIn("release_ritualist", ids) + + def test_open_weights_pilgrim_counts_only_local_model_metadata(self): + aggregate_mentions_only = plugin_api.aggregate_stats([ + {"model_names": {"openai/gpt-5"}, "local_model_events": 999}, + ]) + aggregate_local_chat = plugin_api.aggregate_stats([ + {"model_names": {"openai/gpt-5"}}, + {"model_names": {"ollama/llama3"}}, + ]) + definition = next(a for a in plugin_api.ACHIEVEMENTS if a["id"] == "open_weights_pilgrim") + + self.assertEqual(aggregate_mentions_only["local_model_chat_sessions"], 0) + self.assertEqual(plugin_api.evaluate_definition(definition, aggregate_mentions_only)["state"], "discovered") + self.assertEqual(aggregate_local_chat["local_model_chat_sessions"], 1) + self.assertEqual(plugin_api.evaluate_definition(definition, aggregate_local_chat)["state"], "unlocked") + + def test_config_surgeon_ignores_generic_config_mentions(self): + stats = plugin_api.analyze_messages("s1", "Config talk", [{"content": "config config configuration not configured"}]) + self.assertEqual(stats["config_events"], 0) + stats = plugin_api.analyze_messages("s2", "Real config", [{"content": "edited config.yaml, manifest.json, and .env.local"}]) + self.assertGreaterEqual(stats["config_events"], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/plugins/memory/__init__.py b/plugins/memory/__init__.py index 0ae65a25d5..0d714f64dd 100644 --- a/plugins/memory/__init__.py +++ b/plugins/memory/__init__.py @@ -27,6 +27,7 @@ import logging import sys from pathlib import Path from typing import List, Optional, Tuple +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -314,7 +315,7 @@ def _get_active_memory_provider() -> Optional[str]: try: from hermes_cli.config import load_config config = load_config() - return config.get("memory", {}).get("provider") or None + return cfg_get(config, "memory", "provider") or None except Exception: return None diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index 8bd45d8b3e..a280cbafd4 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -29,10 +29,12 @@ Or via $HERMES_HOME/hindsight/config.json (profile-scoped), falling back to from __future__ import annotations import asyncio +import atexit import importlib import json import logging import os +import queue import threading from datetime import datetime, timezone @@ -41,6 +43,7 @@ from typing import Any, Dict, List from agent.memory_provider import MemoryProvider from hermes_constants import get_hermes_home from tools.registry import tool_error +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -99,6 +102,10 @@ _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None = None _loop_lock = threading.Lock() +# Sentinel pushed to the per-provider retain queue to wake the writer for a +# clean exit. A unique object so it can never collide with a real job. +_WRITER_SENTINEL = object() + def _get_loop() -> asyncio.AbstractEventLoop: """Return a long-lived event loop running on a background thread.""" @@ -443,6 +450,16 @@ class HindsightMemoryProvider(MemoryProvider): self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread = None + # Single-writer model for retain. sync_turn() enqueues; the writer + # thread drains sequentially. Avoids spawning ad-hoc threads that + # can race the interpreter shutdown and emit "cannot schedule new + # futures after interpreter shutdown" / "Unclosed client session". + self._retain_queue: queue.Queue = queue.Queue() + self._writer_thread: threading.Thread | None = None + self._shutting_down = threading.Event() + self._atexit_registered = False + # Legacy alias — older tests/callers reference _sync_thread directly. + # Points at _writer_thread once the writer is running. self._sync_thread = None self._session_id = "" self._parent_session_id = "" @@ -817,6 +834,73 @@ class HindsightMemoryProvider(MemoryProvider): ) ) + def _ensure_writer(self) -> None: + """Lazy-start the single retain-writer thread. + + We don't start the writer in initialize() so providers that never + retain (e.g. tools-only mode) don't pay for an idle thread. + """ + thread = self._writer_thread + if thread is not None and thread.is_alive(): + return + # If the previous writer exited (e.g. after a prior shutdown), reset + # the flag so this fresh writer is allowed to drain new jobs. + self._shutting_down.clear() + thread = threading.Thread( + target=self._writer_loop, + daemon=True, + name="hindsight-writer", + ) + self._writer_thread = thread + # Keep the legacy _sync_thread alias pointing at the writer so any + # external code that joins _sync_thread keeps working. + self._sync_thread = thread + thread.start() + + def _writer_loop(self) -> None: + """Drain the retain queue serially. Exits on sentinel. + + Each job() is wrapped so a single failure can't kill the writer. + task_done() always fires so queue.join() works in tests. + """ + while True: + try: + job = self._retain_queue.get(timeout=1.0) + except queue.Empty: + if self._shutting_down.is_set(): + return + continue + try: + if job is _WRITER_SENTINEL: + return + try: + job() + except Exception as exc: + logger.warning("Hindsight retain failed: %s", exc, exc_info=True) + finally: + self._retain_queue.task_done() + + def _register_atexit(self) -> None: + """Register an idempotent atexit hook to drain the writer. + + Without this, a CLI exit that doesn't go through MemoryManager. + shutdown_all() would leave in-flight retain jobs racing interpreter + teardown, producing "cannot schedule new futures" warnings and + unclosed aiohttp sessions. + """ + if self._atexit_registered: + return + self._atexit_registered = True + atexit.register(self._atexit_shutdown) + + def _atexit_shutdown(self) -> None: + if self._shutting_down.is_set(): + return + try: + self.shutdown() + except Exception as exc: + logger.debug("Hindsight atexit shutdown failed: %s", exc) + def _run_hindsight_operation(self, operation): """Run an async Hindsight client operation, retrying once after idle shutdown.""" client = self._get_client() @@ -913,7 +997,7 @@ class HindsightMemoryProvider(MemoryProvider): self._api_url = self._config.get("api_url") or os.environ.get("HINDSIGHT_API_URL", default_url) self._llm_base_url = self._config.get("llm_base_url", "") - banks = self._config.get("banks", {}).get("hermes", {}) + banks = cfg_get(self._config, "banks", "hermes", default={}) static_bank_id = self._config.get("bank_id") or banks.get("bankId", "hermes") self._bank_id_template = self._config.get("bank_id_template", "") or "" self._bank_id = _resolve_bank_id_template( @@ -1080,6 +1164,9 @@ class HindsightMemoryProvider(MemoryProvider): if not self._auto_recall: logger.debug("Prefetch: skipped (auto_recall disabled)") return + if self._shutting_down.is_set(): + logger.debug("Prefetch: skipped (shutting down)") + return # Truncate query to max chars if self._recall_max_input_chars and len(query) > self._recall_max_input_chars: query = query[:self._recall_max_input_chars] @@ -1188,13 +1275,19 @@ class HindsightMemoryProvider(MemoryProvider): return kwargs def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Retain conversation turn in background (non-blocking). + """Enqueue a retain for the current turn. Non-blocking. - Respects retain_every_n_turns for batching. + The actual aretain_batch runs on a single long-lived writer thread + that drains an in-memory queue. Once shutdown() has been called, + further sync_turn() calls are dropped — this prevents post-exit + retains from reaching aiohttp after interpreter shutdown begins. """ if not self._auto_retain: logger.debug("sync_turn: skipped (auto_retain disabled)") return + if self._shutting_down.is_set(): + logger.debug("sync_turn: skipped (shutting down)") + return if session_id: self._session_id = str(session_id).strip() @@ -1219,37 +1312,42 @@ class HindsightMemoryProvider(MemoryProvider): if self._parent_session_id: lineage_tags.append(f"parent:{self._parent_session_id}") - def _sync(): - try: - item = self._build_retain_kwargs( - content, - context=self._retain_context, - metadata=self._build_metadata( - message_count=len(self._session_turns) * 2, - turn_index=self._turn_index, - ), - tags=lineage_tags or None, - ) - item.pop("bank_id", None) - item.pop("retain_async", None) - logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", - self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns)) - self._run_hindsight_operation( - lambda client: client.aretain_batch( - bank_id=self._bank_id, - items=[item], - document_id=self._document_id, - retain_async=self._retain_async, - ) - ) - logger.debug("Hindsight retain succeeded") - except Exception as e: - logger.warning("Hindsight sync failed: %s", e, exc_info=True) + # Snapshot the state needed for the retain. The writer may run after + # _session_turns / _turn_index are mutated by a later sync_turn(). + metadata_snapshot = self._build_metadata( + message_count=len(self._session_turns) * 2, + turn_index=self._turn_index, + ) + num_turns = len(self._session_turns) + document_id = self._document_id + bank_id = self._bank_id + retain_async_flag = self._retain_async + retain_context = self._retain_context - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - self._sync_thread = threading.Thread(target=_sync, daemon=True, name="hindsight-sync") - self._sync_thread.start() + def _do_retain() -> None: + item = self._build_retain_kwargs( + content, + context=retain_context, + metadata=metadata_snapshot, + tags=lineage_tags or None, + ) + item.pop("bank_id", None) + item.pop("retain_async", None) + logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", + bank_id, document_id, retain_async_flag, len(content), num_turns) + self._run_hindsight_operation( + lambda client: client.aretain_batch( + bank_id=bank_id, + items=[item], + document_id=document_id, + retain_async=retain_async_flag, + ) + ) + logger.debug("Hindsight retain succeeded") + + self._ensure_writer() + self._register_atexit() + self._retain_queue.put(_do_retain) def get_tool_schemas(self) -> List[Dict[str, Any]]: if self._memory_mode == "context": @@ -1324,11 +1422,149 @@ class HindsightMemoryProvider(MemoryProvider): return tool_error(f"Unknown tool: {tool_name}") + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Refresh cached per-session state when the agent rotates session_id. + + Fires on /resume, /branch, /reset, /new, and context compression. + Without this hook, initialize()-cached state (``_session_id``, + ``_document_id``, ``_session_turns``, ``_turn_counter``) would keep + pointing at the previous session and writes would land in the wrong + document. See hermes-agent#6672. + + Always update ``_session_id`` so metadata and tags on subsequent + retains reflect the active session. Always mint a fresh + ``_document_id`` so the new session's retain doesn't overwrite the + old session's document on vectorize-io/hindsight#1303. Always clear + the accumulated batch buffers (``_session_turns``, ``_turn_counter``, + ``_turn_index``) — even for /resume and /branch, the new session's + batching must start from zero so an in-flight retain doesn't flush + under the wrong ``_document_id``. + + Before clearing, flush any buffered turns under the *old* + ``_document_id``. Users who set ``retain_every_n_turns > 1`` would + otherwise silently lose whatever's in ``_session_turns`` at the + moment of switch — the same data-loss class as the shutdown race, + just at a different lifecycle event. + + Also wait for any in-flight prefetch from the old session and drop + its cached result; otherwise the new session's first ``prefetch()`` + could read stale recall text from before the switch. + + ``parent_session_id`` is recorded for lineage tags on future retains. + ``reset`` is accepted but not needed for Hindsight's state model — + buffer clearing is correct for every session switch, not only /reset. + """ + new_id = str(new_session_id or "").strip() + if not new_id: + return + + # 1. Flush any buffered turns under the OLD identifiers. Snapshot + # everything before mutating self._* so metadata + tags + doc_id + # all reference the old session consistently. + if self._session_turns: + old_turns = list(self._session_turns) + old_session_id = self._session_id + old_document_id = self._document_id + old_parent_session_id = self._parent_session_id + old_turn_index = self._turn_index + old_metadata = self._build_metadata( + message_count=len(old_turns) * 2, + turn_index=old_turn_index, + ) + old_lineage_tags: list[str] = [] + if old_session_id: + old_lineage_tags.append(f"session:{old_session_id}") + if old_parent_session_id: + old_lineage_tags.append(f"parent:{old_parent_session_id}") + old_content = "[" + ",".join(old_turns) + "]" + + def _flush(): + try: + item = self._build_retain_kwargs( + old_content, + context=self._retain_context, + metadata=old_metadata, + tags=old_lineage_tags or None, + ) + item.pop("bank_id", None) + item.pop("retain_async", None) + logger.debug( + "Hindsight flush-on-switch: bank=%s, doc=%s, num_turns=%d", + self._bank_id, old_document_id, len(old_turns), + ) + self._run_hindsight_operation( + lambda client: client.aretain_batch( + bank_id=self._bank_id, + items=[item], + document_id=old_document_id, + retain_async=self._retain_async, + ) + ) + except Exception as e: + logger.warning("Hindsight flush-on-switch failed: %s", e, exc_info=True) + + # Route the flush through the same writer queue sync_turn + # uses. That serializes it behind any still-queued retains + # from the old session (FIFO by document_id), avoids racing + # two threads on aretain_batch against the same document, and + # keeps shutdown's drain semantics intact. Skip enqueue if + # shutdown has already fired — the writer is draining/gone. + if not self._shutting_down.is_set(): + self._ensure_writer() + self._register_atexit() + self._retain_queue.put(_flush) + + # 2. Drain any in-flight prefetch from the old session and drop + # its cached result so the new session doesn't see stale recall. + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=3.0) + with self._prefetch_lock: + self._prefetch_result = "" + + # 3. Now rotate to the new session. + if parent_session_id: + self._parent_session_id = str(parent_session_id).strip() + self._session_id = new_id + start_ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + self._document_id = f"{self._session_id}-{start_ts}" + self._session_turns = [] + self._turn_counter = 0 + self._turn_index = 0 + logger.debug( + "Hindsight on_session_switch: new_session=%s parent=%s reset=%s doc=%s", + self._session_id, self._parent_session_id, reset, self._document_id, + ) + def shutdown(self) -> None: - logger.debug("Hindsight shutdown: waiting for background threads") - for t in (self._prefetch_thread, self._sync_thread): - if t and t.is_alive(): - t.join(timeout=5.0) + logger.debug("Hindsight shutdown: stopping writer + waiting for background threads") + # Stop accepting new retain jobs first so anyone still calling + # sync_turn() during teardown is dropped, not enqueued. + self._shutting_down.set() + # Drain the writer: it will finish in-flight work, then exit on + # the sentinel. Bounded join keeps shutdown predictable even if + # the daemon is wedged. + writer = self._writer_thread + if writer is not None and writer.is_alive(): + try: + self._retain_queue.put(_WRITER_SENTINEL) + except Exception: + pass + writer.join(timeout=10.0) + if writer.is_alive(): + logger.warning( + "Hindsight writer did not stop within 10s; " + "abandoning %d pending retain(s)", + self._retain_queue.qsize(), + ) + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=5.0) if self._client is not None: try: if self._mode == "local_embedded": diff --git a/plugins/memory/holographic/__init__.py b/plugins/memory/holographic/__init__.py index cd4ef07b44..dc9ee530c5 100644 --- a/plugins/memory/holographic/__init__.py +++ b/plugins/memory/holographic/__init__.py @@ -26,6 +26,7 @@ from agent.memory_provider import MemoryProvider from tools.registry import tool_error from .store import MemoryStore from .retrieval import FactRetriever +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -102,7 +103,7 @@ def _load_plugin_config() -> dict: import yaml with open(config_path) as f: all_config = yaml.safe_load(f) or {} - return all_config.get("plugins", {}).get("hermes-memory-store", {}) or {} + return cfg_get(all_config, "plugins", "hermes-memory-store", default={}) or {} except Exception: return {} diff --git a/plugins/memory/honcho/cli.py b/plugins/memory/honcho/cli.py index 8f354d2cdb..402389ab96 100644 --- a/plugins/memory/honcho/cli.py +++ b/plugins/memory/honcho/cli.py @@ -12,6 +12,7 @@ from pathlib import Path from hermes_constants import get_hermes_home from plugins.memory.honcho.client import resolve_active_host, resolve_config_path, HOST +from hermes_cli.config import cfg_get def clone_honcho_for_profile(profile_name: str) -> bool: @@ -106,7 +107,7 @@ def cmd_enable(args) -> None: # If this is a new profile host block with no settings, clone from default if not block.get("aiPeer"): - default_block = cfg.get("hosts", {}).get(HOST, {}) + default_block = cfg_get(cfg, "hosts", HOST, default={}) for key in ("recallMode", "writeFrequency", "sessionStrategy", "contextTokens", "dialecticReasoningLevel", "dialecticDynamic", "dialecticMaxChars", "messageMaxChars", "dialecticMaxInputChars", @@ -139,7 +140,7 @@ def cmd_disable(args) -> None: cfg = _read_config() host = _host_key() label = f"[{host}] " if host != "hermes" else "" - block = cfg.get("hosts", {}).get(host, {}) + block = cfg_get(cfg, "hosts", host, default={}) if not block or block.get("enabled") is False: print(f" {label}Honcho is already disabled.\n") @@ -212,7 +213,7 @@ def sync_honcho_profiles_quiet() -> int: if not cfg: return 0 - default_block = cfg.get("hosts", {}).get(HOST, {}) + default_block = cfg_get(cfg, "hosts", HOST, default={}) has_key = bool(cfg.get("apiKey") or os.environ.get("HONCHO_API_KEY")) if not default_block and not has_key: return 0 diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index f8687eb2bd..8ea4a4bedc 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -528,6 +528,46 @@ class OpenVikingMemoryProvider(MemoryProvider): # -- Tool implementations ------------------------------------------------ + @staticmethod + def _unwrap_result(resp: Any) -> Any: + """Return OpenViking payload body regardless of wrapped/unwrapped shape.""" + if isinstance(resp, dict) and "result" in resp: + return resp.get("result") + return resp + + @staticmethod + def _normalize_summary_uri(uri: str) -> str: + """Map pseudo summary files to their parent directory URI for L0/L1 reads.""" + if not uri: + return uri + for suffix in ("/.abstract.md", "/.overview.md", "/.read.md", "/.full.md"): + if uri.endswith(suffix): + return uri[: -len(suffix)] or "viking://" + return uri + + def _is_directory_uri(self, uri: str) -> bool | None: + """Probe fs/stat to decide if a URI is a directory. + + Returns True/False when the server answers cleanly, and None when the + probe itself fails (network error, unexpected shape). Callers should + treat None as "unknown" and fall back to the exception-based path. + """ + try: + resp = self._client.get("/api/v1/fs/stat", params={"uri": uri}) + except Exception: + return None + result = self._unwrap_result(resp) + if isinstance(result, dict): + if "isDir" in result: + return bool(result.get("isDir")) + if "is_dir" in result: + return bool(result.get("is_dir")) + if result.get("type") == "dir": + return True + if result.get("type") == "file": + return False + return None + def _tool_search(self, args: dict) -> str: query = args.get("query", "") if not query: @@ -576,27 +616,72 @@ class OpenVikingMemoryProvider(MemoryProvider): return tool_error("uri is required") level = args.get("level", "overview") - # Map our level names to OpenViking GET endpoints - if level == "abstract": - resp = self._client.get("/api/v1/content/abstract", params={"uri": uri}) - elif level == "full": + + summary_level = level in ("abstract", "overview") + # OpenViking expects directory URIs for pseudo summary files + # (e.g. viking://user/hermes/.overview.md). + resolved_uri = self._normalize_summary_uri(uri) if summary_level else uri + used_fallback = False + + # abstract/overview endpoints are directory-only on OpenViking + # (v0.3.x returns 500/412 for file URIs). When the caller asks for a + # summary level on a non-pseudo URI, probe fs/stat first and route + # file URIs straight to /content/read instead of eating a failing + # round-trip. The pseudo-URI path already points at a directory, so + # skip the probe there. + if summary_level and resolved_uri == uri: + is_dir = self._is_directory_uri(uri) + if is_dir is False: + resolved_uri = uri + used_fallback = True + + # Map our level names to OpenViking GET endpoints. + endpoint = "/api/v1/content/read" + if not used_fallback: + if level == "abstract": + endpoint = "/api/v1/content/abstract" + elif level == "overview": + endpoint = "/api/v1/content/overview" + + try: + resp = self._client.get(endpoint, params={"uri": resolved_uri}) + except Exception: + # OpenViking may return HTTP 500 for abstract/overview reads on normal + # file URIs (mem_*.md). For those, gracefully fallback to full read. + if not summary_level or resolved_uri != uri or used_fallback: + raise resp = self._client.get("/api/v1/content/read", params={"uri": uri}) - else: # overview - resp = self._client.get("/api/v1/content/overview", params={"uri": uri}) + used_fallback = True - result = resp.get("result", "") - # result is a plain string from the content endpoints - content = result if isinstance(result, str) else result.get("content", "") + result = self._unwrap_result(resp) + # Content endpoints may return either plain strings or objects. + if isinstance(result, str): + content = result + elif isinstance(result, dict): + content = result.get("content", "") or result.get("text", "") + else: + content = "" - # Truncate very long content to avoid flooding the context - if len(content) > 8000: - content = content[:8000] + "\n\n[... truncated, use a more specific URI or abstract level]" + # Truncate long content to avoid flooding context. + max_len = 8000 + if level == "overview": + max_len = 4000 + elif level == "abstract": + max_len = 1200 - return json.dumps({ + if len(content) > max_len: + content = content[:max_len] + "\n\n[... truncated, use a more specific URI or full level]" + + payload = { "uri": uri, + "resolved_uri": resolved_uri, "level": level, "content": content, - }, ensure_ascii=False) + } + if used_fallback: + payload["fallback"] = "content/read" + + return json.dumps(payload, ensure_ascii=False) def _tool_browse(self, args: dict) -> str: action = args.get("action", "list") @@ -606,19 +691,27 @@ class OpenVikingMemoryProvider(MemoryProvider): endpoint_map = {"tree": "/api/v1/fs/tree", "list": "/api/v1/fs/ls", "stat": "/api/v1/fs/stat"} endpoint = endpoint_map.get(action, "/api/v1/fs/ls") resp = self._client.get(endpoint, params={"uri": path}) - result = resp.get("result", {}) + result = self._unwrap_result(resp) # Format list/tree results for readability - if action in ("list", "tree") and isinstance(result, list): - entries = [] - for e in result[:50]: # cap at 50 entries - entries.append({ - "name": e.get("rel_path", e.get("name", "")), - "uri": e.get("uri", ""), - "type": "dir" if e.get("isDir") else "file", - "abstract": e.get("abstract", ""), - }) - return json.dumps({"path": path, "entries": entries}, ensure_ascii=False) + if action in ("list", "tree"): + raw_entries = result + if isinstance(result, dict): + raw_entries = result.get("entries") or result.get("items") or result.get("children") or [] + + if isinstance(raw_entries, list): + entries = [] + for e in raw_entries[:50]: # cap at 50 entries + uri = e.get("uri", "") + name = e.get("rel_path") or e.get("name") or (uri.rsplit("/", 1)[-1] if uri else "") + is_dir = bool(e.get("isDir") or e.get("is_dir") or e.get("type") == "dir") + entries.append({ + "name": name, + "uri": uri, + "type": "dir" if is_dir else "file", + "abstract": e.get("abstract", ""), + }) + return json.dumps({"path": path, "entries": entries}, ensure_ascii=False) return json.dumps(result, ensure_ascii=False) diff --git a/plugins/platforms/irc/__init__.py b/plugins/platforms/irc/__init__.py new file mode 100644 index 0000000000..d4f1d7bf0e --- /dev/null +++ b/plugins/platforms/irc/__init__.py @@ -0,0 +1,3 @@ +from .adapter import register + +__all__ = ["register"] diff --git a/plugins/platforms/irc/adapter.py b/plugins/platforms/irc/adapter.py new file mode 100644 index 0000000000..a9eea62ba2 --- /dev/null +++ b/plugins/platforms/irc/adapter.py @@ -0,0 +1,686 @@ +""" +IRC Platform Adapter for Hermes Agent. + +A plugin-based gateway adapter that connects to an IRC server and relays +messages to/from the Hermes agent. Zero external dependencies — uses +Python's stdlib asyncio for the IRC protocol. + +Configuration in config.yaml:: + + gateway: + platforms: + irc: + enabled: true + extra: + server: irc.libera.chat + port: 6697 + nickname: hermes-bot + channel: "#hermes" + use_tls: true + server_password: "" # optional server password + nickserv_password: "" # optional NickServ identification + allowed_users: [] # empty = allow all, or list of nicks + max_message_length: 450 # IRC line limit (safe default) + +Or via environment variables (overrides config.yaml): + IRC_SERVER, IRC_PORT, IRC_NICKNAME, IRC_CHANNEL, IRC_USE_TLS, + IRC_SERVER_PASSWORD, IRC_NICKSERV_PASSWORD +""" + +import asyncio +import logging +import os +import re +import ssl +import time +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy import: BasePlatformAdapter and friends live in the main repo. +# We import at function/class level to avoid import errors when the plugin +# is discovered but the gateway hasn't been fully initialised yet. +# --------------------------------------------------------------------------- + +from gateway.platforms.base import ( + BasePlatformAdapter, + SendResult, + MessageEvent, + MessageType, +) +from gateway.session import SessionSource +from gateway.config import PlatformConfig, Platform + + +def _ensure_imports(): + """No-op — kept for backward compatibility with any call sites.""" + pass + + +# --------------------------------------------------------------------------- +# IRC protocol helpers +# --------------------------------------------------------------------------- + +def _parse_irc_message(raw: str) -> dict: + """Parse a raw IRC protocol line into components. + + Returns dict with keys: prefix, command, params. + """ + prefix = "" + trailing = "" + + if raw.startswith(":"): + try: + prefix, raw = raw[1:].split(" ", 1) + except ValueError: + prefix = raw[1:] + raw = "" + + if " :" in raw: + raw, trailing = raw.split(" :", 1) + + parts = raw.split() + command = parts[0] if parts else "" + params = parts[1:] if len(parts) > 1 else [] + if trailing: + params.append(trailing) + + return {"prefix": prefix, "command": command, "params": params} + + +def _extract_nick(prefix: str) -> str: + """Extract nickname from IRC prefix (nick!user@host).""" + return prefix.split("!")[0] if "!" in prefix else prefix + + +# --------------------------------------------------------------------------- +# IRC Adapter +# --------------------------------------------------------------------------- + +class IRCAdapter(BasePlatformAdapter): + """Async IRC adapter implementing the BasePlatformAdapter interface. + + This class is instantiated by the adapter_factory passed to + register_platform(). + """ + + def __init__(self, config, **kwargs): + platform = Platform("irc") + super().__init__(config=config, platform=platform) + + extra = getattr(config, "extra", {}) or {} + + # Connection settings (env vars override config.yaml) + self.server = os.getenv("IRC_SERVER") or extra.get("server", "") + self.port = int(os.getenv("IRC_PORT") or extra.get("port", 6697)) + self.nickname = os.getenv("IRC_NICKNAME") or extra.get("nickname", "hermes-bot") + self.channel = os.getenv("IRC_CHANNEL") or extra.get("channel", "") + self.use_tls = ( + os.getenv("IRC_USE_TLS", "").lower() in ("1", "true", "yes") + if os.getenv("IRC_USE_TLS") + else extra.get("use_tls", True) + ) + self.server_password = os.getenv("IRC_SERVER_PASSWORD") or extra.get("server_password", "") + self.nickserv_password = os.getenv("IRC_NICKSERV_PASSWORD") or extra.get("nickserv_password", "") + + # Auth + self.allowed_users: list = extra.get("allowed_users", []) + # IRC nicks are case-insensitive — normalise for lookups + self._allowed_users_lower: set = {u.lower() for u in self.allowed_users if isinstance(u, str)} + + # IRC limits + max_msg = extra.get("max_message_length") + if max_msg is None: + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get("irc") + if entry and entry.max_message_length: + max_msg = entry.max_message_length + except Exception: + pass + self.max_message_length = int(max_msg or 450) + + # Runtime state + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._recv_task: Optional[asyncio.Task] = None + self._current_nick = self.nickname + self._registered = False # IRC registration complete + self._registration_event = asyncio.Event() + + @property + def name(self) -> str: + return "IRC" + + # ── Connection lifecycle ────────────────────────────────────────────── + + async def connect(self) -> bool: + """Connect to the IRC server, register, and join the channel.""" + if not self.server or not self.channel: + logger.error("IRC: server and channel must be configured") + self._set_fatal_error( + "config_missing", + "IRC_SERVER and IRC_CHANNEL must be set", + retryable=False, + ) + return False + + # Prevent two profiles from using the same IRC identity + try: + from gateway.status import acquire_scoped_lock, release_scoped_lock + lock_key = f"{self.server}:{self.nickname}" + if not acquire_scoped_lock("irc", lock_key): + logger.error("IRC: %s@%s already in use by another profile", self.nickname, self.server) + self._set_fatal_error("lock_conflict", "IRC identity in use by another profile", retryable=False) + return False + self._lock_key = lock_key + except ImportError: + self._lock_key = None # status module not available (e.g. tests) + + try: + ssl_ctx = None + if self.use_tls: + ssl_ctx = ssl.create_default_context() + + self._reader, self._writer = await asyncio.wait_for( + asyncio.open_connection(self.server, self.port, ssl=ssl_ctx), + timeout=30.0, + ) + except Exception as e: + logger.error("IRC: failed to connect to %s:%s — %s", self.server, self.port, e) + self._set_fatal_error("connect_failed", str(e), retryable=True) + return False + + # IRC registration sequence + if self.server_password: + await self._send_raw(f"PASS {self.server_password}") + await self._send_raw(f"NICK {self.nickname}") + await self._send_raw(f"USER {self.nickname} 0 * :Hermes Agent") + + # Start receive loop + self._recv_task = asyncio.create_task(self._receive_loop()) + + # Wait for registration (001 RPL_WELCOME) with timeout + try: + await asyncio.wait_for(self._registration_event.wait(), timeout=30.0) + except asyncio.TimeoutError: + logger.error("IRC: registration timed out") + await self.disconnect() + self._set_fatal_error("registration_timeout", "IRC server did not send RPL_WELCOME", retryable=True) + return False + + # NickServ identification + if self.nickserv_password: + await self._send_raw(f"PRIVMSG NickServ :IDENTIFY {self.nickserv_password}") + await asyncio.sleep(2) # Give NickServ time to process + + # Join channel + await self._send_raw(f"JOIN {self.channel}") + + self._mark_connected() + logger.info("IRC: connected to %s:%s as %s, joined %s", self.server, self.port, self._current_nick, self.channel) + return True + + async def disconnect(self) -> None: + """Quit and close the connection.""" + # Release the scoped lock so another profile can use this identity + if getattr(self, "_lock_key", None): + try: + from gateway.status import release_scoped_lock + release_scoped_lock("irc", self._lock_key) + except Exception: + pass + self._mark_disconnected() + if self._writer and not self._writer.is_closing(): + try: + await self._send_raw("QUIT :Hermes Agent shutting down") + await asyncio.sleep(0.5) + except Exception: + pass + try: + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + + self._reader = None + self._writer = None + self._registered = False + self._registration_event.clear() + + # ── Sending ─────────────────────────────────────────────────────────── + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + if not self._writer or self._writer.is_closing(): + return SendResult(success=False, error="Not connected") + + target = chat_id # channel name or nick for DMs + lines = self._split_message(content, target) + + for line in lines: + try: + await self._send_raw(f"PRIVMSG {target} :{line}") + # Basic rate limiting to avoid excess flood + await asyncio.sleep(0.3) + except Exception as e: + return SendResult(success=False, error=str(e)) + + return SendResult(success=True, message_id=str(int(time.time() * 1000))) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """IRC has no typing indicator — no-op.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + is_channel = chat_id.startswith("#") or chat_id.startswith("&") + return { + "name": chat_id, + "type": "group" if is_channel else "dm", + } + + # ── Message splitting ───────────────────────────────────────────────── + + def _split_message(self, content: str, target: str) -> List[str]: + """Split a long message into IRC-safe chunks. + + IRC has a ~512 byte line limit. After accounting for protocol + overhead (``PRIVMSG :``), we split content into chunks. + """ + # Strip markdown formatting that doesn't render in IRC + content = self._strip_markdown(content) + + overhead = len(f"PRIVMSG {target} :".encode("utf-8")) + 2 # +2 for \r\n + max_bytes = 510 - overhead + user_limit = self.max_message_length + + lines: List[str] = [] + for paragraph in content.split("\n"): + if not paragraph.strip(): + continue + while True: + para_bytes = paragraph.encode("utf-8") + limit = min(user_limit, max_bytes) + if len(para_bytes) <= limit: + if paragraph.strip(): + lines.append(paragraph) + break + # Binary search for a safe character boundary <= limit + low, high = 1, len(paragraph) + best = 0 + while low <= high: + mid = (low + high) // 2 + if len(paragraph[:mid].encode("utf-8")) <= limit: + best = mid + low = mid + 1 + else: + high = mid - 1 + split_at = best + # Prefer a space boundary + space = paragraph.rfind(" ", 0, split_at) + if space > split_at // 3: + split_at = space + lines.append(paragraph[:split_at].rstrip()) + paragraph = paragraph[split_at:].lstrip() + + return lines if lines else [""] + + @staticmethod + def _strip_markdown(text: str) -> str: + """Convert basic markdown to plain text for IRC.""" + # Bold: **text** or __text__ → text + text = re.sub(r"\*\*(.+?)\*\*", r"\1", text) + text = re.sub(r"__(.+?)__", r"\1", text) + # Italic: *text* or _text_ → text + text = re.sub(r"\*(.+?)\*", r"\1", text) + text = re.sub(r"(? None: + """Send a raw IRC protocol line.""" + if not self._writer or self._writer.is_closing(): + return + encoded = (line + "\r\n").encode("utf-8") + self._writer.write(encoded) + await self._writer.drain() + + async def _receive_loop(self) -> None: + """Main receive loop — reads lines and dispatches them.""" + buffer = b"" + try: + while self._reader and not self._reader.at_eof(): + data = await self._reader.read(4096) + if not data: + break + buffer += data + while b"\r\n" in buffer: + line, buffer = buffer.split(b"\r\n", 1) + try: + decoded = line.decode("utf-8", errors="replace") + await self._handle_line(decoded) + except Exception as e: + logger.warning("IRC: error handling line: %s", e) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("IRC: receive loop error: %s", e) + finally: + if self.is_connected: + logger.warning("IRC: connection lost, marking disconnected") + self._set_fatal_error("connection_lost", "IRC connection closed unexpectedly", retryable=True) + await self._notify_fatal_error() + + async def _handle_line(self, raw: str) -> None: + """Dispatch a single IRC protocol line.""" + msg = _parse_irc_message(raw) + command = msg["command"] + params = msg["params"] + + # PING/PONG keepalive + if command == "PING": + payload = params[0] if params else "" + await self._send_raw(f"PONG :{payload}") + return + + # RPL_WELCOME (001) — registration complete + if command == "001": + self._registered = True + self._registration_event.set() + if params: + # Server may confirm our nick in the first param + self._current_nick = params[0] + return + + # ERR_NICKNAMEINUSE (433) — nick collision during registration + if command == "433": + # Retry with incrementing suffix: hermes_, hermes_1, hermes_2... + base = self.nickname.rstrip("_0123456789") + suffix_match = re.search(r"_(\d+)$", self._current_nick) + if suffix_match: + next_num = int(suffix_match.group(1)) + 1 + self._current_nick = f"{base}_{next_num}" + elif self._current_nick == self.nickname: + self._current_nick = self.nickname + "_" + else: + self._current_nick = self.nickname + "_1" + await self._send_raw(f"NICK {self._current_nick}") + return + + # PRIVMSG — incoming message (channel or DM) + if command == "PRIVMSG" and len(params) >= 2: + sender_nick = _extract_nick(msg["prefix"]) + target = params[0] + text = params[1] + + # Ignore our own messages + if sender_nick.lower() == self._current_nick.lower(): + return + + # CTCP ACTION (/me) — convert to text + if text.startswith("\x01ACTION ") and text.endswith("\x01"): + text = f"* {sender_nick} {text[8:-1]}" + + # Ignore other CTCP + if text.startswith("\x01"): + return + + # Determine if this is a channel message or DM + is_channel = target.startswith("#") or target.startswith("&") + chat_id = target if is_channel else sender_nick + chat_type = "group" if is_channel else "dm" + + # In channels, only respond if addressed (nick: or nick,) + if is_channel: + addressed = False + for prefix in (f"{self._current_nick}:", f"{self._current_nick},", + f"{self._current_nick} "): + if text.lower().startswith(prefix.lower()): + text = text[len(prefix):].strip() + addressed = True + break + if not addressed: + return # Ignore unaddressed channel messages + + # Auth check (case-insensitive) + if self._allowed_users_lower and sender_nick.lower() not in self._allowed_users_lower: + logger.debug("IRC: ignoring message from unauthorized user %s", sender_nick) + return + + await self._dispatch_message( + text=text, + chat_id=chat_id, + chat_type=chat_type, + user_id=sender_nick, + user_name=sender_nick, + ) + + # NICK — track our own nick changes + if command == "NICK" and _extract_nick(msg["prefix"]).lower() == self._current_nick.lower(): + if params: + self._current_nick = params[0] + + async def _dispatch_message( + self, + text: str, + chat_id: str, + chat_type: str, + user_id: str, + user_name: str, + ) -> None: + """Build a MessageEvent and hand it to the base class handler.""" + if not self._message_handler: + return + + source = self.build_source( + chat_id=chat_id, + chat_name=chat_id, + chat_type=chat_type, + user_id=user_id, + user_name=user_name, + ) + + event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id=str(int(time.time() * 1000)), + timestamp=__import__("datetime").datetime.now(), + ) + + await self.handle_message(event) + + +# --------------------------------------------------------------------------- +# Plugin registration +# --------------------------------------------------------------------------- + +def check_requirements() -> bool: + """Check if IRC is configured. + + Only requires the server and channel — no external pip packages needed. + """ + server = os.getenv("IRC_SERVER", "") + channel = os.getenv("IRC_CHANNEL", "") + # Also accept config.yaml-only configuration (no env vars). + # The gateway passes PlatformConfig; we just check env for the + # hermes setup / requirements check path. + return bool(server and channel) + + +def validate_config(config) -> bool: + """Validate that the platform config has enough info to connect.""" + extra = getattr(config, "extra", {}) or {} + server = os.getenv("IRC_SERVER") or extra.get("server", "") + channel = os.getenv("IRC_CHANNEL") or extra.get("channel", "") + return bool(server and channel) + + +def interactive_setup() -> None: + """Interactive `hermes gateway setup` flow for the IRC platform. + + Lazy-imports ``hermes_cli.setup`` helpers so the plugin stays importable + in non-CLI contexts (gateway runtime, tests). + """ + from hermes_cli.setup import ( + prompt, + prompt_yes_no, + save_env_value, + get_env_value, + print_header, + print_info, + print_warning, + print_success, + ) + + print_header("IRC") + existing_server = get_env_value("IRC_SERVER") + if existing_server: + print_info(f"IRC: already configured (server: {existing_server})") + if not prompt_yes_no("Reconfigure IRC?", False): + return + + print_info("Connect Hermes to an IRC network. Uses Python stdlib — no extra packages needed.") + print_info(" Works with Libera.Chat, OFTC, your own ZNC/InspIRCd, etc.") + print() + + server = prompt("IRC server hostname (e.g. irc.libera.chat)", default=existing_server or "") + if not server: + print_warning("Server is required — skipping IRC setup") + return + save_env_value("IRC_SERVER", server.strip()) + + use_tls = prompt_yes_no("Use TLS (recommended)?", True) + save_env_value("IRC_USE_TLS", "true" if use_tls else "false") + + default_port = "6697" if use_tls else "6667" + port = prompt(f"Port (default {default_port})", default=get_env_value("IRC_PORT") or "") + if port: + try: + save_env_value("IRC_PORT", str(int(port))) + except ValueError: + print_warning(f"Invalid port — using default {default_port}") + elif get_env_value("IRC_PORT"): + # User cleared the prompt; drop the override so the default applies. + save_env_value("IRC_PORT", "") + + nickname = prompt( + "Bot nickname (e.g. hermes-bot)", + default=get_env_value("IRC_NICKNAME") or "", + ) + if not nickname: + print_warning("Nickname is required — skipping IRC setup") + return + save_env_value("IRC_NICKNAME", nickname.strip()) + + channel = prompt( + "Channel to join (e.g. #hermes — comma-separate for multiple)", + default=get_env_value("IRC_CHANNEL") or "", + ) + if not channel: + print_warning("Channel is required — skipping IRC setup") + return + save_env_value("IRC_CHANNEL", channel.strip()) + + print() + print_info("🔑 Optional authentication") + print_info(" Leave blank to skip.") + if prompt_yes_no("Configure a server password (PASS command)?", False): + server_password = prompt("Server password", password=True) + if server_password: + save_env_value("IRC_SERVER_PASSWORD", server_password) + + if prompt_yes_no("Identify with NickServ on connect?", False): + nickserv = prompt("NickServ password", password=True) + if nickserv: + save_env_value("IRC_NICKSERV_PASSWORD", nickserv) + + print() + print_info("🔒 Access control: restrict who can message the bot") + print_info(" IRC nicks are not authenticated — anyone can claim any nick.") + print_info(" For public channels, pair with NickServ-only mode on your network") + print_info(" if you want stronger identity guarantees.") + allow_all = prompt_yes_no("Allow all users in the channel to talk to the bot?", False) + if allow_all: + save_env_value("IRC_ALLOW_ALL_USERS", "true") + save_env_value("IRC_ALLOWED_USERS", "") + print_warning("⚠️ Open access — any nick in the channel can command the bot.") + else: + save_env_value("IRC_ALLOW_ALL_USERS", "false") + allowed = prompt( + "Allowed nicks (comma-separated, leave empty to deny everyone)", + default=get_env_value("IRC_ALLOWED_USERS") or "", + ) + if allowed: + save_env_value("IRC_ALLOWED_USERS", allowed.replace(" ", "")) + print_success("Allowlist configured") + else: + save_env_value("IRC_ALLOWED_USERS", "") + print_info("No nicks allowed — the bot will ignore all messages until you add nicks.") + + print() + print_success("IRC configuration saved to ~/.hermes/.env") + print_info("Restart the gateway for changes to take effect: hermes gateway restart") + + +def is_connected(config) -> bool: + """Check whether IRC is configured (env or config.yaml).""" + extra = getattr(config, "extra", {}) or {} + server = os.getenv("IRC_SERVER") or extra.get("server", "") + channel = os.getenv("IRC_CHANNEL") or extra.get("channel", "") + return bool(server and channel) + + +def register(ctx): + """Plugin entry point — called by the Hermes plugin system.""" + ctx.register_platform( + name="irc", + label="IRC", + adapter_factory=lambda cfg: IRCAdapter(cfg), + check_fn=check_requirements, + validate_config=validate_config, + is_connected=is_connected, + required_env=["IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"], + install_hint="No extra packages needed (stdlib only)", + setup_fn=interactive_setup, + # Auth env vars for _is_user_authorized() integration + allowed_users_env="IRC_ALLOWED_USERS", + allow_all_env="IRC_ALLOW_ALL_USERS", + # IRC line limit after protocol overhead + max_message_length=450, + # Display + emoji="💬", + # IRC doesn't have phone numbers to redact + pii_safe=False, + allow_update_command=True, + # LLM guidance + platform_hint=( + "You are chatting via IRC. IRC does not support markdown formatting " + "— use plain text only. Messages are limited to ~450 characters per " + "line (long messages are automatically split). In channels, users " + "address you by prefixing your nick. Keep responses concise and " + "conversational." + ), + ) diff --git a/plugins/platforms/irc/plugin.yaml b/plugins/platforms/irc/plugin.yaml new file mode 100644 index 0000000000..1e3d19f48c --- /dev/null +++ b/plugins/platforms/irc/plugin.yaml @@ -0,0 +1,13 @@ +name: irc-platform +kind: platform +version: 1.0.0 +description: > + IRC gateway adapter for Hermes Agent. + Connects to an IRC server and relays messages between an IRC channel + (or DMs) and the Hermes agent. No external dependencies — uses + Python's stdlib asyncio for the IRC protocol. +author: Nous Research +requires_env: + - IRC_SERVER + - IRC_CHANNEL + - IRC_NICKNAME diff --git a/plugins/platforms/teams/__init__.py b/plugins/platforms/teams/__init__.py new file mode 100644 index 0000000000..d4f1d7bf0e --- /dev/null +++ b/plugins/platforms/teams/__init__.py @@ -0,0 +1,3 @@ +from .adapter import register + +__all__ = ["register"] diff --git a/plugins/platforms/teams/adapter.py b/plugins/platforms/teams/adapter.py new file mode 100644 index 0000000000..b1769cf52c --- /dev/null +++ b/plugins/platforms/teams/adapter.py @@ -0,0 +1,685 @@ +""" +Microsoft Teams platform adapter for Hermes Agent. + +Uses the microsoft-teams-apps SDK for authentication and activity processing. +Runs an aiohttp webhook server to receive messages from Teams. +Proactive messaging (send, typing) uses the SDK's App.send() method. + +Requires: + pip install microsoft-teams-apps aiohttp + TEAMS_CLIENT_ID, TEAMS_CLIENT_SECRET, and TEAMS_TENANT_ID env vars + +Configuration in config.yaml: + platforms: + teams: + enabled: true + extra: + client_id: "your-client-id" # or TEAMS_CLIENT_ID env var + client_secret: "your-secret" # or TEAMS_CLIENT_SECRET env var + tenant_id: "your-tenant-id" # or TEAMS_TENANT_ID env var + port: 3978 # or TEAMS_PORT env var +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +from typing import Any, Dict, Optional + +try: + from aiohttp import web + + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + web = None # type: ignore[assignment] + +try: + from microsoft_teams.apps import App, ActivityContext + from microsoft_teams.api import MessageActivity, ConversationReference + from microsoft_teams.api.activities.typing import TypingActivityInput + from microsoft_teams.api.activities.invoke.adaptive_card import AdaptiveCardInvokeActivity + from microsoft_teams.api.models.adaptive_card import ( + AdaptiveCardActionCardResponse, + AdaptiveCardActionMessageResponse, + ) + from microsoft_teams.api.models.invoke_response import InvokeResponse, AdaptiveCardInvokeResponse + from microsoft_teams.apps.http.adapter import ( + HttpMethod, + HttpRequest, + HttpResponse, + HttpRouteHandler, + ) + from microsoft_teams.cards import AdaptiveCard, ExecuteAction, TextBlock + + TEAMS_SDK_AVAILABLE = True +except ImportError: + TEAMS_SDK_AVAILABLE = False + App = None # type: ignore[assignment,misc] + ActivityContext = None # type: ignore[assignment,misc] + MessageActivity = None # type: ignore[assignment,misc] + ConversationReference = None # type: ignore[assignment,misc] + TypingActivityInput = None # type: ignore[assignment,misc] + AdaptiveCardInvokeActivity = None # type: ignore[assignment,misc] + AdaptiveCardActionCardResponse = None # type: ignore[assignment,misc] + AdaptiveCardActionMessageResponse = None # type: ignore[assignment,misc] + AdaptiveCardInvokeResponse = None # type: ignore[assignment,misc,union-attr] + InvokeResponse = None # type: ignore[assignment,misc] + HttpMethod = str # type: ignore[assignment,misc] + HttpRequest = None # type: ignore[assignment,misc] + HttpResponse = None # type: ignore[assignment,misc] + HttpRouteHandler = None # type: ignore[assignment,misc] + AdaptiveCard = None # type: ignore[assignment,misc] + ExecuteAction = None # type: ignore[assignment,misc] + TextBlock = None # type: ignore[assignment,misc] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.helpers import MessageDeduplicator +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, + cache_image_from_url, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_PORT = 3978 +_WEBHOOK_PATH = "/api/messages" + + +class _AiohttpBridgeAdapter: + """HttpServerAdapter that bridges the Teams SDK into an aiohttp server. + + Without a custom adapter, ``App()`` unconditionally imports fastapi/uvicorn + and allocates a ``FastAPI()`` instance. This bridge captures the SDK's + route registrations and wires them into our own aiohttp ``Application``. + """ + + def __init__(self, aiohttp_app: "web.Application"): + self._aiohttp_app = aiohttp_app + + def register_route(self, method: "HttpMethod", path: str, handler: "HttpRouteHandler") -> None: + """Register an SDK route handler as an aiohttp route.""" + + async def _aiohttp_handler(request: "web.Request") -> "web.Response": + body = await request.json() + headers = dict(request.headers) + result: "HttpResponse" = await handler(HttpRequest(body=body, headers=headers)) + status = result.get("status", 200) + resp_body = result.get("body") + if resp_body is not None: + return web.Response( + status=status, + body=json.dumps(resp_body), + content_type="application/json", + ) + return web.Response(status=status) + + self._aiohttp_app.router.add_route(method, path, _aiohttp_handler) + + def serve_static(self, path: str, directory: str) -> None: + pass + + async def start(self, port: int) -> None: + raise NotImplementedError("aiohttp server is managed by the adapter") + + async def stop(self) -> None: + pass + + +def check_requirements() -> bool: + """Return True when all Teams dependencies and credentials are present.""" + return TEAMS_SDK_AVAILABLE and AIOHTTP_AVAILABLE + + +def validate_config(config) -> bool: + """Return True when the config has the minimum required credentials.""" + extra = getattr(config, "extra", {}) or {} + client_id = os.getenv("TEAMS_CLIENT_ID") or extra.get("client_id", "") + client_secret = os.getenv("TEAMS_CLIENT_SECRET") or extra.get("client_secret", "") + tenant_id = os.getenv("TEAMS_TENANT_ID") or extra.get("tenant_id", "") + return bool(client_id and client_secret and tenant_id) + + +def is_connected(config) -> bool: + """Check whether Teams is configured (env or config.yaml).""" + return validate_config(config) + + +# Keep the old name as an alias so existing test imports don't break. +check_teams_requirements = check_requirements + + +class TeamsAdapter(BasePlatformAdapter): + """Microsoft Teams adapter using the microsoft-teams-apps SDK.""" + + MAX_MESSAGE_LENGTH = 28000 # Teams text message limit (~28 KB) + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform("teams")) + extra = config.extra or {} + self._client_id = extra.get("client_id") or os.getenv("TEAMS_CLIENT_ID", "") + self._client_secret = extra.get("client_secret") or os.getenv("TEAMS_CLIENT_SECRET", "") + self._tenant_id = extra.get("tenant_id") or os.getenv("TEAMS_TENANT_ID", "") + self._port = int(extra.get("port") or os.getenv("TEAMS_PORT", str(_DEFAULT_PORT))) + self._app: Optional["App"] = None + self._runner: Optional["web.AppRunner"] = None + self._dedup = MessageDeduplicator(max_size=1000) + # Maps chat_id → ConversationReference captured from incoming messages. + # Used to send cards with the correct conversation type (personal/group/channel). + self._conv_refs: Dict[str, Any] = {} + + async def connect(self) -> bool: + if not TEAMS_SDK_AVAILABLE: + self._set_fatal_error( + "MISSING_SDK", + "microsoft-teams-apps not installed. Run: pip install microsoft-teams-apps", + retryable=False, + ) + return False + + if not AIOHTTP_AVAILABLE: + self._set_fatal_error( + "MISSING_SDK", + "aiohttp not installed. Run: pip install aiohttp", + retryable=False, + ) + return False + + if not self._client_id or not self._client_secret or not self._tenant_id: + self._set_fatal_error( + "MISSING_CREDENTIALS", + "TEAMS_CLIENT_ID, TEAMS_CLIENT_SECRET, and TEAMS_TENANT_ID are all required", + retryable=False, + ) + return False + + try: + # Set up aiohttp app first — the bridge adapter wires SDK routes into it + aiohttp_app = web.Application() + aiohttp_app.router.add_get("/health", lambda _: web.Response(text="ok")) + + self._app = App( + client_id=self._client_id, + client_secret=self._client_secret, + tenant_id=self._tenant_id, + http_server_adapter=_AiohttpBridgeAdapter(aiohttp_app), + ) + + # Register message handler before initialize() + @self._app.on_message + async def _handle_message(ctx: ActivityContext[MessageActivity]): + await self._on_message(ctx) + + @self._app.on_card_action + async def _handle_card_action( + ctx: ActivityContext[AdaptiveCardInvokeActivity], + ) -> InvokeResponse[AdaptiveCardActionMessageResponse]: + return await self._on_card_action(ctx) + + # initialize() calls register_route() on the bridge, which adds + # POST /api/messages to aiohttp_app automatically + await self._app.initialize() + + self._runner = web.AppRunner(aiohttp_app) + await self._runner.setup() + site = web.TCPSite(self._runner, "0.0.0.0", self._port) + await site.start() + + self._running = True + self._mark_connected() + logger.info( + "[teams] Webhook server listening on 0.0.0.0:%d%s", + self._port, + _WEBHOOK_PATH, + ) + return True + + except Exception as e: + self._set_fatal_error( + "CONNECT_FAILED", + f"Teams connection failed: {e}", + retryable=True, + ) + logger.error("[teams] Failed to connect: %s", e) + return False + + async def disconnect(self) -> None: + self._running = False + if self._runner: + await self._runner.cleanup() + self._runner = None + self._app = None + self._mark_disconnected() + logger.info("[teams] Disconnected") + + async def _on_message(self, ctx: ActivityContext[MessageActivity]) -> None: + """Process an incoming Teams message and dispatch to the gateway.""" + activity = ctx.activity + + # Self-message filter + bot_id = self._app.id if self._app else None + if bot_id and getattr(activity.from_, "id", None) == bot_id: + return + + # Deduplication + msg_id = getattr(activity, "id", None) + if msg_id and self._dedup.is_duplicate(msg_id): + return + + # Cache the conversation reference for proactive sends (approval cards, etc.) + conv_id = getattr(activity.conversation, "id", None) + if conv_id: + self._conv_refs[conv_id] = ctx.conversation_ref + + # Extract text — strip bot @mentions + text = "" + if hasattr(activity, "text") and activity.text: + text = activity.text + # Strip BotName HTML tags that Teams prepends for @mentions + if "" in text: + import re + text = re.sub(r"[^<]*\s*", "", text).strip() + + # Determine chat type from conversation + conv = activity.conversation + conv_type = getattr(conv, "conversation_type", None) or "" + if conv_type == "personal": + chat_type = "dm" + elif conv_type == "groupChat": + chat_type = "group" + elif conv_type == "channel": + chat_type = "channel" + else: + chat_type = "dm" + + # Build source + from_account = activity.from_ + user_id = getattr(from_account, "aad_object_id", None) or getattr(from_account, "id", "") + user_name = getattr(from_account, "name", None) or "" + + source = self.build_source( + chat_id=conv.id, + chat_name=getattr(conv, "name", None) or "", + chat_type=chat_type, + user_id=str(user_id), + user_name=user_name, + guild_id=getattr(conv, "tenant_id", None) or self._tenant_id, + ) + + # Handle image attachments + media_urls = [] + media_types = [] + for att in getattr(activity, "attachments", None) or []: + content_url = getattr(att, "content_url", None) + content_type = getattr(att, "content_type", None) or "" + if content_url and content_type.startswith("image/"): + try: + cached = await cache_image_from_url(content_url) + if cached: + media_urls.append(cached) + media_types.append(content_type) + except Exception as e: + logger.warning("[teams] Failed to cache image attachment: %s", e) + + msg_type = MessageType.PHOTO if media_urls else MessageType.TEXT + + event = MessageEvent( + text=text, + source=source, + message_type=msg_type, + media_urls=media_urls, + media_types=media_types, + message_id=msg_id, + ) + await self.handle_message(event) + + async def _send_card(self, chat_id: str, card: "AdaptiveCard") -> "Any": + """Send an AdaptiveCard, using a stored ConversationReference when available.""" + from microsoft_teams.api import MessageActivityInput + + conv_ref = self._conv_refs.get(chat_id) + if conv_ref and self._app: + activity = MessageActivityInput().add_card(card) + return await self._app.activity_sender.send(activity, conv_ref) + elif self._app: + return await self._app.send(chat_id, card) + return None + + async def _on_card_action( + self, ctx: "ActivityContext[AdaptiveCardInvokeActivity]" + ) -> "InvokeResponse[AdaptiveCardActionMessageResponse]": + """Handle an Adaptive Card Action.Execute button click.""" + from tools.approval import resolve_gateway_approval, has_blocking_approval + + action = ctx.activity.value.action + data = action.data or {} + hermes_action = data.get("hermes_action", "") + session_key = data.get("session_key", "") + + if not hermes_action or not session_key: + return InvokeResponse( + status=200, + body=AdaptiveCardActionMessageResponse(value="Unknown action."), + ) + + # Only authorized users may click approval buttons. + allowed_csv = os.getenv("TEAMS_ALLOWED_USERS", "").strip() + if allowed_csv: + from_account = ctx.activity.from_ + clicker_id = getattr(from_account, "aad_object_id", None) or getattr(from_account, "id", "") + allowed_ids = {uid.strip() for uid in allowed_csv.split(",") if uid.strip()} + if "*" not in allowed_ids and clicker_id not in allowed_ids: + logger.warning("[teams] Unauthorized card action by %s — ignoring", clicker_id) + return InvokeResponse( + status=200, + body=AdaptiveCardActionMessageResponse(value="⛔ Not authorized."), + ) + + choice_map = { + "approve_once": "once", + "approve_session": "session", + "approve_always": "always", + "deny": "deny", + } + choice = choice_map.get(hermes_action) + if not choice: + return InvokeResponse( + status=200, + body=AdaptiveCardActionMessageResponse(value="Unknown action."), + ) + + if not has_blocking_approval(session_key): + return InvokeResponse( + status=200, + body=AdaptiveCardActionCardResponse( + value=AdaptiveCard() + .with_version("1.4") + .with_body([TextBlock(text="⚠️ Approval already resolved or expired.", wrap=True)]) + ), + ) + + resolve_gateway_approval(session_key, choice) + + label_map = { + "once": "✅ Allowed (once)", + "session": "✅ Allowed (session)", + "always": "✅ Always allowed", + "deny": "❌ Denied", + } + cmd = data.get("cmd", "") + desc = data.get("desc", "") + body = [] + if cmd: + body.append(TextBlock(text="⚠️ Command Approval Required", wrap=True, weight="Bolder")) + body.append(TextBlock(text=f"```\n{cmd}\n```", wrap=True)) + if desc: + body.append(TextBlock(text=f"Reason: {desc}", wrap=True, isSubtle=True)) + body.append(TextBlock(text=label_map[choice], wrap=True, weight="Bolder")) + + return InvokeResponse( + status=200, + body=AdaptiveCardActionCardResponse( + value=AdaptiveCard().with_version("1.4").with_body(body) + ), + ) + + async def send_exec_approval( + self, + chat_id: str, + command: str, + session_key: str, + description: str = "dangerous command", + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an Adaptive Card approval prompt with Allow/Deny buttons.""" + if not self._app: + return SendResult(success=False, error="Teams app not initialized") + + cmd_preview = command[:2000] + "..." if len(command) > 2000 else command + # Truncated for button data payload — just enough to reconstruct the card body. + btn_data_base = { + "session_key": session_key, + "cmd": command[:200] + "..." if len(command) > 200 else command, + "desc": description, + } + + card = ( + AdaptiveCard() + .with_version("1.4") + .with_body([ + TextBlock(text="⚠️ Command Approval Required", wrap=True, weight="Bolder"), + TextBlock(text=f"```\n{cmd_preview}\n```", wrap=True), + TextBlock(text=f"Reason: {description}", wrap=True, isSubtle=True), + ]) + .with_actions([ + ExecuteAction( + title="Allow Once", + verb="hermes_approve", + data={**btn_data_base, "hermes_action": "approve_once"}, + style="positive", + ), + ExecuteAction( + title="Allow Session", + verb="hermes_approve", + data={**btn_data_base, "hermes_action": "approve_session"}, + ), + ExecuteAction( + title="Always Allow", + verb="hermes_approve", + data={**btn_data_base, "hermes_action": "approve_always"}, + ), + ExecuteAction( + title="Deny", + verb="hermes_approve", + data={**btn_data_base, "hermes_action": "deny"}, + style="destructive", + ), + ]) + ) + + try: + result = await self._send_card(chat_id, card) + message_id = getattr(result, "id", None) if result else None + return SendResult(success=True, message_id=message_id) + except Exception as e: + logger.error("[teams] send_exec_approval failed: %s", e, exc_info=True) + return SendResult(success=False, error=str(e), retryable=True) + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + if not self._app: + return SendResult(success=False, error="Teams app not initialized") + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted) + last_message_id = None + + for chunk in chunks: + try: + result = await self._app.send(chat_id, chunk) + last_message_id = getattr(result, "id", None) + except Exception as e: + return SendResult(success=False, error=str(e), retryable=True) + + return SendResult(success=True, message_id=last_message_id) + + async def send_typing(self, chat_id: str, metadata: Optional[Dict[str, Any]] = None) -> None: + if not self._app: + return + try: + await self._app.send(chat_id, TypingActivityInput()) + except Exception: + pass + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + if not self._app: + return SendResult(success=False, error="Teams app not initialized") + + try: + import base64 + import mimetypes + from microsoft_teams.api import Attachment, MessageActivityInput + + if image_url.startswith("http://") or image_url.startswith("https://"): + content_url = image_url + mime_type = "image/png" + else: + # Local path — encode as base64 data URI + path = image_url.removeprefix("file://") + mime_type = mimetypes.guess_type(path)[0] or "image/png" + with open(path, "rb") as f: + content_url = f"data:{mime_type};base64,{base64.b64encode(f.read()).decode()}" + + attachment = Attachment(content_type=mime_type, content_url=content_url) + activity = MessageActivityInput().add_attachments(attachment) + if caption: + activity = activity.add_text(caption) + + conv_ref = self._conv_refs.get(chat_id) + if conv_ref: + result = await self._app.activity_sender.send(activity, conv_ref) + else: + result = await self._app.send(chat_id, activity) + + return SendResult(success=True, message_id=getattr(result, "id", None)) + except Exception as e: + logger.error("[teams] send_image failed: %s", e, exc_info=True) + return SendResult(success=False, error=str(e), retryable=True) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, + ) -> SendResult: + return await self.send_image( + chat_id=chat_id, + image_url=image_path, + caption=caption, + reply_to=reply_to, + ) + + async def get_chat_info(self, chat_id: str) -> dict: + return {"name": chat_id, "type": "unknown", "chat_id": chat_id} + + +# ── Interactive setup ───────────────────────────────────────────────────────── + +def interactive_setup() -> None: + """Guide the user through Teams setup using the Teams CLI.""" + from hermes_cli.config import ( + get_env_value, + save_env_value, + prompt, + prompt_yes_no, + print_info, + print_success, + print_warning, + ) + + existing_id = get_env_value("TEAMS_CLIENT_ID") + if existing_id: + print_info(f"Teams: already configured (app ID: {existing_id})") + if not prompt_yes_no("Reconfigure Teams?", False): + return + + print_info("You'll need the Teams CLI. If you haven't already:") + print_info(" npm install -g @microsoft/teams.cli@preview") + print_info(" teams login") + print() + print_info("Then expose port 3978 publicly (devtunnel / ngrok / cloudflared),") + print_info("and create your bot:") + print_info(" teams app create --name \"Hermes\" --endpoint \"https:///api/messages\"") + print() + print_info("The CLI will print CLIENT_ID, CLIENT_SECRET, and TENANT_ID. Paste them below.") + print() + + client_id = prompt("Client ID", default=existing_id or "") + if not client_id: + print_warning("Client ID is required — skipping Teams setup") + return + save_env_value("TEAMS_CLIENT_ID", client_id.strip()) + + client_secret = prompt("Client secret", default=get_env_value("TEAMS_CLIENT_SECRET") or "", password=True) + if not client_secret: + print_warning("Client secret is required — skipping Teams setup") + return + save_env_value("TEAMS_CLIENT_SECRET", client_secret.strip()) + + tenant_id = prompt("Tenant ID", default=get_env_value("TEAMS_TENANT_ID") or "") + if not tenant_id: + print_warning("Tenant ID is required — skipping Teams setup") + return + save_env_value("TEAMS_TENANT_ID", tenant_id.strip()) + + print() + print_info("To find your AAD object ID for the allowlist: teams status --verbose") + if prompt_yes_no("Restrict access to specific users? (recommended)", True): + allowed = prompt( + "Allowed AAD object IDs (comma-separated)", + default=get_env_value("TEAMS_ALLOWED_USERS") or "", + ) + if allowed: + save_env_value("TEAMS_ALLOWED_USERS", allowed.replace(" ", "")) + print_success("Allowlist configured") + else: + save_env_value("TEAMS_ALLOWED_USERS", "") + else: + save_env_value("TEAMS_ALLOW_ALL_USERS", "true") + print_warning("⚠️ Open access — anyone who can message the bot can command it.") + + print() + print_success("Teams configuration saved to ~/.hermes/.env") + print_info("Install the app in Teams: teams app install --id ") + print_info("Restart the gateway: hermes gateway restart") + + +# ── Plugin entry point ──────────────────────────────────────────────────────── + +def register(ctx) -> None: + """Plugin entry point — called by the Hermes plugin system.""" + ctx.register_platform( + name="teams", + label="Microsoft Teams", + adapter_factory=lambda cfg: TeamsAdapter(cfg), + check_fn=check_requirements, + validate_config=validate_config, + is_connected=is_connected, + required_env=["TEAMS_CLIENT_ID", "TEAMS_CLIENT_SECRET", "TEAMS_TENANT_ID"], + install_hint="pip install microsoft-teams-apps aiohttp", + setup_fn=interactive_setup, + # Auth env vars for _is_user_authorized() integration + allowed_users_env="TEAMS_ALLOWED_USERS", + allow_all_env="TEAMS_ALLOW_ALL_USERS", + # Teams supports up to ~28 KB per message + max_message_length=28000, + # Display + emoji="💼", + allow_update_command=True, + # LLM guidance + platform_hint=( + "You are chatting via Microsoft Teams. Teams renders a subset of " + "markdown — bold (**text**), italic (*text*), and inline code " + "(`code`) work, but complex tables or raw HTML do not. Keep " + "responses clear and professional." + ), + ) diff --git a/plugins/platforms/teams/plugin.yaml b/plugins/platforms/teams/plugin.yaml new file mode 100644 index 0000000000..57f18adaa1 --- /dev/null +++ b/plugins/platforms/teams/plugin.yaml @@ -0,0 +1,13 @@ +name: teams-platform +kind: platform +version: 1.0.0 +description: > + Microsoft Teams gateway adapter for Hermes Agent. + Connects to Microsoft Teams via the Bot Framework and relays messages + between Teams chats (personal DMs, group chats, channel posts) and + the Hermes agent. Supports Adaptive Card approval prompts. +author: Aamir Jawaid +requires_env: + - TEAMS_CLIENT_ID + - TEAMS_CLIENT_SECRET + - TEAMS_TENANT_ID diff --git a/pyproject.toml b/pyproject.toml index 57a752877e..f1132e8d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ dependencies = [ "firecrawl-py>=4.16.0,<5", "parallel-web>=0.4.2,<1", "fal-client>=0.13.1,<1", + # Cron scheduler (built-in feature — scheduled cron/interval jobs use croniter). + "croniter>=6.0.0,<7", # Text-to-speech (Edge TTS is free, no API key needed) "edge-tts>=7.2.7,<8", # Skills Hub (GitHub App JWT auth — optional, only needed for bot identity) @@ -39,9 +41,10 @@ dependencies = [ [project.optional-dependencies] modal = ["modal>=1.0.0,<2"] daytona = ["daytona>=0.148.0,<1"] +vercel = ["vercel>=0.5.7,<0.6.0"] dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2", "ty>=0.0.1a29,<0.0.22", "ruff"] messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4", "qrcode>=7.0,<8"] -cron = ["croniter>=6.0.0,<7"] +cron = [] # croniter is now a core dependency; this extra kept for back-compat slack = ["slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"] matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4", "aiosqlite>=0.20", "asyncpg>=0.29", "aiohttp-socks>=0.10,<1"] cli = ["simple-term-menu>=1.0,<2"] @@ -100,6 +103,7 @@ yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git@bfb0c88 all = [ "hermes-agent[modal]", "hermes-agent[daytona]", + "hermes-agent[vercel]", "hermes-agent[messaging]", # matrix: python-olm (required by matrix-nio[e2e]) is upstream-broken on # modern macOS (archived libolm, C++ errors with Clang 21+). On Linux the diff --git a/run_agent.py b/run_agent.py index f5729dcd42..80738aab16 100644 --- a/run_agent.py +++ b/run_agent.py @@ -160,6 +160,7 @@ from agent.trajectory import ( save_trajectory as _save_trajectory_to_file, ) from utils import atomic_json_write, base_url_host_matches, base_url_hostname, env_var_enabled, normalize_proxy_url +from hermes_cli.config import cfg_get @@ -322,6 +323,12 @@ _PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"}) # Maximum number of concurrent worker threads for parallel tool execution. _MAX_TOOL_WORKERS = 8 +# Guard so the OpenRouter metadata pre-warm thread is only spawned once per +# process, not once per AIAgent instantiation. Without this, long-running +# gateway processes leak one OS thread per incoming message and eventually +# exhaust the system thread limit (RuntimeError: can't start new thread). +_openrouter_prewarm_done = threading.Event() + # Patterns that indicate a terminal command may modify/delete files. _DESTRUCTIVE_PATTERNS = re.compile( r"""(?:^|\s|&&|\|\||;|`)(?: @@ -925,6 +932,7 @@ class AIAgent: thread_id: str = None, gateway_session_key: str = None, skip_context_files: bool = False, + load_soul_identity: bool = False, skip_memory: bool = False, session_db=None, parent_session_id: str = None, @@ -976,6 +984,9 @@ class AIAgent: skip_context_files (bool): If True, skip auto-injection of SOUL.md, AGENTS.md, and .cursorrules into the system prompt. Use this for batch processing and data generation to avoid polluting trajectories with user-specific persona or project instructions. + load_soul_identity (bool): If True, still use ~/.hermes/SOUL.md as the primary + identity even when skip_context_files=True. Project context files from the cwd + remain skipped. """ _install_safe_stdio() @@ -1004,6 +1015,7 @@ class AIAgent: self._print_fn = None self.background_review_callback = None # Optional sync callback for gateway delivery self.skip_context_files = skip_context_files + self.load_soul_identity = load_soul_identity self.pass_session_id = pass_session_id self._credential_pool = credential_pool self.log_prefix_chars = log_prefix_chars @@ -1101,10 +1113,17 @@ class AIAgent: # Pre-warm OpenRouter model metadata cache in a background thread. # fetch_model_metadata() is cached for 1 hour; this avoids a blocking # HTTP request on the first API response when pricing is estimated. - if self.provider == "openrouter" or self._is_openrouter_url(): + # Use a process-level Event so this thread is only spawned once — a new + # AIAgent is created for every gateway request, so without the guard + # each message leaks one OS thread and the process eventually exhausts + # the system thread limit (RuntimeError: can't start new thread). + if (self.provider == "openrouter" or self._is_openrouter_url()) and \ + not _openrouter_prewarm_done.is_set(): + _openrouter_prewarm_done.set() threading.Thread( - target=lambda: fetch_model_metadata(), + target=fetch_model_metadata, daemon=True, + name="openrouter-prewarm", ).start() self.tool_progress_callback = tool_progress_callback @@ -1788,7 +1807,7 @@ class AIAgent: # compression model. Custom endpoints often cannot report this via # /models, so the startup feasibility check needs the config hint. try: - _aux_cfg = _agent_cfg.get("auxiliary", {}).get("compression", {}) + _aux_cfg = cfg_get(_agent_cfg, "auxiliary", "compression", default={}) except Exception: _aux_cfg = {} if isinstance(_aux_cfg, dict): @@ -1892,6 +1911,7 @@ class AIAgent: self._ensure_lmstudio_runtime_loaded(_config_context_length) + # Select context engine: config-driven (like memory providers). # 1. Check config.yaml context.engine setting # 2. Check plugins/context_engine// directory (repo-shipped) @@ -2132,7 +2152,7 @@ class AIAgent: # Context engine reset (works for both built-in compressor and plugins) if hasattr(self, "context_compressor") and self.context_compressor: self.context_compressor.on_session_reset() - + def _ensure_lmstudio_runtime_loaded(self, config_context_length: Optional[int] = None) -> None: """ Preload the LM Studio model with at least Hermes' minimum context. @@ -2813,6 +2833,24 @@ class AIAgent: # Third-party Anthropic-compatible gateway. return True, True + # MiniMax on its Anthropic-compatible endpoint serves its own + # model family (MiniMax-M2.7, M2.5, M2.1, M2) with documented + # cache_control support (0.1× read pricing, 5-minute TTL). The + # blanket is_claude gate above excludes these — opt them in + # explicitly via provider id or host match so users on + # provider=minimax / minimax-cn (or custom endpoints pointing at + # api.minimax.io/anthropic / api.minimaxi.com/anthropic) get the + # same cost reduction as Claude traffic. + # Docs: https://platform.minimax.io/docs/api-reference/anthropic-api-compatible-cache + if is_anthropic_wire: + is_minimax_provider = provider_lower in {"minimax", "minimax-cn"} + is_minimax_host = ( + base_url_host_matches(eff_base_url, "api.minimax.io") + or base_url_host_matches(eff_base_url, "api.minimaxi.com") + ) + if is_minimax_provider or is_minimax_host: + return True, True + # Qwen/Alibaba on OpenCode (Zen/Go) and native DashScope: OpenAI-wire # transport that accepts Anthropic-style cache_control markers and # rewards them with real cache hits. Without this branch @@ -2898,7 +2936,7 @@ class AIAgent: # Check if there's any non-whitespace content remaining return bool(cleaned.strip()) - + def _strip_think_blocks(self, content: str) -> str: """Remove reasoning/thinking blocks from content, returning only visible text. @@ -3116,8 +3154,8 @@ class AIAgent: marker in assistant_text for marker in workspace_markers ) return (user_targets_workspace or assistant_targets_workspace) and assistant_mentions_action - - + + def _extract_reasoning(self, assistant_message) -> Optional[str]: """ Extract reasoning/thinking content from an assistant message. @@ -3690,7 +3728,7 @@ class AIAgent: # Return everything up to (not including) the last assistant message return messages[:last_assistant_idx] - + def _format_tools_for_system_message(self) -> str: """ Format tool definitions for the system message in the trajectory format. @@ -3714,7 +3752,7 @@ class AIAgent: formatted_tools.append(formatted_tool) return json.dumps(formatted_tools, ensure_ascii=False) - + def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]: """ Convert internal message format to trajectory format for saving. @@ -3879,7 +3917,7 @@ class AIAgent: i += 1 return trajectory - + def _save_trajectory(self, messages: List[Dict[str, Any]], user_query: str, completed: bool): """ Save conversation trajectory to JSONL file. @@ -3894,7 +3932,7 @@ class AIAgent: trajectory = self._convert_to_trajectory_format(messages, user_query, completed) _save_trajectory_to_file(trajectory, self.model, completed) - + @staticmethod def _summarize_api_error(error: Exception) -> str: """Extract a human-readable one-liner from an API error. @@ -4206,7 +4244,7 @@ class AIAgent: except Exception as e: if self.verbose_logging: logging.warning(f"Failed to save session log: {e}") - + def interrupt(self, message: str = None) -> None: """ Request the agent to interrupt its current tool-calling loop. @@ -4274,7 +4312,7 @@ class AIAgent: logger.debug("Failed to propagate interrupt to child agent: %s", e) if not self.quiet_mode: print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "")) - + def clear_interrupt(self) -> None: """Clear any pending interrupt request and the per-thread tool interrupt signal.""" self._interrupt_requested = False @@ -4495,7 +4533,7 @@ class AIAgent: ) except Exception: pass - + def commit_memory_session(self, messages: list = None) -> None: """Trigger end-of-session extraction without tearing providers down. Called when session_id rotates (e.g. /new, context compression); @@ -4546,8 +4584,14 @@ class AIAgent: if not (self._memory_manager and final_response and original_user_message): return try: - self._memory_manager.sync_all(original_user_message, final_response) - self._memory_manager.queue_prefetch_all(original_user_message) + self._memory_manager.sync_all( + original_user_message, final_response, + session_id=self.session_id or "", + ) + self._memory_manager.queue_prefetch_all( + original_user_message, + session_id=self.session_id or "", + ) except Exception: pass @@ -4685,7 +4729,7 @@ class AIAgent: if not self.quiet_mode: self._vprint(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history") _set_interrupt(False) - + @property def is_interrupted(self) -> bool: """Check if an interrupt has been requested.""" @@ -4717,9 +4761,11 @@ class AIAgent: # 6. Current date & time (frozen at build time) # 7. Platform-specific formatting hint - # Try SOUL.md as primary identity (unless context files are skipped) + # Try SOUL.md as primary identity unless the caller explicitly skipped it. + # Some execution modes (cron) still want HERMES_HOME persona while keeping + # cwd project instructions disabled. _soul_loaded = False - if not self.skip_context_files: + if self.load_soul_identity or not self.skip_context_files: _soul_content = load_soul_md() if _soul_content: prompt_parts = [_soul_content] @@ -4867,6 +4913,15 @@ class AIAgent: platform_key = (self.platform or "").lower().strip() if platform_key in PLATFORM_HINTS: prompt_parts.append(PLATFORM_HINTS[platform_key]) + elif platform_key: + # Check plugin registry for platform-specific LLM guidance + try: + from gateway.platform_registry import platform_registry + _entry = platform_registry.get(platform_key) + if _entry and _entry.platform_hint: + prompt_parts.append(_entry.platform_hint) + except Exception: + pass return "\n\n".join(p.strip() for p in prompt_parts if p.strip()) @@ -6155,7 +6210,12 @@ class AIAgent: correctly — rebuilding with the Bedrock SDK when provider is bedrock, rather than always falling back to build_anthropic_client() which requires a direct Anthropic API key. + + Honors ``self._oauth_1m_beta_disabled`` (set by the reactive recovery + path when an OAuth subscription rejects the 1M-context beta) so the + rebuilt client carries the reduced beta set. """ + _drop_1m = bool(getattr(self, "_oauth_1m_beta_disabled", False)) if getattr(self, "provider", None) == "bedrock": from agent.anthropic_adapter import build_anthropic_bedrock_client region = getattr(self, "_bedrock_region", "us-east-1") or "us-east-1" @@ -6166,6 +6226,7 @@ class AIAgent: self._anthropic_api_key, getattr(self, "_anthropic_base_url", None), timeout=get_provider_request_timeout(self.provider, self.model), + drop_context_1m_beta=_drop_1m, ) def _interruptible_api_call(self, api_kwargs: dict): @@ -6470,6 +6531,9 @@ class AIAgent: Falls back to _interruptible_api_call on provider errors indicating streaming is not supported. """ + if self._interrupt_requested: + raise InterruptedError("Agent interrupted before streaming API call") + if self.api_mode == "codex_responses": # Codex streams internally via _run_codex_stream. The main dispatch # in _interruptible_api_call already calls it; we just need to @@ -7128,6 +7192,12 @@ class AIAgent: # to non-streaming on the next attempt via _disable_streaming. result["error"] = e return + except InterruptedError as e: + # The interrupt may be noticed inside the worker thread before + # the polling loop sees it. Surface it through the normal result + # channel so callers never miss a fast pre-retry interrupt. + result["error"] = e + return finally: request_client = request_client_holder.get("client") if request_client is not None: @@ -8112,6 +8182,7 @@ class AIAgent: context_length=ctx_len, base_url=getattr(self, "_anthropic_base_url", None), fast_mode=(self.request_overrides or {}).get("speed") == "fast", + drop_context_1m_beta=bool(getattr(self, "_oauth_1m_beta_disabled", False)), ) # AWS Bedrock native Converse API — bypasses the OpenAI client entirely. @@ -8234,6 +8305,7 @@ class AIAgent: model=self.model, messages=_msgs_for_chat, tools=self.tools, + base_url=self.base_url, timeout=self._resolved_api_call_timeout(), max_tokens=self.max_tokens, ephemeral_max_output_tokens=_ephemeral_out, @@ -8919,6 +8991,23 @@ class AIAgent: except Exception as _ce_err: logger.debug("context engine on_session_start (compression): %s", _ce_err) + # Notify memory providers of the compression-driven session_id rotation + # so provider-cached per-session state (Hindsight's _document_id, + # accumulated turn buffers, counters) refreshes. reset=False because + # the logical conversation continues; only the id and DB row rolled + # over. See #6672. + try: + _old_sid = locals().get("old_session_id") + if _old_sid and self._memory_manager: + self._memory_manager.on_session_switch( + self.session_id or "", + parent_session_id=_old_sid, + reset=False, + reason="compression", + ) + except Exception as _me_err: + logger.debug("memory manager on_session_switch (compression): %s", _me_err) + # Warn on repeated compressions (quality degrades with each pass) _cc = self.context_compressor.compression_count if _cc >= 2: @@ -9940,7 +10029,7 @@ class AIAgent: is_oauth=self._is_anthropic_oauth, preserve_dots=self._anthropic_preserve_dots()) summary_response = self._anthropic_messages_create(_ant_kw) - _summary_result = _tsum.normalize_response(summary_response) + _summary_result = _tsum.normalize_response(summary_response, strip_tool_prefix=self._is_anthropic_oauth) final_response = (_summary_result.content or "").strip() else: summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary").chat.completions.create(**summary_kwargs) @@ -9970,7 +10059,7 @@ class AIAgent: max_tokens=self.max_tokens, reasoning_config=self.reasoning_config, preserve_dots=self._anthropic_preserve_dots()) retry_response = self._anthropic_messages_create(_ant_kw2) - _retry_result = _tretry.normalize_response(retry_response) + _retry_result = _tretry.normalize_response(retry_response, strip_tool_prefix=self._is_anthropic_oauth) final_response = (_retry_result.content or "").strip() else: summary_kwargs = { @@ -10679,6 +10768,7 @@ class AIAgent: copilot_auth_retry_attempted=False thinking_sig_retry_attempted = False image_shrink_retry_attempted = False + oauth_1m_beta_retry_attempted = False has_retried_429 = False restart_with_compressed_messages = False restart_with_length_continuation = False @@ -11098,7 +11188,12 @@ class AIAgent: # would have been appended in the non-truncated path. _trunc_msg = None _trunc_transport = self._get_transport() - _trunc_result = _trunc_transport.normalize_response(response) + if self.api_mode == "anthropic_messages": + _trunc_result = _trunc_transport.normalize_response( + response, strip_tool_prefix=self._is_anthropic_oauth + ) + else: + _trunc_result = _trunc_transport.normalize_response(response) _trunc_msg = _trunc_result _trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None @@ -11630,6 +11725,36 @@ class AIAgent: "or shrink didn't reduce size; surfacing original error." ) + # Anthropic OAuth subscription rejected the 1M-context beta + # header ("long context beta is not yet available for this + # subscription"). Disable the beta for the rest of this + # session, rebuild the client, and retry once. 1M-capable + # subscriptions never hit this branch — they accept the + # beta and keep full 1M context. See PR #17680 for the + # original report (we chose reactive recovery over the + # proposed unconditional omit so capable subscriptions + # don't silently lose the capability). + if ( + classified.reason == FailoverReason.oauth_long_context_beta_forbidden + and self.api_mode == "anthropic_messages" + and self._is_anthropic_oauth + and not oauth_1m_beta_retry_attempted + ): + oauth_1m_beta_retry_attempted = True + if not getattr(self, "_oauth_1m_beta_disabled", False): + self._oauth_1m_beta_disabled = True + try: + self._anthropic_client.close() + except Exception: + pass + self._rebuild_anthropic_client() + self._vprint( + f"{self.log_prefix}🔕 OAuth subscription doesn't support " + f"the 1M-context beta — disabled for this session and retrying...", + force=True, + ) + continue + if ( self.api_mode == "codex_responses" and self.provider == "openai-codex" @@ -12436,7 +12561,10 @@ class AIAgent: try: _transport = self._get_transport() - normalized = _transport.normalize_response(response) + _normalize_kwargs = {} + if self.api_mode == "anthropic_messages": + _normalize_kwargs["strip_tool_prefix"] = self._is_anthropic_oauth + normalized = _transport.normalize_response(response, **_normalize_kwargs) assistant_message = normalized finish_reason = normalized.finish_reason diff --git a/scripts/release.py b/scripts/release.py index d66b3b36d4..baeb7dbe1c 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -44,30 +44,48 @@ AUTHOR_MAP = { "qiyin.zuo@pcitc.com": "qiyin-code", "teknium@nousresearch.com": "teknium1", "127238744+teknium1@users.noreply.github.com": "teknium1", + "2093036+exiao@users.noreply.github.com": "exiao", + "rylen.anil@gmail.com": "rylena", + "14046872+tmimmanuel@users.noreply.github.com": "tmimmanuel", "revar@users.noreply.github.com": "revaraver", # Matrix parity salvage batch (April 2026) "sr@samirusani": "samrusani", "angelclaw@AngelMacBook.local": "angel12", "charles@cryptoassetrecovery.com": "charles-brooks", "heathley@Heathley-MacBook-Air.local": "heathley", + "vlad19@gmail.com": "dandaka", "adamrummer@gmail.com": "cyclingwithelephants", "nbot@liizfq.top": "liizfq", "274096618+hermes-agent-dhabibi@users.noreply.github.com": "dhabibi", "dejie.guo@gmail.com": "JayGwod", + # OpenViking viking_read salvage (April 2026) + "hitesh@gmail.com": "htsh", + "pty819@outlook.com": "pty819", + "pty819@users.noreply.github.com": "pty819", + "517024110@qq.com": "chennest", + "aamirjawaid@microsoft.com": "heyitsaamir", "johnnncenaaa77@gmail.com": "johnncenae", "thomasjhon6666@gmail.com": "ThomassJonax", "focusflow.app.help@gmail.com": "yes999zc", + "162235745+0z1-ghb@users.noreply.github.com": "0z1-ghb", "yes999zc@163.com": "yes999zc", "343873859@qq.com": "DrStrangerUJN", + "252818347@qq.com": "hejuntt1014", "uzmpsk.dilekakbas@gmail.com": "dlkakbs", "beliefanx@gmail.com": "BeliefanX", "jefferson@heimdallstrategy.com": "Mind-Dragon", + "44753291+Nanako0129@users.noreply.github.com": "Nanako0129", "steve.westerhouse@origami-analytics.com": "westers", + "yeyitech@users.noreply.github.com": "yeyitech", + "260878550+beenherebefore@users.noreply.github.com": "beenherebefore", + "79389617+txbxxx@users.noreply.github.com": "txbxxx", + "liuhao03@bilibili.com": "liuhao1024", "130918800+devorun@users.noreply.github.com": "devorun", "surat.s@itm.kmutnb.ac.th": "beesrsj2500", "beesr@bee.localdomain": "beesrsj2500", "mtf201013@gmail.com": "ma-pony", "sonoyuncudmr@gmail.com": "Sonoyunchu", + "43525405+yatesjalex@users.noreply.github.com": "yatesjalex", "maks.mir@yahoo.com": "say8hi", "27719690+Mirac1eSky@users.noreply.github.com": "Mirac1eSky", "web3blind@users.noreply.github.com": "web3blind", @@ -95,7 +113,9 @@ AUTHOR_MAP = { "82637225+kshitijk4poor@users.noreply.github.com": "kshitijk4poor", "keifergu@tencent.com": "keifergu", "kshitijk4poor@users.noreply.github.com": "kshitijk4poor", + "SHL0MS@users.noreply.github.com": "SHL0MS", "abner.the.foreman@agentmail.to": "Abnertheforeman", + "adam.manning@pro-serveinc.com": "amanning3390", "thomasgeorgevii09@gmail.com": "tochukwuada", "harryykyle1@gmail.com": "hharry11", "kshitijk4poor@gmail.com": "kshitijk4poor", @@ -120,6 +140,7 @@ AUTHOR_MAP = { "126368201+vilkasdev@users.noreply.github.com": "vilkasdev", "137614867+cutepawss@users.noreply.github.com": "cutepawss", "96793918+memosr@users.noreply.github.com": "memosr", + "mehmet.sr35@gmail.com": "memosr", "milkoor@users.noreply.github.com": "milkoor", "xuerui911@gmail.com": "Fatty911", "131039422+SHL0MS@users.noreply.github.com": "SHL0MS", @@ -259,6 +280,7 @@ AUTHOR_MAP = { "danielrpike9@gmail.com": "Bartok9", "skozyuk@cruxexperts.com": "CruxExperts", "154585401+LeonSGP43@users.noreply.github.com": "LeonSGP43", + "12250313+Kailigithub@users.noreply.github.com": "Kailigithub", "mgparkprint@gmail.com": "vlwkaos", "tranquil_flow@protonmail.com": "Tranquil-Flow", "LyleLengyel@gmail.com": "mcndjxlefnd", @@ -313,6 +335,7 @@ AUTHOR_MAP = { "dalvidjr2022@gmail.com": "Jr-kenny", "m@statecraft.systems": "mbierling", "balyan.sid@gmail.com": "alt-glitch", + "52913345+alt-glitch@users.noreply.github.com": "alt-glitch", "oluwadareab12@gmail.com": "bennytimz", "simon@simonmarcus.org": "simon-marcus", "xowiekk@gmail.com": "Xowiek", diff --git a/skills/creative/comfyui/SKILL.md b/skills/creative/comfyui/SKILL.md new file mode 100644 index 0000000000..4fbeb60357 --- /dev/null +++ b/skills/creative/comfyui/SKILL.md @@ -0,0 +1,606 @@ +--- +name: comfyui +description: "Generate images, video, and audio with ComfyUI — install, launch, manage nodes/models, run workflows with parameter injection. Uses the official comfy-cli for lifecycle and direct REST/WebSocket API for execution." +version: 5.0.0 +author: [kshitijk4poor, alt-glitch] +license: MIT +platforms: [macos, linux, windows] +compatibility: "Requires ComfyUI (local, Comfy Desktop, or Comfy Cloud) and comfy-cli (auto-installed via pipx/uvx by the setup script)." +prerequisites: + commands: ["python3"] +setup: + help: "Run scripts/hardware_check.py FIRST to decide local vs Comfy Cloud; then scripts/comfyui_setup.sh auto-installs locally (or use Cloud API key for platform.comfy.org)." +metadata: + hermes: + tags: + - comfyui + - image-generation + - stable-diffusion + - flux + - sd3 + - wan-video + - hunyuan-video + - creative + - generative-ai + - video-generation + related_skills: [stable-diffusion-image-generation, image_gen] + category: creative +--- + +# ComfyUI + +Generate images, video, audio, and 3D content through ComfyUI using the +official `comfy-cli` for setup/lifecycle and direct REST/WebSocket API +for workflow execution. + +## What's in this skill + +**Reference docs (`references/`):** + +- `official-cli.md` — every `comfy ...` command, with flags +- `rest-api.md` — REST + WebSocket endpoints (local + cloud), payload schemas +- `workflow-format.md` — API-format JSON, common node types, param mapping + +**Scripts (`scripts/`):** + +| Script | Purpose | +|--------|---------| +| `_common.py` | Shared HTTP, cloud routing, node catalogs (don't run directly) | +| `hardware_check.py` | Probe GPU/VRAM/disk → recommend local vs Comfy Cloud | +| `comfyui_setup.sh` | Hardware check + comfy-cli + ComfyUI install + launch + verify | +| `extract_schema.py` | Read a workflow → list controllable params + model deps | +| `check_deps.py` | Check workflow against running server → list missing nodes/models | +| `auto_fix_deps.py` | Run check_deps then `comfy node install` / `comfy model download` | +| `run_workflow.py` | Inject params, submit, monitor, download outputs (HTTP or WS) | +| `run_batch.py` | Submit a workflow N times with sweeps, parallel up to your tier | +| `ws_monitor.py` | Real-time WebSocket viewer for executing jobs (live progress) | +| `health_check.py` | Verification checklist runner — comfy-cli + server + models + smoke test | +| `fetch_logs.py` | Pull traceback / status messages for a given prompt_id | + +**Example workflows (`workflows/`):** SD 1.5, SDXL, Flux Dev, SDXL img2img, +SDXL inpaint, ESRGAN upscale, AnimateDiff video, Wan T2V. See +`workflows/README.md`. + +## When to Use + +- User asks to generate images with Stable Diffusion, SDXL, Flux, SD3, etc. +- User wants to run a specific ComfyUI workflow file +- User wants to chain generative steps (txt2img → upscale → face restore) +- User needs ControlNet, inpainting, img2img, or other advanced pipelines +- User asks to manage ComfyUI queue, check models, or install custom nodes +- User wants video/audio/3D generation via AnimateDiff, Hunyuan, Wan, AudioCraft, etc. + +## Architecture: Two Layers + +``` +┌─────────────────────────────────────────────────────┐ +│ Layer 1: comfy-cli (official lifecycle tool) │ +│ Setup, server lifecycle, custom nodes, models │ +│ → comfy install / launch / stop / node / model │ +└─────────────────────────┬───────────────────────────┘ + │ +┌─────────────────────────▼───────────────────────────┐ +│ Layer 2: REST/WebSocket API + skill scripts │ +│ Workflow execution, param injection, monitoring │ +│ POST /api/prompt, GET /api/view, WS /ws │ +│ → run_workflow.py, run_batch.py, ws_monitor.py │ +└─────────────────────────────────────────────────────┘ +``` + +**Why two layers?** The official CLI is excellent for installation and server +management but has minimal workflow execution support. The REST/WS API fills +that gap — the scripts handle param injection, execution monitoring, and +output download that the CLI doesn't do. + +## Quick Start + +### Detect environment + +```bash +# What's available? +command -v comfy >/dev/null 2>&1 && echo "comfy-cli: installed" +curl -s http://127.0.0.1:8188/system_stats 2>/dev/null && echo "server: running" + +# Can this machine run ComfyUI locally? (GPU/VRAM/disk check) +python3 scripts/hardware_check.py +``` + +If nothing is installed, see **Setup & Onboarding** below — but always run the +hardware check first. + +### One-line health check + +```bash +python3 scripts/health_check.py +# → JSON: comfy_cli on PATH? server reachable? at least one checkpoint? smoke-test passes? +``` + +## Core Workflow + +### Step 1: Get a workflow JSON in API format + +Workflows must be in API format (each node has `class_type`). They come from: + +- ComfyUI web UI → **Workflow → Export (API)** (newer UI) or + the legacy "Save (API Format)" button (older UI) +- This skill's `workflows/` directory (ready-to-run examples) +- Community downloads (civitai, Reddit, Discord) — usually editor format, + must be loaded into ComfyUI then re-exported + +Editor format (top-level `nodes` and `links` arrays) is **not directly +executable**. The scripts detect this and tell you to re-export. + +### Step 2: See what's controllable + +```bash +python3 scripts/extract_schema.py workflow_api.json --summary-only +# → {"parameter_count": 12, "has_negative_prompt": true, "has_seed": true, ...} + +python3 scripts/extract_schema.py workflow_api.json +# → full schema with parameters, model deps, embedding refs +``` + +### Step 3: Run with parameters + +```bash +# Local (defaults to http://127.0.0.1:8188) +python3 scripts/run_workflow.py \ + --workflow workflow_api.json \ + --args '{"prompt": "a beautiful sunset over mountains", "seed": -1, "steps": 30}' \ + --output-dir ./outputs + +# Cloud (export API key once; uses correct /api routing automatically) +export COMFY_CLOUD_API_KEY="comfyui-..." +python3 scripts/run_workflow.py \ + --workflow workflow_api.json \ + --args '{"prompt": "..."}' \ + --host https://cloud.comfy.org \ + --output-dir ./outputs + +# Real-time progress via WebSocket (requires `pip install websocket-client`) +python3 scripts/run_workflow.py \ + --workflow flux_dev.json \ + --args '{"prompt": "..."}' \ + --ws + +# img2img / inpaint: pass --input-image to upload + reference automatically +python3 scripts/run_workflow.py \ + --workflow sdxl_img2img.json \ + --input-image image=./photo.png \ + --args '{"prompt": "make it watercolor", "denoise": 0.6}' + +# Batch / sweep: 8 random seeds, parallel up to cloud tier limit +python3 scripts/run_batch.py \ + --workflow sdxl.json \ + --args '{"prompt": "abstract"}' \ + --count 8 --randomize-seed --parallel 3 \ + --output-dir ./outputs/batch +``` + +`-1` for `seed` (or omitting it with `--randomize-seed`) generates a fresh +random seed per run. + +### Step 4: Present results + +The scripts emit JSON to stdout describing every output file: + +```json +{ + "status": "success", + "prompt_id": "abc-123", + "outputs": [ + {"file": "./outputs/sdxl_00001_.png", "node_id": "9", + "type": "image", "filename": "sdxl_00001_.png"} + ] +} +``` + +## Decision Tree + +| User says | Tool | Command | +|-----------|------|---------| +| **Lifecycle (use comfy-cli)** | | | +| "install ComfyUI" | comfy-cli | `bash scripts/comfyui_setup.sh` | +| "start ComfyUI" | comfy-cli | `comfy launch --background` | +| "stop ComfyUI" | comfy-cli | `comfy stop` | +| "install X node" | comfy-cli | `comfy node install ` | +| "download X model" | comfy-cli | `comfy model download --url --relative-path models/checkpoints` | +| "list installed models" | comfy-cli | `comfy model list` | +| "list installed nodes" | comfy-cli | `comfy node show installed` | +| **Execution (use scripts)** | | | +| "is everything ready?" | script | `health_check.py` (optionally with `--workflow X --smoke-test`) | +| "what can I change in this workflow?" | script | `extract_schema.py W.json` | +| "check if W's deps are met" | script | `check_deps.py W.json` | +| "fix missing deps" | script | `auto_fix_deps.py W.json` | +| "generate an image" | script | `run_workflow.py --workflow W --args '{...}'` | +| "use this image" (img2img) | script | `run_workflow.py --input-image image=./x.png ...` | +| "8 variations with random seeds" | script | `run_batch.py --count 8 --randomize-seed ...` | +| "show me live progress" | script | `ws_monitor.py --prompt-id ` | +| "fetch the error from job X" | script | `fetch_logs.py ` | +| **Direct REST** | | | +| "what's in the queue?" | REST | `curl http://HOST:8188/queue` (local) or `--host https://cloud.comfy.org` | +| "cancel that" | REST | `curl -X POST http://HOST:8188/interrupt` | +| "free GPU memory" | REST | `curl -X POST http://HOST:8188/free` | + +## Setup & Onboarding + +When a user asks to set up ComfyUI, **the FIRST thing to do is ask whether +they want Comfy Cloud (hosted, zero install, API key) or Local (install +ComfyUI on their machine)**. Don't start running install commands or hardware +checks until they've answered. + +**Official docs:** https://docs.comfy.org/installation +**CLI docs:** https://docs.comfy.org/comfy-cli/getting-started +**Cloud docs:** https://docs.comfy.org/get_started/cloud +**Cloud API:** https://docs.comfy.org/development/cloud/overview + +### Step 0: Ask Local vs Cloud (ALWAYS FIRST) + +Suggested script: + +> "Do you want to run ComfyUI locally on your machine, or use Comfy Cloud? +> +> - **Comfy Cloud** — hosted on RTX 6000 Pro GPUs, all common models pre-installed, +> zero setup. Requires an API key (paid subscription required to actually run +> workflows; free tier is read-only). Best if you don't have a capable GPU. +> - **Local** — free, but your machine MUST meet the hardware requirements: +> - NVIDIA GPU with **≥6 GB VRAM** (≥8 GB for SDXL, ≥12 GB for Flux/video), OR +> - AMD GPU with ROCm support (Linux), OR +> - Apple Silicon Mac (M1+) with **≥16 GB unified memory** (≥32 GB recommended). +> - Intel Macs and machines with no GPU will NOT work — use Cloud instead. +> +> Which would you like?" + +Routing: + +- **Cloud** → skip to **Path A**. +- **Local** → run hardware check first, then pick a path from Paths B–E based on the verdict. +- **Unsure** → run the hardware check and let the verdict decide. + +### Step 1: Verify Hardware (ONLY if user chose local) + +```bash +python3 scripts/hardware_check.py --json +# Optional: also probe `torch` for actual CUDA/MPS: +python3 scripts/hardware_check.py --json --check-pytorch +``` + +| Verdict | Meaning | Action | +|------------|---------------------------------------------------------------|--------| +| `ok` | ≥8 GB VRAM (discrete) OR ≥32 GB unified (Apple Silicon) | Local install — use `comfy_cli_flag` from report | +| `marginal` | SD1.5 works; SDXL tight; Flux/video unlikely | Local OK for light workflows, else **Path A (Cloud)** | +| `cloud` | No usable GPU, <6 GB VRAM, <16 GB Apple unified, Intel Mac, Rosetta Python | **Switch to Cloud** unless user explicitly forces local | + +The script also surfaces `wsl: true` (WSL2 with NVIDIA passthrough) and +`rosetta: true` (x86_64 Python on Apple Silicon — must reinstall as ARM64). + +If verdict is `cloud` but the user wants local, do not proceed silently. +Show the `notes` array verbatim and ask whether they want to (a) switch to +Cloud or (b) force a local install (will OOM or be unusably slow on modern models). + +### Choosing an Installation Path + +Use the hardware check first. The table below is the fallback for when the +user has already told you their hardware: + +| Situation | Recommended Path | +|-----------|------------------| +| `verdict: cloud` from hardware check | **Path A: Comfy Cloud** | +| No GPU / want to try without commitment | **Path A: Comfy Cloud** | +| Windows + NVIDIA + non-technical | **Path B: ComfyUI Desktop** | +| Windows + NVIDIA + technical | **Path C: Portable** or **Path D: comfy-cli** | +| Linux + any GPU | **Path D: comfy-cli** (easiest) | +| macOS + Apple Silicon | **Path B: Desktop** or **Path D: comfy-cli** | +| Headless / server / CI / agents | **Path D: comfy-cli** | + +For the fully automated path (hardware check → install → launch → verify): + +```bash +bash scripts/comfyui_setup.sh +# Or with overrides: +bash scripts/comfyui_setup.sh --m-series --port=8190 --workspace=/data/comfy +``` + +It runs `hardware_check.py` internally, refuses to install locally when the +verdict is `cloud` (unless `--force-cloud-override`), picks the right +`comfy-cli` flag, and prefers `pipx`/`uvx` over global `pip` to avoid polluting +system Python. + +--- + +### Path A: Comfy Cloud (No Local Install) + +For users without a capable GPU or who want zero setup. Hosted on RTX 6000 Pro. + +**Docs:** https://docs.comfy.org/get_started/cloud + +1. Sign up at https://comfy.org/cloud +2. Generate an API key at https://platform.comfy.org/login +3. Set the key: + ```bash + export COMFY_CLOUD_API_KEY="comfyui-xxxxxxxxxxxx" + ``` +4. Run workflows: + ```bash + python3 scripts/run_workflow.py \ + --workflow workflows/flux_dev_txt2img.json \ + --args '{"prompt": "..."}' \ + --host https://cloud.comfy.org \ + --output-dir ./outputs + ``` + +**Pricing:** https://www.comfy.org/cloud/pricing +**Concurrent jobs:** Free/Standard 1, Creator 3, Pro 5. Free tier +**cannot run workflows via API** — only browse models. Paid subscription +required for `/api/prompt`, `/api/upload/*`, `/api/view`, etc. + +--- + +### Path B: ComfyUI Desktop (Windows / macOS) + +One-click installer for non-technical users. Currently Beta. + +**Docs:** https://docs.comfy.org/installation/desktop +- **Windows (NVIDIA):** https://download.comfy.org/windows/nsis/x64 +- **macOS (Apple Silicon):** https://comfy.org + +Linux is **not supported** for Desktop — use Path D. + +--- + +### Path C: ComfyUI Portable (Windows Only) + +**Docs:** https://docs.comfy.org/installation/comfyui_portable_windows + +Download from https://github.com/comfyanonymous/ComfyUI/releases, extract, +run `run_nvidia_gpu.bat`. Update via `update/update_comfyui_stable.bat`. + +--- + +### Path D: comfy-cli (All Platforms — Recommended for Agents) + +The official CLI is the best path for headless/automated setups. + +**Docs:** https://docs.comfy.org/comfy-cli/getting-started + +#### Install comfy-cli + +```bash +# Recommended: +pipx install comfy-cli +# Or use uvx without installing: +uvx --from comfy-cli comfy --help +# Or (if pipx/uvx unavailable): +pip install --user comfy-cli +``` + +Disable analytics non-interactively: +```bash +comfy --skip-prompt tracking disable +``` + +#### Install ComfyUI + +```bash +comfy --skip-prompt install --nvidia # NVIDIA (CUDA) +comfy --skip-prompt install --amd # AMD (ROCm, Linux) +comfy --skip-prompt install --m-series # Apple Silicon (MPS) +comfy --skip-prompt install --cpu # CPU only (slow) +comfy --skip-prompt install --nvidia --fast-deps # uv-based dep resolution +``` + +Default location: `~/comfy/ComfyUI` (Linux), `~/Documents/comfy/ComfyUI` +(macOS/Win). Override with `comfy --workspace /custom/path install`. + +#### Launch / verify + +```bash +comfy launch --background # background daemon on :8188 +comfy launch -- --listen 0.0.0.0 --port 8190 # LAN-accessible custom port +curl -s http://127.0.0.1:8188/system_stats # health check +``` + +--- + +### Path E: Manual Install (Advanced / Unsupported Hardware) + +For Ascend NPU, Cambricon MLU, Intel Arc, or other unsupported hardware. + +**Docs:** https://docs.comfy.org/installation/manual_install + +```bash +git clone https://github.com/comfyanonymous/ComfyUI.git +cd ComfyUI +pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130 +pip install -r requirements.txt +python main.py +``` + +--- + +### Post-Install: Download Models + +```bash +# SDXL (general purpose, ~6.5 GB) +comfy model download \ + --url "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors" \ + --relative-path models/checkpoints + +# SD 1.5 (lighter, ~4 GB, good for 6 GB cards) +comfy model download \ + --url "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors" \ + --relative-path models/checkpoints + +# Flux Dev fp8 (smaller variant, ~12 GB) +comfy model download \ + --url "https://huggingface.co/Comfy-Org/flux1-dev/resolve/main/flux1-dev-fp8.safetensors" \ + --relative-path models/checkpoints + +# CivitAI (set token first): +comfy model download \ + --url "https://civitai.com/api/download/models/128713" \ + --relative-path models/checkpoints \ + --set-civitai-api-token "YOUR_TOKEN" +``` + +List installed: `comfy model list`. + +### Post-Install: Install Custom Nodes + +```bash +comfy node install comfyui-impact-pack # popular utility pack +comfy node install comfyui-animatediff-evolved # video generation +comfy node install comfyui-controlnet-aux # ControlNet preprocessors +comfy node install comfyui-essentials # common helpers +comfy node update all +comfy node install-deps --workflow=workflow.json # install everything a workflow needs +``` + +### Post-Install: Verify + +```bash +python3 scripts/health_check.py +# → comfy_cli on PATH? server reachable? checkpoints? smoke test? + +python3 scripts/check_deps.py my_workflow.json +# → are this workflow's nodes/models/embeddings installed? + +python3 scripts/run_workflow.py \ + --workflow workflows/sd15_txt2img.json \ + --args '{"prompt": "test", "steps": 4}' \ + --output-dir ./test-outputs +``` + +## Image Upload (img2img / Inpainting) + +The simplest way is to use `--input-image` with `run_workflow.py`: + +```bash +python3 scripts/run_workflow.py \ + --workflow workflows/sdxl_img2img.json \ + --input-image image=./photo.png \ + --args '{"prompt": "make it cyberpunk", "denoise": 0.6}' +``` + +The flag uploads `photo.png`, then injects its server-side filename into +whatever schema parameter is named `image`. For inpainting, pass both: + +```bash +python3 scripts/run_workflow.py \ + --workflow workflows/sdxl_inpaint.json \ + --input-image image=./photo.png \ + --input-image mask_image=./mask.png \ + --args '{"prompt": "fill with flowers"}' +``` + +Manual upload via REST: +```bash +curl -X POST "http://127.0.0.1:8188/upload/image" \ + -F "image=@photo.png" -F "type=input" -F "overwrite=true" +# Returns: {"name": "photo.png", "subfolder": "", "type": "input"} + +# Cloud equivalent: +curl -X POST "https://cloud.comfy.org/api/upload/image" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" \ + -F "image=@photo.png" -F "type=input" -F "overwrite=true" +``` + +## Cloud Specifics + +- **Base URL:** `https://cloud.comfy.org` +- **Auth:** `X-API-Key` header (or `?token=KEY` for WebSocket) +- **API key:** set `$COMFY_CLOUD_API_KEY` once and the scripts pick it up automatically +- **Output download:** `/api/view` returns a 302 to a signed URL; the scripts + follow it and strip `X-API-Key` before fetching from the storage backend + (don't leak the API key to S3/CloudFront). +- **Endpoint differences from local ComfyUI:** + - `/api/object_info`, `/api/queue`, `/api/userdata` — **403 on free tier**; + paid only. + - `/history` is renamed to `/history_v2` on cloud (the scripts route + automatically). + - `/models/` is renamed to `/experiment/models/` on cloud + (the scripts route automatically). + - `clientId` in WebSocket is currently ignored — all connections for a + user receive the same broadcast. Filter by `prompt_id` client-side. + - `subfolder` is accepted on uploads but ignored — cloud has a flat namespace. +- **Concurrent jobs:** Free/Standard: 1, Creator: 3, Pro: 5. Extras queue + automatically. Use `run_batch.py --parallel N` to saturate your tier. + +## Queue & System Management + +```bash +# Local +curl -s http://127.0.0.1:8188/queue | python3 -m json.tool +curl -X POST http://127.0.0.1:8188/queue -d '{"clear": true}' # cancel pending +curl -X POST http://127.0.0.1:8188/interrupt # cancel running +curl -X POST http://127.0.0.1:8188/free \ + -H "Content-Type: application/json" \ + -d '{"unload_models": true, "free_memory": true}' + +# Cloud — same paths under /api/, plus: +python3 scripts/fetch_logs.py --tail-queue --host https://cloud.comfy.org +``` + +## Pitfalls + +1. **API format required** — every script and the `/api/prompt` endpoint expect + API-format workflow JSON. The scripts detect editor format (top-level + `nodes` and `links` arrays) and tell you to re-export via + "Workflow → Export (API)" (newer UI) or "Save (API Format)" (older UI). + +2. **Server must be running** — all execution requires a live server. + `comfy launch --background` starts one. Verify with + `curl http://127.0.0.1:8188/system_stats`. + +3. **Model names are exact** — case-sensitive, includes file extension. + `check_deps.py` does fuzzy matching (with/without extension and folder + prefix), but the workflow itself must use the canonical name. Use + `comfy model list` to discover what's installed. + +4. **Missing custom nodes** — "class_type not found" means a required node + isn't installed. `check_deps.py` reports which package to install; + `auto_fix_deps.py` runs the install for you. + +5. **Working directory** — `comfy-cli` auto-detects the ComfyUI workspace. + If commands fail with "no workspace found", use + `comfy --workspace /path/to/ComfyUI ` or + `comfy set-default /path/to/ComfyUI`. + +6. **Cloud free-tier API limits** — `/api/prompt`, `/api/view`, `/api/upload/*`, + `/api/object_info` all return 403 on free accounts. `health_check.py` and + `check_deps.py` handle this gracefully and surface a clear message. + +7. **Timeout for video/audio workflows** — auto-detected when an output node + is `VHS_VideoCombine`, `SaveVideo`, etc.; the default jumps from 300 s to + 900 s. Override explicitly with `--timeout 1800`. + +8. **Path traversal in output filenames** — server-supplied filenames are + passed through `safe_path_join` to refuse anything escaping `--output-dir`. + Keep this protection on — workflows with custom save nodes can produce + arbitrary paths. + +9. **Workflow JSON is arbitrary code** — custom nodes run Python, so + submitting an unknown workflow has the same trust profile as `eval`. + Inspect workflows from untrusted sources before running. + +10. **Auto-randomized seed** — pass `seed: -1` in `--args` (or use + `--randomize-seed` and omit the seed) to get a fresh seed per run. + The actual seed is logged to stderr. + +11. **`tracking` prompt** — first run of `comfy` may prompt for analytics. + Use `comfy --skip-prompt tracking disable` to skip non-interactively. + `comfyui_setup.sh` does this for you. + +## Verification Checklist + +Use `python3 scripts/health_check.py` to run the whole list at once. Manual: + +- [ ] `hardware_check.py` verdict is `ok` OR the user explicitly chose Comfy Cloud +- [ ] `comfy --version` works (or `uvx --from comfy-cli comfy --help`) +- [ ] `curl http://HOST:PORT/system_stats` returns JSON +- [ ] `comfy model list` shows at least one checkpoint (local) OR + `/api/experiment/models/checkpoints` returns models (cloud) +- [ ] Workflow JSON is in API format +- [ ] `check_deps.py` reports `is_ready: true` (or only `node_check_skipped` + on cloud free tier) +- [ ] Test run with a small workflow completes; outputs land in `--output-dir` diff --git a/skills/creative/comfyui/references/official-cli.md b/skills/creative/comfyui/references/official-cli.md new file mode 100644 index 0000000000..59a981b4a8 --- /dev/null +++ b/skills/creative/comfyui/references/official-cli.md @@ -0,0 +1,255 @@ +# comfy-cli Command Reference + +Official CLI from [Comfy-Org/comfy-cli](https://github.com/Comfy-Org/comfy-cli). +Docs: https://docs.comfy.org/comfy-cli/getting-started + +## Installation + +Order of preference: + +```bash +pipx install comfy-cli # recommended (isolated env) +uvx --from comfy-cli comfy --help # zero-install via uv +pip install --user comfy-cli # fallback +``` + +The skill's `comfyui_setup.sh` picks the best available method. + +First run may prompt for analytics. Disable non-interactively: +```bash +comfy --skip-prompt tracking disable +``` + +## Global Options + +| Option | Description | +|--------|-------------| +| `--workspace ` | Target a specific ComfyUI workspace | +| `--recent` | Use most recently used workspace | +| `--here` | Use current directory as workspace | +| `--skip-prompt` | No interactive prompts (use defaults) | +| `-v` / `--version` | Print version | + +Workspace resolution priority: +1. `--workspace` (explicit path) +2. `--recent` (from config) +3. `--here` (cwd) +4. `comfy set-default` path +5. Most recently used +6. `~/comfy/ComfyUI` (Linux) or `~/Documents/comfy/ComfyUI` (macOS/Win) + +## Lifecycle Commands + +### `comfy install` + +Download and install ComfyUI + ComfyUI-Manager. + +```bash +comfy install # interactive GPU selection +comfy install --nvidia +comfy install --amd # ROCm (Linux) +comfy install --m-series # Apple Silicon (MPS) +comfy install --cpu # CPU only (slow) +comfy install --fast-deps # use uv for deps +comfy install --skip-manager # skip ComfyUI-Manager +``` + +| Option | Description | +|--------|-------------| +| `--nvidia` / `--amd` / `--m-series` / `--cpu` | GPU type | +| `--cuda-version` | 11.8, 12.1, 12.4, 12.6, 12.8, 12.9, 13.0 | +| `--rocm-version` | 6.1, 6.2, 6.3, 7.0, 7.1 | +| `--fast-deps` | uv-based dependency resolution | +| `--skip-manager` | Don't install ComfyUI-Manager | +| `--skip-torch-or-directml` | Skip PyTorch install | +| `--version ` | `0.2.0`, `latest`, `nightly` | +| `--commit ` | Install specific commit | +| `--pr "#1234"` | Install from a PR | +| `--restore` | Restore deps for existing install | + +### `comfy launch` + +```bash +comfy launch # foreground :8188 +comfy launch --background # background daemon +comfy launch -- --listen 0.0.0.0 # LAN-accessible +comfy launch -- --port 8190 # custom port +comfy launch -- --cpu # force CPU mode +comfy launch -- --lowvram # 6 GB cards +comfy launch --background -- --listen 0.0.0.0 --port 8190 +``` + +Common extra args after `--`: `--listen`, `--port`, `--cpu`, `--lowvram`, +`--novram`, `--fp16-vae`, `--force-fp32`, `--disable-cuda-malloc`. + +### `comfy stop` + +```bash +comfy stop +``` + +### `comfy run` + +Submit a raw workflow JSON to a running server. **Limited** — no parameter +injection, no structured output download. For agents, use +`scripts/run_workflow.py` instead. + +```bash +comfy run --workflow workflow_api.json +comfy run --workflow workflow_api.json --host 10.0.0.5 --port 8188 +comfy run --workflow workflow_api.json --timeout 300 --wait +``` + +### `comfy which` + +```bash +comfy which # show targeted workspace +comfy --recent which +``` + +### `comfy set-default` + +```bash +comfy set-default /path/to/ComfyUI +comfy set-default /path/to/ComfyUI --launch-extras="--listen 0.0.0.0" +``` + +### `comfy update` + +```bash +comfy update # update ComfyUI core +comfy node update all # update all custom nodes +``` + +--- + +## `comfy node` — Custom Node Management + +All node operations use ComfyUI-Manager (`cm-cli`) under the hood. + +```bash +comfy node show installed # list installed +comfy node show enabled # list enabled +comfy node show all # all available in registry +comfy node simple-show installed # compact list + +comfy node install comfyui-impact-pack +comfy node install --uv-compile # ComfyUI-Manager v4.1+ unified resolver +comfy node uninstall +comfy node update | all +comfy node enable +comfy node disable +comfy node fix # fix broken deps + +comfy node install-deps --workflow=workflow.json +comfy node deps-in-workflow --workflow=w.json --output=deps.json + +comfy node save-snapshot +comfy node restore-snapshot + +comfy node bisect start # binary-search a culprit node +comfy node bisect good +comfy node bisect bad +comfy node bisect reset +``` + +### Dependency Resolution Options + +| Flag | Description | +|------|-------------| +| `--fast-deps` | comfy-cli built-in uv resolver | +| `--uv-compile` | ComfyUI-Manager v4.1+ unified resolver (recommended) | +| `--no-deps` | Skip dep installation | + +Make `uv-compile` default: `comfy manager uv-compile-default true` + +--- + +## `comfy model` — Model Management + +```bash +comfy model list +comfy model list --relative-path models/checkpoints + +comfy model download --url +comfy model download --url --relative-path models/loras +comfy model download --url --filename custom_name.safetensors + +comfy model remove # interactive +comfy model remove --relative-path models/checkpoints --model-names "model.safetensors" +``` + +| Option | Description | +|--------|-------------| +| `--url` | Download URL (CivitAI, HuggingFace, direct) | +| `--relative-path` | Subdirectory under workspace (e.g. `models/checkpoints`) | +| `--filename` | Custom save filename | +| `--set-civitai-api-token` | Persist CivitAI token | +| `--set-hf-api-token` | Persist HuggingFace token | +| `--downloader` | `httpx` (default) or `aria2` | + +Standard model directories: +``` +ComfyUI/models/ +├── checkpoints/ # Full model files +├── loras/ # LoRA adapters +├── vae/ # VAE models +├── controlnet/ # ControlNet models +├── clip/ # CLIP / T5 text encoders +├── clip_vision/ # CLIP vision encoders +├── upscale_models/ # ESRGAN / SwinIR / etc. +├── embeddings/ # Textual inversion embeddings +├── unet/ # Standalone UNet weights +├── diffusion_models/ # Flux / SD3 / Wan diffusion models +├── animatediff_models/ # AnimateDiff motion modules +├── ipadapter/ # IPAdapter weights +└── style_models/ # Style adapters +``` + +--- + +## `comfy manager` — ComfyUI-Manager Settings + +```bash +comfy manager disable # disable Manager completely +comfy manager enable-gui # enable new GUI +comfy manager disable-gui # API-only +comfy manager enable-legacy-gui # legacy GUI +comfy manager uv-compile-default true # make --uv-compile the default +comfy manager clear # clear startup action +``` + +--- + +## `comfy pr-cache` — Frontend PR Cache + +```bash +comfy pr-cache list +comfy pr-cache clean +comfy pr-cache clean 456 +``` + +Cache expires after 7 days; max 10 builds. + +--- + +## Configuration + +| OS | Path | +|----|------| +| Linux | `~/.config/comfy-cli/config.ini` | +| macOS | `~/Library/Application Support/comfy-cli/config.ini` | +| Windows | `~/AppData/Local/comfy-cli/config.ini` | + +Stores: default workspace, recent workspace, background server PID, API +tokens, manager GUI mode, launch extras. + +## Discovery + +Custom-node registry: +- https://registry.comfy.org/ + +Model browsers: +- https://huggingface.co/models +- https://civitai.com (NSFW; requires API token for many) +- https://comfyworkflows.com (community workflows) diff --git a/skills/creative/comfyui/references/rest-api.md b/skills/creative/comfyui/references/rest-api.md new file mode 100644 index 0000000000..64091c9d67 --- /dev/null +++ b/skills/creative/comfyui/references/rest-api.md @@ -0,0 +1,312 @@ +# ComfyUI REST + WebSocket API Reference + +ComfyUI exposes a REST + WebSocket interface for workflow execution and +management. **The same surface is used locally and on Comfy Cloud, with +auth/path differences.** + +## Connection + +| | Local ComfyUI | Comfy Cloud | +|---|---|---| +| Base URL | `http://127.0.0.1:8188` | `https://cloud.comfy.org` | +| API path prefix | none (`/prompt`, `/view`, …) | `/api/...` (`/api/prompt`, `/api/view`, …) | +| Auth | none (or bearer token if configured) | `X-API-Key` header | +| WebSocket | `ws://host:port/ws?clientId={uuid}` | `wss://cloud.comfy.org/ws?clientId={uuid}&token={API_KEY}` | +| `/api/view` response | direct bytes | 302 redirect → signed URL (use `curl -L`) | + +The skill scripts route URLs automatically via `_common.resolve_url()`. + +## Endpoint differences on Comfy Cloud + +The cloud surface diverges from local ComfyUI in several ways. The skill +scripts handle these transparently; document them here so anyone calling +`curl` directly knows. + +| Local path | Cloud path | Notes | +|------------|-----------|-------| +| `/system_stats` | `/api/system_stats` | Cloud version is **public** (no auth required) | +| `/object_info` | `/api/object_info` | **Paid tier only** — free returns 403 | +| `/queue` | `/api/queue` | Paid tier only | +| `/userdata` | `/api/userdata` | Paid tier only | +| `/prompt` (POST) | `/api/prompt` (POST) | Paid tier only | +| `/upload/image` | `/api/upload/image` | Paid tier only; `subfolder` accepted but ignored | +| `/upload/mask` | `/api/upload/mask` | Same as above | +| `/view` | `/api/view` | Paid tier only; **returns 302** to signed URL | +| `/history` | `/api/history_v2` | **Renamed**; old path returns 404 | +| `/history/{id}` | `/api/history_v2/{id}` or `/api/jobs/{id}` | Both work; `/jobs` returns full job | +| `/models` | `/api/experiment/models` | **Renamed** | +| `/models/{folder}` | `/api/experiment/models/{folder}` | **Renamed**; response shape differs (see below) | + +### Cloud model-list response shape + +- **Local:** `["a.safetensors", "b.safetensors", …]` — flat list of strings. +- **Cloud:** `[{"name": "a.safetensors", "pathIndex": 0}, …]` — list of objects. +- **Cloud 404 with `code: "folder_not_found"`** — folder is empty or unknown, + not an "endpoint missing" error. Distinguish by reading the body. + +The skill helper `_common.parse_model_list()` normalizes both. + +## Workflow Execution + +### Submit Workflow + +```bash +# Local +curl -X POST "http://127.0.0.1:8188/prompt" \ + -H "Content-Type: application/json" \ + -d '{"prompt": '"$(cat workflow_api.json)"', "client_id": "'"$(uuidgen)"'"}' + +# Cloud +curl -X POST "https://cloud.comfy.org/api/prompt" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"prompt": '"$(cat workflow_api.json)"'}' +``` + +**Response:** +```json +{"prompt_id": "abc-123-def", "number": 1, "node_errors": {}} +``` + +If `node_errors` is non-empty, the workflow has validation errors (missing +nodes, bad inputs). + +### Check Job Status (Cloud) + +```bash +curl -X GET "https://cloud.comfy.org/api/job/{prompt_id}/status" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" +``` + +| Status | Description | +| ------------- | ---------------------------------- | +| `pending` | Job is queued and waiting to start | +| `in_progress` | Job is currently executing | +| `completed` | Job finished successfully | +| `failed` | Job encountered an error | +| `cancelled` | Job was cancelled by user | + +### Job detail with outputs (Cloud) + +```bash +curl -X GET "https://cloud.comfy.org/api/jobs/{prompt_id}" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" +``` + +Response includes `outputs` keyed by node ID. Cloud uses `video` (singular) +in the output structure; local uses `videos` (plural). The skill scripts +accept both. + +### Get History (Local) + +```bash +curl -s "http://127.0.0.1:8188/history" # all +curl -s "http://127.0.0.1:8188/history/{id}" # one prompt_id +``` + +Local entry shape: +```json +{ + "": { + "prompt": [...], + "outputs": {"": {"images": [...]}}, + "status": { + "status_str": "success" | "error", + "completed": true | false, + "messages": [["execution_start", {...}], ["execution_error", {...}], …] + } + } +} +``` + +**Important:** when reading status, check `status_str == "error"` BEFORE +checking `completed`, because both can be true for failed runs. + +### Download Output + +```bash +# Local (direct bytes) +curl -s "http://127.0.0.1:8188/view?filename=ComfyUI_00001_.png&subfolder=&type=output" \ + -o output.png + +# Cloud (302 → signed URL; -L follows; STRIP X-API-Key for the second hop) +curl -L "https://cloud.comfy.org/api/view?filename=...&type=output" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" \ + -o output.png +``` + +The skill's `run_workflow.py` strips `X-API-Key` automatically on the +cross-host redirect, so the signed URL never sees your auth. + +## WebSocket Monitoring + +Connect for real-time execution events. + +```bash +# Local +wscat -c "ws://127.0.0.1:8188/ws?clientId=MY-UUID" + +# Cloud +wscat -c "wss://cloud.comfy.org/ws?clientId=MY-UUID&token=$COMFY_CLOUD_API_KEY" +``` + +**Note:** on Cloud the `clientId` is currently ignored — all messages for a +user are broadcast to every connection. Filter messages client-side by +`data.prompt_id`. + +### JSON Message Types + +| Type | When | Key Fields | +|------|------|------------| +| `status` | Queue change | `status.exec_info.queue_remaining` | +| `notification` | User-friendly status string | `value` | +| `execution_start` | Workflow begins | `prompt_id` | +| `executing` | Node running (or end-of-run if `node` is null on local) | `node`, `prompt_id` | +| `progress` | Sampling steps | `node`, `value`, `max` | +| `progress_state` | Extended progress with per-node metadata | `nodes` (dict) | +| `executed` | Node output ready | `node`, `output` (with `images`/`video`/etc.) | +| `execution_cached` | Nodes skipped because of cache | `nodes` (list of IDs) | +| `execution_success` | All done | `prompt_id` | +| `execution_error` | Failure | `exception_type`, `exception_message`, `traceback`, `node_id` | +| `execution_interrupted` | Cancelled | `prompt_id` | + +### Binary Frames (Preview Images) + +| Type code | Meaning | +|-----------|---------| +| `0x00000001` | `PREVIEW_IMAGE` — `[type:4][image_type:4][data]` (image_type 1=JPEG, 2=PNG) | +| `0x00000003` | `TEXT` — `[type:4][nid_len:4][nid][text]` (UTF-8) | +| `0x00000004` | `PREVIEW_IMAGE_WITH_METADATA` — `[type:4][meta_len:4][json][image_data]` | + +`scripts/ws_monitor.py --previews

` saves preview frames to disk. + +## File Upload + +```bash +# Image +curl -X POST "http://127.0.0.1:8188/upload/image" \ + -F "image=@photo.png" -F "type=input" -F "overwrite=true" +# Returns: {"name": "photo.png", "subfolder": "", "type": "input"} + +# Mask (linked to a previously uploaded image) +curl -X POST "http://127.0.0.1:8188/upload/mask" \ + -F "image=@mask.png" -F "type=input" \ + -F 'original_ref={"filename":"photo.png","subfolder":"","type":"input"}' +``` + +Cloud equivalent: prepend `https://cloud.comfy.org/api` and add `-H "X-API-Key: $COMFY_CLOUD_API_KEY"`. + +## Node & Model Discovery + +```bash +# All node types and their input specs +curl -s "http://127.0.0.1:8188/object_info" | python3 -m json.tool + +# Specific node +curl -s "http://127.0.0.1:8188/object_info/KSampler" + +# Models per folder (local) +curl -s "http://127.0.0.1:8188/models/checkpoints" +curl -s "http://127.0.0.1:8188/models/loras" + +# Models per folder (cloud — note the experimental prefix) +curl -s "https://cloud.comfy.org/api/experiment/models/checkpoints" \ + -H "X-API-Key: $COMFY_CLOUD_API_KEY" +``` + +## Queue Management + +```bash +# View queue +curl -s "http://127.0.0.1:8188/queue" + +# Clear all pending +curl -X POST "http://127.0.0.1:8188/queue" \ + -H "Content-Type: application/json" \ + -d '{"clear": true}' + +# Delete specific items +curl -X POST "http://127.0.0.1:8188/queue" \ + -H "Content-Type: application/json" \ + -d '{"delete": ["prompt_id_1", "prompt_id_2"]}' + +# Cancel currently-running job +curl -X POST "http://127.0.0.1:8188/interrupt" +``` + +## System Management + +```bash +# Stats (VRAM, RAM, GPU, ComfyUI version) +curl -s "http://127.0.0.1:8188/system_stats" + +# Free GPU memory +curl -X POST "http://127.0.0.1:8188/free" \ + -H "Content-Type: application/json" \ + -d '{"unload_models": true, "free_memory": true}' +``` + +## ComfyUI-Manager Endpoints (Optional) + +These require ComfyUI-Manager installed. Useful for installing nodes/models +via the API instead of `comfy-cli`. + +```bash +# Install a custom node from a git URL +curl -X POST "http://127.0.0.1:8188/manager/queue/install" \ + -H "Content-Type: application/json" \ + -d '{"git_url": "https://github.com/user/comfyui-node.git"}' + +# Check install queue status +curl -s "http://127.0.0.1:8188/manager/queue/status" + +# Install model +curl -X POST "http://127.0.0.1:8188/manager/queue/install_model" \ + -H "Content-Type: application/json" \ + -d '{"url": "https://...", "path": "models/checkpoints", "filename": "model.safetensors"}' +``` + +## POST /prompt Payload Format + +```json +{ + "prompt": { + "3": { + "class_type": "KSampler", + "inputs": { + "seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0] + } + } + }, + "client_id": "unique-uuid-for-ws-filtering", + "extra_data": { + "api_key_comfy_org": "optional-PARTNER-NODE-key (NOT the cloud auth key)" + } +} +``` + +- `prompt`: workflow graph in API format +- `client_id`: UUID — local server uses it to filter WebSocket events; cloud + ignores it. +- `extra_data.api_key_comfy_org`: ONLY required when the workflow uses + partner nodes (Flux Pro, Ideogram, etc.). Don't conflate with `X-API-Key`. + +## Error Categories (cloud `execution_error` `exception_type`) + +| Type | Meaning | +|------|---------| +| `ValidationError` | Bad workflow / inputs (often nicer to surface from `node_errors`) | +| `ModelDownloadError` | Required model not available | +| `ImageDownloadError` | Failed to fetch input image from URL | +| `OOMError` | Out of GPU memory | +| `InsufficientFundsError` | Account balance too low (partner nodes) | +| `InactiveSubscriptionError` | Subscription not active | diff --git a/skills/creative/comfyui/references/workflow-format.md b/skills/creative/comfyui/references/workflow-format.md new file mode 100644 index 0000000000..e8343de73c --- /dev/null +++ b/skills/creative/comfyui/references/workflow-format.md @@ -0,0 +1,226 @@ +# ComfyUI Workflow JSON Format + +## Two Formats — Only API Format Is Executable + +**API format** is required for `/api/prompt` and every script in this skill. +The web UI also produces an "editor format" used for visual editing, which +**cannot** be submitted directly. + +### API Format + +Top-level keys are string node IDs. Each node has `class_type` and `inputs`: + +```json +{ + "3": { + "class_type": "KSampler", + "inputs": { + "seed": 156680208700286, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0] + }, + "_meta": {"title": "KSampler"} + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"} + } +} +``` + +**Detection:** every top-level value has `class_type`. The skill's +`_common.is_api_format()` does this check. + +### Editor Format (not directly executable) + +Has `nodes[]` and `links[]` arrays — the visual graph. To convert: open in +ComfyUI's web UI and use **Workflow → Export (API)** (newer UI) or the +"Save (API Format)" button (older UI). + +**Detection:** top-level has `"nodes"` and `"links"` keys. + +## Inputs: Literals vs Links + +```json +"inputs": { + "text": "a cat", // literal — modifiable + "seed": 42, // literal — modifiable + "clip": ["4", 1] // link — wiring; do NOT overwrite +} +``` + +Links are length-2 arrays of `[upstream_node_id, output_slot]`. The skill's +parameter injector refuses to overwrite a link with a literal (logs a +warning and skips). + +## Common Node Types and Their Controllable Parameters + +The full catalog lives in `scripts/_common.py` (`PARAM_PATTERNS` and +`MODEL_LOADERS`). Highlights: + +### Text Prompts + +| Node Class | Key Fields | +|------------|------------| +| `CLIPTextEncode` | `text` | +| `CLIPTextEncodeSDXL` | `text_g`, `text_l`, `width`, `height` | +| `CLIPTextEncodeFlux` | `clip_l`, `t5xxl`, `guidance` | + +To distinguish positive from negative the skill traces `KSampler.negative` +back through Reroute / Primitive nodes to the source CLIPTextEncode. Falls +back to `_meta.title` heuristics ("negative", "neg", "anti"). + +### Sampling + +| Node Class | Key Fields | +|------------|------------| +| `KSampler` | `seed`, `steps`, `cfg`, `sampler_name`, `scheduler`, `denoise` | +| `KSamplerAdvanced` | `noise_seed`, `steps`, `cfg`, `start_at_step`, `end_at_step` | +| `SamplerCustom` | `noise_seed`, `cfg`, `sampler`, `sigmas` | +| `SamplerCustomAdvanced` | `noise_seed` (via RandomNoise input) | +| `RandomNoise` | `noise_seed` | +| `BasicScheduler` | `steps`, `scheduler`, `denoise` | +| `KSamplerSelect` | `sampler_name` | +| `BasicGuider` / `CFGGuider` | `cfg` | +| `ModelSamplingFlux` | `max_shift`, `base_shift`, `width`, `height` | +| `SDTurboScheduler` | `steps`, `denoise` | + +### Latent / Dimensions + +| Node Class | Key Fields | +|------------|------------| +| `EmptyLatentImage` | `width`, `height`, `batch_size` | +| `EmptySD3LatentImage` | `width`, `height`, `batch_size` | +| `EmptyHunyuanLatentVideo` | `width`, `height`, `length`, `batch_size` | +| `EmptyMochiLatentVideo` | `width`, `height`, `length`, `batch_size` | +| `EmptyLTXVLatentVideo` | `width`, `height`, `length`, `batch_size` | + +### Model Loading + +| Node Class | Key Fields | Folder | +|------------|------------|--------| +| `CheckpointLoaderSimple` | `ckpt_name` | `checkpoints` | +| `LoraLoader` | `lora_name`, `strength_model`, `strength_clip` | `loras` | +| `LoraLoaderModelOnly` | `lora_name`, `strength_model` | `loras` | +| `VAELoader` | `vae_name` | `vae` | +| `ControlNetLoader` | `control_net_name` | `controlnet` | +| `CLIPLoader` | `clip_name` | `clip` | +| `DualCLIPLoader` | `clip_name1`, `clip_name2` | `clip` | +| `TripleCLIPLoader` | `clip_name1/2/3` | `clip` | +| `UNETLoader` | `unet_name` | `unet` | +| `DiffusionModelLoader` | `model_name` | `diffusion_models` | +| `UpscaleModelLoader` | `model_name` | `upscale_models` | +| `IPAdapterModelLoader` | `ipadapter_file` | `ipadapter` | +| `ADE_AnimateDiffLoaderWithContext` | `model_name`, `motion_scale` | `animatediff_models` | + +### Image Input/Output + +| Node Class | Key Fields | +|------------|------------| +| `LoadImage` | `image` (server-side filename, after upload) | +| `LoadImageMask` | `image`, `channel` (`red` / `green` / `blue` / `alpha`) | +| `VAEEncode` / `VAEDecode` | (no controllable fields) | +| `VAEEncodeForInpaint` | `grow_mask_by` | +| `SaveImage` | `filename_prefix` | +| `VHS_VideoCombine` | `frame_rate`, `format`, `filename_prefix`, `loop_count`, `pingpong` | + +### ControlNet + +| Node Class | Key Fields | +|------------|------------| +| `ControlNetApply` | `strength` | +| `ControlNetApplyAdvanced` | `strength`, `start_percent`, `end_percent` | + +### IPAdapter (community pack `comfyui_ipadapter_plus`) + +| Node Class | Key Fields | +|------------|------------| +| `IPAdapterAdvanced` | `weight`, `start_at`, `end_at` | +| `IPAdapter` | `weight` | + +### Embeddings (referenced inside prompt strings) + +ComfyUI scans prompt text for `embedding:NAME` syntax. The skill's +`_common.iter_embedding_refs()` extracts these as model dependencies. + +```text +"a beautiful cat, embedding:goodvibes:1.2, embedding:art-style" +``` + +`extract_schema.py` and `check_deps.py` surface these in +`embedding_dependencies` / `missing_embeddings`. + +## Parameter Injection Pattern + +```python +import json, copy + +with open("workflow_api.json") as f: + workflow = json.load(f) + +wf = copy.deepcopy(workflow) +wf["6"]["inputs"]["text"] = "a beautiful sunset" +wf["7"]["inputs"]["text"] = "ugly, blurry" +wf["3"]["inputs"]["seed"] = 42 +wf["3"]["inputs"]["steps"] = 30 +wf["5"]["inputs"]["width"] = 1024 +wf["5"]["inputs"]["height"] = 1024 +``` + +`scripts/extract_schema.py` automates discovering which node IDs/fields +correspond to which user-facing parameters. It returns a `parameters` dict +that `run_workflow.py` reads to inject values from `--args`. + +## Identifying Controllable Parameters (Heuristics) + +For unknown workflows: + +1. **Prompt text** — any `CLIPTextEncode.text`. Use connection tracing back + from `KSampler.positive` / `.negative` to disambiguate (don't trust + meta-title alone). +2. **Seed** — `KSampler.seed` / `KSamplerAdvanced.noise_seed` / `RandomNoise.noise_seed`. +3. **Dimensions** — `Empty*LatentImage.width/height` (must be multiples of 8). +4. **Steps / CFG** — `KSampler.steps`, `KSampler.cfg`. Steps 20–50 typical. + CFG 5–15 typical (Flux uses guidance, not CFG). +5. **Model / checkpoint** — `CheckpointLoaderSimple.ckpt_name`. Filename must + match an installed file *exactly*. +6. **LoRA** — `LoraLoader.lora_name`, `.strength_model`. +7. **Images for img2img / inpaint** — `LoadImage.image`. Server-side filename + after upload. +8. **Denoise** — `KSampler.denoise`. 0.0–1.0; 1.0 = ignore input image, + 0.0 = pass through. Sweet spot for img2img: 0.4–0.7. + +## Output Nodes + +Output is produced by these node types. The skill's `OUTPUT_NODES` set +extends to common community packs. + +| Node | Output Key | Content | +|------|-----------|---------| +| `SaveImage` | `images` | List of `{filename, subfolder, type}` | +| `PreviewImage` | `images` | Temporary preview (not saved) | +| `VHS_VideoCombine` | `gifs` (older) or `videos`/`video` (newer cloud) | Video file refs | +| `SaveAudio` | `audio` | Audio file refs | +| `SaveAnimatedWEBP` / `SaveAnimatedPNG` | `images` | Animated images | +| `Save3D` | `3d` | 3D asset refs | + +After execution, fetch outputs from `/history/{prompt_id}` (local) or +`/api/jobs/{prompt_id}` (cloud) → `outputs` → `{node_id}` → `{key}`. + +## Wrapper Variants + +Some saved JSON files wrap the workflow under a `"prompt"` key (matching +the `/api/prompt` payload shape). The skill's `_common.unwrap_workflow()` +handles this — pass any of: + +- raw API format: `{"3": {...}, "4": {...}}` +- wrapped: `{"prompt": {"3": {...}}, "client_id": "..."}` + +It rejects editor format with a clear error and a re-export instruction. diff --git a/skills/creative/comfyui/scripts/_common.py b/skills/creative/comfyui/scripts/_common.py new file mode 100644 index 0000000000..ef742733eb --- /dev/null +++ b/skills/creative/comfyui/scripts/_common.py @@ -0,0 +1,835 @@ +""" +_common.py — Shared logic for ComfyUI skill scripts. + +Single source of truth for: +- HTTP transport (with retry/backoff, streaming, timeout handling) +- Cloud detection and endpoint mapping (local ComfyUI vs Comfy Cloud) +- Workflow node-type catalogs (param patterns, model loaders, output nodes) +- API-format validation +- Path-traversal-safe file writes +- API-key loading from env / CLI + +Stdlib-only by design (with optional `requests` upgrade if installed). Python 3.10+. +""" + +from __future__ import annotations + +import json +import os +import random +import re +import sys +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterator +from urllib.parse import urlparse + +# Optional: prefer `requests` if installed (better redirects, streaming, header handling) +try: + import requests # type: ignore[import-not-found] + HAS_REQUESTS = True +except ImportError: # pragma: no cover - exercised via stdlib fallback + HAS_REQUESTS = False + import urllib.error + import urllib.request + + +# ============================================================================= +# Constants & catalogs +# ============================================================================= + +DEFAULT_LOCAL_HOST = "http://127.0.0.1:8188" +DEFAULT_CLOUD_HOST = "https://cloud.comfy.org" +ENV_API_KEY = "COMFY_CLOUD_API_KEY" + +# Connection / retry defaults +DEFAULT_HTTP_TIMEOUT = 60 # seconds — single-attempt request timeout +DEFAULT_RETRIES = 3 # total attempts including the first +RETRY_BASE_DELAY = 1.0 # seconds — exponential backoff base +RETRY_MAX_DELAY = 30.0 # seconds — cap on backoff +RETRY_STATUS_CODES = {408, 429, 500, 502, 503, 504, 522, 524} + +# Streaming download chunk size (bytes) +DOWNLOAD_CHUNK_SIZE = 1 << 16 # 64 KiB + +# Heuristic: workflows with these node types tend to be slow → larger default timeout +SLOW_OUTPUT_NODES = { + "VHS_VideoCombine", "SaveAnimatedWEBP", "SaveAnimatedPNG", + "SaveVideo", "SaveAudio", "SaveAnimateDiffVideo", + "SVD_img2vid_Conditioning", + "WanVideoSampler", "HunyuanVideoSampler", + "CogVideoSampler", "LTXVideoSampler", +} + +# --------------------------------------------------------------------------- +# Output node catalog (extensible — community packs add their own) +# --------------------------------------------------------------------------- +OUTPUT_NODES: set[str] = { + # Built-in + "SaveImage", "PreviewImage", + "SaveAudio", "SaveVideo", "PreviewAudio", "PreviewVideo", + "SaveAnimatedWEBP", "SaveAnimatedPNG", + # Common community packs + "VHS_VideoCombine", # Video Helper Suite + "ImageSave", # Was Node Suite + "Image Save", # Was Node Suite (alt name) + "easy imageSave", # easy-use + "Image Save With Metadata", + "PreviewImage|pysssss", # pysssss preview + "ShowText|pysssss", + "SaveLatent", + "SaveGLB", # 3D + "Save3D", +} + +# --------------------------------------------------------------------------- +# Folder aliases — handle ComfyUI's gradual folder renames +# --------------------------------------------------------------------------- +# When `check_deps.py` queries `/models/` and gets 404 / empty, +# it tries each alias in turn. Critical for Comfy Cloud which has fully +# migrated to the new naming (unet → diffusion_models, clip → text_encoders). +FOLDER_ALIASES: dict[str, list[str]] = { + "unet": ["unet", "diffusion_models"], + "diffusion_models": ["diffusion_models", "unet"], + "clip": ["clip", "text_encoders"], + "text_encoders": ["text_encoders", "clip"], + "controlnet": ["controlnet", "control_net"], +} + + +def folder_aliases_for(folder: str) -> list[str]: + """Return the search order of folder names (primary first).""" + return FOLDER_ALIASES.get(folder, [folder]) + + +# --------------------------------------------------------------------------- +# Model-loader catalog: class_type -> (input field, model folder) +# --------------------------------------------------------------------------- +# A loader can have multiple fields (e.g., DualCLIPLoader has clip_name1 and +# clip_name2). We list them with explicit entries. The folder name is the +# *canonical* one; FOLDER_ALIASES is consulted when querying. +MODEL_LOADERS: dict[str, list[tuple[str, str]]] = { + # Checkpoints + "CheckpointLoaderSimple": [("ckpt_name", "checkpoints")], + "CheckpointLoader": [("ckpt_name", "checkpoints")], + "CheckpointLoader (Simple)": [("ckpt_name", "checkpoints")], + "ImageOnlyCheckpointLoader": [("ckpt_name", "checkpoints")], + "unCLIPCheckpointLoader": [("ckpt_name", "checkpoints")], + # LoRA + "LoraLoader": [("lora_name", "loras")], + "LoraLoaderModelOnly": [("lora_name", "loras")], + "LoraLoaderTagsQuery": [("lora_name", "loras")], + # VAE + "VAELoader": [("vae_name", "vae")], + # ControlNet + "ControlNetLoader": [("control_net_name", "controlnet")], + "DiffControlNetLoader": [("control_net_name", "controlnet")], + "ControlNetLoaderAdvanced": [("control_net_name", "controlnet")], + # CLIP / text encoders (primary "clip" folder; check_deps tries text_encoders too) + "CLIPLoader": [("clip_name", "clip")], + "DualCLIPLoader": [("clip_name1", "clip"), ("clip_name2", "clip")], + "TripleCLIPLoader": [("clip_name1", "clip"), ("clip_name2", "clip"), ("clip_name3", "clip")], + "CLIPVisionLoader": [("clip_name", "clip_vision")], + # UNET / Diffusion model (primary "unet"; check_deps tries diffusion_models too) + "UNETLoader": [("unet_name", "unet")], + "DiffusionModelLoader": [("model_name", "diffusion_models")], + "UNETLoaderGGUF": [("unet_name", "unet")], + # Upscaler + "UpscaleModelLoader": [("model_name", "upscale_models")], + # Style / GLIGEN / Hypernetwork + "StyleModelLoader": [("style_model_name", "style_models")], + "GLIGENLoader": [("gligen_name", "gligen")], + "HypernetworkLoader": [("hypernetwork_name", "hypernetworks")], + # IPAdapter family (community). + # Note: IPAdapterUnifiedLoader's `preset` and IPAdapterInsightFaceLoader's + # `provider` are enums (not file paths), so they're intentionally omitted — + # check_deps would otherwise treat enum values as missing model files. + "IPAdapterModelLoader": [("ipadapter_file", "ipadapter")], + "InstantIDModelLoader": [("instantid_file", "instantid")], + # AnimateDiff / video + "ADE_LoadAnimateDiffModel": [("model_name", "animatediff_models")], + "ADE_AnimateDiffLoaderWithContext": [("model_name", "animatediff_models")], + "ADE_AnimateDiffLoaderGen1": [("model_name", "animatediff_models")], + # Photomaker + "PhotoMakerLoader": [("photomaker_model_name", "photomaker")], + # Sampler / scheduler models + "ModelSamplingFlux": [], # parametric only +} + +# --------------------------------------------------------------------------- +# Param patterns: (class_type, field_name) -> friendly_name +# Order matters — first match wins for naming. Use _meta.title for disambiguation. +# --------------------------------------------------------------------------- +PARAM_PATTERNS: list[tuple[str, str, str]] = [ + # ---- Prompts ---- + ("CLIPTextEncode", "text", "prompt"), + ("CLIPTextEncodeSDXL", "text_g", "prompt"), + ("CLIPTextEncodeSDXL", "text_l", "prompt_l"), + ("CLIPTextEncodeSDXLRefiner", "text", "refiner_prompt"), + ("CLIPTextEncodeFlux", "clip_l", "prompt_l"), + ("CLIPTextEncodeFlux", "t5xxl", "prompt"), + ("CLIPTextEncodeFlux", "guidance", "guidance"), + ("smZ CLIPTextEncode", "text", "prompt"), + ("BNK_CLIPTextEncodeAdvanced", "text", "prompt"), + + # ---- Standard sampling ---- + ("KSampler", "seed", "seed"), + ("KSampler", "steps", "steps"), + ("KSampler", "cfg", "cfg"), + ("KSampler", "sampler_name", "sampler_name"), + ("KSampler", "scheduler", "scheduler"), + ("KSampler", "denoise", "denoise"), + ("KSamplerAdvanced", "noise_seed", "seed"), + ("KSamplerAdvanced", "steps", "steps"), + ("KSamplerAdvanced", "cfg", "cfg"), + ("KSamplerAdvanced", "sampler_name", "sampler_name"), + ("KSamplerAdvanced", "scheduler", "scheduler"), + ("KSamplerAdvanced", "start_at_step", "start_at_step"), + ("KSamplerAdvanced", "end_at_step", "end_at_step"), + + # ---- Modern sampler chain (Flux / SD3 / SDXL refiner via SamplerCustom) ---- + ("RandomNoise", "noise_seed", "seed"), + ("BasicScheduler", "steps", "steps"), + ("BasicScheduler", "scheduler", "scheduler"), + ("BasicScheduler", "denoise", "denoise"), + ("KSamplerSelect", "sampler_name", "sampler_name"), + # NB: BasicGuider has no cfg input (it just bundles model+conditioning). + ("CFGGuider", "cfg", "cfg"), + ("DualCFGGuider", "cfg_conds", "cfg"), + ("DualCFGGuider", "cfg_cond2_negative", "cfg_negative"), + ("ModelSamplingFlux", "max_shift", "max_shift"), + ("ModelSamplingFlux", "base_shift", "base_shift"), + ("ModelSamplingFlux", "width", "model_width"), + ("ModelSamplingFlux", "height", "model_height"), + ("ModelSamplingSD3", "shift", "shift"), + ("ModelSamplingDiscrete", "sampling", "sampling"), + ("SDTurboScheduler", "steps", "steps"), + ("SDTurboScheduler", "denoise", "denoise"), + ("SamplerCustom", "noise_seed", "seed"), + ("SamplerCustom", "cfg", "cfg"), + # NB: SamplerCustomAdvanced takes a NOISE input (from RandomNoise) — no seed field directly. + + # ---- Dimensions / latent ---- + ("EmptyLatentImage", "width", "width"), + ("EmptyLatentImage", "height", "height"), + ("EmptyLatentImage", "batch_size", "batch_size"), + ("EmptySD3LatentImage", "width", "width"), + ("EmptySD3LatentImage", "height", "height"), + ("EmptySD3LatentImage", "batch_size", "batch_size"), + ("EmptyHunyuanLatentVideo", "width", "width"), + ("EmptyHunyuanLatentVideo", "height", "height"), + ("EmptyHunyuanLatentVideo", "length", "length"), + ("EmptyHunyuanLatentVideo", "batch_size", "batch_size"), + ("EmptyMochiLatentVideo", "width", "width"), + ("EmptyMochiLatentVideo", "height", "height"), + ("EmptyMochiLatentVideo", "length", "length"), + ("EmptyLTXVLatentVideo", "width", "width"), + ("EmptyLTXVLatentVideo", "height", "height"), + ("EmptyLTXVLatentVideo", "length", "length"), + ("LatentUpscale", "width", "upscale_width"), + ("LatentUpscale", "height", "upscale_height"), + ("LatentUpscaleBy", "scale_by", "scale_by"), + ("ImageScale", "width", "width"), + ("ImageScale", "height", "height"), + + # ---- Image input ---- + ("LoadImage", "image", "image"), + ("LoadImageMask", "image", "mask_image"), + ("LoadImageOutput", "image", "image"), + ("VHS_LoadVideo", "video", "video"), + ("VHS_LoadAudio", "audio", "audio"), + + # ---- Model selection (sometimes useful to swap per run) ---- + ("CheckpointLoaderSimple", "ckpt_name", "ckpt_name"), + ("CheckpointLoader", "ckpt_name", "ckpt_name"), + ("ImageOnlyCheckpointLoader", "ckpt_name", "ckpt_name"), + ("VAELoader", "vae_name", "vae_name"), + ("UNETLoader", "unet_name", "unet_name"), + ("DiffusionModelLoader", "model_name", "diffusion_model_name"), + ("UpscaleModelLoader", "model_name", "upscale_model_name"), + ("CLIPLoader", "clip_name", "clip_name"), + ("DualCLIPLoader", "clip_name1", "clip_name1"), + ("DualCLIPLoader", "clip_name2", "clip_name2"), + ("ControlNetLoader", "control_net_name", "controlnet_name"), + + # ---- LoRA ---- + ("LoraLoader", "lora_name", "lora_name"), + ("LoraLoader", "strength_model", "lora_strength"), + ("LoraLoader", "strength_clip", "lora_strength_clip"), + ("LoraLoaderModelOnly", "lora_name", "lora_name"), + ("LoraLoaderModelOnly", "strength_model", "lora_strength"), + + # ---- ControlNet ---- + ("ControlNetApply", "strength", "controlnet_strength"), + ("ControlNetApplyAdvanced", "strength", "controlnet_strength"), + ("ControlNetApplyAdvanced", "start_percent", "controlnet_start"), + ("ControlNetApplyAdvanced", "end_percent", "controlnet_end"), + + # ---- IPAdapter ---- + ("IPAdapterAdvanced", "weight", "ipadapter_weight"), + ("IPAdapterAdvanced", "start_at", "ipadapter_start"), + ("IPAdapterAdvanced", "end_at", "ipadapter_end"), + ("IPAdapter", "weight", "ipadapter_weight"), + + # ---- Upscale ---- + ("ImageUpscaleWithModel", "upscale_method", "upscale_method"), + + # ---- AnimateDiff ---- + ("ADE_AnimateDiffLoaderWithContext", "motion_scale", "motion_scale"), + ("ADE_AnimateDiffLoaderGen1", "motion_scale", "motion_scale"), + + # ---- Video / Save ---- + ("VHS_VideoCombine", "frame_rate", "frame_rate"), + ("VHS_VideoCombine", "format", "video_format"), + ("VHS_VideoCombine", "filename_prefix", "filename_prefix"), + ("SaveImage", "filename_prefix", "filename_prefix"), + + # ---- Hunyuan / Wan / LTX video ---- + ("HunyuanVideoSampler", "seed", "seed"), + ("HunyuanVideoSampler", "steps", "steps"), + ("HunyuanVideoSampler", "cfg", "cfg"), + ("WanVideoSampler", "seed", "seed"), + ("WanVideoSampler", "steps", "steps"), + ("WanVideoSampler", "cfg", "cfg"), + ("LTXVScheduler", "max_shift", "max_shift"), + ("LTXVScheduler", "base_shift", "base_shift"), + + # ---- rgthree primitives (often used as user-facing inputs) ---- + ("Seed (rgthree)", "seed", "seed"), + ("Image Comparer (rgthree)", "image_a", "image"), + ("Power Lora Loader (rgthree)", "PowerLoraLoaderHeaderWidget", "_lora_header"), + + # ---- Easy-use / utility primitives ---- + ("PrimitiveNode", "value", "primitive_value"), + ("easy seed", "seed", "seed"), + ("easy positive", "positive", "prompt"), + ("easy negative", "negative", "negative_prompt"), + ("easy fullLoader", "ckpt_name", "ckpt_name"), + ("easy fullLoader", "vae_name", "vae_name"), + ("easy fullLoader", "lora_name", "lora_name"), + ("easy fullLoader", "positive", "prompt"), + ("easy fullLoader", "negative", "negative_prompt"), +] + +# Prompt-like fields whose value should be scanned for embedding references +PROMPT_FIELDS = {"text", "text_g", "text_l", "t5xxl", "clip_l", "positive", "negative"} + +# Pattern matches: embedding:name, embedding:name.pt, embedding:name:1.2, (embedding:name:1.2) +# Word-boundary at start avoids matching things like "no_embedding:foo". +EMBEDDING_REGEX = re.compile( + r"(?:^|[\s,(\[])embedding\s*:\s*([A-Za-z0-9_\-\./\\]+?)(?:\.(?:pt|safetensors|bin))?(?=[\s:,)\(\]]|$)", + re.IGNORECASE, +) + + +# ============================================================================= +# Cloud detection & endpoint routing +# ============================================================================= + +CLOUD_DOMAIN_SUFFIXES = (".comfy.org",) +CLOUD_DOMAIN_EXACT = {"cloud.comfy.org"} + + +def is_cloud_host(host: str) -> bool: + """True if the host points at Comfy Cloud (or staging/preview subdomain).""" + parsed = urlparse(host if "://" in host else f"http://{host}") + hostname = (parsed.hostname or "").lower() + if hostname in CLOUD_DOMAIN_EXACT: + return True + return any(hostname.endswith(s) for s in CLOUD_DOMAIN_SUFFIXES) + + +def build_cloud_aware_url(base: str, path: str, *, force_cloud: bool | None = None) -> str: + """Build a URL that adds /api prefix when targeting Comfy Cloud. + + Local ComfyUI accepts both `/foo` and `/api/foo` for many endpoints. + Cloud requires `/api/foo`. + + `path` should be a path component (e.g. "/prompt") or full path with query + (e.g. "/view?filename=x"). + """ + base = base.rstrip("/") + cloud = is_cloud_host(base) if force_cloud is None else force_cloud + if not path.startswith("/"): + path = "/" + path + if cloud and not path.startswith("/api/"): + path = "/api" + path + return base + path + + +def cloud_endpoint(path: str) -> str: + """Map a cloud endpoint path to its current canonical form. + + Handles known renames documented in the Comfy Cloud API: + /history -> /history_v2 + /models/ -> /experiment/models/ + /models -> /experiment/models + """ + if path.startswith("/history") and not path.startswith("/history_v2"): + return "/history_v2" + path[len("/history"):] + if path.startswith("/models/"): + return "/experiment/models/" + path[len("/models/"):] + if path == "/models": + return "/experiment/models" + return path + + +def resolve_url(base: str, path: str, *, is_cloud: bool | None = None) -> str: + """Top-level URL resolver. Applies cloud rename + /api prefix as needed.""" + cloud = is_cloud_host(base) if is_cloud is None else is_cloud + if cloud: + path = cloud_endpoint(path) + return build_cloud_aware_url(base, path, force_cloud=cloud) + + +# ============================================================================= +# API key resolution +# ============================================================================= + +def resolve_api_key(explicit: str | None) -> str | None: + """Look up API key from CLI flag → env var. Strips whitespace and quotes.""" + val = explicit if explicit else os.environ.get(ENV_API_KEY) + if val is None: + return None + val = val.strip().strip("'\"") + return val or None + + +# ============================================================================= +# HTTP transport +# ============================================================================= + +@dataclass +class HTTPResponse: + status: int + headers: dict[str, str] + body: bytes + url: str # final URL after redirects + + def text(self, encoding: str = "utf-8") -> str: + return self.body.decode(encoding, errors="replace") + + def json(self) -> Any: + return json.loads(self.body.decode("utf-8", errors="replace")) + + +def _sleep_backoff(attempt: int, base: float = RETRY_BASE_DELAY, cap: float = RETRY_MAX_DELAY) -> None: + """Sleep with full-jitter exponential backoff.""" + delay = min(cap, base * (2 ** attempt)) + delay = random.uniform(0, delay) + time.sleep(delay) + + +def http_request( + method: str, + url: str, + *, + headers: dict[str, str] | None = None, + json_body: Any = None, + data: bytes | None = None, + files: dict | None = None, + form: dict | None = None, + timeout: float = DEFAULT_HTTP_TIMEOUT, + follow_redirects: bool = True, + retries: int = DEFAULT_RETRIES, + stream: bool = False, + sink: Path | None = None, +) -> HTTPResponse: + """Single entry point for all HTTP traffic. + + Behavior: + - Retries on connection errors and on HTTP statuses in RETRY_STATUS_CODES, + with exponential backoff + jitter. + - For cross-host redirects, drops Authorization-style headers (so signed + URLs don't leak the API key to S3/CloudFront). + - When `stream=True` and `sink` is a Path, streams the response body to + disk in 64 KiB chunks instead of buffering. + + Either `json_body`, `data`, or `files`+`form` may be supplied (mutually exclusive). + """ + if headers is None: + headers = {} + headers = dict(headers) # copy + headers.setdefault("User-Agent", "hermes-comfyui-skill/5.0") + + if files or form is not None: + # Multipart upload — needs `requests`. The stdlib fallback lacks + # multipart encoding helpers; raise a clear error. + if not HAS_REQUESTS: + raise RuntimeError( + "Multipart upload requires the `requests` package. " + "Install with: pip install requests" + ) + + last_exc: Exception | None = None + for attempt in range(retries): + try: + resp = _http_once( + method=method, url=url, headers=headers, + json_body=json_body, data=data, files=files, form=form, + timeout=timeout, follow_redirects=follow_redirects, + stream=stream, sink=sink, + ) + if resp.status in RETRY_STATUS_CODES and attempt + 1 < retries: + _sleep_backoff(attempt) + continue + return resp + except (TimeoutError, ConnectionError, OSError) as e: + last_exc = e + if attempt + 1 < retries: + _sleep_backoff(attempt) + continue + raise + + # Should not reach here unless retries was 0 + if last_exc: + raise last_exc + raise RuntimeError("http_request: retries exhausted with no response") + + +_SENSITIVE_HEADERS = ("x-api-key", "authorization", "cookie") + + +if HAS_REQUESTS: + class _StripSensitiveOnRedirectSession(requests.Session): + """Session that drops sensitive headers on cross-host redirects. + + `requests` already strips `Authorization` cross-host (rebuild_auth), + but it does NOT strip custom headers like `X-API-Key`. We override + `rebuild_auth` to additionally strip every header in + `_SENSITIVE_HEADERS` when the destination is a different host — + critical when ComfyUI Cloud's `/api/view` redirects to a signed S3 URL. + """ + + def rebuild_auth(self, prepared_request, response): # type: ignore[override] + super().rebuild_auth(prepared_request, response) + try: + old_url = response.request.url + new_url = prepared_request.url + old_host = (urlparse(old_url).hostname or "").lower() + new_host = (urlparse(new_url).hostname or "").lower() + if old_host and new_host and old_host != new_host: + headers = prepared_request.headers + for key in list(headers.keys()): + if key.lower() in _SENSITIVE_HEADERS: + del headers[key] + except Exception: + # Defensive: never let header stripping break a redirect. + pass + + +def _http_once( + *, method: str, url: str, headers: dict[str, str], + json_body: Any, data: bytes | None, files: dict | None, form: dict | None, + timeout: float, follow_redirects: bool, + stream: bool, sink: Path | None, +) -> HTTPResponse: + """One HTTP attempt. No retry.""" + if HAS_REQUESTS: + kwargs: dict[str, Any] = { + "method": method, "url": url, "headers": headers, + "timeout": timeout, "allow_redirects": follow_redirects, + } + if json_body is not None: + kwargs["json"] = json_body + elif data is not None: + kwargs["data"] = data + elif files is not None or form is not None: + kwargs["files"] = files + kwargs["data"] = form + if stream: + kwargs["stream"] = True + + # Use the subclass that strips sensitive headers cross-host + with _StripSensitiveOnRedirectSession() as s: + try: + r = s.request(**kwargs) + if stream and sink is not None: + sink.parent.mkdir(parents=True, exist_ok=True) + with sink.open("wb") as f: + for chunk in r.iter_content(DOWNLOAD_CHUNK_SIZE): + if chunk: + f.write(chunk) + body = b"" # already drained + else: + body = r.content + return HTTPResponse( + status=r.status_code, + headers={k: v for k, v in r.headers.items()}, + body=body, + url=r.url, + ) + except requests.exceptions.RequestException as e: + # Convert to TimeoutError / ConnectionError so the retry loop + # picks them up uniformly with the stdlib path. + if isinstance(e, requests.exceptions.Timeout): + raise TimeoutError(str(e)) from e + raise ConnectionError(str(e)) from e + + # ---------- stdlib fallback ---------- + if json_body is not None: + body_bytes = json.dumps(json_body).encode("utf-8") + headers.setdefault("Content-Type", "application/json") + else: + body_bytes = data + req = urllib.request.Request(url, data=body_bytes, headers=headers, method=method) + + # urllib follows redirects by default. We need to: + # 1) intercept cross-host redirects and drop X-API-Key + # 2) optionally NOT follow redirects when follow_redirects=False + class _RedirectHandler(urllib.request.HTTPRedirectHandler): + def __init__(self, original_host: str, follow: bool): + self.original_host = original_host + self.follow = follow + + def redirect_request(self, req2, fp, code, msg, hdrs, newurl): + if not self.follow: + return None + new_host = (urlparse(newurl).hostname or "").lower() + if new_host != self.original_host: + # Build a new request with cleaned headers + clean_headers = { + k: v for k, v in req2.header_items() + if k.lower() not in ("x-api-key", "authorization", "cookie") + } + new_req = urllib.request.Request(newurl, headers=clean_headers, method="GET") + return new_req + return super().redirect_request(req2, fp, code, msg, hdrs, newurl) + + original_host = (urlparse(url).hostname or "").lower() + opener = urllib.request.build_opener(_RedirectHandler(original_host, follow_redirects)) + + try: + resp = opener.open(req, timeout=timeout) + except urllib.error.HTTPError as e: + return HTTPResponse( + status=e.code, + headers=dict(e.headers) if e.headers else {}, + body=e.read() or b"", + url=getattr(e, "url", url), + ) + + final_url = resp.geturl() + final_status = resp.status + final_headers = dict(resp.headers) + + if stream and sink is not None: + sink.parent.mkdir(parents=True, exist_ok=True) + with sink.open("wb") as f: + while True: + chunk = resp.read(DOWNLOAD_CHUNK_SIZE) + if not chunk: + break + f.write(chunk) + return HTTPResponse(status=final_status, headers=final_headers, body=b"", url=final_url) + + return HTTPResponse(status=final_status, headers=final_headers, body=resp.read(), url=final_url) + + +def http_get(url: str, **kwargs: Any) -> HTTPResponse: + return http_request("GET", url, **kwargs) + + +def http_post(url: str, **kwargs: Any) -> HTTPResponse: + return http_request("POST", url, **kwargs) + + +# ============================================================================= +# Workflow validation & helpers +# ============================================================================= + +def is_api_format(workflow: Any) -> bool: + """API format = top-level dict where each value has `class_type`.""" + if not isinstance(workflow, dict): + return False + if "nodes" in workflow and "links" in workflow: + return False + for v in workflow.values(): + if isinstance(v, dict) and "class_type" in v: + return True + return False + + +def unwrap_workflow(payload: Any) -> dict: + """Unwrap common wrapper variants. Returns API-format workflow or raises ValueError.""" + if isinstance(payload, dict) and is_api_format(payload): + return payload + # Some files wrap workflow under "prompt" key (e.g. saved /prompt payloads) + if isinstance(payload, dict) and "prompt" in payload and is_api_format(payload["prompt"]): + return payload["prompt"] + # Editor format + if isinstance(payload, dict) and "nodes" in payload and "links" in payload: + raise ValueError( + "Workflow is in editor format (has top-level 'nodes' and 'links' arrays). " + "Re-export from ComfyUI using 'Workflow → Export (API)' (newer UI) " + "or 'Save (API Format)' (older UI)." + ) + raise ValueError( + "Workflow is not in API format. Each top-level entry must have a 'class_type' field." + ) + + +def is_link(value: Any) -> bool: + """True if `value` is a [node_id, output_index] connection (length-2 list).""" + return ( + isinstance(value, list) + and len(value) == 2 + and isinstance(value[0], str) + and isinstance(value[1], int) + ) + + +def iter_nodes(workflow: dict) -> Iterator[tuple[str, dict]]: + """Yield (node_id, node) for each valid API-format node.""" + for node_id, node in workflow.items(): + if isinstance(node, dict) and "class_type" in node: + yield node_id, node + + +def iter_model_deps(workflow: dict) -> Iterator[dict]: + """Yield {node_id, class_type, field, value, folder} for each model dependency.""" + for node_id, node in iter_nodes(workflow): + cls = node["class_type"] + if cls not in MODEL_LOADERS: + continue + inputs = node.get("inputs", {}) or {} + for field_name, folder in MODEL_LOADERS[cls]: + val = inputs.get(field_name) + if val and isinstance(val, str) and not is_link(val): + yield { + "node_id": node_id, + "class_type": cls, + "field": field_name, + "value": val, + "folder": folder, + } + + +def iter_embedding_refs(workflow: dict) -> Iterator[tuple[str, str]]: + """Yield (node_id, embedding_name) for every embedding mention in prompts.""" + for node_id, node in iter_nodes(workflow): + inputs = node.get("inputs", {}) or {} + for field_name, val in inputs.items(): + if field_name not in PROMPT_FIELDS: + continue + if not isinstance(val, str): + continue + for m in EMBEDDING_REGEX.finditer(val): + yield node_id, m.group(1) + + +# ============================================================================= +# Path safety +# ============================================================================= + +def safe_path_join(base: Path, *parts: str) -> Path: + """Join paths, raising if the result escapes `base`. + + Server-supplied filenames may contain `../` etc. This guards against + path-traversal attacks when downloading outputs. + """ + base_resolved = base.resolve() + candidate = base.joinpath(*parts).resolve() + try: + candidate.relative_to(base_resolved) + except ValueError as e: + raise ValueError( + f"Refusing path traversal: {candidate} is outside {base_resolved}" + ) from e + return candidate + + +def media_type_from_filename(filename: str) -> str: + ext = Path(filename).suffix.lower() + if ext in (".mp4", ".webm", ".avi", ".mov", ".mkv", ".gif", ".webp"): + return "video" + if ext in (".wav", ".mp3", ".flac", ".ogg", ".m4a"): + return "audio" + if ext in (".glb", ".obj", ".ply", ".gltf"): + return "3d" + if ext in (".json", ".txt", ".md"): + return "text" + return "image" + + +def looks_like_video_workflow(workflow: dict) -> bool: + """Used to bump default timeout for video workflows.""" + for _, node in iter_nodes(workflow): + if node["class_type"] in SLOW_OUTPUT_NODES: + return True + if node["class_type"].lower().startswith(("animatediff", "ade_", "wanvideo", "hunyuanvideo", "ltxvideo", "cogvideo")): + return True + return False + + +# ============================================================================= +# Seed handling +# ============================================================================= + +# ComfyUI's max seed range. Many UIs treat `-1` as "randomize on submit". +SEED_MAX = 2**63 - 1 +SEED_MIN = 0 + + +def coerce_seed(value: Any) -> int: + """Convert -1 or None to a fresh random seed; otherwise return int(value). + + Accepts numeric -1 OR string "-1" (both treated as "randomize"). Other + parse failures raise TypeError/ValueError for the caller to surface. + """ + if value is None: + return random.randint(SEED_MIN, SEED_MAX) + # Stringly-typed -1 from CLI / JSON should also randomize + if isinstance(value, str) and value.strip() == "-1": + return random.randint(SEED_MIN, SEED_MAX) + if value == -1: + return random.randint(SEED_MIN, SEED_MAX) + return int(value) + + +# ============================================================================= +# Cloud model-list normalization +# ============================================================================= + +def parse_model_list(payload: Any) -> set[str]: + """Normalize model-list responses from local ComfyUI vs Comfy Cloud. + + Local: `["a.safetensors", "b.safetensors"]` + Cloud: `[{"name": "a.safetensors", "pathIndex": 0}, ...]` + """ + if not isinstance(payload, list): + return set() + out: set[str] = set() + for item in payload: + if isinstance(item, str): + out.add(item) + elif isinstance(item, dict): + name = item.get("name") or item.get("filename") or item.get("path") + if isinstance(name, str): + out.add(name) + return out + + +# ============================================================================= +# Misc utilities +# ============================================================================= + +def new_client_id() -> str: + return str(uuid.uuid4()) + + +def fmt_kv(d: dict) -> str: + """Pretty key=value for log lines.""" + return " ".join(f"{k}={v!r}" for k, v in d.items()) + + +def emit_json(obj: Any, *, indent: int = 2) -> None: + """Print JSON to stdout. Centralised so behavior can be tweaked (e.g., --raw).""" + print(json.dumps(obj, indent=indent, default=str)) + + +def log(msg: str) -> None: + """stderr log with consistent prefix (so JSON stdout stays clean).""" + print(f"[comfyui-skill] {msg}", file=sys.stderr) diff --git a/skills/creative/comfyui/scripts/auto_fix_deps.py b/skills/creative/comfyui/scripts/auto_fix_deps.py new file mode 100755 index 0000000000..788bf8e9e3 --- /dev/null +++ b/skills/creative/comfyui/scripts/auto_fix_deps.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 +""" +auto_fix_deps.py — Run check_deps.py, then attempt to install whatever is missing. + +For local servers: + - Missing custom nodes → `comfy node install ` + - Missing models → `comfy model download` (only if a URL is supplied via + --model-source-file or detected via well-known names) + +For cloud: prints what would be needed but cannot install (cloud preinstalls +custom nodes and most models server-side; if something genuinely isn't there, +ask Comfy support). + +This is conservative: it never installs without an explicit URL for models +(downloading the wrong model is hard to undo). Custom nodes from the registry +are auto-installed by name. + +Usage: + python3 auto_fix_deps.py workflow_api.json + python3 auto_fix_deps.py workflow_api.json --models-from-file urls.json + python3 auto_fix_deps.py workflow_api.json --dry-run +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import subprocess +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, emit_json, log, resolve_api_key, +) +from check_deps import check_deps # noqa: E402 +from _common import unwrap_workflow # noqa: E402 + + +def comfy_cli_available() -> str | None: + """Return command prefix for comfy-cli, or None.""" + if shutil.which("comfy"): + return "comfy" + if shutil.which("uvx"): + return "uvx --from comfy-cli comfy" + return None + + +def run_cmd(cmd: list[str], *, dry_run: bool = False) -> tuple[int, str]: + if dry_run: + return 0, "[dry-run]" + log(f"$ {' '.join(cmd)}") + proc = subprocess.run(cmd, capture_output=True, text=True, check=False) + out = (proc.stdout or "") + (proc.stderr or "") + return proc.returncode, out + + +def install_node(package: str, *, dry_run: bool = False, comfy_cmd: str = "comfy") -> bool: + cmd = comfy_cmd.split() + ["--skip-prompt", "node", "install", package] + code, _ = run_cmd(cmd, dry_run=dry_run) + return code == 0 + + +def install_model(url: str, folder: str, filename: str | None = None, + *, dry_run: bool = False, comfy_cmd: str = "comfy", + hf_token: str | None = None, civitai_token: str | None = None) -> bool: + cmd = comfy_cmd.split() + [ + "--skip-prompt", "model", "download", + "--url", url, + "--relative-path", f"models/{folder}", + ] + if filename: + cmd.extend(["--filename", filename]) + if hf_token: + cmd.extend(["--set-hf-api-token", hf_token]) + if civitai_token: + cmd.extend(["--set-civitai-api-token", civitai_token]) + code, _ = run_cmd(cmd, dry_run=dry_run) + return code == 0 + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="Run check_deps and install whatever is missing") + p.add_argument("workflow") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST) + p.add_argument("--api-key", help=f"or set ${ENV_API_KEY}") + p.add_argument("--models-from-file", + help="JSON file mapping {model_filename: download_url} for models that need install") + p.add_argument("--hf-token", help="HuggingFace token for downloads") + p.add_argument("--civitai-token", help="CivitAI token for downloads") + p.add_argument("--dry-run", action="store_true", + help="Show what would be installed without doing it") + p.add_argument("--no-restart", action="store_true", + help="Don't suggest restarting the server after node install") + args = p.parse_args(argv) + + api_key = resolve_api_key(args.api_key) + + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + emit_json({"error": f"Workflow not found: {args.workflow}"}) + return 1 + try: + with wf_path.open() as f: + workflow = unwrap_workflow(json.load(f)) + except (ValueError, json.JSONDecodeError) as e: + emit_json({"error": str(e)}) + return 1 + + report = check_deps(workflow, host=args.host, api_key=api_key) + + if report["is_ready"]: + emit_json({"status": "ready", "report": report}) + return 0 + + if report["is_cloud"]: + emit_json({ + "status": "cannot_fix_cloud", + "reason": "Comfy Cloud preinstalls nodes; if something is genuinely missing, contact support.", + "report": report, + }) + return 1 + + comfy_cmd = comfy_cli_available() + if not comfy_cmd: + emit_json({ + "status": "cannot_fix", + "reason": "comfy-cli not on PATH; install with `pip install comfy-cli` or `pipx install comfy-cli`", + "report": report, + }) + return 1 + + actions: list[dict] = [] + failures: list[dict] = [] + + # ---- Install missing custom nodes ---- + seen_packages: set[str] = set() + for entry in report["missing_nodes"]: + cmd = entry.get("fix_command", "") + if cmd.startswith("comfy node install "): + package = cmd.split(" ")[-1] + if package in seen_packages: + continue + seen_packages.add(package) + ok = install_node(package, dry_run=args.dry_run, comfy_cmd=comfy_cmd) + (actions if ok else failures).append({ + "kind": "node", "package": package, "node_class": entry["class_type"], + "ok": ok, + }) + else: + failures.append({ + "kind": "node", "node_class": entry["class_type"], + "ok": False, "reason": "No registry mapping known. " + entry.get("fix_hint", ""), + }) + + # ---- Install missing models (only when URL provided) ---- + sources: dict[str, str] = {} + if args.models_from_file: + try: + sources = json.loads(Path(args.models_from_file).read_text()) + except (OSError, json.JSONDecodeError) as e: + log(f"Could not read --models-from-file: {e}") + + for entry in report["missing_models"]: + filename = entry["value"] + url = sources.get(filename) + if not url: + failures.append({ + "kind": "model", "filename": filename, "folder": entry["folder"], + "ok": False, "reason": "No URL provided in --models-from-file. " + "Refusing to guess.", + }) + continue + ok = install_model( + url, entry["folder"], filename, + dry_run=args.dry_run, comfy_cmd=comfy_cmd, + hf_token=args.hf_token, civitai_token=args.civitai_token, + ) + (actions if ok else failures).append({ + "kind": "model", "filename": filename, "folder": entry["folder"], + "url": url, "ok": ok, + }) + + # ---- Embeddings ---- + for entry in report["missing_embeddings"]: + emb_name = entry["embedding_name"] + # Try common extensions in user-supplied source map + url = (sources.get(f"{emb_name}.pt") + or sources.get(f"{emb_name}.safetensors") + or sources.get(emb_name)) + if not url: + failures.append({ + "kind": "embedding", "name": emb_name, + "ok": False, "reason": "No URL provided in --models-from-file.", + }) + continue + target_filename = ( + f"{emb_name}.safetensors" if url.endswith(".safetensors") + else f"{emb_name}.pt" + ) + ok = install_model( + url, "embeddings", target_filename, + dry_run=args.dry_run, comfy_cmd=comfy_cmd, + hf_token=args.hf_token, civitai_token=args.civitai_token, + ) + (actions if ok else failures).append({ + "kind": "embedding", "name": emb_name, "url": url, "ok": ok, + }) + + needs_restart = any(a["kind"] == "node" and a.get("ok") for a in actions) + + emit_json({ + "status": "fixed" if not failures else "partial", + "actions_taken": actions, + "failures": failures, + "needs_server_restart": needs_restart and not args.no_restart, + "restart_hint": "comfy stop && comfy launch --background", + "dry_run": args.dry_run, + }) + return 0 if not failures else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/check_deps.py b/skills/creative/comfyui/scripts/check_deps.py new file mode 100755 index 0000000000..607e2c0a2d --- /dev/null +++ b/skills/creative/comfyui/scripts/check_deps.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +""" +check_deps.py — Verify a ComfyUI workflow's dependencies (custom nodes, models, +embeddings) against a running server. + +Improvements over v1: + - Cloud-aware endpoint mapping (handles `/api/experiment/models/{folder}` and + `/api/object_info` variants verified against live cloud API) + - Distinguishes 200-empty (genuinely no models in folder) vs 404 + (folder doesn't exist) vs 403 (auth/tier issue) — no silent passes + - Outputs concrete remediation commands (e.g. `comfy node install `) + when nodes are missing + - Detects embedding references inside prompt strings as model deps + - Skips check on cloud free tier `/api/object_info` (403) without false alarm + - Accepts API key from CLI flag OR $COMFY_CLOUD_API_KEY env var + +Usage: + python3 check_deps.py workflow_api.json + python3 check_deps.py workflow_api.json --host 127.0.0.1 --port 8188 + python3 check_deps.py workflow_api.json --host https://cloud.comfy.org + +Stdlib-only. Python 3.10+. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, + emit_json, folder_aliases_for, http_get, is_cloud_host, + iter_embedding_refs, iter_model_deps, iter_nodes, parse_model_list, + resolve_api_key, resolve_url, unwrap_workflow, +) + + +# Known node → custom-node-package map. When a workflow needs a node we don't +# recognize, suggesting the right `comfy node install ...` makes the difference +# between a working agent and a stuck one. +NODE_TO_PACKAGE: dict[str, str] = { + # rgthree (Reroute is JS-only and doesn't appear in /object_info) + "Power Lora Loader (rgthree)": "rgthree-comfy", + "Image Comparer (rgthree)": "rgthree-comfy", + "Seed (rgthree)": "rgthree-comfy", + "Display Any (rgthree)": "rgthree-comfy", + "Display Int (rgthree)": "rgthree-comfy", + # Impact pack + "FaceDetailer": "comfyui-impact-pack", + "DetailerForEach": "comfyui-impact-pack", + "BboxDetectorSEGS": "comfyui-impact-pack", + "SAMLoader": "comfyui-impact-pack", + "ImpactWildcardProcessor": "comfyui-impact-pack", + # Impact subpack (separate package) + "UltralyticsDetectorProvider": "comfyui-impact-subpack", + # Was Node Suite + "Image Save": "was-node-suite-comfyui", + "Number Counter": "was-node-suite-comfyui", + "Text String": "was-node-suite-comfyui", + # easy-use + "easy fullLoader": "comfyui-easy-use", + "easy positive": "comfyui-easy-use", + "easy negative": "comfyui-easy-use", + "easy seed": "comfyui-easy-use", + "easy imageSave": "comfyui-easy-use", + # Video Helper Suite + "VHS_VideoCombine": "comfyui-videohelpersuite", + "VHS_LoadVideo": "comfyui-videohelpersuite", + "VHS_LoadAudio": "comfyui-videohelpersuite", + # AnimateDiff + "ADE_AnimateDiffLoaderWithContext": "comfyui-animatediff-evolved", + "ADE_AnimateDiffLoaderGen1": "comfyui-animatediff-evolved", + "ADE_LoadAnimateDiffModel": "comfyui-animatediff-evolved", + # ControlNet aux preprocessors (full class names) + "CannyEdgePreprocessor": "comfyui_controlnet_aux", + "DWPreprocessor": "comfyui_controlnet_aux", + "OpenposePreprocessor": "comfyui_controlnet_aux", + "DepthAnythingPreprocessor": "comfyui_controlnet_aux", + "Zoe_DepthAnythingPreprocessor": "comfyui_controlnet_aux", + "AnimalPosePreprocessor": "comfyui_controlnet_aux", + # IPAdapter Plus + "IPAdapterAdvanced": "comfyui_ipadapter_plus", + "IPAdapterUnifiedLoader": "comfyui_ipadapter_plus", + "IPAdapterModelLoader": "comfyui_ipadapter_plus", + "IPAdapterInsightFaceLoader": "comfyui_ipadapter_plus", + # InstantID + "InstantIDModelLoader": "comfyui_instantid", + "ApplyInstantID": "comfyui_instantid", + # Comfy essentials (note: registry slug uses underscore, not hyphen) + "GetImageSize+": "comfyui_essentials", + "ImageBatchMultiple+": "comfyui_essentials", + # pysssss + "ShowText|pysssss": "comfyui-custom-scripts", + "PreviewImage|pysssss": "comfyui-custom-scripts", + # SUPIR + "SUPIR_Upscale": "comfyui-supir", + "SUPIR_first_stage": "comfyui-supir", + # GGUF (case-sensitive registry slug) + "UNETLoaderGGUF": "ComfyUI-GGUF", + "DualCLIPLoaderGGUF": "ComfyUI-GGUF", + # Florence2 + "Florence2Run": "comfyui-florence2", + # WAS + "Image Filter Adjustments": "was-node-suite-comfyui", + # Photomaker (case-sensitive) + "PhotoMakerLoader": "ComfyUI-PhotoMaker-Plus", + # Wan video (case-sensitive) + "WanVideoSampler": "ComfyUI-WanVideoWrapper", + "WanVideoModelLoader": "ComfyUI-WanVideoWrapper", +} + +# Nodes whose package isn't on the comfy registry — need git-URL install via +# ComfyUI-Manager. We surface a helpful hint instead of an unrunnable command. +NODE_TO_GIT_URL: dict[str, str] = { + "HunyuanVideoSampler": "https://github.com/kijai/ComfyUI-HunyuanVideoWrapper", + "HunyuanVideoModelLoader": "https://github.com/kijai/ComfyUI-HunyuanVideoWrapper", +} + + +def fetch_object_info(url: str, headers: dict) -> tuple[set[str] | None, dict | None]: + """Returns (installed_node_set, error_info). Error info is a dict if we + couldn't query (e.g. cloud free tier), else None. + """ + r = http_get(url, headers=headers, retries=2, timeout=30) + if r.status == 200: + try: + data = r.json() + if isinstance(data, dict): + return set(data.keys()), None + except Exception: + pass + return None, {"http_status": 200, "reason": "non-dict response"} + if r.status == 403: + try: + body = r.json() + except Exception: + body = {"raw": r.text()[:200]} + return None, {"http_status": 403, "reason": "forbidden", "body": body} + if r.status == 404: + return None, {"http_status": 404, "reason": "endpoint not found"} + return None, {"http_status": r.status, "reason": "unexpected", "body": r.text()[:200]} + + +def _fetch_one_folder( + base: str, folder: str, headers: dict, *, is_cloud: bool, +) -> tuple[set[str] | None, dict | None]: + """Single-folder fetch, no aliasing. Returns (installed_set, error_info).""" + url = resolve_url(base, f"/models/{folder}", is_cloud=is_cloud) + r = http_get(url, headers=headers, retries=2, timeout=30) + if r.status == 200: + try: + return parse_model_list(r.json()), None + except Exception: + return set(), {"http_status": 200, "reason": "non-list response"} + if r.status == 404: + body_text = r.text() + try: + body = r.json() + except Exception: + body = {"raw": body_text[:200]} + code = body.get("code") if isinstance(body, dict) else None + if code == "folder_not_found": + # Folder is genuinely empty/missing on server — not the same as + # "endpoint missing". Return empty set with informational error. + return set(), {"http_status": 404, "reason": "folder_empty_or_unknown", "body": body} + return None, {"http_status": 404, "reason": "endpoint not found", "body": body} + if r.status == 403: + try: + body = r.json() + except Exception: + body = {} + return None, {"http_status": 403, "reason": "forbidden", "body": body} + return None, {"http_status": r.status, "reason": "unexpected"} + + +def fetch_models_for_folder( + base: str, folder: str, headers: dict, *, is_cloud: bool, +) -> tuple[set[str] | None, dict | None]: + """Fetch installed models for a folder, trying aliases. + + Folder renames over time (e.g. unet → diffusion_models, clip → text_encoders) + mean a workflow asking for a model in `unet` may need to look in + `diffusion_models`. We union models from every reachable alias. + + Returns (combined_set | None, last_error | None). + """ + aliases = folder_aliases_for(folder) + combined: set[str] = set() + any_success = False + last_err: dict | None = None + for alias in aliases: + models, err = _fetch_one_folder(base, alias, headers, is_cloud=is_cloud) + if models is not None: + combined.update(models) + any_success = True + last_err = None + else: + last_err = err + if not any_success: + return None, last_err + return combined, None + + +def fetch_embeddings(base: str, headers: dict, *, is_cloud: bool) -> tuple[set[str] | None, dict | None]: + """Local ComfyUI exposes /embeddings; cloud uses /experiment/models/embeddings.""" + if is_cloud: + return fetch_models_for_folder(base, "embeddings", headers, is_cloud=True) + # Local: dedicated /embeddings returns a flat list of names + r = http_get(resolve_url(base, "/embeddings", is_cloud=False), headers=headers, retries=2) + if r.status == 200: + try: + data = r.json() + if isinstance(data, list): + # Strip extensions from the registered names since prompt syntax + # usually omits them ("embedding:goodvibes" vs "goodvibes.pt") + names = set() + for n in data: + if isinstance(n, str): + names.add(n) + # Also store stem for fuzzy matching + names.add(Path(n).stem) + return names, None + except Exception: + pass + return None, {"http_status": r.status, "reason": "unexpected"} + + +def normalize_for_match(name: str) -> set[str]: + """Generate matching variants of a model name (with/without extension, slashes, etc.)""" + s = {name} + s.add(Path(name).stem) + s.add(Path(name).name) + # ComfyUI sometimes strips/keeps the leading folder + if "/" in name or "\\" in name: + flat = name.replace("\\", "/").split("/")[-1] + s.add(flat) + s.add(Path(flat).stem) + return {x for x in s if x} + + +def model_present(needed: str, installed: set[str]) -> bool: + if not installed: + return False + needed_variants = normalize_for_match(needed) + installed_norm: set[str] = set() + for inst in installed: + installed_norm.update(normalize_for_match(inst)) + return bool(needed_variants & installed_norm) + + +def suggest_install_command(node_class: str) -> str | None: + pkg = NODE_TO_PACKAGE.get(node_class) + if pkg: + return f"comfy node install {pkg}" + return None + + +def suggest_git_url(node_class: str) -> str | None: + """For nodes not on the registry, return a git URL the user can hand to + ComfyUI-Manager's `/manager/queue/install` endpoint.""" + return NODE_TO_GIT_URL.get(node_class) + + +def check_deps( + workflow: dict, host: str, *, api_key: str | None = None, +) -> dict: + headers: dict[str, str] = {} + if api_key: + headers["X-API-Key"] = api_key + + is_cloud = is_cloud_host(host) + base = host.rstrip("/") + + # ---- 1. Required nodes ---- + required_nodes: set[str] = set() + for _, node in iter_nodes(workflow): + required_nodes.add(node["class_type"]) + + object_info_url = resolve_url(base, "/object_info", is_cloud=is_cloud) + installed_nodes, obj_err = fetch_object_info(object_info_url, headers) + + missing_nodes: list[dict] = [] + node_check_skipped = False + if installed_nodes is None: + # Couldn't query (e.g. cloud free tier). Don't false-alarm; mark skipped. + node_check_skipped = True + else: + for cls in sorted(required_nodes): + if cls not in installed_nodes: + entry = {"class_type": cls} + cmd = suggest_install_command(cls) + git_url = suggest_git_url(cls) + if cmd: + entry["fix_command"] = cmd + elif git_url: + entry["fix_git_url"] = git_url + entry["fix_hint"] = ( + f"Not on registry. Install via Manager with this git URL: {git_url}" + ) + else: + entry["fix_hint"] = ( + "Search https://registry.comfy.org or " + "use ComfyUI-Manager UI to find the package providing this node." + ) + missing_nodes.append(entry) + + # ---- 2. Required models ---- + model_cache: dict[str, tuple[set[str] | None, dict | None]] = {} + missing_models: list[dict] = [] + folder_errors: dict[str, dict] = {} + + for dep in iter_model_deps(workflow): + folder = dep["folder"] + if folder not in model_cache: + model_cache[folder] = fetch_models_for_folder( + base, folder, headers, is_cloud=is_cloud, + ) + installed, err = model_cache[folder] + if installed is None: + # Couldn't enumerate this folder — record once + folder_errors.setdefault(folder, err or {}) + # Don't flag as missing (we don't know); the folder_errors block surfaces this + continue + if not model_present(dep["value"], installed): + entry = dict(dep) + entry["fix_hint"] = ( + f"comfy model download --url --relative-path models/{folder} " + f"--filename {dep['value']!r}" + ) + missing_models.append(entry) + + # ---- 3. Embedding refs in prompts ---- + emb_installed, emb_err = fetch_embeddings(base, headers, is_cloud=is_cloud) + missing_embeddings: list[dict] = [] + seen_emb: set[tuple[str, str]] = set() + for nid, emb_name in iter_embedding_refs(workflow): + if (nid, emb_name) in seen_emb: + continue + seen_emb.add((nid, emb_name)) + if emb_installed is None: + # Couldn't enumerate — skip silently here, surface the error in the + # folder_errors block + continue + if not model_present(emb_name, emb_installed): + missing_embeddings.append({ + "node_id": nid, + "embedding_name": emb_name, + "folder": "embeddings", + "fix_hint": ( + f"Download {emb_name}.pt or .safetensors and place in " + f"models/embeddings/, or `comfy model download --url " + f"--relative-path models/embeddings`" + ), + }) + + if emb_err and emb_installed is None: + folder_errors.setdefault("embeddings", emb_err) + + is_ready = ( + not node_check_skipped + and not missing_nodes + and not missing_models + and not missing_embeddings + ) + + return { + "is_ready": is_ready, + "node_check_skipped": node_check_skipped, + "node_check_skip_reason": obj_err if node_check_skipped else None, + "missing_nodes": missing_nodes, + "missing_models": missing_models, + "missing_embeddings": missing_embeddings, + "folder_errors": folder_errors, + # 0 is a legitimate count (e.g. empty server). Use None only when not queried. + "installed_node_count": len(installed_nodes) if installed_nodes is not None else None, + "required_node_count": len(required_nodes), + "required_nodes": sorted(required_nodes), + "host": base, + "is_cloud": is_cloud, + } + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="Check ComfyUI workflow dependencies against a running server") + p.add_argument("workflow", help="Path to workflow API JSON file") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST, help="ComfyUI server URL") + p.add_argument("--port", type=int, help="Server port (overrides --host port)") + p.add_argument("--api-key", help=f"API key for cloud (or set ${ENV_API_KEY} env var)") + p.add_argument("--strict", action="store_true", + help="Exit non-zero if node check is skipped (e.g. on cloud free tier)") + args = p.parse_args(argv) + + host = args.host + if args.port is not None: + # Strip any port from host and append --port + from urllib.parse import urlparse, urlunparse + parsed = urlparse(host if "://" in host else f"http://{host}") + new_netloc = f"{parsed.hostname}:{args.port}" + host = urlunparse(parsed._replace(netloc=new_netloc)) + + api_key = resolve_api_key(args.api_key) + + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + emit_json({"error": f"Workflow file not found: {args.workflow}"}) + return 1 + try: + with wf_path.open() as f: + payload = json.load(f) + workflow = unwrap_workflow(payload) + except ValueError as e: + emit_json({"error": str(e)}) + return 1 + except json.JSONDecodeError as e: + emit_json({"error": f"Invalid JSON: {e}"}) + return 1 + + try: + result = check_deps(workflow, host=host, api_key=api_key) + except Exception as e: + emit_json({"error": f"Dep check failed: {e}", "host": host}) + return 1 + + emit_json(result) + + if not result["is_ready"]: + return 1 + if args.strict and result["node_check_skipped"]: + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/comfyui_setup.sh b/skills/creative/comfyui/scripts/comfyui_setup.sh new file mode 100755 index 0000000000..dd0369833d --- /dev/null +++ b/skills/creative/comfyui/scripts/comfyui_setup.sh @@ -0,0 +1,286 @@ +#!/usr/bin/env bash +# ComfyUI Setup — Install, launch, and verify using the official comfy-cli. +# +# Improvements over v1: +# - Prefers `pipx` / `uvx` over global `pip install` (avoids polluting system Python) +# - Idempotent: detects already-running server and skips re-launch +# - Configurable port via --port=N (default 8188) +# - Configurable workspace via --workspace=PATH +# - Persistent log file in /tmp/comfyui_setup..log for debugging +# - SIGINT trap cleans up partial state +# - Refuses local install when hardware_check.py verdict is "cloud" +# - Forwards extra flags to comfy-cli (e.g. --cuda-version=12.4) +# +# Usage: +# bash scripts/comfyui_setup.sh +# (auto-detects GPU; uses recommendation from hardware_check.py) +# bash scripts/comfyui_setup.sh --nvidia +# bash scripts/comfyui_setup.sh --m-series --port=8190 +# bash scripts/comfyui_setup.sh --amd --workspace=/data/comfy +# +# Flags: +# --nvidia | --amd | --m-series | --cpu GPU selection (skips hw check) +# --port=N HTTP port (default 8188) +# --workspace=PATH ComfyUI install location +# --skip-launch Install only, don't start server +# --force-cloud-override Install locally even if hw says cloud +# -- Pass remaining args to `comfy install` + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +HARDWARE_CHECK="$SCRIPT_DIR/hardware_check.py" +LOG_FILE="/tmp/comfyui_setup.$$.log" +PORT=8188 +WORKSPACE="" +GPU_FLAG="" +SKIP_LAUNCH=0 +FORCE_CLOUD_OVERRIDE=0 +EXTRA_INSTALL_ARGS=() + +cleanup() { + local exit_code=$? + if [ $exit_code -ne 0 ]; then + echo "==> Setup exited with status $exit_code. Log: $LOG_FILE" >&2 + fi + exit $exit_code +} +trap cleanup EXIT INT TERM + +log() { echo "==> $*" | tee -a "$LOG_FILE" >&2; } +err() { echo "ERROR: $*" | tee -a "$LOG_FILE" >&2; } + +# --- Argument parsing --- +PASSTHROUGH=0 +for arg in "$@"; do + if [ "$PASSTHROUGH" -eq 1 ]; then + EXTRA_INSTALL_ARGS+=("$arg") + continue + fi + case "$arg" in + --nvidia|--amd|--m-series|--cpu) + GPU_FLAG="$arg" + ;; + --port=*) + PORT="${arg#*=}" + ;; + --workspace=*) + WORKSPACE="${arg#*=}" + ;; + --skip-launch) + SKIP_LAUNCH=1 + ;; + --force-cloud-override) + FORCE_CLOUD_OVERRIDE=1 + ;; + --) + PASSTHROUGH=1 + ;; + --help|-h) + # Print the leading comment block, stripping the `# ` prefix. + # Stops at the first blank line which separates docs from code. + awk ' + NR == 1 { next } # skip shebang + /^[^#]/ { exit } # stop at first non-comment line + /^$/ { exit } # ...or first blank line + { sub(/^# ?/, ""); print } + ' "$0" + exit 0 + ;; + *) + err "Unknown argument: $arg" + exit 64 + ;; + esac +done + +log "Logging to $LOG_FILE" + +# --- Step 0: Hardware check (skipped if user gave an explicit GPU flag) --- +if [ -z "$GPU_FLAG" ]; then + if [ ! -f "$HARDWARE_CHECK" ]; then + log "hardware_check.py not found — defaulting to --nvidia" + GPU_FLAG="--nvidia" + else + log "Running hardware check…" + set +e + HW_JSON="$(python3 "$HARDWARE_CHECK" --json 2>>"$LOG_FILE")" + HW_EXIT=$? + set -e + + if [ -z "$HW_JSON" ]; then + err "hardware_check.py produced no output (exit $HW_EXIT). Pass an explicit flag." + exit 1 + fi + echo "$HW_JSON" | tee -a "$LOG_FILE" >&2 + + VERDICT="$(echo "$HW_JSON" | python3 -c 'import sys,json; print(json.load(sys.stdin).get("verdict",""))')" + FLAG="$(echo "$HW_JSON" | python3 -c 'import sys,json; print(json.load(sys.stdin).get("comfy_cli_flag") or "")')" + + if [ "$VERDICT" = "cloud" ] && [ "$FORCE_CLOUD_OVERRIDE" -ne 1 ]; then + log "" + log "Hardware check: this machine is not suitable for local ComfyUI." + log "Recommended: Comfy Cloud — https://platform.comfy.org" + log "" + log "To override and force a local install, re-run with --force-cloud-override" + log "or pass an explicit GPU flag (--nvidia|--amd|--m-series|--cpu)." + exit 2 + fi + + if [ "$VERDICT" = "marginal" ]; then + log "Hardware check: verdict is MARGINAL." + log " SD1.5 should work; SDXL/Flux may be slow or OOM." + log " Consider Comfy Cloud for heavier workflows: https://platform.comfy.org" + fi + + if [ -z "$FLAG" ]; then + log "hardware_check could not pick a comfy-cli flag. Defaulting to --nvidia." + log "(For Intel Arc or unsupported hardware, use the manual install path.)" + GPU_FLAG="--nvidia" + else + GPU_FLAG="$FLAG" + fi + fi +fi + +log "GPU flag: $GPU_FLAG" +log "Port: $PORT" +[ -n "$WORKSPACE" ] && log "Workspace: $WORKSPACE" +[ "${#EXTRA_INSTALL_ARGS[@]}" -gt 0 ] && log "Extra install args: ${EXTRA_INSTALL_ARGS[*]}" + +# --- Step 1: Install comfy-cli (prefer pipx / uvx over global pip) --- +COMFY_BIN="" +if command -v comfy >/dev/null 2>&1; then + COMFY_BIN="comfy" + log "comfy-cli already on PATH: $(comfy -v 2>/dev/null || echo 'unknown version')" +elif command -v uvx >/dev/null 2>&1; then + log "Using uvx (no install needed)" + COMFY_BIN="uvx --from comfy-cli comfy" +elif command -v pipx >/dev/null 2>&1; then + log "Installing comfy-cli via pipx…" + pipx install comfy-cli >>"$LOG_FILE" 2>&1 + COMFY_BIN="comfy" + # pipx adds shims to ~/.local/bin which may need to be on PATH + if ! command -v comfy >/dev/null 2>&1; then + if [ -x "$HOME/.local/bin/comfy" ]; then + export PATH="$HOME/.local/bin:$PATH" + COMFY_BIN="$HOME/.local/bin/comfy" + fi + fi +else + log "Neither pipx nor uvx found. Falling back to pip install --user…" + log " (Recommend installing pipx: https://pipx.pypa.io)" + if ! pip install --user comfy-cli >>"$LOG_FILE" 2>&1; then + # macOS: PEP 668 externally-managed-environment may block --user + log "pip install --user failed. Retrying with --break-system-packages…" + pip install --user --break-system-packages comfy-cli >>"$LOG_FILE" 2>&1 || { + err "Could not install comfy-cli. Install pipx or uv first." + exit 1 + } + fi + # Resolve the actual `comfy` script — pip --user puts it in: + # Linux: ~/.local/bin/comfy + # macOS: ~/Library/Python//bin/comfy OR ~/.local/bin/comfy + COMFY_BIN="" + for candidate in "$HOME/.local/bin/comfy" \ + "$HOME/Library/Python/3.13/bin/comfy" \ + "$HOME/Library/Python/3.12/bin/comfy" \ + "$HOME/Library/Python/3.11/bin/comfy" \ + "$HOME/Library/Python/3.10/bin/comfy"; do + if [ -x "$candidate" ]; then + COMFY_BIN="$candidate" + export PATH="$(dirname "$candidate"):$PATH" + break + fi + done + if [ -z "$COMFY_BIN" ]; then + if command -v comfy >/dev/null 2>&1; then + COMFY_BIN="comfy" + else + err "Installed comfy-cli but couldn't find the 'comfy' script." + err "Add the right Python user-bin directory to PATH and retry." + exit 1 + fi + fi +fi + +# --- Step 2: Disable analytics tracking (avoid interactive prompt) --- +log "Disabling analytics tracking…" +$COMFY_BIN --skip-prompt tracking disable >>"$LOG_FILE" 2>&1 || true + +# --- Step 3: Install ComfyUI --- +WORKSPACE_ARG=() +if [ -n "$WORKSPACE" ]; then + WORKSPACE_ARG=(--workspace "$WORKSPACE") +fi + +if $COMFY_BIN "${WORKSPACE_ARG[@]}" which 2>/dev/null | grep -q "ComfyUI"; then + EXISTING_WS="$($COMFY_BIN "${WORKSPACE_ARG[@]}" which 2>/dev/null || true)" + log "ComfyUI already installed at: $EXISTING_WS" +else + log "Installing ComfyUI ($GPU_FLAG)…" + if ! $COMFY_BIN "${WORKSPACE_ARG[@]}" --skip-prompt install "$GPU_FLAG" "${EXTRA_INSTALL_ARGS[@]}" >>"$LOG_FILE" 2>&1; then + err "Install failed. Tail of log:" + tail -20 "$LOG_FILE" >&2 + exit 1 + fi +fi + +if [ "$SKIP_LAUNCH" -eq 1 ]; then + log "Setup complete (--skip-launch). Run \`$COMFY_BIN launch --background -- --port $PORT\` when ready." + exit 0 +fi + +# --- Step 4: Detect already-running server --- +if curl -fsS "http://127.0.0.1:$PORT/system_stats" >/dev/null 2>&1; then + log "Server already running on port $PORT — skipping launch." + log "Stop with \`$COMFY_BIN stop\` if you want a fresh start." + curl -fsS "http://127.0.0.1:$PORT/system_stats" | python3 -m json.tool 2>/dev/null || true + log "Done." + exit 0 +fi + +# --- Step 5: Launch --- +log "Launching ComfyUI in background on port $PORT…" +LAUNCH_EXTRAS=("--" "--port" "$PORT") +if ! $COMFY_BIN "${WORKSPACE_ARG[@]}" launch --background "${LAUNCH_EXTRAS[@]}" >>"$LOG_FILE" 2>&1; then + err "Background launch failed. Tail of log:" + tail -20 "$LOG_FILE" >&2 + err "Try foreground launch to see real-time errors: $COMFY_BIN launch -- --port $PORT" + exit 1 +fi + +# --- Step 6: Wait for server --- +log "Waiting for server…" +MAX_WAIT=60 +ELAPSED=0 +while [ $ELAPSED -lt $MAX_WAIT ]; do + if curl -fsS "http://127.0.0.1:$PORT/system_stats" >/dev/null 2>&1; then + log "Server is running!" + curl -fsS "http://127.0.0.1:$PORT/system_stats" | python3 -m json.tool 2>/dev/null || true + break + fi + sleep 2 + ELAPSED=$((ELAPSED + 2)) +done + +if [ $ELAPSED -ge $MAX_WAIT ]; then + err "Server did not start within ${MAX_WAIT}s." + err "Inspect log: $LOG_FILE" + err "Or run foreground: $COMFY_BIN launch -- --port $PORT" + exit 1 +fi + +log "" +log "Setup complete!" +log " Server: http://127.0.0.1:$PORT" +log " Web UI: http://127.0.0.1:$PORT (open in browser)" +log " Stop: $COMFY_BIN stop" +log " Log: $LOG_FILE (kept until shell closes)" +log "" +log "Next steps:" +log " - Download a model: $COMFY_BIN model download --url --relative-path models/checkpoints" +log " - Run a workflow: python3 $SCRIPT_DIR/run_workflow.py --workflow --args '{...}'" + +# Disable trap on success path +trap - EXIT diff --git a/skills/creative/comfyui/scripts/extract_schema.py b/skills/creative/comfyui/scripts/extract_schema.py new file mode 100755 index 0000000000..ba44cfdf6a --- /dev/null +++ b/skills/creative/comfyui/scripts/extract_schema.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +""" +extract_schema.py — Analyze a ComfyUI API-format workflow and extract +controllable parameters. + +Improvements over v1: + - Catalogs live in `_common.py`, shared with `check_deps.py` + - Coverage expanded for Flux / SD3 / Wan / Hunyuan / LTX / IPAdapter / rgthree + - Symmetric duplicate-name resolution: ALL duplicates get a node-id suffix + (instead of "first wins, second renamed"), so callers see consistent names + - Negative prompt detected by tracing `KSampler.negative` connections back to + the source CLIPTextEncode (more reliable than meta-title heuristic) + - Embedding references in prompt text are extracted as model dependencies + - Detects Primitive nodes that drive other nodes' inputs (and surfaces them + as the user-facing parameter) + - Reroutes are followed when tracing connections + +Usage: + python3 extract_schema.py workflow_api.json + python3 extract_schema.py workflow_api.json --output schema.json + +Stdlib-only. Python 3.10+. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + OUTPUT_NODES, PARAM_PATTERNS, PROMPT_FIELDS, + is_link, iter_embedding_refs, iter_model_deps, iter_nodes, unwrap_workflow, +) + + +# Sampler nodes whose `positive` / `negative` connections we trace +SAMPLER_NODE_FAMILY = { + "KSampler", "KSamplerAdvanced", + "SamplerCustom", "SamplerCustomAdvanced", + "BasicGuider", "CFGGuider", "DualCFGGuider", +} + + +def infer_type(value: Any) -> str: + if isinstance(value, bool): + return "bool" + if isinstance(value, int): + return "int" + if isinstance(value, float): + return "float" + if isinstance(value, str): + return "string" + if isinstance(value, list): + return "link" + if isinstance(value, dict): + return "object" + return "unknown" + + +def trace_to_node(workflow: dict, link: list, *, max_hops: int = 8) -> str | None: + """Follow a [node_id, slot] link, hopping through Reroute / Primitive nodes + if needed, to find the *upstream* node id that holds the actual value/input. + + Bounded by both `max_hops` AND a visited-set to prevent infinite loops on + pathological graphs. + """ + if not is_link(link): + return None + nid: str | None = link[0] + visited: set[str] = set() + for _ in range(max_hops): + if nid is None or nid in visited: + return nid + visited.add(nid) + node = workflow.get(nid) + if not isinstance(node, dict): + return None + cls = node.get("class_type", "") + # Reroute / Primitive / passthrough wrappers + if cls in ("Reroute", "PrimitiveNode", "Note", "easy showAnything"): + inputs = node.get("inputs", {}) or {} + # Find first link-shaped input and follow it + next_link = next((v for v in inputs.values() if is_link(v)), None) + if next_link is None: + return nid + nid = next_link[0] + continue + return nid + return nid + + +def find_negative_prompt_node(workflow: dict) -> str | None: + """Trace `negative` input of a sampler back to the source text encoder.""" + for nid, node in iter_nodes(workflow): + if node["class_type"] not in SAMPLER_NODE_FAMILY: + continue + inputs = node.get("inputs", {}) or {} + neg = inputs.get("negative") + if not is_link(neg): + continue + src = trace_to_node(workflow, neg) + if src and isinstance(workflow.get(src), dict): + cls = workflow[src].get("class_type", "") + if cls.startswith("CLIPTextEncode") or cls in ("smZ CLIPTextEncode", "BNK_CLIPTextEncodeAdvanced"): + return src + return None + + +def find_positive_prompt_node(workflow: dict) -> str | None: + for nid, node in iter_nodes(workflow): + if node["class_type"] not in SAMPLER_NODE_FAMILY: + continue + inputs = node.get("inputs", {}) or {} + pos = inputs.get("positive") + if not is_link(pos): + continue + src = trace_to_node(workflow, pos) + if src and isinstance(workflow.get(src), dict): + cls = workflow[src].get("class_type", "") + if cls.startswith("CLIPTextEncode") or cls in ("smZ CLIPTextEncode", "BNK_CLIPTextEncodeAdvanced"): + return src + return None + + +def extract_schema(workflow: dict) -> dict: + """Extract controllable parameters from a workflow. + + Returns: + { + "parameters": { friendly_name: {node_id, field, type, value, ...} }, + "output_nodes": [node_id, ...], + "model_dependencies": [{node_id, class_type, field, value, folder}], + "embedding_dependencies": [{node_id, embedding_name, found_in_field, value_excerpt}], + "summary": {...} + } + """ + output_nodes: list[str] = [] + + # First pass: identify positive / negative prompt nodes via connection tracing + pos_node = find_positive_prompt_node(workflow) + neg_node = find_negative_prompt_node(workflow) + + # ----- collect raw parameter candidates ----- + # Each candidate = (friendly_name, node_id, field, value) + # We resolve duplicate friendly_names AFTER the loop so dedup is symmetric. + raw_params: list[dict] = [] + + for node_id, node in iter_nodes(workflow): + cls = node["class_type"] + inputs = node.get("inputs", {}) or {} + + if cls in OUTPUT_NODES: + output_nodes.append(node_id) + + # Match this node against PARAM_PATTERNS + for p_class, p_field, friendly in PARAM_PATTERNS: + if cls != p_class: + continue + if p_field not in inputs: + continue + value = inputs[p_field] + t = infer_type(value) + if t == "link": + continue # connections aren't directly controllable + + actual_name = friendly + + # Disambiguate prompt vs negative_prompt by connection tracing + if friendly == "prompt": + if node_id == neg_node and pos_node != neg_node: + actual_name = "negative_prompt" + elif node_id == pos_node: + actual_name = "prompt" + else: + # Fallback: use _meta.title hints if present + meta_title = (node.get("_meta") or {}).get("title", "").lower() + if any(t_ in meta_title for t_ in ("negative", "neg", "-prompt", "anti")): + actual_name = "negative_prompt" + + raw_params.append({ + "name_hint": actual_name, + "node_id": node_id, + "field": p_field, + "type": t, + "value": value, + "class_type": cls, + }) + + # ----- symmetric duplicate-name resolution ----- + # Group by name_hint. If a hint appears once, keep it. If multiple, suffix + # ALL with their node_id. Always-stable, always-uniquely-addressable. + by_name: dict[str, list[dict]] = {} + for r in raw_params: + by_name.setdefault(r["name_hint"], []).append(r) + + parameters: dict[str, dict] = {} + for name, entries in by_name.items(): + if len(entries) == 1: + r = entries[0] + parameters[name] = { + "node_id": r["node_id"], "field": r["field"], + "type": r["type"], "value": r["value"], + "class_type": r["class_type"], + } + else: + # Sort by node_id (string-natural) for stability + entries.sort(key=lambda x: (str(x["node_id"]).zfill(8), x["field"])) + for r in entries: + full_name = f"{name}_{r['node_id']}" + parameters[full_name] = { + "node_id": r["node_id"], "field": r["field"], + "type": r["type"], "value": r["value"], + "class_type": r["class_type"], + "alias_of": name, + } + + # ----- model dependencies ----- + model_deps = list(iter_model_deps(workflow)) + + # ----- embedding dependencies (in prompt text) ----- + embedding_deps: list[dict] = [] + seen_emb: set[tuple[str, str]] = set() + for nid, emb_name in iter_embedding_refs(workflow): + key = (nid, emb_name) + if key in seen_emb: + continue + seen_emb.add(key) + # Find which field had the reference, for context + node = workflow.get(nid, {}) + inputs = node.get("inputs", {}) or {} + found_field = None + excerpt = None + for fname, fval in inputs.items(): + if isinstance(fval, str) and fname in PROMPT_FIELDS and emb_name in fval: + found_field = fname + excerpt = fval[:120] + break + embedding_deps.append({ + "node_id": nid, + "embedding_name": emb_name, + "field": found_field, + "value_excerpt": excerpt, + "folder": "embeddings", + }) + + # ----- summary ----- + summary = { + "parameter_count": len(parameters), + "output_node_count": len(output_nodes), + "model_dep_count": len(model_deps), + "embedding_dep_count": len(embedding_deps), + "has_negative_prompt": "negative_prompt" in parameters, + "has_seed": "seed" in parameters or any(p.startswith("seed_") for p in parameters), + "is_video_workflow": any( + workflow.get(n, {}).get("class_type", "") in { + "VHS_VideoCombine", "SaveVideo", "SaveAnimatedWEBP", "SaveAnimatedPNG", + } for n in output_nodes + ), + } + + return { + "parameters": parameters, + "output_nodes": output_nodes, + "model_dependencies": model_deps, + "embedding_dependencies": embedding_deps, + "summary": summary, + } + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="Extract controllable parameters from a ComfyUI workflow") + p.add_argument("workflow", help="Path to workflow API JSON file") + p.add_argument("--output", "-o", help="Output file (default: stdout)") + p.add_argument("--summary-only", action="store_true", + help="Only print the summary block") + args = p.parse_args(argv) + + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + print(f"Error: {wf_path} not found", file=sys.stderr) + return 1 + + try: + with wf_path.open() as f: + payload = json.load(f) + workflow = unwrap_workflow(payload) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + except json.JSONDecodeError as e: + print(f"Error: invalid JSON — {e}", file=sys.stderr) + return 1 + + schema = extract_schema(workflow) + + if args.summary_only: + out = json.dumps(schema["summary"], indent=2) + else: + out = json.dumps(schema, indent=2, default=str) + + if args.output: + Path(args.output).write_text(out) + print(f"Schema written to {args.output}", file=sys.stderr) + else: + print(out) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/fetch_logs.py b/skills/creative/comfyui/scripts/fetch_logs.py new file mode 100755 index 0000000000..c7b3b08480 --- /dev/null +++ b/skills/creative/comfyui/scripts/fetch_logs.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +""" +fetch_logs.py — Retrieve workflow execution diagnostics from a ComfyUI server. + +When a workflow errors, the server's /history (local) or /jobs (cloud) entry +contains the full Python traceback. This script makes it easy to fetch by +prompt_id, with sensible formatting. + +Usage: + python3 fetch_logs.py + python3 fetch_logs.py --host https://cloud.comfy.org + python3 fetch_logs.py --tail-queue # show currently queued/running jobs +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, emit_json, http_get, is_cloud_host, + resolve_api_key, resolve_url, +) + + +def fetch_history_entry(host: str, headers: dict, prompt_id: str, *, is_cloud: bool) -> dict: + if is_cloud: + # Try /jobs/{id} first + url = resolve_url(host, f"/jobs/{prompt_id}", is_cloud=True) + r = http_get(url, headers=headers, retries=2, timeout=30) + if r.status == 200: + try: + return {"ok": True, "entry": r.json(), "source": "/api/jobs"} + except Exception: + pass + # Fallback to history_v2 + url = resolve_url(host, f"/history/{prompt_id}", is_cloud=True) + r = http_get(url, headers=headers, retries=2, timeout=30) + try: + data = r.json() + except Exception: + data = None + if r.status == 200 and data: + return {"ok": True, "entry": data, "source": "/api/history_v2"} + return {"ok": False, "http_status": r.status, "body": r.text()[:500]} + + url = resolve_url(host, f"/history/{prompt_id}", is_cloud=False) + r = http_get(url, headers=headers, retries=2, timeout=30) + if r.status != 200: + return {"ok": False, "http_status": r.status, "body": r.text()[:500]} + try: + data = r.json() + except Exception: + return {"ok": False, "reason": "non-JSON response"} + if not isinstance(data, dict) or prompt_id not in data: + return {"ok": False, "reason": "prompt_id not found in history", + "history_keys": list(data.keys())[:5] if isinstance(data, dict) else []} + return {"ok": True, "entry": data[prompt_id], "source": "/history"} + + +def fetch_queue(host: str, headers: dict) -> dict: + url = resolve_url(host, "/queue") + r = http_get(url, headers=headers, retries=2, timeout=15) + try: + data = r.json() + except Exception: + data = {"raw": r.text()[:500]} + return {"http_status": r.status, "data": data} + + +def extract_diagnostics(entry: dict) -> dict: + """Pull out the parts a human cares about: status, errors, traceback, timing.""" + diag: dict = {} + status = entry.get("status") or {} + diag["status_str"] = status.get("status_str") + diag["completed"] = status.get("completed") + + messages = status.get("messages") or [] + diag["execution_log"] = [] + for msg in messages: + if isinstance(msg, list) and len(msg) >= 2: + mtype, mdata = msg[0], msg[1] + diag["execution_log"].append({"type": mtype, "data": mdata}) + else: + diag["execution_log"].append(msg) + + # Look for execution_error inside messages + errors = [] + for msg in messages: + if isinstance(msg, list) and len(msg) >= 2 and msg[0] == "execution_error": + errors.append(msg[1]) + if errors: + diag["errors"] = errors + + # Cloud's /jobs response shape: top-level outputs / status / etc. + if "outputs" in entry: + out = entry["outputs"] or {} + if isinstance(out, dict): + diag["output_node_ids"] = list(out.keys()) + # Count file refs across all output buckets (images / video / etc.) + total = 0 + for node_output in out.values(): + if not isinstance(node_output, dict): + continue + for v in node_output.values(): + if isinstance(v, list): + total += len(v) + diag["output_count"] = total + else: + diag["output_node_ids"] = [] + diag["output_count"] = 0 + return diag + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="Fetch workflow execution diagnostics") + p.add_argument("prompt_id", nargs="?", help="prompt_id to look up") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST) + p.add_argument("--api-key", help=f"or set ${ENV_API_KEY}") + p.add_argument("--raw", action="store_true", + help="Print the full history entry instead of the digest") + p.add_argument("--tail-queue", action="store_true", + help="Show currently running/pending jobs instead") + args = p.parse_args(argv) + + api_key = resolve_api_key(args.api_key) + headers = {"X-API-Key": api_key} if api_key else {} + is_cloud = is_cloud_host(args.host) + + if args.tail_queue: + emit_json(fetch_queue(args.host, headers)) + return 0 + + if not args.prompt_id: + print("Error: prompt_id is required (or use --tail-queue)", file=sys.stderr) + return 1 + + res = fetch_history_entry(args.host, headers, args.prompt_id, is_cloud=is_cloud) + if not res.get("ok"): + emit_json(res) + return 1 + + if args.raw: + emit_json(res) + return 0 + + diag = extract_diagnostics(res["entry"]) + diag["source"] = res.get("source") + diag["prompt_id"] = args.prompt_id + emit_json(diag) + return 0 if diag.get("status_str") not in ("error",) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/hardware_check.py b/skills/creative/comfyui/scripts/hardware_check.py new file mode 100755 index 0000000000..6a4d6c6d40 --- /dev/null +++ b/skills/creative/comfyui/scripts/hardware_check.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +"""hardware_check.py — Detect whether this machine can realistically run ComfyUI locally. + +Improvements over v1: + - Multi-GPU detection: scans all NVIDIA / AMD GPUs, picks the best one (most VRAM) + - Apple Silicon: detects Rosetta-via-x86_64 false negative; warns instead of misclassifying + - Apple generation: defaults to None (unknown) instead of mis-tagging as M1 + - WSL2 detection: identifies WSL2 + nvidia-smi situation explicitly + - ROCm: prefers `rocm-smi --json` for new ROCm 6.x output + - Disk space check: warns if /home or workspace volume has < 25 GB free + - PyTorch verification (optional): tries to import torch and check device availability + - Windows: prefers PowerShell `Get-CimInstance` over deprecated `wmic` + - More accurate VRAM thresholds and verdict reasons + +Emits a structured JSON report. Exit codes match `verdict`: + 0 → ok + 1 → marginal + 2 → cloud + +Usage: + python3 hardware_check.py [--json] [--check-pytorch] +""" + +from __future__ import annotations + +import json +import os +import platform +import re +import shutil +import subprocess +import sys +from typing import Any + + +# Thresholds (GiB). +MIN_VRAM_GB_USABLE = 6 +OK_VRAM_GB = 8 +GREAT_VRAM_GB = 12 +MIN_MAC_RAM_GB = 16 +OK_MAC_RAM_GB = 32 +MIN_FREE_DISK_GB = 25 # ComfyUI core ~5 GB + one model ~5–24 GB + +_COMFY_CLI_FLAG = { + "nvidia": "--nvidia", + "amd": "--amd", + "apple-silicon": "--m-series", + "intel": None, + "comfy-cloud": None, + "cpu": "--cpu", +} + + +def _run(cmd: list[str], timeout: int = 8) -> str: + try: + out = subprocess.run( + cmd, capture_output=True, text=True, timeout=timeout, check=False + ) + return (out.stdout or "") + (out.stderr or "") + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + return "" + + +def is_wsl() -> bool: + """Return True when running under Windows Subsystem for Linux.""" + if platform.system() != "Linux": + return False + if "microsoft" in platform.release().lower() or "wsl" in platform.release().lower(): + return True + try: + with open("/proc/version", "r") as fh: + return "microsoft" in fh.read().lower() + except OSError: + return False + + +def is_rosetta() -> bool: + """Return True when Python is running translated under Rosetta on Apple Silicon.""" + if platform.system() != "Darwin": + return False + if platform.machine() == "arm64": + return False + # x86_64 on Darwin — could be Intel Mac or Rosetta. Probe sysctl. + out = _run(["sysctl", "-in", "sysctl.proc_translated"]).strip() + return out == "1" + + +def detect_nvidia() -> dict | None: + """Detect NVIDIA GPUs. Returns the GPU with the most VRAM, plus list of all.""" + if not shutil.which("nvidia-smi"): + return None + out = _run([ + "nvidia-smi", + "--query-gpu=index,name,memory.total,driver_version", + "--format=csv,noheader,nounits", + ]) + if not out.strip(): + return None + gpus = [] + for line in out.strip().splitlines(): + parts = [p.strip() for p in line.split(",")] + if len(parts) < 3: + continue + try: + idx = int(parts[0]) + name = parts[1] + vram_mb = int(parts[2]) + except ValueError: + continue + driver = parts[3] if len(parts) > 3 else "" + gpus.append({ + "vendor": "nvidia", + "index": idx, + "name": name, + "vram_gb": round(vram_mb / 1024, 1), + "driver": driver, + }) + if not gpus: + return None + # Pick GPU with most VRAM + best = max(gpus, key=lambda g: g["vram_gb"]) + if len(gpus) > 1: + best["all_gpus"] = gpus + return best + + +def detect_rocm() -> dict | None: + if not shutil.which("rocm-smi"): + return None + # Prefer JSON output (new ROCm 6.x) + out = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram", "--json"]) + if out.strip().startswith("{"): + try: + data = json.loads(out) + cards = [] + for card_id, info in data.items(): + if not card_id.startswith("card"): + continue + name = (info.get("Card series") or info.get("Card model") + or info.get("Marketing Name") or "AMD GPU") + vram_b = info.get("VRAM Total Memory (B)") or info.get("vram_total_memory_b") or 0 + try: + vram_b = int(vram_b) + except (ValueError, TypeError): + vram_b = 0 + cards.append({ + "vendor": "amd", + "name": str(name).strip(), + "vram_gb": round(vram_b / (1024**3), 1), + "driver": "rocm", + }) + if cards: + best = max(cards, key=lambda c: c["vram_gb"]) + if len(cards) > 1: + best["all_gpus"] = cards + return best + except json.JSONDecodeError: + pass + # Fall back to text parsing + out = _run(["rocm-smi", "--showproductname", "--showmeminfo", "vram"]) + if not out.strip(): + return None + name_m = re.search(r"Card (?:series|model|Marketing Name):\s*(.+)", out) + vram_m = re.search(r"VRAM Total Memory \(B\):\s*(\d+)", out) + vram_gb = round(int(vram_m.group(1)) / (1024**3), 1) if vram_m else 0.0 + return { + "vendor": "amd", + "name": name_m.group(1).strip() if name_m else "AMD GPU", + "vram_gb": vram_gb, + "driver": "rocm", + } + + +def detect_apple_silicon() -> dict | None: + if platform.system() != "Darwin": + return None + if platform.machine() != "arm64": + return None + chip = _run(["sysctl", "-n", "machdep.cpu.brand_string"]).strip() + m = re.search(r"Apple M(\d+)", chip) + generation = int(m.group(1)) if m else None + mem_bytes = 0 + try: + mem_bytes = int(_run(["sysctl", "-n", "hw.memsize"]).strip() or 0) + except ValueError: + pass + ram_gb = round(mem_bytes / (1024**3), 1) if mem_bytes else 0.0 + + # Detect chip variant ("Pro", "Max", "Ultra") — affects performance even at same gen + variant = None + for v in ("Ultra", "Max", "Pro"): + if v in chip: + variant = v + break + + return { + "vendor": "apple", + "name": chip or "Apple Silicon", + "generation": generation, + "variant": variant, + "unified_memory_gb": ram_gb, + } + + +def detect_intel_arc() -> dict | None: + if platform.system() not in ("Linux", "Windows"): + return None + if shutil.which("clinfo"): + out = _run(["clinfo", "--list"]) + if "Intel" in out and ("Arc" in out or "Xe" in out): + return {"vendor": "intel", "name": "Intel Arc/Xe", "vram_gb": 0.0} + # Windows: try Get-CimInstance + if platform.system() == "Windows" and shutil.which("powershell"): + out = _run(["powershell", "-NoProfile", + "Get-CimInstance Win32_VideoController | Select-Object Name | Format-List"]) + if "Intel" in out and ("Arc" in out or "Iris Xe" in out): + return {"vendor": "intel", "name": "Intel Arc/Iris Xe", "vram_gb": 0.0} + return None + + +def total_system_ram_gb() -> float: + sysname = platform.system() + if sysname == "Darwin": + try: + return round(int(_run(["sysctl", "-n", "hw.memsize"]).strip() or 0) / (1024**3), 1) + except ValueError: + return 0.0 + if sysname == "Linux": + try: + with open("/proc/meminfo", "r") as fh: + for line in fh: + if line.startswith("MemTotal:"): + kb = int(line.split()[1]) + return round(kb / (1024**2), 1) + except OSError: + return 0.0 + if sysname == "Windows": + if shutil.which("powershell"): + out = _run([ + "powershell", "-NoProfile", + "(Get-CimInstance Win32_ComputerSystem).TotalPhysicalMemory", + ]) + m = re.search(r"(\d{8,})", out) + if m: + return round(int(m.group(1)) / (1024**3), 1) + # Fall back to wmic for older Windows + out = _run(["wmic", "ComputerSystem", "get", "TotalPhysicalMemory"]) + m = re.search(r"(\d{6,})", out) + if m: + return round(int(m.group(1)) / (1024**3), 1) + return 0.0 + + +def total_free_disk_gb(path: str = ".") -> float: + try: + usage = shutil.disk_usage(path) + return round(usage.free / (1024**3), 1) + except OSError: + return 0.0 + + +def check_pytorch_cuda() -> dict | None: + """Optional PyTorch availability check. Only run when --check-pytorch is set.""" + try: + import torch # type: ignore[import-not-found] + except Exception as e: + return {"available": False, "reason": f"torch not importable: {e}"} + info: dict[str, Any] = { + "available": True, + "torch_version": torch.__version__, + } + try: + info["cuda_available"] = bool(torch.cuda.is_available()) + if info["cuda_available"]: + info["cuda_device_count"] = torch.cuda.device_count() + info["cuda_device_0"] = torch.cuda.get_device_name(0) + except Exception: + info["cuda_available"] = False + try: + info["mps_available"] = bool(torch.backends.mps.is_available()) + except Exception: + info["mps_available"] = False + return info + + +def classify(gpu: dict | None, ram_gb: float, free_disk_gb: float, *, wsl: bool, rosetta: bool) -> tuple[str, str, list[str]]: + notes: list[str] = [] + + if rosetta: + notes.append( + "Detected Python running under Rosetta on Apple Silicon. " + "ComfyUI MPS support requires native ARM64 Python — install via " + "`brew install python` or arm64 Miniforge, then re-run." + ) + return "cloud", "comfy-cloud", notes + + if wsl and gpu and gpu["vendor"] == "nvidia": + notes.append("Detected WSL2 + NVIDIA — confirm `nvidia-smi` works in your WSL distro before installing.") + + if free_disk_gb and free_disk_gb < MIN_FREE_DISK_GB: + notes.append( + f"Free disk space ({free_disk_gb} GB) is below the {MIN_FREE_DISK_GB} GB recommended minimum. " + "ComfyUI core (~5 GB) plus one SDXL model (~6.5 GB) needs space; Flux Dev needs ~24 GB." + ) + + # Host RAM matters even for discrete-GPU systems: ComfyUI swaps model + # weights through CPU RAM when shuffling between text encoders / VAE / UNet. + # Apple's unified-memory check is handled below so don't double-warn. + if ram_gb and ram_gb < 8 and gpu and gpu.get("vendor") != "apple": + notes.append( + f"System RAM ({ram_gb} GB) is low. ComfyUI swaps model weights through " + "host RAM; <8 GB causes severe slowdowns. 16+ GB recommended." + ) + + if gpu is None: + notes.append( + "No supported accelerator found (NVIDIA CUDA / AMD ROCm / Apple Silicon / Intel Arc)." + ) + notes.append( + "CPU-only ComfyUI works but is unusably slow for modern models — use Comfy Cloud." + ) + return "cloud", "comfy-cloud", notes + + if gpu["vendor"] == "apple": + gen = gpu.get("generation") + variant = gpu.get("variant") + mem = gpu.get("unified_memory_gb", 0.0) + gen_str = f"M{gen}" if gen else "Apple Silicon" + if variant: + gen_str += f" {variant}" + if mem < MIN_MAC_RAM_GB: + notes.append( + f"{gen_str} with {mem} GB unified memory — below the {MIN_MAC_RAM_GB} GB practical minimum." + ) + notes.append("SD1.5 may work; SDXL/Flux will swap or OOM. Recommend Comfy Cloud.") + return "cloud", "comfy-cloud", notes + if mem < OK_MAC_RAM_GB: + notes.append( + f"{gen_str} with {mem} GB — SDXL works but slow. Flux/video likely too tight." + ) + return "marginal", "apple-silicon", notes + notes.append(f"{gen_str} with {mem} GB unified memory — good for SDXL/Flux.") + return "ok", "apple-silicon", notes + + if gpu["vendor"] == "intel": + notes.append("Intel Arc detected — ComfyUI IPEX support is experimental; Comfy Cloud is more reliable.") + return "marginal", "intel", notes + + # Discrete NVIDIA / AMD + vram = gpu.get("vram_gb", 0.0) + name = gpu["name"] + if vram < MIN_VRAM_GB_USABLE: + notes.append( + f"{name} has only {vram} GB VRAM — below the {MIN_VRAM_GB_USABLE} GB practical minimum." + ) + notes.append("Most modern models won't load. Recommend Comfy Cloud.") + return "cloud", "comfy-cloud", notes + if vram < OK_VRAM_GB: + notes.append( + f"{name} ({vram} GB VRAM) — SD1.5 works, SDXL tight, Flux/video unlikely." + ) + return "marginal", gpu["vendor"], notes + if vram < GREAT_VRAM_GB: + notes.append(f"{name} ({vram} GB VRAM) — SDXL comfortable, Flux possible with optimizations.") + return "ok", gpu["vendor"], notes + notes.append(f"{name} ({vram} GB VRAM) — can run everything including Flux/video.") + return "ok", gpu["vendor"], notes + + +def build_report(*, check_pytorch: bool = False) -> dict: + sysname = platform.system() + arch = platform.machine() + ram_gb = total_system_ram_gb() + free_disk_gb = total_free_disk_gb(os.path.expanduser("~")) + + rosetta = is_rosetta() + wsl = is_wsl() + + gpu = ( + detect_nvidia() + or detect_rocm() + or detect_apple_silicon() + or detect_intel_arc() + ) + + # Intel Mac: arm64 detect failed AND no other GPU paths + if gpu is None and sysname == "Darwin" and arch != "arm64" and not rosetta: + notes = [ + "Intel Mac detected — no MPS backend available.", + "ComfyUI will fall back to CPU which is unusably slow. Use Comfy Cloud.", + ] + report = { + "os": sysname, + "arch": arch, + "system_ram_gb": ram_gb, + "free_disk_gb": free_disk_gb, + "wsl": False, + "rosetta": False, + "gpu": None, + "verdict": "cloud", + "recommended_install_path": "comfy-cloud", + "comfy_cli_flag": None, + "notes": notes, + "install_urls": _install_urls(), + } + if check_pytorch: + report["pytorch"] = check_pytorch_cuda() + return report + + verdict, install_path, notes = classify( + gpu, ram_gb, free_disk_gb, wsl=wsl, rosetta=rosetta, + ) + + report = { + "os": sysname, + "arch": arch, + "system_ram_gb": ram_gb, + "free_disk_gb": free_disk_gb, + "wsl": wsl, + "rosetta": rosetta, + "gpu": gpu, + "verdict": verdict, + "recommended_install_path": install_path, + "comfy_cli_flag": _COMFY_CLI_FLAG.get(install_path), + "notes": notes, + "install_urls": _install_urls(), + } + if check_pytorch: + report["pytorch"] = check_pytorch_cuda() + return report + + +def _install_urls() -> dict: + return { + "desktop": "https://docs.comfy.org/installation/desktop", + "manual": "https://docs.comfy.org/installation/manual_install", + "comfy_cli": "https://docs.comfy.org/comfy-cli/getting-started", + "cloud": "https://platform.comfy.org", + } + + +def main(argv: list[str] | None = None) -> int: + import argparse + p = argparse.ArgumentParser(description="Check whether this machine can run ComfyUI locally.") + p.add_argument("--json", action="store_true", help="Emit machine-readable JSON only") + p.add_argument("--check-pytorch", action="store_true", + help="Also probe `torch` for CUDA/MPS availability (slower)") + args = p.parse_args(argv) + + report = build_report(check_pytorch=args.check_pytorch) + + if args.json: + print(json.dumps(report, indent=2)) + else: + print(f"OS: {report['os']} ({report['arch']})") + if report.get("wsl"): + print("Env: WSL2") + if report.get("rosetta"): + print("Env: Rosetta (x86_64 Python on Apple Silicon)") + print(f"RAM: {report['system_ram_gb']} GB") + print(f"Free disk: {report['free_disk_gb']} GB (~/)") + if report["gpu"]: + g = report["gpu"] + if g["vendor"] == "apple": + print(f"GPU: {g['name']} — {g.get('unified_memory_gb', 0)} GB unified memory") + else: + print(f"GPU: {g['name']} — {g.get('vram_gb', 0)} GB VRAM") + if g.get("all_gpus") and len(g["all_gpus"]) > 1: + print(f" ({len(g['all_gpus'])} GPUs total; using best by VRAM)") + else: + print("GPU: (none detected)") + print(f"Verdict: {report['verdict']} → {report['recommended_install_path']}") + if report["comfy_cli_flag"]: + print(f" run: comfy --skip-prompt install {report['comfy_cli_flag']}") + if report.get("pytorch"): + pt = report["pytorch"] + if pt.get("available"): + line = f"PyTorch: {pt.get('torch_version')}" + if pt.get("cuda_available"): + line += f" + CUDA ({pt.get('cuda_device_0', '?')})" + if pt.get("mps_available"): + line += " + MPS" + print(line) + else: + print(f"PyTorch: not available — {pt.get('reason')}") + for n in report["notes"]: + print(f" • {n}") + + if report["verdict"] == "ok": + return 0 + if report["verdict"] == "marginal": + return 1 + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/health_check.py b/skills/creative/comfyui/scripts/health_check.py new file mode 100755 index 0000000000..63c5025ca9 --- /dev/null +++ b/skills/creative/comfyui/scripts/health_check.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +health_check.py — One-stop verification that the ComfyUI environment is ready. + +Runs through the verification checklist: + 1. comfy-cli on PATH + 2. server reachable (/system_stats) + 3. at least one checkpoint installed + 4. (optional) a specific workflow's deps are met + 5. (optional) actually submit a tiny test workflow and verify round-trip + +Usage: + python3 health_check.py + python3 health_check.py --host https://cloud.comfy.org + python3 health_check.py --workflow my.json + python3 health_check.py --smoke-test # actually submit a tiny workflow +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, emit_json, http_get, parse_model_list, + resolve_api_key, resolve_url, unwrap_workflow, +) + + +def comfy_cli_status() -> dict: + if shutil.which("comfy"): + return {"available": True, "method": "comfy", "path": shutil.which("comfy")} + if shutil.which("uvx"): + return {"available": True, "method": "uvx", + "hint": "Invoke as `uvx --from comfy-cli comfy ...`"} + return { + "available": False, + "hint": "Install with: pipx install comfy-cli (or `pip install comfy-cli`)", + } + + +def server_status(host: str, headers: dict) -> dict: + url = resolve_url(host, "/system_stats") + try: + r = http_get(url, headers=headers, retries=2, timeout=10) + if r.status == 200: + try: + stats = r.json() or {} + except Exception: + stats = {} + return {"reachable": True, "url": url, "stats": stats} + return {"reachable": False, "url": url, "http_status": r.status, "body": r.text()[:200]} + except Exception as e: + return {"reachable": False, "url": url, "error": str(e)} + + +def checkpoint_status(host: str, headers: dict) -> dict: + url = resolve_url(host, "/models/checkpoints") + try: + r = http_get(url, headers=headers, retries=2, timeout=15) + except Exception as e: + return {"queryable": False, "error": str(e)} + if r.status != 200: + return {"queryable": False, "http_status": r.status, "url": url, "body": r.text()[:200]} + try: + models = parse_model_list(r.json()) + except Exception: + models = set() + return {"queryable": True, "count": len(models), + "first_few": sorted(models)[:5]} + + +SMOKE_WORKFLOW = { + # Minimal SD1.5 workflow that doesn't depend on rare nodes. + # 256x256 + 1 step is the smallest config that doesn't trigger SDXL/Flux + # validation errors while still executing fast. + "3": { + "class_type": "KSampler", + "inputs": { + "seed": 1, "steps": 1, "cfg": 7.0, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + "model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0], + "latent_image": ["5", 0], + }, + }, + "4": {"class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "REPLACE_ME"}}, + "5": {"class_type": "EmptyLatentImage", + "inputs": {"width": 256, "height": 256, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", + "inputs": {"text": "test", "clip": ["4", 1]}}, + "7": {"class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["4", 1]}}, + "9": {"class_type": "SaveImage", + "inputs": {"filename_prefix": "smoke", "images": ["3", 0]}}, +} + + +def smoke_test(host: str, headers: dict, ckpt_name: str | None) -> dict: + """Submit a tiny workflow and verify the server accepts it. + + Cancels the job immediately after acceptance so we don't burn GPU + time / cloud minutes on a smoke test. + """ + if not ckpt_name: + return {"ran": False, "reason": "no checkpoint available"} + wf = json.loads(json.dumps(SMOKE_WORKFLOW)) + wf["4"]["inputs"]["ckpt_name"] = ckpt_name + + # Lazy import to avoid circular issues + from run_workflow import ComfyRunner + api_key = headers.get("X-API-Key") + runner = ComfyRunner(host=host, api_key=api_key) + sub = runner.submit(wf) + if "_http_error" in sub: + return {"ran": True, "submitted": False, + "http_status": sub["_http_error"], "body": sub.get("body")} + pid = sub.get("prompt_id") + if not pid: + return {"ran": True, "submitted": False, "response": sub} + + # Cancel so we don't actually waste compute on the smoke test. + cancelled = False + try: + cancelled = runner.cancel(pid) + except Exception: + pass + + return { + "ran": True, "submitted": True, "prompt_id": pid, + "cancelled_after_submit": cancelled, + "note": "Submission accepted; cancelled to avoid running the full pipeline.", + } + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="One-stop ComfyUI health check") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST) + p.add_argument("--api-key", help=f"or set ${ENV_API_KEY}") + p.add_argument("--workflow", help="Optional: also run check_deps on this workflow") + p.add_argument("--smoke-test", action="store_true", + help="Submit a tiny test workflow and verify round-trip") + p.add_argument("--strict", action="store_true", + help="Exit non-zero on any non-pass condition (including warnings)") + args = p.parse_args(argv) + + api_key = resolve_api_key(args.api_key) + headers = {"X-API-Key": api_key} if api_key else {} + + cli = comfy_cli_status() + server = server_status(args.host, headers) + ckpts = checkpoint_status(args.host, headers) if server.get("reachable") else None + + # ---- workflow check ---- + workflow_check: dict | None = None + if args.workflow: + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + workflow_check = {"error": "workflow file not found"} + else: + try: + with wf_path.open() as f: + workflow = unwrap_workflow(json.load(f)) + from check_deps import check_deps + workflow_check = check_deps(workflow, host=args.host, api_key=api_key) + except (ValueError, json.JSONDecodeError) as e: + workflow_check = {"error": str(e)} + + smoke = None + if args.smoke_test and server.get("reachable"): + first_ckpt = ckpts["first_few"][0] if ckpts and ckpts.get("first_few") else None + smoke = smoke_test(args.host, headers, first_ckpt) + + # ---- verdict ---- + verdict = "pass" + reasons: list[str] = [] + if not server.get("reachable"): + verdict = "fail" + reasons.append("server unreachable") + if ckpts and ckpts.get("queryable") and ckpts.get("count", 0) == 0: + verdict = "warn" if verdict == "pass" else verdict + reasons.append("no checkpoints installed") + if workflow_check and workflow_check.get("error"): + verdict = "fail" + reasons.append(f"workflow check failed: {workflow_check['error']}") + elif workflow_check and not workflow_check.get("is_ready"): + if workflow_check.get("node_check_skipped"): + reasons.append("node check skipped (cloud free tier)") + else: + verdict = "fail" + reasons.append("workflow has missing deps") + if smoke and smoke.get("ran") and not smoke.get("submitted"): + verdict = "fail" + reasons.append("smoke-test submission failed") + if not cli.get("available"): + verdict = "warn" if verdict == "pass" else verdict + reasons.append("comfy-cli not on PATH (lifecycle commands won't work)") + + report = { + "verdict": verdict, + "reasons": reasons, + "host": args.host, + "comfy_cli": cli, + "server": server, + "checkpoints": ckpts, + "workflow_check": workflow_check, + "smoke_test": smoke, + } + emit_json(report) + + if verdict == "pass": + return 0 + if verdict == "warn": + return 1 if args.strict else 0 + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/run_batch.py b/skills/creative/comfyui/scripts/run_batch.py new file mode 100755 index 0000000000..7f5b159dbd --- /dev/null +++ b/skills/creative/comfyui/scripts/run_batch.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 +""" +run_batch.py — Run a workflow many times, varying parameters per run. + +Two modes: + 1. --count N --randomize-seed + Submit N runs, each with a fresh random seed. Use for quick variations. + 2. --sweep '{"seed": [1,2,3], "steps": [20,30]}' + Cartesian product of values. With cloud subscription, runs in parallel + up to your tier's concurrent-job limit. + +Both modes write each run's outputs into output-dir/run_NNN/. + +Examples: + python3 run_batch.py --workflow flux_dev.json \ + --args '{"prompt": "a cat"}' \ + --count 8 --randomize-seed \ + --output-dir ./outputs/cat-batch + + python3 run_batch.py --workflow sdxl.json \ + --args '{"prompt": "abstract"}' \ + --sweep '{"seed": [1,2,3], "steps": [20, 40]}' \ + --output-dir ./outputs/sweep +""" + +from __future__ import annotations + +import argparse +import itertools +import json +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, coerce_seed, emit_json, log, + looks_like_video_workflow, resolve_api_key, unwrap_workflow, +) +from run_workflow import ( # noqa: E402 + ComfyRunner, download_outputs, inject_params, +) +from extract_schema import extract_schema # noqa: E402 + + +def expand_sweep(sweep: dict, base_args: dict, count: int, randomize_seed: bool) -> list[dict]: + """Generate a list of args dicts for each run.""" + if sweep: + # Cartesian product + keys = list(sweep.keys()) + values = [sweep[k] if isinstance(sweep[k], list) else [sweep[k]] for k in keys] + runs = [] + for combo in itertools.product(*values): + ar = dict(base_args) + for k, v in zip(keys, combo): + ar[k] = v + runs.append(ar) + return runs + # Count mode + runs = [] + for _ in range(count): + ar = dict(base_args) + if randomize_seed: + ar["seed"] = coerce_seed(None) + runs.append(ar) + return runs + + +def execute_one( + runner: ComfyRunner, workflow: dict, schema: dict, args: dict, + *, output_dir: Path, timeout: int, ws: bool, +) -> dict: + wf, warnings = inject_params(workflow, schema, args) + sub = runner.submit(wf) + if "_http_error" in sub: + return {"status": "error", "error": "submission HTTP error", + "details": sub.get("body"), "args": args} + pid = sub.get("prompt_id") + if not pid: + return {"status": "error", "error": "no prompt_id", "response": sub, "args": args} + if sub.get("node_errors"): + return {"status": "error", "error": "validation failed", + "node_errors": sub["node_errors"], "args": args} + + if ws: + result = runner.monitor_ws(pid, timeout=timeout) + else: + result = runner.poll_status(pid, timeout=timeout) + + if result["status"] != "success": + return { + "status": result["status"], + "prompt_id": pid, + "details": result.get("data"), + "args": args, + } + + outputs = result.get("outputs") or runner.get_outputs(pid) + downloaded = download_outputs(runner, outputs, output_dir, preserve_subfolder=False) + return { + "status": "success", + "prompt_id": pid, + "args": args, + "outputs": downloaded, + "warnings": warnings, + } + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser( + description="Submit a workflow many times with varying parameters.", + ) + p.add_argument("--workflow", required=True) + p.add_argument("--args", default="{}", help="Base parameters JSON") + p.add_argument("--count", type=int, default=0, + help="Number of runs (use with --randomize-seed)") + p.add_argument("--sweep", default="", + help='JSON dict of param→list of values. Cartesian product. ' + 'e.g. \'{"seed":[1,2,3],"cfg":[5,8]}\'') + p.add_argument("--randomize-seed", action="store_true", + help="In --count mode, vary seed per run") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST) + p.add_argument("--api-key", help=f"or set ${ENV_API_KEY}") + p.add_argument("--partner-key") + p.add_argument("--parallel", type=int, default=1, + help="Concurrent submissions (cloud: up to your tier limit). " + "Default 1 (sequential)") + p.add_argument("--output-dir", default="./outputs/batch") + p.add_argument("--timeout", type=int, default=0) + p.add_argument("--ws", action="store_true") + p.add_argument("--continue-on-error", action="store_true", + help="Don't stop the batch when a run fails") + args = p.parse_args(argv) + + if args.count <= 0 and not args.sweep: + emit_json({"error": "Specify --count N or --sweep '{...}'"}) + return 1 + + base_args = json.loads(args.args) if args.args.strip() else {} + sweep = json.loads(args.sweep) if args.sweep.strip() else {} + + # Validate sweep shape + if sweep: + if not isinstance(sweep, dict): + emit_json({"error": "--sweep must be a JSON object {param: [values]}"}) + return 1 + empty = [k for k, v in sweep.items() if isinstance(v, list) and len(v) == 0] + if empty: + emit_json({"error": f"--sweep parameters have empty value lists: {empty}"}) + return 1 + # If user passed BOTH --sweep and --count/--randomize-seed, --sweep wins + if args.count or args.randomize_seed: + log("--sweep set; ignoring --count / --randomize-seed (sweep defines the runs)") + + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + emit_json({"error": f"Workflow not found: {args.workflow}"}) + return 1 + try: + with wf_path.open() as f: + workflow = unwrap_workflow(json.load(f)) + except (ValueError, json.JSONDecodeError) as e: + emit_json({"error": str(e)}) + return 1 + + schema = extract_schema(workflow) + runs = expand_sweep(sweep, base_args, args.count, args.randomize_seed) + log(f"Planned {len(runs)} run(s)") + + api_key = resolve_api_key(args.api_key) + runner = ComfyRunner(host=args.host, api_key=api_key, partner_key=args.partner_key) + + ok, info = runner.check_server() + if not ok: + emit_json({"error": "Cannot reach server", "details": info, "host": args.host}) + return 1 + + timeout = args.timeout + if timeout <= 0: + timeout = 900 if looks_like_video_workflow(workflow) else 300 + + base_dir = Path(args.output_dir).expanduser() + base_dir.mkdir(parents=True, exist_ok=True) + + results: list[dict] = [] + failures = 0 + + if args.parallel > 1: + with ThreadPoolExecutor(max_workers=args.parallel) as ex: + future_to_idx = {} + for i, ar in enumerate(runs): + run_dir = base_dir / f"run_{i:04d}" + fut = ex.submit( + execute_one, runner, workflow, schema, ar, + output_dir=run_dir, timeout=timeout, ws=args.ws, + ) + future_to_idx[fut] = i + for fut in as_completed(future_to_idx): + i = future_to_idx[fut] + try: + r = fut.result() + except Exception as e: + r = {"status": "error", "error": str(e), "args": runs[i]} + r["index"] = i + results.append(r) + if r["status"] != "success": + failures += 1 + log(f" run {i} → {r['status']}: {r.get('error','?')}") + if not args.continue_on_error: + log(" --continue-on-error not set; aborting batch") + break + else: + log(f" run {i} → success: {len(r.get('outputs', []))} files") + else: + for i, ar in enumerate(runs): + run_dir = base_dir / f"run_{i:04d}" + r = execute_one(runner, workflow, schema, ar, + output_dir=run_dir, timeout=timeout, ws=args.ws) + r["index"] = i + results.append(r) + if r["status"] != "success": + failures += 1 + log(f" run {i} → {r['status']}: {r.get('error','?')}") + if not args.continue_on_error: + log(" --continue-on-error not set; aborting batch") + break + else: + log(f" run {i} → success: {len(r.get('outputs', []))} files") + + results.sort(key=lambda x: x.get("index", 0)) + emit_json({ + "status": "success" if failures == 0 else "partial", + "total": len(runs), + "completed": sum(1 for r in results if r["status"] == "success"), + "failed": failures, + "output_dir": str(base_dir), + "results": results, + }) + return 0 if failures == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/run_workflow.py b/skills/creative/comfyui/scripts/run_workflow.py new file mode 100755 index 0000000000..444957960b --- /dev/null +++ b/skills/creative/comfyui/scripts/run_workflow.py @@ -0,0 +1,796 @@ +#!/usr/bin/env python3 +""" +run_workflow.py — Inject parameters into a ComfyUI workflow, submit it, monitor +execution, and download outputs. + +Improvements over v1: + - Cloud-aware URL routing (handles /api prefix and /history_v2 / /experiment/models renames) + - API key from CLI flag OR $COMFY_CLOUD_API_KEY env var + - WebSocket progress monitoring (--ws), with HTTP polling fallback + - Streaming download (no whole-file buffering — handles GB-size video outputs) + - Path-traversal-safe output writes + - Subfolder-aware download paths (no silent overwrites) + - Retry with exponential backoff on transient errors + - Status-error correctly classified before "completed: true" + - Image upload helper (--input-image NAME=PATH) + - Auto-randomize seed when value is -1 or omitted on a randomize-seed flag + - Auto-extends timeout heuristically for video workflows + - Editor-format detection with helpful error + - Doesn't pollute extra_data.api_key_comfy_org with the cloud auth key + unless --partner-key is provided (correct semantic per cloud docs) + +Usage: + # Local server + python3 run_workflow.py --workflow workflow_api.json \ + --args '{"prompt": "a cat", "seed": 42}' \ + --output-dir ./outputs + + # Cloud server (API key from env var) + export COMFY_CLOUD_API_KEY="comfyui-xxxxxxx" + python3 run_workflow.py --workflow workflow_api.json \ + --args '{"prompt": "a cat"}' \ + --host https://cloud.comfy.org \ + --output-dir ./outputs + + # With image input (auto-uploads, then references) + python3 run_workflow.py --workflow img2img.json \ + --input-image image=./photo.png \ + --args '{"prompt": "make it cyberpunk"}' + + # WebSocket real-time progress + python3 run_workflow.py --workflow flux_dev.json \ + --args '{"prompt": "..."}' \ + --ws + +Stdlib-only by default (Python 3.10+). Will use `requests`/`websocket-client` +if installed for nicer behavior. +""" + +from __future__ import annotations + +import argparse +import copy +import json +import sys +import time +from pathlib import Path +from typing import Any +from urllib.parse import urlencode, urlparse + +# Local import — _common.py sits next to this script. +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, + coerce_seed, emit_json, http_get, http_post, http_request, + is_cloud_host, is_link, log, looks_like_video_workflow, + media_type_from_filename, new_client_id, resolve_api_key, resolve_url, + safe_path_join, unwrap_workflow, +) + + +# ============================================================================= +# Runner +# ============================================================================= + +class WorkflowRunError(Exception): + """Raised when a workflow run fails (validation, execution, timeout).""" + + def __init__(self, status: str, message: str, **details: Any): + super().__init__(message) + self.status = status + self.message = message + self.details = details + + def to_dict(self) -> dict: + d = {"status": self.status, "error": self.message} + d.update(self.details) + return d + + +class ComfyRunner: + def __init__( + self, + host: str = DEFAULT_LOCAL_HOST, + api_key: str | None = None, + client_id: str | None = None, + partner_key: str | None = None, + ): + self.host = host.rstrip("/") + self.api_key = api_key + self.partner_key = partner_key + self.is_cloud = is_cloud_host(self.host) + self.client_id = client_id or new_client_id() + + @property + def headers(self) -> dict[str, str]: + h: dict[str, str] = {} + if self.api_key: + h["X-API-Key"] = self.api_key + return h + + def _url(self, path: str) -> str: + return resolve_url(self.host, path, is_cloud=self.is_cloud) + + # ---------- server health ---------- + def check_server(self) -> tuple[bool, dict | None]: + try: + r = http_get(self._url("/system_stats"), headers=self.headers, retries=2) + if r.status == 200: + try: + return True, r.json() + except Exception: + return True, None + return False, {"http_status": r.status, "body": r.text()[:500]} + except Exception as e: + return False, {"error": str(e)} + + # ---------- upload ---------- + def upload_image(self, path: Path, *, image_type: str = "input", overwrite: bool = True, + endpoint: str = "/upload/image", extra_form: dict | None = None) -> dict: + """Upload an image file via multipart. Returns server-side ref dict.""" + if not path.exists(): + raise FileNotFoundError(f"input image not found: {path}") + # Stream the file via a handle to avoid OOM on huge inputs (16MP+ photos). + with path.open("rb") as fh: + files = {"image": (path.name, fh)} + form = {"type": image_type} + if overwrite: + form["overwrite"] = "true" + if extra_form: + form.update({k: str(v) for k, v in extra_form.items()}) + r = http_request( + "POST", self._url(endpoint), + headers=self.headers, files=files, form=form, + timeout=300, retries=2, + ) + if r.status != 200: + raise WorkflowRunError( + "upload_failed", + f"Upload of {path.name} failed: HTTP {r.status}", + body=r.text()[:500], + ) + try: + return r.json() + except Exception: + return {"name": path.name} + + def upload_mask(self, path: Path, original_ref: dict) -> dict: + """Upload an inpaint mask, linked to a previously uploaded source image. + + `original_ref` should be the dict returned by `upload_image()` for the + source image (or `{"filename": ..., "subfolder": ..., "type": "input"}`). + """ + return self.upload_image( + path, + endpoint="/upload/mask", + extra_form={ + "subfolder": "clipspace", + "original_ref": json.dumps(original_ref), + }, + ) + + # ---------- submit ---------- + def submit(self, workflow: dict) -> dict: + payload: dict[str, Any] = {"prompt": workflow, "client_id": self.client_id} + if self.partner_key: + payload["extra_data"] = {"api_key_comfy_org": self.partner_key} + + r = http_post(self._url("/prompt"), headers=self.headers, json_body=payload, timeout=120) + try: + body = r.json() + except Exception: + body = {"raw": r.text()[:500]} + if r.status != 200: + return {"_http_error": r.status, "body": body} + return body + + # ---------- HTTP polling ---------- + def poll_status(self, prompt_id: str, *, timeout: float = 300.0, + initial_interval: float = 1.5, max_interval: float = 8.0) -> dict: + start = time.time() + interval = initial_interval + + while time.time() - start < timeout: + if self.is_cloud: + r = http_get( + self._url(f"/job/{prompt_id}/status"), + headers=self.headers, retries=2, timeout=30, + ) + if r.status == 200: + try: + data = r.json() + except Exception: + data = {} + s = data.get("status") + if s == "completed": + return {"status": "success", "data": data} + if s in ("failed",): + return {"status": "error", "data": data} + if s == "cancelled": + return {"status": "cancelled", "data": data} + # pending / in_progress → continue + elif r.status == 404: + # Cloud sometimes 404s briefly between submit and dispatcher pickup + pass + else: + # transient error — retry loop covers it + pass + else: + # Local: /history/{id} grows once execution completes + r = http_get( + self._url(f"/history/{prompt_id}"), + headers=self.headers, retries=2, timeout=30, + ) + if r.status == 200: + try: + data = r.json() or {} + except Exception: + data = {} + entry = data.get(prompt_id) + if isinstance(entry, dict): + st = entry.get("status") or {} + # IMPORTANT: check error first — `completed: true` can coexist with errors + status_str = st.get("status_str") + if status_str == "error": + return {"status": "error", "data": entry} + if st.get("completed", False): + return {"status": "success", "outputs": entry.get("outputs", {})} + # not in history yet → continue polling + + time.sleep(interval) + interval = min(max_interval, interval * 1.4) + + return {"status": "timeout", "elapsed": time.time() - start} + + # ---------- WebSocket monitoring ---------- + def monitor_ws(self, prompt_id: str, *, timeout: float = 300.0, + on_progress: Any = None) -> dict: + """Connect to /ws and listen until execution_success / execution_error. + + Falls back to HTTP polling if `websocket-client` is not installed. + Returns same shape as poll_status. + """ + try: + import websocket # type: ignore[import-not-found] + except ImportError: + log("websocket-client not installed; falling back to HTTP polling") + return self.poll_status(prompt_id, timeout=timeout) + + # Build WS URL. Preserve any base-path components the user gave us + # (e.g. http://example.com/comfyui → ws://example.com/comfyui/ws). + parsed = urlparse(self.host) + scheme = "wss" if parsed.scheme == "https" else "ws" + netloc = parsed.netloc + base_path = parsed.path.rstrip("/") + ws_url = f"{scheme}://{netloc}{base_path}/ws?clientId={self.client_id}" + if self.is_cloud and self.api_key: + ws_url += f"&token={self.api_key}" + + outputs: dict[str, Any] = {} + error_payload: dict[str, Any] | None = None + success = False + seen_executed = False + + ws = websocket.create_connection(ws_url, timeout=timeout) + try: + ws.settimeout(timeout) + deadline = time.time() + timeout + while time.time() < deadline: + msg = ws.recv() + if isinstance(msg, bytes): + # Binary preview frame — ignore for now; ws_monitor.py prints them + continue + try: + payload = json.loads(msg) + except Exception: + continue + mtype = payload.get("type", "") + mdata = payload.get("data", {}) or {} + + # Filter to our job (cloud broadcasts; local filters via client_id) + pid = mdata.get("prompt_id") + if pid is not None and pid != prompt_id: + continue + + if mtype == "progress": + if callable(on_progress): + on_progress({ + "type": "progress", + "value": mdata.get("value"), + "max": mdata.get("max"), + "node": mdata.get("node"), + }) + elif mtype == "progress_state": + if callable(on_progress): + on_progress({"type": "progress_state", "nodes": mdata.get("nodes", {})}) + elif mtype == "executing": + node = mdata.get("node") + if callable(on_progress): + on_progress({"type": "executing", "node": node}) + # When `node` is None on a local server, that signals end-of-run + if node is None and not self.is_cloud and seen_executed: + success = True + break + elif mtype == "executed": + seen_executed = True + nid = mdata.get("node") + out = mdata.get("output") or {} + if nid: + outputs[nid] = out + elif mtype == "notification": + if callable(on_progress): + on_progress({"type": "notification", "message": mdata.get("value", "")}) + elif mtype == "execution_success": + success = True + break + elif mtype == "execution_error": + error_payload = mdata + break + elif mtype == "execution_interrupted": + error_payload = {"interrupted": True, **mdata} + break + finally: + try: + ws.close() + except Exception: + pass + + if error_payload is not None: + return {"status": "error", "data": error_payload} + if success: + return {"status": "success", "outputs": outputs} + return {"status": "timeout", "elapsed": timeout} + + # ---------- outputs ---------- + def get_outputs(self, prompt_id: str) -> dict: + if self.is_cloud: + # Try /jobs/{id} first (returns full job with outputs); fall back to /history_v2 + r = http_get(self._url(f"/jobs/{prompt_id}"), headers=self.headers, retries=2) + if r.status == 200: + try: + return (r.json() or {}).get("outputs", {}) or {} + except Exception: + pass + # Fallback + r = http_get(self._url(f"/history/{prompt_id}"), headers=self.headers, retries=2) + if r.status == 200: + try: + body = r.json() or {} + except Exception: + body = {} + if isinstance(body, dict) and prompt_id in body: + return body[prompt_id].get("outputs", {}) or {} + if isinstance(body, dict) and "outputs" in body: + return body["outputs"] or {} + return {} + # Local + r = http_get(self._url(f"/history/{prompt_id}"), headers=self.headers, retries=2) + if r.status != 200: + return {} + try: + body = r.json() or {} + except Exception: + return {} + entry = body.get(prompt_id) or {} + return entry.get("outputs", {}) or {} + + def download_output( + self, *, filename: str, subfolder: str, file_type: str, + output_dir: Path, preserve_subfolder: bool = True, overwrite: bool = False, + ) -> Path: + """Stream a single output to disk. Path-traversal-safe.""" + params = {"filename": filename, "subfolder": subfolder, "type": file_type} + url = self._url("/view") + "?" + urlencode(params) + + # Compute target path safely. If preserve_subfolder, include subfolder in the + # local path; otherwise put the file in output_dir flat. + target_parts: list[str] = [] + if preserve_subfolder and subfolder: + target_parts.extend(p for p in subfolder.split("/") if p and p not in (".", "..")) + target_parts.append(filename) + out_path = safe_path_join(output_dir, *target_parts) + + if out_path.exists() and not overwrite: + stem, suffix = out_path.stem, out_path.suffix + i = 1 + while True: + candidate = out_path.with_name(f"{stem}_{i}{suffix}") + if not candidate.exists(): + out_path = candidate + break + i += 1 + + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Stream download. Two-step for cloud: get the 302, then fetch signed URL + # so we don't accidentally send X-API-Key to the storage backend. + # The HTTP transport already strips X-API-Key on cross-host redirect + # via _strip_api_key_on_redirect, so a single follow_redirects=True call + # is safe AND simpler. + r = http_request( + "GET", url, headers=self.headers, + timeout=600, retries=3, follow_redirects=True, + stream=True, sink=out_path, + ) + if r.status != 200: + try: + if out_path.exists(): + out_path.unlink() + except Exception: + pass + raise WorkflowRunError( + "download_failed", + f"Download of {filename} failed: HTTP {r.status}", + url=url, + ) + return out_path + + # ---------- queue / cancel ---------- + def cancel(self, prompt_id: str | None = None) -> bool: + if prompt_id: + r = http_post( + self._url("/queue"), headers=self.headers, + json_body={"delete": [prompt_id]}, retries=1, + ) + return r.status == 200 + # Interrupt currently running + r = http_post(self._url("/interrupt"), headers=self.headers, retries=1) + return r.status == 200 + + +# ============================================================================= +# Schema / parameter injection +# ============================================================================= + +def _inline_schema(workflow: dict) -> dict: + """Generate schema using the sibling extract_schema module.""" + from extract_schema import extract_schema # noqa: WPS433 + return extract_schema(workflow) + + +def load_schema(schema_path: str | None, workflow: dict) -> dict: + if schema_path: + with open(schema_path) as f: + return json.load(f) + return _inline_schema(workflow) + + +def inject_params( + workflow: dict, schema: dict, args: dict, + *, randomize_seed_if_unset: bool = False, +) -> tuple[dict, list[str]]: + """Inject user args into the workflow. Returns (new_workflow, warnings).""" + wf = copy.deepcopy(workflow) + params = schema.get("parameters", {}) or {} + warnings: list[str] = [] + + # Auto-randomize seed when it's -1 in args, or when randomize_seed_if_unset + # and user didn't pass a seed. + if "seed" in params: + if "seed" in args and args["seed"] in (None, -1, "-1"): + args = dict(args) + args["seed"] = coerce_seed(args["seed"]) + warnings.append(f"seed=-1 expanded to {args['seed']}") + elif randomize_seed_if_unset and "seed" not in args: + args = dict(args) + args["seed"] = coerce_seed(None) + warnings.append(f"seed auto-randomized to {args['seed']}") + + for name, value in args.items(): + if name not in params: + warnings.append(f"unknown parameter '{name}' (not in schema), skipping") + continue + m = params[name] + nid, field = m["node_id"], m["field"] + node = wf.get(nid) + if not isinstance(node, dict) or "inputs" not in node: + warnings.append(f"node '{nid}' for parameter '{name}' missing in workflow") + continue + # Refuse to overwrite a link with a literal — would silently break wiring + cur = node["inputs"].get(field) + if is_link(cur): + warnings.append( + f"parameter '{name}' targets {nid}.{field} which is currently a link; " + f"refusing to overwrite (set the schema to point at the source node instead)" + ) + continue + node["inputs"][field] = value + + return wf, warnings + + +# ============================================================================= +# Output download helper +# ============================================================================= + +def download_outputs( + runner: ComfyRunner, outputs: dict, output_dir: Path, + *, preserve_subfolder: bool = True, overwrite: bool = False, +) -> list[dict]: + """Walk the outputs dict and download every file. Cloud uses `video` (singular); + local uses `videos` (plural). We accept both.""" + output_dir.mkdir(parents=True, exist_ok=True) + downloaded: list[dict] = [] + + OUTPUT_KEYS = ("images", "gifs", "videos", "video", "audio", "files", "models", "3d") + + for node_id, node_output in (outputs or {}).items(): + if not isinstance(node_output, dict): + continue + for key in OUTPUT_KEYS: + entries = node_output.get(key) + if not entries: + continue + if not isinstance(entries, list): + entries = [entries] + for fi in entries: + if not isinstance(fi, dict): + continue + filename = fi.get("filename") or "" + if not filename: + continue + subfolder = fi.get("subfolder") or "" + file_type = fi.get("type") or "output" + try: + out_path = runner.download_output( + filename=filename, subfolder=subfolder, file_type=file_type, + output_dir=output_dir, preserve_subfolder=preserve_subfolder, + overwrite=overwrite, + ) + downloaded.append({ + "file": str(out_path), + "node_id": node_id, + "type": media_type_from_filename(filename), + "filename": filename, + "subfolder": subfolder, + "source_type": file_type, + }) + except Exception as e: + log(f"WARN: failed to download {filename}: {e}") + return downloaded + + +# ============================================================================= +# CLI +# ============================================================================= + +def parse_input_image_arg(spec: str) -> tuple[str, Path]: + """Parse `name=path` (or `path` alone, defaulting to name='image').""" + if "=" in spec: + name, path = spec.split("=", 1) + return name.strip(), Path(path).expanduser() + return "image", Path(spec).expanduser() + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser( + description="Run a ComfyUI workflow with parameter injection.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--workflow", required=True, help="Path to workflow API JSON file") + p.add_argument("--args", default="{}", + help="JSON parameters to inject (or `@/path/to/args.json`)") + p.add_argument("--schema", help="Path to schema JSON (auto-generated if omitted)") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST, help="ComfyUI server URL") + p.add_argument("--api-key", + help=f"API key for cloud (or set ${ENV_API_KEY} env var)") + p.add_argument("--partner-key", + help="Partner-node API key (extra_data.api_key_comfy_org). " + "Required for Flux Pro / Ideogram / etc. Defaults to --api-key if not set.") + p.add_argument("--output-dir", default="./outputs", help="Directory to save outputs") + p.add_argument("--timeout", type=int, default=0, + help="Max seconds to wait (0=auto: 300 / 900 for video workflows)") + p.add_argument("--input-image", action="append", default=[], + help="Upload local image before running. Format: `name=path` or `path`. " + "The `name` becomes the value injected into the matching schema parameter.") + p.add_argument("--randomize-seed", action="store_true", + help="If schema has a 'seed' parameter and --args didn't set one, randomize it") + p.add_argument("--ws", action="store_true", + help="Use WebSocket for real-time progress (requires `websocket-client`)") + p.add_argument("--no-download", action="store_true", help="Skip downloading outputs") + p.add_argument("--flat-output", action="store_true", + help="Don't preserve server-side subfolder structure when saving outputs") + p.add_argument("--overwrite", action="store_true", + help="Overwrite existing files instead of appending _1, _2, ...") + p.add_argument("--submit-only", action="store_true", + help="Submit and return prompt_id without waiting") + p.add_argument("--client-id", help="Override generated client_id (UUID)") + p.add_argument("--use-partner-key-as-auth", action="store_true", + help="(Compat) Use --partner-key value as cloud X-API-Key. Don't use unless you know why.") + + args = p.parse_args(argv) + + # ---- Load workflow ---- + wf_path = Path(args.workflow).expanduser() + if not wf_path.exists(): + emit_json({"error": f"Workflow file not found: {args.workflow}"}) + return 1 + try: + with wf_path.open() as f: + workflow_raw = json.load(f) + workflow = unwrap_workflow(workflow_raw) + except ValueError as e: + emit_json({"error": str(e)}) + return 1 + except json.JSONDecodeError as e: + emit_json({"error": f"Invalid JSON in workflow file: {e}"}) + return 1 + + # ---- Parse user args ---- + args_str = args.args + if args_str.startswith("@"): + try: + args_str = Path(args_str[1:]).read_text() + except OSError as e: + emit_json({"error": f"Cannot read args file: {e}"}) + return 1 + try: + user_args = json.loads(args_str) if args_str.strip() else {} + except json.JSONDecodeError as e: + emit_json({"error": f"Invalid --args JSON: {e}"}) + return 1 + if not isinstance(user_args, dict): + emit_json({"error": "--args must be a JSON object"}) + return 1 + + # ---- Resolve API key ---- + api_key = resolve_api_key(args.api_key) + partner_key = args.partner_key or None + if args.use_partner_key_as_auth and not api_key and partner_key: + api_key = partner_key + + # ---- Connect ---- + runner = ComfyRunner( + host=args.host, api_key=api_key, partner_key=partner_key, + client_id=args.client_id, + ) + + # Server reachability + ok, info = runner.check_server() + if not ok: + emit_json({ + "error": f"Cannot reach server at {args.host}", + "details": info, + "hint": ( + "Check `comfy launch --background` is running for local, " + f"or set ${ENV_API_KEY} for cloud." + ), + }) + return 1 + + # ---- Upload input images ---- + upload_warnings: list[str] = [] + for spec in args.input_image: + try: + param_name, path = parse_input_image_arg(spec) + except Exception as e: + emit_json({"error": f"Bad --input-image spec '{spec}': {e}"}) + return 1 + try: + ref = runner.upload_image(path) + except Exception as e: + emit_json({"error": f"Upload failed for {path}: {e}"}) + return 1 + # Register as a user arg so inject_params consumes it through the schema + uploaded_name = ref.get("name") or path.name + if param_name not in user_args: + user_args[param_name] = uploaded_name + + # ---- Inject params ---- + schema = load_schema(args.schema, workflow) + workflow, inj_warnings = inject_params( + workflow, schema, user_args, randomize_seed_if_unset=args.randomize_seed, + ) + warnings = upload_warnings + inj_warnings + for w in warnings: + log(f"WARN: {w}") + + # ---- Submit ---- + submit_resp = runner.submit(workflow) + if "_http_error" in submit_resp: + emit_json({ + "error": "Submission HTTP error", + "http_status": submit_resp["_http_error"], + "body": submit_resp.get("body"), + }) + return 1 + + if isinstance(submit_resp.get("error"), dict): + emit_json({ + "error": "Workflow validation failed", + "details": submit_resp["error"], + "node_errors": submit_resp.get("node_errors"), + }) + return 1 + + prompt_id = submit_resp.get("prompt_id") + if not prompt_id: + emit_json({"error": "No prompt_id in submit response", "response": submit_resp}) + return 1 + + node_errors = submit_resp.get("node_errors") or {} + if node_errors: + emit_json({"error": "Workflow validation failed", "node_errors": node_errors}) + return 1 + + if args.submit_only: + emit_json({"status": "submitted", "prompt_id": prompt_id, "warnings": warnings}) + return 0 + + # ---- Wait ---- + timeout = args.timeout + if timeout <= 0: + timeout = 900 if looks_like_video_workflow(workflow) else 300 + + log(f"Submitted: prompt_id={prompt_id}, waiting (timeout={timeout}s)…") + + def _on_progress(evt: dict) -> None: + t = evt.get("type") + if t == "progress": + log(f" step {evt.get('value')}/{evt.get('max')} on node {evt.get('node')}") + elif t == "executing": + node = evt.get("node") + if node: + log(f" executing node {node}") + + try: + if args.ws: + wait_result = runner.monitor_ws(prompt_id, timeout=timeout, on_progress=_on_progress) + else: + wait_result = runner.poll_status(prompt_id, timeout=timeout) + except KeyboardInterrupt: + log(f"Interrupted — cancelling job {prompt_id} on server…") + try: + runner.cancel(prompt_id) + except Exception as e: + log(f" (cancel request failed: {e})") + emit_json({ + "status": "interrupted", + "prompt_id": prompt_id, + "note": "Ctrl+C received; sent cancellation to server.", + }) + return 130 + + if wait_result["status"] == "timeout": + emit_json({ + "status": "timeout", + "prompt_id": prompt_id, + "elapsed": wait_result.get("elapsed"), + "hint": "Re-run with larger --timeout, or use --submit-only and check later.", + }) + return 1 + if wait_result["status"] == "error": + emit_json({"status": "error", "prompt_id": prompt_id, "details": wait_result.get("data")}) + return 1 + if wait_result["status"] == "cancelled": + emit_json({"status": "cancelled", "prompt_id": prompt_id}) + return 1 + + # ---- Outputs ---- + outputs = wait_result.get("outputs") + if not outputs: + outputs = runner.get_outputs(prompt_id) + + if args.no_download: + emit_json({ + "status": "success", "prompt_id": prompt_id, + "outputs": outputs, "warnings": warnings, + }) + return 0 + + downloaded = download_outputs( + runner, outputs, Path(args.output_dir).expanduser(), + preserve_subfolder=not args.flat_output, overwrite=args.overwrite, + ) + + emit_json({ + "status": "success", + "prompt_id": prompt_id, + "outputs": downloaded, + "warnings": warnings, + }) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/scripts/ws_monitor.py b/skills/creative/comfyui/scripts/ws_monitor.py new file mode 100755 index 0000000000..b8689655bd --- /dev/null +++ b/skills/creative/comfyui/scripts/ws_monitor.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +""" +ws_monitor.py — Real-time ComfyUI WebSocket monitor. + +Connects to /ws and pretty-prints execution events: node start/finish, sampling +progress, cached nodes, errors. Optionally writes preview frames to disk. + +Useful for: + - Watching a long-running job in real time without parsing JSON yourself + - Saving in-progress preview frames for video / animation workflows + - Debugging "why is this hanging?" — see exactly which node is stuck + +Usage: + # Local — watch all jobs from this client_id + python3 ws_monitor.py + + # Cloud — watch a specific prompt_id + python3 ws_monitor.py --host https://cloud.comfy.org \ + --prompt-id abc-123-def + + # Save preview frames to ./previews/ + python3 ws_monitor.py --previews ./previews + +Requires: websocket-client (`pip install websocket-client`). +Falls back to a clear error message when not installed. +""" + +from __future__ import annotations + +import argparse +import json +import struct +import sys +from pathlib import Path +from urllib.parse import urlparse + +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from _common import ( # noqa: E402 + DEFAULT_LOCAL_HOST, ENV_API_KEY, log, new_client_id, resolve_api_key, is_cloud_host, +) + + +# Binary frame types from ComfyUI WebSocket protocol +BINARY_PREVIEW_IMAGE = 1 +BINARY_TEXT = 3 +BINARY_PREVIEW_IMAGE_WITH_METADATA = 4 + +# Image type codes inside PREVIEW_IMAGE +IMAGE_TYPE_JPEG = 1 +IMAGE_TYPE_PNG = 2 + +# ANSI escape codes (works on most modern terminals) +RESET = "\033[0m" +DIM = "\033[2m" +BOLD = "\033[1m" +GREEN = "\033[32m" +YELLOW = "\033[33m" +RED = "\033[31m" +CYAN = "\033[36m" + + +def fmt_color(s: str, color: str, *, color_on: bool = True) -> str: + return f"{color}{s}{RESET}" if color_on else s + + +def parse_binary_frame(data: bytes) -> dict | None: + if len(data) < 8: + return None + type_code = struct.unpack(">I", data[0:4])[0] + if type_code == BINARY_PREVIEW_IMAGE: + image_type = struct.unpack(">I", data[4:8])[0] + ext = "jpg" if image_type == IMAGE_TYPE_JPEG else "png" if image_type == IMAGE_TYPE_PNG else "bin" + return { + "kind": "preview", + "image_type": image_type, + "ext": ext, + "image_bytes": data[8:], + } + if type_code == BINARY_PREVIEW_IMAGE_WITH_METADATA: + if len(data) < 12: + return None + meta_len = struct.unpack(">I", data[4:8])[0] + meta_end = 8 + meta_len + if len(data) < meta_end: + return None + try: + meta = json.loads(data[8:meta_end].decode("utf-8")) + except Exception: + meta = {"raw": data[8:meta_end][:200].decode("utf-8", "replace")} + return { + "kind": "preview_with_metadata", + "metadata": meta, + "image_bytes": data[meta_end:], + "ext": "png", + } + if type_code == BINARY_TEXT: + if len(data) < 8: + return None + nid_len = struct.unpack(">I", data[4:8])[0] + nid_end = 8 + nid_len + if len(data) < nid_end: + return None + return { + "kind": "text", + "node_id": data[8:nid_end].decode("utf-8", "replace"), + "text": data[nid_end:].decode("utf-8", "replace"), + } + return {"kind": "unknown", "type_code": type_code, "size": len(data)} + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser(description="Real-time ComfyUI WebSocket monitor") + p.add_argument("--host", default=DEFAULT_LOCAL_HOST, help="ComfyUI server URL") + p.add_argument("--api-key", help=f"API key for cloud (or set ${ENV_API_KEY} env var)") + p.add_argument("--client-id", default=None, help="Client ID (default: random UUID)") + p.add_argument("--prompt-id", default=None, + help="Filter to a specific prompt_id (default: all jobs)") + p.add_argument("--previews", default=None, + help="Directory to save in-progress preview frames") + p.add_argument("--no-color", action="store_true", help="Disable ANSI colour") + p.add_argument("--timeout", type=float, default=600.0, + help="Hard cap on monitor duration (default 600s)") + args = p.parse_args(argv) + + try: + import websocket # type: ignore[import-not-found] + except ImportError: + print(json.dumps({ + "error": "websocket-client not installed", + "install": "pip install websocket-client", + })) + return 1 + + api_key = resolve_api_key(args.api_key) + cloud = is_cloud_host(args.host) + client_id = args.client_id or new_client_id() + + # Build WS URL preserving any base-path component (e.g. behind reverse proxy). + parsed = urlparse(args.host if "://" in args.host else f"http://{args.host}") + scheme = "wss" if parsed.scheme == "https" else "ws" + netloc = parsed.netloc + base_path = parsed.path.rstrip("/") + ws_url = f"{scheme}://{netloc}{base_path}/ws?clientId={client_id}" + if cloud and api_key: + ws_url += f"&token={api_key}" + + color_on = not args.no_color and sys.stdout.isatty() + + preview_dir = Path(args.previews).expanduser() if args.previews else None + if preview_dir: + preview_dir.mkdir(parents=True, exist_ok=True) + log(f"Saving previews to {preview_dir}") + + log(f"Connecting to {ws_url} (client_id={client_id})") + if args.prompt_id: + log(f"Filtering messages to prompt_id={args.prompt_id}") + + ws = websocket.create_connection(ws_url, timeout=args.timeout) + ws.settimeout(args.timeout) + + preview_counter = 0 + try: + while True: + try: + msg = ws.recv() + except websocket.WebSocketTimeoutException: + log(f"Idle for {args.timeout}s — exiting") + return 0 + if isinstance(msg, bytes): + parsed = parse_binary_frame(msg) + if parsed is None: + continue + if parsed["kind"] in ("preview", "preview_with_metadata") and preview_dir: + img_bytes = parsed.get("image_bytes", b"") + if img_bytes: + ext = parsed.get("ext", "png") + out = preview_dir / f"preview_{preview_counter:05d}.{ext}" + out.write_bytes(img_bytes) + preview_counter += 1 + log(f" [preview] saved {out.name} ({len(img_bytes)} bytes)") + continue + + try: + payload = json.loads(msg) + except Exception: + continue + mtype = payload.get("type", "") + mdata = payload.get("data", {}) or {} + pid = mdata.get("prompt_id") + + if args.prompt_id and pid and pid != args.prompt_id: + continue + + if mtype == "status": + qr = mdata.get("status", {}).get("exec_info", {}).get("queue_remaining", "?") + print(fmt_color(f"[status] queue_remaining={qr}", DIM, color_on=color_on)) + elif mtype == "execution_start": + print(fmt_color(f"[start] prompt_id={pid}", BOLD, color_on=color_on)) + elif mtype == "executing": + node = mdata.get("node") + if node: + print(fmt_color(f" [executing] node={node}", CYAN, color_on=color_on)) + else: + print(fmt_color(f" [executing] (workflow done) prompt_id={pid}", DIM, color_on=color_on)) + elif mtype == "progress": + v, m = mdata.get("value", 0), mdata.get("max", 0) + pct = (v / m * 100) if m else 0 + print(f" [progress] {v}/{m} ({pct:5.1f}%) node={mdata.get('node')}") + elif mtype == "progress_state": + # Newer extended progress message + nodes = mdata.get("nodes") or {} + running = [k for k, v in nodes.items() if v.get("running")] + if running: + print(fmt_color(f" [progress_state] running={running}", DIM, color_on=color_on)) + elif mtype == "executed": + node = mdata.get("node") + out = mdata.get("output") or {} + summary_parts = [] + for key in ("images", "video", "videos", "gifs", "audio", "files"): + if out.get(key): + summary_parts.append(f"{key}={len(out[key])}") + summary = ", ".join(summary_parts) if summary_parts else "(no files)" + print(fmt_color(f" [executed] node={node} {summary}", GREEN, color_on=color_on)) + elif mtype == "execution_cached": + cached = mdata.get("nodes") or [] + if cached: + print(fmt_color(f" [cached] {len(cached)} nodes skipped", DIM, color_on=color_on)) + elif mtype == "execution_success": + print(fmt_color(f"[success] prompt_id={pid}", GREEN + BOLD, color_on=color_on)) + if args.prompt_id: + return 0 + elif mtype == "execution_error": + exc_type = mdata.get("exception_type", "?") + exc_msg = mdata.get("exception_message", "?") + print(fmt_color(f"[error] {exc_type}: {exc_msg}", RED + BOLD, color_on=color_on)) + tb = mdata.get("traceback") + if tb: + if isinstance(tb, list): + for line in tb: + print(fmt_color(f" {line}", RED, color_on=color_on)) + else: + print(fmt_color(f" {tb}", RED, color_on=color_on)) + if args.prompt_id: + return 1 + elif mtype == "execution_interrupted": + print(fmt_color(f"[interrupted] prompt_id={pid}", YELLOW, color_on=color_on)) + if args.prompt_id: + return 1 + elif mtype == "notification": + v = mdata.get("value", "") + print(fmt_color(f"[notification] {v}", DIM, color_on=color_on)) + else: + # Unknown / lightly-used types: print compactly + print(fmt_color(f"[{mtype}] {json.dumps(mdata, default=str)[:200]}", DIM, color_on=color_on)) + + except KeyboardInterrupt: + log("Interrupted") + return 130 + finally: + try: + ws.close() + except Exception: + pass + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/skills/creative/comfyui/tests/README.md b/skills/creative/comfyui/tests/README.md new file mode 100644 index 0000000000..833632ae9c --- /dev/null +++ b/skills/creative/comfyui/tests/README.md @@ -0,0 +1,50 @@ +# ComfyUI Skill Tests + +Pytest suite covering the skill's scripts. Pure-stdlib unit tests run +without any setup; cloud integration tests need a Comfy Cloud API key. + +## Running + +```bash +# Unit tests only (no network required) — runs in <1s +python3 -m pytest tests/ -c tests/pytest.ini -o addopts="-p no:xdist" + +# Including cloud integration tests +COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/ \ + -c tests/pytest.ini -o addopts="-p no:xdist" + +# Just cloud tests +COMFY_CLOUD_API_KEY="comfyui-..." python3 -m pytest tests/test_cloud_integration.py \ + -c tests/pytest.ini -o addopts="-p no:xdist" -v +``` + +The `-c` and `-o` overrides isolate this suite from any parent +`pyproject.toml` pytest config (e.g. the `-n auto` from a parent repo). + +## Test files + +| File | Coverage | +|------|----------| +| `test_common.py` | Cloud detection, URL routing, format validation, embeddings, paths, seeds, model-list parsing, folder aliases | +| `test_extract_schema.py` | Connection tracing, positive/negative prompt detection, dedup logic, embedding deps | +| `test_run_workflow.py` | Param injection (incl. -1 seed, link refusal), output download walk, runner construction | +| `test_check_deps.py` | Model-name fuzzy matching, install command suggestions | +| `test_cloud_integration.py` | Live cloud API contract tests (auto-skipped without API key) | + +## Adding tests + +When you change a script: + +1. Add a unit test if the change is pure logic (cloud detection, parsing, etc.) +2. Add a cloud integration test if the change depends on cloud API behavior + (use `pytestmark = pytest.mark.cloud` so it auto-skips without a key) +3. Workflow fixtures live in `conftest.py` (`sd15_workflow`, `flux_workflow`, + `video_workflow`) + +## Why the explicit `-c` / `-o`? + +The parent hermes-agent repo's `pyproject.toml` enables `pytest-xdist` by +default (`-n auto`). This suite is small enough that parallelism isn't +worth the complexity, and pytest-xdist isn't always installed in the user's +environment. The `-c tests/pytest.ini -o addopts="-p no:xdist"` flags make +the suite run identically regardless of the parent project's config. diff --git a/skills/creative/comfyui/tests/conftest.py b/skills/creative/comfyui/tests/conftest.py new file mode 100644 index 0000000000..a800fa79f1 --- /dev/null +++ b/skills/creative/comfyui/tests/conftest.py @@ -0,0 +1,64 @@ +"""Pytest configuration for the comfyui skill test suite. + +Adds `scripts/` to sys.path so tests can `from _common import ...`, and +provides a few common fixtures. +""" + +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parent.parent +SCRIPTS = ROOT / "scripts" +WORKFLOWS = ROOT / "workflows" + +sys.path.insert(0, str(SCRIPTS)) + + +@pytest.fixture +def sd15_workflow() -> dict: + return json.loads((WORKFLOWS / "sd15_txt2img.json").read_text()) + + +@pytest.fixture +def flux_workflow() -> dict: + return json.loads((WORKFLOWS / "flux_dev_txt2img.json").read_text()) + + +@pytest.fixture +def video_workflow() -> dict: + return json.loads((WORKFLOWS / "wan_video_t2v.json").read_text()) + + +@pytest.fixture +def workflows_dir() -> Path: + return WORKFLOWS + + +@pytest.fixture +def scripts_dir() -> Path: + return SCRIPTS + + +@pytest.fixture +def cloud_key() -> str | None: + """Cloud API key if set, otherwise None. + + Tests that need cloud connectivity should skip when this is None. + """ + return os.environ.get("COMFY_CLOUD_API_KEY") + + +def pytest_collection_modifyitems(config, items): + """Auto-skip cloud tests when no API key is set.""" + if os.environ.get("COMFY_CLOUD_API_KEY"): + return + skip_cloud = pytest.mark.skip(reason="Set COMFY_CLOUD_API_KEY to run cloud tests") + for item in items: + if "cloud" in item.keywords: + item.add_marker(skip_cloud) diff --git a/skills/creative/comfyui/tests/pytest.ini b/skills/creative/comfyui/tests/pytest.ini new file mode 100644 index 0000000000..2111fe2122 --- /dev/null +++ b/skills/creative/comfyui/tests/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +markers = + cloud: tests that hit live Comfy Cloud API (require COMFY_CLOUD_API_KEY) +testpaths = . +addopts = -p no:xdist diff --git a/skills/creative/comfyui/tests/test_check_deps.py b/skills/creative/comfyui/tests/test_check_deps.py new file mode 100644 index 0000000000..30116a7fe7 --- /dev/null +++ b/skills/creative/comfyui/tests/test_check_deps.py @@ -0,0 +1,68 @@ +"""Tests for check_deps.py — focuses on parsing logic that doesn't need a server.""" + +from __future__ import annotations + +from check_deps import ( + NODE_TO_PACKAGE, + model_present, + normalize_for_match, + suggest_install_command, +) + + +class TestNormalizeForMatch: + def test_basic(self): + s = normalize_for_match("model.safetensors") + assert "model.safetensors" in s + assert "model" in s + + def test_subfolder(self): + s = normalize_for_match("subdir/model.pt") + assert "subdir/model.pt" in s + assert "model.pt" in s + assert "model" in s + + +class TestModelPresent: + def test_exact_match(self): + assert model_present("a.safetensors", {"a.safetensors", "b.safetensors"}) is True + + def test_extension_difference(self): + # User said "model" but installed is "model.safetensors" + assert model_present("model", {"model.safetensors"}) is True + # Reverse direction — also matches + assert model_present("model.safetensors", {"model"}) is True + + def test_subfolder_match(self): + # Installed list has "subdir/model.safetensors", workflow asks "model.safetensors" + assert model_present("model.safetensors", {"subdir/model.safetensors"}) is True + + def test_missing(self): + assert model_present("missing.safetensors", {"a.safetensors", "b.safetensors"}) is False + + def test_empty_installed(self): + assert model_present("anything.safetensors", set()) is False + + +class TestSuggestInstallCommand: + def test_known_node(self): + cmd = suggest_install_command("VHS_VideoCombine") + assert cmd == "comfy node install comfyui-videohelpersuite" + + def test_unknown_node(self): + assert suggest_install_command("SomeRandomNodeName123") is None + + +class TestNodePackageMap: + def test_no_duplicates(self): + # Each node should map to exactly one package + keys = list(NODE_TO_PACKAGE.keys()) + assert len(keys) == len(set(keys)) + + def test_packages_are_safe_for_shell(self): + # Registry slugs must be alphanumerics + hyphens/underscores only + # (passed straight to `comfy node install `). + import re + safe = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._\-]*$") + for pkg in NODE_TO_PACKAGE.values(): + assert safe.match(pkg), f"Unsafe package slug: {pkg!r}" diff --git a/skills/creative/comfyui/tests/test_cloud_integration.py b/skills/creative/comfyui/tests/test_cloud_integration.py new file mode 100644 index 0000000000..eb7b04ca22 --- /dev/null +++ b/skills/creative/comfyui/tests/test_cloud_integration.py @@ -0,0 +1,95 @@ +"""Integration tests against the live Comfy Cloud API. + +These tests are auto-skipped when COMFY_CLOUD_API_KEY is not set. +They never SUBMIT workflows (would need a paid subscription) — they only +verify the read-only endpoints we rely on. +""" + +from __future__ import annotations + +import pytest + +from _common import http_get, parse_model_list, resolve_url + + +pytestmark = pytest.mark.cloud + + +class TestCloudEndpointsLive: + def test_system_stats_reachable(self, cloud_key): + url = resolve_url("https://cloud.comfy.org", "/system_stats") + r = http_get(url, headers={"X-API-Key": cloud_key}) + assert r.status == 200 + data = r.json() + assert "system" in data + + def test_models_endpoint_routed_to_experiment(self, cloud_key): + # We expect the skill to route /models/checkpoints → /api/experiment/models/checkpoints + url = resolve_url("https://cloud.comfy.org", "/models/checkpoints") + assert "/api/experiment/models/checkpoints" in url + r = http_get(url, headers={"X-API-Key": cloud_key}) + assert r.status == 200 + + def test_models_endpoint_returns_dicts(self, cloud_key): + url = resolve_url("https://cloud.comfy.org", "/models/checkpoints") + r = http_get(url, headers={"X-API-Key": cloud_key}) + data = r.json() + assert isinstance(data, list) + if data: + # Cloud format: list of dicts with `name` + assert isinstance(data[0], dict) + assert "name" in data[0] + # Our parser normalizes both + normalized = parse_model_list(data) + assert len(normalized) == len(data) + + def test_history_renamed_to_v2(self, cloud_key): + # /history → /api/history_v2 on cloud + url = resolve_url("https://cloud.comfy.org", "/history/some-fake-id") + assert "/api/history_v2/some-fake-id" in url + + def test_object_info_paid_tier(self, cloud_key): + # On free tier, /object_info returns 403 with a recognizable message + url = resolve_url("https://cloud.comfy.org", "/object_info") + r = http_get(url, headers={"X-API-Key": cloud_key}) + # Should be either 200 (paid) or 403 (free) — not 404 / 500 + assert r.status in (200, 403) + if r.status == 403: + # Body should mention the limitation + assert "free tier" in r.text().lower() or "subscription" in r.text().lower() + + +class TestCloudCheckDepsLive: + def test_check_deps_against_cloud(self, cloud_key, sd15_workflow): + from check_deps import check_deps + report = check_deps(sd15_workflow, host="https://cloud.comfy.org", api_key=cloud_key) + # Either node check passed OR was skipped (free tier) + assert "missing_models" in report + assert "is_cloud" in report and report["is_cloud"] is True + + def test_flux_workflow_models_resolved_via_aliases(self, cloud_key, flux_workflow): + """Flux uses unet/clip folders; cloud has them in diffusion_models/text_encoders. + With folder aliasing, the check should still find them.""" + from check_deps import check_deps + report = check_deps(flux_workflow, host="https://cloud.comfy.org", api_key=cloud_key) + # The exact required Flux files (flux1-dev.safetensors, t5xxl_fp16, clip_l, ae) + # are present on cloud; with folder aliasing, none should be missing. + # If this fails, either the cloud removed the model or the aliasing logic broke. + missing_filenames = {m["value"] for m in report["missing_models"]} + assert "ae.safetensors" not in missing_filenames, \ + "ae.safetensors should be on cloud's vae folder" + # t5xxl_fp16 / clip_l should be reachable via the clip → text_encoders alias + # flux1-dev.safetensors likewise via unet → diffusion_models + + +class TestHealthCheckLive: + def test_health_check_passes(self, cloud_key, capsys): + from health_check import main as health_main + rc = health_main(["--host", "https://cloud.comfy.org", "--api-key", cloud_key]) + captured = capsys.readouterr() + # Should produce JSON + import json + report = json.loads(captured.out) + assert report["server"]["reachable"] is True + assert report["checkpoints"]["queryable"] is True + assert report["checkpoints"]["count"] > 0 diff --git a/skills/creative/comfyui/tests/test_common.py b/skills/creative/comfyui/tests/test_common.py new file mode 100644 index 0000000000..0263fe1d91 --- /dev/null +++ b/skills/creative/comfyui/tests/test_common.py @@ -0,0 +1,447 @@ +"""Unit tests for _common.py — pure logic only, no network.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from _common import ( + DEFAULT_LOCAL_HOST, + EMBEDDING_REGEX, + FOLDER_ALIASES, + build_cloud_aware_url, + cloud_endpoint, + coerce_seed, + folder_aliases_for, + is_api_format, + is_cloud_host, + is_link, + iter_embedding_refs, + iter_model_deps, + iter_nodes, + looks_like_video_workflow, + media_type_from_filename, + parse_model_list, + resolve_url, + safe_path_join, + unwrap_workflow, +) + + +# ============================================================================= +# Cloud detection / URL routing +# ============================================================================= + +class TestCloudDetection: + def test_cloud_host_exact(self): + assert is_cloud_host("https://cloud.comfy.org") is True + assert is_cloud_host("https://cloud.comfy.org/foo/bar") is True + + def test_cloud_host_subdomain(self): + assert is_cloud_host("https://staging.cloud.comfy.org") is True + assert is_cloud_host("https://api.cloud.comfy.org") is True + + def test_local_not_cloud(self): + assert is_cloud_host("http://127.0.0.1:8188") is False + assert is_cloud_host("http://localhost:8188") is False + assert is_cloud_host("http://my-server.local:8188") is False + + def test_no_scheme(self): + # Defaults to http:// + assert is_cloud_host("cloud.comfy.org") is True + assert is_cloud_host("127.0.0.1:8188") is False + + +class TestCloudEndpointRename: + def test_history_renamed(self): + assert cloud_endpoint("/history") == "/history_v2" + assert cloud_endpoint("/history/abc-123") == "/history_v2/abc-123" + + def test_history_v2_preserved(self): + assert cloud_endpoint("/history_v2") == "/history_v2" + + def test_models_renamed(self): + assert cloud_endpoint("/models") == "/experiment/models" + assert cloud_endpoint("/models/checkpoints") == "/experiment/models/checkpoints" + assert cloud_endpoint("/models/loras") == "/experiment/models/loras" + + def test_other_paths_unchanged(self): + assert cloud_endpoint("/prompt") == "/prompt" + assert cloud_endpoint("/queue") == "/queue" + + +class TestResolveURL: + def test_local_no_prefix(self): + assert resolve_url("http://127.0.0.1:8188", "/prompt") == "http://127.0.0.1:8188/prompt" + + def test_cloud_adds_api_prefix(self): + assert resolve_url("https://cloud.comfy.org", "/prompt") == "https://cloud.comfy.org/api/prompt" + + def test_cloud_history_renamed(self): + assert resolve_url("https://cloud.comfy.org", "/history/abc") == "https://cloud.comfy.org/api/history_v2/abc" + + def test_cloud_models_renamed(self): + assert resolve_url("https://cloud.comfy.org", "/models/loras") == "https://cloud.comfy.org/api/experiment/models/loras" + + def test_cloud_already_has_api(self): + # Don't double-prefix + assert resolve_url("https://cloud.comfy.org", "/api/prompt") == "https://cloud.comfy.org/api/prompt" + + def test_trailing_slash_stripped(self): + assert resolve_url("http://127.0.0.1:8188/", "/prompt") == "http://127.0.0.1:8188/prompt" + + +# ============================================================================= +# Workflow validation +# ============================================================================= + +class TestAPIFormatDetection: + def test_valid_api(self, sd15_workflow): + assert is_api_format(sd15_workflow) is True + + def test_editor_format_rejected(self): + editor = {"nodes": [], "links": [], "version": 0.4} + assert is_api_format(editor) is False + + def test_empty_dict(self): + assert is_api_format({}) is False + + def test_non_dict(self): + assert is_api_format([]) is False + assert is_api_format(None) is False + assert is_api_format("string") is False + + def test_node_with_class_type(self): + wf = {"3": {"class_type": "KSampler", "inputs": {}}} + assert is_api_format(wf) is True + + +class TestUnwrapWorkflow: + def test_passthrough_api_format(self, sd15_workflow): + result = unwrap_workflow(sd15_workflow) + assert result is sd15_workflow + + def test_unwrap_prompt_key(self, sd15_workflow): + wrapped = {"prompt": sd15_workflow, "client_id": "abc"} + result = unwrap_workflow(wrapped) + assert result is sd15_workflow + + def test_editor_format_raises(self): + with pytest.raises(ValueError, match="editor format"): + unwrap_workflow({"nodes": [], "links": []}) + + def test_garbage_raises(self): + with pytest.raises(ValueError): + unwrap_workflow({"foo": "bar"}) + + +class TestIsLink: + def test_valid_link(self): + assert is_link(["3", 0]) is True + assert is_link(["10", 1]) is True + + def test_non_link(self): + assert is_link("string") is False + assert is_link(42) is False + assert is_link([]) is False + assert is_link(["3"]) is False # missing slot + assert is_link(["3", "0"]) is False # slot must be int + assert is_link([3, 0]) is False # node_id must be string + + +# ============================================================================= +# Workflow iterators +# ============================================================================= + +class TestIterators: + def test_iter_nodes(self, sd15_workflow): + nodes = dict(iter_nodes(sd15_workflow)) + assert "3" in nodes + assert nodes["3"]["class_type"] == "KSampler" + + def test_iter_nodes_skips_comments(self, sd15_workflow): + # _comment is not a node + nodes = dict(iter_nodes(sd15_workflow)) + assert "_comment" not in nodes + + def test_iter_model_deps(self, sd15_workflow): + deps = list(iter_model_deps(sd15_workflow)) + names = [d["value"] for d in deps] + assert "v1-5-pruned-emaonly.safetensors" in names + + def test_iter_model_deps_flux(self, flux_workflow): + deps = list(iter_model_deps(flux_workflow)) + names = {d["value"]: d["folder"] for d in deps} + assert names["flux1-dev.safetensors"] == "unet" + assert names["t5xxl_fp16.safetensors"] == "clip" + assert names["clip_l.safetensors"] == "clip" + assert names["ae.safetensors"] == "vae" + + +# ============================================================================= +# Embedding extraction +# ============================================================================= + +class TestEmbeddingRegex: + def test_basic_embedding(self): + m = EMBEDDING_REGEX.search("a cat, embedding:goodvibes, more text") + assert m is not None + assert m.group(1) == "goodvibes" + + def test_embedding_with_strength(self): + m = EMBEDDING_REGEX.search("embedding:bad-hands-5:1.2") + assert m is not None + assert m.group(1) == "bad-hands-5" + + def test_embedding_with_extension(self): + # Strips .pt / .safetensors / .bin + m = EMBEDDING_REGEX.search("embedding:my-emb.pt") + assert m is not None + assert m.group(1) == "my-emb" + + def test_embedding_in_parens(self): + m = EMBEDDING_REGEX.search("(embedding:foo:0.8)") + assert m is not None + assert m.group(1) == "foo" + + def test_multiple_in_one_string(self): + text = "a cat, embedding:foo:1.2, and embedding:bar" + matches = [m.group(1) for m in EMBEDDING_REGEX.finditer(text)] + assert matches == ["foo", "bar"] + + def test_no_false_positive_on_word_embedding(self): + # "embedding " (with space, no colon) should not match + m = EMBEDDING_REGEX.search("the embedding is great") + assert m is None + + +class TestIterEmbeddingRefs: + def test_finds_in_clip_text_encode(self): + wf = { + "1": {"class_type": "CLIPTextEncode", + "inputs": {"text": "embedding:foo, embedding:bar:0.5", "clip": ["2", 0]}}, + "2": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}}, + } + refs = list(iter_embedding_refs(wf)) + names = [name for _, name in refs] + assert names == ["foo", "bar"] + + def test_ignores_non_prompt_fields(self): + wf = { + "1": {"class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "embedding:foo.safetensors"}}, + } + refs = list(iter_embedding_refs(wf)) + # ckpt_name is not a prompt field — ignored + assert refs == [] + + +# ============================================================================= +# Path safety +# ============================================================================= + +class TestSafePathJoin: + def test_normal_join(self, tmp_path): + p = safe_path_join(tmp_path, "subdir", "file.png") + assert p.is_relative_to(tmp_path) + + def test_blocks_traversal(self, tmp_path): + with pytest.raises(ValueError, match="path traversal"): + safe_path_join(tmp_path, "..", "..", "etc", "passwd") + + def test_blocks_absolute(self, tmp_path): + with pytest.raises(ValueError): + safe_path_join(tmp_path, "/etc/passwd") + + def test_subfolder_with_filename(self, tmp_path): + p = safe_path_join(tmp_path, "outputs", "img.png") + assert p.name == "img.png" + assert p.parent.name == "outputs" + + +# ============================================================================= +# Seed coercion +# ============================================================================= + +class TestCoerceSeed: + def test_explicit_int(self): + assert coerce_seed(42) == 42 + assert coerce_seed(0) == 0 + + def test_minus_one_randomizes(self): + s = coerce_seed(-1) + assert isinstance(s, int) + assert 0 <= s < 2**63 + + def test_none_randomizes(self): + s = coerce_seed(None) + assert isinstance(s, int) + + def test_string_int(self): + # str() that converts cleanly is allowed (relaxed) + assert coerce_seed("12345") == 12345 + + def test_string_minus_one_randomizes(self): + # CLI / JSON sometimes carries seed as a string. + s = coerce_seed("-1") + assert isinstance(s, int) + assert 0 <= s < 2**63 + # And whitespace tolerated + s2 = coerce_seed(" -1 ") + assert isinstance(s2, int) + assert 0 <= s2 < 2**63 + + +# ============================================================================= +# Model list normalization (cloud format) +# ============================================================================= + +class TestParseModelList: + def test_local_format_strings(self): + result = parse_model_list(["a.safetensors", "b.safetensors"]) + assert result == {"a.safetensors", "b.safetensors"} + + def test_cloud_format_dicts(self): + result = parse_model_list([ + {"name": "a.safetensors", "pathIndex": 0}, + {"name": "b.safetensors", "pathIndex": 1}, + ]) + assert result == {"a.safetensors", "b.safetensors"} + + def test_empty(self): + assert parse_model_list([]) == set() + + def test_garbage(self): + assert parse_model_list("not a list") == set() + assert parse_model_list(None) == set() + + def test_mixed_format(self): + result = parse_model_list([ + "string-form.safetensors", + {"name": "dict-form.safetensors"}, + ]) + assert result == {"string-form.safetensors", "dict-form.safetensors"} + + +# ============================================================================= +# Folder aliases +# ============================================================================= + +class TestFolderAliases: + def test_unet_aliases_diffusion_models(self): + aliases = folder_aliases_for("unet") + assert "unet" in aliases + assert "diffusion_models" in aliases + + def test_clip_aliases_text_encoders(self): + aliases = folder_aliases_for("clip") + assert "clip" in aliases + assert "text_encoders" in aliases + + def test_unknown_folder_returns_self(self): + assert folder_aliases_for("checkpoints") == ["checkpoints"] + + def test_primary_first(self): + # Order matters: primary should be first for human-friendly fix hints + assert folder_aliases_for("unet")[0] == "unet" + assert folder_aliases_for("diffusion_models")[0] == "diffusion_models" + + +# ============================================================================= +# Media-type detection +# ============================================================================= + +class TestMediaType: + def test_video_extensions(self): + assert media_type_from_filename("vid.mp4") == "video" + assert media_type_from_filename("foo.webm") == "video" + assert media_type_from_filename("bar.gif") == "video" + + def test_audio_extensions(self): + assert media_type_from_filename("song.wav") == "audio" + assert media_type_from_filename("music.mp3") == "audio" + + def test_image_default(self): + assert media_type_from_filename("pic.png") == "image" + assert media_type_from_filename("image.jpg") == "image" + assert media_type_from_filename("unknown.xyz") == "image" + + def test_3d(self): + assert media_type_from_filename("model.glb") == "3d" + assert media_type_from_filename("scene.gltf") == "3d" + + +# ============================================================================= +# Cross-host header stripping (security) +# ============================================================================= + +class TestRedirectHeaderStripping: + """Verify X-API-Key is dropped when redirect crosses to a different host + (e.g. cloud /api/view → S3 signed URL). Critical to prevent leaking auth + tokens to the storage backend. + """ + + def _build_session(self): + from _common import _StripSensitiveOnRedirectSession, HAS_REQUESTS + if not HAS_REQUESTS: + import pytest + pytest.skip("requests not installed") + return _StripSensitiveOnRedirectSession() + + def test_strips_x_api_key_cross_host(self): + import requests + s = self._build_session() + prep = requests.PreparedRequest() + prep.prepare(method="GET", url="https://other.example.com/file", + headers={"X-API-Key": "leak", "Authorization": "Bearer x"}) + resp = requests.Response() + orig = requests.PreparedRequest() + orig.prepare(method="GET", url="https://cloud.comfy.org/api/view", headers={}) + resp.request = orig + s.rebuild_auth(prep, resp) + assert "X-API-Key" not in prep.headers + assert "Authorization" not in prep.headers + + def test_preserves_x_api_key_same_host(self): + import requests + s = self._build_session() + prep = requests.PreparedRequest() + prep.prepare(method="GET", url="https://cloud.comfy.org/foo", + headers={"X-API-Key": "keep"}) + resp = requests.Response() + orig = requests.PreparedRequest() + orig.prepare(method="GET", url="https://cloud.comfy.org/bar", headers={}) + resp.request = orig + s.rebuild_auth(prep, resp) + assert prep.headers.get("X-API-Key") == "keep" + + def test_strips_cookie_cross_host(self): + import requests + s = self._build_session() + prep = requests.PreparedRequest() + prep.prepare(method="GET", url="https://other.example.com/x", + headers={"Cookie": "session=secret"}) + resp = requests.Response() + orig = requests.PreparedRequest() + orig.prepare(method="GET", url="https://cloud.comfy.org/foo", headers={}) + resp.request = orig + s.rebuild_auth(prep, resp) + assert "Cookie" not in prep.headers + + +# ============================================================================= +# Video workflow detection +# ============================================================================= + +class TestVideoWorkflow: + def test_image_workflow(self, sd15_workflow): + assert looks_like_video_workflow(sd15_workflow) is False + + def test_animatediff_workflow(self, workflows_dir): + import json + wf = json.loads((workflows_dir / "animatediff_video.json").read_text()) + assert looks_like_video_workflow(wf) is True + + def test_wan_workflow(self, video_workflow): + assert looks_like_video_workflow(video_workflow) is True diff --git a/skills/creative/comfyui/tests/test_extract_schema.py b/skills/creative/comfyui/tests/test_extract_schema.py new file mode 100644 index 0000000000..1cb965a1fa --- /dev/null +++ b/skills/creative/comfyui/tests/test_extract_schema.py @@ -0,0 +1,185 @@ +"""Tests for extract_schema.py.""" + +from __future__ import annotations + +import pytest + +from extract_schema import ( + extract_schema, + find_negative_prompt_node, + find_positive_prompt_node, + trace_to_node, +) + + +# ============================================================================= +# Connection tracing +# ============================================================================= + +class TestConnectionTracing: + def test_direct_link(self): + wf = { + "1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}}, + "2": {"class_type": "KSampler", + "inputs": {"positive": ["1", 0], "negative": ["1", 0]}}, + } + assert trace_to_node(wf, ["1", 0]) == "1" + + def test_through_reroute(self): + wf = { + "1": {"class_type": "CLIPTextEncode", "inputs": {"text": "x"}}, + "2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}}, + "3": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}}, + } + assert trace_to_node(wf, ["3", 0]) == "1" + + def test_circular_safe(self): + wf = { + "1": {"class_type": "Reroute", "inputs": {"input": ["2", 0]}}, + "2": {"class_type": "Reroute", "inputs": {"input": ["1", 0]}}, + } + # Should hit max_hops without infinite loop + result = trace_to_node(wf, ["1", 0], max_hops=5) + assert result in ("1", "2") # any node, just don't hang + + +class TestPositiveNegativeDetection: + def test_basic(self, sd15_workflow): + # In sd15_workflow.json node 6 is positive, node 7 is negative + assert find_positive_prompt_node(sd15_workflow) == "6" + assert find_negative_prompt_node(sd15_workflow) == "7" + + def test_swapped_order(self): + wf = { + "3": {"class_type": "KSampler", + "inputs": { + "positive": ["7", 0], "negative": ["6", 0], + "model": ["4", 0], "latent_image": ["5", 0], + "seed": 1, "steps": 20, "cfg": 7.5, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + }}, + "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}}, + "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", "inputs": {"text": "ugly", "clip": ["4", 1]}}, + "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "beautiful", "clip": ["4", 1]}}, + } + # Now 7 is the positive (despite higher node ID) + assert find_positive_prompt_node(wf) == "7" + assert find_negative_prompt_node(wf) == "6" + + +# ============================================================================= +# Schema extraction +# ============================================================================= + +class TestExtractSchema: + def test_basic_sd15(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + params = schema["parameters"] + assert "prompt" in params + assert "negative_prompt" in params + assert "seed" in params + assert "steps" in params + assert "cfg" in params + assert "width" in params + assert "height" in params + + def test_prompt_value_correct(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + # The positive prompt in the example is the landscape one + assert "landscape" in schema["parameters"]["prompt"]["value"] + assert "ugly" in schema["parameters"]["negative_prompt"]["value"] + + def test_model_dependencies(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + deps = schema["model_dependencies"] + ckpts = [d["value"] for d in deps if d["folder"] == "checkpoints"] + assert "v1-5-pruned-emaonly.safetensors" in ckpts + + def test_output_nodes(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + assert "9" in schema["output_nodes"] + + def test_summary(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + s = schema["summary"] + assert s["has_negative_prompt"] is True + assert s["has_seed"] is True + assert s["is_video_workflow"] is False + assert s["parameter_count"] > 5 + + def test_flux_workflow(self, flux_workflow): + schema = extract_schema(flux_workflow) + # Flux uses RandomNoise for seed + assert schema["summary"]["has_seed"] is True + # Flux has only positive prompt (no negative encoder) + assert schema["summary"]["has_negative_prompt"] is False + + def test_video_detected(self, video_workflow): + schema = extract_schema(video_workflow) + assert schema["summary"]["is_video_workflow"] is True + + +class TestEmbeddingDeps: + def test_extract_from_prompt(self): + wf = { + "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}}, + "5": {"class_type": "EmptyLatentImage", + "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", + "inputs": { + "text": "a cat, embedding:goodvibes, embedding:art:1.2", + "clip": ["1", 1] + }}, + "7": {"class_type": "CLIPTextEncode", + "inputs": { + "text": "ugly, embedding:badhands", + "clip": ["1", 1] + }}, + "3": {"class_type": "KSampler", + "inputs": { + "positive": ["6", 0], "negative": ["7", 0], + "model": ["1", 0], "latent_image": ["5", 0], + "seed": 1, "steps": 20, "cfg": 7.5, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + }}, + "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}}, + } + schema = extract_schema(wf) + names = [d["embedding_name"] for d in schema["embedding_dependencies"]] + assert sorted(names) == ["art", "badhands", "goodvibes"] + + +class TestDuplicateDeduplication: + def test_two_ksamplers_get_unique_names(self): + wf = { + "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}}, + "5": {"class_type": "EmptyLatentImage", + "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", "inputs": {"text": "a", "clip": ["1", 1]}}, + "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "b", "clip": ["1", 1]}}, + "3": {"class_type": "KSampler", + "inputs": { + "positive": ["6", 0], "negative": ["7", 0], + "model": ["1", 0], "latent_image": ["5", 0], + "seed": 42, "steps": 20, "cfg": 7.5, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + }}, + "4": {"class_type": "KSampler", + "inputs": { + "positive": ["6", 0], "negative": ["7", 0], + "model": ["1", 0], "latent_image": ["5", 0], + "seed": 99, "steps": 30, "cfg": 8.0, + "sampler_name": "euler", "scheduler": "normal", "denoise": 0.6, + }}, + "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}}, + } + schema = extract_schema(wf) + params = schema["parameters"] + # Both seeds present with disambiguated names + seed_keys = [k for k in params if "seed" in k] + # Symmetric: both renamed (no bare "seed") + assert "seed" not in params + assert "seed_3" in params and "seed_4" in params + assert params["seed_3"]["value"] == 42 + assert params["seed_4"]["value"] == 99 diff --git a/skills/creative/comfyui/tests/test_run_workflow.py b/skills/creative/comfyui/tests/test_run_workflow.py new file mode 100644 index 0000000000..32eb172ad1 --- /dev/null +++ b/skills/creative/comfyui/tests/test_run_workflow.py @@ -0,0 +1,213 @@ +"""Tests for run_workflow.py — focuses on logic that doesn't require a server.""" + +from __future__ import annotations + +import copy +import json + +import pytest + +from extract_schema import extract_schema +from run_workflow import ( + ComfyRunner, + download_outputs, + inject_params, + parse_input_image_arg, +) + + +class TestParseInputImageArg: + def test_with_name(self, tmp_path): + f = tmp_path / "x.png" + f.write_text("x") + n, p = parse_input_image_arg(f"image={f}") + assert n == "image" + assert p == f + + def test_without_name_defaults(self, tmp_path): + f = tmp_path / "x.png" + f.write_text("x") + n, p = parse_input_image_arg(str(f)) + assert n == "image" + + def test_custom_name(self, tmp_path): + f = tmp_path / "x.png" + f.write_text("x") + n, p = parse_input_image_arg(f"mask_image={f}") + assert n == "mask_image" + + +class TestInjectParams: + def test_basic_injection(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + wf, warnings = inject_params(sd15_workflow, schema, { + "prompt": "new prompt", + "seed": 999, + "steps": 25, + }) + assert wf["6"]["inputs"]["text"] == "new prompt" + assert wf["3"]["inputs"]["seed"] == 999 + assert wf["3"]["inputs"]["steps"] == 25 + assert warnings == [] + + def test_unknown_param_warns(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + _, warnings = inject_params(sd15_workflow, schema, {"foobar": "x"}) + assert any("foobar" in w for w in warnings) + + def test_seed_minus_one_randomizes(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + wf, warnings = inject_params(sd15_workflow, schema, {"seed": -1}) + assert wf["3"]["inputs"]["seed"] != -1 + assert isinstance(wf["3"]["inputs"]["seed"], int) + assert any("expanded" in w.lower() for w in warnings) + + def test_randomize_seed_when_unset(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + original = sd15_workflow["3"]["inputs"]["seed"] + wf, warnings = inject_params(sd15_workflow, schema, {}, randomize_seed_if_unset=True) + assert wf["3"]["inputs"]["seed"] != original + assert isinstance(wf["3"]["inputs"]["seed"], int) + + def test_does_not_mutate_original(self, sd15_workflow): + schema = extract_schema(sd15_workflow) + original_text = sd15_workflow["6"]["inputs"]["text"] + inject_params(sd15_workflow, schema, {"prompt": "MUTATED"}) + assert sd15_workflow["6"]["inputs"]["text"] == original_text + + def test_refuses_to_overwrite_link(self): + wf = { + "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "x"}}, + "5": {"class_type": "EmptyLatentImage", + "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", + "inputs": {"text": ["3", 0], "clip": ["1", 1]}}, # text is a link! + "3": {"class_type": "KSampler", + "inputs": {"seed": 1, "steps": 20, "cfg": 7.5, + "sampler_name": "euler", "scheduler": "normal", "denoise": 1.0, + "model": ["1", 0], "positive": ["6", 0], "negative": ["6", 0], + "latent_image": ["5", 0]}}, + "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "x", "images": ["3", 0]}}, + } + # Manually create a schema that has prompt pointing at 6.text + schema = { + "parameters": { + "prompt": {"node_id": "6", "field": "text", "type": "string", "value": ""}, + } + } + wf2, warnings = inject_params(wf, schema, {"prompt": "literal value"}) + # The link should NOT have been overwritten + assert wf2["6"]["inputs"]["text"] == ["3", 0] + assert any("link" in w.lower() for w in warnings) + + +# ============================================================================= +# Output download walk +# ============================================================================= + +class TestDownloadOutputsWalk: + """Test that download_outputs walks the structure correctly.""" + + def test_handles_videos_plural(self, tmp_path, monkeypatch): + """Local ComfyUI uses 'videos'/'gifs' (plural) keys.""" + downloads = [] + + class FakeRunner: + def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite): + downloads.append((filename, subfolder, file_type)) + p = output_dir / filename + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(b"x") + return p + + outputs = { + "9": {"images": [{"filename": "img1.png", "subfolder": "", "type": "output"}]}, + "10": {"videos": [{"filename": "vid1.mp4", "subfolder": "", "type": "output"}]}, + "11": {"gifs": [{"filename": "anim1.gif", "subfolder": "", "type": "output"}]}, + } + + result = download_outputs(FakeRunner(), outputs, tmp_path) + files = sorted(d["filename"] for d in result) + assert files == ["anim1.gif", "img1.png", "vid1.mp4"] + + def test_handles_video_singular_cloud(self, tmp_path): + """Cloud uses 'video' (singular).""" + class FakeRunner: + def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite): + p = output_dir / filename + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(b"x") + return p + + outputs = { + "10": {"video": [{"filename": "cloud.mp4", "subfolder": "", "type": "output"}]}, + } + result = download_outputs(FakeRunner(), outputs, tmp_path) + assert len(result) == 1 + assert result[0]["filename"] == "cloud.mp4" + + def test_preserves_subfolder(self, tmp_path): + """When preserve_subfolder=True, server subfolder becomes local subdir.""" + class FakeRunner: + def download_output(self, *, filename, subfolder, file_type, output_dir, preserve_subfolder, overwrite): + if preserve_subfolder and subfolder: + p = output_dir / subfolder / filename + else: + p = output_dir / filename + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(b"x") + return p + + outputs = { + "9": {"images": [ + {"filename": "img.png", "subfolder": "myrun", "type": "output"}, + {"filename": "img.png", "subfolder": "otherrun", "type": "output"}, + ]}, + } + result = download_outputs(FakeRunner(), outputs, tmp_path, preserve_subfolder=True) + files = [d["file"] for d in result] + assert any("myrun" in f for f in files) + assert any("otherrun" in f for f in files) + # Both must exist (no collision) + assert len({str(f) for f in files}) == 2 + + +# ============================================================================= +# ComfyRunner construction +# ============================================================================= + +class TestRunnerConstruction: + def test_local_default(self): + r = ComfyRunner() + assert r.is_cloud is False + assert r.host == "http://127.0.0.1:8188" + + def test_cloud_detection(self): + r = ComfyRunner(host="https://cloud.comfy.org", api_key="abc") + assert r.is_cloud is True + assert "X-API-Key" in r.headers + + def test_cloud_subdomain_detected(self): + r = ComfyRunner(host="https://staging.cloud.comfy.org", api_key="abc") + assert r.is_cloud is True + + def test_partner_key_does_not_pollute_extra_data(self): + r = ComfyRunner(host="https://cloud.comfy.org", api_key="auth-key") + # No partner-key set → no extra_data should appear in submitted prompt + # (This is a static check; runtime check happens in submit()) + assert r.partner_key is None + + def test_url_routing_local(self): + r = ComfyRunner() + url = r._url("/prompt") + assert url == "http://127.0.0.1:8188/prompt" + + def test_url_routing_cloud(self): + r = ComfyRunner(host="https://cloud.comfy.org", api_key="x") + url = r._url("/prompt") + assert url == "https://cloud.comfy.org/api/prompt" + + def test_url_routing_cloud_history_renamed(self): + r = ComfyRunner(host="https://cloud.comfy.org", api_key="x") + url = r._url("/history/abc-123") + assert url == "https://cloud.comfy.org/api/history_v2/abc-123" diff --git a/skills/creative/comfyui/workflows/README.md b/skills/creative/comfyui/workflows/README.md new file mode 100644 index 0000000000..f3f40c2f2d --- /dev/null +++ b/skills/creative/comfyui/workflows/README.md @@ -0,0 +1,86 @@ +# Example Workflows + +These are starter API-format workflows for the most common tasks. They're +ready to run with `scripts/run_workflow.py` once you've installed (or have +cloud access to) the listed models. + +| File | Purpose | Required models | Min VRAM | +|------|---------|-----------------|----------| +| `sd15_txt2img.json` | SD 1.5 text-to-image (512×512) | SD1.5 checkpoint, e.g. `v1-5-pruned-emaonly.safetensors` | 4 GB | +| `sdxl_txt2img.json` | SDXL text-to-image (1024×1024) | `sd_xl_base_1.0.safetensors` | 8 GB | +| `flux_dev_txt2img.json` | Flux Dev text-to-image (1024×1024) | `flux1-dev.safetensors`, `t5xxl_fp16.safetensors`, `clip_l.safetensors`, `ae.safetensors` | 24 GB (or use `flux1-dev-fp8`) | +| `sdxl_img2img.json` | SDXL image-to-image | SDXL checkpoint | 8 GB | +| `sdxl_inpaint.json` | SDXL inpainting (image + mask) | SDXL checkpoint | 8 GB | +| `upscale_4x.json` | Standalone 4× ESRGAN upscale | `4x-UltraSharp.pth` (or any upscaler) | 4 GB | +| `animatediff_video.json` | AnimateDiff text-to-video (16 frames) | SD1.5 checkpoint, `mm_sd_v15_v2.ckpt` motion module | 8 GB | +| `wan_video_t2v.json` | Wan 2.x text-to-video (~33 frames) | `wan2.2_t2v_1.3B_fp16.safetensors`, `umt5_xxl_fp16.safetensors`, `wan_2.1_vae.safetensors` | 24 GB | + +## Quick start + +```bash +# Run a workflow with prompt injection +python3 ../scripts/run_workflow.py \ + --workflow sdxl_txt2img.json \ + --args '{"prompt": "majestic eagle in flight", "seed": 12345, "steps": 35}' \ + --output-dir ./out + +# Img2img: upload an input image first via the script's helper +python3 ../scripts/run_workflow.py \ + --workflow sdxl_img2img.json \ + --input-image image=./photo.png \ + --args '{"prompt": "make it watercolor", "denoise": 0.6}' \ + --output-dir ./out + +# Cloud (set API key once) +export COMFY_CLOUD_API_KEY="comfyui-..." +python3 ../scripts/run_workflow.py \ + --workflow flux_dev_txt2img.json \ + --args '{"prompt": "a fox in a misty forest"}' \ + --host https://cloud.comfy.org \ + --output-dir ./out + +# What can I tweak in this workflow? +python3 ../scripts/extract_schema.py sdxl_txt2img.json --summary-only + +# Are all required models / nodes installed? +python3 ../scripts/check_deps.py wan_video_t2v.json +``` + +## Notes + +- **Inpaint masks**: white pixels = "regenerate this region", black = preserve. + ComfyUI's `LoadImageMask` reads the **red channel** by default; export your + mask as a single-channel image or as a normal RGB where red==intensity. + +- **Denoise strength** in img2img: `0.0` = output identical to input, + `1.0` = ignore input entirely. Sweet spot is usually 0.4–0.7. + +- **Flux Dev** needs ~24 GB VRAM in its base form. The `flux1-dev-fp8.safetensors` + variant (already on Comfy Cloud) cuts that roughly in half. + +- **Video workflows** can take many minutes. The skill auto-detects video + output nodes and bumps the default timeout to 900s. Override with `--timeout 1800`. + +- These JSON files are deliberately **API format** (top-level keys are node IDs + with `class_type`), not editor format. To open them in ComfyUI's web UI for + visual editing, use `Workflow → Load (API Format)` or `Workflow → Open` and + follow the prompt. + +## Cloud vs local model names + +Comfy Cloud's preinstalled checkpoints sometimes have a `-fp16` suffix +(`v1-5-pruned-emaonly-fp16.safetensors`) while the canonical local download +keeps the original name (`v1-5-pruned-emaonly.safetensors`). The example +workflows use the local-canonical names. When running on cloud, override with: + +```bash +python3 ../scripts/run_workflow.py \ + --workflow sd15_txt2img.json \ + --args '{"ckpt_name": "v1-5-pruned-emaonly-fp16.safetensors", "prompt": "..."}' \ + --host https://cloud.comfy.org +``` + +The `ckpt_name`, `vae_name`, `lora_name`, `unet_name`, etc. are all exposed +as controllable parameters by `extract_schema.py` — discover what's installed +with `comfy model list` (local) or `curl /api/experiment/models/checkpoints` +(cloud). diff --git a/skills/creative/comfyui/workflows/animatediff_video.json b/skills/creative/comfyui/workflows/animatediff_video.json new file mode 100644 index 0000000000..cc2b296c3a --- /dev/null +++ b/skills/creative/comfyui/workflows/animatediff_video.json @@ -0,0 +1,64 @@ +{ + "_comment": "AnimateDiff text-to-video at 16 frames. Required: comfyui-animatediff-evolved + comfyui-videohelpersuite custom nodes; SD1.5 checkpoint; AnimateDiff motion module (e.g. mm_sd_v15_v2.ckpt in models/animatediff_models/). Outputs a webp animation.", + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 42, "steps": 25, "cfg": 7.5, + "sampler_name": "dpmpp_sde", "scheduler": "karras", "denoise": 1.0, + "model": ["10", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "_meta": {"title": "Checkpoint"}, + "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"} + }, + "5": { + "class_type": "EmptyLatentImage", + "_meta": {"title": "Latent (16 frames)"}, + "inputs": {"width": 512, "height": 512, "batch_size": 16} + }, + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Positive Prompt"}, + "inputs": {"text": "a hot air balloon drifting over a mountain valley, sunset, cinematic", "clip": ["4", 1]} + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": {"text": "low quality, blurry, deformed, watermark", "clip": ["4", 1]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["4", 2]} + }, + "9": { + "class_type": "VHS_VideoCombine", + "_meta": {"title": "Video Combine"}, + "inputs": { + "frame_rate": 8.0, + "loop_count": 0, + "filename_prefix": "animatediff", + "format": "video/h264-mp4", + "pingpong": false, + "save_output": true, + "images": ["8", 0] + } + }, + "10": { + "class_type": "ADE_AnimateDiffLoaderWithContext", + "_meta": {"title": "AnimateDiff Loader"}, + "inputs": { + "model": ["4", 0], + "model_name": "mm_sd_v15_v2.ckpt", + "beta_schedule": "sqrt_linear (AnimateDiff)", + "motion_scale": 1.0, + "apply_v2_models_properly": true + } + } +} diff --git a/skills/creative/comfyui/workflows/flux_dev_txt2img.json b/skills/creative/comfyui/workflows/flux_dev_txt2img.json new file mode 100644 index 0000000000..1791280be2 --- /dev/null +++ b/skills/creative/comfyui/workflows/flux_dev_txt2img.json @@ -0,0 +1,78 @@ +{ + "_comment": "Flux Dev text-to-image using the modern sampler chain (BasicScheduler/Guider/SamplerCustomAdvanced). Required: flux1-dev.safetensors (UNET), t5xxl_fp16.safetensors + clip_l.safetensors (CLIP), ae.safetensors (VAE).", + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Prompt"}, + "inputs": {"text": "a serene mountain landscape at golden hour, photorealistic", "clip": ["11", 0]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["13", 0], "vae": ["10", 0]} + }, + "9": { + "class_type": "SaveImage", + "_meta": {"title": "Save Image"}, + "inputs": {"filename_prefix": "flux_dev", "images": ["8", 0]} + }, + "10": { + "class_type": "VAELoader", + "_meta": {"title": "VAE"}, + "inputs": {"vae_name": "ae.safetensors"} + }, + "11": { + "class_type": "DualCLIPLoader", + "_meta": {"title": "DualCLIPLoader"}, + "inputs": { + "clip_name1": "t5xxl_fp16.safetensors", + "clip_name2": "clip_l.safetensors", + "type": "flux" + } + }, + "12": { + "class_type": "UNETLoader", + "_meta": {"title": "UNET Loader"}, + "inputs": {"unet_name": "flux1-dev.safetensors", "weight_dtype": "default"} + }, + "13": { + "class_type": "SamplerCustomAdvanced", + "_meta": {"title": "Sampler Custom"}, + "inputs": { + "noise": ["25", 0], + "guider": ["22", 0], + "sampler": ["16", 0], + "sigmas": ["17", 0], + "latent_image": ["27", 0] + } + }, + "16": { + "class_type": "KSamplerSelect", + "_meta": {"title": "Sampler Select"}, + "inputs": {"sampler_name": "euler"} + }, + "17": { + "class_type": "BasicScheduler", + "_meta": {"title": "Scheduler"}, + "inputs": { + "scheduler": "simple", + "steps": 20, + "denoise": 1.0, + "model": ["12", 0] + } + }, + "22": { + "class_type": "BasicGuider", + "_meta": {"title": "Guider"}, + "inputs": {"model": ["12", 0], "conditioning": ["6", 0]} + }, + "25": { + "class_type": "RandomNoise", + "_meta": {"title": "Noise"}, + "inputs": {"noise_seed": 42} + }, + "27": { + "class_type": "EmptySD3LatentImage", + "_meta": {"title": "Latent"}, + "inputs": {"width": 1024, "height": 1024, "batch_size": 1} + } +} diff --git a/skills/creative/comfyui/workflows/sd15_txt2img.json b/skills/creative/comfyui/workflows/sd15_txt2img.json new file mode 100644 index 0000000000..f67eb79f54 --- /dev/null +++ b/skills/creative/comfyui/workflows/sd15_txt2img.json @@ -0,0 +1,49 @@ +{ + "_comment": "SD 1.5 text-to-image. Smallest model, fastest. Required model: v1-5-pruned-emaonly.safetensors (or any SD1.5 checkpoint)", + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 156680208700286, + "steps": 20, + "cfg": 8.0, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "_meta": {"title": "Load Checkpoint"}, + "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"} + }, + "5": { + "class_type": "EmptyLatentImage", + "_meta": {"title": "Empty Latent"}, + "inputs": {"width": 512, "height": 512, "batch_size": 1} + }, + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Positive Prompt"}, + "inputs": {"text": "a beautiful landscape painting, masterpiece, highly detailed", "clip": ["4", 1]} + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": {"text": "ugly, blurry, low quality, deformed", "clip": ["4", 1]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["4", 2]} + }, + "9": { + "class_type": "SaveImage", + "_meta": {"title": "Save Image"}, + "inputs": {"filename_prefix": "sd15", "images": ["8", 0]} + } +} diff --git a/skills/creative/comfyui/workflows/sdxl_img2img.json b/skills/creative/comfyui/workflows/sdxl_img2img.json new file mode 100644 index 0000000000..a835567aaa --- /dev/null +++ b/skills/creative/comfyui/workflows/sdxl_img2img.json @@ -0,0 +1,54 @@ +{ + "_comment": "SDXL img2img: load an input image, encode to latent, denoise partially. Use --input-image image=./photo.png with run_workflow.py. Lower 'denoise' value preserves more of the source image.", + "1": { + "class_type": "LoadImage", + "_meta": {"title": "Load Source Image"}, + "inputs": {"image": "REPLACE_WITH_UPLOADED_FILENAME.png"} + }, + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 42, + "steps": 30, + "cfg": 7.5, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 0.65, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["12", 0] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "_meta": {"title": "Load SDXL Base"}, + "inputs": {"ckpt_name": "sd_xl_base_1.0.safetensors"} + }, + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Positive Prompt"}, + "inputs": {"text": "make it cyberpunk, neon lights, futuristic", "clip": ["4", 1]} + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": {"text": "ugly, blurry, low quality, deformed", "clip": ["4", 1]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["4", 2]} + }, + "9": { + "class_type": "SaveImage", + "_meta": {"title": "Save Image"}, + "inputs": {"filename_prefix": "sdxl_img2img", "images": ["8", 0]} + }, + "12": { + "class_type": "VAEEncode", + "_meta": {"title": "VAE Encode"}, + "inputs": {"pixels": ["1", 0], "vae": ["4", 2]} + } +} diff --git a/skills/creative/comfyui/workflows/sdxl_inpaint.json b/skills/creative/comfyui/workflows/sdxl_inpaint.json new file mode 100644 index 0000000000..20e50ccf1b --- /dev/null +++ b/skills/creative/comfyui/workflows/sdxl_inpaint.json @@ -0,0 +1,59 @@ +{ + "_comment": "SDXL inpainting: given an image + mask, regenerate the masked region. Upload both: --input-image image=./photo.png --input-image mask_image=./mask.png. White pixels in mask = regenerate; black = preserve.", + "1": { + "class_type": "LoadImage", + "_meta": {"title": "Load Source"}, + "inputs": {"image": "REPLACE_WITH_UPLOADED_FILENAME.png"} + }, + "2": { + "class_type": "LoadImageMask", + "_meta": {"title": "Load Mask"}, + "inputs": {"image": "REPLACE_WITH_UPLOADED_MASK.png", "channel": "red"} + }, + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 42, + "steps": 30, + "cfg": 7.5, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["12", 0] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "_meta": {"title": "Checkpoint"}, + "inputs": {"ckpt_name": "sd_xl_base_1.0.safetensors"} + }, + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Positive Prompt"}, + "inputs": {"text": "fill with blooming flowers, photorealistic", "clip": ["4", 1]} + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": {"text": "ugly, blurry, deformed, bad anatomy", "clip": ["4", 1]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["4", 2]} + }, + "9": { + "class_type": "SaveImage", + "_meta": {"title": "Save"}, + "inputs": {"filename_prefix": "sdxl_inpaint", "images": ["8", 0]} + }, + "12": { + "class_type": "VAEEncodeForInpaint", + "_meta": {"title": "VAE Encode for Inpaint"}, + "inputs": {"pixels": ["1", 0], "mask": ["2", 0], "vae": ["4", 2], "grow_mask_by": 6} + } +} diff --git a/skills/creative/comfyui/workflows/sdxl_txt2img.json b/skills/creative/comfyui/workflows/sdxl_txt2img.json new file mode 100644 index 0000000000..cb590b40f9 --- /dev/null +++ b/skills/creative/comfyui/workflows/sdxl_txt2img.json @@ -0,0 +1,49 @@ +{ + "_comment": "SDXL text-to-image at 1024x1024. Required model: sd_xl_base_1.0.safetensors (or any SDXL checkpoint).", + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 42, + "steps": 30, + "cfg": 7.5, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["4", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0] + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "_meta": {"title": "Load SDXL Base"}, + "inputs": {"ckpt_name": "sd_xl_base_1.0.safetensors"} + }, + "5": { + "class_type": "EmptyLatentImage", + "_meta": {"title": "Empty Latent"}, + "inputs": {"width": 1024, "height": 1024, "batch_size": 1} + }, + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Positive Prompt"}, + "inputs": {"text": "cinematic photograph, dramatic lighting, intricate detail", "clip": ["4", 1]} + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": {"text": "ugly, blurry, low quality, deformed, watermark", "clip": ["4", 1]} + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["4", 2]} + }, + "9": { + "class_type": "SaveImage", + "_meta": {"title": "Save Image"}, + "inputs": {"filename_prefix": "sdxl", "images": ["8", 0]} + } +} diff --git a/skills/creative/comfyui/workflows/upscale_4x.json b/skills/creative/comfyui/workflows/upscale_4x.json new file mode 100644 index 0000000000..91ad7eb1de --- /dev/null +++ b/skills/creative/comfyui/workflows/upscale_4x.json @@ -0,0 +1,27 @@ +{ + "_comment": "Standalone 4x upscale of an input image using ESRGAN. Required model: 4x-UltraSharp.pth (or any upscaler in models/upscale_models/). Upload with --input-image image=./photo.png.", + "1": { + "class_type": "LoadImage", + "_meta": {"title": "Load Image"}, + "inputs": {"image": "REPLACE_WITH_UPLOADED_FILENAME.png"} + }, + "2": { + "class_type": "UpscaleModelLoader", + "_meta": {"title": "Load Upscale Model"}, + "inputs": {"model_name": "4x-UltraSharp.pth"} + }, + "3": { + "class_type": "ImageUpscaleWithModel", + "_meta": {"title": "Upscale Image (with Model)"}, + "inputs": { + "upscale_method": "lanczos", + "upscale_model": ["2", 0], + "image": ["1", 0] + } + }, + "4": { + "class_type": "SaveImage", + "_meta": {"title": "Save"}, + "inputs": {"filename_prefix": "upscaled_4x", "images": ["3", 0]} + } +} diff --git a/skills/creative/comfyui/workflows/wan_video_t2v.json b/skills/creative/comfyui/workflows/wan_video_t2v.json new file mode 100644 index 0000000000..7514e3a627 --- /dev/null +++ b/skills/creative/comfyui/workflows/wan_video_t2v.json @@ -0,0 +1,69 @@ +{ + "_comment": "Wan 2.1 text-to-video. Cloud: confirmed available. Local: download wan2.1_t2v_1.3B_fp16.safetensors → models/diffusion_models/ (or models/unet/), umt5_xxl_fp16.safetensors → models/text_encoders/ (or models/clip/), wan_2.1_vae.safetensors → models/vae/. Output: MP4. Large model — only on cloud or 24 GB+ local GPU.", + "6": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Prompt"}, + "inputs": { + "text": "a graceful crane taking flight from a misty lake at dawn, slow motion, 4k", + "clip": ["38", 0] + } + }, + "7": { + "class_type": "CLIPTextEncode", + "_meta": {"title": "Negative Prompt"}, + "inputs": { + "text": "static, blurry, watermark, low quality", + "clip": ["38", 0] + } + }, + "8": { + "class_type": "VAEDecode", + "_meta": {"title": "VAE Decode"}, + "inputs": {"samples": ["3", 0], "vae": ["39", 0]} + }, + "37": { + "class_type": "UNETLoader", + "_meta": {"title": "Wan UNET"}, + "inputs": {"unet_name": "wan2.1_t2v_1.3B_fp16.safetensors", "weight_dtype": "default"} + }, + "38": { + "class_type": "CLIPLoader", + "_meta": {"title": "Wan CLIP"}, + "inputs": {"clip_name": "umt5_xxl_fp16.safetensors", "type": "wan"} + }, + "39": { + "class_type": "VAELoader", + "_meta": {"title": "Wan VAE"}, + "inputs": {"vae_name": "wan_2.1_vae.safetensors"} + }, + "3": { + "class_type": "KSampler", + "_meta": {"title": "KSampler"}, + "inputs": { + "seed": 42, "steps": 30, "cfg": 6.0, + "sampler_name": "uni_pc", "scheduler": "simple", "denoise": 1.0, + "model": ["37", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["40", 0] + } + }, + "40": { + "class_type": "EmptyHunyuanLatentVideo", + "_meta": {"title": "Latent Video (33 frames)"}, + "inputs": {"width": 832, "height": 480, "length": 33, "batch_size": 1} + }, + "9": { + "class_type": "VHS_VideoCombine", + "_meta": {"title": "Video Combine"}, + "inputs": { + "frame_rate": 16.0, + "loop_count": 0, + "filename_prefix": "wan_t2v", + "format": "video/h264-mp4", + "pingpong": false, + "save_output": true, + "images": ["8", 0] + } + } +} diff --git a/skills/creative/pretext/SKILL.md b/skills/creative/pretext/SKILL.md new file mode 100644 index 0000000000..429dd8798f --- /dev/null +++ b/skills/creative/pretext/SKILL.md @@ -0,0 +1,219 @@ +--- +name: pretext +description: "Use when building creative browser demos with @chenglou/pretext — DOM-free text layout for ASCII art, typographic flow around obstacles, text-as-geometry games, kinetic typography, and text-powered generative art. Produces single-file HTML demos by default." +version: 1.0.0 +author: Hermes Agent +license: MIT +metadata: + hermes: + tags: [creative-coding, typography, pretext, ascii-art, canvas, generative, text-layout, kinetic-typography] + related_skills: [p5js, claude-design, excalidraw, architecture-diagram] +--- + +# Pretext Creative Demos + +## Overview + +[`@chenglou/pretext`](https://github.com/chenglou/pretext) is a 15KB zero-dependency TypeScript library by Cheng Lou (React core, ReasonML, Midjourney) for **DOM-free multiline text measurement and layout**. It does one thing: given `(text, font, width)`, return the line breaks, per-line widths, per-grapheme positions, and total height — all via canvas measurement, no reflow. + +That sounds like plumbing. It is not. Because it is fast and geometric, it is a **creative primitive**: you can reflow paragraphs around a moving sprite at 60fps, build games whose level geometry is made of real words, drive ASCII logos through prose, shatter text into particles with exact per-grapheme starting positions, or pack shrink-wrapped multiline UI without any `getBoundingClientRect` thrash. + +This skill exists so Hermes can make **cool demos** with it — the kind people post to X. See `pretext.cool` and `chenglou.me/pretext` for the community demo corpus. + +## When to Use + +Use when the user asks for: +- A "pretext demo" / "cool pretext thing" / "text-as-X" +- Text flowing around a moving shape (hero sections, editorial layouts, animated long-form pages) +- ASCII-art effects using **real words or prose**, not monospace rasters +- Games where the playfield / obstacles / bricks are made of text (Tetris-from-letters, Breakout-of-prose) +- Kinetic typography with per-glyph physics (shatter, scatter, flock, flow) +- Typographic generative art, especially with non-Latin scripts or mixed scripts +- Multiline "shrink-wrap" UI (smallest container width that still fits the text) +- Anything that would require knowing line breaks *before* rendering + +Don't use for: +- Static SVG/HTML pages where CSS already solves layout — just use CSS +- Rich text editors, general inline formatting engines (pretext is intentionally narrow) +- Image → text (use `ascii-art` / `ascii-video` skills) +- Pure canvas generative art with no text role — use `p5js` + +## Creative Standard + +This is visual art rendered in a browser. Pretext returns numbers; **you** draw the thing. + +- **Don't ship a "hello world" demo.** The `hello-orb-flow.html` template is the *starting* point. Every delivered demo must add intentional color, motion, composition, and one visual detail the user didn't ask for but will appreciate. +- **Dark backgrounds, warm cores, considered palette.** Classic amber-on-black (CRT / terminal) works, but so do cold-white-on-charcoal (editorial) and desaturated pastels (risograph). Pick one and commit. +- **Proportional fonts are the point.** Pretext's whole vibe is "not monospaced" — lean into it. Use Iowan Old Style, Inter, JetBrains Mono, Helvetica Neue, or a variable font. Never default sans. +- **Real source/text, not lorem ipsum.** The corpus should mean something. Short manifestos, poetry, real source code, a found text, the library's own README — never `lorem ipsum`. +- **First-paint excellence.** No loading states, no blank frames. The demo must look shippable the instant it opens. + +## Stack + +Single self-contained HTML file per demo. No build step. + +| Layer | Tool | Purpose | +|-------|------|---------| +| Core | `@chenglou/pretext` via `esm.sh` CDN | Text measurement + line layout | +| Render | HTML5 Canvas 2D | Glyph rendering, per-frame composition | +| Segmentation | `Intl.Segmenter` (built-in) | Grapheme splitting for emoji / CJK / combining marks | +| Interaction | Raw DOM events | Mouse / touch / wheel — no framework | + +```html + +``` + +Pin the version. `@0.0.6` at time of writing — check [npm](https://www.npmjs.com/package/@chenglou/pretext) for the latest if demo behavior is off. + +## The Two Use Cases + +Almost everything reduces to one of these two shapes. Learn both. + +### Use-case 1 — measure, then render with CSS/DOM + +```js +const prepared = prepare(text, "16px Inter"); +const { height, lineCount } = layout(prepared, 320, 20); +``` + +You still let the browser draw the text. Pretext just tells you how tall the box will be at a given width, **without** a DOM read. Use for: +- Virtualized lists where rows contain wrapping text +- Masonry with precise card heights +- "Does this label fit?" dev-time checks +- Preventing layout shift when remote text loads + +**Keep `font` and `letterSpacing` exactly in sync with your CSS.** The canvas `ctx.font` format (e.g. `"16px Inter"`, `"500 17px 'JetBrains Mono'"`) must match the rendered CSS, or measurements drift. + +### Use-case 2 — measure *and* render yourself + +```js +const prepared = prepareWithSegments(text, FONT); +const { lines } = layoutWithLines(prepared, 320, 26); +for (let i = 0; i < lines.length; i++) { + ctx.fillText(lines[i].text, 0, i * 26); +} +``` + +This is where the creative work lives. You own the drawing, so you can: +- Render to canvas, SVG, WebGL, or any coordinate system +- Substitute per-glyph transforms (rotation, jitter, scale, opacity) +- Use line metadata (width, grapheme positions) as geometry + +For **variable-width-per-line** flow (text around a shape, text in a donut band, text in a non-rectangular column): + +```js +let cursor = { segmentIndex: 0, graphemeIndex: 0 }; +let y = 0; +while (true) { + const lineWidth = widthAtY(y); // your function: how wide is the corridor at this y? + const range = layoutNextLineRange(prepared, cursor, lineWidth); + if (!range) break; + const line = materializeLineRange(prepared, range); + ctx.fillText(line.text, leftEdgeAtY(y), y); + cursor = range.end; + y += lineHeight; +} +``` + +This is the most important pattern in the whole library. It's what unlocks "text flowing around a dragged sprite" — the demo that went viral on X. + +### Helpers worth knowing + +- `measureLineStats(prepared, maxWidth)` → `{ lineCount, maxLineWidth }` — the widest line, i.e. multiline shrink-wrap width. +- `walkLineRanges(prepared, maxWidth, callback)` — iterate lines without allocating strings. Use for stats/physics over graphemes when you don't need the characters. +- `@chenglou/pretext/rich-inline` — the same system but for paragraphs mixing fonts / chips / mentions. Import from the subpath. + +## Demo Recipe Patterns + +The community corpus (see `references/patterns.md`) clusters into a handful of strong patterns. Pick one and riff — don't invent a new category unless asked. + +| Pattern | Key API | Example idea | +|---|---|---| +| **Reflow around obstacle** | `layoutNextLineRange` + per-row width function | Editorial paragraph that parts around a dragged cursor sprite | +| **Text-as-geometry game** | `layoutWithLines` + per-line collision rects | Breakout where each brick is a measured word | +| **Shatter / particles** | `walkLineRanges` → per-grapheme (x,y) → physics | Sentence that explodes into letters on click | +| **ASCII obstacle typography** | `layoutNextLineRange` + measured per-row obstacle spans | Bitmap ASCII logo, shape morphs, and draggable wire objects that make text open around their actual geometry | +| **Editorial multi-column** | `layoutNextLineRange` per column + shared cursor | Animated magazine spread with pull quotes | +| **Kinetic type** | `layoutWithLines` + per-line transform over time | Star Wars crawl, wave, bounce, glitch | +| **Multiline shrink-wrap** | `measureLineStats` | Quote card that auto-sizes to its tightest container | + +See `templates/donut-orbit.html` and `templates/hello-orb-flow.html` for working single-file starters. + +## Workflow + +1. **Pick a pattern** from the table above based on the user's brief. +2. **Start from a template**: + - `templates/hello-orb-flow.html` — text reflowing around a moving orb (reflow-around-obstacle pattern) + - `templates/donut-orbit.html` — advanced example: measured ASCII logo obstacles, draggable wire sphere/cube, morphing shape fields, selectable DOM text, and dev-only controls + - `write_file` to a new `.html` in `/tmp/` or the user's workspace. +3. **Swap the corpus** for something intentional to the brief. Real prose, 10-100 sentences, no lorem. +4. **Tune the aesthetic** — font, palette, composition, interaction. This is the work; don't skip it. +5. **Verify locally**: + ```sh + cd && python3 -m http.server 8765 + # then open http://localhost:8765/.html + ``` +6. **Check the console** — pretext will throw if `prepareWithSegments` is called with a bad font string; `Intl.Segmenter` is available in every modern browser. +7. **Show the user the file path**, not just the code — they want to open it. + +## Performance Notes + +- `prepare()` / `prepareWithSegments()` is the expensive call. Do it **once** per text+font pair. Cache the handle. +- On resize, only rerun `layout()` / `layoutWithLines()` — never re-prepare. +- For per-frame animations where text doesn't change but geometry does, `layoutNextLineRange` in a tight loop is cheap enough to do every frame at 60fps for normal-length paragraphs. +- When rendering ASCII masks per frame, keep a cell buffer (`Uint8Array`/typed arrays), derive measured per-row obstacle spans from the cells or projected geometry, merge spans, then feed those spans into `layoutNextLineRange` before drawing text. +- Keep visual animation and layout animation coupled. If a sphere morphs into a cube, tween both the rendered cell buffer and the obstacle spans with the same value; otherwise the demo looks painted-on instead of physically reflowed. +- For fades, prefer layer opacity over changing glyph intensity or obstacle scale. Put transient ASCII sprites on their own canvas and fade the canvas with CSS/GSAP opacity so geometry does not appear to shrink. +- Canvas `ctx.font` setting is surprisingly slow; set it **once** per frame if font doesn't vary, not per `fillText` call. + +## Common Pitfalls + +1. **Drifting CSS/canvas font strings.** `ctx.font = "16px Inter"` measured, but CSS says `font-family: Inter, sans-serif; font-size: 16px`. Fine *if* Inter loads. If Inter 404s, CSS falls back to sans-serif and measurements drift by 5-20%. Always `preload` the font or use a web-safe family. + +2. **Re-preparing inside the animation loop.** Only `layout*` is cheap. Re-calling `prepare` every frame will tank perf. Keep the prepared handle in module scope. + +3. **Forgetting `Intl.Segmenter` for grapheme splits.** Emoji, combining marks, CJK — `"é".split("")` gives you two chars. Use `new Intl.Segmenter(undefined, { granularity: "grapheme" })` when sampling individual visible glyphs. + +4. **`break: 'never'` chips without `extraWidth`.** In `rich-inline`, if you use `break: 'never'` for an atomic chip/mention, you must also supply `extraWidth` for the pill padding — otherwise chip chrome overflows the container. + +5. **Using `@chenglou/pretext` from `unpkg` with TypeScript-only entry.** Use `esm.sh` — it compiles the TS exports to browser-ready ESM automatically. `unpkg` will 404 or serve raw TS. + +6. **Monospace fallbacks silently erasing the whole point.** Users seeing monospace-looking output often have a CSS `font-family` that fell through to `monospace`. Verify the actual rendered font via DevTools. + +7. **Skipping rows vs adjusting width** when flowing around a shape. If the corridor on this row is too narrow to fit a line, *skip the row* (`y += lineHeight; continue;`) rather than passing a tiny maxWidth to `layoutNextLineRange` — pretext will return one-grapheme lines that look broken. + +8. **Shipping a cold demo.** The default first-paint looks tutorial-grade. Add: vignette, subtle scanline, idle auto-motion, one carefully chosen interactive response (drag, hover, scroll, click). Without these, "cool pretext demo" lands as "intern repro of the README." + +## Verification Checklist + +- [ ] Demo is a single self-contained `.html` file — opens by double-click or `python3 -m http.server` +- [ ] `@chenglou/pretext` imported via `esm.sh` with pinned version +- [ ] Corpus is real prose, not lorem ipsum, and matches the demo's concept +- [ ] Font string passed to `prepare` matches the CSS font exactly +- [ ] `prepare()` / `prepareWithSegments()` called once, not per frame +- [ ] Dark background + considered palette — not the default white canvas +- [ ] At least one interactive response (drag / hover / scroll / click) or idle auto-motion +- [ ] Tested locally with `python3 -m http.server` and confirmed no console errors +- [ ] 60fps on a mid-tier laptop (or graceful degradation documented) +- [ ] One "extra mile" detail the user didn't ask for + +## Reference: Community Demos + +Clone these for inspiration / patterns (all MIT-ish, linked from [pretext.cool](https://www.pretext.cool/)): + +- **Pretext Breaker** — breakout with word-bricks — `github.com/rinesh/pretext-breaker` +- **Tetris × Pretext** — `github.com/shinichimochizuki/tetris-pretext` +- **Dragon animation** — `github.com/qtakmalay/PreTextExperiments` +- **Somnai editorial engine** — `github.com/somnai-dreams/pretext-demos` +- **Bad Apple!! ASCII** — `github.com/frmlinn/bad-apple-pretext` +- **Drag-sprite reflow** — `github.com/dokobot/pretext-demo` +- **Alarmy editorial clock** — `github.com/SmisLee/alarmy-pretext-demo` + +Official playground: [chenglou.me/pretext](https://chenglou.me/pretext/) — accordion, bubbles, dynamic-layout, editorial-engine, justification-comparison, masonry, markdown-chat, rich-note. diff --git a/skills/creative/pretext/references/patterns.md b/skills/creative/pretext/references/patterns.md new file mode 100644 index 0000000000..2fa867232d --- /dev/null +++ b/skills/creative/pretext/references/patterns.md @@ -0,0 +1,258 @@ +# Pretext Patterns + +Copy-pasteable snippets for the most common pretext demo shapes. Each pattern is self-contained — drop into an HTML ` + + diff --git a/skills/creative/pretext/templates/hello-orb-flow.html b/skills/creative/pretext/templates/hello-orb-flow.html new file mode 100644 index 0000000000..b7bdbca2f4 --- /dev/null +++ b/skills/creative/pretext/templates/hello-orb-flow.html @@ -0,0 +1,95 @@ + + + + +pretext hello — text flowing around an orb + + + + + + + diff --git a/skills/creative/sketch/SKILL.md b/skills/creative/sketch/SKILL.md new file mode 100644 index 0000000000..b84f143dd4 --- /dev/null +++ b/skills/creative/sketch/SKILL.md @@ -0,0 +1,217 @@ +--- +name: sketch +description: "Throwaway HTML mockups: 2-3 design variants to compare." +version: 1.0.0 +author: Hermes Agent (adapted from gsd-build/get-shit-done) +license: MIT +metadata: + hermes: + tags: [sketch, mockup, design, ui, prototype, html, variants, exploration, wireframe, comparison] + related_skills: [spike, claude-design, popular-web-designs, excalidraw] +--- + +# Sketch + +Use this skill when the user wants to **see a design direction before committing** to one — exploring a UI/UX idea as disposable HTML mockups. The point is to generate 2-3 interactive variants so the user can compare visual directions side-by-side, not to produce shippable code. + +Load this when the user says things like "sketch this screen", "show me what X could look like", "compare layout A vs B", "give me 2-3 takes on this UI", "let me see some variants", "mockup this before I build". + +## When NOT to use this + +- User wants a production component — use `claude-design` or build it properly +- User wants a polished one-off HTML artifact (landing page, deck) — `claude-design` +- User wants a diagram — `excalidraw`, `architecture-diagram` +- The design is already locked — just build it + +## If the user has the full GSD system installed + +If `gsd-sketch` shows up as a sibling skill (installed via `npx get-shit-done-cc --hermes`), prefer **`gsd-sketch`** for the full workflow: persistent `.planning/sketches/` with MANIFEST, frontier mode analysis, consistency audits across past sketches, and integration with the rest of GSD. This skill is the lightweight standalone version — one-off sketching without the state machinery. + +## Core method + +``` +intake → variants → head-to-head → pick winner (or iterate) +``` + +### 1. Intake (skip if the user already gave you enough) + +Before generating variants, get three things — one question at a time, not all at once: + +1. **Feel.** "What should this feel like? Adjectives, emotions, a vibe." — *"calm, editorial, like Linear"* tells you more than *"minimal"*. +2. **References.** "What apps, sites, or products capture the feel you're imagining?" — actual references beat abstract descriptions. +3. **Core action.** "What's the single most important thing a user does on this screen?" — the variants should all serve this well; if they don't, they're just decoration. + +Reflect each answer briefly before the next question. If the user already gave you all three upfront, skip straight to variants. + +### 2. Variants (2-3, never 1, rarely 4+) + +Produce **2-3 variants** in one go. Each variant is a complete, standalone HTML file. Don't describe variants — build them. The point is comparison. + +Each variant should take a **different design stance**, not different pixel values. Three good variant axes: + +- **Density:** compact / airy / ultra-dense (pick two contrasting poles) +- **Emphasis:** content-first / action-first / tool-first +- **Aesthetic:** editorial / utilitarian / playful +- **Layout:** single-column / sidebar / split-pane +- **Grounding:** card-based / bare-content / document-style + +Pick one axis and pull apart from it. Two variants that differ only in accent color are wasted effort — the user can't distinguish them. + +**Variant naming:** describe the stance, not the number. + +``` +sketches/ +├── 001-calm-editorial/ +│ ├── index.html +│ └── README.md +├── 001-utilitarian-dense/ +│ ├── index.html +│ └── README.md +└── 001-playful-split/ + ├── index.html + └── README.md +``` + +### 3. Make them real HTML + +Each variant is a **single self-contained HTML file**: + +- Inline ` +``` + +### 4. Variant README + +Each variant's `README.md` answers: + +```markdown +## Variant: {stance name} + +### Design stance +One sentence on the principle driving this variant. + +### Key choices +- Layout: ... +- Typography: ... +- Color: ... +- Interaction: ... + +### Trade-offs +- Strong at: ... +- Weak at: ... + +### Best for +- The kind of user or use case this variant actually serves +``` + +### 5. Head-to-head + +After all variants are built, present them as a comparison. Don't just list — **opinionate**: + +```markdown +## Three takes on the home screen + +| Dimension | Calm editorial | Utilitarian dense | Playful split | +|-----------|----------------|-------------------|---------------| +| Density | Low | High | Medium | +| Primary action visibility | Low | High | Medium | +| Scan-ability | High | Medium | Low | +| Feel | Calm, trusted | Sharp, tool-like | Inviting, energetic | + +**My take:** Utilitarian dense for power users, calm editorial for content-forward audiences. Playful split is weakest — tries to do both and commits to neither. +``` + +Let the user pick a winner, or combine two into a hybrid, or ask for another round. + +## Theming (when the project has a visual identity) + +If the user has an existing theme (colors, fonts, tokens), put shared tokens in `sketches/themes/tokens.css` and `@import` them in each variant. Keep tokens minimal: + +```css +/* sketches/themes/tokens.css */ +:root { + --color-bg: #fafafa; + --color-fg: #1a1a1a; + --color-accent: #0066ff; + --color-muted: #666; + --radius: 8px; + --font-display: "Inter", sans-serif; + --font-body: -apple-system, BlinkMacSystemFont, sans-serif; +} +``` + +Don't over-tokenize a throwaway sketch — three colors and one font is usually enough. + +## Interactivity bar + +A sketch is interactive enough when the user can: + +1. **Click a primary action** and something visible happens (state change, modal, toast, navigation feint) +2. **See one meaningful state transition** (filter a list, toggle a mode, open/close a panel) +3. **Hover recognizable affordances** (buttons, rows, tabs) + +More than that is over-engineering a throwaway. Less than that is a screenshot. + +## Frontier mode (picking what to sketch next) + +If sketches already exist and the user says "what should I sketch next?": + +- **Consistency gaps** — two winning variants from different sketches made independent choices that haven't been composed together yet +- **Unsketched screens** — referenced but never explored +- **State coverage** — happy path sketched, but not empty / loading / error / 1000-items +- **Responsive gaps** — validated at one viewport; does it hold at mobile / ultrawide? +- **Interaction patterns** — static layouts exist; transitions, drag, scroll behavior don't + +Propose 2-4 named candidates. Let the user pick. + +## Output + +- Create `sketches/` (or `.planning/sketches/` if the user is using GSD conventions) in the repo root +- One subdir per variant: `NNN-stance-name/index.html` + `README.md` +- Tell the user how to open them: `open sketches/001-calm-editorial/index.html` on macOS, `xdg-open` on Linux, `start` on Windows +- Keep variants disposable — a sketch that you felt the need to preserve should be promoted into real project code, not curated as an asset + +**Typical tool sequence for one variant:** + +``` +terminal("mkdir -p sketches/001-calm-editorial") +write_file("sketches/001-calm-editorial/index.html", "...") +write_file("sketches/001-calm-editorial/README.md", "## Variant: Calm editorial\n...") +browser_navigate(url="file://$(pwd)/sketches/001-calm-editorial/index.html") +browser_vision(question="How does this look? Any obvious layout issues?") +``` + +Repeat for each variant, then present the comparison table. + +## Attribution + +Adapted from the GSD (Get Shit Done) project's `/gsd-sketch` workflow — MIT © 2025 Lex Christopherson ([gsd-build/get-shit-done](https://github.com/gsd-build/get-shit-done)). The full GSD system ships persistent sketch state, theme/variant pattern references, and consistency-audit workflows; install with `npx get-shit-done-cc --hermes --global`. diff --git a/skills/software-development/spike/SKILL.md b/skills/software-development/spike/SKILL.md new file mode 100644 index 0000000000..79d66bda14 --- /dev/null +++ b/skills/software-development/spike/SKILL.md @@ -0,0 +1,196 @@ +--- +name: spike +description: "Throwaway experiments to validate an idea before build." +version: 1.0.0 +author: Hermes Agent (adapted from gsd-build/get-shit-done) +license: MIT +metadata: + hermes: + tags: [spike, prototype, experiment, feasibility, throwaway, exploration, research, planning, mvp, proof-of-concept] + related_skills: [sketch, writing-plans, subagent-driven-development, plan] +--- + +# Spike + +Use this skill when the user wants to **feel out an idea** before committing to a real build — validating feasibility, comparing approaches, or surfacing unknowns that no amount of research will answer. Spikes are disposable by design. Throw them away once they've paid their debt. + +Load this when the user says things like "let me try this", "I want to see if X works", "spike this out", "before I commit to Y", "quick prototype of Z", "is this even possible?", or "compare A vs B". + +## When NOT to use this + +- The answer is knowable from docs or reading code — just do research, don't build +- The work is production path — use `writing-plans` / `plan` instead +- The idea is already validated — jump straight to implementation + +## If the user has the full GSD system installed + +If `gsd-spike` shows up as a sibling skill (installed via `npx get-shit-done-cc --hermes`), prefer **`gsd-spike`** when the user wants the full GSD workflow: persistent `.planning/spikes/` state, MANIFEST tracking across sessions, Given/When/Then verdict format, and commit patterns that integrate with the rest of GSD. This skill is the lightweight standalone version for users who don't have (or don't want) the full system. + +## Core method + +Regardless of scale, every spike follows this loop: + +``` +decompose → research → build → verdict + ↑__________________________________________↓ + iterate on findings +``` + +### 1. Decompose + +Break the user's idea into **2-5 independent feasibility questions**. Each question is one spike. Present them as a table with Given/When/Then framing: + +| # | Spike | Validates (Given/When/Then) | Risk | +|---|-------|----------------------------|------| +| 001 | websocket-streaming | Given a WS connection, when LLM streams tokens, then client receives chunks < 100ms | High | +| 002a | pdf-parse-pdfjs | Given a multi-page PDF, when parsed with pdfjs, then structured text is extractable | Medium | +| 002b | pdf-parse-camelot | Given a multi-page PDF, when parsed with camelot, then structured text is extractable | Medium | + +**Spike types:** +- **standard** — one approach answering one question +- **comparison** — same question, different approaches (shared number, letter suffix `a`/`b`/`c`) + +**Good spike questions:** specific feasibility with observable output. +**Bad spike questions:** too broad, no observable output, or just "read the docs about X". + +**Order by risk.** The spike most likely to kill the idea runs first. No point prototyping the easy parts if the hard part doesn't work. + +**Skip decomposition** only if the user already knows exactly what they want to spike and says so. Then take their idea as a single spike. + +### 2. Align (for multi-spike ideas) + +Present the spike table. Ask: "Build all in this order, or adjust?" Let the user drop, reorder, or re-frame before you write any code. + +### 3. Research (per spike, before building) + +Spikes are not research-free — you research enough to pick the right approach, then you build. Per spike: + +1. **Brief it.** 2-3 sentences: what this spike is, why it matters, key risk. +2. **Surface competing approaches** if there's real choice: + + | Approach | Tool/Library | Pros | Cons | Status | + |----------|-------------|------|------|--------| + | ... | ... | ... | ... | maintained / abandoned / beta | + +3. **Pick one.** State why. If 2+ are credible, build quick variants within the spike. +4. **Skip research** for pure logic with no external dependencies. + +Use Hermes tools for the research step: + +- `web_search("python websocket streaming libraries 2025")` — find candidates +- `web_extract(urls=["https://websockets.readthedocs.io/..."])` — read the actual docs (returns markdown) +- `terminal("pip show websockets | grep Version")` — check what's installed in the project's venv + +For libraries without docs pages, clone and read their `README.md` / `examples/` via `read_file`. Context7 MCP (if the user has it configured) is also a good source — `mcp_*_resolve-library-id` then `mcp_*_query-docs`. + +### 4. Build + +One directory per spike. Keep it standalone. + +``` +spikes/ +├── 001-websocket-streaming/ +│ ├── README.md +│ └── main.py +├── 002a-pdf-parse-pdfjs/ +│ ├── README.md +│ └── parse.js +└── 002b-pdf-parse-camelot/ + ├── README.md + └── parse.py +``` + +**Bias toward something the user can interact with.** Spikes fail when the only output is a log line that says "it works." The user wants to *feel* the spike working. Default choices, in order of preference: + +1. A runnable CLI that takes input and prints observable output +2. A minimal HTML page that demonstrates the behavior +3. A small web server with one endpoint +4. A unit test that exercises the question with recognizable assertions + +**Depth over speed.** Never declare "it works" after one happy-path run. Test edge cases. Follow surprising findings. The verdict is only trustworthy when the investigation was honest. + +**Avoid** unless the spike specifically requires it: complex package management, build tools/bundlers, Docker, env files, config systems. Hardcode everything — it's a spike. + +**Building one spike** — a typical tool sequence: + +``` +terminal("mkdir -p spikes/001-websocket-streaming") +write_file("spikes/001-websocket-streaming/README.md", "# 001: websocket-streaming\n\n...") +write_file("spikes/001-websocket-streaming/main.py", "...") +terminal("cd spikes/001-websocket-streaming && python3 main.py") +# Observe output, iterate. +``` + +**Parallel comparison spikes (002a / 002b) — delegate.** When two approaches can run in parallel and both need real engineering (not 10-line prototypes), fan out with `delegate_task`: + +``` +delegate_task(tasks=[ + {"goal": "Build 002a-pdf-parse-pdfjs: ...", "toolsets": ["terminal", "file", "web"]}, + {"goal": "Build 002b-pdf-parse-camelot: ...", "toolsets": ["terminal", "file", "web"]}, +]) +``` + +Each subagent returns its own verdict; you write the head-to-head. + +### 5. Verdict + +Each spike's `README.md` closes with: + +```markdown +## Verdict: VALIDATED | PARTIAL | INVALIDATED + +### What worked +- ... + +### What didn't +- ... + +### Surprises +- ... + +### Recommendation for the real build +- ... +``` + +**VALIDATED** = the core question was answered yes, with evidence. +**PARTIAL** = it works under constraints X, Y, Z — document them. +**INVALIDATED** = doesn't work, for this reason. This is a successful spike. + +## Comparison spikes + +When two approaches answer the same question (002a / 002b), build them **back to back**, then do a head-to-head comparison at the end: + +```markdown +## Head-to-head: pdfjs vs camelot + +| Dimension | pdfjs (002a) | camelot (002b) | +|-----------|--------------|----------------| +| Extraction quality | 9/10 structured | 7/10 table-only | +| Setup complexity | npm install, 1 line | pip + ghostscript | +| Perf on 100-page PDF | 3s | 18s | +| Handles rotated text | no | yes | + +**Winner:** pdfjs for our use case. Camelot if we need table-first extraction later. +``` + +## Frontier mode (picking what to spike next) + +If spikes already exist and the user says "what should I spike next?", walk the existing directories and look for: + +- **Integration risks** — two validated spikes that touch the same resource but were tested independently +- **Data handoffs** — spike A's output was assumed compatible with spike B's input; never proven +- **Gaps in the vision** — capabilities assumed but unproven +- **Alternative approaches** — different angles for PARTIAL or INVALIDATED spikes + +Propose 2-4 candidates as Given/When/Then. Let the user pick. + +## Output + +- Create `spikes/` (or `.planning/spikes/` if the user is using GSD conventions) in the repo root +- One dir per spike: `NNN-descriptive-name/` +- `README.md` per spike captures question, approach, results, verdict +- Keep the code throwaway — a spike that takes 2 days to "clean up for production" was a bad spike + +## Attribution + +Adapted from the GSD (Get Shit Done) project's `/gsd-spike` workflow — MIT © 2025 Lex Christopherson ([gsd-build/get-shit-done](https://github.com/gsd-build/get-shit-done)). The full GSD system offers persistent spike state, MANIFEST tracking, and integration with a broader spec-driven development pipeline; install with `npx get-shit-done-cc --hermes --global`. diff --git a/skills/software-development/subagent-driven-development/SKILL.md b/skills/software-development/subagent-driven-development/SKILL.md index 5d349c9720..23c5bf47da 100644 --- a/skills/software-development/subagent-driven-development/SKILL.md +++ b/skills/software-development/subagent-driven-development/SKILL.md @@ -340,3 +340,12 @@ Catch issues early ``` **Quality is not an accident. It's the result of systematic process.** + +## Further reading (load when relevant) + +When the orchestration involves significant context usage, long review loops, or complex validation checkpoints, load these references for the specific discipline: + +- **`references/context-budget-discipline.md`** — Four-tier context degradation model (PEAK / GOOD / DEGRADING / POOR), read-depth rules that scale with context window size, and early warning signs of silent degradation. Load when a run will clearly consume significant context (multi-phase plans, many subagents, large artifacts). +- **`references/gates-taxonomy.md`** — The four canonical gate types (Pre-flight, Revision, Escalation, Abort) with behavior, recovery, and examples. Load when designing or reviewing any workflow that has validation checkpoints — use the vocabulary explicitly so each gate has defined entry, failure behavior, and resumption rules. + +Both references adapted from gsd-build/get-shit-done (MIT © 2025 Lex Christopherson). diff --git a/skills/software-development/subagent-driven-development/references/context-budget-discipline.md b/skills/software-development/subagent-driven-development/references/context-budget-discipline.md new file mode 100644 index 0000000000..2728160c16 --- /dev/null +++ b/skills/software-development/subagent-driven-development/references/context-budget-discipline.md @@ -0,0 +1,53 @@ +# Context Budget Discipline + +Practical rules for keeping orchestrator context lean when spawning subagents or reading large artifacts. Use these whenever you're running a multi-step agent loop that will consume significant context — plan execution, subagent orchestration, review pipelines, multi-file refactors. + +Adapted from the GSD (Get Shit Done) project's context-budget reference — MIT © 2025 Lex Christopherson ([gsd-build/get-shit-done](https://github.com/gsd-build/get-shit-done)). + +## Universal rules + +Every workflow that spawns agents or reads significant content must follow these: + +1. **Never read agent definition files.** `delegate_task` auto-loads them — you reading them too just doubles the cost. +2. **Never inline large files into subagent prompts.** Tell the agent to read the file from disk with `read_file` instead. The subagent gets full content; your context stays lean. +3. **Read depth scales with context window.** See the table below. +4. **Delegate heavy work to subagents.** The orchestrator routes; it doesn't execute. +5. **Proactively warn** the user when you've consumed significant context ("Context is getting heavy — consider checkpointing progress before we continue"). + +## Read depth by context window + +Check the model's actual context window (not "it's Claude so 200K"). Some Sonnet deployments are 1M, some are 200K. If you don't know, assume the smaller one — err toward leanness. + +| Context window | Subagent output reading | Summary files | Verification files | Plans for other phases | +|----------------|-------------------------|---------------|--------------------|-----------------------| +| < 500k (e.g. 200k) | Frontmatter only | Frontmatter only | Frontmatter only | Current phase only | +| >= 500k (1M models) | Full body permitted | Full body permitted | Full body permitted | Current phase only | + +"Frontmatter only" means: read enough to see the final status/verdict/conclusion. If the subagent wrote a 3000-line debug log, read the summary section it produced, not the log. + +## Four-tier degradation model + +Monitor your context usage and shift behavior as you climb the tiers. The point is to notice *before* you hit the wall, not when responses start truncating. + +| Tier | Usage | Behavior | +|------|-------|----------| +| **PEAK** | 0 – 30% | Full operations. Read bodies, spawn multiple agents in parallel, inline results freely. | +| **GOOD** | 30 – 50% | Normal operations. Prefer frontmatter reads. Delegate aggressively. | +| **DEGRADING** | 50 – 70% | Economize. Frontmatter-only reads, minimal inlining, **warn the user** about budget. | +| **POOR** | 70%+ | Emergency mode. **Checkpoint progress immediately.** No new reads unless critical. Finish the current task and stop cleanly. | + +## Early warning signs (before panic thresholds fire) + +Quality degrades *gradually* before hard limits hit. Watch for these: + +- **Silent partial completion.** Subagent claims done but implementation is incomplete. Self-checks catch file existence, not semantic completeness. Always verify subagent output against the plan's must-haves, not just "did a file appear?" +- **Increasing vagueness.** Agent starts using phrases like "appropriate handling" or "standard patterns" instead of specific code. This is context pressure showing up before budget warnings fire. +- **Skipped protocol steps.** Agent omits steps it would normally follow. If success criteria has 8 items and the report covers 5, suspect context pressure, not "the agent decided 5 was enough." + +When these signs appear, checkpoint the work and either reset context or hand off to a fresh subagent. + +## Fundamental limitation + +When you orchestrate, you cannot verify semantic correctness of subagent output — only structural completeness ("did the file appear?", "does the test pass?"). Semantic verification requires either running the code yourself or delegating a review pass to another fresh subagent. + +**Mitigation:** in every task you delegate, include explicit "must-have" truths the subagent must confirm in its response (e.g., "confirm your test actually tests X, not just that X was imported"). The subagent re-asserting concrete facts is evidence; vague summaries are not. diff --git a/skills/software-development/subagent-driven-development/references/gates-taxonomy.md b/skills/software-development/subagent-driven-development/references/gates-taxonomy.md new file mode 100644 index 0000000000..206f71efc9 --- /dev/null +++ b/skills/software-development/subagent-driven-development/references/gates-taxonomy.md @@ -0,0 +1,93 @@ +# Gates Taxonomy + +Canonical gate types for validation checkpoints across any workflow that spawns subagents, runs review loops, or has human-approval pauses. Every validation checkpoint maps to one of these four types — naming them explicitly makes the workflow legible and prevents "what happens when this check fails?" confusion. + +Adapted from the GSD (Get Shit Done) project's gates reference — MIT © 2025 Lex Christopherson ([gsd-build/get-shit-done](https://github.com/gsd-build/get-shit-done)). + +## The four gate types + +### 1. Pre-flight gate + +**Purpose:** Validates preconditions before starting an operation. + +**Behavior:** Blocks entry if conditions unmet. No partial work created — bail before anything changes. + +**Recovery:** Fix the missing precondition, then retry. + +**Examples:** +- Implementation phase checks that the plan file exists before it starts writing code. +- Delegated subagent checks that required env vars are set before making API calls. +- Commit checks that tests passed before pushing. + +### 2. Revision gate + +**Purpose:** Evaluates output quality and routes to revision if insufficient. + +**Behavior:** Loops back to the producer with specific feedback. Bounded by an iteration cap (typically 3). + +**Recovery:** Producer addresses feedback; checker re-evaluates. The loop escalates early if issue count does not decrease between consecutive iterations (stall detection). After max iterations, escalates to the user unconditionally — never loop forever. + +**Examples:** +- Plan reviewer reads a draft plan, returns specific issues, planner revises, reviewer re-reads (max 3 cycles). +- Code reviewer checks subagent-produced code against must-haves; dispatches fixes back to the implementer if any must-have failed. +- Test coverage checker validates new tests exercise the new paths; if not, sends back to author. + +### 3. Escalation gate + +**Purpose:** Surfaces unresolvable issues to the human for a decision. + +**Behavior:** Pauses workflow, presents options, waits for human input. Never guesses, never picks a default. + +**Recovery:** Human chooses action; workflow resumes on the selected path. + +**Examples:** +- Revision loop exhausted after 3 iterations. +- Merge conflict during automated worktree cleanup. +- Ambiguous requirement — two reasonable interpretations and the choice changes the approach. +- Subagent reports "the plan says X but the codebase actually does Y" — human decides which is right. + +### 4. Abort gate + +**Purpose:** Terminates the operation to prevent damage or waste. + +**Behavior:** Stops immediately, preserves state (checkpoint current progress), reports the specific reason. + +**Recovery:** Human investigates root cause, fixes, restarts from checkpoint. + +**Examples:** +- Context window critically low during execution (POOR tier, >70%) — abort cleanly rather than produce truncated output. +- Critical dependency unavailable mid-run (network down, API key revoked). +- Unrecoverable filesystem state (disk full, permissions lost). +- Safety invariant violated (agent attempted an irreversible destructive action outside approved scope). + +## How to use this in a skill + +When you write an orchestration skill that has validation checkpoints, **name each checkpoint by its gate type explicitly** and answer three questions: + +1. **What condition triggers this gate?** (e.g., "plan file missing", "issue count didn't decrease", "context >70%") +2. **What happens when it fails?** (block / loop back / ask human / abort) +3. **Who resumes, and from where?** (fix precondition + retry, revise + re-check, human decision, restart from checkpoint) + +Answering these three up front means your skill never hits "what do we do now?" at runtime. + +## Example — a review loop with all four gate types + +``` +[Pre-flight] plan.md exists and is non-empty? → no: bail, ask user to write a plan first + ↓ yes +[Execute] subagent implements task + ↓ +[Revision] reviewer checks against must-haves → fail: loop back to subagent (max 3) + ↓ pass +[Pre-flight] tests pass? → no: bail, report failing tests + ↓ yes +[Commit] + ↓ +(on revision loop exhaustion) +[Escalation] "3 review cycles failed to converge on issue X — pick: force-merge, rewrite task, abandon" + ↓ user picks +(on any tier-POOR context pressure during loop) +[Abort] "context at 73%, checkpointing and stopping" +``` + +The vocabulary is small on purpose. Every gate in every workflow should fit one of these four. If you find yourself inventing a fifth, it's probably a revision gate with extra branching, or an escalation gate in disguise. diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index d4afed101f..6628f0da26 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -11,6 +11,7 @@ import acp from acp.agent.router import build_agent_router from acp.schema import ( AgentCapabilities, + AgentMessageChunk, AuthenticateResponse, AvailableCommandsUpdate, Implementation, @@ -27,6 +28,7 @@ from acp.schema import ( SessionInfo, TextContentBlock, Usage, + UserMessageChunk, ) from acp_adapter.server import HermesACPAgent, HERMES_VERSION from acp_adapter.session import SessionManager @@ -224,6 +226,58 @@ class TestSessionOps: resp = await agent.load_session(cwd="/tmp", session_id="bogus") assert resp is None + @pytest.mark.asyncio + async def test_load_session_replays_persisted_history_to_client(self, agent): + mock_conn = MagicMock(spec=acp.Client) + mock_conn.session_update = AsyncMock() + agent._conn = mock_conn + + new_resp = await agent.new_session(cwd="/tmp") + state = agent.session_manager.get_session(new_resp.session_id) + state.history = [ + {"role": "system", "content": "hidden system"}, + {"role": "user", "content": "what controls the / slash commands?"}, + {"role": "assistant", "content": "HermesACPAgent._ADVERTISED_COMMANDS controls them."}, + {"role": "tool", "content": "tool output should not replay"}, + ] + + mock_conn.session_update.reset_mock() + resp = await agent.load_session(cwd="/tmp", session_id=new_resp.session_id) + + assert isinstance(resp, LoadSessionResponse) + calls = mock_conn.session_update.await_args_list + replay_calls = [ + call for call in calls + if getattr(call.kwargs.get("update"), "session_update", None) + in {"user_message_chunk", "agent_message_chunk"} + ] + assert len(replay_calls) == 2 + assert isinstance(replay_calls[0].kwargs["update"], UserMessageChunk) + assert replay_calls[0].kwargs["update"].content.text == "what controls the / slash commands?" + assert isinstance(replay_calls[1].kwargs["update"], AgentMessageChunk) + assert replay_calls[1].kwargs["update"].content.text.startswith("HermesACPAgent") + + @pytest.mark.asyncio + async def test_resume_session_replays_persisted_history_to_client(self, agent): + mock_conn = MagicMock(spec=acp.Client) + mock_conn.session_update = AsyncMock() + agent._conn = mock_conn + + new_resp = await agent.new_session(cwd="/tmp") + state = agent.session_manager.get_session(new_resp.session_id) + state.history = [{"role": "user", "content": "So tell me the current state"}] + + mock_conn.session_update.reset_mock() + resp = await agent.resume_session(cwd="/tmp", session_id=new_resp.session_id) + + assert isinstance(resp, ResumeSessionResponse) + updates = [call.kwargs["update"] for call in mock_conn.session_update.await_args_list] + assert any( + isinstance(update, UserMessageChunk) + and update.content.text == "So tell me the current state" + for update in updates + ) + @pytest.mark.asyncio async def test_resume_session_creates_new_if_missing(self, agent): resume_resp = await agent.resume_session(cwd="/tmp", session_id="nonexistent") diff --git a/tests/agent/test_anthropic_adapter.py b/tests/agent/test_anthropic_adapter.py index b78ae48590..8105363b2e 100644 --- a/tests/agent/test_anthropic_adapter.py +++ b/tests/agent/test_anthropic_adapter.py @@ -66,34 +66,29 @@ class TestBuildAnthropicClient: assert "claude-code-20250219" in betas assert "interleaved-thinking-2025-05-14" in betas assert "fine-grained-tool-streaming-2025-05-14" in betas + # Default: 1M-context beta stays IN for OAuth so 1M-capable + # subscriptions keep full context. The reactive recovery path + # in run_agent.py flips it off only after a subscription + # actually rejects the beta. + assert "context-1m-2025-08-07" in betas assert "api_key" not in kwargs - def test_oauth_does_not_send_claude_code_spoof_headers(self): - """OAuth requests identify as Hermes — no claude-cli UA, no x-app: cli. - - Anthropic's OAuth-gated Messages API accepts requests from non-Claude-Code - clients as long as auth is correct and the OAuth beta headers are present. - See commit that removed fingerprinting for the live-test write-up. - """ + def test_oauth_drop_context_1m_beta_strips_only_1m(self): + """drop_context_1m_beta=True strips context-1m-2025-08-07 while + preserving every other OAuth-relevant beta.""" with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk: - build_anthropic_client("sk-ant-oat01-" + "x" * 60) - headers = mock_sdk.Anthropic.call_args[1]["default_headers"] - assert "user-agent" not in {k.lower() for k in headers} - assert "x-app" not in {k.lower() for k in headers} - - def test_oauth_strips_context_1m_beta(self): - """context-1m-2025-08-07 is incompatible with OAuth auth — must be stripped. - - Anthropic returns HTTP 400 "This authentication style is incompatible - with the long context beta header." when OAuth traffic carries it. - """ - with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk: - build_anthropic_client("sk-ant-oat01-" + "x" * 60) - betas = mock_sdk.Anthropic.call_args[1]["default_headers"]["anthropic-beta"] + build_anthropic_client( + "sk-ant-oat01-" + "x" * 60, + drop_context_1m_beta=True, + ) + kwargs = mock_sdk.Anthropic.call_args[1] + betas = kwargs["default_headers"]["anthropic-beta"] assert "context-1m-2025-08-07" not in betas - # But other common betas still flow through - assert "interleaved-thinking-2025-05-14" in betas + # Everything else must still be there. assert "oauth-2025-04-20" in betas + assert "claude-code-20250219" in betas + assert "interleaved-thinking-2025-05-14" in betas + assert "fine-grained-tool-streaming-2025-05-14" in betas def test_api_key_uses_api_key(self): with patch("agent.anthropic_adapter._anthropic_sdk") as mock_sdk: @@ -104,6 +99,7 @@ class TestBuildAnthropicClient: # API key auth should still get common betas betas = kwargs["default_headers"]["anthropic-beta"] assert "interleaved-thinking-2025-05-14" in betas + assert "context-1m-2025-08-07" in betas assert "oauth-2025-04-20" not in betas # OAuth-only beta NOT present assert "claude-code-20250219" not in betas # OAuth-only beta NOT present @@ -113,7 +109,7 @@ class TestBuildAnthropicClient: kwargs = mock_sdk.Anthropic.call_args[1] assert kwargs["base_url"] == "https://custom.api.com" assert kwargs["default_headers"] == { - "anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14" + "anthropic-beta": "interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07" } def test_minimax_anthropic_endpoint_uses_bearer_auth_for_regular_api_keys(self): @@ -990,6 +986,42 @@ class TestBuildAnthropicKwargs: ) assert kwargs["model"] == "claude-sonnet-4-20250514" + def test_fast_mode_oauth_default_keeps_context_1m_beta(self): + """Default OAuth fast-mode requests still carry context-1m-2025-08-07.""" + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=4096, + reasoning_config=None, + is_oauth=True, + fast_mode=True, + ) + betas = kwargs["extra_headers"]["anthropic-beta"] + assert "fast-mode-2026-02-01" in betas + assert "oauth-2025-04-20" in betas + assert "context-1m-2025-08-07" in betas + + def test_fast_mode_oauth_drop_context_1m_beta_strips_only_1m(self): + """drop_context_1m_beta=True strips context-1m from fast-mode + extra_headers while preserving every other OAuth + fast-mode beta.""" + kwargs = build_anthropic_kwargs( + model="claude-opus-4-6", + messages=[{"role": "user", "content": "Hi"}], + tools=None, + max_tokens=4096, + reasoning_config=None, + is_oauth=True, + fast_mode=True, + drop_context_1m_beta=True, + ) + betas = kwargs["extra_headers"]["anthropic-beta"] + assert "context-1m-2025-08-07" not in betas + assert "fast-mode-2026-02-01" in betas + assert "oauth-2025-04-20" in betas + assert "claude-code-20250219" in betas + assert "interleaved-thinking-2025-05-14" in betas + def test_reasoning_config_maps_to_manual_thinking_for_pre_4_6_models(self): kwargs = build_anthropic_kwargs( model="claude-sonnet-4-20250514", diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index fb23a59bc4..32290b0612 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -259,7 +259,7 @@ class TestAnthropicOAuthFlag: assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled" -class TestTryCodex: +class TestBuildCodexClient: def test_pool_without_selected_entry_falls_back_to_auth_store(self): with ( patch("agent.auxiliary_client._select_pool_entry", return_value=(True, None)), @@ -267,15 +267,23 @@ class TestTryCodex: patch("agent.auxiliary_client.OpenAI") as mock_openai, ): mock_openai.return_value = MagicMock() - from agent.auxiliary_client import _try_codex + from agent.auxiliary_client import _build_codex_client - client, model = _try_codex() + client, model = _build_codex_client("gpt-5.4") assert client is not None - assert model == "gpt-5.2-codex" + assert model == "gpt-5.4" assert mock_openai.call_args.kwargs["api_key"] == "codex-auth-token" assert mock_openai.call_args.kwargs["base_url"] == "https://chatgpt.com/backend-api/codex" + def test_rejects_missing_model(self): + """Callers must pass an explicit model; no hardcoded default.""" + from agent.auxiliary_client import _build_codex_client + + client, model = _build_codex_client("") + assert client is None + assert model is None + class TestExpiredCodexFallback: """Test that expired Codex tokens don't block the auto chain.""" @@ -507,14 +515,14 @@ class TestGetTextAuxiliaryClient: patch("agent.auxiliary_client.OpenAI"), patch("hermes_cli.auth._read_codex_tokens", side_effect=AssertionError("legacy codex store should not run")), ): - from agent.auxiliary_client import _try_codex + from agent.auxiliary_client import _build_codex_client - client, model = _try_codex() + client, model = _build_codex_client("gpt-5.4") from agent.auxiliary_client import CodexAuxiliaryClient assert isinstance(client, CodexAuxiliaryClient) - assert model == "gpt-5.2-codex" + assert model == "gpt-5.4" def test_returns_none_when_nothing_available(self, monkeypatch): monkeypatch.delenv("OPENAI_BASE_URL", raising=False) @@ -783,11 +791,15 @@ class TestIsPaymentError: class TestGetProviderChain: """_get_provider_chain() resolves functions at call time (testable).""" - def test_returns_five_entries(self): + def test_returns_four_entries(self): chain = _get_provider_chain() - assert len(chain) == 5 + assert len(chain) == 4 labels = [label for label, _ in chain] - assert labels == ["openrouter", "nous", "local/custom", "openai-codex", "api-key"] + assert labels == ["openrouter", "nous", "local/custom", "api-key"] + # Codex is deliberately NOT in this chain — see _get_provider_chain + # docstring. ChatGPT-account Codex has a shifting model allow-list; + # guessing a model to fall back on breaks more often than it helps. + assert "openai-codex" not in labels def test_picks_up_patched_functions(self): """Patches on _try_* functions must be visible in the chain.""" @@ -814,7 +826,6 @@ class TestTryPaymentFallback: with patch("agent.auxiliary_client._try_openrouter", return_value=(None, None)), \ patch("agent.auxiliary_client._try_nous", return_value=(None, None)), \ patch("agent.auxiliary_client._try_custom_endpoint", return_value=(None, None)), \ - patch("agent.auxiliary_client._try_codex", return_value=(None, None)), \ patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \ patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"): client, model, label = _try_payment_fallback("openrouter") @@ -825,23 +836,26 @@ class TestTryPaymentFallback: """'codex' should map to 'openai-codex' in the skip set.""" mock_client = MagicMock() with patch("agent.auxiliary_client._try_openrouter", return_value=(mock_client, "or-model")), \ - patch("agent.auxiliary_client._try_codex", return_value=(None, None)), \ patch("agent.auxiliary_client._read_main_provider", return_value="openai-codex"): client, model, label = _try_payment_fallback("openai-codex", task="vision") assert client is mock_client assert label == "openrouter" - def test_skips_to_codex_when_or_and_nous_fail(self): - mock_codex = MagicMock() + def test_codex_not_in_fallback_chain(self): + """Codex is deliberately NOT a fallback rung (shifting model allow-list). + + When OR/Nous/custom/api-key all fail, payment-fallback returns None — + Codex is never tried with a guessed model. + """ with patch("agent.auxiliary_client._try_openrouter", return_value=(None, None)), \ patch("agent.auxiliary_client._try_nous", return_value=(None, None)), \ patch("agent.auxiliary_client._try_custom_endpoint", return_value=(None, None)), \ - patch("agent.auxiliary_client._try_codex", return_value=(mock_codex, "gpt-5.2-codex")), \ + patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)), \ patch("agent.auxiliary_client._read_main_provider", return_value="openrouter"): client, model, label = _try_payment_fallback("openrouter") - assert client is mock_codex - assert model == "gpt-5.2-codex" - assert label == "openai-codex" + assert client is None + assert model is None + assert label == "" class TestCallLlmPaymentFallback: @@ -1360,14 +1374,14 @@ class TestAuxiliaryAuthRefreshRetry: with ( patch( "agent.auxiliary_client.resolve_vision_provider_client", - side_effect=[("openai-codex", failing_client, "gpt-5.2-codex"), ("openai-codex", fresh_client, "gpt-5.2-codex")], + side_effect=[("openai-codex", failing_client, "gpt-5.4"), ("openai-codex", fresh_client, "gpt-5.4")], ), patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh, ): resp = call_llm( task="vision", provider="openai-codex", - model="gpt-5.2-codex", + model="gpt-5.4", messages=[{"role": "user", "content": "hi"}], ) @@ -1384,14 +1398,14 @@ class TestAuxiliaryAuthRefreshRetry: fresh_client.chat.completions.create.return_value = _DummyResponse("fresh-non-vision") with ( - patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.2-codex", None, None, None)), - patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.2-codex"), (fresh_client, "gpt-5.2-codex")]), + patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)), + patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]), patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh, ): resp = call_llm( task="compression", provider="openai-codex", - model="gpt-5.2-codex", + model="gpt-5.4", messages=[{"role": "user", "content": "hi"}], ) @@ -1439,14 +1453,14 @@ class TestAuxiliaryAuthRefreshRetry: with ( patch( "agent.auxiliary_client.resolve_vision_provider_client", - side_effect=[("openai-codex", failing_client, "gpt-5.2-codex"), ("openai-codex", fresh_client, "gpt-5.2-codex")], + side_effect=[("openai-codex", failing_client, "gpt-5.4"), ("openai-codex", fresh_client, "gpt-5.4")], ), patch("agent.auxiliary_client._refresh_provider_credentials", return_value=True) as mock_refresh, ): resp = await async_call_llm( task="vision", provider="openai-codex", - model="gpt-5.2-codex", + model="gpt-5.4", messages=[{"role": "user", "content": "hi"}], ) @@ -1635,3 +1649,106 @@ class TestCodexAdapterReasoningTranslation: ) assert "reasoning" not in captured + + +class TestVisionAutoSkipsKimiCoding: + """_resolve_auto vision branch skips providers that have no vision on + their main endpoint (e.g. Kimi Coding Plan /coding) and falls through + to the aggregator chain instead of handing back a client that will 404 + on every request (#17076). + """ + + def test_kimi_coding_skipped_falls_through_to_openrouter(self, monkeypatch): + """kimi-coding as main + vision auto → OpenRouter (not kimi).""" + fake_or_client = MagicMock(name="openrouter_client") + + monkeypatch.setattr( + "agent.auxiliary_client._read_main_provider", lambda: "kimi-coding", + ) + monkeypatch.setattr( + "agent.auxiliary_client._read_main_model", lambda: "kimi-code", + ) + # Guard: if the skip doesn't fire, _resolve_strict_vision_backend + # and resolve_provider_client both would try kimi-coding — detect + # either via the main-provider call and fail loud. + rpc_mock = MagicMock(side_effect=AssertionError( + "resolve_provider_client should NOT be called for kimi-coding " + "on the vision auto path")) + monkeypatch.setattr( + "agent.auxiliary_client.resolve_provider_client", rpc_mock, + ) + + def fake_strict(provider, model=None): + if provider == "openrouter": + return fake_or_client, "google/gemini-3-flash-preview" + if provider == "nous": + return None, None + raise AssertionError( + f"strict vision backend should not be called for {provider!r} " + "when main provider is kimi-coding" + ) + monkeypatch.setattr( + "agent.auxiliary_client._resolve_strict_vision_backend", + fake_strict, + ) + + provider, client, model = resolve_vision_provider_client() + assert provider == "openrouter" + assert client is fake_or_client + assert model == "google/gemini-3-flash-preview" + + def test_kimi_coding_cn_skipped_too(self, monkeypatch): + """Same skip applies to the CN variant.""" + fake_or_client = MagicMock(name="openrouter_client") + + monkeypatch.setattr( + "agent.auxiliary_client._read_main_provider", lambda: "kimi-coding-cn", + ) + monkeypatch.setattr( + "agent.auxiliary_client._read_main_model", lambda: "kimi-code", + ) + rpc_mock = MagicMock(side_effect=AssertionError( + "resolve_provider_client should NOT be called for kimi-coding-cn")) + monkeypatch.setattr( + "agent.auxiliary_client.resolve_provider_client", rpc_mock, + ) + monkeypatch.setattr( + "agent.auxiliary_client._resolve_strict_vision_backend", + lambda p, m=None: (fake_or_client, "gemini") + if p == "openrouter" + else (None, None), + ) + + provider, client, _ = resolve_vision_provider_client() + assert provider == "openrouter" + assert client is fake_or_client + + def test_explicit_override_to_kimi_coding_still_honored(self, monkeypatch): + """When a user *explicitly* requests kimi-coding for vision (e.g. + they know what they're doing, or are running a future build that + adds image_in capability to Kimi Code), the explicit path still + routes to kimi-coding — only the auto branch applies the skip. + """ + monkeypatch.setattr( + "agent.auxiliary_client._read_main_provider", lambda: "openrouter", + ) + fake_kimi_client = MagicMock(name="kimi_client") + gcc_mock = MagicMock(return_value=(fake_kimi_client, "kimi-code")) + monkeypatch.setattr( + "agent.auxiliary_client._get_cached_client", gcc_mock, + ) + + provider, client, model = resolve_vision_provider_client( + provider="kimi-coding", + ) + assert provider == "kimi-coding" + assert client is fake_kimi_client + gcc_mock.assert_called_once() + + def test_skip_set_covers_exactly_known_entries(self): + """Guard against accidental widening of the skip list.""" + from agent.auxiliary_client import _PROVIDERS_WITHOUT_VISION + assert _PROVIDERS_WITHOUT_VISION == frozenset({ + "kimi-coding", + "kimi-coding-cn", + }) diff --git a/tests/agent/test_codex_cloudflare_headers.py b/tests/agent/test_codex_cloudflare_headers.py index 6a343c8f84..2d9633a803 100644 --- a/tests/agent/test_codex_cloudflare_headers.py +++ b/tests/agent/test_codex_cloudflare_headers.py @@ -10,7 +10,7 @@ of auth correctness. ``_codex_cloudflare_headers`` in ``agent.auxiliary_client`` centralizes the header set so the primary chat client (``run_agent.AIAgent.__init__`` + ``_apply_client_headers_for_base_url``) and the auxiliary client paths -(``_try_codex`` and the ``raw_codex`` branch of ``resolve_provider_client``) +(``_build_codex_client`` and the ``raw_codex`` branch of ``resolve_provider_client``) all emit the same headers. These tests pin: @@ -207,9 +207,10 @@ class TestPrimaryClientWiring: # --------------------------------------------------------------------------- class TestAuxiliaryClientWiring: - def test_try_codex_passes_codex_headers(self, monkeypatch): - """_try_codex builds the OpenAI client used for compression / vision / - title generation when routed through Codex. Must emit codex headers.""" + def test_build_codex_client_passes_codex_headers(self, monkeypatch): + """_build_codex_client builds the OpenAI client used for compression / + vision / title generation when routed through Codex. Must emit codex + headers.""" from agent import auxiliary_client token = _make_codex_jwt("acct-aux-try-codex") @@ -225,7 +226,7 @@ class TestAuxiliaryClientWiring: ) with patch("agent.auxiliary_client.OpenAI") as mock_openai: mock_openai.return_value = MagicMock() - client, model = auxiliary_client._try_codex() + client, model = auxiliary_client._build_codex_client("gpt-5.4") assert client is not None headers = mock_openai.call_args.kwargs.get("default_headers") or {} assert headers.get("originator") == "codex_cli_rs" @@ -244,7 +245,7 @@ class TestAuxiliaryClientWiring: with patch("agent.auxiliary_client.OpenAI") as mock_openai: mock_openai.return_value = MagicMock() client, model = auxiliary_client.resolve_provider_client( - "openai-codex", raw_codex=True, + "openai-codex", model="gpt-5.4", raw_codex=True, ) assert client is not None headers = mock_openai.call_args.kwargs.get("default_headers") or {} diff --git a/tests/agent/test_copilot_acp_client.py b/tests/agent/test_copilot_acp_client.py index 63c87fdabd..dfc336b41c 100644 --- a/tests/agent/test_copilot_acp_client.py +++ b/tests/agent/test_copilot_acp_client.py @@ -80,15 +80,19 @@ class CopilotACPClientSafetyTests(unittest.TestCase): secret_file = root / "config.env" secret_file.write_text("OPENAI_API_KEY=sk-proj-abc123def456ghi789jkl012") - response = self._dispatch( - { - "jsonrpc": "2.0", - "id": 3, - "method": "fs/read_text_file", - "params": {"path": str(secret_file)}, - }, - cwd=str(root), - ) + # agent.redact snapshots HERMES_REDACT_SECRETS at import time into + # _REDACT_ENABLED, so patching os.environ is a no-op. Flip the + # module-level constant directly for the duration of the call. + with patch("agent.redact._REDACT_ENABLED", True): + response = self._dispatch( + { + "jsonrpc": "2.0", + "id": 3, + "method": "fs/read_text_file", + "params": {"path": str(secret_file)}, + }, + cwd=str(root), + ) content = ((response.get("result") or {}).get("content") or "") self.assertNotIn("abc123def456", content) diff --git a/tests/agent/test_curator.py b/tests/agent/test_curator.py index a8a4b5ada3..70040ec01d 100644 --- a/tests/agent/test_curator.py +++ b/tests/agent/test_curator.py @@ -271,10 +271,17 @@ def test_run_review_synchronous_invokes_llm_stub(curator_env, monkeypatch): _write_skill(skills_dir, "a") calls = [] - monkeypatch.setattr( - c, "_run_llm_review", - lambda prompt: (calls.append(prompt), "stubbed-summary")[1], - ) + def _stub(prompt): + calls.append(prompt) + return { + "final": "stubbed-summary", + "summary": "stubbed-summary", + "model": "stub-model", + "provider": "stub-provider", + "tool_calls": [], + "error": None, + } + monkeypatch.setattr(c, "_run_llm_review", _stub) captured = [] c.run_curator_review(on_summary=lambda s: captured.append(s), synchronous=True) @@ -478,3 +485,153 @@ def test_cli_pin_refuses_bundled_skill(curator_env, capsys): captured = capsys.readouterr() assert rc == 1 assert "bundled" in captured.out.lower() or "hub" in captured.out.lower() + + +# --------------------------------------------------------------------------- +# curator review-model resolution (canonical auxiliary.curator slot) +# +# Curator was unified with the rest of the aux task system in Apr 2026 so +# `hermes model` → auxiliary picker, the dashboard Models tab, and the full +# per-task config (timeout, base_url, api_key, extra_body) all work for it. +# Voscko report: curator.auxiliary.{provider,model} was advertised but never +# read. Fix wires curator through auxiliary.curator with a legacy fallback. +# --------------------------------------------------------------------------- + + +def test_review_model_defaults_to_main_when_slot_is_auto(curator_env): + """auxiliary.curator absent (or auto/empty) → use main model.provider/model.""" + curator = curator_env["curator"] + cfg = { + "model": {"provider": "openrouter", "default": "openai/gpt-5.5"}, + } + assert curator._resolve_review_model(cfg) == ("openrouter", "openai/gpt-5.5") + + # Explicit auto/empty slot — still main model. + cfg["auxiliary"] = {"curator": {"provider": "auto", "model": ""}} + assert curator._resolve_review_model(cfg) == ("openrouter", "openai/gpt-5.5") + + +def test_review_model_honors_auxiliary_curator_slot(curator_env): + """auxiliary.curator.{provider,model} fully set → that pair wins.""" + curator = curator_env["curator"] + cfg = { + "model": {"provider": "openrouter", "default": "openai/gpt-5.5"}, + "auxiliary": { + "curator": { + "provider": "openrouter", + "model": "openai/gpt-5.4-mini", + }, + }, + } + assert curator._resolve_review_model(cfg) == ( + "openrouter", "openai/gpt-5.4-mini", + ) + + +def test_review_model_auxiliary_curator_partial_override_falls_back(curator_env): + """Only one of slot provider/model set → fall back to the main pair. + + Prevents half-configured overrides from sending an empty side to + resolve_runtime_provider. + """ + curator = curator_env["curator"] + base_main = {"provider": "openrouter", "default": "openai/gpt-5.5"} + + cfg_provider_only = { + "model": dict(base_main), + "auxiliary": {"curator": {"provider": "openrouter", "model": ""}}, + } + assert curator._resolve_review_model(cfg_provider_only) == ( + "openrouter", "openai/gpt-5.5", + ) + + cfg_model_only = { + "model": dict(base_main), + "auxiliary": {"curator": {"provider": "auto", "model": "gpt-5.4-mini"}}, + } + assert curator._resolve_review_model(cfg_model_only) == ( + "openrouter", "openai/gpt-5.5", + ) + + +def test_review_model_legacy_curator_auxiliary_still_works(curator_env, caplog): + """Pre-unification users set curator.auxiliary.{provider,model} — honor it. + + Emits a deprecation log line but keeps their config working. + """ + curator = curator_env["curator"] + cfg = { + "model": {"provider": "openrouter", "default": "openai/gpt-5.5"}, + "curator": { + "auxiliary": { + "provider": "openrouter", + "model": "openai/gpt-5.4-mini", + }, + }, + } + import logging + with caplog.at_level(logging.INFO, logger="agent.curator"): + result = curator._resolve_review_model(cfg) + assert result == ("openrouter", "openai/gpt-5.4-mini") + assert any( + "deprecated curator.auxiliary" in rec.message for rec in caplog.records + ), "expected deprecation warning when legacy curator.auxiliary is used" + + +def test_review_model_new_slot_wins_over_legacy(curator_env): + """When BOTH new and legacy are set, the canonical slot wins.""" + curator = curator_env["curator"] + cfg = { + "model": {"provider": "openrouter", "default": "openai/gpt-5.5"}, + "auxiliary": { + "curator": {"provider": "nous", "model": "new-winner"}, + }, + "curator": { + "auxiliary": {"provider": "openrouter", "model": "legacy-loser"}, + }, + } + assert curator._resolve_review_model(cfg) == ("nous", "new-winner") + + +def test_review_model_handles_missing_sections(curator_env): + """Missing auxiliary/curator sections never raise — fall back cleanly.""" + curator = curator_env["curator"] + cfg = {"model": {"provider": "anthropic", "model": "claude-sonnet-4-6"}} + assert curator._resolve_review_model(cfg) == ( + "anthropic", "claude-sonnet-4-6", + ) + + # Completely empty config → ("auto", "") — resolve_runtime_provider + # handles the auto-detection chain from there. + assert curator._resolve_review_model({}) == ("auto", "") + + +def test_curator_slot_is_canonical_aux_task(): + """Curator must be a first-class slot in every aux-task registry. + + Four sources of truth, all checked by the shared registry test + (test_aux_config.py) for the main tasks — this test pins `curator` + specifically so the unification doesn't silently regress. + """ + from hermes_cli.config import DEFAULT_CONFIG + from hermes_cli.main import _AUX_TASKS + from hermes_cli.web_server import _AUX_TASK_SLOTS + + # 1. DEFAULT_CONFIG.auxiliary — schema source + assert "curator" in DEFAULT_CONFIG["auxiliary"], \ + "curator missing from DEFAULT_CONFIG['auxiliary']" + slot = DEFAULT_CONFIG["auxiliary"]["curator"] + assert slot["provider"] == "auto" + assert slot["model"] == "" + assert slot["timeout"] > 0, "curator timeout should be set (reviews run long)" + + # 2. hermes_cli/main.py _AUX_TASKS — CLI picker + aux_keys = {k for k, _name, _desc in _AUX_TASKS} + assert "curator" in aux_keys, "curator missing from _AUX_TASKS (CLI picker)" + + # 3. hermes_cli/web_server.py _AUX_TASK_SLOTS — REST API allowlist + assert "curator" in _AUX_TASK_SLOTS, \ + "curator missing from _AUX_TASK_SLOTS (dashboard REST API)" + + # 4. web/src/pages/ModelsPage.tsx is checked at build time; the tsx + # array and this tuple share a ``Must match _AUX_TASK_SLOTS`` comment. diff --git a/tests/agent/test_curator_reports.py b/tests/agent/test_curator_reports.py new file mode 100644 index 0000000000..3c94c231c1 --- /dev/null +++ b/tests/agent/test_curator_reports.py @@ -0,0 +1,258 @@ +"""Tests for the curator per-run report writer (run.json + REPORT.md). + +Reports live under ``~/.hermes/logs/curator/{YYYYMMDD-HHMMSS}/`` alongside +the standard log dir, not inside the user's ``skills/`` data directory. +""" + +from __future__ import annotations + +import json +import os +from datetime import datetime, timezone, timedelta +from pathlib import Path + +import pytest + + +@pytest.fixture +def curator_env(tmp_path, monkeypatch): + """Isolated HERMES_HOME with a skills/ dir + reset curator module state.""" + home = tmp_path / ".hermes" + home.mkdir() + (home / "skills").mkdir() + (home / "logs").mkdir() + monkeypatch.setenv("HERMES_HOME", str(home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + import importlib + import hermes_constants + importlib.reload(hermes_constants) + from agent import curator + importlib.reload(curator) + from tools import skill_usage + importlib.reload(skill_usage) + yield {"home": home, "curator": curator, "skill_usage": skill_usage} + + +def _make_llm_meta(**overrides): + base = { + "final": "short summary of the pass", + "summary": "short summary", + "model": "test-model", + "provider": "test-provider", + "tool_calls": [], + "error": None, + } + base.update(overrides) + return base + + +def test_reports_root_is_under_logs_not_skills(curator_env): + """Reports live in logs/curator/, not skills/ — operational telemetry + belongs with the logs, not with user-authored skill data.""" + curator = curator_env["curator"] + root = curator._reports_root() + home = curator_env["home"] + # Must be under logs/ + assert root == home / "logs" / "curator" + # Must NOT be under skills/ + assert "skills" not in root.parts + + +def test_write_run_report_creates_both_files(curator_env): + """Each run writes both a run.json (machine) and a REPORT.md (human).""" + curator = curator_env["curator"] + start = datetime.now(timezone.utc) + + run_dir = curator._write_run_report( + started_at=start, + elapsed_seconds=12.345, + auto_counts={"checked": 5, "marked_stale": 1, "archived": 0, "reactivated": 0}, + auto_summary="1 marked stale", + before_report=[], + before_names=set(), + after_report=[], + llm_meta=_make_llm_meta(), + ) + assert run_dir is not None + assert run_dir.is_dir() + assert (run_dir / "run.json").exists() + assert (run_dir / "REPORT.md").exists() + + # The directory name is a timestamp under logs/curator/ + assert run_dir.parent == curator._reports_root() + + +def test_run_json_has_expected_shape(curator_env): + """run.json must carry the machine-readable fields downstream tooling needs.""" + curator = curator_env["curator"] + start = datetime.now(timezone.utc) + + before_report = [ + {"name": "old-thing", "state": "active", "pinned": False}, + {"name": "keeper", "state": "active", "pinned": True}, + ] + after_report = [ + {"name": "keeper", "state": "active", "pinned": True}, + {"name": "new-umbrella", "state": "active", "pinned": False}, + ] + + run_dir = curator._write_run_report( + started_at=start, + elapsed_seconds=42.0, + auto_counts={"checked": 2, "marked_stale": 0, "archived": 0, "reactivated": 0}, + auto_summary="no changes", + before_report=before_report, + before_names={r["name"] for r in before_report}, + after_report=after_report, + llm_meta=_make_llm_meta( + final="I consolidated the whole universe.", + tool_calls=[ + {"name": "skills_list", "arguments": "{}"}, + {"name": "skill_manage", "arguments": '{"action":"create"}'}, + {"name": "terminal", "arguments": "mv ..."}, + ], + ), + ) + payload = json.loads((run_dir / "run.json").read_text()) + + # top-level shape + for k in ( + "started_at", "duration_seconds", "model", "provider", + "auto_transitions", "counts", "tool_call_counts", + "archived", "added", "state_transitions", + "llm_final", "llm_summary", "llm_error", "tool_calls", + ): + assert k in payload, f"missing key: {k}" + + # Diff logic + assert payload["archived"] == ["old-thing"] + assert payload["added"] == ["new-umbrella"] + # Counts reflect the diff + assert payload["counts"]["before"] == 2 + assert payload["counts"]["after"] == 2 + assert payload["counts"]["archived_this_run"] == 1 + assert payload["counts"]["added_this_run"] == 1 + # Tool call counts are aggregated + assert payload["tool_call_counts"]["skills_list"] == 1 + assert payload["tool_call_counts"]["skill_manage"] == 1 + assert payload["tool_call_counts"]["terminal"] == 1 + assert payload["counts"]["tool_calls_total"] == 3 + + +def test_report_md_is_human_readable(curator_env): + """REPORT.md should be a valid markdown doc with the key sections visible.""" + curator = curator_env["curator"] + start = datetime.now(timezone.utc) + + run_dir = curator._write_run_report( + started_at=start, + elapsed_seconds=75.0, + auto_counts={"checked": 10, "marked_stale": 2, "archived": 1, "reactivated": 0}, + auto_summary="2 marked stale, 1 archived", + before_report=[{"name": "foo", "state": "active", "pinned": False}], + before_names={"foo"}, + after_report=[{"name": "foo-umbrella", "state": "active", "pinned": False}], + llm_meta=_make_llm_meta( + final="Consolidated foo-like skills into foo-umbrella.", + model="claude-opus-4.7", + provider="openrouter", + ), + ) + md = (run_dir / "REPORT.md").read_text() + + # Structural checks + assert "# Curator run" in md + assert "Auto-transitions" in md + assert "LLM consolidation pass" in md + assert "Recovery" in md + + # The model / provider we passed in show up + assert "claude-opus-4.7" in md + assert "openrouter" in md + + # The added/archived lists are present + assert "Skills archived" in md + assert "`foo`" in md + assert "New skills this run" in md + assert "`foo-umbrella`" in md + + # The full LLM final response is included verbatim (no 240-char truncation) + assert "Consolidated foo-like skills into foo-umbrella." in md + + +def test_same_second_reruns_get_unique_dirs(curator_env): + """If the curator somehow runs twice in the same second, the second + report still gets its own directory rather than overwriting the first.""" + curator = curator_env["curator"] + start = datetime(2026, 4, 29, 5, 33, 34, tzinfo=timezone.utc) + + kwargs = dict( + started_at=start, + elapsed_seconds=1.0, + auto_counts={"checked": 0, "marked_stale": 0, "archived": 0, "reactivated": 0}, + auto_summary="no changes", + before_report=[], + before_names=set(), + after_report=[], + llm_meta=_make_llm_meta(), + ) + a = curator._write_run_report(**kwargs) + b = curator._write_run_report(**kwargs) + assert a != b + assert a is not None and b is not None + # Second dir has a numeric disambiguator suffix + assert b.name.startswith(a.name) + + +def test_report_captures_llm_error_and_continues(curator_env): + """If the LLM pass recorded an error, the report still writes and + surfaces the error prominently.""" + curator = curator_env["curator"] + run_dir = curator._write_run_report( + started_at=datetime.now(timezone.utc), + elapsed_seconds=2.0, + auto_counts={"checked": 0, "marked_stale": 0, "archived": 0, "reactivated": 0}, + auto_summary="no changes", + before_report=[], + before_names=set(), + after_report=[], + llm_meta=_make_llm_meta( + error="HTTP 400: No models provided", + final="", + summary="error", + ), + ) + md = (run_dir / "REPORT.md").read_text() + assert "HTTP 400" in md + payload = json.loads((run_dir / "run.json").read_text()) + assert payload["llm_error"] == "HTTP 400: No models provided" + + +def test_state_transitions_captured_in_report(curator_env): + """When a skill moves active → stale or stale → archived between + before/after snapshots, the report records it.""" + curator = curator_env["curator"] + start = datetime.now(timezone.utc) + + before = [{"name": "getting-old", "state": "active", "pinned": False}] + after = [{"name": "getting-old", "state": "stale", "pinned": False}] + + run_dir = curator._write_run_report( + started_at=start, + elapsed_seconds=1.0, + auto_counts={"checked": 1, "marked_stale": 1, "archived": 0, "reactivated": 0}, + auto_summary="1 marked stale", + before_report=before, + before_names={r["name"] for r in before}, + after_report=after, + llm_meta=_make_llm_meta(), + ) + payload = json.loads((run_dir / "run.json").read_text()) + assert payload["state_transitions"] == [ + {"name": "getting-old", "from": "active", "to": "stale"} + ] + md = (run_dir / "REPORT.md").read_text() + assert "State transitions" in md + assert "getting-old" in md + assert "active → stale" in md diff --git a/tests/agent/test_deepseek_anthropic_thinking.py b/tests/agent/test_deepseek_anthropic_thinking.py new file mode 100644 index 0000000000..4d032fa359 --- /dev/null +++ b/tests/agent/test_deepseek_anthropic_thinking.py @@ -0,0 +1,242 @@ +"""Regression guard: preserve thinking blocks on DeepSeek's /anthropic endpoint. + +DeepSeek's ``api.deepseek.com/anthropic`` route speaks the Anthropic Messages +protocol but, when thinking mode is enabled, requires ``thinking`` blocks from +prior assistant turns to round-trip on subsequent requests. The generic +third-party path strips them (signatures are Anthropic-proprietary and other +proxies cannot validate them), so without a DeepSeek-specific carve-out the +next tool-call turn fails with HTTP 400:: + + The content[].thinking in the thinking mode must be passed back to the + API. + +DeepSeek's compatibility matrix lists ``thinking`` as supported but +``redacted_thinking`` and ``cache_control`` on thinking blocks as not +supported. Handling is the same as Kimi's ``/coding`` endpoint: strip +Anthropic-signed blocks (DeepSeek can't validate them) but preserve unsigned +blocks that Hermes synthesises from ``reasoning_content``. + +See hermes-agent#16748. +""" + +from __future__ import annotations + +import pytest + + +class TestDeepSeekAnthropicPreservesThinking: + """convert_messages_to_anthropic must replay DeepSeek thinking blocks.""" + + @pytest.mark.parametrize( + "base_url", + [ + "https://api.deepseek.com/anthropic", + "https://api.deepseek.com/anthropic/", + "https://api.deepseek.com/anthropic/v1", + "https://API.DeepSeek.com/anthropic", + ], + ) + def test_unsigned_thinking_block_survives_replay(self, base_url: str) -> None: + """Unsigned thinking (synthesised from reasoning_content) must be preserved.""" + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "reasoning_content": "planning the tool call", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "skill_view", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "ok"}, + ] + _system, converted = convert_messages_to_anthropic( + messages, base_url=base_url + ) + + assistant_msg = next(m for m in converted if m["role"] == "assistant") + thinking_blocks = [ + b for b in assistant_msg["content"] + if isinstance(b, dict) and b.get("type") == "thinking" + ] + assert len(thinking_blocks) == 1, ( + f"DeepSeek /anthropic ({base_url}) must preserve unsigned thinking " + "blocks synthesised from reasoning_content — upstream rejects " + "replayed tool-call messages without them." + ) + assert thinking_blocks[0]["thinking"] == "planning the tool call" + # Synthesised block — never has a signature + assert "signature" not in thinking_blocks[0] + + def test_unsigned_thinking_preserved_on_non_latest_assistant_turn(self) -> None: + """DeepSeek validates history across every prior assistant turn, not just last.""" + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "q1"}, + { + "role": "assistant", + "reasoning_content": "r1", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "ok"}, + {"role": "user", "content": "q2"}, + { + "role": "assistant", + "reasoning_content": "r2", + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_2", "content": "ok"}, + ] + _system, converted = convert_messages_to_anthropic( + messages, base_url="https://api.deepseek.com/anthropic" + ) + + assistants = [m for m in converted if m["role"] == "assistant"] + assert len(assistants) == 2 + for assistant, expected in zip(assistants, ("r1", "r2")): + thinking = [ + b for b in assistant["content"] + if isinstance(b, dict) and b.get("type") == "thinking" + ] + assert len(thinking) == 1 + assert thinking[0]["thinking"] == expected + + def test_signed_anthropic_thinking_block_is_stripped(self) -> None: + """Anthropic-signed blocks (that leaked through) must still be stripped. + + DeepSeek issues its own signatures and cannot validate Anthropic's — + the strip-signed / keep-unsigned split matches the Kimi policy. + """ + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "anthropic-signed payload", + "signature": "anthropic-sig-xyz", + }, + {"type": "text", "text": "hello"}, + ], + }, + {"role": "user", "content": "again"}, + ] + _system, converted = convert_messages_to_anthropic( + messages, base_url="https://api.deepseek.com/anthropic" + ) + + assistant_msg = next(m for m in converted if m["role"] == "assistant") + thinking_blocks = [ + b for b in assistant_msg["content"] + if isinstance(b, dict) and b.get("type") == "thinking" + ] + assert thinking_blocks == [], ( + "Signed Anthropic thinking blocks must be stripped on DeepSeek — " + "DeepSeek cannot validate Anthropic-proprietary signatures." + ) + + def test_cache_control_stripped_from_thinking_block(self) -> None: + """cache_control must still be stripped even when the block is preserved. + + DeepSeek's compatibility matrix lists cache_control on thinking blocks + as ignored — cache markers interfere with signature validation on + upstreams that do check them, so Hermes strips them everywhere. + """ + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "reasoning_content": "r1", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "ok"}, + ] + # Inject cache_control on the synthesised thinking block after-the-fact + # by running conversion once, mutating, then re-running would be + # indirect. Instead check the simpler invariant: no thinking block in + # the converted output carries cache_control. + _system, converted = convert_messages_to_anthropic( + messages, base_url="https://api.deepseek.com/anthropic" + ) + for m in converted: + if not isinstance(m.get("content"), list): + continue + for b in m["content"]: + if isinstance(b, dict) and b.get("type") in ("thinking", "redacted_thinking"): + assert "cache_control" not in b + + def test_openai_compat_deepseek_base_is_not_matched(self) -> None: + """The OpenAI-compatible ``api.deepseek.com`` base must NOT trigger the + DeepSeek /anthropic branch — it never reaches this adapter, but the + detector should still fail closed so an accidental misuse doesn't + quietly send signed Anthropic blocks to an OpenAI endpoint. + """ + from agent.anthropic_adapter import _is_deepseek_anthropic_endpoint + + assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com") is False + assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/v1") is False + assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/anthropic") is True + assert _is_deepseek_anthropic_endpoint("https://api.deepseek.com/anthropic/v1") is True + + def test_non_deepseek_third_party_still_strips_all_thinking(self) -> None: + """MiniMax and other third-party Anthropic endpoints must keep the + generic strip-all behaviour (they reject unsigned blocks outright). + """ + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "reasoning_content": "r1", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "ok"}, + ] + _system, converted = convert_messages_to_anthropic( + messages, base_url="https://api.minimax.io/anthropic" + ) + assistant_msg = next(m for m in converted if m["role"] == "assistant") + thinking_blocks = [ + b for b in assistant_msg["content"] + if isinstance(b, dict) and b.get("type") == "thinking" + ] + assert thinking_blocks == [], ( + "Non-DeepSeek third-party endpoints must keep the generic " + "strip-all-thinking behaviour — unsigned blocks get rejected." + ) diff --git a/tests/agent/test_error_classifier.py b/tests/agent/test_error_classifier.py index d6598b66a3..9d52c7bdf2 100644 --- a/tests/agent/test_error_classifier.py +++ b/tests/agent/test_error_classifier.py @@ -57,7 +57,9 @@ class TestFailoverReason: "context_overflow", "payload_too_large", "image_too_large", "model_not_found", "format_error", "provider_policy_blocked", - "thinking_signature", "long_context_tier", "unknown", + "thinking_signature", "long_context_tier", + "oauth_long_context_beta_forbidden", + "unknown", } actual = {r.value for r in FailoverReason} assert expected == actual @@ -458,6 +460,40 @@ class TestClassifyApiError: result = classify_api_error(e, provider="anthropic") assert result.reason == FailoverReason.rate_limit + # ── Provider-specific: Anthropic OAuth 1M-context beta forbidden ── + + def test_anthropic_oauth_1m_beta_forbidden(self): + """400 + 'long context beta is not yet available for this subscription' + → oauth_long_context_beta_forbidden (retryable, no compression).""" + e = MockAPIError( + "The long context beta is not yet available for this subscription.", + status_code=400, + ) + result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4.6") + assert result.reason == FailoverReason.oauth_long_context_beta_forbidden + assert result.retryable is True + assert result.should_compress is False + + def test_anthropic_oauth_1m_beta_forbidden_does_not_collide_with_tier_gate(self): + """The 429 'extra usage' + 'long context' tier gate keeps its own + classification even though its message mentions 'long context'.""" + e = MockAPIError( + "Extra usage is required for long context requests over 200k tokens", + status_code=429, + ) + result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4.6") + assert result.reason == FailoverReason.long_context_tier + + def test_400_without_beta_phrase_is_not_1m_beta_forbidden(self): + """A generic 400 that happens to mention 'long context' but not the + exact beta-availability phrase should not be misclassified.""" + e = MockAPIError( + "long context window exceeded", + status_code=400, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason != FailoverReason.oauth_long_context_beta_forbidden + # ── Transport errors ── def test_read_timeout(self): diff --git a/tests/agent/test_kimi_coding_anthropic_thinking.py b/tests/agent/test_kimi_coding_anthropic_thinking.py index 706f7e0e16..89872cc2f0 100644 --- a/tests/agent/test_kimi_coding_anthropic_thinking.py +++ b/tests/agent/test_kimi_coding_anthropic_thinking.py @@ -94,13 +94,16 @@ class TestKimiCodingSkipsAnthropicThinking: ) assert "thinking" in kwargs - def test_kimi_root_endpoint_unaffected(self) -> None: - """Only the /coding route is special-cased — plain api.kimi.com is not. + def test_kimi_root_endpoint_via_anthropic_transport_omits_thinking(self) -> None: + """Plain ``api.kimi.com`` hit via the Anthropic transport also omits thinking. - ``api.kimi.com`` without ``/coding`` uses the chat_completions transport - (see runtime_provider._detect_api_mode_for_url); build_anthropic_kwargs - should never see it, but if it somehow does we should not suppress - thinking there — that path has different semantics. + Auto-detection routes ``api.kimi.com/v1`` to ``chat_completions`` by + default, but users can explicitly configure + ``api_mode: anthropic_messages`` against any Kimi host. The upstream + validation (reasoning_content required on replayed tool-call + messages) is the same regardless of URL path, so the thinking + suppression must apply to every Kimi host, not just ``/coding``. + See #17057. """ from agent.anthropic_adapter import build_anthropic_kwargs @@ -112,4 +115,98 @@ class TestKimiCodingSkipsAnthropicThinking: reasoning_config={"enabled": True, "effort": "medium"}, base_url="https://api.kimi.com/v1", ) + assert "thinking" not in kwargs + + # ── #17057: custom / proxied Kimi-compatible endpoints ────────── + @pytest.mark.parametrize( + "base_url,model", + [ + # Custom host with Kimi-family model — the reporter's case + ("http://my-kimi-proxy.internal", "kimi-2.6"), + ("https://llm.example.com/anthropic", "kimi-k2.5"), + ("https://llm.example.com/anthropic", "moonshot-v1-8k"), + ("https://llm.example.com/anthropic", "kimi_thinking"), + ("https://llm.example.com/anthropic", "moonshotai/kimi-k2.5"), + # Official Moonshot host (previously uncovered) + ("https://api.moonshot.ai/anthropic", "moonshot-v1-32k"), + ("https://api.moonshot.cn/anthropic", "moonshot-v1-32k"), + ], + ) + def test_kimi_family_custom_endpoint_omits_thinking( + self, base_url: str, model: str + ) -> None: + """Custom / proxied Kimi endpoints must also strip Anthropic thinking.""" + from agent.anthropic_adapter import build_anthropic_kwargs + + kwargs = build_anthropic_kwargs( + model=model, + messages=[{"role": "user", "content": "hello"}], + tools=None, + max_tokens=4096, + reasoning_config={"enabled": True, "effort": "medium"}, + base_url=base_url, + ) + assert "thinking" not in kwargs, ( + f"Kimi-family endpoint ({base_url}, {model}) must not receive " + f"Anthropic thinking — upstream validates reasoning_content on " + f"replayed tool-call history we don't preserve." + ) + assert "output_config" not in kwargs + + def test_custom_endpoint_non_kimi_model_keeps_thinking(self) -> None: + """Custom endpoint with a non-Kimi model must keep thinking intact. + + Guards against over-broad model-family matching — only model names + starting with a Kimi/Moonshot prefix should trigger suppression. + """ + from agent.anthropic_adapter import build_anthropic_kwargs + + kwargs = build_anthropic_kwargs( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": "hello"}], + tools=None, + max_tokens=4096, + reasoning_config={"enabled": True, "effort": "medium"}, + base_url="https://my-llm-proxy.example.com/anthropic", + ) assert "thinking" in kwargs + assert kwargs["thinking"]["type"] == "enabled" + + def test_kimi_family_replay_preserves_unsigned_thinking(self) -> None: + """On a custom Kimi endpoint, unsigned reasoning_content thinking + blocks must survive the third-party signature-stripping pass so + the upstream's message-history validation passes. + """ + from agent.anthropic_adapter import convert_messages_to_anthropic + + messages = [ + {"role": "user", "content": "hi"}, + { + "role": "assistant", + "reasoning_content": "planning the tool call", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "skill_view", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "ok"}, + ] + _, converted = convert_messages_to_anthropic( + messages, + base_url="http://my-kimi-proxy.internal", + model="kimi-2.6", + ) + # The assistant message still carries the unsigned thinking block + # synthesised from reasoning_content (required by Kimi's history + # validation). A plain third-party endpoint would have stripped it. + assistant_msg = next(m for m in converted if m["role"] == "assistant") + assistant_blocks = assistant_msg["content"] + thinking_blocks = [ + b for b in assistant_blocks + if isinstance(b, dict) and b.get("type") == "thinking" + ] + assert len(thinking_blocks) == 1 + assert thinking_blocks[0]["thinking"] == "planning the tool call" diff --git a/tests/agent/test_memory_session_switch.py b/tests/agent/test_memory_session_switch.py new file mode 100644 index 0000000000..610c09b29f --- /dev/null +++ b/tests/agent/test_memory_session_switch.py @@ -0,0 +1,320 @@ +"""Tests for the on_session_switch hook and session_id propagation. + +Covers #6672: memory providers must be notified when AIAgent.session_id +rotates mid-process (via /resume, /branch, /reset, /new, or context +compression). Without the notification, providers that cache per-session +state in initialize() (Hindsight, and any plugin that stores session_id +for scoped writes) keep writing into the old session's record. +""" + +import json + +import pytest + +from agent.memory_manager import MemoryManager +from agent.memory_provider import MemoryProvider + + +class _RecordingProvider(MemoryProvider): + """Provider that records every lifecycle call for assertion.""" + + def __init__(self, name="rec"): + self._name = name + self.switch_calls: list[dict] = [] + self.sync_calls: list[dict] = [] + self.queue_calls: list[dict] = [] + self.initialize_calls: list[dict] = [] + + @property + def name(self) -> str: + return self._name + + def is_available(self) -> bool: # pragma: no cover - unused + return True + + def initialize(self, session_id, **kwargs): + self.initialize_calls.append({"session_id": session_id, **kwargs}) + + def get_tool_schemas(self): + return [] + + def sync_turn(self, user_content, assistant_content, *, session_id=""): + self.sync_calls.append( + {"user": user_content, "asst": assistant_content, "session_id": session_id} + ) + + def queue_prefetch(self, query, *, session_id=""): + self.queue_calls.append({"query": query, "session_id": session_id}) + + def on_session_switch( + self, + new_session_id, + *, + parent_session_id="", + reset=False, + **kwargs, + ): + self.switch_calls.append( + { + "new": new_session_id, + "parent": parent_session_id, + "reset": reset, + "extra": kwargs, + } + ) + + +# --------------------------------------------------------------------------- +# MemoryProvider ABC — default on_session_switch is a no-op +# --------------------------------------------------------------------------- + + +class _MinimalProvider(MemoryProvider): + """Provider that does NOT override on_session_switch — ABC default must no-op.""" + + @property + def name(self) -> str: + return "minimal" + + def is_available(self) -> bool: + return True + + def initialize(self, session_id, **kwargs): # pragma: no cover - unused + pass + + def get_tool_schemas(self): + return [] + + +def test_abc_default_on_session_switch_is_noop(): + """Providers that don't override the hook must not raise.""" + p = _MinimalProvider() + # All three call styles must be accepted without raising + p.on_session_switch("new-id") + p.on_session_switch("new-id", parent_session_id="old-id") + p.on_session_switch("new-id", parent_session_id="old-id", reset=True) + p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session") + + +# --------------------------------------------------------------------------- +# MemoryManager.on_session_switch — fan-out +# --------------------------------------------------------------------------- + + +def test_manager_fans_out_to_all_providers(): + mm = MemoryManager() + # Only one external provider is allowed; use the builtin slot for p1. + p1 = _RecordingProvider(name="builtin") + p2 = _RecordingProvider(name="hindsight") + mm.add_provider(p1) + mm.add_provider(p2) + + mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume") + + assert len(p1.switch_calls) == 1 + assert len(p2.switch_calls) == 1 + for call in (p1.switch_calls[0], p2.switch_calls[0]): + assert call["new"] == "new-sid" + assert call["parent"] == "old-sid" + assert call["reset"] is False + assert call["extra"] == {"reason": "resume"} + + +def test_manager_ignores_empty_session_id(): + """Empty string session_id must not trigger provider hooks. + + Prevents accidental fires during shutdown when self.session_id may be + cleared. Providers expect a meaningful id to switch TO. + """ + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.on_session_switch("") + mm.on_session_switch(None) # type: ignore[arg-type] + assert p.switch_calls == [] + + +def test_manager_isolates_provider_failures(): + """A provider that raises must not block other providers.""" + + class _Broken(_RecordingProvider): + def on_session_switch(self, *args, **kwargs): # type: ignore[override] + raise RuntimeError("boom") + + mm = MemoryManager() + # MemoryManager rejects a second external provider, so pair broken + # (builtin slot) with a good external one. + broken = _Broken(name="builtin") + good = _RecordingProvider(name="good") + mm.add_provider(broken) + mm.add_provider(good) + + # Must not raise — exceptions in one provider are swallowed + logged + mm.on_session_switch("new-sid", parent_session_id="old-sid") + assert len(good.switch_calls) == 1 + assert good.switch_calls[0]["new"] == "new-sid" + + +def test_manager_reset_flag_preserved(): + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.on_session_switch("new-sid", reset=True, reason="new_session") + assert p.switch_calls[0]["reset"] is True + assert p.switch_calls[0]["extra"] == {"reason": "new_session"} + + +# --------------------------------------------------------------------------- +# MemoryManager.sync_all / queue_prefetch_all — session_id propagation +# --------------------------------------------------------------------------- + + +def test_sync_all_propagates_session_id_to_providers(): + """run_agent.py's sync_all call must pass session_id through to providers. + + Without this, a provider that updates _session_id defensively in + sync_turn (as Hindsight does at hindsight/__init__.py:1199) never + sees the new id and keeps writing under the old one. + """ + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.sync_all("hello", "world", session_id="sess-42") + assert p.sync_calls == [ + {"user": "hello", "asst": "world", "session_id": "sess-42"} + ] + + +def test_queue_prefetch_all_propagates_session_id_to_providers(): + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.queue_prefetch_all("next query", session_id="sess-42") + assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}] + + +# --------------------------------------------------------------------------- +# Hindsight reference implementation — state-flush semantics +# --------------------------------------------------------------------------- + + +def _make_hindsight_provider(): + """Build a bare HindsightMemoryProvider that skips network setup. + + We instantiate without importing optional deps at class-level by + bypassing __init__ and seeding the attributes on_session_switch + reads/writes. This keeps the test hermetic. + """ + import threading + hindsight_mod = pytest.importorskip("plugins.memory.hindsight") + provider = object.__new__(hindsight_mod.HindsightMemoryProvider) + provider._session_id = "old-sid" + provider._parent_session_id = "" + provider._document_id = "old-sid-20260101_000000_000000" + provider._session_turns = ["turn-1", "turn-2"] + provider._turn_counter = 2 + provider._turn_index = 2 + # Attrs read by _build_metadata / _build_retain_kwargs when the + # buffer-flush path on session switch fires. Empty strings keep the + # metadata minimal but well-formed. + provider._retain_source = "" + provider._platform = "" + provider._user_id = "" + provider._user_name = "" + provider._chat_id = "" + provider._chat_name = "" + provider._chat_type = "" + provider._thread_id = "" + provider._agent_identity = "" + provider._agent_workspace = "" + provider._retain_tags = [] + provider._retain_context = "test-context" + provider._retain_async = False + provider._bank_id = "test-bank" + # Prefetch state the switch path drains/clears. + provider._prefetch_thread = None + provider._prefetch_lock = threading.Lock() + provider._prefetch_result = "" + # Sync thread tracking (legacy alias at the writer). + provider._sync_thread = None + # Writer queue infra the flush-on-switch path enqueues onto. We stub + # _ensure_writer / _register_atexit so no real thread is spawned; + # tests exercising flush delivery live in + # tests/plugins/memory/test_hindsight_provider.py where the full + # writer-queue wiring is in place. + import queue as _queue + provider._retain_queue = _queue.Queue() + provider._shutting_down = threading.Event() + provider._atexit_registered = True + provider._ensure_writer = lambda: None + provider._register_atexit = lambda: None + # Stub the network-touching helper so any enqueued flush closure is + # a no-op if ever drained in a unit test. + provider._run_hindsight_operation = lambda _op: None + return provider + + +def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc(): + provider = _make_hindsight_provider() + old_doc = provider._document_id + + provider.on_session_switch( + "new-sid", parent_session_id="old-sid", reset=False, reason="resume" + ) + + assert provider._session_id == "new-sid" + assert provider._parent_session_id == "old-sid" + # Document id MUST be fresh — else next retain overwrites old session doc + assert provider._document_id != old_doc + assert provider._document_id.startswith("new-sid-") + + +def test_hindsight_on_session_switch_clears_turn_buffers(): + """Accumulated _session_turns must not leak into the next session. + + Hindsight batches turns under a single _document_id. If the buffer + isn't cleared on switch, the next retain under the new _document_id + flushes turns that belong to the previous session. + """ + provider = _make_hindsight_provider() + provider.on_session_switch("new-sid", parent_session_id="old-sid") + assert provider._session_turns == [] + assert provider._turn_counter == 0 + assert provider._turn_index == 0 + + +def test_hindsight_on_session_switch_clears_on_reset_true(): + """reset=True (from /new, /reset) must also flush buffers.""" + provider = _make_hindsight_provider() + provider.on_session_switch("new-sid", reset=True, reason="new_session") + assert provider._session_id == "new-sid" + assert provider._session_turns == [] + assert provider._turn_counter == 0 + + +def test_hindsight_on_session_switch_ignores_empty_id(): + """Empty new_session_id must be a no-op to avoid corrupting state.""" + provider = _make_hindsight_provider() + before = ( + provider._session_id, + provider._document_id, + list(provider._session_turns), + provider._turn_counter, + ) + provider.on_session_switch("") + provider.on_session_switch(None) # type: ignore[arg-type] + after = ( + provider._session_id, + provider._document_id, + list(provider._session_turns), + provider._turn_counter, + ) + assert before == after + + +def test_hindsight_preserves_parent_across_empty_parent_arg(): + """Omitting parent_session_id must NOT overwrite an existing one.""" + provider = _make_hindsight_provider() + provider._parent_session_id = "original-parent" + provider.on_session_switch("new-sid") # no parent passed + assert provider._parent_session_id == "original-parent" diff --git a/tests/agent/test_minimax_provider.py b/tests/agent/test_minimax_provider.py index 9ae865d57e..7c64b3575a 100644 --- a/tests/agent/test_minimax_provider.py +++ b/tests/agent/test_minimax_provider.py @@ -308,10 +308,15 @@ class TestMinimaxPreserveDots: from agent.anthropic_adapter import normalize_model_name assert normalize_model_name("MiniMax-M2.7", preserve_dots=True) == "MiniMax-M2.7" - def test_normalize_converts_without_preserve(self): + def test_normalize_preserves_non_anthropic_dots_without_preserve(self): from agent.anthropic_adapter import normalize_model_name - # Without preserve_dots, dots become hyphens (broken for MiniMax) - assert normalize_model_name("MiniMax-M2.7", preserve_dots=False) == "MiniMax-M2-7" + # Non-Anthropic model families use dots as canonical version separators; + # only Claude/Anthropic names are hyphen-normalized by default. + assert normalize_model_name("MiniMax-M2.7", preserve_dots=False) == "MiniMax-M2.7" + + def test_normalize_still_converts_claude_dots_without_preserve(self): + from agent.anthropic_adapter import normalize_model_name + assert normalize_model_name("claude-opus-4.6", preserve_dots=False) == "claude-opus-4-6" class TestMinimaxSwitchModelCredentialGuard: diff --git a/tests/agent/test_onboarding.py b/tests/agent/test_onboarding.py index c886979898..1eaf0d01d2 100644 --- a/tests/agent/test_onboarding.py +++ b/tests/agent/test_onboarding.py @@ -205,11 +205,22 @@ class TestDetectOpenclawResidue: class TestOpenclawResidueHint: - def test_hint_mentions_cleanup_command(self): + def test_hint_mentions_migrate_command(self): + # `migrate` is the non-destructive path — should lead the banner. msg = openclaw_residue_hint_cli() - assert "hermes claw cleanup" in msg + assert "hermes claw migrate" in msg assert "~/.openclaw" in msg + def test_hint_mentions_cleanup_command(self): + # `cleanup` is mentioned as the follow-up archive step. + assert "hermes claw cleanup" in openclaw_residue_hint_cli() + + def test_hint_warns_cleanup_breaks_openclaw(self): + # Archiving the directory breaks OpenClaw for users still running it — + # the banner must flag that side effect. + msg = openclaw_residue_hint_cli().lower() + assert "openclaw will stop working" in msg or "stop working" in msg + def test_hint_not_empty(self): assert openclaw_residue_hint_cli().strip() diff --git a/tests/agent/test_skill_commands_reload.py b/tests/agent/test_skill_commands_reload.py new file mode 100644 index 0000000000..ee77141d19 --- /dev/null +++ b/tests/agent/test_skill_commands_reload.py @@ -0,0 +1,160 @@ +"""Tests for ``agent.skill_commands.reload_skills``. + +Covers the helper that powers ``/reload-skills`` (CLI + gateway slash command). +The helper rescans the skills directory and returns a diff of what changed. +It does NOT invalidate the skills system-prompt cache — skills are invoked +at runtime via ``/skill-name``, ``skills_list``, or ``skill_view`` and don't +need to live in the system prompt. + +``added`` and ``removed`` are lists of ``{"name": str, "description": str}`` +dicts. Descriptions are truncated to 60 chars. +""" + +import shutil +import tempfile +import textwrap +from pathlib import Path + +import pytest + + +def _write_skill(skills_dir: Path, name: str, description: str = "") -> Path: + skill_dir = skills_dir / name + skill_dir.mkdir(parents=True, exist_ok=True) + (skill_dir / "SKILL.md").write_text( + textwrap.dedent( + f"""\ + --- + name: {name} + description: {description or f'{name} skill'} + --- + body + """ + ) + ) + return skill_dir + + +@pytest.fixture +def hermes_home(monkeypatch): + """Isolate HERMES_HOME for ``reload_skills`` tests. + + Rather than popping cache-bearing modules from ``sys.modules`` (which + races against pytest-xdist's parallel workers), we monkeypatch the + module-level ``HERMES_HOME`` / ``SKILLS_DIR`` constants in place so the + isolation is local to this fixture's scope. + """ + td = tempfile.mkdtemp(prefix="hermes-reload-skills-") + monkeypatch.setenv("HERMES_HOME", td) + home = Path(td) + (home / "skills").mkdir(parents=True, exist_ok=True) + + # Import lazily (inside fixture) so the modules are already resident, + # then redirect their captured paths at the new temp dir. + import tools.skills_tool as _st + import agent.skill_commands as _sc + + monkeypatch.setattr(_st, "HERMES_HOME", home, raising=False) + monkeypatch.setattr(_st, "SKILLS_DIR", home / "skills", raising=False) + # Reset the in-process slash-command cache so each test starts from zero. + monkeypatch.setattr(_sc, "_skill_commands", {}, raising=False) + + yield home + + shutil.rmtree(td, ignore_errors=True) + + +class TestReloadSkillsHelper: + """``agent.skill_commands.reload_skills``.""" + + def test_returns_expected_keys(self, hermes_home): + from agent.skill_commands import reload_skills + + result = reload_skills() + assert set(result) == {"added", "removed", "unchanged", "total", "commands"} + assert result["total"] == 0 + assert result["added"] == [] + assert result["removed"] == [] + + def test_detects_newly_added_skill_with_description(self, hermes_home): + from agent.skill_commands import reload_skills, get_skill_commands + + # Prime the cache so subsequent diff is meaningful + get_skill_commands() + + _write_skill(hermes_home / "skills", "demo", "a demo skill") + result = reload_skills() + + assert result["added"] == [{"name": "demo", "description": "a demo skill"}] + assert result["removed"] == [] + assert result["total"] == 1 + assert result["commands"] == 1 + + def test_detects_removed_skill_carries_description(self, hermes_home): + from agent.skill_commands import reload_skills + + skill_dir = _write_skill(hermes_home / "skills", "demo", "soon to be gone") + # First reload: demo present + first = reload_skills() + assert first["total"] == 1 + assert first["added"] == [{"name": "demo", "description": "soon to be gone"}] + + # Remove and reload — the description must survive the removal diff + # (we cached it from the pre-rescan snapshot). + shutil.rmtree(skill_dir) + second = reload_skills() + + assert second["removed"] == [{"name": "demo", "description": "soon to be gone"}] + assert second["added"] == [] + assert second["total"] == 0 + + def test_description_passes_through_verbatim(self, hermes_home): + """``description`` must be the full SKILL.md frontmatter string — no + truncation. The system prompt renders skills as + `` - name: description`` without a length cap, and the reload + note mirrors that format, so truncating here would make the diff + render differently from the original catalog.""" + from agent.skill_commands import reload_skills, get_skill_commands + + get_skill_commands() # prime + long_desc = "x" * 200 + _write_skill(hermes_home / "skills", "longdesc", long_desc) + + result = reload_skills() + assert len(result["added"]) == 1 + assert result["added"][0]["description"] == long_desc + + def test_unchanged_skills_appear_in_unchanged_list(self, hermes_home): + from agent.skill_commands import reload_skills, get_skill_commands + + _write_skill(hermes_home / "skills", "alpha") + # Prime cache + get_skill_commands() + + # Call reload again with no FS changes + result = reload_skills() + assert "alpha" in result["unchanged"] + assert result["added"] == [] + assert result["removed"] == [] + + def test_does_not_invalidate_prompt_cache_snapshot(self, hermes_home): + """reload_skills must NOT delete the skills prompt-cache snapshot. + + Skills are called at runtime — the system prompt doesn't need to + mention them for the model to use them — so reloading them should + preserve prefix caching. + """ + from agent.prompt_builder import _skills_prompt_snapshot_path + from agent.skill_commands import reload_skills + + snapshot = _skills_prompt_snapshot_path() + snapshot.parent.mkdir(parents=True, exist_ok=True) + snapshot.write_text("{}") + assert snapshot.exists() + + reload_skills() + + assert snapshot.exists(), ( + "prompt cache snapshot should be preserved — skills don't live " + "in the system prompt so there's no reason to invalidate it" + ) diff --git a/tests/agent/transports/test_chat_completions.py b/tests/agent/transports/test_chat_completions.py index e558fa3de7..bec7dc58a0 100644 --- a/tests/agent/transports/test_chat_completions.py +++ b/tests/agent/transports/test_chat_completions.py @@ -122,21 +122,25 @@ class TestChatCompletionsBuildKwargs: ) assert kw["extra_body"]["think"] is False - def test_gemini_without_explicit_reasoning_config_keeps_existing_behavior(self, transport): + def test_gemini_native_without_explicit_reasoning_config_keeps_existing_behavior(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="gemini-3-flash-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta", ) assert "thinking_config" not in kw.get("extra_body", {}) + assert "google" not in kw.get("extra_body", {}) + assert "extra_body" not in kw.get("extra_body", {}) - def test_gemini_flash_reasoning_maps_to_thinking_config(self, transport): + def test_gemini_native_flash_reasoning_maps_to_top_level_thinking_config(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="gemini-3-flash-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta", reasoning_config={"enabled": True, "effort": "high"}, ) assert kw["extra_body"]["thinking_config"] == { @@ -144,52 +148,85 @@ class TestChatCompletionsBuildKwargs: "thinkingLevel": "high", } - def test_gemini_25_reasoning_only_enables_visible_thoughts(self, transport): + def test_gemini_openai_compat_flash_reasoning_maps_to_nested_google_thinking_config(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gemini-3-flash-preview", + messages=msgs, + provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + reasoning_config={"enabled": True, "effort": "high"}, + ) + assert "thinking_config" not in kw["extra_body"] + assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == { + "include_thoughts": True, + "thinking_level": "high", + } + + def test_gemini_native_25_reasoning_only_enables_visible_thoughts(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="gemini-2.5-flash", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta", reasoning_config={"enabled": True, "effort": "high"}, ) assert kw["extra_body"]["thinking_config"] == { "includeThoughts": True, } - def test_gemini_pro_reasoning_clamps_to_supported_levels(self, transport): + def test_gemini_openai_compat_pro_reasoning_clamps_to_supported_levels(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="google/gemini-3.1-pro-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", reasoning_config={"enabled": True, "effort": "medium"}, ) - assert kw["extra_body"]["thinking_config"] == { - "includeThoughts": True, - "thinkingLevel": "low", + assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == { + "include_thoughts": True, + "thinking_level": "low", } - def test_gemini_disabled_reasoning_hides_thoughts(self, transport): + def test_gemini_native_disabled_reasoning_hides_thoughts(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="gemini-3-flash-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta", reasoning_config={"enabled": False}, ) assert kw["extra_body"]["thinking_config"] == { "includeThoughts": False, } - def test_gemini_xhigh_clamps_to_high(self, transport): + def test_gemini_openai_compat_xhigh_clamps_to_high(self, transport): msgs = [{"role": "user", "content": "Hi"}] kw = transport.build_kwargs( model="gemini-3-flash-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", reasoning_config={"enabled": True, "effort": "xhigh"}, ) - assert kw["extra_body"]["thinking_config"]["thinkingLevel"] == "high" + assert kw["extra_body"]["extra_body"]["google"]["thinking_config"]["thinking_level"] == "high" + + def test_google_gemini_cli_keeps_top_level_thinking_config(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gemini-3-flash-preview", + messages=msgs, + provider_name="google-gemini-cli", + reasoning_config={"enabled": True, "effort": "high"}, + ) + assert kw["extra_body"]["thinking_config"] == { + "includeThoughts": True, + "thinkingLevel": "high", + } + assert "google" not in kw["extra_body"] def test_gemini_flash_minimal_clamps_to_low(self, transport): # Gemini 3 Flash documents low/medium/high; "minimal" isn't accepted, @@ -199,11 +236,12 @@ class TestChatCompletionsBuildKwargs: model="gemini-3-flash-preview", messages=msgs, provider_name="gemini", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", reasoning_config={"enabled": True, "effort": "minimal"}, ) - assert kw["extra_body"]["thinking_config"] == { - "includeThoughts": True, - "thinkingLevel": "low", + assert kw["extra_body"]["extra_body"]["google"]["thinking_config"] == { + "include_thoughts": True, + "thinking_level": "low", } def test_max_tokens_with_fn(self, transport): diff --git a/tests/cli/test_branch_command.py b/tests/cli/test_branch_command.py index 581cdbdb6a..5e78815b8f 100644 --- a/tests/cli/test_branch_command.py +++ b/tests/cli/test_branch_command.py @@ -192,6 +192,33 @@ class TestBranchCommandCLI: assert cli_instance._resumed is True + def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db): + """The /branch command must notify memory providers of the rotation. + + Without this, providers that cache per-session state in + initialize() keep writing under the old session_id. See #6672. + """ + from cli import HermesCLI + + # Wire a real-ish agent object with a MagicMock memory_manager + agent = MagicMock() + mm = MagicMock() + agent._memory_manager = mm + cli_instance.agent = agent + original_id = cli_instance.session_id + + HermesCLI._handle_branch_command(cli_instance, "/branch") + + # Hook must have been called exactly once with the new session_id, + # parent pointing at the branched-from session, reset=False, and + # reason="branch" for diagnostics. + assert mm.on_session_switch.call_count == 1 + _, kwargs = mm.on_session_switch.call_args + assert mm.on_session_switch.call_args.args[0] == cli_instance.session_id + assert kwargs["parent_session_id"] == original_id + assert kwargs["reset"] is False + assert kwargs["reason"] == "branch" + def test_fork_alias(self): """The /fork alias should resolve to 'branch'.""" from hermes_cli.commands import resolve_command diff --git a/tests/cli/test_cli_init.py b/tests/cli/test_cli_init.py index b926d55f53..e4e6426325 100644 --- a/tests/cli/test_cli_init.py +++ b/tests/cli/test_cli_init.py @@ -296,6 +296,30 @@ class TestRootLevelProviderOverride: # Root-level "opencode-go" must NOT leak through assert cfg["model"]["provider"] != "opencode-go" + def test_terminal_vercel_runtime_bridged_to_env(self, tmp_path, monkeypatch): + """Classic CLI must expose terminal.vercel_runtime to terminal_tool.py.""" + import yaml + + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("TERMINAL_VERCEL_RUNTIME", raising=False) + + config_path = hermes_home / "config.yaml" + config_path.write_text(yaml.safe_dump({ + "terminal": { + "backend": "vercel_sandbox", + "vercel_runtime": "python3.13", + }, + })) + + import cli + monkeypatch.setattr(cli, "_hermes_home", hermes_home) + cfg = cli.load_cli_config() + + assert cfg["terminal"]["vercel_runtime"] == "python3.13" + assert os.environ["TERMINAL_VERCEL_RUNTIME"] == "python3.13" + def test_normalize_root_model_keys_moves_to_model(self): """_normalize_root_model_keys migrates root keys into model section.""" from hermes_cli.config import _normalize_root_model_keys diff --git a/tests/cli/test_cli_loading_indicator.py b/tests/cli/test_cli_loading_indicator.py index 6cec9eca3d..dd7bdb68d1 100644 --- a/tests/cli/test_cli_loading_indicator.py +++ b/tests/cli/test_cli_loading_indicator.py @@ -49,8 +49,15 @@ class TestCLILoadingIndicator: seen["status"] = cli_obj._command_status print("reload done") + # /reload-mcp now wraps the actual reload in a prompt-cache-invalidation + # confirmation prompt (commit 4d7fc0f37). This test exercises the + # loading-indicator path, not the confirmation UX, so pre-approve the + # reload via config so the handler goes straight into _reload_mcp(). + fake_cfg = {"approvals": {"mcp_reload_confirm": False}} + with patch.object(cli_obj, "_reload_mcp", side_effect=fake_reload), \ - patch.object(cli_obj, "_invalidate") as invalidate_mock: + patch.object(cli_obj, "_invalidate") as invalidate_mock, \ + patch("cli.load_cli_config", return_value=fake_cfg): assert cli_obj.process_command("/reload-mcp") output = capsys.readouterr().out diff --git a/tests/cli/test_cli_reload_skills.py b/tests/cli/test_cli_reload_skills.py new file mode 100644 index 0000000000..1b728bc3c1 --- /dev/null +++ b/tests/cli/test_cli_reload_skills.py @@ -0,0 +1,99 @@ +"""Tests for the ``/reload-skills`` CLI slash command (``HermesCLI._reload_skills``). + +The CLI handler prints the diff (name + description) for the user and — +when any skills were added or removed — queues a one-shot note on +``self._pending_skills_reload_note``. The note is prepended to the NEXT +user message (see cli.py ~L8770, same pattern as +``_pending_model_switch_note``) and cleared after use, so no phantom user +turn is persisted to ``conversation_history``. +""" + +from unittest.mock import patch + + +def _make_cli(): + """Build a minimal HermesCLI shell exposing ``_reload_skills``.""" + import cli as cli_mod + + obj = object.__new__(cli_mod.HermesCLI) + obj._command_running = False + obj.conversation_history = [] + obj.agent = None + return obj + + +class TestReloadSkillsCLI: + def test_reports_added_and_removed_and_queues_note(self, capsys): + cli = _make_cli() + with patch( + "agent.skill_commands.reload_skills", + return_value={ + "added": [ + {"name": "alpha", "description": "Run alpha to do xyz"}, + {"name": "beta", "description": "Run beta to do abc"}, + ], + "removed": [ + {"name": "gamma", "description": "Old removed skill"}, + ], + "unchanged": ["delta"], + "total": 3, + "commands": 3, + }, + ): + cli._reload_skills() + + out = capsys.readouterr().out + assert "Added Skills:" in out + assert "- alpha: Run alpha to do xyz" in out + assert "- beta: Run beta to do abc" in out + assert "Removed Skills:" in out + assert "- gamma: Old removed skill" in out + assert "3 skill(s) available" in out + + # Must NOT pollute conversation_history — alternation-safe. + assert cli.conversation_history == [] + + # One-shot note queued with system-prompt-style formatting. + note = getattr(cli, "_pending_skills_reload_note", None) + assert note is not None + assert note.startswith("[USER INITIATED SKILLS RELOAD:") + assert note.endswith("Use skills_list to see the updated catalog.]") + assert "Added Skills:" in note + assert " - alpha: Run alpha to do xyz" in note + assert " - beta: Run beta to do abc" in note + assert "Removed Skills:" in note + assert " - gamma: Old removed skill" in note + + def test_reports_no_changes_and_queues_nothing(self, capsys): + cli = _make_cli() + with patch( + "agent.skill_commands.reload_skills", + return_value={ + "added": [], + "removed": [], + "unchanged": ["alpha"], + "total": 1, + "commands": 1, + }, + ): + cli._reload_skills() + + out = capsys.readouterr().out + assert "No new skills detected" in out + assert "1 skill(s) available" in out + assert cli.conversation_history == [] + assert getattr(cli, "_pending_skills_reload_note", None) is None + + def test_handles_reload_failure_gracefully(self, capsys): + cli = _make_cli() + with patch( + "agent.skill_commands.reload_skills", + side_effect=RuntimeError("boom"), + ): + cli._reload_skills() + + out = capsys.readouterr().out + assert "Skills reload failed" in out + assert "boom" in out + assert cli.conversation_history == [] + assert getattr(cli, "_pending_skills_reload_note", None) is None diff --git a/tests/conftest.py b/tests/conftest.py index 6386e26ec1..f9ad9d9b2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ test runner at ``scripts/run_tests.sh``. """ import asyncio +import logging import os import re import signal @@ -174,7 +175,10 @@ _HERMES_BEHAVIORAL_VARS = frozenset({ "HERMES_SESSION_KEY", "HERMES_GATEWAY_SESSION", "HERMES_PLATFORM", + "HERMES_MODEL", + "HERMES_INFERENCE_MODEL", "HERMES_INFERENCE_PROVIDER", + "HERMES_TUI_PROVIDER", "HERMES_MANAGED", "HERMES_DEV", "HERMES_CONTAINER", @@ -184,6 +188,14 @@ _HERMES_BEHAVIORAL_VARS = frozenset({ "HERMES_BACKGROUND_NOTIFICATIONS", "HERMES_EXEC_ASK", "HERMES_HOME_MODE", + "TERMINAL_CWD", + "TERMINAL_ENV", + "TERMINAL_VERCEL_RUNTIME", + "TERMINAL_CONTAINER_CPU", + "TERMINAL_CONTAINER_DISK", + "TERMINAL_CONTAINER_MEMORY", + "TERMINAL_CONTAINER_PERSISTENT", + "TERMINAL_DOCKER_RUN_AS_HOST_USER", "BROWSER_CDP_URL", "CAMOFOX_URL", # Platform allowlists — not credentials, but if set from any source @@ -326,6 +338,14 @@ def _reset_module_state(): that don't exist yet (test collection before production import) are skipped silently — production import later creates fresh empty state. """ + # --- logging — quiet/one-shot paths mutate process-global logger state --- + logging.disable(logging.NOTSET) + for _logger_name in ("tools", "run_agent", "trajectory_compressor", "cron", "hermes_cli"): + _logger = logging.getLogger(_logger_name) + _logger.disabled = False + _logger.setLevel(logging.NOTSET) + _logger.propagate = True + # --- tools.approval — the single biggest source of cross-test pollution --- try: from tools import approval as _approval_mod @@ -380,6 +400,26 @@ def _reset_module_state(): except Exception: pass + # --- tools.terminal_tool — active environment/cwd cache --- + # File tools prefer a live terminal cwd when one is cached for the task. + # Clear terminal environments between tests so a prior terminal call can't + # override TERMINAL_CWD in path-resolution tests. + try: + from tools import terminal_tool as _term_mod + _envs_to_cleanup = [] + with _term_mod._env_lock: + _envs_to_cleanup = list(_term_mod._active_environments.values()) + _term_mod._active_environments.clear() + _term_mod._last_activity.clear() + _term_mod._creation_locks.clear() + for _env in _envs_to_cleanup: + try: + _env.cleanup() + except Exception: + pass + except Exception: + pass + # --- tools.credential_files — ContextVar --- try: from tools import credential_files as _credf_mod diff --git a/tests/cron/test_compute_next_run_last_run_at.py b/tests/cron/test_compute_next_run_last_run_at.py new file mode 100644 index 0000000000..0585aab09a --- /dev/null +++ b/tests/cron/test_compute_next_run_last_run_at.py @@ -0,0 +1,87 @@ +"""Test that compute_next_run uses last_run_at for cron jobs. + +Regression test for: cron jobs computing next_run_at from _hermes_now() +instead of from last_run_at, making them inconsistent with interval jobs. +""" +import pytest +from datetime import datetime +from zoneinfo import ZoneInfo + +pytest.importorskip("croniter") + +from cron.jobs import compute_next_run + + +class TestCronComputeNextRunUsesLastRunAt: + """compute_next_run MUST use last_run_at as the croniter base for cron jobs, + consistent with how interval jobs work.""" + + def test_cron_uses_last_run_at_for_every_6h_schedule(self, monkeypatch): + """For a schedule like 'every 6 hours', the base time matters. + If last_run_at is Apr 6 14:10, next should be Apr 6 18:00. + If now is Apr 10 22:00, next should be Apr 11 00:00. + compute_next_run must use last_run_at, not now.""" + morocco = ZoneInfo("Africa/Casablanca") + + # Job last ran April 6 at 14:10 + last_run = datetime(2026, 4, 6, 14, 10, 0, tzinfo=morocco) + + # But now it's April 10 at 22:00 (e.g., gateway restarted) + now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco) + monkeypatch.setattr("cron.jobs._hermes_now", lambda: now) + + schedule = {"kind": "cron", "expr": "0 */6 * * *"} # every 6 hours + + result = compute_next_run(schedule, last_run_at=last_run.isoformat()) + assert result is not None + next_dt = datetime.fromisoformat(result) + + # With last_run_at as base (Apr 6 14:10), next is Apr 6 18:00. + # With now as base (Apr 10 22:00), next is Apr 11 00:00. + # The fix should use last_run_at, returning Apr 6 18:00 + # (stale detection in get_due_jobs() fast-forwards from there). + assert next_dt.date().isoformat() == "2026-04-06", ( + f"Expected next run on Apr 6 (from last_run_at), got {next_dt}" + ) + assert next_dt.hour == 18 + + def test_cron_without_last_run_at_uses_now(self, monkeypatch): + """When last_run_at is NOT provided, compute_next_run falls back to + _hermes_now() as the croniter base (existing behavior).""" + morocco = ZoneInfo("Africa/Casablanca") + + now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco) + monkeypatch.setattr("cron.jobs._hermes_now", lambda: now) + + schedule = {"kind": "cron", "expr": "0 */6 * * *"} + + result = compute_next_run(schedule) + assert result is not None + next_dt = datetime.fromisoformat(result) + + # Without last_run_at, should compute from now -> Apr 11 00:00 + assert next_dt.date().isoformat() == "2026-04-11", ( + f"Expected next run on Apr 11 (from now), got {next_dt}" + ) + assert next_dt.hour == 0 + + def test_cron_weekly_consistent_with_interval(self, monkeypatch): + """Both cron and interval jobs should anchor to last_run_at when + provided, producing consistent behavior after a crash/restart.""" + morocco = ZoneInfo("Africa/Casablanca") + + last_run = datetime(2026, 4, 6, 14, 10, 0, tzinfo=morocco) + now = datetime(2026, 4, 10, 22, 0, 0, tzinfo=morocco) + monkeypatch.setattr("cron.jobs._hermes_now", lambda: now) + + cron_schedule = {"kind": "cron", "expr": "0 14 * * 1"} + interval_schedule = {"kind": "interval", "minutes": 7 * 24 * 60} + + cron_result = compute_next_run(cron_schedule, last_run_at=last_run.isoformat()) + interval_result = compute_next_run(interval_schedule, last_run_at=last_run.isoformat()) + + # Both should be after last_run_at + cron_dt = datetime.fromisoformat(cron_result) + interval_dt = datetime.fromisoformat(interval_result) + assert cron_dt > last_run, f"Cron next {cron_dt} should be after last_run {last_run}" + assert interval_dt > last_run, f"Interval next {interval_dt} should be after last_run {last_run}" diff --git a/tests/cron/test_cron_inactivity_timeout.py b/tests/cron/test_cron_inactivity_timeout.py index 0b83f64f07..67e932089f 100644 --- a/tests/cron/test_cron_inactivity_timeout.py +++ b/tests/cron/test_cron_inactivity_timeout.py @@ -169,10 +169,20 @@ class TestInactivityTimeout: assert result["final_response"] == "Done" + def _parse_cron_timeout(self, raw_value): + """Mirror the defensive parsing logic from cron/scheduler.py run_job().""" + if raw_value: + try: + return float(raw_value) + except (ValueError, TypeError): + return 600.0 + return 600.0 + def test_timeout_env_var_parsing(self, monkeypatch): """HERMES_CRON_TIMEOUT env var is respected.""" monkeypatch.setenv("HERMES_CRON_TIMEOUT", "1200") - _cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600)) + raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip() + _cron_timeout = self._parse_cron_timeout(raw) assert _cron_timeout == 1200.0 _cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None @@ -181,10 +191,27 @@ class TestInactivityTimeout: def test_timeout_zero_means_unlimited(self, monkeypatch): """HERMES_CRON_TIMEOUT=0 yields None (unlimited).""" monkeypatch.setenv("HERMES_CRON_TIMEOUT", "0") - _cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600)) + raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip() + _cron_timeout = self._parse_cron_timeout(raw) _cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None assert _cron_inactivity_limit is None + def test_timeout_invalid_value_falls_back_to_default(self, monkeypatch): + """HERMES_CRON_TIMEOUT=abc should fall back to 600s, not raise ValueError.""" + monkeypatch.setenv("HERMES_CRON_TIMEOUT", "abc") + raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip() + _cron_timeout = self._parse_cron_timeout(raw) + assert _cron_timeout == 600.0 + _cron_inactivity_limit = _cron_timeout if _cron_timeout > 0 else None + assert _cron_inactivity_limit == 600.0 + + def test_timeout_empty_string_uses_default(self, monkeypatch): + """HERMES_CRON_TIMEOUT='' (empty) should use the 600s default.""" + monkeypatch.setenv("HERMES_CRON_TIMEOUT", "") + raw = os.getenv("HERMES_CRON_TIMEOUT", "").strip() + _cron_timeout = self._parse_cron_timeout(raw) + assert _cron_timeout == 600.0 + def test_timeout_error_includes_diagnostics(self): """The TimeoutError message should include last activity info.""" agent = SlowFakeAgent( diff --git a/tests/cron/test_cron_workdir.py b/tests/cron/test_cron_workdir.py index 03777dd470..5f317c4f4c 100644 --- a/tests/cron/test_cron_workdir.py +++ b/tests/cron/test_cron_workdir.py @@ -265,6 +265,7 @@ class TestRunJobTerminalCwd: class FakeAgent: def __init__(self, **kwargs): observed["skip_context_files"] = kwargs.get("skip_context_files") + observed["load_soul_identity"] = kwargs.get("load_soul_identity") observed["terminal_cwd_during_init"] = os.environ.get( "TERMINAL_CWD", "_UNSET_" ) @@ -335,6 +336,7 @@ class TestRunJobTerminalCwd: # AIAgent was built with skip_context_files=False (feature ON). assert observed["skip_context_files"] is False + assert observed["load_soul_identity"] is True # TERMINAL_CWD was pointing at the job workdir while the agent ran. assert observed["terminal_cwd_during_init"] == str(tmp_path.resolve()) assert observed["terminal_cwd_during_run"] == str(tmp_path.resolve()) @@ -373,6 +375,8 @@ class TestRunJobTerminalCwd: # Feature is OFF — skip_context_files stays True. assert observed["skip_context_files"] is True + # Cron still forces SOUL.md identity even when cwd context files stay off. + assert observed["load_soul_identity"] is True # TERMINAL_CWD saw the same value during init as it had before. assert observed["terminal_cwd_during_init"] == before # And after run_job completes, it's still the sentinel (nothing diff --git a/tests/cron/test_scheduler.py b/tests/cron/test_scheduler.py index 23565511cf..638146989b 100644 --- a/tests/cron/test_scheduler.py +++ b/tests/cron/test_scheduler.py @@ -279,6 +279,44 @@ class TestResolveDeliveryTarget: "thread_id": None, } + def test_list_form_deliver_is_normalized(self, monkeypatch): + """deliver=['telegram'] (Python list) should resolve like 'telegram' string. + + Regression test for #17139: MCP clients / scripts that pass the deliver + field as an array-shaped value used to fail with "no delivery target + resolved for deliver=['telegram']" because ``str(['telegram'])`` was + passed through to ``split(',')`` verbatim. + """ + monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "-4004") + job = { + "deliver": ["telegram"], + "origin": None, + } + + assert _resolve_delivery_target(job) == { + "platform": "telegram", + "chat_id": "-4004", + "thread_id": None, + } + + def test_list_form_multiple_platforms_normalized(self, monkeypatch): + """deliver=['telegram', 'discord'] resolves to multiple targets.""" + from cron.scheduler import _resolve_delivery_targets + + monkeypatch.setenv("TELEGRAM_HOME_CHANNEL", "-111") + monkeypatch.setenv("DISCORD_HOME_CHANNEL", "-222") + job = {"deliver": ["telegram", "discord"], "origin": None} + + targets = _resolve_delivery_targets(job) + platforms = sorted(t["platform"] for t in targets) + assert platforms == ["discord", "telegram"] + + def test_empty_list_form_deliver_resolves_to_local(self): + """deliver=[] is treated as local (no delivery).""" + from cron.scheduler import _resolve_delivery_targets + + assert _resolve_delivery_targets({"deliver": []}) == [] + class TestDeliverResultWrapping: """Verify that cron deliveries are wrapped with header/footer and no longer mirrored.""" @@ -513,14 +551,14 @@ class TestDeliverResultWrapping: patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro): _deliver_result( job, - "MEDIA:/tmp/voice.ogg", + "[[audio_as_voice]]\nMEDIA:/tmp/voice.ogg", adapters={Platform.TELEGRAM: adapter}, loop=loop, ) # Text send should NOT be called (no text after stripping MEDIA tag) adapter.send.assert_not_called() - # Audio should still be delivered + # Audio should still be delivered as a voice bubble adapter.send_voice.assert_called_once() def test_live_adapter_sends_cleaned_text_not_raw(self): @@ -989,6 +1027,80 @@ class TestRunJobSessionPersistence: assert os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID") is None fake_db.close.assert_called_once() + def test_run_job_clears_stale_auto_delivery_thread_id_between_jobs(self, tmp_path, monkeypatch): + jobs = [ + { + "id": "threaded-job", + "name": "threaded", + "prompt": "hello", + "deliver": "telegram:-1001:42", + }, + { + "id": "threadless-job", + "name": "threadless", + "prompt": "hello again", + "deliver": "telegram:-2002", + }, + ] + fake_db = MagicMock() + seen = [] + + monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_PLATFORM", raising=False) + monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID", raising=False) + monkeypatch.delenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID", raising=False) + + class FakeAgent: + def __init__(self, *args, **kwargs): + pass + + def run_conversation(self, *args, **kwargs): + from gateway.session_context import get_session_env + + seen.append( + { + "platform": get_session_env("HERMES_CRON_AUTO_DELIVER_PLATFORM") or None, + "chat_id": get_session_env("HERMES_CRON_AUTO_DELIVER_CHAT_ID") or None, + "thread_id": get_session_env("HERMES_CRON_AUTO_DELIVER_THREAD_ID") or None, + } + ) + return {"final_response": "ok"} + + with patch("cron.scheduler._hermes_home", tmp_path), \ + patch("hermes_state.SessionDB", return_value=fake_db), \ + patch( + "hermes_cli.runtime_provider.resolve_runtime_provider", + return_value={ + "api_key": "***", + "base_url": "https://example.invalid/v1", + "provider": "openrouter", + "api_mode": "chat_completions", + }, + ), \ + patch("run_agent.AIAgent", FakeAgent): + for job in jobs: + success, output, final_response, error = run_job(job) + assert success is True + assert error is None + assert final_response == "ok" + assert "ok" in output + + assert seen == [ + { + "platform": "telegram", + "chat_id": "-1001", + "thread_id": "42", + }, + { + "platform": "telegram", + "chat_id": "-2002", + "thread_id": None, + }, + ] + assert os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM") is None + assert os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID") is None + assert os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID") is None + assert fake_db.close.call_count == 2 + class TestRunJobConfigLogging: """Verify that config.yaml parse failures are logged, not silently swallowed.""" diff --git a/tests/gateway/_plugin_adapter_loader.py b/tests/gateway/_plugin_adapter_loader.py new file mode 100644 index 0000000000..4174a7161c --- /dev/null +++ b/tests/gateway/_plugin_adapter_loader.py @@ -0,0 +1,72 @@ +"""Shared helper for loading platform-plugin ``adapter.py`` modules in tests. + +Every platform plugin under ``plugins/platforms//`` ships its own +``adapter.py``. If two tests independently do:: + + sys.path.insert(0, "plugins/platforms/irc") + from adapter import IRCAdapter + + sys.path.insert(0, "plugins/platforms/teams") + from adapter import TeamsAdapter + +…then whichever collects first in an xdist worker wins +``sys.modules["adapter"]``, and the other raises ``ImportError`` at +collection time. The fallout cascades across unrelated tests sharing that +worker because ``sys.path`` is still polluted. + +Use :func:`load_plugin_adapter` instead of ad-hoc ``sys.path`` tricks. +It loads the adapter from an explicit file path under a unique module +name (``plugin_adapter_``), so it cannot collide with any +other plugin's adapter module. + +The ``tests/gateway/conftest.py`` guard rejects the anti-pattern at +collection time so this can't regress when new plugin adapter tests are +added. +""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +_REPO_ROOT = Path(__file__).resolve().parents[2] +_PLUGINS_DIR = _REPO_ROOT / "plugins" / "platforms" + + +def load_plugin_adapter(plugin_name: str) -> ModuleType: + """Import ``plugins/platforms//adapter.py`` in isolation. + + The module is registered under the unique name + ``plugin_adapter_`` in ``sys.modules``. No ``sys.path`` + mutation. Safe to call multiple times — repeat calls return the + already-loaded module. + """ + module_name = f"plugin_adapter_{plugin_name}" + cached = sys.modules.get(module_name) + if cached is not None: + return cached + + adapter_path = _PLUGINS_DIR / plugin_name / "adapter.py" + if not adapter_path.is_file(): + raise FileNotFoundError( + f"Plugin adapter not found: {adapter_path}. " + f"Known plugins: {sorted(p.name for p in _PLUGINS_DIR.iterdir() if p.is_dir())}" + ) + + spec = importlib.util.spec_from_file_location(module_name, adapter_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not build import spec for {adapter_path}") + + module = importlib.util.module_from_spec(spec) + # Register BEFORE exec so the module can find itself if needed (some + # modules do ``sys.modules[__name__]`` reflection during import). + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(module_name, None) + raise + return module diff --git a/tests/gateway/conftest.py b/tests/gateway/conftest.py index 3e734e0d40..da8a2d3364 100644 --- a/tests/gateway/conftest.py +++ b/tests/gateway/conftest.py @@ -12,11 +12,32 @@ ImportError fallback, causing 30+ downstream test failures wherever Individual test files may still call their own ``_ensure_telegram_mock`` — it short-circuits when the mock is already present. + +Plugin-adapter anti-pattern guard +--------------------------------- +Tests for platform plugins (``plugins/platforms//adapter.py``) +must load the adapter via +:func:`tests.gateway._plugin_adapter_loader.load_plugin_adapter`, not by +adding the plugin directory to ``sys.path`` and doing a bare +``from adapter import ...``. The guard at the bottom of this file +scans test module ASTs at collection time and fails collection with a +pointer to the helper if the anti-pattern is detected. + +Rationale: every plugin ships its own ``adapter.py``, and two tests each +inserting their plugin dir on ``sys.path[0]`` race for +``sys.modules["adapter"]`` in the same xdist worker. Whichever collects +first wins; the other fails with ``ImportError``, and the polluted +``sys.path`` cascades into unrelated tests. See PR #17764 for the +incident. """ +import ast import sys +from pathlib import Path from unittest.mock import MagicMock +import pytest + def _ensure_telegram_mock() -> None: """Install a comprehensive telegram mock in sys.modules. @@ -197,3 +218,128 @@ def _ensure_discord_mock() -> None: # Run at collection time — before any test file's module-level imports. _ensure_telegram_mock() _ensure_discord_mock() + + +# --------------------------------------------------------------------------- +# Plugin-adapter anti-pattern guard +# --------------------------------------------------------------------------- + +_GATEWAY_DIR = Path(__file__).resolve().parent +_GUARD_HINT = ( + "Plugin adapter tests must use " + "``from tests.gateway._plugin_adapter_loader import load_plugin_adapter`` " + "and call ``load_plugin_adapter('')`` instead of inserting " + "``plugins/platforms//`` on sys.path and doing a bare ``import " + "adapter`` / ``from adapter import ...``. See the 'Plugin-adapter " + "anti-pattern guard' docstring in tests/gateway/conftest.py." +) + + +def _scan_for_plugin_adapter_antipattern(source: str) -> list[str]: + """Return a list of offending-line descriptions, or [] if clean. + + Flags two things: + 1. ``sys.path.insert(..., )`` + 2. ``import adapter`` or ``from adapter import ...`` at module level. + """ + try: + tree = ast.parse(source) + except SyntaxError: + return [] # Let pytest surface the real syntax error. + + offenses: list[str] = [] + + for node in ast.walk(tree): + # sys.path.insert(0, ".../plugins/platforms/...") + if isinstance(node, ast.Call): + func = node.func + target_name: str | None = None + if isinstance(func, ast.Attribute): + # sys.path.insert / sys.path.append + if ( + isinstance(func.value, ast.Attribute) + and isinstance(func.value.value, ast.Name) + and func.value.value.id == "sys" + and func.value.attr == "path" + and func.attr in ("insert", "append", "extend") + ): + target_name = f"sys.path.{func.attr}" + + if target_name is not None: + call_src = ast.unparse(node) + # Match both the string-literal form + # ``.../plugins/platforms/...`` and the Path-operator form + # ``Path(...) / 'plugins' / 'platforms' / ...`` that + # plugin tests typically use. + _src_no_ws = "".join(call_src.split()) + if ( + "plugins/platforms" in call_src + or "plugins\\platforms" in call_src + or "'plugins'/'platforms'" in _src_no_ws + or '"plugins"/"platforms"' in _src_no_ws + ): + offenses.append( + f"line {node.lineno}: {target_name}(...) points into " + f"plugins/platforms/" + ) + + # Bare `import adapter` / `from adapter import ...` anywhere (module level + # OR inside functions — both are symptoms of the same pattern). + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name == "adapter": + offenses.append( + f"line {node.lineno}: ``import adapter`` " + f"(bare — resolves to whichever plugin's adapter.py " + f"is first on sys.path)" + ) + elif isinstance(node, ast.ImportFrom): + if node.module == "adapter" and node.level == 0: + offenses.append( + f"line {node.lineno}: ``from adapter import ...`` " + f"(bare — resolves to whichever plugin's adapter.py " + f"is first on sys.path)" + ) + + return offenses + + +def pytest_configure(config): + """Reject plugin-adapter tests that use the sys.path anti-pattern. + + Runs once per pytest session on the controller, BEFORE any xdist + worker is spawned. If any file under ``tests/gateway/`` matches the + anti-pattern, we fail the whole session with a clear message — + before a polluted ``sys.path`` can cascade across workers. + """ + # Only run on the xdist controller (or in non-xdist runs). Skip on + # worker subprocesses so we don't scan the filesystem N times. + if hasattr(config, "workerinput"): + return + + violations: list[str] = [] + for path in _GATEWAY_DIR.rglob("test_*.py"): + if path.name in {"_plugin_adapter_loader.py", "conftest.py"}: + continue + try: + source = path.read_text(encoding="utf-8") + except OSError: + continue + if "adapter" not in source and "plugins/platforms" not in source: + continue + offenses = _scan_for_plugin_adapter_antipattern(source) + if offenses: + violations.append( + f" {path.relative_to(_GATEWAY_DIR.parent.parent)}:\n " + + "\n ".join(offenses) + ) + + if violations: + raise pytest.UsageError( + "Plugin-adapter-import anti-pattern detected in gateway tests:\n" + + "\n".join(violations) + + "\n\n" + + _GUARD_HINT + ) + diff --git a/tests/gateway/test_agent_cache.py b/tests/gateway/test_agent_cache.py index f3e63b0726..abf0ce3481 100644 --- a/tests/gateway/test_agent_cache.py +++ b/tests/gateway/test_agent_cache.py @@ -170,6 +170,22 @@ class TestAgentConfigSignature: ) assert sig_a == sig_b + def test_tool_registry_generation_change_busts_cache(self): + """MCP reloads mutate the tool registry, so cached agents must rebuild.""" + from gateway.run import GatewayRunner + + runtime = {"api_key": "k", "base_url": "u", "provider": "p"} + sig_before = GatewayRunner._agent_config_signature( + "m", runtime, ["telegram"], "", + cache_keys={"tools.registry_generation": 10}, + ) + sig_after = GatewayRunner._agent_config_signature( + "m", runtime, ["telegram"], "", + cache_keys={"tools.registry_generation": 11}, + ) + + assert sig_before != sig_after + class TestExtractCacheBustingConfig: """Verify _extract_cache_busting_config pulls the documented subset of @@ -229,6 +245,17 @@ class TestExtractCacheBustingConfig: out = GatewayRunner._extract_cache_busting_config(None) for section, key in GatewayRunner._CACHE_BUSTING_CONFIG_KEYS: assert out[f"{section}.{key}"] is None + assert "tools.registry_generation" in out + + def test_extract_includes_live_tool_registry_generation(self, monkeypatch): + from gateway.run import GatewayRunner + from tools.registry import registry + + monkeypatch.setattr(registry, "_generation", 12345) + + out = GatewayRunner._extract_cache_busting_config({}) + + assert out["tools.registry_generation"] == 12345 def test_full_round_trip_busts_cache_on_real_edit(self): """End-to-end: simulate a config edit on main and verify the diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 8285851064..2ebb48bcf4 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -314,6 +314,7 @@ def _create_app(adapter: APIServerAdapter) -> web.Application: app.router.add_get("/health/detailed", adapter._handle_health_detailed) app.router.add_get("/v1/health", adapter._handle_health) app.router.add_get("/v1/models", adapter._handle_models) + app.router.add_get("/v1/capabilities", adapter._handle_capabilities) app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions) app.router.add_post("/v1/responses", adapter._handle_responses) app.router.add_get("/v1/responses/{response_id}", adapter._handle_get_response) @@ -491,6 +492,46 @@ class TestModelsEndpoint: assert resp.status == 200 +# --------------------------------------------------------------------------- +# /v1/capabilities endpoint +# --------------------------------------------------------------------------- + + +class TestCapabilitiesEndpoint: + @pytest.mark.asyncio + async def test_capabilities_advertises_plugin_safe_contract(self, adapter): + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/capabilities") + assert resp.status == 200 + data = await resp.json() + assert data["object"] == "hermes.api_server.capabilities" + assert data["platform"] == "hermes-agent" + assert data["model"] == "hermes-agent" + assert data["auth"]["type"] == "bearer" + assert data["auth"]["required"] is False + assert data["features"]["chat_completions"] is True + assert data["features"]["run_status"] is True + assert data["features"]["run_events_sse"] is True + assert data["features"]["session_continuity_header"] == "X-Hermes-Session-Id" + assert data["endpoints"]["run_status"]["path"] == "/v1/runs/{run_id}" + + @pytest.mark.asyncio + async def test_capabilities_requires_auth_when_key_configured(self, auth_adapter): + app = _create_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/capabilities") + assert resp.status == 401 + + authed = await cli.get( + "/v1/capabilities", + headers={"Authorization": "Bearer sk-secret"}, + ) + assert authed.status == 200 + data = await authed.json() + assert data["auth"]["required"] is True + + # --------------------------------------------------------------------------- # /v1/chat/completions endpoint # --------------------------------------------------------------------------- @@ -647,17 +688,17 @@ class TestChatCompletionsEndpoint: @pytest.mark.asyncio async def test_stream_includes_tool_progress(self, adapter): - """tool_progress_callback fires → progress appears as custom SSE event, not in delta.content.""" + """tool_start_callback fires → progress appears as custom SSE event, not in delta.content.""" import asyncio app = _create_app(adapter) async with TestClient(TestServer(app)) as cli: async def _mock_run_agent(**kwargs): cb = kwargs.get("stream_delta_callback") - tp_cb = kwargs.get("tool_progress_callback") - # Simulate tool progress before streaming content - if tp_cb: - tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"}) + ts_cb = kwargs.get("tool_start_callback") + # Simulate the structured tool start the gateway now consumes. + if ts_cb: + ts_cb("call_terminal_1", "terminal", {"command": "ls -la"}) if cb: await asyncio.sleep(0.05) cb("Here are the files.") @@ -683,7 +724,10 @@ class TestChatCompletionsEndpoint: # markers instead of calling tools (#6972). assert "event: hermes.tool.progress" in body assert '"tool": "terminal"' in body - assert '"label": "ls -la"' in body + # ``label`` is now derived by ``build_tool_preview`` from the + # tool args rather than passed by the caller, so we assert + # only that *some* label exists rather than a literal value. + assert '"label":' in body # The progress marker must NOT appear inside any # chat.completion.chunk delta.content field. import json as _json @@ -703,17 +747,17 @@ class TestChatCompletionsEndpoint: @pytest.mark.asyncio async def test_stream_tool_progress_skips_internal_events(self, adapter): - """Internal events (name starting with _) are not streamed.""" + """Internal tool calls (name starting with ``_``) are not streamed.""" import asyncio app = _create_app(adapter) async with TestClient(TestServer(app)) as cli: async def _mock_run_agent(**kwargs): cb = kwargs.get("stream_delta_callback") - tp_cb = kwargs.get("tool_progress_callback") - if tp_cb: - tp_cb("tool.started", "_thinking", "some internal state", {}) - tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"}) + ts_cb = kwargs.get("tool_start_callback") + if ts_cb: + ts_cb("call_internal_1", "_thinking", {"text": "some internal state"}) + ts_cb("call_search_1", "web_search", {"query": "Python docs"}) if cb: await asyncio.sleep(0.05) cb("Found it.") @@ -735,10 +779,142 @@ class TestChatCompletionsEndpoint: body = await resp.text() # Internal _thinking event should NOT appear anywhere assert "some internal state" not in body + assert "call_internal_1" not in body # Real tool progress should appear as custom SSE event assert "event: hermes.tool.progress" in body assert '"tool": "web_search"' in body - assert '"label": "Python docs"' in body + # Label is derived from the args dict by build_tool_preview; + # asserting on the structural fact (label exists, call id + # is correlated) rather than a literal preview string keeps + # the test robust against preview-formatter tweaks. + assert '"label":' in body + assert '"toolCallId": "call_search_1"' in body + + @pytest.mark.asyncio + async def test_stream_emits_tool_lifecycle_with_call_id(self, adapter): + """Regression for #16588. + + ``/v1/chat/completions`` streaming previously emitted only a + ``tool.started``-style ``hermes.tool.progress`` event; clients + rendering tool lifecycle UI had no way to mark a tool as finished + because no matching ``status: completed`` event was emitted, and + no ``toolCallId`` was carried for correlation. + + The fix adds ``tool_start_callback`` / ``tool_complete_callback`` + to the chat completions agent invocation and writes both halves + of the lifecycle pair on the same ``event: hermes.tool.progress`` + SSE line, with stable ``toolCallId`` and ``status``. + """ + import asyncio + import json as _json + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + ts_cb = kwargs.get("tool_start_callback") + tc_cb = kwargs.get("tool_complete_callback") + # The structured callbacks own the chat-completions SSE + # channel now; ``tool_progress_callback`` is intentionally + # not wired so each tool start emits exactly one event. + if ts_cb: + ts_cb("call_terminal_1", "terminal", {"command": "ls -la"}) + if tc_cb: + tc_cb("call_terminal_1", "terminal", {"command": "ls -la"}, "ok") + if cb: + await asyncio.sleep(0.05) + cb("done.") + return ( + {"final_response": "done.", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "list"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + + # Walk the SSE body and collect *(status, toolCallId)* pairs + # per event so the assertions verify per-event correlation — + # an event missing ``toolCallId`` would not pass even if a + # different event happens to carry the right id. + pairs: list[tuple[str | None, str | None]] = [] + lines = body.splitlines() + for i, line in enumerate(lines): + if line.strip() != "event: hermes.tool.progress": + continue + for follow in lines[i + 1: i + 4]: + if follow.startswith("data: "): + try: + payload = _json.loads(follow[len("data: "):]) + except _json.JSONDecodeError: + break + pairs.append((payload.get("status"), payload.get("toolCallId"))) + break + + # Each tool start must emit exactly one event (no duplicate + # legacy + new emit), and each lifecycle pair must carry the + # same toolCallId on every event — not just somewhere in the + # aggregate. + assert len(pairs) == 2, f"expected 2 events (running+completed), got {pairs}" + assert pairs[0] == ("running", "call_terminal_1"), pairs + assert pairs[1] == ("completed", "call_terminal_1"), pairs + + @pytest.mark.asyncio + async def test_stream_tool_lifecycle_skips_internal_and_orphan_completes(self, adapter): + """Internal tools (``_thinking``-style) and ``completed`` events + without a prior matching ``running`` must produce no lifecycle + events on the wire — otherwise clients would see orphaned + ``status: completed`` updates they cannot correlate.""" + import asyncio + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + ts_cb = kwargs.get("tool_start_callback") + tc_cb = kwargs.get("tool_complete_callback") + # Internal tool — must be filtered. + if ts_cb: + ts_cb("call_internal_1", "_thinking", {}) + if tc_cb: + tc_cb("call_internal_1", "_thinking", {}, "") + # Completion without start — orphan, must be dropped. + if tc_cb: + tc_cb("call_orphan_1", "web_search", {}, "ok") + if cb: + await asyncio.sleep(0.05) + cb("ok.") + return ( + {"final_response": "ok.", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "ok"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + + # Neither the internal call_id nor the orphan call_id should + # surface as a lifecycle payload on the wire. + assert "call_internal_1" not in body + assert "call_orphan_1" not in body + assert '"status": "running"' not in body + assert '"status": "completed"' not in body @pytest.mark.asyncio async def test_no_user_message_returns_400(self, adapter): diff --git a/tests/gateway/test_api_server_runs.py b/tests/gateway/test_api_server_runs.py index e485bad5ce..900eb3c869 100644 --- a/tests/gateway/test_api_server_runs.py +++ b/tests/gateway/test_api_server_runs.py @@ -1,7 +1,8 @@ -"""Tests for /v1/runs endpoints: start, events, and stop. +"""Tests for /v1/runs endpoints: start, status, events, and stop. Covers: - POST /v1/runs — start a run (202) +- GET /v1/runs/{run_id} — poll run status - GET /v1/runs/{run_id}/events — SSE event stream - POST /v1/runs/{run_id}/stop — interrupt a running agent - Auth, error handling, and cleanup @@ -46,6 +47,7 @@ def _create_runs_app(adapter: APIServerAdapter) -> web.Application: app = web.Application(middlewares=mws) app["api_server_adapter"] = adapter app.router.add_post("/v1/runs", adapter._handle_runs) + app.router.add_get("/v1/runs/{run_id}", adapter._handle_get_run) app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events) app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run) return app @@ -116,6 +118,13 @@ class TestStartRun: assert data["status"] == "started" assert data["run_id"].startswith("run_") + status_resp = await cli.get(f"/v1/runs/{data['run_id']}") + assert status_resp.status == 200 + status = await status_resp.json() + assert status["run_id"] == data["run_id"] + assert status["status"] in {"queued", "running", "completed"} + assert status["object"] == "hermes.run" + @pytest.mark.asyncio async def test_start_invalid_json_returns_400(self, adapter): app = _create_runs_app(adapter) @@ -143,6 +152,18 @@ class TestStartRun: resp = await cli.post("/v1/runs", json={"input": ""}) assert resp.status == 400 + @pytest.mark.asyncio + async def test_start_invalid_history_does_not_allocate_run(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post( + "/v1/runs", + json={"input": "hello", "conversation_history": {"role": "user"}}, + ) + assert resp.status == 400 + assert adapter._run_streams == {} + assert adapter._run_statuses == {} + @pytest.mark.asyncio async def test_start_requires_auth(self, auth_adapter): app = _create_runs_app(auth_adapter) @@ -170,6 +191,89 @@ class TestStartRun: assert resp.status == 202 +# --------------------------------------------------------------------------- +# GET /v1/runs/{run_id} — poll run status +# --------------------------------------------------------------------------- + + +class TestRunStatus: + @pytest.mark.asyncio + async def test_status_completed_run_includes_output_and_usage(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "done"} + mock_agent.session_prompt_tokens = 4 + mock_agent.session_completion_tokens = 2 + mock_agent.session_total_tokens = 6 + mock_create.return_value = mock_agent + + resp = await cli.post("/v1/runs", json={"input": "hello"}) + data = await resp.json() + run_id = data["run_id"] + + for _ in range(20): + status_resp = await cli.get(f"/v1/runs/{run_id}") + assert status_resp.status == 200 + status = await status_resp.json() + if status["status"] == "completed": + break + await asyncio.sleep(0.05) + + assert status["status"] == "completed" + assert status["output"] == "done" + assert status["usage"]["total_tokens"] == 6 + assert status["last_event"] == "run.completed" + + @pytest.mark.asyncio + async def test_status_reflects_explicit_session_id(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "done"} + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + mock_create.return_value = mock_agent + + resp = await cli.post( + "/v1/runs", + json={"input": "hello", "session_id": "space-session"}, + ) + data = await resp.json() + run_id = data["run_id"] + + for _ in range(20): + status_resp = await cli.get(f"/v1/runs/{run_id}") + status = await status_resp.json() + if status["status"] == "completed": + break + await asyncio.sleep(0.05) + + mock_agent.run_conversation.assert_called_once() + # task_id stays "default" so the Runs API shares one sandbox + # container with CLI/gateway; session_id is surfaced in status + # for external UIs to correlate runs with their own session IDs. + assert mock_agent.run_conversation.call_args.kwargs["task_id"] == "default" + assert status["session_id"] == "space-session" + + @pytest.mark.asyncio + async def test_status_not_found_returns_404(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/runs/run_nonexistent") + assert resp.status == 404 + + @pytest.mark.asyncio + async def test_status_requires_auth(self, auth_adapter): + app = _create_runs_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/runs/run_any") + assert resp.status == 401 + + # --------------------------------------------------------------------------- # GET /v1/runs/{run_id}/events — SSE event stream # --------------------------------------------------------------------------- @@ -257,6 +361,11 @@ class TestStopRun: # Agent interrupt should have been called mock_agent.interrupt.assert_called_once_with("Stop requested via API") + status_resp = await cli.get(f"/v1/runs/{run_id}") + assert status_resp.status == 200 + status_data = await status_resp.json() + assert status_data["status"] in {"stopping", "cancelled"} + # Refs should be cleaned up await asyncio.sleep(0.5) assert run_id not in adapter._active_run_agents diff --git a/tests/gateway/test_config_cwd_bridge.py b/tests/gateway/test_config_cwd_bridge.py index af967af24b..2366625388 100644 --- a/tests/gateway/test_config_cwd_bridge.py +++ b/tests/gateway/test_config_cwd_bridge.py @@ -33,6 +33,11 @@ def _simulate_config_bridge(cfg: dict, initial_env: dict | None = None): "backend": "TERMINAL_ENV", "cwd": "TERMINAL_CWD", "timeout": "TERMINAL_TIMEOUT", + "vercel_runtime": "TERMINAL_VERCEL_RUNTIME", + "container_persistent": "TERMINAL_CONTAINER_PERSISTENT", + "container_cpu": "TERMINAL_CONTAINER_CPU", + "container_memory": "TERMINAL_CONTAINER_MEMORY", + "container_disk": "TERMINAL_CONTAINER_DISK", } for cfg_key, env_var in terminal_env_map.items(): if cfg_key in terminal_cfg: @@ -240,3 +245,24 @@ class TestTildeExpansion: } result = _simulate_config_bridge(cfg) assert result["TERMINAL_CWD"] == os.path.expanduser("~/nested") + + +class TestVercelTerminalBridge: + def test_vercel_terminal_settings_bridge(self): + cfg = { + "terminal": { + "backend": "vercel_sandbox", + "vercel_runtime": "python3.13", + "container_persistent": True, + "container_cpu": 2, + "container_memory": 4096, + "container_disk": 51200, + } + } + result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/from/env"}) + assert result["TERMINAL_ENV"] == "vercel_sandbox" + assert result["TERMINAL_VERCEL_RUNTIME"] == "python3.13" + assert result["TERMINAL_CONTAINER_PERSISTENT"] == "True" + assert result["TERMINAL_CONTAINER_CPU"] == "2" + assert result["TERMINAL_CONTAINER_MEMORY"] == "4096" + assert result["TERMINAL_CONTAINER_DISK"] == "51200" diff --git a/tests/gateway/test_irc_adapter.py b/tests/gateway/test_irc_adapter.py new file mode 100644 index 0000000000..a1718fbdaf --- /dev/null +++ b/tests/gateway/test_irc_adapter.py @@ -0,0 +1,502 @@ +"""Tests for the IRC platform adapter plugin.""" + +import asyncio +import os +import sys +import pytest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from tests.gateway._plugin_adapter_loader import load_plugin_adapter + +# Load plugins/platforms/irc/adapter.py under a unique module name +# (plugin_adapter_irc) so it cannot collide with other plugin adapters +# loaded by sibling tests in the same xdist worker. +_irc_mod = load_plugin_adapter("irc") + +_parse_irc_message = _irc_mod._parse_irc_message +_extract_nick = _irc_mod._extract_nick +IRCAdapter = _irc_mod.IRCAdapter +check_requirements = _irc_mod.check_requirements +validate_config = _irc_mod.validate_config +register = _irc_mod.register + + +class TestIRCProtocolHelpers: + + def test_parse_simple_command(self): + msg = _parse_irc_message("PING :server.example.com") + assert msg["command"] == "PING" + assert msg["params"] == ["server.example.com"] + assert msg["prefix"] == "" + + def test_parse_prefixed_message(self): + msg = _parse_irc_message(":nick!user@host PRIVMSG #channel :Hello world") + assert msg["prefix"] == "nick!user@host" + assert msg["command"] == "PRIVMSG" + assert msg["params"] == ["#channel", "Hello world"] + + def test_parse_numeric_reply(self): + msg = _parse_irc_message(":server 001 hermes-bot :Welcome to IRC") + assert msg["prefix"] == "server" + assert msg["command"] == "001" + assert msg["params"] == ["hermes-bot", "Welcome to IRC"] + + def test_parse_nick_collision(self): + msg = _parse_irc_message(":server 433 * hermes-bot :Nickname is already in use") + assert msg["command"] == "433" + + def test_extract_nick_full_prefix(self): + assert _extract_nick("nick!user@host") == "nick" + + def test_extract_nick_bare(self): + assert _extract_nick("server.example.com") == "server.example.com" + + +# ── IRC Adapter ────────────────────────────────────────────────────────── + + +class TestIRCAdapterInit: + + def test_init_from_env(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_PORT", "6667") + monkeypatch.setenv("IRC_NICKNAME", "testbot") + monkeypatch.setenv("IRC_CHANNEL", "#test") + monkeypatch.setenv("IRC_USE_TLS", "false") + + from gateway.config import PlatformConfig + cfg = PlatformConfig(enabled=True) + adapter = IRCAdapter(cfg) + + assert adapter.server == "irc.test.net" + assert adapter.port == 6667 + assert adapter.nickname == "testbot" + assert adapter.channel == "#test" + assert adapter.use_tls is False + + def test_init_from_config_extra(self, monkeypatch): + # Clear any env vars + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "irc.libera.chat", + "port": 6697, + "nickname": "hermes", + "channel": "#hermes-dev", + "use_tls": True, + }, + ) + adapter = IRCAdapter(cfg) + + assert adapter.server == "irc.libera.chat" + assert adapter.port == 6697 + assert adapter.nickname == "hermes" + assert adapter.channel == "#hermes-dev" + assert adapter.use_tls is True + + def test_env_overrides_config(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "env-server.net") + + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={"server": "config-server.net", "channel": "#ch"}, + ) + adapter = IRCAdapter(cfg) + assert adapter.server == "env-server.net" + + +class TestIRCAdapterSend: + + @pytest.fixture + def adapter(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "testbot", + "channel": "#test", + "use_tls": False, + }, + ) + return IRCAdapter(cfg) + + @pytest.mark.asyncio + async def test_send_not_connected(self, adapter): + result = await adapter.send("#test", "hello") + assert result.success is False + assert "Not connected" in result.error + + @pytest.mark.asyncio + async def test_send_success(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + result = await adapter.send("#test", "hello world") + assert result.success is True + assert result.message_id is not None + # Verify PRIVMSG was sent + writer.write.assert_called() + sent_data = writer.write.call_args[0][0] + assert b"PRIVMSG #test :hello world" in sent_data + + @pytest.mark.asyncio + async def test_send_splits_long_messages(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + long_msg = "x" * 1000 + result = await adapter.send("#test", long_msg) + assert result.success is True + # Should have been split into multiple PRIVMSG calls + assert writer.write.call_count > 1 + + +class TestIRCAdapterMessageParsing: + + @pytest.fixture + def adapter(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "hermes", + "channel": "#test", + "use_tls": False, + }, + ) + a = IRCAdapter(cfg) + a._current_nick = "hermes" + a._registered = True + return a + + @pytest.mark.asyncio + async def test_handle_ping(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + await adapter._handle_line("PING :test-server") + sent = writer.write.call_args[0][0] + assert b"PONG :test-server" in sent + + @pytest.mark.asyncio + async def test_handle_welcome(self, adapter): + adapter._registered = False + adapter._registration_event = asyncio.Event() + + await adapter._handle_line(":server 001 hermes :Welcome to IRC") + assert adapter._registered is True + assert adapter._registration_event.is_set() + + @pytest.mark.asyncio + async def test_handle_nick_collision(self, adapter): + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + await adapter._handle_line(":server 433 * hermes :Nickname in use") + assert adapter._current_nick == "hermes_" + sent = writer.write.call_args[0][0] + assert b"NICK hermes_" in sent + + @pytest.mark.asyncio + async def test_handle_addressed_channel_message(self, adapter): + """Messages addressed to the bot (nick: msg) should be dispatched.""" + handler = AsyncMock(return_value="response") + adapter._message_handler = handler + + # Mock handle_message to capture the event + dispatched = [] + original_dispatch = adapter._dispatch_message + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + + await adapter._handle_line(":user!u@host PRIVMSG #test :hermes: hello there") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "hello there" + assert dispatched[0]["chat_id"] == "#test" + + @pytest.mark.asyncio + async def test_ignores_unaddressed_channel_message(self, adapter): + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG #test :just talking") + assert len(dispatched) == 0 + + @pytest.mark.asyncio + async def test_handle_dm(self, adapter): + """DMs (target == bot nick) should always be dispatched.""" + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG hermes :private message") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "private message" + assert dispatched[0]["chat_type"] == "dm" + assert dispatched[0]["chat_id"] == "user" + + @pytest.mark.asyncio + async def test_ignores_own_messages(self, adapter): + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":hermes!bot@host PRIVMSG #test :my own msg") + assert len(dispatched) == 0 + + @pytest.mark.asyncio + async def test_ctcp_action_converted(self, adapter): + """CTCP ACTION (/me) should be converted to text.""" + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":user!u@host PRIVMSG hermes :\x01ACTION waves\x01") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "* user waves" + + @pytest.mark.asyncio + async def test_allowed_users_case_insensitive(self, monkeypatch): + """Allowlist should match nicks case-insensitively.""" + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "hermes", + "channel": "#test", + "use_tls": False, + "allowed_users": ["Admin", "BOB"], + }, + ) + adapter = IRCAdapter(cfg) + adapter._current_nick = "hermes" + adapter._registered = True + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + # "admin" matches "Admin" in allowlist + await adapter._handle_line(":admin!u@host PRIVMSG #test :hermes: hello") + assert len(dispatched) == 1 + assert dispatched[0]["text"] == "hello" + + @pytest.mark.asyncio + async def test_unauthorized_user_blocked(self, monkeypatch): + """Nicks not in allowlist should be ignored.""" + for key in ("IRC_SERVER", "IRC_PORT", "IRC_NICKNAME", "IRC_CHANNEL", "IRC_USE_TLS"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig( + enabled=True, + extra={ + "server": "localhost", + "port": 6667, + "nickname": "hermes", + "channel": "#test", + "use_tls": False, + "allowed_users": ["Admin", "BOB"], + }, + ) + adapter = IRCAdapter(cfg) + adapter._current_nick = "hermes" + adapter._registered = True + dispatched = [] + + async def capture_dispatch(**kwargs): + dispatched.append(kwargs) + + adapter._dispatch_message = capture_dispatch + adapter._message_handler = AsyncMock() + + await adapter._handle_line(":eve!u@host PRIVMSG #test :hermes: hello") + assert len(dispatched) == 0 + + @pytest.mark.asyncio + async def test_nick_collision_retry(self, adapter): + """Multiple 433 responses should keep incrementing the suffix.""" + writer = MagicMock() + writer.is_closing = MagicMock(return_value=False) + writer.write = MagicMock() + writer.drain = AsyncMock() + adapter._writer = writer + + await adapter._handle_line(":server 433 * hermes :Nickname in use") + assert adapter._current_nick == "hermes_" + await adapter._handle_line(":server 433 * hermes_ :Nickname in use") + assert adapter._current_nick == "hermes_1" + await adapter._handle_line(":server 433 * hermes_1 :Nickname in use") + assert adapter._current_nick == "hermes_2" + + +class TestIRCAdapterSplitting: + + def test_split_respects_byte_limit(self): + """Multi-byte characters should not exceed IRC byte limit.""" + # 100 japanese chars = 300 bytes in utf-8 + text = "あ" * 100 + from gateway.config import PlatformConfig + cfg = PlatformConfig(enabled=True, extra={"server": "x", "channel": "#x"}) + adapter = IRCAdapter(cfg) + adapter._current_nick = "bot" + lines = adapter._split_message(text, "#test") + for line in lines: + overhead = len(f"PRIVMSG #test :{line}\r\n".encode("utf-8")) + assert overhead <= 512, f"line over 512 bytes: {overhead}" + + def test_split_prefers_word_boundary(self): + text = "hello world foo bar baz qux" + from gateway.config import PlatformConfig + cfg = PlatformConfig(enabled=True, extra={"server": "x", "channel": "#x"}) + adapter = IRCAdapter(cfg) + adapter._current_nick = "bot" + lines = adapter._split_message(text, "#test") + # Should not split in the middle of "world" + assert any("hello" in ln for ln in lines) + assert any("world" in ln for ln in lines) + + +class TestIRCProtocolHelpersExtra: + + def test_parse_malformed_no_space(self): + """A line starting with : but no space should not crash.""" + msg = _parse_irc_message(":justaprefix") + assert msg["prefix"] == "justaprefix" + assert msg["command"] == "" + assert msg["params"] == [] + + def test_parse_empty(self): + msg = _parse_irc_message("") + assert msg["prefix"] == "" + assert msg["command"] == "" + assert msg["params"] == [] + + +class TestIRCAdapterMarkdown: + + def test_strip_bold(self): + assert IRCAdapter._strip_markdown("**bold**") == "bold" + + def test_strip_italic(self): + assert IRCAdapter._strip_markdown("*italic*") == "italic" + + def test_strip_code(self): + assert IRCAdapter._strip_markdown("`code`") == "code" + + def test_strip_link(self): + result = IRCAdapter._strip_markdown("[click here](https://example.com)") + assert result == "click here (https://example.com)" + + def test_strip_image(self): + result = IRCAdapter._strip_markdown("![alt](https://example.com/img.png)") + assert result == "https://example.com/img.png" + + +# ── Requirements / validation ──────────────────────────────────────────── + + +class TestIRCRequirements: + + def test_check_requirements_with_env(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_CHANNEL", "#test") + assert check_requirements() is True + + def test_check_requirements_missing_server(self, monkeypatch): + monkeypatch.delenv("IRC_SERVER", raising=False) + monkeypatch.setenv("IRC_CHANNEL", "#test") + assert check_requirements() is False + + def test_check_requirements_missing_channel(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.delenv("IRC_CHANNEL", raising=False) + assert check_requirements() is False + + def test_validate_config_from_extra(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_CHANNEL"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig(extra={"server": "irc.test.net", "channel": "#test"}) + assert validate_config(cfg) is True + + def test_validate_config_missing(self, monkeypatch): + for key in ("IRC_SERVER", "IRC_CHANNEL"): + monkeypatch.delenv(key, raising=False) + from gateway.config import PlatformConfig + cfg = PlatformConfig(extra={}) + assert validate_config(cfg) is False + + +# ── Plugin registration ────────────────────────────────────────────────── + + +class TestIRCPluginRegistration: + """Test the register() entry point.""" + + def test_register_adds_to_registry(self, monkeypatch): + monkeypatch.setenv("IRC_SERVER", "irc.test.net") + monkeypatch.setenv("IRC_CHANNEL", "#test") + + from gateway.platform_registry import platform_registry + + # Clean up if already registered + platform_registry.unregister("irc") + + ctx = MagicMock() + register(ctx) + ctx.register_platform.assert_called_once() + call_kwargs = ctx.register_platform.call_args + assert call_kwargs[1]["name"] == "irc" or call_kwargs[0][0] == "irc" if call_kwargs[0] else call_kwargs[1]["name"] == "irc" diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 722fc9f703..75e1a1e148 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1276,9 +1276,10 @@ class TestMatrixUploadAndSend: mock_client.send_message_event = AsyncMock(return_value="$event") adapter._client = mock_client - result = await adapter._upload_and_send( - "!room:example.org", b"secret", "secret.txt", "text/plain", "m.file", - ) + with patch.dict("sys.modules", _make_fake_mautrix()): + result = await adapter._upload_and_send( + "!room:example.org", b"secret", "secret.txt", "text/plain", "m.file", + ) assert result.success is True # Should have uploaded ciphertext, not plaintext diff --git a/tests/gateway/test_platform_base.py b/tests/gateway/test_platform_base.py index 59246b7990..a6e0d51d60 100644 --- a/tests/gateway/test_platform_base.py +++ b/tests/gateway/test_platform_base.py @@ -323,6 +323,55 @@ class TestExtractMedia: assert "Here" in cleaned assert "After" in cleaned + def test_media_tag_supports_unquoted_flac_paths_with_spaces(self): + content = "MEDIA:/tmp/Jane Doe/speech.flac" + media, cleaned = BasePlatformAdapter.extract_media(content) + assert media == [("/tmp/Jane Doe/speech.flac", False)] + assert cleaned == "" + + +# --------------------------------------------------------------------------- +# should_send_media_as_audio +# --------------------------------------------------------------------------- + +class TestShouldSendMediaAsAudio: + """Audio-routing policy shared by gateway + scheduler + send_message.""" + + def test_unknown_extension_returns_false(self): + from gateway.platforms.base import should_send_media_as_audio + assert should_send_media_as_audio(None, ".png") is False + assert should_send_media_as_audio("telegram", ".pdf") is False + + def test_non_telegram_platforms_route_all_audio(self): + from gateway.platforms.base import should_send_media_as_audio + for ext in (".mp3", ".m4a", ".wav", ".flac", ".ogg", ".opus"): + assert should_send_media_as_audio("discord", ext) is True + assert should_send_media_as_audio("slack", ext) is True + + def test_telegram_mp3_and_m4a_route_to_audio(self): + from gateway.platforms.base import should_send_media_as_audio + assert should_send_media_as_audio("telegram", ".mp3") is True + assert should_send_media_as_audio("telegram", ".m4a") is True + + def test_telegram_wav_and_flac_fall_through_to_document(self): + from gateway.platforms.base import should_send_media_as_audio + assert should_send_media_as_audio("telegram", ".wav") is False + assert should_send_media_as_audio("telegram", ".flac") is False + + def test_telegram_ogg_opus_only_when_voice_flagged(self): + from gateway.platforms.base import should_send_media_as_audio + assert should_send_media_as_audio("telegram", ".ogg", is_voice=True) is True + assert should_send_media_as_audio("telegram", ".opus", is_voice=True) is True + assert should_send_media_as_audio("telegram", ".ogg") is False + assert should_send_media_as_audio("telegram", ".opus") is False + + def test_accepts_platform_enum(self): + from gateway.config import Platform + from gateway.platforms.base import should_send_media_as_audio + assert should_send_media_as_audio(Platform.TELEGRAM, ".mp3") is True + assert should_send_media_as_audio(Platform.TELEGRAM, ".flac") is False + assert should_send_media_as_audio(Platform.DISCORD, ".flac") is True + # --------------------------------------------------------------------------- # truncate_message diff --git a/tests/gateway/test_platform_connected_checkers.py b/tests/gateway/test_platform_connected_checkers.py new file mode 100644 index 0000000000..ba16ac4954 --- /dev/null +++ b/tests/gateway/test_platform_connected_checkers.py @@ -0,0 +1,99 @@ +""" +Verify that every gateway platform — built-in and plugin — has a connection +checker so ``GatewayConfig.get_connected_platforms()`` doesn't silently drop +platforms with bespoke auth requirements. +""" + +from unittest.mock import MagicMock + +import pytest + +from gateway.config import Platform, _PLATFORM_CONNECTED_CHECKERS, _BUILTIN_PLATFORM_VALUES + + +def test_all_builtins_have_checker_or_generic_token_path(): + """Every built-in Platform member must be reachable by either: + + 1. The generic ``config.token or config.api_key`` check, OR + 2. A platform-specific entry in ``_PLATFORM_CONNECTED_CHECKERS``. + + This guarantees ``get_connected_platforms()`` doesn't silently ignore + a built-in just because nobody added it to the checker dict. + """ + # Platforms covered by the generic token/api_key branch + generic_token_values = {p.value for p in { + Platform.TELEGRAM, + Platform.DISCORD, + Platform.SLACK, + Platform.MATRIX, + Platform.MATTERMOST, + Platform.HOMEASSISTANT, + }} + + # Platforms with a bespoke checker + checker_values = {p.value for p in set(_PLATFORM_CONNECTED_CHECKERS.keys())} + + # Every built-in should be in one of the two sets + all_builtins = set(_BUILTIN_PLATFORM_VALUES) + missing = all_builtins - generic_token_values - checker_values - {"local"} + + assert not missing, ( + f"Built-in platforms missing a connection checker: " + f"{sorted(missing)}. " + f"Add them to _PLATFORM_CONNECTED_CHECKERS or generic_token_platforms." + ) + + +@pytest.mark.parametrize("platform, checker", list(_PLATFORM_CONNECTED_CHECKERS.items())) +def test_checker_handles_minimal_config(platform, checker): + """Each bespoke checker must not crash on a minimal PlatformConfig.""" + mock_config = MagicMock() + mock_config.extra = {} + mock_config.token = None + mock_config.api_key = None + mock_config.enabled = True + + # Should return a bool without raising + result = checker(mock_config) + assert isinstance(result, bool) + + +@pytest.mark.parametrize("platform, checker", list(_PLATFORM_CONNECTED_CHECKERS.items())) +def test_checker_returns_true_when_configured(platform, checker, monkeypatch): + """Each bespoke checker must return True when the config looks valid.""" + mock_config = MagicMock() + mock_config.token = None + mock_config.api_key = None + mock_config.enabled = True + + # Set up platform-specific mock extra fields so the checker succeeds + if platform == Platform.WEIXIN: + mock_config.extra = {"account_id": "123", "token": "***"} + elif platform == Platform.SIGNAL: + mock_config.extra = {"http_url": "http://signal:8080"} + elif platform == Platform.EMAIL: + mock_config.extra = {"address": "hermes@example.com"} + elif platform == Platform.SMS: + monkeypatch.setenv("TWILIO_ACCOUNT_SID", "ACtest") + mock_config.extra = {} + elif platform in (Platform.API_SERVER, Platform.WEBHOOK, Platform.WHATSAPP): + mock_config.extra = {} + elif platform == Platform.FEISHU: + mock_config.extra = {"app_id": "app"} + elif platform == Platform.WECOM: + mock_config.extra = {"bot_id": "bot"} + elif platform == Platform.WECOM_CALLBACK: + mock_config.extra = {"corp_id": "corp"} + elif platform == Platform.BLUEBUBBLES: + mock_config.extra = {"server_url": "http://bb:1234", "password": "pw"} + elif platform == Platform.QQBOT: + mock_config.extra = {"app_id": "app", "client_secret": "sec"} + elif platform == Platform.YUANBAO: + mock_config.extra = {"app_id": "app", "app_secret": "sec"} + elif platform == Platform.DINGTALK: + mock_config.extra = {"client_id": "id", "client_secret": "sec"} + else: + pytest.skip(f"No synthetic config defined for {platform.value}") + + result = checker(mock_config) + assert result is True, f"{platform.value} checker should return True with valid-looking config" diff --git a/tests/gateway/test_platform_reconnect.py b/tests/gateway/test_platform_reconnect.py index 5667427232..a0bd7ab9ee 100644 --- a/tests/gateway/test_platform_reconnect.py +++ b/tests/gateway/test_platform_reconnect.py @@ -14,8 +14,15 @@ from gateway.run import GatewayRunner class StubAdapter(BasePlatformAdapter): """Adapter whose connect() result can be controlled.""" - def __init__(self, *, succeed=True, fatal_error=None, fatal_retryable=True): - super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM) + def __init__( + self, + *, + platform=Platform.TELEGRAM, + succeed=True, + fatal_error=None, + fatal_retryable=True, + ): + super().__init__(PlatformConfig(enabled=True, token="test"), platform) self._succeed = succeed self._fatal_error = fatal_error self._fatal_retryable = fatal_retryable @@ -65,6 +72,85 @@ def _make_runner(): # --- Startup queueing --- +class TestStartupPlatformIsolation: + """Verify one blocked platform cannot prevent later platforms from starting.""" + + @pytest.mark.asyncio + async def test_start_continues_after_platform_connect_timeout(self, tmp_path): + """A timeout on Telegram should queue it and still connect Feishu.""" + runner = _make_runner() + runner.config = GatewayConfig( + platforms={ + Platform.TELEGRAM: PlatformConfig(enabled=True, token="test"), + Platform.FEISHU: PlatformConfig(enabled=True, token="test"), + }, + sessions_dir=tmp_path, + ) + runner.hooks = MagicMock() + runner.hooks.loaded_hooks = [] + runner.hooks.emit = AsyncMock() + runner._suspend_stuck_loop_sessions = MagicMock(return_value=0) + runner._update_runtime_status = MagicMock() + runner._update_platform_runtime_status = MagicMock() + runner._sync_voice_mode_state_to_adapter = MagicMock() + runner._send_update_notification = AsyncMock(return_value=True) + runner._send_restart_notification = AsyncMock() + + adapters = { + Platform.TELEGRAM: StubAdapter(platform=Platform.TELEGRAM), + Platform.FEISHU: StubAdapter(platform=Platform.FEISHU), + } + runner._create_adapter = MagicMock( + side_effect=lambda platform, _config: adapters[platform] + ) + runner._connect_adapter_with_timeout = AsyncMock( + side_effect=[ + TimeoutError("telegram connect timed out after 30s"), + True, + ] + ) + + def fake_create_task(coro): + coro.close() + return MagicMock() + + with patch("gateway.status.write_runtime_status"): + with patch("hermes_cli.plugins.discover_plugins"): + with patch("hermes_cli.config.load_config", return_value={}): + with patch("agent.shell_hooks.register_from_config"): + with patch( + "tools.process_registry.process_registry.recover_from_checkpoint", + return_value=0, + ): + with patch( + "gateway.channel_directory.build_channel_directory", + new=AsyncMock(return_value={"platforms": {}}), + ): + with patch("gateway.run.asyncio.create_task", side_effect=fake_create_task): + assert await runner.start() is True + + assert Platform.TELEGRAM in runner._failed_platforms + assert Platform.FEISHU in runner.adapters + assert Platform.TELEGRAM not in runner.adapters + assert runner._create_adapter.call_count == 2 + + @pytest.mark.asyncio + async def test_connect_adapter_timeout_raises_retryable_exception(self, monkeypatch): + """The timeout helper turns a hanging connect into a caught startup error.""" + runner = _make_runner() + adapter = StubAdapter() + + async def hang(): + await asyncio.sleep(60) + return True + + adapter.connect = hang + monkeypatch.setenv("HERMES_GATEWAY_PLATFORM_CONNECT_TIMEOUT", "0.001") + + with pytest.raises(TimeoutError, match="telegram connect timed out"): + await runner._connect_adapter_with_timeout(adapter, Platform.TELEGRAM) + + class TestStartupFailureQueuing: """Verify that failed platforms are queued during startup.""" diff --git a/tests/gateway/test_platform_registry.py b/tests/gateway/test_platform_registry.py new file mode 100644 index 0000000000..e6bb823aa6 --- /dev/null +++ b/tests/gateway/test_platform_registry.py @@ -0,0 +1,396 @@ +"""Tests for the platform adapter registry and dynamic Platform enum.""" + +import os +import pytest +from unittest.mock import MagicMock, patch +from dataclasses import dataclass + +from gateway.platform_registry import PlatformRegistry, PlatformEntry, platform_registry +from gateway.config import Platform, PlatformConfig, GatewayConfig + + +# ── Platform enum dynamic members ───────────────────────────────────────── + + +class TestPlatformEnumDynamic: + """Test that Platform enum accepts unknown values for plugin platforms.""" + + def test_builtin_members_still_work(self): + assert Platform.TELEGRAM.value == "telegram" + assert Platform("telegram") is Platform.TELEGRAM + + def test_dynamic_member_created(self): + p = Platform("irc") + assert p.value == "irc" + assert p.name == "IRC" + + def test_dynamic_member_identity_stable(self): + """Same value returns same object (cached).""" + a = Platform("irc") + b = Platform("irc") + assert a is b + + def test_dynamic_member_case_normalised(self): + """Mixed case normalised to lowercase.""" + a = Platform("IRC") + b = Platform("irc") + assert a is b + assert a.value == "irc" + + def test_dynamic_member_with_hyphens(self): + """Registered plugin platforms with hyphens work once registered.""" + from gateway.platform_registry import platform_registry as _reg + + entry = PlatformEntry( + name="my-platform", + label="My Platform", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + source="plugin", + ) + _reg.register(entry) + try: + p = Platform("my-platform") + assert p.value == "my-platform" + assert p.name == "MY_PLATFORM" + finally: + _reg.unregister("my-platform") + + def test_dynamic_member_rejects_unregistered(self): + """Arbitrary strings are rejected to prevent enum pollution.""" + with pytest.raises(ValueError): + Platform("totally-fake-platform") + + def test_dynamic_member_rejects_non_string(self): + with pytest.raises(ValueError): + Platform(123) + + def test_dynamic_member_rejects_empty(self): + with pytest.raises(ValueError): + Platform("") + + def test_dynamic_member_rejects_whitespace_only(self): + with pytest.raises(ValueError): + Platform(" ") + + +# ── PlatformRegistry ────────────────────────────────────────────────────── + + +class TestPlatformRegistry: + """Test the PlatformRegistry itself.""" + + def _make_entry(self, name="test", check_ok=True, validate_ok=True, factory_ok=True): + adapter_mock = MagicMock() + return PlatformEntry( + name=name, + label=name.title(), + adapter_factory=lambda cfg, _m=adapter_mock: _m if factory_ok else (_ for _ in ()).throw(RuntimeError("factory error")), + check_fn=lambda: check_ok, + validate_config=lambda cfg: validate_ok, + required_env=[], + source="plugin", + ), adapter_mock + + def test_register_and_get(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("alpha") + reg.register(entry) + assert reg.get("alpha") is entry + assert reg.is_registered("alpha") + + def test_get_unknown_returns_none(self): + reg = PlatformRegistry() + assert reg.get("nonexistent") is None + + def test_unregister(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("beta") + reg.register(entry) + assert reg.unregister("beta") is True + assert reg.get("beta") is None + assert reg.unregister("beta") is False # already gone + + def test_create_adapter_success(self): + reg = PlatformRegistry() + entry, mock_adapter = self._make_entry("gamma") + reg.register(entry) + result = reg.create_adapter("gamma", MagicMock()) + assert result is mock_adapter + + def test_create_adapter_unknown_name(self): + reg = PlatformRegistry() + assert reg.create_adapter("unknown", MagicMock()) is None + + def test_create_adapter_check_fails(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("delta", check_ok=False) + reg.register(entry) + assert reg.create_adapter("delta", MagicMock()) is None + + def test_create_adapter_validate_fails(self): + reg = PlatformRegistry() + entry, _ = self._make_entry("epsilon", validate_ok=False) + reg.register(entry) + assert reg.create_adapter("epsilon", MagicMock()) is None + + def test_create_adapter_factory_exception(self): + reg = PlatformRegistry() + entry = PlatformEntry( + name="broken", + label="Broken", + adapter_factory=lambda cfg: (_ for _ in ()).throw(RuntimeError("boom")), + check_fn=lambda: True, + validate_config=None, + source="plugin", + ) + reg.register(entry) + # factory raises → create_adapter returns None instead of propagating + assert reg.create_adapter("broken", MagicMock()) is None + + def test_create_adapter_no_validate(self): + """When validate_config is None, skip validation.""" + reg = PlatformRegistry() + mock_adapter = MagicMock() + entry = PlatformEntry( + name="novalidate", + label="NoValidate", + adapter_factory=lambda cfg: mock_adapter, + check_fn=lambda: True, + validate_config=None, + source="plugin", + ) + reg.register(entry) + assert reg.create_adapter("novalidate", MagicMock()) is mock_adapter + + def test_all_entries(self): + reg = PlatformRegistry() + e1, _ = self._make_entry("one") + e2, _ = self._make_entry("two") + reg.register(e1) + reg.register(e2) + names = {e.name for e in reg.all_entries()} + assert names == {"one", "two"} + + def test_plugin_entries(self): + reg = PlatformRegistry() + plugin_entry, _ = self._make_entry("plugged") + builtin_entry = PlatformEntry( + name="core", + label="Core", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + source="builtin", + ) + reg.register(plugin_entry) + reg.register(builtin_entry) + plugin_names = {e.name for e in reg.plugin_entries()} + assert plugin_names == {"plugged"} + + def test_re_register_replaces(self): + reg = PlatformRegistry() + entry1, mock1 = self._make_entry("dup") + entry2 = PlatformEntry( + name="dup", + label="Dup v2", + adapter_factory=lambda cfg: "v2", + check_fn=lambda: True, + source="plugin", + ) + reg.register(entry1) + reg.register(entry2) + assert reg.get("dup").label == "Dup v2" + + +# ── GatewayConfig integration ──────────────────────────────────────────── + + +class TestGatewayConfigPluginPlatform: + """Test that GatewayConfig parses and validates plugin platforms.""" + + def test_from_dict_accepts_plugin_platform(self): + data = { + "platforms": { + "telegram": {"enabled": True, "token": "test-token"}, + "irc": {"enabled": True, "extra": {"server": "irc.libera.chat"}}, + } + } + cfg = GatewayConfig.from_dict(data) + platform_values = {p.value for p in cfg.platforms} + assert "telegram" in platform_values + assert "irc" in platform_values + + def test_get_connected_platforms_includes_registered_plugin(self): + """Plugin platform with registry entry passes get_connected_platforms.""" + # Register a fake plugin platform + from gateway.platform_registry import platform_registry as _reg + + test_entry = PlatformEntry( + name="testplat", + label="TestPlat", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + validate_config=lambda cfg: bool(cfg.extra.get("token")), + source="plugin", + ) + _reg.register(test_entry) + try: + data = { + "platforms": { + "testplat": {"enabled": True, "extra": {"token": "abc"}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "testplat" in connected_values + finally: + _reg.unregister("testplat") + + def test_get_connected_platforms_excludes_unregistered_plugin(self): + """Plugin platform without registry entry is excluded.""" + data = { + "platforms": { + "unknown_plugin": {"enabled": True, "extra": {"token": "abc"}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "unknown_plugin" not in connected_values + + def test_get_connected_platforms_excludes_invalid_config(self): + """Plugin platform with failing validate_config is excluded.""" + from gateway.platform_registry import platform_registry as _reg + + test_entry = PlatformEntry( + name="badconfig", + label="BadConfig", + adapter_factory=lambda cfg: MagicMock(), + check_fn=lambda: True, + validate_config=lambda cfg: False, # always fails + source="plugin", + ) + _reg.register(test_entry) + try: + data = { + "platforms": { + "badconfig": {"enabled": True, "extra": {}}, + } + } + cfg = GatewayConfig.from_dict(data) + connected = cfg.get_connected_platforms() + connected_values = {p.value for p in connected} + assert "badconfig" not in connected_values + finally: + _reg.unregister("badconfig") + + +# ── Extended PlatformEntry fields ───────────────────────────────────── + + +class TestPlatformEntryExtendedFields: + """Test the auth, message length, and display fields on PlatformEntry.""" + + def test_default_field_values(self): + entry = PlatformEntry( + name="test", + label="Test", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + ) + assert entry.allowed_users_env == "" + assert entry.allow_all_env == "" + assert entry.max_message_length == 0 + assert entry.pii_safe is False + assert entry.emoji == "🔌" + assert entry.allow_update_command is True + + def test_custom_auth_fields(self): + entry = PlatformEntry( + name="irc", + label="IRC", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + allowed_users_env="IRC_ALLOWED_USERS", + allow_all_env="IRC_ALLOW_ALL_USERS", + max_message_length=450, + pii_safe=False, + emoji="💬", + ) + assert entry.allowed_users_env == "IRC_ALLOWED_USERS" + assert entry.allow_all_env == "IRC_ALLOW_ALL_USERS" + assert entry.max_message_length == 450 + assert entry.emoji == "💬" + + +# ── Cron platform resolution ───────────────────────────────────────── + + +class TestCronPlatformResolution: + """Test that cron delivery accepts plugin platform names.""" + + def test_builtin_platform_resolves(self): + """Built-in platform names resolve via Platform() call.""" + p = Platform("telegram") + assert p is Platform.TELEGRAM + + def test_plugin_platform_resolves(self): + """Plugin platform names create dynamic enum members.""" + p = Platform("irc") + assert p.value == "irc" + + def test_invalid_platform_type_rejected(self): + """Non-string values are still rejected.""" + with pytest.raises(ValueError): + Platform(None) + + +# ── platforms.py integration ────────────────────────────────────────── + + +class TestPlatformsMerge: + """Test get_all_platforms() merges with registry.""" + + def test_get_all_platforms_includes_builtins(self): + from hermes_cli.platforms import get_all_platforms, PLATFORMS + merged = get_all_platforms() + for key in PLATFORMS: + assert key in merged + + def test_get_all_platforms_includes_plugin(self): + from hermes_cli.platforms import get_all_platforms + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="testmerge", + label="TestMerge", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + source="plugin", + emoji="🧪", + )) + try: + merged = get_all_platforms() + assert "testmerge" in merged + assert "TestMerge" in merged["testmerge"].label + finally: + _reg.unregister("testmerge") + + def test_platform_label_plugin_fallback(self): + from hermes_cli.platforms import platform_label + from gateway.platform_registry import platform_registry as _reg + + _reg.register(PlatformEntry( + name="labeltest", + label="LabelTest", + adapter_factory=lambda cfg: None, + check_fn=lambda: True, + source="plugin", + emoji="🏷️", + )) + try: + label = platform_label("labeltest") + assert "LabelTest" in label + finally: + _reg.unregister("labeltest") diff --git a/tests/gateway/test_plugin_platform_interface.py b/tests/gateway/test_plugin_platform_interface.py new file mode 100644 index 0000000000..c2392cf827 --- /dev/null +++ b/tests/gateway/test_plugin_platform_interface.py @@ -0,0 +1,230 @@ +""" +Interface compliance tests for all plugin-based gateway platforms. + +Discovers platforms dynamically under ``plugins/platforms/`` — no manual +enumeration — and verifies each one implements the required contract. +""" + +import importlib +import sys +from pathlib import Path +from types import ModuleType +from typing import Any +from unittest.mock import MagicMock + +import pytest + +PROJECT_ROOT = Path(__file__).parent.parent.resolve() +PLATFORMS_DIR = PROJECT_ROOT / "plugins" / "platforms" + + +def _discover_platform_plugins() -> list[str]: + """Return names of all bundled platform plugins.""" + if not PLATFORMS_DIR.is_dir(): + return [] + names = [] + for child in sorted(PLATFORMS_DIR.iterdir()): + if child.is_dir() and (child / "__init__.py").exists(): + names.append(child.name) + return names + + +# Dynamically parametrise over discovered platforms +_PLATFORM_NAMES = _discover_platform_plugins() + + +@pytest.fixture +def clean_registry(): + """Yield with a clean platform registry, restoring state afterwards.""" + from gateway.platform_registry import platform_registry + + original = dict(platform_registry._entries) + platform_registry._entries.clear() + yield platform_registry + platform_registry._entries.clear() + platform_registry._entries.update(original) + + +class _MockPluginContext: + """Minimal mock of hermes_cli.plugins.PluginContext. + + Only implements register_platform so we can exercise the plugin's + register() entrypoint without importing the real plugin system. + """ + + def __init__(self): + self.registered_names: list[str] = [] + + def register_platform( + self, + *, + name: str, + label: str, + adapter_factory: Any, + check_fn: Any, + **kwargs: Any, + ) -> None: + from gateway.platform_registry import platform_registry, PlatformEntry + + entry = PlatformEntry( + name=name, + label=label, + adapter_factory=adapter_factory, + check_fn=check_fn, + **kwargs, + ) + platform_registry.register(entry) + self.registered_names.append(name) + + +def _import_platform_module(name: str) -> ModuleType: + """Import plugins.platforms. in a test-safe way.""" + # Make sure the project root is on sys.path so relative imports work + if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + module = importlib.import_module(f"plugins.platforms.{name}") + return module + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_plugin_exposes_register_function(platform_name: str): + """Every platform plugin must expose a callable register function.""" + module = _import_platform_module(platform_name) + assert hasattr(module, "register"), f"{platform_name} missing register()" + assert callable(module.register), f"{platform_name}.register not callable" + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_plugin_registers_valid_platform_entry(platform_name: str, clean_registry): + """Calling register() must create a valid PlatformEntry.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + assert platform_name in ctx.registered_names + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None, f"{platform_name} did not register an entry" + assert entry.name == platform_name + assert entry.label + assert callable(entry.adapter_factory) + assert callable(entry.check_fn) + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_platform_entry_has_required_fields(platform_name: str, clean_registry): + """PlatformEntry must have the mandatory metadata fields.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None + + # Mandatory fields + assert isinstance(entry.name, str) and entry.name + assert isinstance(entry.label, str) and entry.label + assert callable(entry.adapter_factory) + assert callable(entry.check_fn) + + # Optional but recommended fields + if entry.validate_config is not None: + assert callable(entry.validate_config) + if entry.is_connected is not None: + assert callable(entry.is_connected) + if entry.setup_fn is not None: + assert callable(entry.setup_fn) + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_adapter_factory_produces_valid_adapter(platform_name: str, clean_registry): + """The adapter factory must return an object with the base interface.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None + + # Build a minimal synthetic config that shouldn't crash __init__ + mock_config = MagicMock() + mock_config.extra = {} + mock_config.enabled = True + mock_config.token = None + mock_config.api_key = None + mock_config.home_channel = None + mock_config.reply_to_mode = "first" + + adapter = entry.adapter_factory(mock_config) + assert adapter is not None, f"{platform_name} adapter_factory returned None" + + # Required adapter interface + assert hasattr(adapter, "connect") and callable(adapter.connect) + assert hasattr(adapter, "disconnect") and callable(adapter.disconnect) + assert hasattr(adapter, "send") and callable(adapter.send) + assert hasattr(adapter, "name") + + # Should be a BasePlatformAdapter subclass if importable + try: + from gateway.platforms.base import BasePlatformAdapter + assert isinstance(adapter, BasePlatformAdapter) + except Exception: + pytest.skip("BasePlatformAdapter not available for isinstance check") + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_check_fn_returns_bool(platform_name: str, clean_registry): + """check_fn() must return a boolean.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None + + result = entry.check_fn() + assert isinstance(result, bool), f"{platform_name}.check_fn() returned {type(result)}, expected bool" + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_validate_config_if_present(platform_name: str, clean_registry): + """If validate_config is provided, it must accept a config object.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None + + if entry.validate_config is None: + pytest.skip("No validate_config provided") + + mock_config = MagicMock() + mock_config.extra = {} + result = entry.validate_config(mock_config) + assert isinstance(result, bool) + + +@pytest.mark.parametrize("platform_name", _PLATFORM_NAMES) +def test_is_connected_if_present(platform_name: str, clean_registry): + """If is_connected is provided, it must accept a config object.""" + module = _import_platform_module(platform_name) + ctx = _MockPluginContext() + module.register(ctx) + + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform_name) + assert entry is not None + + if entry.is_connected is None: + pytest.skip("No is_connected provided") + + mock_config = MagicMock() + mock_config.extra = {} + result = entry.is_connected(mock_config) + assert isinstance(result, bool) diff --git a/tests/gateway/test_reload_skills_command.py b/tests/gateway/test_reload_skills_command.py new file mode 100644 index 0000000000..5b9804bb1d --- /dev/null +++ b/tests/gateway/test_reload_skills_command.py @@ -0,0 +1,200 @@ +"""Tests for the ``/reload-skills`` gateway slash command handler. + +Verifies: + * dispatcher routes ``/reload-skills`` to ``_handle_reload_skills_command`` + * the underscored alias ``/reload_skills`` is not flagged as unknown + * the handler invokes ``agent.skill_commands.reload_skills`` and renders a + human-readable diff + * when any skills changed, a one-shot note is queued on + ``runner._pending_skills_reload_notes[session_key]`` (the agent loop + consumes and clears it on the next user turn — see ``gateway/run.py`` + near the ``_has_fresh_tool_tail`` block) + * the handler does NOT append to the session transcript out-of-band — + message alternation must not be broken by a phantom user turn +""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent(text=text, source=_make_source(), message_id="m1") + + +def _make_runner(): + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace( + emit=AsyncMock(), + emit_collect=AsyncMock(return_value=[]), + loaded_hooks=False, + ) + + session_entry = SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.load_transcript.return_value = [] + runner.session_store.has_any_sessions.return_value = True + runner.session_store.append_to_transcript = MagicMock() + runner.session_store.rewrite_transcript = MagicMock() + runner.session_store.update_session = MagicMock() + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._reasoning_config = None + runner._provider_routing = {} + runner._fallback_model = None + runner._show_reasoning = False + runner._is_user_authorized = lambda _source: True + runner._set_session_env = lambda _context: None + runner._should_send_voice_reply = lambda *_args, **_kwargs: False + # Use the real _session_key_for_source binding so the key matches what + # the agent-loop consumer will look up later. + from gateway.run import GatewayRunner as _GR + runner._session_key_for_source = _GR._session_key_for_source.__get__(runner, _GR) + return runner + + +@pytest.mark.asyncio +async def test_reload_skills_handler_queues_note_on_diff(monkeypatch): + """Diff non-empty → handler queues a one-shot note and does NOT touch transcript.""" + fake_result = { + "added": [ + {"name": "alpha", "description": "Run alpha to do xyz"}, + {"name": "beta", "description": "Run beta to do abc"}, + ], + "removed": [ + {"name": "gamma", "description": "Old removed skill"}, + ], + "unchanged": ["delta"], + "total": 3, + "commands": 3, + } + + import agent.skill_commands as skill_commands_mod + monkeypatch.setattr(skill_commands_mod, "reload_skills", lambda: fake_result) + + runner = _make_runner() + event = _make_event("/reload-skills") + out = await runner._handle_reload_skills_command(event) + + assert out is not None + assert "Skills Reloaded" in out + assert "Added Skills:" in out + assert "- alpha: Run alpha to do xyz" in out + assert "- beta: Run beta to do abc" in out + assert "Removed Skills:" in out + assert "- gamma: Old removed skill" in out + assert "3 skill(s) available" in out + + # MUST NOT write to the session transcript — that would break alternation. + runner.session_store.append_to_transcript.assert_not_called() + + # MUST have queued a one-shot note keyed on the session. + pending = getattr(runner, "_pending_skills_reload_notes", None) + assert pending is not None + session_key = runner._session_key_for_source(event.source) + assert session_key in pending + note = pending[session_key] + assert note.startswith("[USER INITIATED SKILLS RELOAD:") + assert note.endswith("Use skills_list to see the updated catalog.]") + assert "Added Skills:" in note + assert " - alpha: Run alpha to do xyz" in note + assert " - beta: Run beta to do abc" in note + assert "Removed Skills:" in note + assert " - gamma: Old removed skill" in note + + +@pytest.mark.asyncio +async def test_reload_skills_handler_reports_no_changes(monkeypatch): + """No diff → no queued note, no transcript write.""" + import agent.skill_commands as skill_commands_mod + + monkeypatch.setattr( + skill_commands_mod, + "reload_skills", + lambda: { + "added": [], + "removed": [], + "unchanged": ["alpha"], + "total": 1, + "commands": 1, + }, + ) + + runner = _make_runner() + out = await runner._handle_reload_skills_command(_make_event("/reload-skills")) + + assert "No new skills detected" in out + assert "1 skill(s) available" in out + runner.session_store.append_to_transcript.assert_not_called() + # No queued note when nothing changed. + pending = getattr(runner, "_pending_skills_reload_notes", None) + assert not pending # None or empty dict + + +@pytest.mark.asyncio +async def test_dispatcher_routes_reload_skills(monkeypatch): + """``/reload-skills`` must reach ``_handle_reload_skills_command``.""" + import gateway.run as gateway_run + + runner = _make_runner() + sentinel = "reload-skills handler reached" + runner._handle_reload_skills_command = AsyncMock(return_value=sentinel) # type: ignore[attr-defined] + + monkeypatch.setattr( + gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"} + ) + + result = await runner._handle_message(_make_event("/reload-skills")) + assert result == sentinel + + +@pytest.mark.asyncio +async def test_underscored_alias_not_flagged_unknown(monkeypatch): + """Telegram autocomplete sends ``/reload_skills`` for ``/reload-skills``.""" + import gateway.run as gateway_run + + runner = _make_runner() + runner._handle_reload_skills_command = AsyncMock(return_value="ok") # type: ignore[attr-defined] + + monkeypatch.setattr( + gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"} + ) + + result = await runner._handle_message(_make_event("/reload_skills")) + if result is not None: + assert "Unknown command" not in result diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index 42377325e9..0d2060ef31 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -230,3 +230,30 @@ class TestHandleResumeCommand: assert real_key not in runner._running_agents db.close() + + @pytest.mark.asyncio + async def test_resume_evicts_cached_agent(self, tmp_path): + """Gateway /resume evicts the cached AIAgent so the next message + rebuilds with the correct session_id end-to-end — mirrors /branch + and /reset. Without this, the cached agent's memory provider keeps + writing into the wrong session. See #6672. + """ + import threading + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("old_session", "telegram") + db.set_session_title("old_session", "Old Work") + db.create_session("current_session_001", "telegram") + + event = _make_event(text="/resume Old Work") + runner = _make_runner(session_db=db, current_session_id="current_session_001", + event=event) + # Seed the cache with a fake agent + real_key = _session_key_for_event(event) + runner._agent_cache = {real_key: (MagicMock(), object())} + runner._agent_cache_lock = threading.RLock() + + await runner._handle_resume_command(event) + + assert real_key not in runner._agent_cache + db.close() diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 49fb91d449..478a9e2773 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -67,14 +67,20 @@ class NonEditingProgressCaptureAdapter(ProgressCaptureAdapter): class FakeAgent: def __init__(self, **kwargs): + # Capture anything passed via kwargs (older code path) but don't + # freeze it — production now assigns tool_progress_callback after + # construction (see gateway/run.py around the agent-cache hit), + # so we must read it at call time, not at init. self.tool_progress_callback = kwargs.get("tool_progress_callback") self.tools = [] def run_conversation(self, message, conversation_history=None, task_id=None): - self.tool_progress_callback("tool.started", "terminal", "pwd", {}) - time.sleep(0.35) - self.tool_progress_callback("tool.started", "browser_navigate", "https://example.com", {}) - time.sleep(0.35) + cb = self.tool_progress_callback + if cb is not None: + cb("tool.started", "terminal", "pwd", {}) + time.sleep(0.35) + cb("tool.started", "browser_navigate", "https://example.com", {}) + time.sleep(0.35) return { "final_response": "done", "messages": [], @@ -251,6 +257,14 @@ async def test_run_agent_progress_does_not_use_event_message_id_for_telegram_dm( async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch, tmp_path): """Slack DM progress should keep event ts fallback threading.""" monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all") + # Since PR #8006, Slack's built-in display tier sets tool_progress="off" + # by default. Override via config so this test still exercises the + # progress-callback path the Slack DM event_message_id threading depends on. + import yaml + (tmp_path / "config.yaml").write_text( + yaml.dump({"display": {"platforms": {"slack": {"tool_progress": "all"}}}}), + encoding="utf-8", + ) fake_dotenv = types.ModuleType("dotenv") fake_dotenv.load_dotenv = lambda *args, **kwargs: None diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 45afc67121..5e8af49e3e 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -12,9 +12,13 @@ from gateway.session import ( build_session_context_prompt, build_session_key, canonical_whatsapp_identifier, - normalize_whatsapp_identifier, ) +# Legacy name preserved for these tests; product renamed the function to +# canonical_whatsapp_identifier. Keep the tests referencing the old name +# working without duplicating the suite. +normalize_whatsapp_identifier = canonical_whatsapp_identifier + class TestSessionSourceRoundtrip: def test_full_roundtrip(self): @@ -85,8 +89,13 @@ class TestSessionSourceRoundtrip: assert restored.chat_topic is None assert restored.chat_type == "dm" - def test_invalid_platform_raises(self): - with pytest.raises((ValueError, KeyError)): + def test_unknown_platform_rejected_for_bad_names(self): + """Arbitrary platform names are rejected (no accidental enum pollution). + + Only bundled platform plugins (discovered under ``plugins/platforms/``) + and runtime-registered plugins get dynamic enum members. + """ + with pytest.raises(ValueError): SessionSource.from_dict({"platform": "nonexistent", "chat_id": "1"}) diff --git a/tests/gateway/test_signal.py b/tests/gateway/test_signal.py index b51ec713f2..ca8f458a27 100644 --- a/tests/gateway/test_signal.py +++ b/tests/gateway/test_signal.py @@ -800,15 +800,23 @@ class TestSignalSendDocumentViaHelper: # --------------------------------------------------------------------------- -# send() returns message_id from timestamp (#4647) +# Signal streaming edit capability / message_id behavior # --------------------------------------------------------------------------- +class TestSignalStreamingCapabilities: + """Signal must opt out of edit-based streaming behavior.""" + + def test_signal_declares_no_message_editing(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + + assert adapter.SUPPORTS_MESSAGE_EDITING is False + + class TestSignalSendReturnsMessageId: - """Signal send() must return a timestamp-based message_id so the stream - consumer can follow its edit→fallback path correctly.""" + """Signal send() should not pretend sent messages are editable.""" @pytest.mark.asyncio - async def test_send_returns_timestamp_as_message_id(self, monkeypatch): + async def test_send_returns_none_message_id_even_with_timestamp(self, monkeypatch): adapter = _make_signal_adapter(monkeypatch) mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000}) adapter._rpc = mock_rpc @@ -817,7 +825,7 @@ class TestSignalSendReturnsMessageId: result = await adapter.send(chat_id="+155****4567", content="hello") assert result.success is True - assert result.message_id == "1712345678000" + assert result.message_id is None @pytest.mark.asyncio async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch): @@ -997,3 +1005,100 @@ class TestSignalTypingBackoff: assert "+155****4567" not in adapter._typing_failures assert "+155****4567" not in adapter._typing_skip_until + + +# --------------------------------------------------------------------------- +# Reply quote extraction +# --------------------------------------------------------------------------- + +class TestSignalQuoteExtraction: + """Verify Signal reply quote fields are propagated to MessageEvent.""" + + @pytest.mark.asyncio + async def test_handle_envelope_sets_reply_context_from_quote(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + captured = {} + + async def fake_handle(event): + captured["event"] = event + + adapter.handle_message = fake_handle + + await adapter._handle_envelope({ + "envelope": { + "sourceNumber": "+15550001111", + "sourceUuid": "uuid-sender", + "sourceName": "Tester", + "timestamp": 1000000000, + "dataMessage": { + "message": "yes I agree", + "quote": { + "id": 99, + "text": "want to grab lunch?", + "author": "+15550002222", + }, + }, + } + }) + + event = captured["event"] + assert event.text == "yes I agree" + assert event.reply_to_message_id == "99" + assert event.reply_to_text == "want to grab lunch?" + + @pytest.mark.asyncio + async def test_handle_envelope_without_quote_leaves_reply_fields_none(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + captured = {} + + async def fake_handle(event): + captured["event"] = event + + adapter.handle_message = fake_handle + + await adapter._handle_envelope({ + "envelope": { + "sourceNumber": "+15550001111", + "sourceUuid": "uuid-sender", + "sourceName": "Tester", + "timestamp": 1000000000, + "dataMessage": { + "message": "plain message", + }, + } + }) + + event = captured["event"] + assert event.text == "plain message" + assert event.reply_to_message_id is None + assert event.reply_to_text is None + + @pytest.mark.asyncio + async def test_handle_envelope_quote_without_text_sets_only_reply_id(self, monkeypatch): + adapter = _make_signal_adapter(monkeypatch) + captured = {} + + async def fake_handle(event): + captured["event"] = event + + adapter.handle_message = fake_handle + + await adapter._handle_envelope({ + "envelope": { + "sourceNumber": "+15550001111", + "sourceUuid": "uuid-sender", + "sourceName": "Tester", + "timestamp": 1000000000, + "dataMessage": { + "message": "reply without quote text", + "quote": { + "id": 123, + "author": "+15550002222", + }, + }, + } + }) + + event = captured["event"] + assert event.reply_to_message_id == "123" + assert event.reply_to_text is None diff --git a/tests/gateway/test_signal_format.py b/tests/gateway/test_signal_format.py new file mode 100644 index 0000000000..ef50f62fd0 --- /dev/null +++ b/tests/gateway/test_signal_format.py @@ -0,0 +1,452 @@ +"""Tests for Signal _markdown_to_signal() formatting. + +Covers the markdown-to-bodyRanges conversion pipeline: bold, italic, +strikethrough, monospace, code blocks, headings, and — critically — the +false-positive regressions that caused spurious italics in production. +""" + +import pytest + +from gateway.config import PlatformConfig +from gateway.platforms.signal import SignalAdapter + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _m2s(text: str): + """Shorthand: call the static method and return (plain_text, styles).""" + return SignalAdapter._markdown_to_signal(text) + + +def _style_types(styles: list[str]) -> list[str]: + """Extract just the STYLE part from '0:4:BOLD' strings.""" + return [s.rsplit(":", 1)[1] for s in styles] + + +def _find_style(styles: list[str], style_type: str) -> list[str]: + """Return only styles matching a given type.""" + return [s for s in styles if s.endswith(f":{style_type}")] + + +# =========================================================================== +# Basic formatting +# =========================================================================== + +class TestMarkdownToSignalBasic: + """Core formatting: bold, italic, strikethrough, monospace.""" + + def test_bold_double_asterisk(self): + text, styles = _m2s("hello **world**") + assert text == "hello world" + assert len(styles) == 1 + assert styles[0].endswith(":BOLD") + + def test_bold_double_underscore(self): + text, styles = _m2s("hello __world__") + assert text == "hello world" + assert len(styles) == 1 + assert styles[0].endswith(":BOLD") + + def test_italic_single_asterisk(self): + text, styles = _m2s("hello *world*") + assert text == "hello world" + assert len(styles) == 1 + assert styles[0].endswith(":ITALIC") + + def test_italic_single_underscore(self): + text, styles = _m2s("hello _world_") + assert text == "hello world" + assert len(styles) == 1 + assert styles[0].endswith(":ITALIC") + + def test_strikethrough(self): + text, styles = _m2s("hello ~~world~~") + assert text == "hello world" + assert len(styles) == 1 + assert styles[0].endswith(":STRIKETHROUGH") + + def test_inline_monospace(self): + text, styles = _m2s("run `ls -la` now") + assert text == "run ls -la now" + assert len(styles) == 1 + assert styles[0].endswith(":MONOSPACE") + + def test_fenced_code_block(self): + text, styles = _m2s("before\n```\ncode here\n```\nafter") + assert "code here" in text + assert "```" not in text + assert any(s.endswith(":MONOSPACE") for s in styles) + + def test_heading_becomes_bold(self): + text, styles = _m2s("## Section Title") + assert text == "Section Title" + assert len(styles) == 1 + assert styles[0].endswith(":BOLD") + + def test_multiple_styles(self): + text, styles = _m2s("**bold** and *italic*") + assert text == "bold and italic" + types = _style_types(styles) + assert "BOLD" in types + assert "ITALIC" in types + + def test_plain_text_no_styles(self): + text, styles = _m2s("just plain text") + assert text == "just plain text" + assert styles == [] + + def test_empty_string(self): + text, styles = _m2s("") + assert text == "" + assert styles == [] + + +# =========================================================================== +# Italic false-positive regressions +# =========================================================================== + +class TestItalicFalsePositives: + """Regressions from signal-italic-false-positive-fix.md and + signal-italic-bullet-list-fix.md.""" + + # --- snake_case (original fix) --- + + def test_snake_case_not_italic(self): + """snake_case identifiers must NOT be italicized.""" + text, styles = _m2s("the config_file is ready") + assert text == "the config_file is ready" + assert _find_style(styles, "ITALIC") == [] + + def test_multiple_snake_case(self): + text, styles = _m2s("set OPENAI_API_KEY and ANTHROPIC_API_KEY") + assert _find_style(styles, "ITALIC") == [] + + def test_snake_case_path(self): + text, styles = _m2s("/tools/delegate_tool.py") + assert _find_style(styles, "ITALIC") == [] + + def test_snake_case_between_words(self): + """file_path and error_code — underscores between words.""" + text, styles = _m2s("file_path and error_code") + assert _find_style(styles, "ITALIC") == [] + + # --- Bullet lists (second fix) --- + + def test_bullet_list_not_italic(self): + """* item lines must NOT be treated as italic delimiters.""" + md = "* item one\n* item two\n* item three" + text, styles = _m2s(md) + assert _find_style(styles, "ITALIC") == [] + + def test_bullet_list_with_content_before(self): + md = "Here are things:\n\n* first thing\n* second thing" + text, styles = _m2s(md) + assert _find_style(styles, "ITALIC") == [] + + def test_bullet_list_file_paths(self): + """Real-world case that triggered the bug.""" + md = ( + "* tools/delegate_tool.py — delegation\n" + "* tools/file_tools.py — file operations\n" + "* tools/web_tools.py — web operations" + ) + text, styles = _m2s(md) + assert _find_style(styles, "ITALIC") == [] + + def test_bullet_with_italic_inside(self): + """Italic *inside* a bullet item should still work.""" + md = "* this has *emphasis* inside\n* plain item" + text, styles = _m2s(md) + italic_styles = _find_style(styles, "ITALIC") + assert len(italic_styles) == 1 + # The italic should cover "emphasis", not the whole bullet + assert "emphasis" in text + + # --- Cross-line spans (DOTALL removal) --- + + def test_star_italic_no_cross_line(self): + """*foo\\nbar* must NOT match as italic (no DOTALL).""" + text, styles = _m2s("*foo\nbar*") + assert _find_style(styles, "ITALIC") == [] + + def test_underscore_italic_no_cross_line(self): + """_foo\\nbar_ must NOT match as italic (no DOTALL).""" + text, styles = _m2s("_foo\nbar_") + assert _find_style(styles, "ITALIC") == [] + + def test_star_italic_multiline_response(self): + """Multi-paragraph response with * should not false-positive.""" + md = ( + "I checked the following files:\n\n" + "* tools/delegate_tool.py — sub-agent delegation\n" + "* tools/file_tools.py — file read/write/search\n" + "* tools/web_tools.py — web search/extract\n\n" + "Everything looks good." + ) + text, styles = _m2s(md) + assert _find_style(styles, "ITALIC") == [] + + # --- Legitimate italic still works --- + + def test_star_italic_still_works(self): + text, styles = _m2s("this is *italic* text") + assert text == "this is italic text" + assert len(_find_style(styles, "ITALIC")) == 1 + + def test_underscore_italic_still_works(self): + text, styles = _m2s("this is _italic_ text") + assert text == "this is italic text" + assert len(_find_style(styles, "ITALIC")) == 1 + + def test_multiple_italic_same_line(self): + text, styles = _m2s("*foo* and *bar* ok") + assert text == "foo and bar ok" + assert len(_find_style(styles, "ITALIC")) == 2 + + def test_italic_single_word(self): + text, styles = _m2s("*word*") + assert text == "word" + assert len(_find_style(styles, "ITALIC")) == 1 + + def test_italic_multi_word(self): + text, styles = _m2s("*several words here*") + assert text == "several words here" + assert len(_find_style(styles, "ITALIC")) == 1 + + +# =========================================================================== +# Style position accuracy +# =========================================================================== + +class TestStylePositions: + """Verify that start:length positions map to the correct text.""" + + def _extract(self, text: str, style_str: str) -> str: + """Given 'start:length:STYLE', extract the substring from text.""" + # Positions are UTF-16 code units; for ASCII they match code points + parts = style_str.split(":") + start, length = int(parts[0]), int(parts[1]) + # Encode to UTF-16-LE, slice, decode back + encoded = text.encode("utf-16-le") + extracted = encoded[start * 2 : (start + length) * 2] + return extracted.decode("utf-16-le") + + def test_bold_position(self): + text, styles = _m2s("hello **world** end") + assert len(styles) == 1 + assert self._extract(text, styles[0]) == "world" + + def test_italic_position(self): + text, styles = _m2s("hello *world* end") + assert len(styles) == 1 + assert self._extract(text, styles[0]) == "world" + + def test_multiple_styles_positions(self): + text, styles = _m2s("**bold** then *italic*") + assert len(styles) == 2 + extracted = {self._extract(text, s) for s in styles} + assert extracted == {"bold", "italic"} + + def test_emoji_utf16_offset(self): + """Emoji (multi-byte UTF-16) before a styled span.""" + text, styles = _m2s("👋 **hello**") + assert text == "👋 hello" + assert len(styles) == 1 + assert self._extract(text, styles[0]) == "hello" + + +# =========================================================================== +# Edge cases +# =========================================================================== + +class TestEdgeCases: + """Tricky inputs that have caused issues or could regress.""" + + def test_bold_inside_bullet(self): + """Bold inside a bullet list item.""" + md = "* **important** item\n* normal item" + text, styles = _m2s(md) + assert len(_find_style(styles, "BOLD")) == 1 + assert _find_style(styles, "ITALIC") == [] + + def test_code_span_with_underscores(self): + """`snake_case_var` — backtick takes priority over underscore.""" + text, styles = _m2s("use `my_var_name` here") + assert text == "use my_var_name here" + types = _style_types(styles) + assert "MONOSPACE" in types + assert "ITALIC" not in types + + def test_bold_and_italic_nested(self): + """***bold+italic*** — bold captured, not italic (bold pattern first).""" + text, styles = _m2s("***word***") + # ** matches bold around *word*, or *** is ambiguous; + # either way there should be no false italic of the whole string + assert "word" in text + + def test_lone_asterisk(self): + """A single * with no pair should not cause issues.""" + text, styles = _m2s("5 * 3 = 15") + # Should not crash; any italic match would be a false positive + assert "5" in text and "15" in text + + def test_lone_underscore(self): + """A single _ with no pair.""" + text, styles = _m2s("this _ that") + assert text == "this _ that" + + def test_consecutive_underscored_words(self): + """_foo and _bar (leading underscores, no closers).""" + text, styles = _m2s("call _init and _setup") + assert _find_style(styles, "ITALIC") == [] + + def test_mixed_formatting_no_bleed(self): + """Multiple format types don't bleed into each other.""" + md = "**bold** and `code` and *italic* and ~~strike~~" + text, styles = _m2s(md) + assert text == "bold and code and italic and strike" + types = _style_types(styles) + assert sorted(types) == ["BOLD", "ITALIC", "MONOSPACE", "STRIKETHROUGH"] + + +# =========================================================================== +# signal-markdown-strip-patch: core conversion pipeline +# =========================================================================== + +class TestMarkdownStripPatch: + """Tests for the original signal-markdown-strip-patch. + + Covers: fenced code blocks with language tags, links preserved, + headings converted to bold, multiple headings, UTF-16 correctness + for multi-byte characters, and marker stripping completeness. + """ + + def test_fenced_code_block_with_language_tag(self): + """```python\\ncode\\n``` — language tag is stripped, content is MONOSPACE.""" + text, styles = _m2s("```python\nprint('hello')\n```") + assert "```" not in text + assert "python" not in text # language tag stripped + assert "print('hello')" in text + assert any(s.endswith(":MONOSPACE") for s in styles) + + def test_fenced_code_block_multiline(self): + """Multi-line code blocks preserve all lines.""" + md = "```\nline1\nline2\nline3\n```" + text, styles = _m2s(md) + assert "line1" in text + assert "line2" in text + assert "line3" in text + assert "```" not in text + + def test_links_preserved(self): + """[text](url) links are kept as-is — Signal auto-linkifies.""" + md = "Check [this link](https://example.com) for details" + text, styles = _m2s(md) + # Links should pass through — either as markdown or just preserved + assert "https://example.com" in text + + def test_heading_h1(self): + """# H1 becomes bold text.""" + text, styles = _m2s("# Main Title") + assert text == "Main Title" + assert len(styles) == 1 + assert styles[0].endswith(":BOLD") + + def test_heading_h3(self): + """### H3 becomes bold text.""" + text, styles = _m2s("### Sub Section") + assert text == "Sub Section" + assert len(styles) == 1 + assert styles[0].endswith(":BOLD") + + def test_multiple_headings(self): + """Multiple headings each become separate bold spans.""" + md = "## First\n\nSome text\n\n## Second" + text, styles = _m2s(md) + assert "First" in text + assert "Second" in text + assert "##" not in text + bold_styles = _find_style(styles, "BOLD") + assert len(bold_styles) == 2 + + def test_no_raw_markdown_markers_in_output(self): + """All markdown syntax is stripped from plain text output.""" + md = "**bold** and *italic* and ~~struck~~ and `code` and ## heading" + text, styles = _m2s(md) + assert "**" not in text + assert "~~" not in text + assert "`" not in text + # ## at end might remain if not at line start — that's ok + # The important thing is styled markers are stripped + + def test_utf16_surrogate_pair_emoji(self): + """Emoji requiring UTF-16 surrogate pairs don't corrupt offsets.""" + # 🎉 is U+1F389 — requires surrogate pair (2 UTF-16 code units) + text, styles = _m2s("🎉🎉 **test**") + assert "test" in text + assert len(styles) == 1 + # Verify the style position is correct + parts = styles[0].split(":") + start, length = int(parts[0]), int(parts[1]) + # 🎉🎉 = 4 UTF-16 code units + space = 5, then "test" = 4 + assert start == 5 + assert length == 4 + + def test_consecutive_newlines_collapsed(self): + """3+ consecutive newlines are collapsed to 2.""" + text, styles = _m2s("first\n\n\n\n\nsecond") + assert "\n\n\n" not in text + assert "first" in text + assert "second" in text + + def test_empty_bold_not_crash(self): + """**** (empty bold) should not crash.""" + text, styles = _m2s("before **** after") + # Should not raise — exact output doesn't matter much + assert "before" in text + + +# =========================================================================== +# signal-streaming-patch: SUPPORTS_MESSAGE_EDITING and send() behavior +# =========================================================================== + +class TestSignalStreamingPatch: + """Tests for signal-streaming-patch: cursor suppression and edit support. + + These verify the adapter-level properties that prevent the streaming + cursor from leaking into Signal messages. + """ + + def test_signal_does_not_support_editing(self, monkeypatch): + """SignalAdapter.SUPPORTS_MESSAGE_EDITING must be False.""" + monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "") + from gateway.platforms.signal import SignalAdapter + assert SignalAdapter.SUPPORTS_MESSAGE_EDITING is False + + @pytest.mark.asyncio + async def test_send_returns_no_message_id(self, monkeypatch): + """send() returns message_id=None so stream consumer uses no-edit path.""" + monkeypatch.setenv("SIGNAL_GROUP_ALLOWED_USERS", "") + from gateway.platforms.signal import SignalAdapter + from gateway.config import PlatformConfig + + config = PlatformConfig(enabled=True) + config.extra = { + "http_url": "http://localhost:8080", + "account": "+15551234567", + } + adapter = SignalAdapter(config) + + # Mock the RPC call + async def mock_rpc(method, params, rpc_id=None): + return {"timestamp": 1234567890} + + adapter._rpc = mock_rpc + + result = await adapter.send( + chat_id="+15559876543", + content="Hello", + ) + assert result.message_id is None diff --git a/tests/gateway/test_teams.py b/tests/gateway/test_teams.py new file mode 100644 index 0000000000..7a035142ed --- /dev/null +++ b/tests/gateway/test_teams.py @@ -0,0 +1,560 @@ +"""Tests for the Microsoft Teams platform adapter plugin.""" + +import asyncio +import os +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import Platform, PlatformConfig, HomeChannel +from tests.gateway._plugin_adapter_loader import load_plugin_adapter + + +# --------------------------------------------------------------------------- +# SDK Mock — install in sys.modules before importing the adapter +# --------------------------------------------------------------------------- + +def _ensure_teams_mock(): + """Install a teams SDK mock in sys.modules if the real package isn't present.""" + if "microsoft_teams" in sys.modules and hasattr(sys.modules["microsoft_teams"], "__file__"): + return + + # Build the module hierarchy + microsoft_teams = types.ModuleType("microsoft_teams") + microsoft_teams_apps = types.ModuleType("microsoft_teams.apps") + microsoft_teams_api = types.ModuleType("microsoft_teams.api") + microsoft_teams_api_activities = types.ModuleType("microsoft_teams.api.activities") + microsoft_teams_api_activities_typing = types.ModuleType("microsoft_teams.api.activities.typing") + microsoft_teams_api_activities_invoke = types.ModuleType("microsoft_teams.api.activities.invoke") + microsoft_teams_api_activities_invoke_adaptive_card = types.ModuleType( + "microsoft_teams.api.activities.invoke.adaptive_card" + ) + microsoft_teams_api_models = types.ModuleType("microsoft_teams.api.models") + microsoft_teams_api_models_adaptive_card = types.ModuleType("microsoft_teams.api.models.adaptive_card") + microsoft_teams_api_models_invoke_response = types.ModuleType("microsoft_teams.api.models.invoke_response") + microsoft_teams_cards = types.ModuleType("microsoft_teams.cards") + microsoft_teams_apps_http = types.ModuleType("microsoft_teams.apps.http") + microsoft_teams_apps_http_adapter = types.ModuleType("microsoft_teams.apps.http.adapter") + + # App class mock + class MockApp: + def __init__(self, **kwargs): + self._client_id = kwargs.get("client_id") + self.server = MagicMock() + self.server.handle_request = AsyncMock(return_value={"status": 200, "body": None}) + self.credentials = MagicMock() + self.credentials.client_id = self._client_id + + @property + def id(self): + return self._client_id + + def on_message(self, func): + self._message_handler = func + return func + + def on_card_action(self, func): + self._card_action_handler = func + return func + + async def initialize(self): + pass + + async def send(self, conversation_id, activity): + result = MagicMock() + result.id = "sent-activity-id" + return result + + async def start(self, port=3978): + pass + + async def stop(self): + pass + + microsoft_teams_apps.App = MockApp + microsoft_teams_apps.ActivityContext = MagicMock + + # MessageActivity mock + microsoft_teams_api.MessageActivity = MagicMock + microsoft_teams_api.ConversationReference = MagicMock + microsoft_teams_api.MessageActivityInput = MagicMock + + # TypingActivityInput mock + class MockTypingActivityInput: + pass + + microsoft_teams_api_activities_typing.TypingActivityInput = MockTypingActivityInput + + # Adaptive card invoke activity mock + microsoft_teams_api_activities_invoke_adaptive_card.AdaptiveCardInvokeActivity = MagicMock + + # Adaptive card response mocks + microsoft_teams_api_models_adaptive_card.AdaptiveCardActionCardResponse = MagicMock + microsoft_teams_api_models_adaptive_card.AdaptiveCardActionMessageResponse = MagicMock + + # Invoke response mocks + class MockInvokeResponse: + def __init__(self, status=200, body=None): + self.status = status + self.body = body + + microsoft_teams_api_models_invoke_response.InvokeResponse = MockInvokeResponse + microsoft_teams_api_models_invoke_response.AdaptiveCardInvokeResponse = MagicMock + + # Cards mocks + class MockAdaptiveCard: + def with_version(self, v): + return self + + def with_body(self, body): + return self + + def with_actions(self, actions): + return self + + microsoft_teams_cards.AdaptiveCard = MockAdaptiveCard + microsoft_teams_cards.ExecuteAction = MagicMock + microsoft_teams_cards.TextBlock = MagicMock + + # HttpRequest TypedDict mock + def HttpRequest(body=None, headers=None): + return {"body": body, "headers": headers} + + # HttpResponse TypedDict mock + HttpResponse = dict + HttpMethod = str + from typing import Callable + HttpRouteHandler = Callable + + microsoft_teams_apps_http_adapter.HttpRequest = HttpRequest + microsoft_teams_apps_http_adapter.HttpResponse = HttpResponse + microsoft_teams_apps_http_adapter.HttpMethod = HttpMethod + microsoft_teams_apps_http_adapter.HttpRouteHandler = HttpRouteHandler + + # Wire the hierarchy + for name, mod in { + "microsoft_teams": microsoft_teams, + "microsoft_teams.apps": microsoft_teams_apps, + "microsoft_teams.api": microsoft_teams_api, + "microsoft_teams.api.activities": microsoft_teams_api_activities, + "microsoft_teams.api.activities.typing": microsoft_teams_api_activities_typing, + "microsoft_teams.api.activities.invoke": microsoft_teams_api_activities_invoke, + "microsoft_teams.api.activities.invoke.adaptive_card": microsoft_teams_api_activities_invoke_adaptive_card, + "microsoft_teams.api.models": microsoft_teams_api_models, + "microsoft_teams.api.models.adaptive_card": microsoft_teams_api_models_adaptive_card, + "microsoft_teams.api.models.invoke_response": microsoft_teams_api_models_invoke_response, + "microsoft_teams.cards": microsoft_teams_cards, + "microsoft_teams.apps.http": microsoft_teams_apps_http, + "microsoft_teams.apps.http.adapter": microsoft_teams_apps_http_adapter, + }.items(): + sys.modules.setdefault(name, mod) + + +_ensure_teams_mock() + +# Load plugins/platforms/teams/adapter.py under a unique module name +# (plugin_adapter_teams) so it cannot collide with sibling plugin adapters. +_teams_mod = load_plugin_adapter("teams") + +_teams_mod.TEAMS_SDK_AVAILABLE = True +_teams_mod.AIOHTTP_AVAILABLE = True + +TeamsAdapter = _teams_mod.TeamsAdapter +check_requirements = _teams_mod.check_requirements +check_teams_requirements = _teams_mod.check_teams_requirements +validate_config = _teams_mod.validate_config +register = _teams_mod.register + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_config(**extra): + return PlatformConfig(enabled=True, extra=extra) + + +# --------------------------------------------------------------------------- +# Tests: Requirements +# --------------------------------------------------------------------------- + +class TestTeamsRequirements: + def test_returns_false_when_sdk_missing(self, monkeypatch): + monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", False) + assert check_requirements() is False + + def test_returns_false_when_aiohttp_missing(self, monkeypatch): + monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", False) + assert check_requirements() is False + + def test_returns_true_when_deps_available(self, monkeypatch): + monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", True) + monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", True) + assert check_requirements() is True + + def test_alias_matches(self, monkeypatch): + monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", True) + monkeypatch.setattr(_teams_mod, "AIOHTTP_AVAILABLE", True) + assert check_teams_requirements() is True + + def test_validate_config_with_env(self, monkeypatch): + monkeypatch.setenv("TEAMS_CLIENT_ID", "test-id") + monkeypatch.setenv("TEAMS_CLIENT_SECRET", "test-secret") + monkeypatch.setenv("TEAMS_TENANT_ID", "test-tenant") + assert validate_config(_make_config()) is True + + def test_validate_config_from_extra(self, monkeypatch): + monkeypatch.delenv("TEAMS_CLIENT_ID", raising=False) + monkeypatch.delenv("TEAMS_CLIENT_SECRET", raising=False) + monkeypatch.delenv("TEAMS_TENANT_ID", raising=False) + cfg = _make_config(client_id="id", client_secret="secret", tenant_id="tenant") + assert validate_config(cfg) is True + + def test_validate_config_missing(self, monkeypatch): + monkeypatch.delenv("TEAMS_CLIENT_ID", raising=False) + monkeypatch.delenv("TEAMS_CLIENT_SECRET", raising=False) + monkeypatch.delenv("TEAMS_TENANT_ID", raising=False) + assert validate_config(_make_config()) is False + + def test_validate_config_missing_tenant(self, monkeypatch): + monkeypatch.setenv("TEAMS_CLIENT_ID", "test-id") + monkeypatch.setenv("TEAMS_CLIENT_SECRET", "test-secret") + monkeypatch.delenv("TEAMS_TENANT_ID", raising=False) + assert validate_config(_make_config()) is False + + +# --------------------------------------------------------------------------- +# Tests: Adapter Init +# --------------------------------------------------------------------------- + +class TestTeamsAdapterInit: + def test_reads_config_from_extra(self): + config = _make_config( + client_id="cfg-id", + client_secret="cfg-secret", + tenant_id="cfg-tenant", + ) + adapter = TeamsAdapter(config) + assert adapter._client_id == "cfg-id" + assert adapter._client_secret == "cfg-secret" + assert adapter._tenant_id == "cfg-tenant" + + def test_falls_back_to_env_vars(self, monkeypatch): + monkeypatch.setenv("TEAMS_CLIENT_ID", "env-id") + monkeypatch.setenv("TEAMS_CLIENT_SECRET", "env-secret") + monkeypatch.setenv("TEAMS_TENANT_ID", "env-tenant") + adapter = TeamsAdapter(_make_config()) + assert adapter._client_id == "env-id" + assert adapter._client_secret == "env-secret" + assert adapter._tenant_id == "env-tenant" + + def test_default_port(self): + adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant")) + assert adapter._port == 3978 + + def test_custom_port_from_extra(self): + adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant", port=4000)) + assert adapter._port == 4000 + + def test_custom_port_from_env(self, monkeypatch): + monkeypatch.setenv("TEAMS_PORT", "5000") + adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant")) + assert adapter._port == 5000 + + def test_platform_value(self): + adapter = TeamsAdapter(_make_config(client_id="id", client_secret="secret", tenant_id="tenant")) + assert adapter.platform.value == "teams" + + +# --------------------------------------------------------------------------- +# Tests: Plugin registration +# --------------------------------------------------------------------------- + +class TestTeamsPluginRegistration: + + def test_register_calls_ctx(self): + ctx = MagicMock() + register(ctx) + ctx.register_platform.assert_called_once() + + def test_register_name(self): + ctx = MagicMock() + register(ctx) + kwargs = ctx.register_platform.call_args[1] + assert kwargs["name"] == "teams" + + def test_register_auth_env_vars(self): + ctx = MagicMock() + register(ctx) + kwargs = ctx.register_platform.call_args[1] + assert kwargs["allowed_users_env"] == "TEAMS_ALLOWED_USERS" + assert kwargs["allow_all_env"] == "TEAMS_ALLOW_ALL_USERS" + + def test_register_max_message_length(self): + ctx = MagicMock() + register(ctx) + kwargs = ctx.register_platform.call_args[1] + assert kwargs["max_message_length"] == 28000 + + def test_register_has_setup_fn(self): + ctx = MagicMock() + register(ctx) + kwargs = ctx.register_platform.call_args[1] + assert callable(kwargs.get("setup_fn")) + + def test_register_has_platform_hint(self): + ctx = MagicMock() + register(ctx) + kwargs = ctx.register_platform.call_args[1] + assert kwargs.get("platform_hint") + + +# --------------------------------------------------------------------------- +# Tests: Connect / Disconnect +# --------------------------------------------------------------------------- + +class TestTeamsConnect: + @pytest.mark.asyncio + async def test_connect_fails_without_sdk(self, monkeypatch): + monkeypatch.setattr(_teams_mod, "TEAMS_SDK_AVAILABLE", False) + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_connect_fails_without_credentials(self): + adapter = TeamsAdapter(_make_config()) + adapter._client_id = "" + adapter._client_secret = "" + adapter._tenant_id = "" + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_disconnect_cleans_up(self): + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + adapter._running = True + mock_runner = AsyncMock() + adapter._runner = mock_runner + adapter._app = MagicMock() + + await adapter.disconnect() + assert adapter._running is False + assert adapter._app is None + assert adapter._runner is None + mock_runner.cleanup.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Tests: Send +# --------------------------------------------------------------------------- + +class TestTeamsSend: + @pytest.mark.asyncio + async def test_send_returns_error_without_app(self): + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = None + result = await adapter.send("conv-id", "Hello") + assert result.success is False + assert "not initialized" in result.error + + @pytest.mark.asyncio + async def test_send_calls_app_send(self): + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + mock_result = MagicMock() + mock_result.id = "msg-123" + mock_app = MagicMock() + mock_app.send = AsyncMock(return_value=mock_result) + adapter._app = mock_app + + result = await adapter.send("conv-id", "Hello") + assert result.success is True + assert result.message_id == "msg-123" + mock_app.send.assert_awaited_once_with("conv-id", "Hello") + + @pytest.mark.asyncio + async def test_send_handles_error(self): + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + mock_app = MagicMock() + mock_app.send = AsyncMock(side_effect=Exception("Network error")) + adapter._app = mock_app + + result = await adapter.send("conv-id", "Hello") + assert result.success is False + assert "Network error" in result.error + + @pytest.mark.asyncio + async def test_send_typing(self): + adapter = TeamsAdapter(_make_config( + client_id="id", client_secret="secret", tenant_id="tenant", + )) + mock_app = MagicMock() + mock_app.send = AsyncMock() + adapter._app = mock_app + + await adapter.send_typing("conv-id") + mock_app.send.assert_awaited_once() + call_args = mock_app.send.call_args + assert call_args[0][0] == "conv-id" + + +# --------------------------------------------------------------------------- +# Tests: Message Handling +# --------------------------------------------------------------------------- + +class TestTeamsMessageHandling: + def _make_activity( + self, + *, + text="Hello", + from_id="user-123", + from_aad_id="aad-456", + from_name="Test User", + conversation_id="19:abc@thread.v2", + conversation_type="personal", + tenant_id="tenant-789", + activity_id="activity-001", + attachments=None, + ): + activity = MagicMock() + activity.text = text + activity.id = activity_id + activity.from_ = MagicMock() + activity.from_.id = from_id + activity.from_.aad_object_id = from_aad_id + activity.from_.name = from_name + activity.conversation = MagicMock() + activity.conversation.id = conversation_id + activity.conversation.conversation_type = conversation_type + activity.conversation.name = "Test Chat" + activity.conversation.tenant_id = tenant_id + activity.attachments = attachments or [] + return activity + + def _make_ctx(self, activity): + ctx = MagicMock() + ctx.activity = activity + return ctx + + @pytest.mark.asyncio + async def test_personal_message_creates_dm_event(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(conversation_type="personal") + await adapter._on_message(self._make_ctx(activity)) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.call_args[0][0] + assert event.source.chat_type == "dm" + + @pytest.mark.asyncio + async def test_group_message_creates_group_event(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(conversation_type="groupChat") + await adapter._on_message(self._make_ctx(activity)) + + event = adapter.handle_message.call_args[0][0] + assert event.source.chat_type == "group" + + @pytest.mark.asyncio + async def test_channel_message_creates_channel_event(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(conversation_type="channel") + await adapter._on_message(self._make_ctx(activity)) + + event = adapter.handle_message.call_args[0][0] + assert event.source.chat_type == "channel" + + @pytest.mark.asyncio + async def test_user_id_uses_aad_object_id(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(from_aad_id="aad-stable-id", from_id="teams-id") + await adapter._on_message(self._make_ctx(activity)) + + event = adapter.handle_message.call_args[0][0] + assert event.source.user_id == "aad-stable-id" + + @pytest.mark.asyncio + async def test_self_message_filtered(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(from_id="bot-id") + await adapter._on_message(self._make_ctx(activity)) + + adapter.handle_message.assert_not_awaited() + + @pytest.mark.asyncio + async def test_bot_mention_stripped_from_text(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity( + text="Hermes what is the weather?", + from_id="user-id", + ) + await adapter._on_message(self._make_ctx(activity)) + + event = adapter.handle_message.call_args[0][0] + assert event.text == "what is the weather?" + + @pytest.mark.asyncio + async def test_deduplication(self): + adapter = TeamsAdapter(_make_config( + client_id="bot-id", client_secret="secret", tenant_id="tenant", + )) + adapter._app = MagicMock() + adapter._app.id = "bot-id" + adapter.handle_message = AsyncMock() + + activity = self._make_activity(activity_id="msg-dup-001", from_id="user-id") + ctx = self._make_ctx(activity) + + await adapter._on_message(ctx) + await adapter._on_message(ctx) + + assert adapter.handle_message.await_count == 1 diff --git a/tests/gateway/test_telegram_documents.py b/tests/gateway/test_telegram_documents.py index d5564cbf46..4b3e58f459 100644 --- a/tests/gateway/test_telegram_documents.py +++ b/tests/gateway/test_telegram_documents.py @@ -453,6 +453,87 @@ class TestMediaGroups: adapter.handle_message.assert_not_awaited() +# --------------------------------------------------------------------------- +# TestSendVoice — outbound audio delivery +# --------------------------------------------------------------------------- + +class TestSendVoice: + """Tests for TelegramAdapter.send_voice() routing across audio formats.""" + + @pytest.fixture() + def connected_adapter(self, adapter): + """Adapter with a mock bot attached.""" + bot = AsyncMock() + adapter._bot = bot + return adapter + + @pytest.mark.asyncio + async def test_flac_falls_back_to_document(self, connected_adapter, tmp_path): + """Telegram sendAudio does not accept FLAC — must fall back to sendDocument.""" + audio_file = tmp_path / "clip.flac" + audio_file.write_bytes(b"fLaC" + b"\x00" * 32) + + mock_msg = MagicMock() + mock_msg.message_id = 101 + connected_adapter._bot.send_voice = AsyncMock() + connected_adapter._bot.send_audio = AsyncMock() + connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg) + + result = await connected_adapter.send_voice( + chat_id="12345", + audio_path=str(audio_file), + caption="Audio", + ) + + assert result.success is True + assert result.message_id == "101" + connected_adapter._bot.send_document.assert_awaited_once() + connected_adapter._bot.send_audio.assert_not_awaited() + connected_adapter._bot.send_voice.assert_not_awaited() + + @pytest.mark.asyncio + async def test_wav_falls_back_to_document(self, connected_adapter, tmp_path): + """Telegram sendAudio does not accept WAV — must fall back to sendDocument.""" + audio_file = tmp_path / "clip.wav" + audio_file.write_bytes(b"RIFF" + b"\x00" * 32) + + mock_msg = MagicMock() + mock_msg.message_id = 102 + connected_adapter._bot.send_voice = AsyncMock() + connected_adapter._bot.send_audio = AsyncMock() + connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg) + + result = await connected_adapter.send_voice( + chat_id="12345", + audio_path=str(audio_file), + ) + + assert result.success is True + connected_adapter._bot.send_document.assert_awaited_once() + connected_adapter._bot.send_audio.assert_not_awaited() + + @pytest.mark.asyncio + async def test_mp3_routes_to_send_audio(self, connected_adapter, tmp_path): + """MP3 is Telegram-sendAudio-compatible.""" + audio_file = tmp_path / "clip.mp3" + audio_file.write_bytes(b"ID3" + b"\x00" * 32) + + mock_msg = MagicMock() + mock_msg.message_id = 103 + connected_adapter._bot.send_voice = AsyncMock() + connected_adapter._bot.send_audio = AsyncMock(return_value=mock_msg) + connected_adapter._bot.send_document = AsyncMock() + + result = await connected_adapter.send_voice( + chat_id="12345", + audio_path=str(audio_file), + ) + + assert result.success is True + connected_adapter._bot.send_audio.assert_awaited_once() + connected_adapter._bot.send_document.assert_not_awaited() + + # --------------------------------------------------------------------------- # TestSendDocument — outbound file attachment delivery # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_telegram_group_gating.py b/tests/gateway/test_telegram_group_gating.py index ababe5ec61..a560d6cdd6 100644 --- a/tests/gateway/test_telegram_group_gating.py +++ b/tests/gateway/test_telegram_group_gating.py @@ -5,7 +5,14 @@ from unittest.mock import AsyncMock from gateway.config import Platform, PlatformConfig, load_gateway_config -def _make_adapter(require_mention=None, free_response_chats=None, mention_patterns=None, ignored_threads=None): +def _make_adapter( + require_mention=None, + free_response_chats=None, + mention_patterns=None, + ignored_threads=None, + allow_from=None, + group_allow_from=None, +): from gateway.platforms.telegram import TelegramAdapter extra = {} @@ -17,6 +24,10 @@ def _make_adapter(require_mention=None, free_response_chats=None, mention_patter extra["mention_patterns"] = mention_patterns if ignored_threads is not None: extra["ignored_threads"] = ignored_threads + if allow_from is not None: + extra["allow_from"] = allow_from + if group_allow_from is not None: + extra["group_allow_from"] = group_allow_from adapter = object.__new__(TelegramAdapter) adapter.platform = Platform.TELEGRAM @@ -34,6 +45,7 @@ def _group_message( text="hello", *, chat_id=-100, + from_user_id=111, thread_id=None, reply_to_bot=False, entities=None, @@ -50,10 +62,24 @@ def _group_message( caption_entities=caption_entities or [], message_thread_id=thread_id, chat=SimpleNamespace(id=chat_id, type="group"), + from_user=SimpleNamespace(id=from_user_id), reply_to_message=reply_to_message, ) +def _dm_message(text="hello", *, from_user_id=111): + return SimpleNamespace( + text=text, + caption=None, + entities=[], + caption_entities=[], + message_thread_id=None, + chat=SimpleNamespace(id=from_user_id, type="private"), + from_user=SimpleNamespace(id=from_user_id), + reply_to_message=None, + ) + + def _mention_entity(text, mention="@hermes_bot"): offset = text.index(mention) return SimpleNamespace(type="mention", offset=offset, length=len(mention)) @@ -173,6 +199,68 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path): assert __import__("os").environ["TELEGRAM_FREE_RESPONSE_CHATS"] == "-123" +def test_config_bridges_telegram_user_allowlists(monkeypatch, tmp_path): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + "telegram:\n" + " allow_from:\n" + " - \"111\"\n" + " - \"222\"\n" + " group_allow_from:\n" + " - \"333\"\n" + " group_allowed_chats:\n" + " - \"-100\"\n", + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("TELEGRAM_ALLOWED_USERS", raising=False) + monkeypatch.delenv("TELEGRAM_GROUP_ALLOWED_USERS", raising=False) + monkeypatch.delenv("TELEGRAM_GROUP_ALLOWED_CHATS", raising=False) + + config = load_gateway_config() + + assert config is not None + assert __import__("os").environ["TELEGRAM_ALLOWED_USERS"] == "111,222" + assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_USERS"] == "333" + assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_CHATS"] == "-100" + + +def test_config_env_overrides_telegram_user_allowlists(monkeypatch, tmp_path): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + "telegram:\n" + " allow_from: \"111\"\n" + " group_allow_from: \"222\"\n", + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "999") + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "888") + + config = load_gateway_config() + + assert config is not None + assert __import__("os").environ["TELEGRAM_ALLOWED_USERS"] == "999" + assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_USERS"] == "888" + + +def test_dm_allow_from_is_enforced_by_gateway_authorization_not_trigger_gate(): + adapter = _make_adapter(allow_from=["111", "222"]) + + assert adapter._should_process_message(_dm_message("hello", from_user_id=111)) is True + assert adapter._should_process_message(_dm_message("hello", from_user_id=333)) is True + + +def test_group_allow_from_is_enforced_by_gateway_authorization_not_trigger_gate(): + adapter = _make_adapter(group_allow_from=["111"]) + + assert adapter._should_process_message(_group_message("hello", from_user_id=333)) is True + + def test_config_bridges_telegram_ignored_threads(monkeypatch, tmp_path): hermes_home = tmp_path / ".hermes" hermes_home.mkdir() diff --git a/tests/gateway/test_tts_media_routing.py b/tests/gateway/test_tts_media_routing.py new file mode 100644 index 0000000000..0ef37deb3e --- /dev/null +++ b/tests/gateway/test_tts_media_routing.py @@ -0,0 +1,195 @@ +""" +Tests for cross-platform audio/voice media routing. + +These tests pin the expected delivery path for audio media files across +Telegram (where Bot-API sendAudio only accepts MP3/M4A and .ogg/.opus +only renders as a voice bubble when explicitly flagged) and via +``GatewayRunner._deliver_media_from_response``. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult +from gateway.run import GatewayRunner +from gateway.session import SessionSource, build_session_key + + +class _MediaRoutingAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM) + + async def connect(self): + return True + + async def disconnect(self): + pass + + async def send(self, chat_id, content=None, **kwargs): + return SendResult(success=True, message_id="text") + + async def get_chat_info(self, chat_id): + return {"id": chat_id, "type": "dm"} + + +def _event(thread_id=None): + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="chat-1", + chat_type="dm", + thread_id=thread_id, + ) + return MessageEvent( + text="make speech", + message_type=MessageType.TEXT, + source=source, + message_id="msg-1", + ) + + +@pytest.mark.asyncio +async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(): + adapter = _MediaRoutingAdapter() + event = _event() + adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.flac") + adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) + adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) + + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.send_document.assert_awaited_once_with( + chat_id="chat-1", + file_path="/tmp/speech.flac", + metadata=None, + ) + adapter.send_voice.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender(): + adapter = _MediaRoutingAdapter() + event = _event() + adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.ogg") + adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) + adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) + + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.send_document.assert_awaited_once_with( + chat_id="chat-1", + file_path="/tmp/speech.ogg", + metadata=None, + ) + adapter.send_voice.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender(): + adapter = _MediaRoutingAdapter() + event = _event() + adapter._message_handler = AsyncMock( + return_value="[[audio_as_voice]]\nMEDIA:/tmp/speech.ogg" + ) + adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice")) + adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc")) + + await adapter._process_message_background(event, build_session_key(event.source)) + + adapter.send_voice.assert_awaited_once_with( + chat_id="chat-1", + audio_path="/tmp/speech.ogg", + metadata=None, + ) + adapter.send_document.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender(): + event = _event(thread_id="topic-1") + adapter = SimpleNamespace( + name="test", + extract_media=BasePlatformAdapter.extract_media, + extract_images=BasePlatformAdapter.extract_images, + extract_local_files=BasePlatformAdapter.extract_local_files, + send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")), + send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")), + send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")), + send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")), + ) + + await GatewayRunner._deliver_media_from_response( + object(), + "MEDIA:/tmp/speech.flac", + event, + adapter, + ) + + adapter.send_document.assert_awaited_once_with( + chat_id="chat-1", + file_path="/tmp/speech.flac", + metadata={"thread_id": "topic-1"}, + ) + adapter.send_voice.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender(): + event = _event(thread_id="topic-1") + adapter = SimpleNamespace( + name="test", + extract_media=BasePlatformAdapter.extract_media, + extract_images=BasePlatformAdapter.extract_images, + extract_local_files=BasePlatformAdapter.extract_local_files, + send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")), + send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")), + send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")), + send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")), + ) + + await GatewayRunner._deliver_media_from_response( + object(), + "MEDIA:/tmp/speech.ogg", + event, + adapter, + ) + + adapter.send_document.assert_awaited_once_with( + chat_id="chat-1", + file_path="/tmp/speech.ogg", + metadata={"thread_id": "topic-1"}, + ) + adapter.send_voice.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(): + """MP3 audio on Telegram must go through send_voice (which routes to + sendAudio internally); Telegram accepts MP3 for the audio player.""" + event = _event(thread_id="topic-1") + adapter = SimpleNamespace( + name="test", + extract_media=BasePlatformAdapter.extract_media, + extract_images=BasePlatformAdapter.extract_images, + extract_local_files=BasePlatformAdapter.extract_local_files, + send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")), + send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")), + send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")), + send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")), + ) + + await GatewayRunner._deliver_media_from_response( + object(), + "MEDIA:/tmp/speech.mp3", + event, + adapter, + ) + + adapter.send_voice.assert_awaited_once_with( + chat_id="chat-1", + audio_path="/tmp/speech.mp3", + metadata={"thread_id": "topic-1"}, + ) + adapter.send_document.assert_not_awaited() diff --git a/tests/gateway/test_unauthorized_dm_behavior.py b/tests/gateway/test_unauthorized_dm_behavior.py index 9571f3f4e4..bedd3a1f69 100644 --- a/tests/gateway/test_unauthorized_dm_behavior.py +++ b/tests/gateway/test_unauthorized_dm_behavior.py @@ -16,6 +16,8 @@ def _clear_auth_env(monkeypatch) -> None: "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", "SIGNAL_ALLOWED_USERS", + "SIGNAL_GROUP_ALLOWED_USERS", + "TELEGRAM_GROUP_ALLOWED_CHATS", "EMAIL_ALLOWED_USERS", "SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS", @@ -178,7 +180,109 @@ def test_qq_group_allowlist_does_not_authorize_other_groups(monkeypatch): assert runner._is_user_authorized(source) is False -def test_telegram_group_allowlist_authorizes_forum_chat_without_user_allowlist(monkeypatch): +def test_telegram_group_user_allowlist_authorizes_forum_sender_without_dm_allowlist(monkeypatch): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="999", + chat_id="-1001878443972", + user_name="tester", + chat_type="forum", + ) + + assert runner._is_user_authorized(source) is True + + +def test_telegram_group_user_allowlist_rejects_other_senders(monkeypatch): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="123", + chat_id="-1001878443972", + user_name="tester", + chat_type="group", + ) + + assert runner._is_user_authorized(source) is False + + +def test_telegram_group_user_allowlist_wildcard_authorizes_any_sender(monkeypatch): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "*") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="123", + chat_id="-1001878443972", + user_name="tester", + chat_type="group", + ) + + assert runner._is_user_authorized(source) is True + + +def test_telegram_group_user_allowlist_does_not_authorize_dms(monkeypatch): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="999", + chat_id="999", + user_name="tester", + chat_type="dm", + ) + + assert runner._is_user_authorized(source) is False + + +def test_telegram_group_chat_allowlist_authorizes_group_chat_without_user_allowlist(monkeypatch): + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_CHATS", "-1001878443972") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="999", + chat_id="-1001878443972", + user_name="tester", + chat_type="forum", + ) + + assert runner._is_user_authorized(source) is True + + +def test_telegram_group_users_legacy_chat_ids_still_authorize(monkeypatch): + """Backward-compat: PR #15027 shipped TELEGRAM_GROUP_ALLOWED_USERS as a + chat-ID allowlist. PR #17686 renamed it to sender IDs and added + TELEGRAM_GROUP_ALLOWED_CHATS. Users on the old guidance must keep working: + chat-ID-shaped values (starting with "-") in the _USERS var are honored as + chat IDs with a deprecation warning. + """ _clear_auth_env(monkeypatch) monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "-1001878443972") @@ -198,6 +302,58 @@ def test_telegram_group_allowlist_authorizes_forum_chat_without_user_allowlist(m assert runner._is_user_authorized(source) is True +def test_telegram_group_users_legacy_does_not_cross_chats(monkeypatch): + """Legacy chat-ID value only authorizes the listed chat, not any group.""" + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "-1001878443972") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + + source = SessionSource( + platform=Platform.TELEGRAM, + user_id="999", + chat_id="-1009999999999", + user_name="tester", + chat_type="group", + ) + + assert runner._is_user_authorized(source) is False + + +def test_telegram_group_users_mixed_sender_and_legacy_chat(monkeypatch): + """Mixed values: positive user ID gates senders; negative chat ID gates chat.""" + _clear_auth_env(monkeypatch) + monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_USERS", "999,-1001878443972") + + runner, _adapter = _make_runner( + Platform.TELEGRAM, + GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}), + ) + + # Legacy chat ID path: any sender in the listed chat is authorized + legacy_chat_source = SessionSource( + platform=Platform.TELEGRAM, + user_id="123", + chat_id="-1001878443972", + user_name="tester", + chat_type="group", + ) + assert runner._is_user_authorized(legacy_chat_source) is True + + # Sender path: listed sender user ID authorized in any group + sender_source = SessionSource( + platform=Platform.TELEGRAM, + user_id="999", + chat_id="-1009999999999", + user_name="tester", + chat_type="group", + ) + assert runner._is_user_authorized(sender_source) is True + + @pytest.mark.asyncio async def test_unauthorized_dm_pairs_by_default(monkeypatch): _clear_auth_env(monkeypatch) diff --git a/tests/gateway/test_weixin.py b/tests/gateway/test_weixin.py index 3a377effbd..506936f711 100644 --- a/tests/gateway/test_weixin.py +++ b/tests/gateway/test_weixin.py @@ -758,3 +758,33 @@ class TestWeixinVoiceSending: assert voice_item["encode_type"] == 6 assert voice_item["sample_rate"] == 24000 assert voice_item["bits_per_sample"] == 16 + + +class TestIsStaleSessionRet: + """Regression test for #17228: distinguish stale-session ret=-2 from rate-limit ret=-2.""" + + def test_ret_minus_2_with_unknown_error_is_stale(self): + assert weixin._is_stale_session_ret(-2, None, "unknown error") is True + + def test_errcode_minus_2_with_unknown_error_is_stale(self): + assert weixin._is_stale_session_ret(None, -2, "unknown error") is True + + def test_unknown_error_case_insensitive(self): + assert weixin._is_stale_session_ret(-2, None, "Unknown Error") is True + + def test_ret_minus_2_with_freq_limit_is_not_stale(self): + # Genuine rate limit — must NOT be treated as stale session. + assert weixin._is_stale_session_ret(-2, None, "freq limit") is False + + def test_ret_minus_2_with_no_errmsg_is_not_stale(self): + assert weixin._is_stale_session_ret(-2, None, None) is False + assert weixin._is_stale_session_ret(-2, None, "") is False + + def test_errcode_minus_14_is_not_matched_here(self): + # -14 is handled by the separate SESSION_EXPIRED_ERRCODE path; the + # helper only disambiguates -2 from a genuine rate limit. + assert weixin._is_stale_session_ret(-14, None, "session expired") is False + + def test_success_codes_are_not_stale(self): + assert weixin._is_stale_session_ret(0, 0, "") is False + assert weixin._is_stale_session_ret(None, None, "unknown error") is False diff --git a/tests/hermes_cli/test_api_key_providers.py b/tests/hermes_cli/test_api_key_providers.py index 530075238f..291b8b70d4 100644 --- a/tests/hermes_cli/test_api_key_providers.py +++ b/tests/hermes_cli/test_api_key_providers.py @@ -1097,3 +1097,63 @@ class TestHuggingFaceModels: from hermes_cli.models import _PROVIDER_LABELS assert "huggingface" in _PROVIDER_LABELS assert _PROVIDER_LABELS["huggingface"] == "Hugging Face" + + +# ============================================================================= +# MiniMax OAuth provider tests (added by feat/minimax-oauth-provider) +# ============================================================================= + +class TestMinimaxOAuthProvider: + """Tests for the minimax-oauth OAuth provider.""" + + def test_minimax_oauth_in_provider_registry(self): + assert "minimax-oauth" in PROVIDER_REGISTRY + pconfig = PROVIDER_REGISTRY["minimax-oauth"] + assert pconfig.auth_type == "oauth_minimax" + assert pconfig.id == "minimax-oauth" + + def test_minimax_oauth_has_correct_endpoints(self): + from hermes_cli.auth import ( + MINIMAX_OAUTH_GLOBAL_BASE, + MINIMAX_OAUTH_GLOBAL_INFERENCE, + MINIMAX_OAUTH_CN_BASE, + MINIMAX_OAUTH_CN_INFERENCE, + ) + pconfig = PROVIDER_REGISTRY["minimax-oauth"] + assert pconfig.portal_base_url == MINIMAX_OAUTH_GLOBAL_BASE + assert pconfig.inference_base_url == MINIMAX_OAUTH_GLOBAL_INFERENCE + assert pconfig.extra["cn_portal_base_url"] == MINIMAX_OAUTH_CN_BASE + assert pconfig.extra["cn_inference_base_url"] == MINIMAX_OAUTH_CN_INFERENCE + + def test_minimax_oauth_alias_resolves_portal(self): + result = resolve_provider("minimax-portal") + assert result == "minimax-oauth" + + def test_minimax_oauth_alias_resolves_global(self): + result = resolve_provider("minimax-global") + assert result == "minimax-oauth" + + def test_minimax_oauth_alias_resolves_underscore(self): + result = resolve_provider("minimax_oauth") + assert result == "minimax-oauth" + + def test_minimax_oauth_listed_in_canonical_providers(self): + from hermes_cli.models import CANONICAL_PROVIDERS + slugs = [p.slug for p in CANONICAL_PROVIDERS] + assert "minimax-oauth" in slugs + + def test_minimax_oauth_models_alias_in_models_py(self): + from hermes_cli.models import _PROVIDER_ALIASES + assert _PROVIDER_ALIASES.get("minimax-portal") == "minimax-oauth" + assert _PROVIDER_ALIASES.get("minimax-global") == "minimax-oauth" + assert _PROVIDER_ALIASES.get("minimax_oauth") == "minimax-oauth" + + def test_minimax_oauth_has_models(self): + from hermes_cli.models import _PROVIDER_MODELS + models = _PROVIDER_MODELS.get("minimax-oauth", []) + assert len(models) >= 1 + + def test_minimax_oauth_aux_model_registered(self): + from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS + assert "minimax-oauth" in _API_KEY_PROVIDER_AUX_MODELS + assert _API_KEY_PROVIDER_AUX_MODELS["minimax-oauth"] # non-empty diff --git a/tests/hermes_cli/test_auth_commands.py b/tests/hermes_cli/test_auth_commands.py index 23602c9f01..824d0608c0 100644 --- a/tests/hermes_cli/test_auth_commands.py +++ b/tests/hermes_cli/test_auth_commands.py @@ -1446,23 +1446,36 @@ def test_seed_custom_pool_respects_config_suppression(tmp_path, monkeypatch): def test_credential_sources_registry_has_expected_steps(): """Sanity check — the registry contains the expected RemovalSteps. - Guards against accidentally dropping a step during future refactors. - If you add a new credential source, add it to the expected set below. + Adding a new credential source is routine, so this is a structural + invariant check (every step has a description, every step is unique, + core steps are present) rather than a frozen snapshot. Frozen + snapshots of catalog-like data violate the AGENTS.md "don't write + change-detector tests" rule — they break every time someone adds a + provider. """ from agent.credential_sources import _REGISTRY - descriptions = {step.description for step in _REGISTRY} - expected = { + descriptions = [step.description for step in _REGISTRY] + # No empty descriptions, no duplicates. + assert all(d for d in descriptions), "Every removal step must have a description" + assert len(descriptions) == len(set(descriptions)), ( + f"Registry has duplicate step descriptions: {descriptions}" + ) + # Core steps must be present — these are the ones the rest of the code + # assumes exist. When deliberately dropping one, update this list. + required = { "gh auth token / COPILOT_GITHUB_TOKEN / GH_TOKEN", "Any env-seeded credential (XAI_API_KEY, DEEPSEEK_API_KEY, etc.)", "~/.claude/.credentials.json", "~/.hermes/.anthropic_oauth.json", "auth.json providers.nous", "auth.json providers.openai-codex + ~/.codex/auth.json", + "auth.json providers.minimax-oauth", "~/.qwen/oauth_creds.json", "Custom provider config.yaml api_key field", } - assert descriptions == expected, f"Registry mismatch. Got: {descriptions}" + missing = required - set(descriptions) + assert not missing, f"Registry missing required steps: {missing}" def test_credential_sources_find_step_returns_none_for_manual(): diff --git a/tests/hermes_cli/test_claw.py b/tests/hermes_cli/test_claw.py index a613b37023..96817320a0 100644 --- a/tests/hermes_cli/test_claw.py +++ b/tests/hermes_cli/test_claw.py @@ -526,6 +526,11 @@ class TestCmdMigrate: class TestCmdCleanup: """Test the cleanup command handler.""" + @pytest.fixture(autouse=True) + def _mock_openclaw_running(self): + with patch.object(claw_mod, "_detect_openclaw_processes", return_value=[]): + yield + def test_no_dirs_found(self, tmp_path, capsys): args = Namespace(source=None, dry_run=False, yes=False) with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[]): diff --git a/tests/hermes_cli/test_config_env_expansion.py b/tests/hermes_cli/test_config_env_expansion.py index 860129ce81..4de3480f73 100644 --- a/tests/hermes_cli/test_config_env_expansion.py +++ b/tests/hermes_cli/test_config_env_expansion.py @@ -72,7 +72,10 @@ class TestLoadConfigExpansion: monkeypatch.setenv("GOOGLE_API_KEY", "gsk-test-key") monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567:ABC-token") - monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file) + # Patch the imported function's own globals. Other tests may reload + # hermes_cli.config, making string-target monkeypatches hit a different + # module object than this collection-time imported load_config(). + monkeypatch.setitem(load_config.__globals__, "get_config_path", lambda: config_file) config = load_config() @@ -86,7 +89,7 @@ class TestLoadConfigExpansion: config_file.write_text(config_yaml) monkeypatch.delenv("NOT_SET_XYZ_123", raising=False) - monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file) + monkeypatch.setitem(load_config.__globals__, "get_config_path", lambda: config_file) config = load_config() diff --git a/tests/hermes_cli/test_container_aware_cli.py b/tests/hermes_cli/test_container_aware_cli.py index 4422df845d..3291fc7cf5 100644 --- a/tests/hermes_cli/test_container_aware_cli.py +++ b/tests/hermes_cli/test_container_aware_cli.py @@ -105,7 +105,7 @@ def test_get_container_exec_info_defaults(): ) with patch("hermes_constants.is_container", return_value=False), \ - patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \ + patch.dict(get_container_exec_info.__globals__, {"get_hermes_home": lambda: hermes_home}), \ patch.dict(os.environ, {}, clear=False): os.environ.pop("HERMES_DEV", None) info = get_container_exec_info() diff --git a/tests/hermes_cli/test_dashboard_browser_safe_imports.py b/tests/hermes_cli/test_dashboard_browser_safe_imports.py index 4c2293503e..05f3a33bc1 100644 --- a/tests/hermes_cli/test_dashboard_browser_safe_imports.py +++ b/tests/hermes_cli/test_dashboard_browser_safe_imports.py @@ -7,9 +7,10 @@ WEB_SRC = Path(__file__).resolve().parents[2] / "web" / "src" def test_dashboard_does_not_import_nous_ui_root_barrel(): offenders = [] - for path in WEB_SRC.rglob("*.tsx"): - content = path.read_text(encoding="utf-8") - if 'from "@nous-research/ui"' in content or "from '@nous-research/ui'" in content: - offenders.append(str(path.relative_to(WEB_SRC))) + for ext in ("*.tsx", "*.ts"): + for path in WEB_SRC.rglob(ext): + content = path.read_text(encoding="utf-8") + if 'from "@nous-research/ui"' in content or "from '@nous-research/ui'" in content: + offenders.append(str(path.relative_to(WEB_SRC))) assert offenders == [] diff --git a/tests/hermes_cli/test_dashboard_lifecycle_flags.py b/tests/hermes_cli/test_dashboard_lifecycle_flags.py new file mode 100644 index 0000000000..c0c505fc33 --- /dev/null +++ b/tests/hermes_cli/test_dashboard_lifecycle_flags.py @@ -0,0 +1,181 @@ +"""Tests for ``hermes dashboard --stop`` / ``--status`` flags. + +These flags share the detection + kill path with the post-``hermes update`` +cleanup, so the heavy coverage of SIGTERM / SIGKILL / Windows taskkill lives +in ``test_update_stale_dashboard.py``. This file just verifies the flag +dispatch: argparse wiring, no-op when nothing is running, and correct +exit codes. +""" + +from __future__ import annotations + +import argparse +import sys +from unittest.mock import patch, MagicMock + +import pytest + +from hermes_cli.main import cmd_dashboard, _report_dashboard_status + + +def _ns(**kw): + """Build an argparse.Namespace with dashboard defaults plus overrides.""" + defaults = dict( + port=9119, host="127.0.0.1", no_open=False, insecure=False, + tui=False, stop=False, status=False, + ) + defaults.update(kw) + return argparse.Namespace(**defaults) + + +class TestDashboardStatus: + def test_status_no_processes(self, capsys): + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(status=True)) + assert exc.value.code == 0 + out = capsys.readouterr().out + assert "No hermes dashboard processes running" in out + + def test_status_with_processes(self, capsys): + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345, 12346]), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(status=True)) + # Status is informational — always exits 0. + assert exc.value.code == 0 + out = capsys.readouterr().out + assert "2 hermes dashboard process(es) running" in out + assert "PID 12345" in out + assert "PID 12346" in out + + def test_status_does_not_try_to_import_fastapi(self): + """`--status` must not require dashboard runtime deps — it's a + process-table scan only. We prove this by making fastapi import + fail and confirming --status still succeeds.""" + orig_import = __import__ + def fake_import(name, *a, **kw): + if name == "fastapi": + raise ImportError("fastapi missing") + return orig_import(name, *a, **kw) + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + patch("builtins.__import__", side_effect=fake_import), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(status=True)) + assert exc.value.code == 0 + + +class TestDashboardStop: + def test_stop_when_nothing_running(self, capsys): + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(stop=True)) + assert exc.value.code == 0 + out = capsys.readouterr().out + assert "No hermes dashboard processes running" in out + + def test_stop_kills_and_exits_zero_when_all_killed(self, capsys): + """After the kill, if the second scan returns empty we exit 0.""" + # First scan: finds two processes. Second (verification) scan: empty. + scans = iter([[12345, 12346], []]) + with patch("hermes_cli.main._find_stale_dashboard_pids", + side_effect=lambda: next(scans)), \ + patch("hermes_cli.main._kill_stale_dashboard_processes") as mock_kill, \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(stop=True)) + mock_kill.assert_called_once() + # --stop should pass a reason so the output doesn't say "running + # backend no longer matches the updated frontend" (that wording is + # for the post-`hermes update` path). + kwargs = mock_kill.call_args.kwargs + assert "reason" in kwargs + assert "stop" in kwargs["reason"].lower() + assert exc.value.code == 0 + + def test_stop_exits_nonzero_if_kill_leaves_survivors(self): + """If the second scan still finds PIDs, we exit 1 so scripts can + detect that the stop didn't succeed (e.g. permission denied).""" + scans = iter([[12345], [12345]]) # both scans find the same PID + with patch("hermes_cli.main._find_stale_dashboard_pids", + side_effect=lambda: next(scans)), \ + patch("hermes_cli.main._kill_stale_dashboard_processes"), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(stop=True)) + assert exc.value.code == 1 + + def test_stop_does_not_try_to_import_fastapi(self): + """Like --status, --stop must work without dashboard runtime deps.""" + orig_import = __import__ + def fake_import(name, *a, **kw): + if name == "fastapi": + raise ImportError("fastapi missing") + return orig_import(name, *a, **kw) + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + patch("builtins.__import__", side_effect=fake_import), \ + pytest.raises(SystemExit) as exc: + cmd_dashboard(_ns(stop=True)) + assert exc.value.code == 0 + + +class TestLifecycleFlagsTakePrecedence: + """If both --stop and --status are set, --status wins (it's listed + first in cmd_dashboard). Neither is allowed to fall through to the + server-start path, which is the critical safety property — a user + who typed ``hermes dashboard --stop`` must not end up ALSO starting + a new server.""" + + def test_status_wins_over_stop(self, capsys): + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + patch("hermes_cli.main._kill_stale_dashboard_processes") as mock_kill, \ + pytest.raises(SystemExit): + cmd_dashboard(_ns(status=True, stop=True)) + # Kill path must NOT run when --status is also set. + mock_kill.assert_not_called() + + def test_stop_does_not_fall_through_to_server_start(self): + """Covers the worst-case regression: if --stop ever stopped exiting + early, the user would start the dashboard they just asked to stop.""" + called = {"start": False} + def fake_start_server(**kw): + called["start"] = True + + # Provide a fake web_server module so the import doesn't matter. + fake_ws = MagicMock() + fake_ws.start_server = fake_start_server + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + patch.dict(sys.modules, {"hermes_cli.web_server": fake_ws}), \ + pytest.raises(SystemExit): + cmd_dashboard(_ns(stop=True)) + assert called["start"] is False + + +class TestArgparseWiring: + """Confirm the flags are exposed via the real argparse tree so + ``hermes dashboard --stop`` / ``--status`` actually parse.""" + + def test_flags_are_registered(self): + from hermes_cli.main import main as _cli_main # noqa: F401 + # Rebuild the argparse tree by re-running the section of main() + # that builds it. Cheapest way: introspect via --help on the + # already-built parser would require refactoring; instead we + # parse the flags directly via a minimal replay. + import importlib + mod = importlib.import_module("hermes_cli.main") + # Find the dashboard_parser instance by running build logic would + # be too invasive. Instead parse args as if via the CLI by + # intercepting parse_args. This is overkill for a smoke test — + # we just want to know the flags don't KeyError. + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[]), \ + pytest.raises(SystemExit) as exc: + mod.cmd_dashboard(_ns(status=True)) + assert exc.value.code == 0 diff --git a/tests/hermes_cli/test_doctor.py b/tests/hermes_cli/test_doctor.py index c696e105a5..5fafcb81f6 100644 --- a/tests/hermes_cli/test_doctor.py +++ b/tests/hermes_cli/test_doctor.py @@ -161,6 +161,38 @@ def test_check_gateway_service_linger_skips_when_service_not_installed(monkeypat assert issues == [] +def test_doctor_reports_vercel_backend_diagnostics(monkeypatch, tmp_path): + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", "python3.13") + monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "2048") + monkeypatch.setenv("VERCEL_TOKEN", "super-secret-value") + monkeypatch.delenv("VERCEL_PROJECT_ID", raising=False) + monkeypatch.setenv("VERCEL_TEAM_ID", "team") + monkeypatch.setattr(doctor_mod.importlib.util, "find_spec", lambda name: object() if name == "vercel" else None) + + fake_model_tools = types.SimpleNamespace( + check_tool_availability=lambda *a, **kw: ([], []), + TOOLSET_REQUIREMENTS={}, + ) + monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools) + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + doctor_mod.run_doctor(Namespace(fix=False)) + + out = buf.getvalue() + assert "Vercel runtime" in out + assert "python3.13" in out + assert "Vercel custom disk unsupported" in out + assert "Vercel auth incomplete" in out + assert "VERCEL_PROJECT_ID" in out + assert "Vercel auth mode: incomplete access token" in out + assert "Vercel auth present env: VERCEL_TOKEN, VERCEL_TEAM_ID" in out + assert "Vercel auth missing env: VERCEL_PROJECT_ID" in out + assert "super-secret-value" not in out + assert "snapshot filesystem only" in out + + # ── Memory provider section (doctor should only check the *active* provider) ── diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index bd429bff2b..f2bfa8b870 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -14,6 +14,26 @@ from gateway.restart import ( ) +class TestUserSystemdPrivateSocketPreflight: + def test_preflight_accepts_private_socket_without_dbus_bus(self, monkeypatch): + monkeypatch.setattr(gateway_cli, "_ensure_user_systemd_env", lambda: None) + monkeypatch.setattr(gateway_cli, "_user_dbus_socket_path", lambda: Path("/tmp/missing-bus")) + monkeypatch.setattr(gateway_cli, "_user_systemd_private_socket_path", lambda: Path("/tmp/private-socket")) + monkeypatch.setattr(Path, "exists", lambda self: str(self) == "/tmp/private-socket") + + gateway_cli._preflight_user_systemd(auto_enable_linger=False) + + def test_wait_for_user_dbus_socket_accepts_private_socket(self, monkeypatch): + calls = [] + monkeypatch.setattr(gateway_cli, "_ensure_user_systemd_env", lambda: calls.append("env")) + monkeypatch.setattr(gateway_cli, "_user_dbus_socket_path", lambda: Path("/tmp/missing-bus")) + monkeypatch.setattr(gateway_cli, "_user_systemd_private_socket_path", lambda: Path("/tmp/private-socket")) + monkeypatch.setattr(Path, "exists", lambda self: str(self) == "/tmp/private-socket") + + assert gateway_cli._wait_for_user_dbus_socket(timeout=0.1) is True + assert calls == ["env"] + + class TestSystemdServiceRefresh: def test_systemd_install_repairs_outdated_unit_without_force(self, tmp_path, monkeypatch): unit_path = tmp_path / "hermes-gateway.service" @@ -235,7 +255,8 @@ class TestLaunchdServiceRecovery: target = f"{domain}/{label}" def fake_run(cmd, check=False, **kwargs): - calls.append(cmd) + if cmd and cmd[0] == "launchctl": + calls.append(cmd) if cmd == ["launchctl", "kickstart", target] and calls.count(cmd) == 1: raise gateway_cli.subprocess.CalledProcessError(3, cmd, stderr="Could not find service") return SimpleNamespace(returncode=0, stdout="", stderr="") @@ -262,7 +283,8 @@ class TestLaunchdServiceRecovery: target = f"{domain}/{label}" def fake_run(cmd, check=False, **kwargs): - calls.append(cmd) + if cmd and cmd[0] == "launchctl": + calls.append(cmd) if cmd == ["launchctl", "kickstart", target] and calls.count(cmd) == 1: raise gateway_cli.subprocess.CalledProcessError(113, cmd, stderr="Could not find service") return SimpleNamespace(returncode=0, stdout="", stderr="") @@ -1105,6 +1127,10 @@ class TestPreflightUserSystemd: gateway_cli, "_user_dbus_socket_path", lambda: type("P", (), {"exists": lambda self: True})(), ) + monkeypatch.setattr( + gateway_cli, "_user_systemd_private_socket_path", + lambda: type("P", (), {"exists": lambda self: False})(), + ) # Should not raise, no subprocess calls needed. gateway_cli._preflight_user_systemd() @@ -1114,6 +1140,10 @@ class TestPreflightUserSystemd: gateway_cli, "_user_dbus_socket_path", lambda: type("P", (), {"exists": lambda self: False})(), ) + monkeypatch.setattr( + gateway_cli, "_user_systemd_private_socket_path", + lambda: type("P", (), {"exists": lambda self: False})(), + ) monkeypatch.setattr( gateway_cli, "get_systemd_linger_status", lambda: (False, ""), ) @@ -1142,6 +1172,10 @@ class TestPreflightUserSystemd: gateway_cli, "_user_dbus_socket_path", lambda: type("P", (), {"exists": lambda self: False})(), ) + monkeypatch.setattr( + gateway_cli, "_user_systemd_private_socket_path", + lambda: type("P", (), {"exists": lambda self: False})(), + ) monkeypatch.setattr( gateway_cli, "get_systemd_linger_status", lambda: (None, "loginctl not found"), @@ -1159,6 +1193,10 @@ class TestPreflightUserSystemd: gateway_cli, "_user_dbus_socket_path", lambda: type("P", (), {"exists": lambda self: False})(), ) + monkeypatch.setattr( + gateway_cli, "_user_systemd_private_socket_path", + lambda: type("P", (), {"exists": lambda self: False})(), + ) monkeypatch.setattr( gateway_cli, "get_systemd_linger_status", lambda: (True, ""), ) @@ -1177,6 +1215,10 @@ class TestPreflightUserSystemd: gateway_cli, "_user_dbus_socket_path", lambda: type("P", (), {"exists": lambda self: False})(), ) + monkeypatch.setattr( + gateway_cli, "_user_systemd_private_socket_path", + lambda: type("P", (), {"exists": lambda self: False})(), + ) monkeypatch.setattr( gateway_cli, "get_systemd_linger_status", lambda: (False, ""), ) diff --git a/tests/hermes_cli/test_ignore_user_config_flags.py b/tests/hermes_cli/test_ignore_user_config_flags.py index 3d5336cfca..6073877932 100644 --- a/tests/hermes_cli/test_ignore_user_config_flags.py +++ b/tests/hermes_cli/test_ignore_user_config_flags.py @@ -224,22 +224,21 @@ class TestArgparseFlagsRegistered: assert args.ignore_rules is True def test_main_py_registers_both_flags(self): - """E2E: the real hermes_cli/main.py parser accepts both flags. + """E2E: the real hermes parser accepts both flags.""" + from hermes_cli._parser import build_top_level_parser - We invoke the real argparse tree builder from hermes_cli.main. - """ - import hermes_cli.main as hm + parser, _subparsers, chat_parser = build_top_level_parser() + + top_dests = {a.dest for a in parser._actions} + chat_dests = {a.dest for a in chat_parser._actions} + assert "ignore_user_config" in top_dests + assert "ignore_rules" in top_dests + assert "ignore_user_config" in chat_dests + assert "ignore_rules" in chat_dests - # hm has a helper that builds the argparse tree inside main(). - # We can extract it by catching the SystemExit on --help. - # Simpler: just grep the source for the flag strings. Both approaches - # are brittle; we use a combined test. - import inspect - src = inspect.getsource(hm) - assert '"--ignore-user-config"' in src, \ - "chat subparser must register --ignore-user-config" - assert '"--ignore-rules"' in src, \ - "chat subparser must register --ignore-rules" # And the cmd_chat env-var wiring must be present + import inspect + import hermes_cli.main as hm + src = inspect.getsource(hm) assert "HERMES_IGNORE_USER_CONFIG" in src assert "HERMES_IGNORE_RULES" in src diff --git a/tests/hermes_cli/test_mcp_reload_confirm_gate.py b/tests/hermes_cli/test_mcp_reload_confirm_gate.py new file mode 100644 index 0000000000..871f46fe7e --- /dev/null +++ b/tests/hermes_cli/test_mcp_reload_confirm_gate.py @@ -0,0 +1,91 @@ +"""Tests for the approvals.mcp_reload_confirm config gate. + +When the user runs /reload-mcp, the MCP tool set is rebuilt which +invalidates the provider prompt cache for the active session. That's +expensive on long-context / high-reasoning models. The config gate +adds a three-option confirmation (Approve Once / Always Approve / +Cancel); "Always Approve" flips this key to false so subsequent reloads +run silently. +""" + +from __future__ import annotations + +from copy import deepcopy + +from hermes_cli.config import DEFAULT_CONFIG + + +class TestMcpReloadConfirmDefault: + def test_default_config_has_the_key(self): + approvals = DEFAULT_CONFIG.get("approvals") + assert isinstance(approvals, dict) + assert "mcp_reload_confirm" in approvals + + def test_default_is_true(self): + # New installs confirm by default — this is the safe behavior. + assert DEFAULT_CONFIG["approvals"]["mcp_reload_confirm"] is True + + def test_shape_matches_other_approval_keys(self): + # Same flat dict level as `mode` / `timeout` / `cron_mode`. + approvals = DEFAULT_CONFIG["approvals"] + assert isinstance(approvals.get("mode"), str) + assert isinstance(approvals.get("timeout"), int) + assert isinstance(approvals.get("cron_mode"), str) + assert isinstance(approvals.get("mcp_reload_confirm"), bool) + + +class TestUserConfigMerge: + """If a user has a pre-existing config without this key, load_config + should fill it in from DEFAULT_CONFIG (deep merge preserves keys the + user didn't override). + """ + + def test_existing_user_config_without_key_gets_default(self, tmp_path, monkeypatch): + import yaml + + # Simulate a legacy user config without the new key. + home = tmp_path / ".hermes" + home.mkdir() + cfg_path = home / "config.yaml" + legacy = { + "approvals": {"mode": "manual", "timeout": 60, "cron_mode": "deny"}, + } + cfg_path.write_text(yaml.safe_dump(legacy)) + + monkeypatch.setenv("HERMES_HOME", str(home)) + # Force a fresh reimport of config.py so the HERMES_HOME is honored. + import importlib + import hermes_cli.config as cfg_mod + importlib.reload(cfg_mod) + + cfg = cfg_mod.load_config() + assert cfg["approvals"]["mcp_reload_confirm"] is True + + def test_existing_user_config_with_false_key_survives_merge( + self, tmp_path, monkeypatch, + ): + """A user who has clicked "Always Approve" (key=false) must keep + that setting across reloads — the default_true value must not win. + """ + import yaml + + home = tmp_path / ".hermes" + home.mkdir() + cfg_path = home / "config.yaml" + user_cfg = { + "approvals": { + "mode": "manual", + "timeout": 60, + "cron_mode": "deny", + "mcp_reload_confirm": False, + }, + } + cfg_path.write_text(yaml.safe_dump(user_cfg)) + + monkeypatch.setenv("HERMES_HOME", str(home)) + import importlib + import hermes_cli.config as cfg_mod + importlib.reload(cfg_mod) + + cfg = cfg_mod.load_config() + assert cfg["approvals"]["mcp_reload_confirm"] is False diff --git a/tests/hermes_cli/test_profiles.py b/tests/hermes_cli/test_profiles.py index bcf4da244e..9177930f22 100644 --- a/tests/hermes_cli/test_profiles.py +++ b/tests/hermes_cli/test_profiles.py @@ -188,6 +188,23 @@ class TestCreateProfile: assert not (profile_dir / "gateway_state.json").exists() assert not (profile_dir / "processes.json").exists() + def test_clone_all_excludes_sibling_profiles_tree(self, profile_env): + """--clone-all from default ~/.hermes must not copy profiles/* (nested explosion).""" + tmp_path = profile_env + default_home = tmp_path / ".hermes" + profiles_root = default_home / "profiles" + profiles_root.mkdir(exist_ok=True) + (profiles_root / "other").mkdir(parents=True, exist_ok=True) + (profiles_root / "other" / "marker.txt").write_text("sibling data") + + (default_home / "memories").mkdir(exist_ok=True) + (default_home / "memories" / "note.md").write_text("remember this") + + profile_dir = create_profile("coder", clone_all=True, no_alias=True) + + assert (profile_dir / "memories" / "note.md").read_text() == "remember this" + assert not (profile_dir / "profiles").exists() + def test_clone_config_missing_files_skipped(self, profile_env): """Clone config gracefully skips files that don't exist in source.""" profile_dir = create_profile("coder", clone_config=True, no_alias=True) diff --git a/tests/hermes_cli/test_pty_bridge.py b/tests/hermes_cli/test_pty_bridge.py index cd6983b90c..054f5a8d80 100644 --- a/tests/hermes_cli/test_pty_bridge.py +++ b/tests/hermes_cli/test_pty_bridge.py @@ -96,10 +96,17 @@ class TestPtyBridgeIO: @skip_on_windows class TestPtyBridgeResize: def test_resize_updates_child_winsize(self): - # tput reads COLUMNS/LINES from the TTY ioctl (TIOCGWINSZ). - # Spawn a shell, resize, then ask tput for the dimensions. + # Query the TTY ioctl directly instead of using tput, which requires + # TERM and fails in GitHub Actions' non-interactive environment. + winsize_script = ( + "import fcntl, struct, termios, time; " + "time.sleep(0.1); " + "rows, cols, *_ = struct.unpack('HHHH', " + "fcntl.ioctl(0, termios.TIOCGWINSZ, b'\\0' * 8)); " + "print(cols); print(rows)" + ) bridge = PtyBridge.spawn( - ["/bin/sh", "-c", "sleep 0.1; tput cols; tput lines"], + [sys.executable, "-c", winsize_script], cols=80, rows=24, ) diff --git a/tests/hermes_cli/test_relaunch.py b/tests/hermes_cli/test_relaunch.py new file mode 100644 index 0000000000..33b3ffb4b3 --- /dev/null +++ b/tests/hermes_cli/test_relaunch.py @@ -0,0 +1,155 @@ +"""Tests for hermes_cli.relaunch — unified self-relaunch utility.""" + +import sys + +import pytest + +from hermes_cli import relaunch as relaunch_mod + + +class TestResolveHermesBin: + def test_prefers_absolute_argv0_when_executable(self, monkeypatch): + fake = "/nix/store/abc/bin/hermes" + monkeypatch.setattr(sys, "argv", [fake]) + monkeypatch.setattr(relaunch_mod.os.path, "isfile", lambda p: p == fake) + monkeypatch.setattr(relaunch_mod.os, "access", lambda p, mode: p == fake) + assert relaunch_mod.resolve_hermes_bin() == fake + + def test_resolves_relative_argv0(self, monkeypatch, tmp_path): + fake = tmp_path / "hermes" + fake.write_text("#!/bin/sh\n") + fake.chmod(0o755) + monkeypatch.setattr(sys, "argv", [str(fake.name)]) + monkeypatch.chdir(tmp_path) + # Ensure we don't accidentally match a real 'hermes' on PATH + monkeypatch.setattr(relaunch_mod.shutil, "which", lambda _name: None) + assert relaunch_mod.resolve_hermes_bin() == str(fake) + + def test_falls_back_to_path_which(self, monkeypatch): + monkeypatch.setattr(sys, "argv", ["-c"]) # not a real path + monkeypatch.setattr( + relaunch_mod.shutil, "which", lambda name: "/usr/bin/hermes" if name == "hermes" else None + ) + assert relaunch_mod.resolve_hermes_bin() == "/usr/bin/hermes" + + def test_returns_none_when_unresolvable(self, monkeypatch): + monkeypatch.setattr(sys, "argv", ["-c"]) + monkeypatch.setattr(relaunch_mod.shutil, "which", lambda _name: None) + assert relaunch_mod.resolve_hermes_bin() is None + + +class TestExtractInheritedFlags: + def test_extracts_tui_and_dev(self): + argv = ["--tui", "--dev", "chat"] + assert relaunch_mod._extract_inherited_flags(argv) == ["--tui", "--dev"] + + def test_extracts_profile_with_value(self): + argv = ["--profile", "work", "chat"] + assert relaunch_mod._extract_inherited_flags(argv) == ["--profile", "work"] + + def test_extracts_short_p_with_value(self): + argv = ["-p", "work"] + assert relaunch_mod._extract_inherited_flags(argv) == ["-p", "work"] + + def test_extracts_equals_form(self): + argv = ["--profile=work", "--model=anthropic/claude-sonnet-4"] + assert relaunch_mod._extract_inherited_flags(argv) == [ + "--profile=work", + "--model=anthropic/claude-sonnet-4", + ] + + def test_skips_unknown_flags(self): + argv = ["--foo", "bar", "--tui"] + assert relaunch_mod._extract_inherited_flags(argv) == ["--tui"] + + def test_does_not_consume_flag_like_value(self): + argv = ["--tui", "--resume", "abc123"] + assert relaunch_mod._extract_inherited_flags(argv) == ["--tui"] + + def test_preserves_multiple_skills(self): + argv = ["-s", "foo", "-s", "bar", "--tui"] + assert relaunch_mod._extract_inherited_flags(argv) == ["-s", "foo", "-s", "bar", "--tui"] + + +class TestInheritedFlagTable: + """Sanity-check the argparse-introspected table that drives extraction.""" + + def test_short_and_long_aliases_are_paired(self): + table = dict(relaunch_mod._INHERITED_FLAGS_TABLE) + # Each pair declared together in the parser shares takes_value. + for short, long_ in [ + ("-p", "--profile"), + ("-m", "--model"), + ("-s", "--skills"), + ]: + assert table[short] == table[long_], f"{short}/{long_} disagree" + + def test_store_true_flags_do_not_take_value(self): + table = dict(relaunch_mod._INHERITED_FLAGS_TABLE) + for flag in ["--tui", "--dev", "--yolo", "--ignore-user-config", "--ignore-rules"]: + assert table[flag] is False, f"{flag} should not take a value" + + def test_value_flags_take_value(self): + table = dict(relaunch_mod._INHERITED_FLAGS_TABLE) + for flag in ["--profile", "--model", "--provider", "--skills"]: + assert table[flag] is True, f"{flag} should take a value" + + def test_excluded_flags_are_not_inherited(self): + table = dict(relaunch_mod._INHERITED_FLAGS_TABLE) + # --worktree creates a new worktree per process; inheriting would + # orphan the parent's. Chat-only flags (--quiet/-Q, --verbose/-v, + # --source) can't be in argv at the existing relaunch callsites. + for flag in ["-w", "--worktree", "-Q", "--quiet", "-v", "--verbose", "--source"]: + assert flag not in table, f"{flag} should not be inherited" + + +class TestBuildRelaunchArgv: + def test_uses_bin_when_available(self, monkeypatch): + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes") + argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"]) + assert argv[0] == "/usr/bin/hermes" + + def test_falls_back_to_python_module(self, monkeypatch): + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: None) + argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"]) + assert argv == [sys.executable, "-m", "hermes_cli.main", "--resume", "abc"] + + def test_preserves_inherited_flags(self, monkeypatch): + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes") + original = ["--tui", "--dev", "--profile", "work", "sessions", "browse"] + argv = relaunch_mod.build_relaunch_argv(["--resume", "abc"], original_argv=original) + assert "--tui" in argv + assert "--dev" in argv + assert "--profile" in argv + assert "work" in argv + assert "--resume" in argv + assert "abc" in argv + # The original subcommand should not survive + assert "sessions" not in argv + assert "browse" not in argv + + def test_can_disable_preserve(self, monkeypatch): + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes") + original = ["--tui", "chat"] + argv = relaunch_mod.build_relaunch_argv( + ["--resume", "abc"], preserve_inherited=False, original_argv=original + ) + assert "--tui" not in argv + assert argv == ["/usr/bin/hermes", "--resume", "abc"] + + +class TestRelaunch: + def test_calls_execvp(self, monkeypatch): + calls = [] + + def fake_execvp(path, argv): + calls.append((path, argv)) + raise SystemExit(0) + + monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp) + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/bin/hermes") + + with pytest.raises(SystemExit): + relaunch_mod.relaunch(["--resume", "abc"]) + + assert calls == [("/usr/bin/hermes", ["/usr/bin/hermes", "--resume", "abc"])] \ No newline at end of file diff --git a/tests/hermes_cli/test_runtime_provider_resolution.py b/tests/hermes_cli/test_runtime_provider_resolution.py index a30cbaecdc..c7adfe1482 100644 --- a/tests/hermes_cli/test_runtime_provider_resolution.py +++ b/tests/hermes_cli/test_runtime_provider_resolution.py @@ -1998,6 +1998,7 @@ class TestAzureAnthropicEnvVarHint: assert resolved["api_key"] == "fallback-works" + def test_no_key_anywhere_raises_helpful_error(self, monkeypatch): """When nothing resolves, the error message mentions key_env as an option.""" monkeypatch.delenv("AZURE_ANTHROPIC_KEY", raising=False) @@ -2168,3 +2169,67 @@ class TestTencentTokenhubRuntimeResolution: assert resolved["base_url"] == "https://explicit-proxy.example.com/v1" assert resolved["source"] == "explicit" +# --------------------------------------------------------------------------- +# minimax-oauth runtime resolution tests (added by feat/minimax-oauth-provider) +# --------------------------------------------------------------------------- + +def test_minimax_oauth_runtime_returns_anthropic_messages_mode(monkeypatch): + """resolve_runtime_provider for minimax-oauth must return api_mode='anthropic_messages'.""" + from hermes_cli.auth import MINIMAX_OAUTH_GLOBAL_INFERENCE + + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-oauth") + monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "minimax-oauth"}) + monkeypatch.setattr(rp, "load_pool", lambda provider: None) + monkeypatch.setattr( + rp, + "_resolve_named_custom_runtime", + lambda **k: None, + ) + monkeypatch.setattr( + rp, + "_resolve_explicit_runtime", + lambda **k: None, + ) + + fake_creds = { + "provider": "minimax-oauth", + "api_key": "mock-access-token", + "base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE.rstrip("/"), + "source": "oauth", + } + + import hermes_cli.auth as auth_mod + monkeypatch.setattr(auth_mod, "resolve_minimax_oauth_runtime_credentials", + lambda **k: fake_creds) + + resolved = rp.resolve_runtime_provider(requested="minimax-oauth") + + assert resolved["provider"] == "minimax-oauth" + assert resolved["api_mode"] == "anthropic_messages" + assert resolved["api_key"] == "mock-access-token" + + +def test_minimax_oauth_runtime_uses_inference_base_url(monkeypatch): + """Base URL returned by resolve_runtime_provider should match the OAuth credentials.""" + from hermes_cli.auth import MINIMAX_OAUTH_CN_INFERENCE + + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-oauth") + monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "minimax-oauth"}) + monkeypatch.setattr(rp, "load_pool", lambda provider: None) + monkeypatch.setattr(rp, "_resolve_named_custom_runtime", lambda **k: None) + monkeypatch.setattr(rp, "_resolve_explicit_runtime", lambda **k: None) + + fake_creds = { + "provider": "minimax-oauth", + "api_key": "cn-token", + "base_url": MINIMAX_OAUTH_CN_INFERENCE.rstrip("/"), + "source": "oauth", + } + + import hermes_cli.auth as auth_mod + monkeypatch.setattr(auth_mod, "resolve_minimax_oauth_runtime_credentials", + lambda **k: fake_creds) + + resolved = rp.resolve_runtime_provider(requested="minimax-oauth") + + assert MINIMAX_OAUTH_CN_INFERENCE.rstrip("/") in resolved["base_url"] diff --git a/tests/hermes_cli/test_set_config_value.py b/tests/hermes_cli/test_set_config_value.py index fbd71dbb53..adbd0ae1e0 100644 --- a/tests/hermes_cli/test_set_config_value.py +++ b/tests/hermes_cli/test_set_config_value.py @@ -127,6 +127,13 @@ class TestConfigYamlRouting: or "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=True" in env_content ) + def test_terminal_vercel_runtime_goes_to_config_and_env(self, _isolated_hermes_home): + set_config_value("terminal.vercel_runtime", "python3.13") + config = _read_config(_isolated_hermes_home) + env_content = _read_env(_isolated_hermes_home) + assert "vercel_runtime: python3.13" in config + assert "TERMINAL_VERCEL_RUNTIME=python3.13" in env_content + # --------------------------------------------------------------------------- # Empty / falsy values — regression tests for #4277 diff --git a/tests/hermes_cli/test_setup.py b/tests/hermes_cli/test_setup.py index 03b4068755..72adc27c0c 100644 --- a/tests/hermes_cli/test_setup.py +++ b/tests/hermes_cli/test_setup.py @@ -1,5 +1,6 @@ """Tests for setup.py configuration flows.""" import json +import os import sys import types @@ -29,6 +30,17 @@ def _clear_provider_env(monkeypatch): monkeypatch.delenv(key, raising=False) +def _clear_vercel_env(monkeypatch): + for key in ( + "TERMINAL_VERCEL_RUNTIME", + "VERCEL_OIDC_TOKEN", + "VERCEL_TOKEN", + "VERCEL_PROJECT_ID", + "VERCEL_TEAM_ID", + ): + monkeypatch.delenv(key, raising=False) + + def _stub_tts(monkeypatch): """Stub out TTS prompts so setup_model_provider doesn't block.""" monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda q, c, d=0: ( @@ -162,12 +174,13 @@ def test_setup_gateway_skips_service_install_when_systemctl_missing(monkeypatch, "WEBHOOK_ENABLED": "", } + import hermes_cli.gateway as gateway_mod + monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, "")) + monkeypatch.setattr(gateway_mod, "get_env_value", lambda key: env.get(key, "")) monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False) monkeypatch.setattr("platform.system", lambda: "Linux") - import hermes_cli.gateway as gateway_mod - monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False) monkeypatch.setattr(gateway_mod, "is_macos", lambda: False) monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False) @@ -200,12 +213,13 @@ def test_setup_gateway_in_container_shows_docker_guidance(monkeypatch, capsys): "WEBHOOK_ENABLED": "", } + import hermes_cli.gateway as gateway_mod + monkeypatch.setattr(setup_mod, "get_env_value", lambda key: env.get(key, "")) + monkeypatch.setattr(gateway_mod, "get_env_value", lambda key: env.get(key, "")) monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *args, **kwargs: False) monkeypatch.setattr("platform.system", lambda: "Linux") - import hermes_cli.gateway as gateway_mod - monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False) monkeypatch.setattr(gateway_mod, "is_macos", lambda: False) monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False) @@ -480,28 +494,91 @@ def test_modal_setup_persists_direct_mode_when_user_chooses_their_own_account(tm assert config["terminal"]["modal_mode"] == "direct" -def test_resolve_hermes_chat_argv_prefers_which(monkeypatch): - from hermes_cli import setup as setup_mod - - monkeypatch.setattr(setup_mod.shutil, "which", lambda name: "/usr/local/bin/hermes" if name == "hermes" else None) - - assert setup_mod._resolve_hermes_chat_argv() == ["/usr/local/bin/hermes", "chat"] - - -def test_resolve_hermes_chat_argv_falls_back_to_module(monkeypatch): - from hermes_cli import setup as setup_mod - - monkeypatch.setattr(setup_mod.shutil, "which", lambda _name: None) - monkeypatch.setattr(setup_mod.importlib.util, "find_spec", lambda name: object() if name == "hermes_cli" else None) - - assert setup_mod._resolve_hermes_chat_argv() == [sys.executable, "-m", "hermes_cli.main", "chat"] - - -def test_offer_launch_chat_execs_fresh_process(monkeypatch): +def test_vercel_setup_configures_access_token_auth(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _clear_vercel_env(monkeypatch) + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "old-oidc") + monkeypatch.setitem(sys.modules, "vercel", types.ModuleType("vercel")) + config = load_config() + + def fake_prompt_choice(question, choices, default=0): + if question == "Select terminal backend:": + return 5 + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + prompt_values = iter(["python3.13", "yes", "2", "4096", "token", "project", "team"]) + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) + monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: next(prompt_values)) + + from hermes_cli.setup import setup_terminal_backend + + setup_terminal_backend(config) + + assert config["terminal"]["backend"] == "vercel_sandbox" + assert config["terminal"]["vercel_runtime"] == "python3.13" + assert config["terminal"]["container_disk"] == 51200 + assert os.environ["TERMINAL_VERCEL_RUNTIME"] == "python3.13" + assert "VERCEL_OIDC_TOKEN" not in os.environ + assert os.environ["VERCEL_TOKEN"] == "token" + assert os.environ["VERCEL_PROJECT_ID"] == "project" + assert os.environ["VERCEL_TEAM_ID"] == "team" + + +def test_vercel_setup_prefills_project_and_team_from_link_file(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _clear_vercel_env(monkeypatch) + project_root = tmp_path / "project" + nested = project_root / "app" / "src" + nested.mkdir(parents=True) + vercel_dir = project_root / ".vercel" + vercel_dir.mkdir() + (vercel_dir / "project.json").write_text( + json.dumps({"projectId": "linked-project", "orgId": "linked-team"}), + encoding="utf-8", + ) + monkeypatch.chdir(nested) + monkeypatch.setitem(sys.modules, "vercel", types.ModuleType("vercel")) + config = load_config() + config["terminal"]["container_disk"] = 999 + + def fake_prompt_choice(question, choices, default=0): + if question == "Select terminal backend:": + return 5 + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + prompt_values = iter(["node24", "no", "1", "5120", "token", "", ""]) + defaults = {} + + def fake_prompt(message, default="", **kwargs): + defaults[message] = default + value = next(prompt_values) + return value or default + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) + monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt) + + from hermes_cli.setup import setup_terminal_backend + + setup_terminal_backend(config) + + assert config["terminal"]["backend"] == "vercel_sandbox" + assert config["terminal"]["container_persistent"] is False + assert config["terminal"]["container_disk"] == 51200 + assert "VERCEL_OIDC_TOKEN" not in os.environ + assert os.environ["VERCEL_TOKEN"] == "token" + assert os.environ["VERCEL_PROJECT_ID"] == "linked-project" + assert os.environ["VERCEL_TEAM_ID"] == "linked-team" + assert defaults[" Vercel project ID"] == "linked-project" + assert defaults[" Vercel team ID"] == "linked-team" + + +def test_offer_launch_chat_relaunches_via_bin(monkeypatch): from hermes_cli import setup as setup_mod + from hermes_cli import relaunch as relaunch_mod monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *_args, **_kwargs: True) - monkeypatch.setattr(setup_mod, "_resolve_hermes_chat_argv", lambda: ["/usr/local/bin/hermes", "chat"]) + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: "/usr/local/bin/hermes") exec_calls = [] @@ -509,7 +586,7 @@ def test_offer_launch_chat_execs_fresh_process(monkeypatch): exec_calls.append((path, argv)) raise SystemExit(0) - monkeypatch.setattr(setup_mod.os, "execvp", fake_execvp) + monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp) with pytest.raises(SystemExit): setup_mod._offer_launch_chat() @@ -517,13 +594,22 @@ def test_offer_launch_chat_execs_fresh_process(monkeypatch): assert exec_calls == [("/usr/local/bin/hermes", ["/usr/local/bin/hermes", "chat"])] -def test_offer_launch_chat_manual_fallback_when_unresolvable(monkeypatch, capsys): +def test_offer_launch_chat_falls_back_to_module(monkeypatch): from hermes_cli import setup as setup_mod + from hermes_cli import relaunch as relaunch_mod monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *_args, **_kwargs: True) - monkeypatch.setattr(setup_mod, "_resolve_hermes_chat_argv", lambda: None) + monkeypatch.setattr(relaunch_mod, "resolve_hermes_bin", lambda: None) - setup_mod._offer_launch_chat() + exec_calls = [] - captured = capsys.readouterr() - assert "Run 'hermes chat' manually" in captured.out + def fake_execvp(path, argv): + exec_calls.append((path, argv)) + raise SystemExit(0) + + monkeypatch.setattr(relaunch_mod.os, "execvp", fake_execvp) + + with pytest.raises(SystemExit): + setup_mod._offer_launch_chat() + + assert exec_calls == [(sys.executable, [sys.executable, "-m", "hermes_cli.main", "chat"])] diff --git a/tests/hermes_cli/test_setup_irc.py b/tests/hermes_cli/test_setup_irc.py new file mode 100644 index 0000000000..1e5baa5cc0 --- /dev/null +++ b/tests/hermes_cli/test_setup_irc.py @@ -0,0 +1,245 @@ +"""Tests for IRC gateway configuration via `hermes setup gateway` UI. + +Covers the full plugin-platform discovery → status → configure flow so that +a fresh Hermes install (no state, no env vars) can set up IRC through the +interactive setup menus. +""" + +import os +import pytest + +from gateway.platform_registry import PlatformEntry, platform_registry + + +def _register_irc_platform(**overrides): + """Manually register the IRC platform entry as if discover_plugins() found it. + + Tests run outside the normal plugin-discovery path, so we inject the entry + directly into the singleton registry and yield its dict shape. + """ + defaults = dict( + name="irc", + label="IRC", + adapter_factory=lambda cfg: None, + check_fn=lambda: bool(os.getenv("IRC_SERVER", "") and os.getenv("IRC_CHANNEL", "")), + validate_config=None, + required_env=["IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"], + install_hint="No extra packages needed (stdlib only)", + setup_fn=lambda: None, + source="plugin", + plugin_name="irc_platform", + allowed_users_env="IRC_ALLOWED_USERS", + allow_all_env="IRC_ALLOW_ALL_USERS", + max_message_length=450, + pii_safe=False, + emoji="💬", + allow_update_command=True, + platform_hint="You are chatting via IRC.", + ) + defaults.update(overrides) + entry = PlatformEntry(**defaults) + platform_registry.register(entry) + return { + "key": entry.name, + "label": entry.label, + "emoji": entry.emoji, + "token_var": entry.required_env[0] if entry.required_env else "", + "install_hint": entry.install_hint, + "_registry_entry": entry, + } + + +def _unregister_irc_platform(): + platform_registry.unregister("irc") + + +# ── Fresh-install discovery ───────────────────────────────────────────────── + + +class TestIRCFreshInstallDiscovery: + """IRC appears in the setup menu on a brand-new Hermes install.""" + + def test_irc_appears_in_all_platforms(self, monkeypatch): + """When the IRC plugin is registered, _all_platforms() surfaces it.""" + import hermes_cli.gateway as gateway_mod + + _register_irc_platform() + try: + # Ensure no stale env vars leak in + for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"): + monkeypatch.delenv(key, raising=False) + + platforms = gateway_mod._all_platforms() + keys = {p["key"] for p in platforms} + assert "irc" in keys + + irc_plat = next(p for p in platforms if p["key"] == "irc") + assert irc_plat["label"] == "IRC" + assert irc_plat["emoji"] == "💬" + finally: + _unregister_irc_platform() + + def test_irc_status_not_configured_when_fresh(self, monkeypatch): + """On a fresh install with no env vars, IRC shows 'not configured'.""" + import hermes_cli.gateway as gateway_mod + + plat = _register_irc_platform() + try: + for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"): + monkeypatch.delenv(key, raising=False) + + status = gateway_mod._platform_status(plat) + assert status == "not configured" + finally: + _unregister_irc_platform() + + def test_irc_status_configured_when_env_set(self, monkeypatch): + """After the user sets IRC_SERVER and IRC_CHANNEL, status is 'configured'.""" + import hermes_cli.gateway as gateway_mod + + plat = _register_irc_platform() + try: + monkeypatch.setenv("IRC_SERVER", "irc.libera.chat") + monkeypatch.setenv("IRC_CHANNEL", "#hermes") + monkeypatch.setenv("IRC_NICKNAME", "hermes-bot") + + status = gateway_mod._platform_status(plat) + assert status == "configured" + finally: + _unregister_irc_platform() + + def test_irc_status_partial_when_only_server_set(self, monkeypatch): + """If only IRC_SERVER is set, the platform is still not configured.""" + import hermes_cli.gateway as gateway_mod + + plat = _register_irc_platform() + try: + monkeypatch.delenv("IRC_CHANNEL", raising=False) + monkeypatch.delenv("IRC_NICKNAME", raising=False) + monkeypatch.setenv("IRC_SERVER", "irc.libera.chat") + + status = gateway_mod._platform_status(plat) + assert status == "not configured" + finally: + _unregister_irc_platform() + + +# ── Interactive setup dispatch ────────────────────────────────────────────── + + +class TestIRCInteractiveSetup: + """The setup UI dispatches to IRC's interactive_setup() correctly.""" + + def test_configure_platform_dispatches_to_irc_setup_fn(self, monkeypatch, capsys): + """_configure_platform() calls the IRC plugin's setup_fn when selected.""" + import hermes_cli.gateway as gateway_mod + + calls = [] + + def fake_setup(): + calls.append("setup_called") + print("IRC setup complete!") + + plat = _register_irc_platform(setup_fn=fake_setup) + try: + gateway_mod._configure_platform(plat) + finally: + _unregister_irc_platform() + + assert "setup_called" in calls + out = capsys.readouterr().out + assert "IRC setup complete!" in out + + + def test_configure_platform_fallback_when_no_setup_fn(self, monkeypatch, capsys): + """A plugin with no setup_fn falls back to env-var instructions.""" + import hermes_cli.gateway as gateway_mod + + plat = _register_irc_platform(setup_fn=None) + try: + gateway_mod._configure_platform(plat) + finally: + _unregister_irc_platform() + + out = capsys.readouterr().out + assert "IRC" in out + assert "IRC_SERVER" in out + + +# ── End-to-end fresh-install gateway setup ────────────────────────────────── + + +class TestIRCGatewaySetupFreshInstall: + """Simulate the full `hermes setup gateway` experience with IRC present.""" + + def test_setup_gateway_shows_irc_in_platform_menu(self, monkeypatch, capsys, tmp_path): + """The gateway setup menu lists IRC among the available platforms.""" + import hermes_cli.gateway as gateway_mod + from hermes_cli import setup as setup_mod + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _register_irc_platform() + try: + for key in ("IRC_SERVER", "IRC_CHANNEL", "IRC_NICKNAME"): + monkeypatch.delenv(key, raising=False) + + # Sanity-check: IRC must be visible to _all_platforms() + platforms = gateway_mod._all_platforms() + assert any(p["key"] == "irc" for p in platforms), \ + f"IRC not in platforms: {[p['key'] for p in platforms]}" + + # Capture what prompt_checklist is asked to display + checklist_calls = [] + + def capture_prompt_checklist(question, choices, pre_selected=None): + checklist_calls.append({"question": question, "choices": choices}) + return [] # nothing selected → clean exit + + monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *a, **kw: False) + monkeypatch.setattr(setup_mod, "prompt_checklist", capture_prompt_checklist) + monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False) + monkeypatch.setattr(gateway_mod, "is_macos", lambda: False) + monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False) + monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False) + + setup_mod.setup_gateway({}) + + # Find the platform-selection prompt + platform_prompt = next( + (c for c in checklist_calls if "platform" in c["question"].lower()), + None, + ) + assert platform_prompt is not None, \ + f"No platform prompt found in {checklist_calls}" + choices_text = "\n".join(platform_prompt["choices"]) + assert "IRC" in choices_text + assert "💬" in choices_text + assert "not configured" in choices_text.lower() + finally: + _unregister_irc_platform() + + def test_setup_gateway_irc_counts_as_messaging_platform(self, monkeypatch, capsys, tmp_path): + """When IRC is configured, setup_gateway counts it as a messaging platform.""" + import hermes_cli.gateway as gateway_mod + from hermes_cli import setup as setup_mod + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _register_irc_platform() + try: + monkeypatch.setenv("IRC_SERVER", "irc.libera.chat") + monkeypatch.setenv("IRC_CHANNEL", "#hermes") + monkeypatch.setenv("IRC_NICKNAME", "hermes-bot") + + monkeypatch.setattr(setup_mod, "prompt_yes_no", lambda *a, **kw: False) + monkeypatch.setattr(setup_mod, "prompt_choice", lambda *a, **kw: 0) + monkeypatch.setattr(gateway_mod, "supports_systemd_services", lambda: False) + monkeypatch.setattr(gateway_mod, "is_macos", lambda: False) + monkeypatch.setattr(gateway_mod, "_is_service_installed", lambda: False) + monkeypatch.setattr(gateway_mod, "_is_service_running", lambda: False) + + setup_mod.setup_gateway({}) + + out = capsys.readouterr().out + assert "Messaging platforms configured!" in out + finally: + _unregister_irc_platform() diff --git a/tests/hermes_cli/test_setup_openclaw_migration.py b/tests/hermes_cli/test_setup_openclaw_migration.py index a458bd3761..e627b61963 100644 --- a/tests/hermes_cli/test_setup_openclaw_migration.py +++ b/tests/hermes_cli/test_setup_openclaw_migration.py @@ -419,7 +419,12 @@ class TestGetSectionConfigSummary: return "disc456" return "" - with patch.object(setup_mod, "get_env_value", side_effect=env_side): + # Also patch gateway module's binding since _platform_status() + # reads from hermes_cli.gateway.get_env_value after the setup + # flows were unified via platform_registry. + import hermes_cli.gateway as gateway_mod + with patch.object(setup_mod, "get_env_value", side_effect=env_side), \ + patch.object(gateway_mod, "get_env_value", side_effect=env_side): result = setup_mod._get_section_config_summary({}, "gateway") assert "Telegram" in result assert "Discord" in result @@ -471,7 +476,9 @@ class TestGetSectionConfigSummary: def env_side(key): return "true" if key == "WHATSAPP_ENABLED" else "" - with patch.object(setup_mod, "get_env_value", side_effect=env_side): + import hermes_cli.gateway as gateway_mod + with patch.object(setup_mod, "get_env_value", side_effect=env_side), \ + patch.object(gateway_mod, "get_env_value", side_effect=env_side): result = setup_mod._get_section_config_summary({}, "gateway") assert result is not None assert "WhatsApp" in result @@ -481,7 +488,9 @@ class TestGetSectionConfigSummary: def env_side(key): return "http://signal.local" if key == "SIGNAL_HTTP_URL" else "" - with patch.object(setup_mod, "get_env_value", side_effect=env_side): + import hermes_cli.gateway as gateway_mod + with patch.object(setup_mod, "get_env_value", side_effect=env_side), \ + patch.object(gateway_mod, "get_env_value", side_effect=env_side): result = setup_mod._get_section_config_summary({}, "gateway") assert result is not None assert "Signal" in result @@ -529,13 +538,28 @@ class TestGetSectionConfigSummary: assert result == "gpt-5" def test_gateway_matches_platform_registry(self): - """Every platform in _GATEWAY_PLATFORMS should be recognised by its - own env-var sentinel — i.e. the summary must not drift from the + """Every built-in platform should be recognised by its primary + env-var sentinel — i.e. the summary must not drift from the registry used by the setup checklist.""" - for label, env_var, _fn in setup_mod._GATEWAY_PLATFORMS: + from hermes_cli.gateway import _PLATFORMS + + for plat in _PLATFORMS: + label = plat["label"] + env_var = plat.get("token_var") + if not env_var: + continue + # Some platforms require a specific value shape (e.g. WhatsApp + # needs the literal "true"). Use a sentinel that satisfies every + # real validator _platform_status() currently checks. def env_side(key, _target=env_var): - return "x" if key == _target else "" - with patch.object(setup_mod, "get_env_value", side_effect=env_side): + if key != _target: + return "" + if _target == "WHATSAPP_ENABLED": + return "true" + return "x" + import hermes_cli.gateway as gateway_mod + with patch.object(setup_mod, "get_env_value", side_effect=env_side), \ + patch.object(gateway_mod, "get_env_value", side_effect=env_side): result = setup_mod._get_section_config_summary({}, "gateway") expected = setup_mod._gateway_platform_short_label(label) assert result is not None, f"{label} ({env_var}) not recognised" diff --git a/tests/hermes_cli/test_status.py b/tests/hermes_cli/test_status.py index 216687660b..a13e843faf 100644 --- a/tests/hermes_cli/test_status.py +++ b/tests/hermes_cli/test_status.py @@ -79,3 +79,33 @@ def test_show_status_reports_nous_auth_error(monkeypatch, capsys, tmp_path): assert "Error: Refresh session has been revoked" in output assert "Access exp:" in output assert "Key exp:" in output + + +def test_show_status_reports_vercel_backend_contract(monkeypatch, capsys, tmp_path): + from hermes_cli import status as status_mod + import hermes_cli.auth as auth_mod + import hermes_cli.gateway as gateway_mod + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", "python3.13") + monkeypatch.setenv("TERMINAL_CONTAINER_PERSISTENT", "true") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(status_mod.importlib.util, "find_spec", lambda name: object() if name == "vercel" else None) + monkeypatch.setattr(status_mod, "load_config", lambda: {"terminal": {"backend": "vercel_sandbox"}}, raising=False) + monkeypatch.setattr(auth_mod, "get_nous_auth_status", lambda: {}, raising=False) + monkeypatch.setattr(auth_mod, "get_codex_auth_status", lambda: {}, raising=False) + monkeypatch.setattr(auth_mod, "get_qwen_auth_status", lambda: {}, raising=False) + monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda exclude_pids=None: [], raising=False) + + status_mod.show_status(SimpleNamespace(all=False, deep=False)) + + output = capsys.readouterr().out + assert "Backend: vercel_sandbox" in output + assert "Runtime: python3.13" in output + assert "Auth:" in output and "OIDC token via VERCEL_OIDC_TOKEN" in output + assert "Auth detail: mode: OIDC" in output + assert "Auth detail: active env: VERCEL_OIDC_TOKEN" in output + assert "oidc-token" not in output + assert "snapshot filesystem" in output + assert "live processes do not survive" in output diff --git a/tests/hermes_cli/test_tui_resume_flow.py b/tests/hermes_cli/test_tui_resume_flow.py index eb9274984d..8086ee87e3 100644 --- a/tests/hermes_cli/test_tui_resume_flow.py +++ b/tests/hermes_cli/test_tui_resume_flow.py @@ -12,6 +12,7 @@ def _args(**overrides): "model": None, "provider": None, "resume": None, + "toolsets": None, "tui": True, "tui_dev": False, } @@ -35,7 +36,7 @@ def test_cmd_chat_tui_continue_uses_latest_tui_session(monkeypatch, main_mod): calls.append(source) return "20260408_235959_a1b2c3" if source == "tui" else None - def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None): + def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None): captured["resume"] = resume_session_id raise SystemExit(0) @@ -62,7 +63,7 @@ def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, mai return "20260408_235959_d4e5f6" return None - def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None): + def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None): captured["resume"] = resume_session_id raise SystemExit(0) @@ -80,7 +81,7 @@ def test_cmd_chat_tui_continue_falls_back_to_latest_cli_session(monkeypatch, mai def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod): captured = {} - def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None): + def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None): captured["resume"] = resume_session_id raise SystemExit(0) @@ -98,12 +99,13 @@ def test_cmd_chat_tui_resume_resolves_title_before_launch(monkeypatch, main_mod) def test_cmd_chat_tui_passes_model_and_provider(monkeypatch, main_mod): captured = {} - def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None): + def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None): captured.update( { "model": model, "provider": provider, "resume": resume_session_id, + "toolsets": toolsets, "tui_dev": tui_dev, } ) @@ -120,11 +122,193 @@ def test_cmd_chat_tui_passes_model_and_provider(monkeypatch, main_mod): "model": "anthropic/claude-sonnet-4.6", "provider": "anthropic", "resume": None, + "toolsets": None, "tui_dev": False, } -def test_launch_tui_exports_model_and_provider(monkeypatch, main_mod): +def test_cmd_chat_tui_passes_toolsets(monkeypatch, main_mod): + captured = {} + + def fake_launch(resume_session_id=None, tui_dev=False, model=None, provider=None, toolsets=None): + captured["toolsets"] = toolsets + raise SystemExit(0) + + monkeypatch.setattr(main_mod, "_launch_tui", fake_launch) + + with pytest.raises(SystemExit): + main_mod.cmd_chat(_args(toolsets="web,terminal")) + + assert captured["toolsets"] == "web,terminal" + + +def test_main_top_level_tui_accepts_toolsets(monkeypatch, main_mod): + captured = {} + + import hermes_cli.config as config_mod + + monkeypatch.setattr(sys, "argv", ["hermes", "--tui", "--toolsets", "web,terminal"]) + monkeypatch.setitem(sys.modules, "hermes_cli.plugins", types.SimpleNamespace(discover_plugins=lambda: None)) + monkeypatch.setitem(sys.modules, "tools.mcp_tool", types.SimpleNamespace(discover_mcp_tools=lambda: None)) + monkeypatch.setattr(config_mod, "load_config", lambda: {}) + monkeypatch.setattr(config_mod, "get_container_exec_info", lambda: None) + monkeypatch.setitem( + sys.modules, + "agent.shell_hooks", + types.SimpleNamespace(register_from_config=lambda _cfg, accept_hooks=False: None), + ) + monkeypatch.setattr(main_mod, "cmd_chat", lambda args: captured.update({"toolsets": args.toolsets, "tui": args.tui})) + + main_mod.main() + + assert captured == {"toolsets": "web,terminal", "tui": True} + + +def test_main_top_level_oneshot_accepts_toolsets(monkeypatch, main_mod): + captured = {} + + import hermes_cli.config as config_mod + + monkeypatch.setattr(sys, "argv", ["hermes", "-z", "hello", "--toolsets", "web,terminal"]) + monkeypatch.setitem(sys.modules, "hermes_cli.plugins", types.SimpleNamespace(discover_plugins=lambda: None)) + monkeypatch.setitem(sys.modules, "tools.mcp_tool", types.SimpleNamespace(discover_mcp_tools=lambda: None)) + monkeypatch.setattr(config_mod, "load_config", lambda: {}) + monkeypatch.setattr(config_mod, "get_container_exec_info", lambda: None) + monkeypatch.setitem( + sys.modules, + "agent.shell_hooks", + types.SimpleNamespace(register_from_config=lambda _cfg, accept_hooks=False: None), + ) + monkeypatch.setitem( + sys.modules, + "hermes_cli.oneshot", + types.SimpleNamespace(run_oneshot=lambda prompt, **kwargs: captured.update({"prompt": prompt, **kwargs}) or 0), + ) + + with pytest.raises(SystemExit) as exc: + main_mod.main() + + assert exc.value.code == 0 + assert captured == {"prompt": "hello", "model": None, "provider": None, "toolsets": "web,terminal"} + + +def _stub_plugin_discovery(monkeypatch): + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + +def test_oneshot_rejects_invalid_only_toolsets(monkeypatch, capsys): + _stub_plugin_discovery(monkeypatch) + from hermes_cli.oneshot import run_oneshot + + assert run_oneshot("hello", toolsets="nope") == 2 + err = capsys.readouterr().err + assert "nope" in err + assert "did not contain any valid toolsets" in err + + +def test_oneshot_filters_invalid_toolsets_before_redirect(monkeypatch, capsys): + _stub_plugin_discovery(monkeypatch) + from hermes_cli.oneshot import _validate_explicit_toolsets + + valid, error = _validate_explicit_toolsets("web,nope") + + assert valid == ["web"] + assert error is None + assert "nope" in capsys.readouterr().err + + +def test_oneshot_all_toolsets_means_all_not_configured_cli(): + from hermes_cli.oneshot import _validate_explicit_toolsets + + valid, error = _validate_explicit_toolsets("all") + + assert valid is None + assert error is None + + +def test_oneshot_all_toolsets_warns_about_ignored_extra_entries(monkeypatch, capsys): + _stub_plugin_discovery(monkeypatch) + from hermes_cli.oneshot import _validate_explicit_toolsets + + valid, error = _validate_explicit_toolsets("all,nope") + + assert valid is None + assert error is None + assert "ignoring additional entries: nope" in capsys.readouterr().err + + +def test_oneshot_accepts_plugin_toolset_after_discovery(monkeypatch): + import toolsets + + from hermes_cli.oneshot import _validate_explicit_toolsets + + discovered = {"ready": False} + original_validate = toolsets.validate_toolset + + def fake_validate(name): + return name == "plugin_demo" and discovered["ready"] or original_validate(name) + + monkeypatch.setattr(toolsets, "validate_toolset", fake_validate) + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: discovered.update({"ready": True})), + ) + + valid, error = _validate_explicit_toolsets("plugin_demo") + + assert valid == ["plugin_demo"] + assert error is None + + +def test_oneshot_rejects_disabled_mcp_toolset(monkeypatch, capsys): + _stub_plugin_discovery(monkeypatch) + import hermes_cli.config as config_mod + + from hermes_cli.oneshot import _validate_explicit_toolsets + + monkeypatch.setattr( + config_mod, + "read_raw_config", + lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}}, + ) + + valid, error = _validate_explicit_toolsets("mcp-off") + + assert valid is None + assert error == "hermes -z: --toolsets did not contain any valid toolsets.\n" + err = capsys.readouterr().err + assert "ignoring disabled MCP servers" in err + assert "mcp-off" in err + + +def test_oneshot_distinguishes_disabled_mcp_from_unknown(monkeypatch, capsys): + _stub_plugin_discovery(monkeypatch) + import hermes_cli.config as config_mod + + from hermes_cli.oneshot import _validate_explicit_toolsets + + monkeypatch.setattr( + config_mod, + "read_raw_config", + lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}}, + ) + + valid, error = _validate_explicit_toolsets("web,mcp-off,nope") + + assert valid == ["web"] + assert error is None + err = capsys.readouterr().err + assert "ignoring unknown --toolsets entries: nope" in err + assert "ignoring disabled MCP servers" in err + assert "mcp-off" in err + + +def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod): captured = {} active_path_during_call = None @@ -144,13 +328,14 @@ def test_launch_tui_exports_model_and_provider(monkeypatch, main_mod): monkeypatch.setattr(main_mod.subprocess, "call", fake_call) with pytest.raises(SystemExit): - main_mod._launch_tui(model="nous/hermes-test", provider="nous") + main_mod._launch_tui(model="nous/hermes-test", provider="nous", toolsets="web, terminal") env = captured["env"] assert env["HERMES_MODEL"] == "nous/hermes-test" assert env["HERMES_INFERENCE_MODEL"] == "nous/hermes-test" assert env["HERMES_TUI_PROVIDER"] == "nous" assert env["HERMES_INFERENCE_PROVIDER"] == "nous" + assert env["HERMES_TUI_TOOLSETS"] == "web,terminal" active_path = Path(env["HERMES_TUI_ACTIVE_SESSION_FILE"]) assert active_path.name.startswith("hermes-tui-active-session-") assert active_path.suffix == ".json" diff --git a/tests/hermes_cli/test_update_autostash.py b/tests/hermes_cli/test_update_autostash.py index dee8cc1fbd..df8bccb209 100644 --- a/tests/hermes_cli/test_update_autostash.py +++ b/tests/hermes_cli/test_update_autostash.py @@ -333,7 +333,10 @@ def test_cmd_update_retries_optional_extras_individually_when_all_fails(monkeypa raise CalledProcessError(returncode=1, cmd=cmd) if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".[mcp]", "--quiet"]: return SimpleNamespace(returncode=0) - return SimpleNamespace(returncode=0) + # Catch-all must include stdout/stderr so consumers that parse + # output (e.g. the dashboard-restart `ps -A` scan added in the + # updater) don't crash on AttributeError. + return SimpleNamespace(returncode=0, stdout="", stderr="") monkeypatch.setattr(hermes_main.subprocess, "run", fake_run) @@ -370,7 +373,7 @@ def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path): return SimpleNamespace(stdout="1\n", stderr="", returncode=0) if cmd == ["git", "pull", "origin", "main"]: return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0) - return SimpleNamespace(returncode=0) + return SimpleNamespace(returncode=0, stdout="", stderr="") monkeypatch.setattr(hermes_main.subprocess, "run", fake_run) diff --git a/tests/hermes_cli/test_update_stale_dashboard.py b/tests/hermes_cli/test_update_stale_dashboard.py index 20c5eee98c..546fd48991 100644 --- a/tests/hermes_cli/test_update_stale_dashboard.py +++ b/tests/hermes_cli/test_update_stale_dashboard.py @@ -1,19 +1,62 @@ -"""Tests for _warn_stale_dashboard_processes — stale dashboard detection. +"""Tests for the stale-dashboard handling run at the end of ``hermes update``. -Ensures ``hermes update`` warns the user when dashboard processes from a -previous version are still running after files on disk have been replaced. -See #16872. +``hermes update`` detects ``hermes dashboard`` processes left over from the +previous version and kills them (SIGTERM + SIGKILL grace, or ``taskkill /F`` +on Windows). Without this, the running backend silently serves stale Python +against a freshly-updated JS bundle, producing 401s / empty data. + +History: +- #16872 introduced the warn-only helper (``_warn_stale_dashboard_processes``). +- #17049 fixed a Windows wmic UnicodeDecodeError crash on non-UTF-8 locales. +- This file now also covers the kill semantics that replaced the warning. """ from __future__ import annotations +import importlib import os import sys -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, call import pytest -from hermes_cli.main import _warn_stale_dashboard_processes +from hermes_cli.main import ( + _find_stale_dashboard_pids, + _kill_stale_dashboard_processes, + _warn_stale_dashboard_processes, # back-compat alias +) + + +@pytest.fixture(autouse=True) +def _refresh_bindings_against_live_module(): + """Rebind module-level names to the *current* ``hermes_cli.main``. + + Other tests in the suite (notably ``test_env_loader.py`` and + ``test_skills_subparser.py``) reload or delete ``hermes_cli.main`` from + ``sys.modules``. When that happens on the same xdist worker before we + run, our top-of-file ``from hermes_cli.main import ...`` bindings end + up pointing at the *old* module object. ``patch(\"hermes_cli.main.X\")`` + then patches the *new* module, but the function we call still resolves + ``_find_stale_dashboard_pids`` via its stale ``__globals__``, so every + patch becomes a no-op and the kill path silently returns early. + + Refreshing the bindings (and the patch target) to the live module + object — and keeping them consistent — makes the tests immune to + ordering within the worker. The fix lives in the test module because + the two pollutants above are load-bearing for their own tests. + """ + global _find_stale_dashboard_pids + global _kill_stale_dashboard_processes + global _warn_stale_dashboard_processes + + live = sys.modules.get("hermes_cli.main") + if live is None: + live = importlib.import_module("hermes_cli.main") + + _find_stale_dashboard_pids = live._find_stale_dashboard_pids + _kill_stale_dashboard_processes = live._kill_stale_dashboard_processes + _warn_stale_dashboard_processes = live._warn_stale_dashboard_processes + yield def _ps_line(pid: int, cmd: str) -> str: @@ -21,11 +64,26 @@ def _ps_line(pid: int, cmd: str) -> str: return f"{pid:>7} {cmd}" -class TestWarnStaleDashboardProcesses: - """Unit tests for the stale dashboard process warning.""" +def _ps_runner(stdout: str): + """Build a subprocess.run side_effect that only stubs ps -A calls. - def test_no_warning_when_no_dashboard_running(self, capsys): - """ps returns no matching processes — no warning should be printed.""" + Any other subprocess.run invocation (e.g. taskkill on Windows) is + handed back as a successful no-op. This lets tests exercise the real + scan path without having to re-stub every unrelated subprocess call + made later in ``_kill_stale_dashboard_processes``. + """ + def _side_effect(args, *a, **kw): + if isinstance(args, (list, tuple)) and args and args[0] == "ps": + return MagicMock(returncode=0, stdout=stdout, stderr="") + # Any other subprocess.run (e.g. taskkill) — benign success stub. + return MagicMock(returncode=0, stdout="", stderr="") + return _side_effect + + +class TestFindStaleDashboardPids: + """Unit tests for the ps/wmic-based detection step.""" + + def test_no_matches_returns_empty(self): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, @@ -35,26 +93,18 @@ class TestWarnStaleDashboardProcesses: + "\n", stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "dashboard process" not in output + assert _find_stale_dashboard_pids() == [] - def test_warning_printed_for_running_dashboard(self, capsys): - """ps finds a dashboard PID — warning with PID should appear.""" + def test_matches_running_dashboard(self): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, stdout=_ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119") + "\n", stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "1 dashboard process" in output - assert "PID 12345" in output - assert "kill " in output + assert _find_stale_dashboard_pids() == [12345] - def test_multiple_dashboard_pids(self, capsys): - """Multiple dashboard processes — all PIDs listed.""" + def test_multiple_matches(self): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, @@ -65,15 +115,9 @@ class TestWarnStaleDashboardProcesses: ]) + "\n", stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "3 dashboard process" in output - assert "PID 12345" in output - assert "PID 12346" in output - assert "PID 12347" in output + assert sorted(_find_stale_dashboard_pids()) == [12345, 12346, 12347] - def test_self_pid_excluded(self, capsys): - """The current process PID should not be reported.""" + def test_self_pid_excluded(self): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, @@ -83,41 +127,51 @@ class TestWarnStaleDashboardProcesses: ]) + "\n", stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - # The self PID may still appear inside an unrelated context, so anchor - # the check to "PID " which is how the warning prints. - assert f"PID {os.getpid()}" not in output - assert "PID 12345" in output + pids = _find_stale_dashboard_pids() + assert os.getpid() not in pids + assert 12345 in pids - def test_ps_not_found_silently_ignored(self, capsys): - """If ps is missing (FileNotFoundError), no crash, no warning.""" + def test_ps_not_found_returns_empty(self): with patch("subprocess.run", side_effect=FileNotFoundError): - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert output == "" + assert _find_stale_dashboard_pids() == [] - def test_ps_timeout_silently_ignored(self, capsys): - """If ps times out, no crash, no warning.""" + def test_ps_timeout_returns_empty(self): import subprocess as sp - with patch("subprocess.run", side_effect=sp.TimeoutExpired("ps", 10)): - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert output == "" + assert _find_stale_dashboard_pids() == [] - def test_empty_ps_output_no_warning(self, capsys): - """ps returns 0 but empty stdout — no warning.""" + def test_unrelated_process_containing_word_dashboard_not_matched(self): + """Guards against greedy pgrep-style matching catching chat sessions + or unrelated processes whose cmdline happens to contain 'dashboard'. + """ with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( - returncode=0, stdout="\n", stderr="" + returncode=0, + stdout="\n".join([ + _ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119"), + _ps_line(22222, "python3 -m hermes_cli.main chat -q 'rewrite my dashboard'"), + _ps_line(33333, "node /opt/grafana/dashboard-server.js"), + ]) + "\n", + stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "dashboard process" not in output + pids = _find_stale_dashboard_pids() + assert pids == [12345] - def test_invalid_pid_lines_skipped(self, capsys): - """Malformed ps lines should be skipped gracefully.""" + def test_grep_lines_ignored(self): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, + stdout="\n".join([ + _ps_line(99999, "grep hermes dashboard"), + _ps_line(12345, "hermes dashboard --port 9119"), + ]) + "\n", + stderr="", + ) + pids = _find_stale_dashboard_pids() + assert 99999 not in pids + assert 12345 in pids + + def test_invalid_pid_lines_skipped(self): with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, @@ -128,50 +182,213 @@ class TestWarnStaleDashboardProcesses: ]) + "\n", stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "PID 12345" in output - assert "1 dashboard process" in output + pids = _find_stale_dashboard_pids() + assert pids == [12345] - def test_unrelated_process_containing_word_dashboard_not_matched(self, capsys): - """A process whose cmdline contains 'dashboard' but isn't a hermes - dashboard process must NOT be flagged. This guards against the old - ``pgrep -f "hermes.*dashboard"`` greedy regex that matched e.g. a - chat session argv containing both words. - """ + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX kill semantics") +class TestKillStaleDashboardPosix: + """Kill path on Linux / macOS: SIGTERM then SIGKILL any survivors.""" + + def test_no_stale_processes_is_a_noop(self, capsys): + with patch("hermes_cli.main._find_stale_dashboard_pids", return_value=[]): + _kill_stale_dashboard_processes() + assert capsys.readouterr().out == "" + + def test_sigterm_graceful_exit(self, capsys): + """Processes that exit on SIGTERM (the probe gets ProcessLookupError) + are reported as stopped and SIGKILL is never sent.""" + import signal as _signal + + killed_signals: list[tuple[int, int]] = [] + + def fake_kill(pid, sig): + killed_signals.append((pid, sig)) + if sig == 0: + # Probe after SIGTERM → "process gone". + raise ProcessLookupError + # SIGTERM itself: succeed silently. + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345, 12346]), \ + patch("os.kill", side_effect=fake_kill), \ + patch("time.sleep"): + _kill_stale_dashboard_processes() + + # Both got SIGTERM. + sigterms = [pid for pid, sig in killed_signals if sig == _signal.SIGTERM] + assert sorted(sigterms) == [12345, 12346] + # No SIGKILL was needed. + assert not any(sig == _signal.SIGKILL for _, sig in killed_signals) + + out = capsys.readouterr().out + assert "Stopping 2 dashboard" in out + assert "✓ stopped PID 12345" in out + assert "✓ stopped PID 12346" in out + assert "Restart the dashboard" in out + + def test_sigkill_fallback_for_survivors(self, capsys): + """If a process survives SIGTERM + the grace window, SIGKILL is sent.""" + import signal as _signal + + sent: list[tuple[int, int]] = [] + + def fake_kill(pid, sig): + sent.append((pid, sig)) + # Simulate stubborn process: probe (sig 0) always succeeds, + # SIGTERM does nothing, SIGKILL is where it "dies". + if sig in (_signal.SIGTERM, 0, _signal.SIGKILL): + return + # Any other signal — also fine. + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[99999]), \ + patch("os.kill", side_effect=fake_kill), \ + patch("time.sleep"), \ + patch("time.monotonic", side_effect=[0.0] + [10.0] * 20): + # monotonic jumps past the 3s deadline on the second read so the + # grace loop exits immediately after one iteration. + _kill_stale_dashboard_processes() + + signals_sent = [sig for _, sig in sent] + assert _signal.SIGTERM in signals_sent + assert _signal.SIGKILL in signals_sent + + out = capsys.readouterr().out + assert "✓ stopped PID 99999" in out + + def test_permission_error_is_reported_not_raised(self, capsys): + """os.kill raising PermissionError (e.g. another user's process) + must not abort hermes update — it's reported as a failure and we + move on.""" + def fake_kill(pid, sig): + raise PermissionError("Operation not permitted") + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345]), \ + patch("os.kill", side_effect=fake_kill), \ + patch("time.sleep"): + _kill_stale_dashboard_processes() # must not raise + + out = capsys.readouterr().out + assert "✗ failed to stop PID 12345" in out + assert "Operation not permitted" in out + + def test_process_already_gone_counts_as_stopped(self, capsys): + """ProcessLookupError on the initial SIGTERM means the process + already exited between detection and the kill — treat as success.""" + def fake_kill(pid, sig): + raise ProcessLookupError + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345]), \ + patch("os.kill", side_effect=fake_kill), \ + patch("time.sleep"): + _kill_stale_dashboard_processes() + + out = capsys.readouterr().out + assert "✓ stopped PID 12345" in out + assert "failed to stop" not in out + + +class TestKillStaleDashboardWindows: + """Kill path on Windows: taskkill /F.""" + + def test_taskkill_invoked_for_each_pid(self, monkeypatch, capsys): + monkeypatch.setattr(sys, "platform", "win32") + + def fake_run(args, *a, **kw): + # taskkill returns 0 on success + return MagicMock(returncode=0, stdout="", stderr="") + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345, 12346]), \ + patch("subprocess.run", side_effect=fake_run) as mock_run: + _kill_stale_dashboard_processes() + + # Each PID triggered a taskkill /PID /F invocation. + taskkill_calls = [ + c for c in mock_run.call_args_list + if c.args and isinstance(c.args[0], list) and c.args[0][:1] == ["taskkill"] + ] + assert len(taskkill_calls) == 2 + assert ["taskkill", "/PID", "12345", "/F"] in [c.args[0] for c in taskkill_calls] + assert ["taskkill", "/PID", "12346", "/F"] in [c.args[0] for c in taskkill_calls] + + out = capsys.readouterr().out + assert "✓ stopped PID 12345" in out + assert "✓ stopped PID 12346" in out + + def test_taskkill_failure_is_reported(self, monkeypatch, capsys): + monkeypatch.setattr(sys, "platform", "win32") + + def fake_run(args, *a, **kw): + return MagicMock(returncode=128, stdout="", + stderr="ERROR: Access is denied.") + + with patch("hermes_cli.main._find_stale_dashboard_pids", + return_value=[12345]), \ + patch("subprocess.run", side_effect=fake_run): + _kill_stale_dashboard_processes() # must not raise + + out = capsys.readouterr().out + assert "✗ failed to stop PID 12345" in out + assert "Access is denied" in out + + +class TestBackCompatAlias: + """``_warn_stale_dashboard_processes`` is kept as an alias for the + new kill function so old imports don't break.""" + + def test_alias_is_the_kill_function(self): + assert _warn_stale_dashboard_processes is _kill_stale_dashboard_processes + + +class TestWindowsWmicEncoding: + """Regression tests for #17049 — the Windows wmic branch must not crash + `hermes update` on non-UTF-8 system locales (e.g. cp936 on zh-CN). + """ + + def test_wmic_invoked_with_utf8_ignore_errors(self, monkeypatch): + """The wmic subprocess.run call must pass encoding='utf-8' and + errors='ignore' so the subprocess reader thread cannot raise + UnicodeDecodeError on non-UTF-8 wmic output.""" + monkeypatch.setattr(sys, "platform", "win32") with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( returncode=0, - stdout="\n".join([ - # Legitimate dashboard — should match. - _ps_line(12345, "python3 -m hermes_cli.main dashboard --port 9119"), - # hermes running something else, with "dashboard" as a - # substring of an unrelated arg — should NOT match. - _ps_line(22222, "python3 -m hermes_cli.main chat -q 'rewrite my dashboard'"), - # Completely unrelated process mentioning dashboard. - _ps_line(33333, "node /opt/grafana/dashboard-server.js"), - ]) + "\n", + stdout=( + "CommandLine=python -m hermes_cli.main dashboard\n" + "ProcessId=12345\n" + ), stderr="", ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "1 dashboard process" in output - assert "PID 12345" in output - assert "PID 22222" not in output - assert "PID 33333" not in output + _find_stale_dashboard_pids() - def test_grep_lines_ignored(self, capsys): - """Lines containing 'grep' (from a pipe in ps output) are ignored.""" + # The wmic call is the first subprocess.run invocation. + assert mock_run.called, "subprocess.run was not invoked" + wmic_call = mock_run.call_args_list[0] + kwargs = wmic_call.kwargs + assert kwargs.get("encoding") == "utf-8", ( + "encoding kwarg must be 'utf-8' so wmic output is decoded " + "deterministically rather than via the implicit reader-thread " + "default that crashes on non-UTF-8 locales (#17049)." + ) + assert kwargs.get("errors") == "ignore", ( + "errors kwarg must be 'ignore' so undecodable bytes don't take " + "down the reader thread (#17049)." + ) + + def test_wmic_returns_none_stdout_does_not_crash(self, monkeypatch): + """If subprocess.run returns successfully but stdout is None — which + is what Python 3.11 leaves behind when the reader thread silently + crashed on UnicodeDecodeError before this fix landed — detection + must short-circuit instead of raising AttributeError on + ``None.split('\\n')`` and aborting `hermes update` (#17049).""" + monkeypatch.setattr(sys, "platform", "win32") with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( - returncode=0, - stdout="\n".join([ - _ps_line(99999, "grep hermes dashboard"), - _ps_line(12345, "hermes dashboard --port 9119"), - ]) + "\n", - stderr="", + returncode=0, stdout=None, stderr="" ) - _warn_stale_dashboard_processes() - output = capsys.readouterr().out - assert "PID 99999" not in output - assert "PID 12345" in output + # Must not raise. + assert _find_stale_dashboard_pids() == [] diff --git a/tests/hermes_cli/test_user_providers_model_switch.py b/tests/hermes_cli/test_user_providers_model_switch.py index b86dcdba3b..0a357c21fc 100644 --- a/tests/hermes_cli/test_user_providers_model_switch.py +++ b/tests/hermes_cli/test_user_providers_model_switch.py @@ -453,6 +453,142 @@ def test_list_authenticated_providers_no_duplicate_labels_across_schemas(monkeyp ) +def test_list_authenticated_providers_hides_custom_shadowing_builtin_endpoint(monkeypatch): + """#16970: a custom_providers entry whose ``base_url`` matches a built-in + provider's endpoint should be hidden. The built-in row already represents + that endpoint with its canonical slug, curated model list, and auth wiring. + + Repro: user sets ``DASHSCOPE_API_KEY`` (triggers the built-in ``alibaba`` + row pointing at the static ``inference_base_url``) AND defines a + ``my-alibaba`` custom provider pointing at the same URL. Before the fix, + the picker showed both rows for one endpoint. + """ + monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test") + monkeypatch.setattr( + "agent.models_dev.fetch_models_dev", + lambda: { + "alibaba": { + "name": "Alibaba Cloud (DashScope)", + "env": ["DASHSCOPE_API_KEY"], + } + }, + ) + monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {}) + + custom_providers = [ + { + "name": "my-alibaba", + # Matches PROVIDER_REGISTRY['alibaba'].inference_base_url exactly. + "base_url": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + "api_key": "sk-sp-test", + "model": "qwen3.6-plus", + "models": {"qwen3.6-plus": {"context_length": 500000}}, + } + ] + + providers = list_authenticated_providers( + current_provider="my-alibaba", + user_providers={}, + custom_providers=custom_providers, + max_models=50, + ) + + slugs = [p["slug"] for p in providers] + # Built-in alibaba row should be present. + assert "alibaba" in slugs, ( + f"Expected built-in alibaba row, got slugs: {slugs}" + ) + # Custom shadow row should be hidden — its base_url matches the built-in's. + assert not any("my-alibaba" in s for s in slugs), ( + f"Custom my-alibaba should have been dedup'd against the built-in " + f"alibaba endpoint, got slugs: {slugs}" + ) + + +def test_list_authenticated_providers_keeps_custom_with_distinct_endpoint(monkeypatch): + """Dedup must only apply when the endpoint matches a built-in. A custom + provider on a genuinely distinct endpoint stays visible even if a + built-in is also authenticated.""" + monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test") + monkeypatch.setattr( + "agent.models_dev.fetch_models_dev", + lambda: { + "alibaba": { + "name": "Alibaba Cloud (DashScope)", + "env": ["DASHSCOPE_API_KEY"], + } + }, + ) + monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {}) + + custom_providers = [ + { + "name": "my-private-relay", + "base_url": "https://relay.example.internal/v1", + "api_key": "sk-relay-test", + "model": "qwen3.6-plus", + "models": {"qwen3.6-plus": {}}, + } + ] + + providers = list_authenticated_providers( + current_provider="my-private-relay", + user_providers={}, + custom_providers=custom_providers, + max_models=50, + ) + + slugs = [p["slug"] for p in providers] + assert any("my-private-relay" in s for s in slugs), ( + f"Custom provider on distinct endpoint must stay visible, got: {slugs}" + ) + + +def test_list_authenticated_providers_dedup_honors_base_url_env_override(monkeypatch): + """The dedup must track the EFFECTIVE endpoint — if DASHSCOPE_BASE_URL + overrides the static inference_base_url, a custom provider pointing at + the overridden URL (not the static one) should still be recognized as + a duplicate.""" + monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test") + monkeypatch.setenv( + "DASHSCOPE_BASE_URL", + "https://custom-dashscope.example.com/v1", + ) + monkeypatch.setattr( + "agent.models_dev.fetch_models_dev", + lambda: { + "alibaba": { + "name": "Alibaba Cloud (DashScope)", + "env": ["DASHSCOPE_API_KEY"], + } + }, + ) + monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {}) + + custom_providers = [ + { + "name": "my-dashscope-override", + # Same URL as DASHSCOPE_BASE_URL env override above. + "base_url": "https://custom-dashscope.example.com/v1", + "api_key": "sk-test", + "model": "qwen3.6-plus", + } + ] + + providers = list_authenticated_providers( + current_provider="alibaba", + user_providers={}, + custom_providers=custom_providers, + max_models=50, + ) + + slugs = [p["slug"] for p in providers] + assert not any("my-dashscope-override" in s for s in slugs), ( + f"Custom entry matching env-overridden built-in endpoint should be " + f"dedup'd, got: {slugs}" + ) + + # ============================================================================= # Tests for _get_named_custom_provider with providers: dict # ============================================================================= diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index 60390d47ce..f2aed86d42 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -29,7 +29,7 @@ class TestReloadEnv: """reload_env() adds vars from .env that are not in os.environ.""" env_file = tmp_path / ".env" env_file.write_text("TEST_RELOAD_VAR=hello123\n") - with patch("hermes_cli.config.get_env_path", return_value=env_file): + with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}): os.environ.pop("TEST_RELOAD_VAR", None) count = reload_env() assert count >= 1 @@ -40,7 +40,7 @@ class TestReloadEnv: """reload_env() updates vars whose value changed on disk.""" env_file = tmp_path / ".env" env_file.write_text("TEST_RELOAD_VAR=old_value\n") - with patch("hermes_cli.config.get_env_path", return_value=env_file): + with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}): os.environ["TEST_RELOAD_VAR"] = "old_value" # Now change the file env_file.write_text("TEST_RELOAD_VAR=new_value\n") @@ -55,7 +55,7 @@ class TestReloadEnv: env_file.write_text("") # empty .env # Pick a known key from OPTIONAL_ENV_VARS known_key = next(iter(OPTIONAL_ENV_VARS.keys())) - with patch("hermes_cli.config.get_env_path", return_value=env_file): + with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}): os.environ[known_key] = "stale_value" count = reload_env() assert known_key not in os.environ @@ -65,7 +65,7 @@ class TestReloadEnv: """reload_env() preserves non-Hermes env vars even when absent from .env.""" env_file = tmp_path / ".env" env_file.write_text("") - with patch("hermes_cli.config.get_env_path", return_value=env_file): + with patch.dict(reload_env.__globals__, {"get_env_path": lambda: env_file}): os.environ["MY_CUSTOM_UNRELATED_VAR"] = "keep_me" reload_env() assert os.environ.get("MY_CUSTOM_UNRELATED_VAR") == "keep_me" @@ -371,6 +371,12 @@ class TestBuildSchemaFromConfig: assert entry["type"] == "select" assert "options" in entry assert "local" in entry["options"] + assert "vercel_sandbox" in entry["options"] + runtime_entry = CONFIG_SCHEMA["terminal.vercel_runtime"] + assert runtime_entry["type"] == "select" + assert "node24" in runtime_entry["options"] + assert "python3.13" in runtime_entry["options"] + assert len(runtime_entry["options"]) >= 3 def test_empty_prefix_produces_correct_keys(self): from hermes_cli.web_server import _build_schema_from_config @@ -671,8 +677,12 @@ class TestNewEndpoints: assert resp.status_code == 200 assert resp.json()["command"] == "hermes setup" - def test_profiles_create_creates_wrapper_alias_when_safe(self): - from pathlib import Path + def test_profiles_create_creates_wrapper_alias_when_safe(self, monkeypatch, tmp_path): + import hermes_cli.profiles as profiles_mod + + wrapper_dir = tmp_path / "bin" + wrapper_dir.mkdir() + monkeypatch.setattr(profiles_mod, "_get_wrapper_dir", lambda: wrapper_dir) resp = self.client.post( "/api/profiles", @@ -680,7 +690,7 @@ class TestNewEndpoints: ) assert resp.status_code == 200 - wrapper_path = Path.home() / ".local" / "bin" / "writer" + wrapper_path = wrapper_dir / "writer" assert wrapper_path.exists() assert wrapper_path.read_text() == '#!/bin/sh\nexec hermes -p writer "$@"\n' @@ -2057,14 +2067,24 @@ class TestPtyWebSocket: assert b"round-trip-payload" in buf def test_resize_escape_is_forwarded(self, monkeypatch): - # Resize escape gets intercepted and applied via TIOCSWINSZ, - # then ``tput cols/lines`` reports the new dimensions back. + # Resize escape gets intercepted and applied via TIOCSWINSZ, then the + # child reads the TTY ioctl directly. Avoid tput because CI may not set + # TERM for non-interactive shells. + import sys + + winsize_script = ( + "import fcntl, struct, termios, time; " + "time.sleep(0.15); " + "rows, cols, *_ = struct.unpack('HHHH', " + "fcntl.ioctl(0, termios.TIOCGWINSZ, b'\\0' * 8)); " + "print(cols); print(rows)" + ) monkeypatch.setattr( self.ws_module, "_resolve_chat_argv", - # sleep gives the test time to push the resize before tput runs + # sleep gives the test time to push the resize before the child reads the ioctl. lambda resume=None, sidecar_url=None: ( - ["/bin/sh", "-c", "sleep 0.15; tput cols; tput lines"], + [sys.executable, "-c", winsize_script], None, None, ), @@ -2153,13 +2173,30 @@ class TestPtyWebSocket: def test_pub_broadcasts_to_events_subscribers(self, monkeypatch): """Frame written to /api/pub is rebroadcast verbatim to every /api/events subscriber on the same channel.""" + import time from urllib.parse import urlencode + from hermes_cli import web_server as ws_mod qs = urlencode({"token": self.token, "channel": "broadcast-test"}) pub_path = f"/api/pub?{qs}" sub_path = f"/api/events?{qs}" with self.client.websocket_connect(sub_path) as sub: + # Wait for the subscriber to be registered on the server side. + # websocket_connect returns when ws.accept() completes, but the + # server adds us to ``_event_channels`` in a follow-up await, + # so a publish immediately after connect can race ahead of the + # subscriber registration and the message is dropped. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if ws_mod._event_channels.get("broadcast-test"): + break + time.sleep(0.01) + else: + raise AssertionError( + "subscriber did not register on channel within 5s" + ) + with self.client.websocket_connect(pub_path) as pub: pub.send_text('{"type":"tool.start","payload":{"tool_id":"t1"}}') received = sub.receive_text() diff --git a/tests/openviking_plugin/test_openviking.py b/tests/openviking_plugin/test_openviking.py new file mode 100644 index 0000000000..6848afc475 --- /dev/null +++ b/tests/openviking_plugin/test_openviking.py @@ -0,0 +1,233 @@ +"""Tests for plugins/memory/openviking/__init__.py — URI normalization and payload handling.""" + +import json + +from plugins.memory.openviking import OpenVikingMemoryProvider + + +class FakeVikingClient: + def __init__(self, responses): + self.responses = responses + self.calls = [] + + def get(self, path, params=None, **kwargs): + self.calls.append((path, params or {})) + response = self.responses[(path, tuple(sorted((params or {}).items())))] + if isinstance(response, Exception): + raise response + return response + + +class TestOpenVikingSummaryUriNormalization: + def test_normalize_summary_uri_maps_pseudo_files_to_parent_directory(self): + assert OpenVikingMemoryProvider._normalize_summary_uri("viking://user/hermes/.overview.md") == "viking://user/hermes" + assert OpenVikingMemoryProvider._normalize_summary_uri("viking://resources/.abstract.md") == "viking://resources" + assert OpenVikingMemoryProvider._normalize_summary_uri("viking://") == "viking://" + assert OpenVikingMemoryProvider._normalize_summary_uri("viking://user/hermes/memories/profile.md") == "viking://user/hermes/memories/profile.md" + + +class TestOpenVikingRead: + def test_overview_read_normalizes_uri_and_unwraps_result(self): + provider = OpenVikingMemoryProvider() + provider._client = FakeVikingClient( + { + ( + "/api/v1/content/overview", + (("uri", "viking://user/hermes"),), + ): {"result": {"content": "overview text"}}, + } + ) + + result = json.loads(provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"})) + + assert result["uri"] == "viking://user/hermes/.overview.md" + assert result["resolved_uri"] == "viking://user/hermes" + assert result["level"] == "overview" + assert result["content"] == "overview text" + assert provider._client.calls == [( + "/api/v1/content/overview", + {"uri": "viking://user/hermes"}, + )] + + def test_full_read_keeps_original_uri(self): + provider = OpenVikingMemoryProvider() + provider._client = FakeVikingClient( + { + ( + "/api/v1/content/read", + (("uri", "viking://user/hermes/memories/profile.md"),), + ): {"result": "full text"}, + } + ) + + result = json.loads(provider._tool_read({"uri": "viking://user/hermes/memories/profile.md", "level": "full"})) + + assert result["uri"] == "viking://user/hermes/memories/profile.md" + assert result["resolved_uri"] == "viking://user/hermes/memories/profile.md" + assert result["level"] == "full" + assert result["content"] == "full text" + assert provider._client.calls == [( + "/api/v1/content/read", + {"uri": "viking://user/hermes/memories/profile.md"}, + )] + + def test_overview_file_uri_routes_straight_to_content_read_via_stat_probe(self): + """Pre-check via fs/stat: file URIs skip the directory-only endpoint entirely.""" + provider = OpenVikingMemoryProvider() + file_uri = "viking://user/hermes/memories/entities/mem_abc.md" + provider._client = FakeVikingClient( + { + ( + "/api/v1/fs/stat", + (("uri", file_uri),), + ): {"result": {"isDir": False}}, + ( + "/api/v1/content/read", + (("uri", file_uri),), + ): {"result": {"content": "full content"}}, + } + ) + + result = json.loads(provider._tool_read({"uri": file_uri, "level": "overview"})) + + assert result["uri"] == file_uri + assert result["resolved_uri"] == file_uri + assert result["level"] == "overview" + assert result["fallback"] == "content/read" + assert result["content"] == "full content" + assert provider._client.calls == [ + ("/api/v1/fs/stat", {"uri": file_uri}), + ("/api/v1/content/read", {"uri": file_uri}), + ] + + def test_overview_dir_uri_skips_stat_when_pseudo_summary(self): + """Pseudo-URI path already resolves to dir, so no stat probe needed.""" + provider = OpenVikingMemoryProvider() + provider._client = FakeVikingClient( + { + ( + "/api/v1/content/overview", + (("uri", "viking://user/hermes"),), + ): {"result": "overview"}, + } + ) + + result = json.loads(provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"})) + + assert result["content"] == "overview" + # No fs/stat call — normalization already determined it's a directory. + assert provider._client.calls == [ + ("/api/v1/content/overview", {"uri": "viking://user/hermes"}), + ] + + def test_overview_directory_uri_uses_stat_probe_then_overview(self): + """Non-pseudo directory URI: stat → isDir=True → summary endpoint.""" + provider = OpenVikingMemoryProvider() + dir_uri = "viking://user/hermes/memories" + provider._client = FakeVikingClient( + { + ( + "/api/v1/fs/stat", + (("uri", dir_uri),), + ): {"result": {"isDir": True}}, + ( + "/api/v1/content/overview", + (("uri", dir_uri),), + ): {"result": "dir overview"}, + } + ) + + result = json.loads(provider._tool_read({"uri": dir_uri, "level": "overview"})) + + assert result["content"] == "dir overview" + assert "fallback" not in result + assert provider._client.calls == [ + ("/api/v1/fs/stat", {"uri": dir_uri}), + ("/api/v1/content/overview", {"uri": dir_uri}), + ] + + def test_overview_file_uri_falls_back_via_exception_when_stat_indeterminate(self): + """If fs/stat raises or returns unknown shape, legacy exception fallback still kicks in.""" + provider = OpenVikingMemoryProvider() + file_uri = "viking://user/hermes/memories/entities/mem_abc.md" + provider._client = FakeVikingClient( + { + ( + "/api/v1/fs/stat", + (("uri", file_uri),), + ): RuntimeError("stat unavailable"), + ( + "/api/v1/content/overview", + (("uri", file_uri),), + ): RuntimeError("500 Internal Server Error"), + ( + "/api/v1/content/read", + (("uri", file_uri),), + ): {"result": {"content": "fallback full content"}}, + } + ) + + result = json.loads(provider._tool_read({"uri": file_uri, "level": "overview"})) + + assert result["uri"] == file_uri + assert result["level"] == "overview" + assert result["fallback"] == "content/read" + assert result["content"] == "fallback full content" + assert provider._client.calls == [ + ("/api/v1/fs/stat", {"uri": file_uri}), + ("/api/v1/content/overview", {"uri": file_uri}), + ("/api/v1/content/read", {"uri": file_uri}), + ] + + def test_summary_uri_error_does_not_fallback_and_raises(self): + provider = OpenVikingMemoryProvider() + provider._client = FakeVikingClient( + { + ( + "/api/v1/content/overview", + (("uri", "viking://user/hermes"),), + ): RuntimeError("500 Internal Server Error"), + } + ) + + try: + provider._tool_read({"uri": "viking://user/hermes/.overview.md", "level": "overview"}) + assert False, "Expected summary endpoint error to be raised" + except RuntimeError: + pass + + assert provider._client.calls == [ + ("/api/v1/content/overview", {"uri": "viking://user/hermes"}), + ] + + +class TestOpenVikingBrowse: + def test_list_browse_unwraps_and_normalizes_entry_shapes(self): + provider = OpenVikingMemoryProvider() + provider._client = FakeVikingClient( + { + ( + "/api/v1/fs/ls", + (("uri", "viking://user/hermes"),), + ): { + "result": { + "entries": [ + {"name": "memories", "uri": "viking://user/hermes/memories", "type": "dir"}, + {"rel_path": "profile.md", "uri": "viking://user/hermes/memories/profile.md", "isDir": False, "abstract": "Profile"}, + ] + } + }, + } + ) + + result = json.loads(provider._tool_browse({"action": "list", "path": "viking://user/hermes"})) + + assert result["path"] == "viking://user/hermes" + assert result["entries"] == [ + {"name": "memories", "uri": "viking://user/hermes/memories", "type": "dir", "abstract": ""}, + {"name": "profile.md", "uri": "viking://user/hermes/memories/profile.md", "type": "file", "abstract": "Profile"}, + ] + assert provider._client.calls == [( + "/api/v1/fs/ls", + {"uri": "viking://user/hermes"}, + )] diff --git a/tests/plugins/memory/test_hindsight_provider.py b/tests/plugins/memory/test_hindsight_provider.py index 4d363db326..334e6ab5ea 100644 --- a/tests/plugins/memory/test_hindsight_provider.py +++ b/tests/plugins/memory/test_hindsight_provider.py @@ -669,7 +669,7 @@ class TestSyncTurn: p._client = _make_mock_client() p.sync_turn("hello", "hi there") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() call_kwargs = p._client.aretain_batch.call_args.kwargs @@ -710,8 +710,7 @@ class TestSyncTurn: def test_sync_turn_with_tags(self, provider_with_config): p = provider_with_config(retain_tags=["conv", "session1"]) p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert "conv" in item["tags"] assert "session1" in item["tags"] @@ -720,8 +719,7 @@ class TestSyncTurn: def test_sync_turn_uses_aretain_batch(self, provider): """sync_turn should use aretain_batch with retain_async.""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() provider._client.aretain_batch.assert_called_once() call_kwargs = provider._client.aretain_batch.call_args.kwargs assert call_kwargs["document_id"].startswith("test-session-") @@ -732,8 +730,7 @@ class TestSyncTurn: def test_sync_turn_custom_context(self, provider_with_config): p = provider_with_config(retain_context="my-agent") p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert item["context"] == "my-agent" @@ -744,7 +741,7 @@ class TestSyncTurn: p.sync_turn("turn2-user", "turn2-asst") assert p._sync_thread is None p.sync_turn("turn3-user", "turn3-asst") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() call_kwargs = p._client.aretain_batch.call_args.kwargs assert call_kwargs["document_id"].startswith("test-session-") @@ -765,15 +762,13 @@ class TestSyncTurn: p.sync_turn("turn1-user", "turn1-asst") p.sync_turn("turn2-user", "turn2-asst") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.reset_mock() p.sync_turn("turn3-user", "turn3-asst") p.sync_turn("turn4-user", "turn4-asst") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] # Should contain ALL turns from the session @@ -785,8 +780,7 @@ class TestSyncTurn: def test_sync_turn_passes_document_id(self, provider): """sync_turn should pass document_id (session_id + per-startup ts).""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() call_kwargs = provider._client.aretain_batch.call_args.kwargs # Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds} assert call_kwargs["document_id"].startswith("test-session-") @@ -819,8 +813,7 @@ class TestSyncTurn: def test_sync_turn_session_tag(self, provider): """Each retain should be tagged with session: for filtering.""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() item = provider._client.aretain_batch.call_args.kwargs["items"][0] assert "session:test-session" in item["tags"] @@ -841,8 +834,7 @@ class TestSyncTurn: ) p._client = _make_mock_client() p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert "session:child-session" in item["tags"] @@ -851,15 +843,14 @@ class TestSyncTurn: def test_sync_turn_error_does_not_raise(self, provider): provider._client.aretain_batch.side_effect = RuntimeError("network error") provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() def test_sync_turn_preserves_unicode(self, provider_with_config): """Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact.""" p = provider_with_config() p._client = _make_mock_client() p.sync_turn("안녕 こんにちは 你好", "👨‍👩‍👧‍👦 family") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() item = p._client.aretain_batch.call_args.kwargs["items"][0] # ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON, @@ -871,6 +862,216 @@ class TestSyncTurn: assert "👨‍👩‍👧‍👦" in raw_json +# --------------------------------------------------------------------------- +# Shutdown / writer tests +# --------------------------------------------------------------------------- + + +class TestShutdownRace: + def test_sync_turn_uses_single_writer_thread(self, provider): + """All retains run through one long-lived writer thread.""" + provider.sync_turn("a", "b") + provider._retain_queue.join() + first_writer = provider._writer_thread + assert first_writer is not None + assert first_writer.is_alive() + + provider.sync_turn("c", "d") + provider._retain_queue.join() + # Same thread reused — no ad-hoc thread per call. + assert provider._writer_thread is first_writer + assert provider._client.aretain_batch.call_count == 2 + + def test_sync_turn_after_shutdown_is_dropped(self, provider): + """Once shutdown has fired, new sync_turn() calls are no-ops. + + This is the core of the fix: the plugin must not enqueue a retain + during interpreter teardown — that's what causes the + 'cannot schedule new futures' RuntimeError + unclosed aiohttp + sessions on CLI exit. + """ + client = provider._client + provider.shutdown() + before_calls = client.aretain_batch.call_count + provider.sync_turn("late", "turn") + # No new enqueue — the retain queue stays empty. + assert provider._retain_queue.empty() + # And no new client call (would be impossible anyway since shutdown + # nulled self._client; we assert via the captured handle). + assert client.aretain_batch.call_count == before_calls + + def test_queue_prefetch_after_shutdown_is_dropped(self, provider): + provider.shutdown() + provider.queue_prefetch("late query") + assert provider._prefetch_thread is None + + def test_shutdown_drains_pending_retains(self, provider): + """Shutdown must wait for queued retains to complete, not abandon them. + + Otherwise the LAST in-flight turn — typically the most important — + is silently lost. + """ + client = provider._client + provider.sync_turn("a", "b") + provider.sync_turn("c", "d") + provider.shutdown() + # Both retains drained before shutdown returned. + assert client.aretain_batch.call_count == 2 + assert provider._retain_queue.empty() + + def test_shutdown_is_idempotent(self, provider): + provider.sync_turn("a", "b") + provider.shutdown() + # Second shutdown shouldn't blow up or re-close the client. + provider.shutdown() + assert provider._shutting_down.is_set() + + +# --------------------------------------------------------------------------- +# on_session_switch — flush + prefetch reset behavior +# --------------------------------------------------------------------------- + + +class TestSessionSwitchBufferFlush: + def test_buffered_turns_flushed_before_clear(self, provider_with_config): + """retain_every_n_turns > 1 must not silently drop partial buffers + on session switch. Whatever's in _session_turns at switch time + should land in the OLD document under the OLD session id.""" + p = provider_with_config(retain_every_n_turns=3, retain_async=False) + old_doc = p._document_id + + # Two turns buffered, no retain yet (boundary is at turn 3). The + # writer hasn't been started either — sync_turn's early return + # skips _ensure_writer when no retain is due. + p.sync_turn("turn1-user", "turn1-asst") + p.sync_turn("turn2-user", "turn2-asst") + assert p._sync_thread is None + p._client.aretain_batch.assert_not_called() + + # Switch — flush should fire under OLD document_id via the writer queue. + p.on_session_switch("new-sid", parent_session_id="test-session", reset=True) + p._retain_queue.join() + + p._client.aretain_batch.assert_called_once() + kw = p._client.aretain_batch.call_args.kwargs + assert kw["document_id"] == old_doc + item = kw["items"][0] + # Both buffered turns must be present in the flushed payload. + content = json.loads(item["content"]) + flat = json.dumps(content) + assert "turn1-user" in flat + assert "turn2-user" in flat + # Old session id must appear in lineage tags / metadata. + assert "session:test-session" in item["tags"] + assert item["metadata"]["session_id"] == "test-session" + + # And the new session must start with a clean slate. + assert p._session_id == "new-sid" + assert p._session_turns == [] + assert p._turn_counter == 0 + assert p._document_id != old_doc + assert p._document_id.startswith("new-sid-") + + def test_no_flush_when_buffer_empty(self, provider): + """Switch with no buffered turns must not fire a spurious retain.""" + provider.on_session_switch("new-sid") + # Nothing enqueued — join is immediate. + provider._retain_queue.join() + provider._client.aretain_batch.assert_not_called() + assert provider._session_id == "new-sid" + + def test_prefetch_result_cleared_on_switch(self, provider): + """Stale recall text from the old session must not leak into the + next session's first prefetch read.""" + provider._prefetch_result = "old-session recall: User likes Rust" + provider.on_session_switch("new-sid") + assert provider._prefetch_result == "" + # And subsequent prefetch() should now report empty, not the leftover. + assert provider.prefetch("anything") == "" + + def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch): + """on_session_switch must wait for an in-flight prefetch from the + old session to settle before clearing _prefetch_result, otherwise + the thread can race and re-populate the field after the clear.""" + import threading + import time as _time + + gate = threading.Event() + finished = threading.Event() + + def _slow_prefetch(): + gate.wait(timeout=5.0) + with provider._prefetch_lock: + provider._prefetch_result = "old-session recall" + finished.set() + + provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True) + provider._prefetch_thread.start() + + # Release the prefetch worker so it writes _prefetch_result, then + # call on_session_switch — it must join the thread before clearing. + gate.set() + provider.on_session_switch("new-sid") + + assert finished.is_set(), "switch returned before prefetch thread settled" + assert provider._prefetch_result == "" + + def test_flush_serializes_behind_pending_retains_via_writer_queue( + self, provider_with_config + ): + """The flush closure must ride the same _retain_queue sync_turn + uses, so it lands FIFO behind any still-queued old-session + retains rather than racing them on a separate thread. + + Regression guard: an earlier draft spawned a raw threading.Thread + for flush, overwriting _sync_thread and racing the writer against + the same document_id. + """ + import threading as _threading + + p = provider_with_config(retain_every_n_turns=2, retain_async=False) + + # Block the first writer job until we've enqueued the flush + # behind it. This proves ordering — the flush MUST wait. + gate = _threading.Event() + call_order: list[str] = [] + + def _aretain_batch_tracking(**kw): + idx = kw["items"][0]["metadata"].get("turn_index", "") + call_order.append(str(idx)) + if idx == "2": + # First retain blocks until we've enqueued the flush. + gate.wait(timeout=5.0) + + p._client.aretain_batch = AsyncMock(side_effect=_aretain_batch_tracking) + + # Turn 1+2 → boundary hit → retain enqueued (will block). + p.sync_turn("turn1-user", "turn1-asst") + p.sync_turn("turn2-user", "turn2-asst") + + # One more buffered turn so flush has something to land. + p.sync_turn("turn3-user", "turn3-asst") + + # Switch while the first retain is still blocked on `gate`. + p.on_session_switch("new-sid", parent_session_id="test-session") + + # Release the first retain. Flush must have been enqueued + # BEHIND it, and run second. + gate.set() + p._retain_queue.join() + + # The flush carries all buffered turns; sync_turn's retain #2 + # carried the batch at boundary time. Two distinct calls. + assert p._client.aretain_batch.call_count == 2 + # First call landed while buffer was [t1, t2]; flush landed + # after we added t3. So the second call must be strictly after. + assert call_order[0] == "2" + # Flush retain has turn_index matching the buffered count at + # switch time (3 turns accumulated, _turn_index was set to 3 + # by the last sync_turn). + assert call_order[1] == "3" + + # --------------------------------------------------------------------------- # System prompt tests # --------------------------------------------------------------------------- diff --git a/tests/plugins/test_achievements_plugin.py b/tests/plugins/test_achievements_plugin.py new file mode 100644 index 0000000000..782aea7b39 --- /dev/null +++ b/tests/plugins/test_achievements_plugin.py @@ -0,0 +1,377 @@ +"""Tests for the bundled hermes-achievements dashboard plugin. + +These target the two behaviors that matter for official integration: + +* The 200-session scan cap is removed — the plugin now walks the entire + session history by default. Lifetime badges (tens of thousands of + tool calls) were unreachable before this fix on long-running installs. +* First-ever scans run in a background thread so the dashboard request + path never blocks, even on 8000+ session databases where a cold scan + takes minutes. + +The upstream repo ships its own unittest suite under +``plugins/hermes-achievements/tests/`` covering the achievement engine +internals (tier math, secret-state handling, catalog invariants). These +tests live at the hermes-agent level and focus on the integration +contract: the plugin scans ALL of your sessions, not the first 200. +""" +from __future__ import annotations + +import importlib.util +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest + +PLUGIN_MODULE_PATH = ( + Path(__file__).resolve().parents[2] + / "plugins" + / "hermes-achievements" + / "dashboard" + / "plugin_api.py" +) + + +@pytest.fixture +def plugin_api(tmp_path, monkeypatch): + """Load plugin_api with isolated ~/.hermes so state/snapshot files don't collide. + + We load the module fresh per test because the plugin keeps module-level + caches (``_SNAPSHOT_CACHE``, ``_SCAN_STATUS``, background thread handle). + Reloading gives each test a clean world. + """ + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + spec = importlib.util.spec_from_file_location( + f"plugin_api_test_{id(tmp_path)}", PLUGIN_MODULE_PATH + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Stash monkeypatch so ``_install_fake_session_db`` can use it to + # swap ``sys.modules['hermes_state']`` with auto-restoration. Without + # this, a raw ``sys.modules[...] = fake`` assignment would leak the + # fake into later tests in the same xdist worker — breaking every + # test that does ``from hermes_state import SessionDB``. + module._test_monkeypatch = monkeypatch + yield module + + +class _FakeSessionDB: + """Stand-in for hermes_state.SessionDB that records scan calls.""" + + def __init__(self, session_count: int): + self.session_count = session_count + self.last_limit: Optional[int] = None + self.last_include_children: Optional[bool] = None + self.list_calls = 0 + self.messages_calls = 0 + + def list_sessions_rich( + self, + source: Optional[str] = None, + exclude_sources: Optional[List[str]] = None, + limit: int = 20, + offset: int = 0, + include_children: bool = False, + project_compression_tips: bool = True, + ) -> List[Dict[str, Any]]: + self.last_limit = limit + self.last_include_children = include_children + self.list_calls += 1 + # SQLite semantics: LIMIT -1 = unlimited. Honor that here. + effective = self.session_count if limit == -1 else min(self.session_count, limit) + now = int(time.time()) + return [ + { + "id": f"sess-{i}", + "title": f"Session {i}", + "preview": f"preview {i}", + "started_at": now - (self.session_count - i) * 60, + "last_active": now - (self.session_count - i) * 60 + 30, + "source": "cli", + "model": "test-model", + } + for i in range(effective) + ] + + def get_messages(self, session_id: str) -> List[Dict[str, Any]]: + self.messages_calls += 1 + return [ + {"role": "user", "content": f"ask {session_id}"}, + { + "role": "assistant", + "tool_calls": [{"function": {"name": "terminal"}}], + }, + {"role": "tool", "tool_name": "terminal", "content": "ok"}, + ] + + def close(self) -> None: + pass + + +def _install_fake_session_db(plugin_api, fake_db): + """Inject a fake SessionDB so ``scan_sessions`` finds it via its local import. + + Uses the monkeypatch stashed on ``plugin_api`` by the fixture, so the + ``sys.modules['hermes_state']`` swap is auto-restored at test teardown + and cannot leak into unrelated tests in the same xdist worker. + """ + fake_module = type(sys)("hermes_state") + fake_module.SessionDB = lambda: fake_db + plugin_api._test_monkeypatch.setitem(sys.modules, "hermes_state", fake_module) + + +def test_scan_sessions_default_scans_all_history_not_first_200(plugin_api): + """Bug regression: ``scan_sessions()`` used to cap at limit=200. + + A user with 8000+ sessions would only see ~2% of their history in + achievement totals, making lifetime badges unreachable. The default + now passes ``LIMIT -1`` (SQLite "unlimited") to ``list_sessions_rich``. + """ + fake_db = _FakeSessionDB(session_count=500) # > old 200 cap + _install_fake_session_db(plugin_api, fake_db) + + result = plugin_api.scan_sessions() + + assert fake_db.last_limit == -1, ( + "scan_sessions() must pass LIMIT=-1 (unlimited) to list_sessions_rich " + f"by default, got {fake_db.last_limit}" + ) + assert fake_db.last_include_children is True, ( + "scan_sessions() must include subagent/compression child sessions so " + "tool calls made in delegated agents still count toward achievements" + ) + assert len(result["sessions"]) == 500 + assert result["scan_meta"]["sessions_total"] == 500 + + +def test_scan_sessions_explicit_positive_limit_is_honored(plugin_api): + """Callers can still pass a small limit for smoke tests.""" + fake_db = _FakeSessionDB(session_count=500) + _install_fake_session_db(plugin_api, fake_db) + + result = plugin_api.scan_sessions(limit=10) + + assert fake_db.last_limit == 10 + assert len(result["sessions"]) == 10 + + +def test_scan_sessions_zero_or_negative_limit_means_unlimited(plugin_api): + """``limit=0`` and ``limit=-1`` both map to the unlimited path.""" + fake_db = _FakeSessionDB(session_count=300) + _install_fake_session_db(plugin_api, fake_db) + + plugin_api.scan_sessions(limit=0) + assert fake_db.last_limit == -1 + + plugin_api.scan_sessions(limit=-1) + assert fake_db.last_limit == -1 + + +def test_evaluate_all_first_run_returns_pending_and_starts_background_scan(plugin_api): + """First-ever evaluate_all with no cache returns a pending placeholder + immediately and kicks off a background scan thread. Cold scans on + large DBs take minutes — blocking the dashboard request path is not + acceptable. + """ + fake_db = _FakeSessionDB(session_count=50) + _install_fake_session_db(plugin_api, fake_db) + + # Wrap _run_scan_and_update_cache so we can release it on demand, + # simulating a slow cold scan without actually waiting. + scan_started = threading.Event() + allow_scan_finish = threading.Event() + original_run = plugin_api._run_scan_and_update_cache + + def gated_run(*args, **kwargs): + scan_started.set() + allow_scan_finish.wait(timeout=5) + original_run(*args, **kwargs) + + plugin_api._run_scan_and_update_cache = gated_run + + t0 = time.time() + result = plugin_api.evaluate_all() + elapsed = time.time() - t0 + + # Immediate return — should not block waiting for the scan. + assert elapsed < 1.0, f"evaluate_all blocked for {elapsed:.2f}s on first run" + assert result["scan_meta"]["mode"] == "pending" + assert result["unlocked_count"] == 0 + # Catalog still rendered so UI has something to draw. + assert result["total_count"] >= 60 + + # Background scan is running. + assert scan_started.wait(timeout=2), "background scan did not start" + + # Let the scan complete, then a second call returns real data. + allow_scan_finish.set() + # Wait for thread to finish. + thread = plugin_api._BACKGROUND_SCAN_THREAD + assert thread is not None + thread.join(timeout=5) + assert not thread.is_alive() + + second = plugin_api.evaluate_all() + assert second["scan_meta"]["mode"] != "pending" + assert second["scan_meta"].get("sessions_total") == 50 + + +def test_evaluate_all_stale_cache_serves_stale_and_refreshes_in_background(plugin_api): + """When the snapshot is on-disk but older than TTL, evaluate_all returns + the stale data immediately and kicks a background refresh. Users don't + stare at a loading spinner every time TTL expires. + """ + fake_db = _FakeSessionDB(session_count=10) + _install_fake_session_db(plugin_api, fake_db) + + # Seed a stale snapshot on disk. + stale_generated_at = int(time.time()) - plugin_api.SNAPSHOT_TTL_SECONDS - 60 + stale_payload = { + "achievements": [], + "sessions": [], + "aggregate": {}, + "scan_meta": {"mode": "full", "sessions_total": 1, "sessions_rescanned": 1, "sessions_reused": 0}, + "error": None, + "unlocked_count": 0, + "discovered_count": 0, + "secret_count": 0, + "total_count": 0, + "generated_at": stale_generated_at, + } + plugin_api.save_snapshot(stale_payload) + + t0 = time.time() + result = plugin_api.evaluate_all() + elapsed = time.time() - t0 + + assert elapsed < 1.0, f"evaluate_all blocked for {elapsed:.2f}s serving stale data" + assert result["generated_at"] == stale_generated_at + + # Background scan should be running or have completed. + thread = plugin_api._BACKGROUND_SCAN_THREAD + assert thread is not None + thread.join(timeout=5) + + fresh = plugin_api.evaluate_all() + assert fresh["generated_at"] >= stale_generated_at + + +def test_evaluate_all_force_runs_synchronously(plugin_api): + """Manual /rescan (force=True) blocks the caller — users clicking + the rescan button expect up-to-date data when the call returns. + """ + fake_db = _FakeSessionDB(session_count=25) + _install_fake_session_db(plugin_api, fake_db) + + result = plugin_api.evaluate_all(force=True) + + # Synchronous — snapshot is fresh on return. + assert result["scan_meta"].get("sessions_total") == 25 + assert result["scan_meta"]["mode"] in ("full", "incremental") + + +def test_start_background_scan_is_idempotent_while_running(plugin_api): + """Multiple concurrent dashboard requests must not spawn duplicate scans.""" + fake_db = _FakeSessionDB(session_count=5) + _install_fake_session_db(plugin_api, fake_db) + + release = threading.Event() + original_run = plugin_api._run_scan_and_update_cache + + def gated_run(*args, **kwargs): + release.wait(timeout=5) + original_run(*args, **kwargs) + + plugin_api._run_scan_and_update_cache = gated_run + + plugin_api._start_background_scan() + first_thread = plugin_api._BACKGROUND_SCAN_THREAD + assert first_thread is not None and first_thread.is_alive() + + plugin_api._start_background_scan() + plugin_api._start_background_scan() + + assert plugin_api._BACKGROUND_SCAN_THREAD is first_thread + + release.set() + first_thread.join(timeout=5) + + +def test_background_scan_publishes_partial_snapshots(plugin_api): + """The background scanner publishes intermediate snapshots to the cache + every ~N sessions. Each dashboard refresh during a long cold scan sees + more badges unlocked instead of staring at zeros for minutes and then + having everything pop at the end. + """ + fake_db = _FakeSessionDB(session_count=750) + _install_fake_session_db(plugin_api, fake_db) + + # Record every partial snapshot the scanner publishes. + partial_snapshots: List[Dict[str, Any]] = [] + original_compute_from_scan = plugin_api._compute_from_scan + + def recording_compute(scan, *, is_partial=False): + result = original_compute_from_scan(scan, is_partial=is_partial) + if is_partial: + partial_snapshots.append(result) + return result + + plugin_api._compute_from_scan = recording_compute + + # scan 750 sessions with progress_every=250 → expect 2 intermediate + # publications (at 250 and 500; the final 750 call goes through the + # finished, non-partial path). + plugin_api._run_scan_and_update_cache(publish_partial_snapshots=True) + + assert len(partial_snapshots) >= 2, ( + f"expected at least 2 partial publications on a 750-session scan with " + f"progress_every=250, got {len(partial_snapshots)}" + ) + # Partial snapshots should report growing session counts. + counts = [p["scan_meta"].get("sessions_scanned_so_far") for p in partial_snapshots] + assert counts == sorted(counts), f"partial session counts not monotonic: {counts}" + assert counts[0] < 750 and counts[-1] < 750, ( + f"partial counts should be less than the final total; got {counts}" + ) + # Every partial reports the expected end-state total so the UI can + # show an accurate progress bar. + for p in partial_snapshots: + assert p["scan_meta"].get("sessions_expected_total") == 750 + + # Final snapshot in cache is the real (non-partial) one. + final = plugin_api._SNAPSHOT_CACHE + assert final is not None + assert final["scan_meta"].get("mode") != "in_progress" + assert final["scan_meta"].get("sessions_total") == 750 + + +def test_partial_snapshots_do_not_persist_unlock_timestamps(plugin_api): + """Intermediate snapshots must not write to state.json — an unlock + that appears at 30% scan progress could disappear when a later session + rebalances the aggregate. Only the final snapshot records ``unlocked_at``. + """ + fake_db = _FakeSessionDB(session_count=10) + _install_fake_session_db(plugin_api, fake_db) + + # Seed empty state, then invoke partial compute directly. + plugin_api.save_state({"unlocks": {}}) + partial_scan = { + "sessions": [{"session_id": "x", "tool_call_count": 99999, "tool_names": set()}], + "aggregate": {"max_tool_calls_in_session": 99999, "total_tool_calls": 99999}, + "scan_meta": {"mode": "in_progress"}, + } + result = plugin_api._compute_from_scan(partial_scan, is_partial=True) + + # Some achievements should evaluate as unlocked in this aggregate... + assert any(a["unlocked"] for a in result["achievements"]) + + # ...but state.json on disk stays empty (no timestamps were recorded). + persisted = plugin_api.load_state() + assert persisted.get("unlocks", {}) == {}, ( + "partial scans must not record unlock timestamps — a later session " + "could change whether the badge deserves to be unlocked yet" + ) diff --git a/tests/run_agent/test_anthropic_prompt_cache_policy.py b/tests/run_agent/test_anthropic_prompt_cache_policy.py index 7a85022a5c..b8a380a62e 100644 --- a/tests/run_agent/test_anthropic_prompt_cache_policy.py +++ b/tests/run_agent/test_anthropic_prompt_cache_policy.py @@ -89,15 +89,75 @@ class TestThirdPartyAnthropicGateway: assert should is True, "Third-party Anthropic gateway with Claude must cache" assert native is True, "Third-party Anthropic gateway uses native cache_control layout" - def test_third_party_without_claude_name_does_not_cache(self): - # A provider exposing e.g. GLM via anthropic_messages transport — we - # don't know whether it supports cache_control, so stay conservative. + def test_third_party_anthropic_non_claude_unknown_provider_does_not_cache(self): + # A provider exposing e.g. GLM via anthropic_messages transport from + # a host we don't recognize — we don't know whether it supports + # cache_control, so stay conservative. + agent = _make_agent( + provider="custom", + base_url="https://some-unknown-gateway.example.com/anthropic", + api_mode="anthropic_messages", + model="glm-4.5", + ) + assert agent._anthropic_prompt_cache_policy() == (False, False) + + +class TestMiniMaxAnthropicWire: + """MiniMax's own model family on its Anthropic-compatible endpoint. + + MiniMax documents cache_control support on ``/anthropic`` (0.1× read + pricing, 5-minute TTL). Issue #17332: the blanket ``is_claude`` gate on + the third-party-gateway branch left MiniMax-M2.7 etc. paying full input + cost every turn. Allowlist MiniMax explicitly via provider id or host. + """ + + def test_minimax_m27_on_provider_minimax_caches_native_layout(self): + agent = _make_agent( + provider="minimax", + base_url="https://api.minimax.io/anthropic", + api_mode="anthropic_messages", + model="minimax-m2.7", + ) + assert agent._anthropic_prompt_cache_policy() == (True, True) + + def test_minimax_m25_on_provider_minimax_cn_caches_native_layout(self): + agent = _make_agent( + provider="minimax-cn", + base_url="https://api.minimaxi.com/anthropic", + api_mode="anthropic_messages", + model="minimax-m2.5", + ) + assert agent._anthropic_prompt_cache_policy() == (True, True) + + def test_custom_provider_pointed_at_minimax_host_caches(self): + # User wires a custom provider manually at MiniMax's Anthropic URL; + # host match alone should be sufficient to enable caching. agent = _make_agent( provider="custom", base_url="https://api.minimax.io/anthropic", api_mode="anthropic_messages", model="minimax-m2.7", ) + assert agent._anthropic_prompt_cache_policy() == (True, True) + + def test_minimax_host_china_endpoint_caches(self): + agent = _make_agent( + provider="custom", + base_url="https://api.minimaxi.com/anthropic", + api_mode="anthropic_messages", + model="minimax-m2.1", + ) + assert agent._anthropic_prompt_cache_policy() == (True, True) + + def test_minimax_provider_on_openai_wire_does_not_cache(self): + # chat_completions transport — MiniMax's cache_control support is + # documented only for the /anthropic endpoint. Stay off. + agent = _make_agent( + provider="minimax", + base_url="https://api.minimax.io/v1", + api_mode="chat_completions", + model="minimax-m2.7", + ) assert agent._anthropic_prompt_cache_policy() == (False, False) diff --git a/tests/run_agent/test_background_review_toolset_restriction.py b/tests/run_agent/test_background_review_toolset_restriction.py index 0ee3324872..d1193dc6f9 100644 --- a/tests/run_agent/test_background_review_toolset_restriction.py +++ b/tests/run_agent/test_background_review_toolset_restriction.py @@ -8,12 +8,10 @@ effects (terminal, send_message, delegate_task, etc.). import threading from unittest.mock import patch -from run_agent import AIAgent - -def _make_agent_stub(): +def _make_agent_stub(agent_cls): """Create a minimal AIAgent-like object with just enough state for _spawn_background_review.""" - agent = object.__new__(AIAgent) + agent = object.__new__(agent_cls) agent.model = "test-model" agent.platform = "test" agent.provider = "openai" @@ -45,14 +43,16 @@ class _SyncThread: def test_background_review_agent_uses_restricted_toolsets(): """The review agent must only have access to 'memory' and 'skills' toolsets.""" - agent = _make_agent_stub() + import run_agent + + agent = _make_agent_stub(run_agent.AIAgent) captured = {} def _capture_init(self, *args, **kwargs): captured["enabled_toolsets"] = kwargs.get("enabled_toolsets") raise RuntimeError("stop after capturing init args") - with patch.object(AIAgent, "__init__", _capture_init), \ + with patch.object(run_agent.AIAgent, "__init__", _capture_init), \ patch("threading.Thread", _SyncThread): agent._spawn_background_review( messages_snapshot=[], diff --git a/tests/run_agent/test_memory_sync_interrupted.py b/tests/run_agent/test_memory_sync_interrupted.py index 32313740dc..feeb028927 100644 --- a/tests/run_agent/test_memory_sync_interrupted.py +++ b/tests/run_agent/test_memory_sync_interrupted.py @@ -31,6 +31,10 @@ def _bare_agent(): agent = AIAgent.__new__(AIAgent) agent._memory_manager = MagicMock() + # session_id is now propagated into sync_all / queue_prefetch_all so + # providers that cache per-session state can update it mid-process + # (see #6672). + agent.session_id = "test_session_001" return agent @@ -80,9 +84,11 @@ class TestSyncExternalMemoryForTurn: ) agent._memory_manager.sync_all.assert_called_once_with( "What's the weather in Paris?", "It's sunny and 22°C.", + session_id="test_session_001", ) agent._memory_manager.queue_prefetch_all.assert_called_once_with( "What's the weather in Paris?", + session_id="test_session_001", ) # --- Edge cases (pre-existing behaviour preserved) ------------------ diff --git a/tests/run_agent/test_provider_parity.py b/tests/run_agent/test_provider_parity.py index 5ad40e8a88..8eb7478b41 100644 --- a/tests/run_agent/test_provider_parity.py +++ b/tests/run_agent/test_provider_parity.py @@ -144,6 +144,36 @@ class TestBuildApiKwargsOpenRouter: assert messages[1]["tool_calls"][0]["response_item_id"] == "fc_123" assert "codex_reasoning_items" in messages[1] + def test_gemini_native_passes_base_url_for_top_level_thinking_config(self, monkeypatch): + agent = _make_agent( + monkeypatch, + "gemini", + base_url="https://generativelanguage.googleapis.com/v1beta", + model="gemini-3-flash-preview", + ) + agent.reasoning_config = {"enabled": True, "effort": "high"} + kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}]) + assert kwargs["extra_body"]["thinking_config"] == { + "includeThoughts": True, + "thinkingLevel": "high", + } + assert "extra_body" not in kwargs["extra_body"] + + def test_gemini_openai_compat_passes_base_url_for_nested_google_thinking_config(self, monkeypatch): + agent = _make_agent( + monkeypatch, + "gemini", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + model="gemini-3.1-pro-preview", + ) + agent.reasoning_config = {"enabled": True, "effort": "high"} + kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}]) + assert "thinking_config" not in kwargs["extra_body"] + assert kwargs["extra_body"]["extra_body"]["google"]["thinking_config"] == { + "include_thoughts": True, + "thinking_level": "high", + } + def test_should_sanitize_tool_calls_codex_vs_chat(self, monkeypatch): """Codex API should NOT sanitize, all other APIs should sanitize.""" # Codex mode should NOT need sanitization @@ -936,17 +966,25 @@ class TestAuxiliaryClientProviderPriority: client, model = get_text_auxiliary_client() assert mock.call_args.kwargs["base_url"] == "http://localhost:1234/v1" - def test_codex_fallback_last_resort(self, monkeypatch): + def test_codex_not_in_auto_fallback(self, monkeypatch): + """Codex is deliberately NOT part of the auto fallback chain. + + ChatGPT-account Codex gates which models it accepts via an + undocumented, shifting allow-list, so falling through to Codex with + a hardcoded default model breaks silently whenever OpenAI rotates + the list. When nothing else is available, ``get_text_auxiliary_client`` + now returns (None, None) rather than guessing a Codex model. + """ monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) monkeypatch.delenv("OPENAI_BASE_URL", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False) - from agent.auxiliary_client import get_text_auxiliary_client, CodexAuxiliaryClient + from agent.auxiliary_client import get_text_auxiliary_client with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \ patch("agent.auxiliary_client._read_codex_access_token", return_value="codex-tok"), \ patch("agent.auxiliary_client.OpenAI"): client, model = get_text_auxiliary_client() - assert model == "gpt-5.2-codex" - assert isinstance(client, CodexAuxiliaryClient) + assert client is None + assert model is None # ── Provider routing tests ─────────────────────────────────────────────────── diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 2c13a1569f..5585eea484 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -862,6 +862,26 @@ class TestBuildSystemPrompt: prompt = agent._build_system_prompt() assert DEFAULT_AGENT_IDENTITY in prompt + def test_can_use_soul_identity_even_when_context_files_are_skipped(self): + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("terminal")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch("run_agent.load_soul_md", return_value="SOUL IDENTITY"), + ): + agent = AIAgent( + api_key="test-k...7890", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + load_soul_identity=True, + skip_memory=True, + ) + prompt = agent._build_system_prompt() + + assert "SOUL IDENTITY" in prompt + assert DEFAULT_AGENT_IDENTITY not in prompt + def test_includes_system_message(self, agent): prompt = agent._build_system_prompt(system_message="Custom instruction") assert "Custom instruction" in prompt diff --git a/tests/test_cli_skin_integration.py b/tests/test_cli_skin_integration.py index ad99358ab1..40b396fb1b 100644 --- a/tests/test_cli_skin_integration.py +++ b/tests/test_cli_skin_integration.py @@ -96,7 +96,7 @@ class TestCompactBannerSkinIntegration: set_active_skin("default") with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \ - patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"): + patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}): banner = _build_compact_banner() assert "NOUS HERMES" in banner @@ -105,7 +105,7 @@ class TestCompactBannerSkinIntegration: set_active_skin("poseidon") with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \ - patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"): + patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}): banner = _build_compact_banner() assert "Poseidon Agent" in banner @@ -116,7 +116,7 @@ class TestCompactBannerSkinIntegration: skin = get_active_skin() with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \ - patch("cli.format_banner_version_label", return_value="Hermes Agent v0.1.0 (test)"): + patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v0.1.0 (test)"}): banner = _build_compact_banner() assert skin.get_color("banner_border") in banner @@ -127,7 +127,7 @@ class TestCompactBannerSkinIntegration: set_active_skin("default") with patch("cli.shutil.get_terminal_size", return_value=SimpleNamespace(columns=90)), \ - patch("cli.format_banner_version_label", return_value="Hermes Agent v1.0 (test) · upstream abc12345"): + patch.dict(_build_compact_banner.__globals__, {"format_banner_version_label": lambda: "Hermes Agent v1.0 (test) · upstream abc12345"}): banner = _build_compact_banner() assert "upstream abc12345" in banner diff --git a/tests/test_minimax_oauth.py b/tests/test_minimax_oauth.py new file mode 100644 index 0000000000..0e63800e91 --- /dev/null +++ b/tests/test_minimax_oauth.py @@ -0,0 +1,466 @@ +"""Tests for MiniMax OAuth provider (hermes_cli/auth.py). + +Covers: +- PKCE pair generation (S256 challenge) +- _minimax_request_user_code happy path and state-mismatch error +- _minimax_poll_token: pending→success flow, error status, timeout +- _refresh_minimax_oauth_state: skip when not expired, update on success, + re-login required on invalid_grant +- resolve_minimax_oauth_runtime_credentials: error when not logged in +""" +from __future__ import annotations + +import base64 +import hashlib +import json +import time +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from hermes_cli.auth import ( + PROVIDER_REGISTRY, + AuthError, + MINIMAX_OAUTH_CLIENT_ID, + MINIMAX_OAUTH_GLOBAL_BASE, + MINIMAX_OAUTH_GLOBAL_INFERENCE, + MINIMAX_OAUTH_CN_BASE, + MINIMAX_OAUTH_CN_INFERENCE, + MINIMAX_OAUTH_REFRESH_SKEW_SECONDS, + _minimax_pkce_pair, + _minimax_request_user_code, + _minimax_poll_token, + _refresh_minimax_oauth_state, + resolve_minimax_oauth_runtime_credentials, + get_minimax_oauth_auth_status, + get_provider_auth_state, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_httpx_response(status_code: int, body: dict | None = None, text: str = ""): + """Return a minimal mock that quacks like httpx.Response.""" + resp = MagicMock() + resp.status_code = status_code + if body is not None: + resp.json.return_value = body + resp.text = json.dumps(body) + else: + resp.json.side_effect = Exception("No body") + resp.text = text + resp.reason_phrase = "OK" if status_code == 200 else "Error" + return resp + + +def _future_iso(seconds_from_now: int = 3600) -> str: + ts = time.time() + seconds_from_now + return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() + + +def _past_iso(seconds_ago: int = 3600) -> str: + ts = time.time() - seconds_ago + return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() + + +# --------------------------------------------------------------------------- +# 1. test_pkce_pair_produces_valid_s256 +# --------------------------------------------------------------------------- + +def test_pkce_pair_produces_valid_s256(): + verifier, challenge, state = _minimax_pkce_pair() + + # Verifier must be non-empty and URL-safe + assert isinstance(verifier, str) + assert len(verifier) >= 32 + + # Challenge must be URL-safe base64 without trailing "=" + assert isinstance(challenge, str) + assert "=" not in challenge + + # Re-compute challenge from verifier and verify it matches + expected = base64.urlsafe_b64encode( + hashlib.sha256(verifier.encode()).digest() + ).decode().rstrip("=") + assert challenge == expected + + # State must be non-empty + assert isinstance(state, str) + assert len(state) >= 8 + + # Two calls must return different values (randomness) + v2, c2, s2 = _minimax_pkce_pair() + assert verifier != v2 + assert state != s2 + + +# --------------------------------------------------------------------------- +# 2. test_request_user_code_happy_path +# --------------------------------------------------------------------------- + +def test_request_user_code_happy_path(): + state = "test-state-abc" + mock_response = _make_httpx_response(200, { + "user_code": "ABC-123", + "verification_uri": "https://minimax.io/verify", + "expired_in": int(time.time() * 1000) + 300_000, + "state": state, + }) + + client = MagicMock() + client.post.return_value = mock_response + + result = _minimax_request_user_code( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + code_challenge="test-challenge", + state=state, + ) + + assert result["user_code"] == "ABC-123" + assert result["verification_uri"] == "https://minimax.io/verify" + assert result["state"] == state + + # Verify correct endpoint was called + call_args = client.post.call_args + assert "/oauth/code" in call_args[0][0] + headers = call_args[1].get("headers", {}) + assert "x-request-id" in headers + + +# --------------------------------------------------------------------------- +# 3. test_request_user_code_state_mismatch_raises +# --------------------------------------------------------------------------- + +def test_request_user_code_state_mismatch_raises(): + mock_response = _make_httpx_response(200, { + "user_code": "XYZ", + "verification_uri": "https://minimax.io/verify", + "expired_in": 300, + "state": "wrong-state", # Mismatched! + }) + + client = MagicMock() + client.post.return_value = mock_response + + with pytest.raises(AuthError) as exc_info: + _minimax_request_user_code( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + code_challenge="challenge", + state="correct-state", + ) + + assert exc_info.value.code == "state_mismatch" + assert "CSRF" in str(exc_info.value) or "mismatch" in str(exc_info.value).lower() + + +# --------------------------------------------------------------------------- +# 4. test_request_user_code_non_200_raises +# --------------------------------------------------------------------------- + +def test_request_user_code_non_200_raises(): + mock_response = _make_httpx_response(400, text="Bad Request") + mock_response.json.side_effect = Exception("no json") + mock_response.text = "Bad Request" + + client = MagicMock() + client.post.return_value = mock_response + + with pytest.raises(AuthError) as exc_info: + _minimax_request_user_code( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + code_challenge="challenge", + state="state", + ) + + assert exc_info.value.code == "authorization_failed" + + +# --------------------------------------------------------------------------- +# 5. test_poll_token_pending_then_success +# --------------------------------------------------------------------------- + +def test_poll_token_pending_then_success(): + # Set a deadline far enough in the future for polling + deadline_ms = int(time.time() * 1000) + 60_000 # 60 seconds from now + + pending_body = {"status": "pending"} + success_body = { + "status": "success", + "access_token": "access-abc", + "refresh_token": "refresh-xyz", + "expired_in": 3600, + "token_type": "Bearer", + } + + pending_resp = _make_httpx_response(200, pending_body) + success_resp = _make_httpx_response(200, success_body) + + client = MagicMock() + client.post.side_effect = [pending_resp, pending_resp, success_resp] + + with patch("time.sleep"): # don't actually sleep + result = _minimax_poll_token( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + user_code="USER-CODE", + code_verifier="verifier", + expired_in=deadline_ms, + interval_ms=2000, + ) + + assert result["status"] == "success" + assert result["access_token"] == "access-abc" + assert result["refresh_token"] == "refresh-xyz" + assert client.post.call_count == 3 + + +# --------------------------------------------------------------------------- +# 6. test_poll_token_error_raises +# --------------------------------------------------------------------------- + +def test_poll_token_error_raises(): + deadline_ms = int(time.time() * 1000) + 60_000 + error_body = {"status": "error"} + error_resp = _make_httpx_response(200, error_body) + + client = MagicMock() + client.post.return_value = error_resp + + with pytest.raises(AuthError) as exc_info: + _minimax_poll_token( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + user_code="U", + code_verifier="v", + expired_in=deadline_ms, + interval_ms=2000, + ) + + assert exc_info.value.code == "authorization_denied" + + +# --------------------------------------------------------------------------- +# 7. test_poll_token_timeout_raises +# --------------------------------------------------------------------------- + +def test_poll_token_timeout_raises(): + # expired_in is a small duration (treated as seconds from now, already expired) + expired_in = 1 # 1 second from now + # Make sleep a no-op and time.time advance quickly by using a small deadline + # We use a duration-style expired_in (small enough to not be a unix timestamp) + # duration mode: deadline = time.time() + max(1, expired_in) + # We need time() to exceed deadline immediately. + + fixed_now = time.time() + call_count = [0] + + def fake_time(): + call_count[0] += 1 + # After 2 calls, return a time past the deadline + if call_count[0] > 2: + return fixed_now + 10 # past deadline + return fixed_now + + client = MagicMock() + pending_resp = _make_httpx_response(200, {"status": "pending"}) + client.post.return_value = pending_resp + + import hermes_cli.auth as auth_module + with patch.object(auth_module, "time") as mock_time_mod: + # We need to patch the 'time' module used inside _minimax_poll_token + # The function imports 'import time as _time' locally. + # Patch time.sleep and time.time in the auth module's local scope. + pass + + # Use a simpler approach: expired_in as past timestamp (already expired) + past_deadline_ms = int((time.time() - 1) * 1000) # 1 second ago + + with pytest.raises(AuthError) as exc_info: + _minimax_poll_token( + client, + portal_base_url=MINIMAX_OAUTH_GLOBAL_BASE, + client_id=MINIMAX_OAUTH_CLIENT_ID, + user_code="U", + code_verifier="v", + expired_in=past_deadline_ms, + interval_ms=2000, + ) + + assert exc_info.value.code == "timeout" + + +# --------------------------------------------------------------------------- +# 8. test_refresh_skip_when_not_expired +# --------------------------------------------------------------------------- + +def test_refresh_skip_when_not_expired(): + """When token is far from expiry, refresh should return the same state.""" + state = { + "access_token": "old-access", + "refresh_token": "refresh-token", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(3600), # 1 hour in the future + } + + result = _refresh_minimax_oauth_state(state) + assert result["access_token"] == "old-access" + assert result is state # Same object returned (no refresh) + + +# --------------------------------------------------------------------------- +# 9. test_refresh_updates_access_token +# --------------------------------------------------------------------------- + +def test_refresh_updates_access_token(): + """When token is close to expiry, refresh should update the state.""" + # expires_at just MINIMAX_OAUTH_REFRESH_SKEW_SECONDS - 1 from now (close to expiry) + state = { + "access_token": "old-access", + "refresh_token": "my-refresh", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _future_iso(MINIMAX_OAUTH_REFRESH_SKEW_SECONDS - 1), + } + + new_token_body = { + "status": "success", + "access_token": "new-access", + "refresh_token": "new-refresh", + "expired_in": 7200, + } + + mock_resp = _make_httpx_response(200, new_token_body) + + with patch("httpx.Client") as mock_client_class: + mock_client_instance = MagicMock() + mock_client_instance.__enter__ = MagicMock(return_value=mock_client_instance) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_client_instance.post.return_value = mock_resp + mock_client_class.return_value = mock_client_instance + + # Patch _minimax_save_auth_state to avoid touching the auth store + with patch("hermes_cli.auth._minimax_save_auth_state"): + result = _refresh_minimax_oauth_state(state) + + assert result["access_token"] == "new-access" + assert result["refresh_token"] == "new-refresh" + assert result["expires_in"] == 7200 + + +# --------------------------------------------------------------------------- +# 10. test_refresh_reuse_triggers_relogin_required +# --------------------------------------------------------------------------- + +def test_refresh_reuse_triggers_relogin_required(): + """On 400 + invalid_grant body, relogin_required should be set.""" + state = { + "access_token": "old-access", + "refresh_token": "old-refresh", + "portal_base_url": MINIMAX_OAUTH_GLOBAL_BASE, + "client_id": MINIMAX_OAUTH_CLIENT_ID, + "inference_base_url": MINIMAX_OAUTH_GLOBAL_INFERENCE, + "expires_at": _past_iso(100), # already expired + } + + bad_resp = _make_httpx_response(400, text="invalid_grant") + bad_resp.json.side_effect = Exception("no json") + bad_resp.text = "invalid_grant" + bad_resp.reason_phrase = "Bad Request" + + with patch("httpx.Client") as mock_client_class: + mock_client_instance = MagicMock() + mock_client_instance.__enter__ = MagicMock(return_value=mock_client_instance) + mock_client_instance.__exit__ = MagicMock(return_value=False) + mock_client_instance.post.return_value = bad_resp + mock_client_class.return_value = mock_client_instance + + with pytest.raises(AuthError) as exc_info: + _refresh_minimax_oauth_state(state) + + assert exc_info.value.code == "refresh_failed" + assert exc_info.value.relogin_required is True + + +# --------------------------------------------------------------------------- +# 11. test_resolve_credentials_requires_login +# --------------------------------------------------------------------------- + +def test_resolve_credentials_requires_login(): + """When no state is stored, resolve_minimax_oauth_runtime_credentials raises.""" + with patch("hermes_cli.auth.get_provider_auth_state", return_value=None): + with pytest.raises(AuthError) as exc_info: + resolve_minimax_oauth_runtime_credentials() + + assert exc_info.value.code == "not_logged_in" + assert exc_info.value.relogin_required is True + + +# --------------------------------------------------------------------------- +# 12. test_provider_registry_contains_minimax_oauth +# --------------------------------------------------------------------------- + +def test_provider_registry_contains_minimax_oauth(): + assert "minimax-oauth" in PROVIDER_REGISTRY + pconfig = PROVIDER_REGISTRY["minimax-oauth"] + assert pconfig.auth_type == "oauth_minimax" + assert pconfig.client_id == MINIMAX_OAUTH_CLIENT_ID + assert MINIMAX_OAUTH_GLOBAL_BASE in pconfig.portal_base_url + assert MINIMAX_OAUTH_GLOBAL_INFERENCE in pconfig.inference_base_url + assert "cn_portal_base_url" in pconfig.extra + assert "cn_inference_base_url" in pconfig.extra + + +# --------------------------------------------------------------------------- +# 13. test_minimax_oauth_alias_resolves +# --------------------------------------------------------------------------- + +def test_minimax_oauth_alias_resolves(): + from hermes_cli.auth import resolve_provider + # Only test that minimax-oauth itself resolves (alias resolution is tested in models) + result = resolve_provider("minimax-oauth") + assert result == "minimax-oauth" + + +# --------------------------------------------------------------------------- +# 14. test_get_minimax_oauth_auth_status_not_logged_in +# --------------------------------------------------------------------------- + +def test_get_minimax_oauth_auth_status_not_logged_in(): + with patch("hermes_cli.auth.get_provider_auth_state", return_value=None): + status = get_minimax_oauth_auth_status() + + assert status["logged_in"] is False + assert status["provider"] == "minimax-oauth" + + +# --------------------------------------------------------------------------- +# 15. test_get_minimax_oauth_auth_status_logged_in +# --------------------------------------------------------------------------- + +def test_get_minimax_oauth_auth_status_logged_in(): + state = { + "access_token": "tok", + "expires_at": _future_iso(3600), + "region": "global", + } + + with patch("hermes_cli.auth.get_provider_auth_state", return_value=state): + status = get_minimax_oauth_auth_status() + + assert status["logged_in"] is True + assert status["region"] == "global" diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index c8fd3581aa..379aac2bbc 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -193,8 +193,15 @@ class TestPreToolCallBlocking: result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) assert result == {"ok": True} - def test_skip_flag_prevents_double_block_check(self, monkeypatch): - """When skip_pre_tool_call_hook=True, blocking is not checked (caller did it).""" + def test_skip_flag_prevents_double_fire(self, monkeypatch): + """When skip_pre_tool_call_hook=True, the hook does not fire again. + + The caller (e.g. run_agent._invoke_tool) has already called + get_pre_tool_call_block_message(), which fires the hook once. + handle_function_call must NOT fire it a second time — that was + the classic double-fire bug where observer hooks logged every + tool call twice. + """ hook_calls = [] def fake_invoke_hook(hook_name, **kwargs): @@ -208,10 +215,58 @@ class TestPreToolCallBlocking: handle_function_call("web_search", {"q": "test"}, task_id="t1", skip_pre_tool_call_hook=True) - # Hook still fires for observer notification, but get_pre_tool_call_block_message - # is not called — invoke_hook fires directly in the skip=True branch. - assert "pre_tool_call" in hook_calls + # Single-fire contract: when skip=True the caller already fired + # pre_tool_call, so handle_function_call must not fire it again. + assert hook_calls.count("pre_tool_call") == 0, ( + f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times " + f"with skip_pre_tool_call_hook=True; expected 0 " + f"(caller already fired it). hook_calls={hook_calls}" + ) + # post_tool_call and transform_tool_result still fire — only the + # pre-call block-check path is suppressed by the skip flag. assert "post_tool_call" in hook_calls + assert "transform_tool_result" in hook_calls + + def test_run_agent_pattern_fires_pre_tool_call_exactly_once(self, monkeypatch): + """End-to-end regression for the double-fire bug. + + Mirrors run_agent._invoke_tool: first calls + get_pre_tool_call_block_message() (which fires the hook as part of + its block-directive poll), then calls + handle_function_call(skip_pre_tool_call_hook=True). The plugin + hook MUST fire exactly once across both calls — not twice as it + did before the fix (observer plugins were seeing every tool + execution logged twice). + """ + from hermes_cli.plugins import get_pre_tool_call_block_message + + hook_calls = [] + + def fake_invoke_hook(hook_name, **kwargs): + hook_calls.append(hook_name) + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: json.dumps({"ok": True})) + + # Step 1: caller checks for a block directive (this fires pre_tool_call once). + block = get_pre_tool_call_block_message( + "web_search", {"q": "test"}, task_id="t1", + ) + assert block is None + + # Step 2: caller dispatches with skip=True so the hook isn't re-fired. + handle_function_call( + "web_search", {"q": "test"}, task_id="t1", + skip_pre_tool_call_hook=True, + ) + + assert hook_calls.count("pre_tool_call") == 1, ( + f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times " + f"across the run_agent (block-check + dispatch) path; " + f"expected exactly 1. hook_calls={hook_calls}" + ) # ========================================================================= diff --git a/tests/test_model_tools_async_bridge.py b/tests/test_model_tools_async_bridge.py index d6266d7c36..ed0a85cd35 100644 --- a/tests/test_model_tools_async_bridge.py +++ b/tests/test_model_tools_async_bridge.py @@ -199,20 +199,22 @@ class TestRunAsyncWithRunningLoop: @pytest.mark.asyncio async def test_timeout_uses_nonblocking_executor_shutdown(self, monkeypatch): - """A timeout in the running-loop branch must not wait for the worker. + """A timeout in the running-loop branch must not block the caller. - ThreadPoolExecutor's context manager performs shutdown(wait=True). - If _run_async relies on that path after future.result(timeout=...) - times out, the timeout does not bound wall-clock time because the - caller still waits for the stuck coroutine's thread to finish. + If shutdown ever waits for a stuck worker, a tool coroutine that + ignores (or can't observe) cancellation would hang the whole agent. + Guard: the caller must raise TimeoutError and pool.shutdown must be + called with wait=False. The worker's own event loop handles cleanup + (cancellation is scheduled via call_soon_threadsafe before the + caller returns). """ import concurrent.futures from model_tools import _run_async events = { - "cancelled": False, "result_timeout": None, "shutdown_calls": [], + "submitted_fn": None, } class TimeoutFuture: @@ -221,7 +223,6 @@ class TestRunAsyncWithRunningLoop: raise concurrent.futures.TimeoutError() def cancel(self): - events["cancelled"] = True return True class FakeExecutor: @@ -236,8 +237,10 @@ class TestRunAsyncWithRunningLoop: return False def submit(self, fn, *args, **kwargs): - if args and hasattr(args[0], "close"): - args[0].close() + # Record which function got submitted -- should be the + # in-function worker wrapper, not bare asyncio.run, so we + # know _run_async is using a loop it owns and can cancel. + events["submitted_fn"] = getattr(fn, "__name__", repr(fn)) return TimeoutFuture() def shutdown(self, wait=True, cancel_futures=False): @@ -256,8 +259,82 @@ class TestRunAsyncWithRunningLoop: _run_async(_never_finishes()) assert events["result_timeout"] == 300 - assert events["cancelled"] is True - assert events["shutdown_calls"] == [(False, True)] + # The worker wrapper creates its own event loop so _run_async can + # cancel the task on timeout — this must NOT be bare asyncio.run. + assert events["submitted_fn"] != "run", ( + "_run_async submitted asyncio.run directly — it must submit a " + "worker wrapper that owns the event loop so timeouts can cancel " + "the task" + ) + # Critical: shutdown must NOT wait. If wait=True, a stuck coroutine + # would freeze the caller (converts a thread leak into a hang). + assert events["shutdown_calls"], "shutdown was never called" + for wait, _cancel in events["shutdown_calls"]: + assert wait is False, ( + f"shutdown called with wait={wait} — a stuck tool coroutine " + f"would hang the caller indefinitely" + ) + + @pytest.mark.asyncio + async def test_timeout_cancels_coroutine_in_worker_loop(self, monkeypatch): + """On timeout, the worker's event loop must receive a cancel request + so the coroutine stops and the thread exits — not leaked. + + Before the fix, future.cancel() on a running ThreadPoolExecutor + future is a no-op, so the worker thread kept running the coroutine + to completion (leaking one thread per tool-timeout). + """ + from model_tools import _run_async + + # Shrink the 300s internal timeout by patching future.result. + # We do this surgically: let everything else run for real so the + # worker loop actually exists and can observe cancellation. + import concurrent.futures as _cf + + real_pool_cls = _cf.ThreadPoolExecutor + + class FastTimeoutPool(real_pool_cls): + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + + # Patch future.result to time out after 1s instead of 300s. + real_result = _cf.Future.result + + def fast_result(self, timeout=None): + return real_result(self, timeout=1.0 if timeout == 300 else timeout) + + monkeypatch.setattr(_cf.Future, "result", fast_result) + + cancel_observed = threading.Event() + + async def _slow_cancellable(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_observed.set() + raise + + import time as _time + t0 = _time.time() + with pytest.raises(_cf.TimeoutError): + _run_async(_slow_cancellable()) + elapsed = _time.time() - t0 + + # Caller must return fast (no hang waiting for the coro). + assert elapsed < 3.0, ( + f"_run_async blocked caller for {elapsed:.1f}s — should return " + f"on timeout regardless of whether the coroutine has finished" + ) + + # Worker thread must cancel the task (not leak). + deadline = _time.time() + 5 + while not cancel_observed.is_set() and _time.time() < deadline: + _time.sleep(0.05) + assert cancel_observed.is_set(), ( + "Coroutine never received CancelledError — worker thread leaked " + "(ThreadPoolExecutor.cancel() is a no-op on a running future; " + "_run_async must cancel the task inside its worker loop)" + ) # --------------------------------------------------------------------------- diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index dacc55df5b..b9d7c1b0dc 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -59,6 +59,147 @@ def test_write_json_returns_false_on_broken_pipe(monkeypatch): assert server.write_json({"ok": True}) is False +def test_load_enabled_toolsets_prefers_tui_env(monkeypatch): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web, terminal, ,memory") + + assert server._load_enabled_toolsets() == ["web", "terminal", "memory"] + + +def test_load_enabled_toolsets_filters_invalid_tui_env(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web, nope") + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + assert server._load_enabled_toolsets() == ["web"] + assert "nope" in capsys.readouterr().err + + +def test_load_enabled_toolsets_accepts_plugin_env_after_discovery(monkeypatch): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "plugin_demo") + + import toolsets + + discovered = {"ready": False} + original_validate = toolsets.validate_toolset + + def fake_validate(name): + return name == "plugin_demo" and discovered["ready"] or original_validate(name) + + monkeypatch.setattr(toolsets, "validate_toolset", fake_validate) + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: discovered.update({"ready": True})), + ) + + assert server._load_enabled_toolsets() == ["plugin_demo"] + + +def test_load_enabled_toolsets_rejects_disabled_mcp_env(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "mcp-off") + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + import hermes_cli.config as config_mod + + monkeypatch.setattr( + config_mod, + "read_raw_config", + lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}}, + ) + monkeypatch.setattr(config_mod, "load_config", lambda: {"platform_toolsets": {"cli": ["memory"]}}) + + assert server._load_enabled_toolsets() == ["memory"] + err = capsys.readouterr().err + assert "ignoring disabled MCP servers" in err + assert "mcp-off" in err + assert "using configured CLI toolsets" in err + + +def test_load_enabled_toolsets_falls_back_when_tui_env_invalid(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "nope") + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + import hermes_cli.config as config_mod + + monkeypatch.setattr(config_mod, "load_config", lambda: {"platform_toolsets": {"cli": ["memory"]}}) + + assert server._load_enabled_toolsets() == ["memory"] + assert "using configured CLI toolsets" in capsys.readouterr().err + + +def test_load_enabled_toolsets_warns_when_config_fallback_fails(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "nope") + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + import hermes_cli.config as config_mod + + monkeypatch.setattr(config_mod, "load_config", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert server._load_enabled_toolsets() is None + assert "could not be loaded" in capsys.readouterr().err + + +def test_load_enabled_toolsets_honors_builtin_env_if_config_fails(monkeypatch): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web") + + import hermes_cli.config as config_mod + + monkeypatch.setattr(config_mod, "load_config", lambda: (_ for _ in ()).throw(RuntimeError("boom"))) + + assert server._load_enabled_toolsets() == ["web"] + + +def test_load_enabled_toolsets_all_env_means_all(monkeypatch): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "all") + + assert server._load_enabled_toolsets() is None + + +def test_load_enabled_toolsets_all_env_warns_about_ignored_extra_entries(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "all,nope") + + assert server._load_enabled_toolsets() is None + assert "ignoring additional entries: nope" in capsys.readouterr().err + + +def test_load_enabled_toolsets_reports_disabled_mcp_separately(monkeypatch, capsys): + monkeypatch.setenv("HERMES_TUI_TOOLSETS", "web,mcp-off,nope") + monkeypatch.setitem( + sys.modules, + "hermes_cli.plugins", + types.SimpleNamespace(discover_plugins=lambda: None), + ) + + import hermes_cli.config as config_mod + + monkeypatch.setattr( + config_mod, + "read_raw_config", + lambda: {"mcp_servers": {"mcp-off": {"enabled": False}}}, + ) + + assert server._load_enabled_toolsets() == ["web"] + err = capsys.readouterr().err + assert "ignoring unknown HERMES_TUI_TOOLSETS entries: nope" in err + assert "ignoring disabled MCP servers" in err + assert "mcp-off" in err + + def test_history_to_messages_preserves_tool_calls_for_resume_display(): history = [ {"role": "user", "content": "first prompt"}, @@ -879,6 +1020,36 @@ def test_config_set_statusbar_survives_non_dict_display(tmp_path, monkeypatch): assert saved["display"]["tui_statusbar"] == "bottom" +def test_config_set_details_mode_pins_all_sections(tmp_path, monkeypatch): + import yaml + + cfg_path = tmp_path / "config.yaml" + cfg_path.write_text( + yaml.safe_dump( + {"display": {"sections": {"tools": "expanded", "activity": "hidden"}}} + ) + ) + monkeypatch.setattr(server, "_hermes_home", tmp_path) + + resp = server.handle_request( + { + "id": "1", + "method": "config.set", + "params": {"key": "details_mode", "value": "collapsed"}, + } + ) + + assert resp["result"] == {"key": "details_mode", "value": "collapsed"} + saved = yaml.safe_load(cfg_path.read_text()) + assert saved["display"]["details_mode"] == "collapsed" + assert saved["display"]["sections"] == { + "thinking": "collapsed", + "tools": "collapsed", + "subagents": "collapsed", + "activity": "collapsed", + } + + def test_config_set_section_writes_per_section_override(tmp_path, monkeypatch): import yaml @@ -1066,6 +1237,18 @@ def test_config_set_reasoning_updates_live_session_and_agent(tmp_path, monkeypat ) assert resp_show["result"]["value"] == "show" assert server._sessions["sid"]["show_reasoning"] is True + assert server._load_cfg()["display"]["sections"]["thinking"] == "expanded" + + resp_hide = server.handle_request( + { + "id": "3", + "method": "config.set", + "params": {"session_id": "sid", "key": "reasoning", "value": "hide"}, + } + ) + assert resp_hide["result"]["value"] == "hide" + assert server._sessions["sid"]["show_reasoning"] is False + assert server._load_cfg()["display"]["sections"]["thinking"] == "hidden" def test_config_set_verbose_updates_session_mode_and_agent(tmp_path, monkeypatch): @@ -1383,7 +1566,7 @@ def test_session_compress_uses_compress_helper(monkeypatch): monkeypatch.setattr( server, "_compress_session_history", - lambda session, focus_topic=None: (2, {"total": 42}), + lambda session, focus_topic=None, **_kw: (2, {"total": 42}), ) monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"}) @@ -1394,7 +1577,52 @@ def test_session_compress_uses_compress_helper(monkeypatch): assert resp["result"]["removed"] == 2 assert resp["result"]["usage"]["total"] == 42 - emit.assert_called_once_with("session.info", "sid", {"model": "x"}) + emit.assert_any_call("session.info", "sid", {"model": "x"}) + # Final status.update clears the pinned "compressing" indicator so the + # status bar can revert to the neutral state when compaction finishes. + emit.assert_any_call( + "status.update", "sid", {"kind": "status", "text": "ready"} + ) + + +def test_session_compress_syncs_session_key_after_rotation(monkeypatch): + """When AIAgent._compress_context rotates session_id (compression split), + the gateway session_key must follow so subsequent approval routing, + DB title/history lookups, and slash worker resume target the new + continuation session — mirrors HermesCLI._manual_compress's + session_id sync (cli.py). + """ + agent = types.SimpleNamespace(session_id="rotated-id") + server._sessions["sid"] = _session(agent=agent) + server._sessions["sid"]["session_key"] = "old-key" + server._sessions["sid"]["pending_title"] = "stale title" + + monkeypatch.setattr( + server, + "_compress_session_history", + lambda session, focus_topic=None, **_kw: (2, {"total": 42}), + ) + monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"}) + restart_calls = [] + monkeypatch.setattr( + server, "_restart_slash_worker", lambda s: restart_calls.append(s) + ) + + try: + with patch("tui_gateway.server._emit"): + server.handle_request( + { + "id": "1", + "method": "session.compress", + "params": {"session_id": "sid"}, + } + ) + + assert server._sessions["sid"]["session_key"] == "rotated-id" + assert server._sessions["sid"]["pending_title"] is None + assert len(restart_calls) == 1 + finally: + server._sessions.pop("sid", None) def test_prompt_submit_sets_approval_session_key(monkeypatch): @@ -2240,6 +2468,39 @@ def test_mirror_slash_side_effects_allowed_when_idle(monkeypatch): assert applied["model"] +def test_mirror_slash_compress_does_not_prelock_history(monkeypatch): + """Regression guard: /compress side effect must not hold history_lock + when calling _compress_session_history (the helper snapshots under + the same non-reentrant lock internally).""" + import types + + seen = {"compress": False, "sync": False} + emitted = [] + + def _fake_compress(session, focus_topic=None, **_kw): + seen["compress"] = True + assert not session["history_lock"].locked() + return (0, {"total": 0}) + + def _fake_sync(_sid, _session): + seen["sync"] = True + + monkeypatch.setattr(server, "_compress_session_history", _fake_compress) + monkeypatch.setattr(server, "_sync_session_key_after_compress", _fake_sync) + monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "x"}) + monkeypatch.setattr(server, "_emit", lambda *args: emitted.append(args)) + + session = _session(running=False) + session["agent"] = types.SimpleNamespace(model="x") + + warning = server._mirror_slash_side_effects("sid", session, "/compress") + + assert warning == "" + assert seen["compress"] + assert seen["sync"] + assert ("session.info", "sid", {"model": "x"}) in emitted + + # --------------------------------------------------------------------------- # session.create / session.close race: fast /new churn must not orphan the # slash_worker subprocess or the global approval-notify registration. @@ -2274,10 +2535,20 @@ def test_session_create_close_race_does_not_orphan_worker(monkeypatch): self.base_url = "" self.api_key = "" - # Make _build block until we release it — simulates slow agent init + # Make _build block until we release it — simulates slow agent init. + # Also signal when _build actually reaches _make_agent so the test + # can close the session at the right moment: session.create now + # defers _start_agent_build behind a 50ms timer (see the + # `_deferred_build` path in @method("session.create")), so closing + # before the build thread has even started would skip the orphan + # detection entirely and the test would race a non-event. + build_started = threading.Event() release_build = threading.Event() + build_entered = threading.Event() - def _slow_make_agent(sid, key): + def _slow_make_agent(sid, key, session_id=None): + build_started.set() + build_entered.set() release_build.wait(timeout=3.0) return _FakeAgent() @@ -2315,6 +2586,13 @@ def test_session_create_close_race_does_not_orphan_worker(monkeypatch): ) assert resp.get("result"), f"got error: {resp.get('error')}" sid = resp["result"]["session_id"] + assert build_entered.wait(timeout=1.0), "deferred build did not start" + + # Wait until the (deferred) build thread has actually entered + # _make_agent — otherwise session.close pops _sessions[sid] before + # _build ever runs, _start_agent_build never calls _build, and we + # never exercise the orphan-cleanup path. + assert build_started.wait(timeout=2.0), "build thread never entered _make_agent" # Build thread is blocked in _slow_make_agent. Close the session # NOW — this pops _sessions[sid] before _build can install the @@ -2497,6 +2775,155 @@ def test_session_list_returns_clean_error_when_state_db_is_unavailable(monkeypat assert "state.db unavailable: locking protocol" in resp["error"]["message"] +# -------------------------------------------------------------------------- +# session.delete — TUI resume picker `d` key +# -------------------------------------------------------------------------- + + +def test_session_delete_requires_session_id(monkeypatch): + """Empty / missing session_id is a 4006 client error (no DB call).""" + called: list[tuple] = [] + + class _DB: + def delete_session(self, *a, **kw): + called.append((a, kw)) + return True + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + + resp = server.handle_request({"id": "1", "method": "session.delete", "params": {}}) + assert "error" in resp + assert resp["error"]["code"] == 4006 + assert called == [] + + +def test_session_delete_returns_db_unavailable_when_no_db(monkeypatch): + monkeypatch.setattr(server, "_get_db", lambda: None) + monkeypatch.setattr(server, "_db_error", "locked") + + resp = server.handle_request( + {"id": "1", "method": "session.delete", "params": {"session_id": "abc"}} + ) + + assert "error" in resp + assert resp["error"]["code"] == 5036 + assert "state.db unavailable" in resp["error"]["message"] + + +def test_session_delete_refuses_active_session(monkeypatch): + """Cannot delete a session currently bound to a live TUI session.""" + called: list[str] = [] + + class _DB: + def delete_session(self, sid, sessions_dir=None): + called.append(sid) + return True + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + monkeypatch.setitem(server._sessions, "live", {"session_key": "key-live"}) + try: + resp = server.handle_request( + { + "id": "1", + "method": "session.delete", + "params": {"session_id": "key-live"}, + } + ) + finally: + server._sessions.pop("live", None) + + assert "error" in resp + assert resp["error"]["code"] == 4023 + assert "active session" in resp["error"]["message"] + assert called == [], "delete_session must not be called for active sessions" + + +def test_session_delete_fails_closed_when_active_snapshot_raises(monkeypatch): + """Concurrent ``_sessions`` mutation from another RPC thread can raise + ``RuntimeError: dictionary changed size during iteration``. When the + handler can't enumerate active sessions safely it must refuse the + delete (fail closed) rather than fall through and allow it.""" + + class _DB: + def delete_session(self, *a, **kw): + raise AssertionError("delete must not run when active snapshot fails") + + class _ExplodingDict: + def values(self): + raise RuntimeError("dictionary changed size during iteration") + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + monkeypatch.setattr(server, "_sessions", _ExplodingDict()) + + resp = server.handle_request( + {"id": "1", "method": "session.delete", "params": {"session_id": "x"}} + ) + + assert "error" in resp + assert resp["error"]["code"] == 5036 + assert "enumerate active sessions" in resp["error"]["message"] + + +def test_session_delete_returns_4007_when_missing(monkeypatch): + class _DB: + def delete_session(self, sid, sessions_dir=None): + return False + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + + resp = server.handle_request( + {"id": "1", "method": "session.delete", "params": {"session_id": "ghost"}} + ) + + assert "error" in resp + assert resp["error"]["code"] == 4007 + + +def test_session_delete_propagates_db_exception(monkeypatch): + class _DB: + def delete_session(self, sid, sessions_dir=None): + raise RuntimeError("disk full") + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + + resp = server.handle_request( + {"id": "1", "method": "session.delete", "params": {"session_id": "x"}} + ) + + assert "error" in resp + assert resp["error"]["code"] == 5036 + assert "disk full" in resp["error"]["message"] + + +def test_session_delete_success_returns_deleted_id(monkeypatch): + """Happy path — DB delete succeeds, response carries the deleted id + and the on-disk sessions dir is forwarded so transcript files get + cleaned up alongside the row.""" + captured: dict = {} + + class _DB: + def delete_session(self, sid, sessions_dir=None): + captured["sid"] = sid + captured["sessions_dir"] = sessions_dir + return True + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + + resp = server.handle_request( + {"id": "1", "method": "session.delete", "params": {"session_id": "old-1"}} + ) + + assert "result" in resp, resp + assert resp["result"] == {"deleted": "old-1"} + assert captured["sid"] == "old-1" + # sessions_dir must be forwarded so transcript files get cleaned up + # too — not just the SQLite row. The autouse _isolate_hermes_home + # fixture pins HERMES_HOME to a temp dir; the handler should append + # /sessions to it. + assert captured["sessions_dir"] is not None + assert str(captured["sessions_dir"]).endswith("sessions") + + # -------------------------------------------------------------------------- # model.options — curated-list parity with `hermes model` and classic /model # -------------------------------------------------------------------------- diff --git a/tests/tools/test_accretion_caps.py b/tests/tools/test_accretion_caps.py index bdc9b41c37..dcd3c09fd9 100644 --- a/tests/tools/test_accretion_caps.py +++ b/tests/tools/test_accretion_caps.py @@ -127,7 +127,11 @@ class TestReadTrackerCaps: td = ft._read_tracker["long-session"] assert len(td["read_history"]) <= 3 assert len(td["dedup"]) <= 3 - assert len(td["read_timestamps"]) <= 3 + # read_timestamps is populated lazily (via setdefault) only + # when os.path.getmtime() succeeds. On some CI filesystems + # that stat can race with file creation — skip rather than + # hard-error if the dict hasn't been created yet. + assert len(td.get("read_timestamps", {})) <= 3 class TestCompletionConsumedPrune: diff --git a/tests/tools/test_approval_heartbeat.py b/tests/tools/test_approval_heartbeat.py index cdbba406db..d54a5b1421 100644 --- a/tests/tools/test_approval_heartbeat.py +++ b/tests/tools/test_approval_heartbeat.py @@ -131,15 +131,15 @@ class TestApprovalHeartbeat: """Polling slices don't delay responsiveness — resolve is near-instant.""" from tools.approval import ( check_all_command_guards, + has_blocking_approval, register_gateway_notify, resolve_gateway_approval, ) - register_gateway_notify(self.SESSION_KEY, lambda _payload: None) - - start_time = time.monotonic() result_holder: dict = {} + register_gateway_notify(self.SESSION_KEY, lambda _payload: None) + def _run_check(): result_holder["result"] = check_all_command_guards( "rm -rf /tmp/nonexistent-fast-target", "local" @@ -148,9 +148,18 @@ class TestApprovalHeartbeat: thread = threading.Thread(target=_run_check, daemon=True) thread.start() + # Wait until the worker has actually enqueued the approval. Resolving + # before registration is a test race, not a responsiveness signal. + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if has_blocking_approval(self.SESSION_KEY): + break + time.sleep(0.01) + assert has_blocking_approval(self.SESSION_KEY) + # Resolve almost immediately — the wait loop should return within # its current 1s poll slice. - time.sleep(0.1) + start_time = time.monotonic() resolve_gateway_approval(self.SESSION_KEY, "once") thread.join(timeout=5) elapsed = time.monotonic() - start_time diff --git a/tests/tools/test_browser_orphan_reaper.py b/tests/tools/test_browser_orphan_reaper.py index 27352960b4..202aa6f9a2 100644 --- a/tests/tools/test_browser_orphan_reaper.py +++ b/tests/tools/test_browser_orphan_reaper.py @@ -354,6 +354,7 @@ class TestOwnerPidCrossProcess: monkeypatch.setattr( bt, "_requires_real_termux_browser_install", lambda *a: False ) + monkeypatch.setattr(bt, "_chromium_installed", lambda: True) monkeypatch.setattr( bt, "_get_session_info", lambda task_id: {"session_name": session_name}, diff --git a/tests/tools/test_clipboard.py b/tests/tools/test_clipboard.py index 17f929eb9c..90e2ea847f 100644 --- a/tests/tools/test_clipboard.py +++ b/tests/tools/test_clipboard.py @@ -205,36 +205,53 @@ class TestMacosOsascript: class TestIsWsl: def setup_method(self): - # _is_wsl is now hermes_constants.is_wsl — reset its cache + # _is_wsl is hermes_constants.is_wsl; reset the function's own module + # globals so this stays stable even if hermes_constants was imported + # through a different module object earlier in a large xdist run. import hermes_constants hermes_constants._wsl_detected = None + _is_wsl.__globals__["_wsl_detected"] = None + + def teardown_method(self): + # Reset again after the test so we don't leak a cached value + # (True/False) into whichever test the xdist worker runs next. + import hermes_constants + hermes_constants._wsl_detected = None + _is_wsl.__globals__["_wsl_detected"] = None def test_wsl2_detected(self): content = "Linux version 5.15.0 (microsoft-standard-WSL2)" - with patch("builtins.open", mock_open(read_data=content)): + with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}): assert _is_wsl() is True def test_wsl1_detected(self): content = "Linux version 4.4.0-microsoft-standard" - with patch("builtins.open", mock_open(read_data=content)): + with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}): assert _is_wsl() is True def test_regular_linux(self): + # GHA hosted runners are Azure VMs whose real /proc/version often + # contains "microsoft". Patching builtins.open with mock_open is + # supposed to intercept hermes_constants.is_wsl's `open` call, + # but if another test on the same xdist worker already cached + # _wsl_detected=True, the mock never runs because the function + # short-circuits on the cache. setup_method resets, so we just + # need to be sure the patched `open` is actually reached. content = "Linux version 6.14.0-37-generic (buildd@lcy02-amd64-049)" - with patch("builtins.open", mock_open(read_data=content)): + with patch.dict(_is_wsl.__globals__, {"open": mock_open(read_data=content)}): assert _is_wsl() is False def test_proc_version_missing(self): - with patch("builtins.open", side_effect=FileNotFoundError): + with patch.dict(_is_wsl.__globals__, {"open": MagicMock(side_effect=FileNotFoundError)}): assert _is_wsl() is False def test_result_is_cached(self): - import hermes_constants content = "Linux version 5.15.0 (microsoft-standard-WSL2)" - with patch("builtins.open", mock_open(read_data=content)) as m: + opener = mock_open(read_data=content) + with patch.dict(_is_wsl.__globals__, {"open": opener}): assert _is_wsl() is True assert _is_wsl() is True - m.assert_called_once() # only read once + opener.assert_called_once() # only read once # ── WSL (powershell.exe) ──────────────────────────────────────────────── diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 15f8faa9bb..6f6260ffe2 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -770,11 +770,19 @@ class TestLoadConfig(unittest.TestCase): def test_returns_code_execution_section(self): from tools.code_execution_tool import _load_config - mock_cli = MagicMock() - mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 120, "max_tool_calls": 10}} - with patch.dict("sys.modules", {"cli": mock_cli}): + with patch("hermes_cli.config.read_raw_config", + return_value={"code_execution": {"timeout": 120, "max_tool_calls": 10}}): result = _load_config() - self.assertIsInstance(result, dict) + self.assertEqual(result, {"timeout": 120, "max_tool_calls": 10}) + + def test_does_not_import_interactive_cli(self): + from tools.code_execution_tool import _load_config + mock_cli = MagicMock() + mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 999}} + with patch.dict("sys.modules", {"cli": mock_cli}), \ + patch("hermes_cli.config.read_raw_config", return_value={}): + result = _load_config() + self.assertEqual(result, {}) # --------------------------------------------------------------------------- diff --git a/tests/tools/test_command_guards.py b/tests/tools/test_command_guards.py index bb0b46053b..a2fd394304 100644 --- a/tests/tools/test_command_guards.py +++ b/tests/tools/test_command_guards.py @@ -73,6 +73,10 @@ class TestContainerSkip: result = check_all_command_guards("rm -rf /", "daytona") assert result["approved"] is True + def test_vercel_sandbox_skips_both(self): + result = check_all_command_guards("rm -rf /", "vercel_sandbox") + assert result["approved"] is True + # --------------------------------------------------------------------------- # tirith allow + safe command diff --git a/tests/tools/test_cronjob_tools.py b/tests/tools/test_cronjob_tools.py index 38fc12cc8c..ab6f8eef08 100644 --- a/tests/tools/test_cronjob_tools.py +++ b/tests/tools/test_cronjob_tools.py @@ -231,3 +231,60 @@ class TestUnifiedCronjobTool: assert updated["success"] is True assert updated["job"]["skills"] == [] assert updated["job"]["skill"] is None + + def test_create_normalizes_list_form_deliver(self): + """deliver=['telegram'] (list) is stored as the string 'telegram'. + + Regression for #17139: MCP clients / scripts sometimes pass ``deliver`` + as an array. Prior to the fix, ``['telegram']`` was written verbatim + to ``jobs.json`` and the scheduler then tried to resolve the literal + string ``"['telegram']"`` as a platform, failing with + "no delivery target resolved". + """ + from cron.jobs import get_job + + created = json.loads( + cronjob( + action="create", + prompt="Daily briefing", + schedule="every 1h", + deliver=["telegram"], + ) + ) + assert created["success"] is True + stored = get_job(created["job_id"]) + assert stored["deliver"] == "telegram" + + def test_create_normalizes_multi_element_list_deliver(self): + """deliver=['telegram', 'discord'] is stored as 'telegram,discord'.""" + from cron.jobs import get_job + + created = json.loads( + cronjob( + action="create", + prompt="Daily briefing", + schedule="every 1h", + deliver=["telegram", "discord"], + ) + ) + assert created["success"] is True + stored = get_job(created["job_id"]) + assert stored["deliver"] == "telegram,discord" + + def test_update_normalizes_list_form_deliver(self): + """update with deliver=['telegram'] stores the canonical string.""" + from cron.jobs import get_job + + created = json.loads( + cronjob(action="create", prompt="x", schedule="every 1h") + ) + updated = json.loads( + cronjob( + action="update", + job_id=created["job_id"], + deliver=["telegram"], + ) + ) + assert updated["success"] is True + stored = get_job(created["job_id"]) + assert stored["deliver"] == "telegram" diff --git a/tests/tools/test_docker_environment.py b/tests/tools/test_docker_environment.py index 62b8b83df1..cd3b7aae6f 100644 --- a/tests/tools/test_docker_environment.py +++ b/tests/tools/test_docker_environment.py @@ -45,6 +45,7 @@ def _make_dummy_env(**kwargs): host_cwd=kwargs.get("host_cwd"), auto_mount_cwd=kwargs.get("auto_mount_cwd", False), env=kwargs.get("env"), + run_as_host_user=kwargs.get("run_as_host_user", False), ) @@ -384,9 +385,10 @@ def test_normalize_env_dict_rejects_complex_values(): assert result == {"GOOD": "string"} -def test_security_args_include_setuid_setgid_for_gosu_drop(): - """_SECURITY_ARGS must include SETUID and SETGID so the image entrypoint - can drop from root to the non-root `hermes` user via gosu. +def test_security_args_include_setuid_setgid_for_gosu_drop(monkeypatch): + """The default (run_as_host_user=False) invocation must include SETUID and + SETGID caps so the image entrypoint can drop from root to the non-root + `hermes` user via gosu. Without these caps gosu exits with ``error: failed switching to 'hermes': operation not permitted`` @@ -396,17 +398,117 @@ def test_security_args_include_setuid_setgid_for_gosu_drop(): after the drop — the drop is a one-way transition performed before the `no_new_privs` bit is enforced on the exec boundary. """ - args = docker_env._SECURITY_ARGS + monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker") + calls = _mock_subprocess_run(monkeypatch) + + _make_dummy_env() + + run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"] + assert run_calls, "docker run should have been called" + run_args = run_calls[0][0] - # Flatten to set of added caps for clarity. added = { - args[i + 1] - for i, flag in enumerate(args[:-1]) + run_args[i + 1] + for i, flag in enumerate(run_args[:-1]) if flag == "--cap-add" } assert "SETUID" in added, "SETUID cap missing — gosu drop in entrypoint will fail" assert "SETGID" in added, "SETGID cap missing — gosu drop in entrypoint will fail" - # Sanity: the hardening posture is still in place. - assert "--cap-drop" in args and "ALL" in args - assert "--security-opt" in args and "no-new-privileges" in args + +# ── run_as_host_user tests ──────────────────────────────────────── + + +def test_run_as_host_user_passes_uid_gid(monkeypatch): + """With run_as_host_user=True, --user : is added to docker run.""" + monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker") + monkeypatch.setattr(docker_env.os, "getuid", lambda: 1234, raising=False) + monkeypatch.setattr(docker_env.os, "getgid", lambda: 5678, raising=False) + calls = _mock_subprocess_run(monkeypatch) + + _make_dummy_env(run_as_host_user=True) + + run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"] + assert run_calls, "docker run should have been called" + run_args = run_calls[0][0] + + # --user must be present and must be paired with "1234:5678" + assert "--user" in run_args, f"--user flag missing from docker run args: {run_args}" + idx = run_args.index("--user") + assert run_args[idx + 1] == "1234:5678", ( + f"expected --user 1234:5678, got --user {run_args[idx + 1]}" + ) + + +def test_run_as_host_user_drops_setuid_setgid_caps(monkeypatch): + """When --user is passed, the container never needs gosu, so SETUID/SETGID + caps are omitted for a tighter security posture.""" + monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker") + monkeypatch.setattr(docker_env.os, "getuid", lambda: 1000, raising=False) + monkeypatch.setattr(docker_env.os, "getgid", lambda: 1000, raising=False) + calls = _mock_subprocess_run(monkeypatch) + + _make_dummy_env(run_as_host_user=True) + + run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"] + run_args = run_calls[0][0] + + added = { + run_args[i + 1] + for i, flag in enumerate(run_args[:-1]) + if flag == "--cap-add" + } + assert "SETUID" not in added, ( + "SETUID cap should be dropped when running as host user — no gosu drop is needed" + ) + assert "SETGID" not in added, ( + "SETGID cap should be dropped when running as host user — no gosu drop is needed" + ) + # Core non-privilege-drop caps must still be there (pip/npm/apt need them). + assert "DAC_OVERRIDE" in added + assert "CHOWN" in added + assert "FOWNER" in added + + +def test_run_as_host_user_default_off(monkeypatch): + """Without the opt-in, no --user flag is emitted — preserving existing behavior.""" + monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker") + calls = _mock_subprocess_run(monkeypatch) + + _make_dummy_env() # run_as_host_user defaults to False + + run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"] + run_args = run_calls[0][0] + assert "--user" not in run_args, ( + f"--user should not be in docker run args when opt-in is off: {run_args}" + ) + + +def test_run_as_host_user_warns_and_skips_when_no_posix_ids(monkeypatch, caplog): + """On platforms without POSIX getuid/getgid, log a warning and leave the + container at its image default user (no --user flag, full cap set).""" + monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker") + # Simulate a platform where os.getuid is absent (e.g. Windows host). + monkeypatch.delattr(docker_env.os, "getuid", raising=False) + monkeypatch.delattr(docker_env.os, "getgid", raising=False) + calls = _mock_subprocess_run(monkeypatch) + + with caplog.at_level(logging.WARNING): + _make_dummy_env(run_as_host_user=True) + + run_calls = [c for c in calls if isinstance(c[0], list) and len(c[0]) >= 2 and c[0][1] == "run"] + run_args = run_calls[0][0] + + assert "--user" not in run_args + # Fall back to the full cap set since the container still starts as root. + added = { + run_args[i + 1] + for i, flag in enumerate(run_args[:-1]) + if flag == "--cap-add" + } + assert "SETUID" in added + assert "SETGID" in added + assert any( + "does not expose POSIX uid/gid" in rec.getMessage() + for rec in caplog.records + ), "expected a warning when POSIX ids are unavailable" diff --git a/tests/tools/test_hardline_blocklist.py b/tests/tools/test_hardline_blocklist.py index 3f65cc0869..a3a08cd464 100644 --- a/tests/tools/test_hardline_blocklist.py +++ b/tests/tools/test_hardline_blocklist.py @@ -241,7 +241,7 @@ def test_container_backends_still_bypass(clean_session): Hardline only protects environments with real host impact (local, ssh). """ - for env in ("docker", "singularity", "modal", "daytona"): + for env in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): r1 = check_dangerous_command("rm -rf /", env) assert r1["approved"] is True, f"container {env} should still bypass" r2 = check_all_command_guards("rm -rf /", env) diff --git a/tests/tools/test_local_env_blocklist.py b/tests/tools/test_local_env_blocklist.py index 0377d59b36..e3e7c310c5 100644 --- a/tests/tools/test_local_env_blocklist.py +++ b/tests/tools/test_local_env_blocklist.py @@ -132,6 +132,10 @@ class TestProviderEnvBlocklist: "MODAL_TOKEN_ID": "modal-id", "MODAL_TOKEN_SECRET": "modal-secret", "DAYTONA_API_KEY": "daytona-key", + "VERCEL_OIDC_TOKEN": "vercel-oidc-token", + "VERCEL_TOKEN": "vercel-token", + "VERCEL_PROJECT_ID": "vercel-project", + "VERCEL_TEAM_ID": "vercel-team", } result_env = _run_with_env(extra_os_env=leaked_vars) @@ -287,6 +291,10 @@ class TestBlocklistCoverage: "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", "DAYTONA_API_KEY", + "VERCEL_OIDC_TOKEN", + "VERCEL_TOKEN", + "VERCEL_PROJECT_ID", + "VERCEL_TEAM_ID", } assert extras.issubset(_HERMES_PROVIDER_ENV_BLOCKLIST) diff --git a/tests/tools/test_local_interrupt_cleanup.py b/tests/tools/test_local_interrupt_cleanup.py index 72310009a5..a9b7455938 100644 --- a/tests/tools/test_local_interrupt_cleanup.py +++ b/tests/tools/test_local_interrupt_cleanup.py @@ -16,6 +16,7 @@ import signal import subprocess import threading import time +from types import SimpleNamespace import pytest @@ -37,6 +38,58 @@ def _pgid_still_alive(pgid: int) -> bool: return False +def _process_group_snapshot(pgid: int) -> str: + """Return a process-table snapshot for diagnostics.""" + return subprocess.run( + ["ps", "-o", "pid,ppid,pgid,stat,cmd", "-g", str(pgid)], + capture_output=True, + text=True, + check=False, + ).stdout.strip() + + +def _wait_for_pgid_exit(pgid: int, timeout: float = 10.0) -> bool: + """Wait for a process group to disappear under loaded xdist hosts.""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if not _pgid_still_alive(pgid): + return True + time.sleep(0.1) + return not _pgid_still_alive(pgid) + + +def test_kill_process_uses_cached_pgid_if_wrapper_already_exited(monkeypatch): + """If the shell wrapper exits before cleanup, still kill its process group. + + Without the cached pgid fallback, ``os.getpgid(proc.pid)`` raises for the + dead wrapper and cleanup falls back to ``proc.kill()``, which cannot reach + orphaned grandchildren still running in the original process group. + """ + env = object.__new__(LocalEnvironment) + proc = SimpleNamespace( + pid=12345, + _hermes_pgid=67890, + poll=lambda: 0, + kill=lambda: None, + ) + killpg_calls = [] + + def fake_getpgid(_pid): + raise ProcessLookupError + + def fake_killpg(pgid, sig): + killpg_calls.append((pgid, sig)) + if sig == 0: + raise ProcessLookupError + + monkeypatch.setattr(os, "getpgid", fake_getpgid) + monkeypatch.setattr(os, "killpg", fake_killpg) + + env._kill_process(proc) + + assert killpg_calls == [(67890, signal.SIGTERM), (67890, 0)] + + def test_wait_for_process_kills_subprocess_on_keyboardinterrupt(): """When KeyboardInterrupt arrives mid-poll, the subprocess group must be killed before the exception is re-raised.""" @@ -118,19 +171,15 @@ def test_wait_for_process_kills_subprocess_on_keyboardinterrupt(): assert not t.is_alive(), "worker didn't exit within 5 s of the interrupt" # The critical assertion: the subprocess GROUP must be dead. Not - # just the bash wrapper — the 'sleep 30' child too. - # Give the SIGTERM+1s wait+SIGKILL escalation a moment to complete. - deadline = time.monotonic() + 3.0 - while time.monotonic() < deadline: - if not _pgid_still_alive(pgid): - break - time.sleep(0.1) - assert not _pgid_still_alive(pgid), ( + # just the bash wrapper — the 'sleep 30' child too. Under xdist load, + # process-group disappearance can lag briefly after the worker exits, + # especially if the process is already dying or waiting to be reaped. + assert _wait_for_pgid_exit(pgid), ( f"subprocess group {pgid} is STILL ALIVE after worker received " f"KeyboardInterrupt — orphan bug regressed. This is the " f"sleep-300-survives-SIGTERM scenario from Physikal's Apr 2026 " f"report. See tools/environments/base.py _wait_for_process " - f"except-block." + f"except-block.\n{_process_group_snapshot(pgid)}" ) # And the worker should have observed the KeyboardInterrupt (i.e. # it re-raised cleanly, not silently swallowed). diff --git a/tests/tools/test_mcp_dynamic_discovery.py b/tests/tools/test_mcp_dynamic_discovery.py index 891770319f..c9adf545ed 100644 --- a/tests/tools/test_mcp_dynamic_discovery.py +++ b/tests/tools/test_mcp_dynamic_discovery.py @@ -88,24 +88,29 @@ class TestMessageHandler: from mcp.types import ServerNotification, ToolListChangedNotification server = MCPServerTask("notif_srv") - with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + # Product now schedules the refresh as a background task (see + # _schedule_tools_refresh in mcp_tool.py ~L918) rather than awaiting + # it directly, to avoid wedging the stdio JSON-RPC stream. Patch at + # the scheduler seam so we can still assert dispatch happened without + # reaching into asyncio.create_task internals. + with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule: handler = server._make_message_handler() notification = ServerNotification( root=ToolListChangedNotification(method="notifications/tools/list_changed") ) await handler(notification) - mock_refresh.assert_awaited_once() + mock_schedule.assert_called_once() @pytest.mark.asyncio async def test_ignores_exceptions_and_other_messages(self): server = MCPServerTask("notif_srv") - with patch.object(MCPServerTask, "_refresh_tools", new_callable=AsyncMock) as mock_refresh: + with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule: handler = server._make_message_handler() # Exceptions should not trigger refresh await handler(RuntimeError("connection dead")) # Unknown message types should not trigger refresh await handler({"jsonrpc": "2.0", "result": "ok"}) - mock_refresh.assert_not_awaited() + mock_schedule.assert_not_called() class TestDeregister: diff --git a/tests/tools/test_mcp_structured_content.py b/tests/tools/test_mcp_structured_content.py index 520872e8a5..2870ce1e86 100644 --- a/tests/tools/test_mcp_structured_content.py +++ b/tests/tools/test_mcp_structured_content.py @@ -35,7 +35,15 @@ def _fake_run_on_mcp_loop(coro, timeout=30): """Run an MCP coroutine directly in a fresh event loop.""" loop = asyncio.new_event_loop() try: - return loop.run_until_complete(coro) + # `_rpc_lock` must be created inside the loop that awaits it, or asyncio + # raises "attached to a different loop". Build it here and attach it to + # whatever fake server is currently registered under _servers. + async def _install_lock_and_run(): + for srv in list(mcp_tool._servers.values()): + if getattr(srv, "_rpc_lock", None) is None: + srv._rpc_lock = asyncio.Lock() + return await coro + return loop.run_until_complete(_install_lock_and_run()) finally: loop.close() @@ -44,7 +52,10 @@ def _fake_run_on_mcp_loop(coro, timeout=30): def _patch_mcp_server(): """Patch _servers and the MCP event loop so _make_tool_handler can run.""" fake_session = MagicMock() - fake_server = SimpleNamespace(session=fake_session) + # `_rpc_lock` is acquired by _make_tool_handler's call path (mcp_tool.py + # ~L2008) to serialize JSON-RPC against the server — build it inside the + # fresh loop that _fake_run_on_mcp_loop spins up, not at fixture import. + fake_server = SimpleNamespace(session=fake_session, _rpc_lock=None) with patch.dict(mcp_tool._servers, {"test-server": fake_server}), \ patch("tools.mcp_tool._run_on_mcp_loop", side_effect=_fake_run_on_mcp_loop): yield fake_session diff --git a/tests/tools/test_modal_sandbox_fixes.py b/tests/tools/test_modal_sandbox_fixes.py index 570ef5b218..9113c892d3 100644 --- a/tests/tools/test_modal_sandbox_fixes.py +++ b/tests/tools/test_modal_sandbox_fixes.py @@ -7,6 +7,7 @@ Covers the bugs discovered while setting up TBLite evaluation: 4. ensurepip fix in Modal image builder 5. No swe-rex dependency — uses native Modal SDK 6. /home/ added to host prefix check +7. Vercel sandbox cwd normalization """ import os @@ -101,6 +102,26 @@ class TestCwdHandling: config = _tt_mod._get_env_config() assert config["cwd"] == "/root" + def test_host_path_replaced_for_vercel_sandbox(self, monkeypatch): + """Host paths should be discarded for Vercel Sandbox.""" + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_CWD", "/Users/someone/projects") + config = _tt_mod._get_env_config() + assert config["cwd"] == "/vercel/sandbox" + + def test_relative_path_replaced_for_vercel_sandbox(self, monkeypatch): + """Relative cwd should not map into a remote Vercel sandbox.""" + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_CWD", "src") + config = _tt_mod._get_env_config() + assert config["cwd"] == "/vercel/sandbox" + + def test_default_cwd_is_workspace_root_for_vercel_sandbox(self, monkeypatch): + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.delenv("TERMINAL_CWD", raising=False) + config = _tt_mod._get_env_config() + assert config["cwd"] == "/vercel/sandbox" + @pytest.mark.parametrize("backend", ["modal", "docker", "singularity", "daytona"]) def test_default_cwd_is_root_for_container_backends(self, backend, monkeypatch): """Container backends should default to /root, not ~.""" diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index d981878a31..83059915e4 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -103,6 +103,134 @@ class TestGetAndPoll: assert result["exit_code"] == 0 +# ========================================================================= +# Orphaned-pipe reconciliation (issue #17327) +# ========================================================================= + +@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only: uses setsid/fcntl") +class TestOrphanedPipeReconciliation: + """Regression tests for issue #17327. + + `hermes update` in Feishu spawned a background subprocess that restarted + the gateway; the direct child exited quickly but a descendant daemon + held the stdout pipe open. `_reader_loop.finally` never ran, so + `session.exited` stayed False and the agent polled 74 times over 7 + minutes, all returning `status: running`. + + The fix is `_reconcile_local_exit()`: poll() and wait() now check the + direct `Popen.poll()` before trusting `session.exited`. + """ + + def test_reconcile_flips_exited_when_direct_child_done(self, registry): + """Direct child exited but reader thread is blocked on orphaned pipe.""" + # Simulate the orphaned-pipe scenario: direct child exited, but a + # descendant holds stdout open so the reader never sees EOF. + # Approach: spawn `sh -c 'sleep 10 &'` with setsid — sh forks the + # sleep into a new session group, exits immediately, but sleep + # inherits the stdout pipe and keeps it open. + proc = subprocess.Popen( + ["sh", "-c", "exec 1>&2; ( sleep 30 ) & disown; exit 0"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + preexec_fn=os.setsid, + ) + + s = _make_session(sid="proc_orphan_test") + s.process = proc + s.pid = proc.pid + registry._running[s.id] = s + + # Wait for the direct child to exit. We don't start a reader thread, + # so session.exited stays False (mimicking the stuck-reader state). + assert _wait_until(lambda: proc.poll() is not None, timeout=5.0), ( + "Direct child should exit quickly (sh exits, sleep descendant " + "holds the pipe open)" + ) + + # Before the fix: poll would return "running" forever. + # After the fix: poll reconciles against proc.poll() and flips. + assert s.exited is False # Precondition: reader hasn't updated it. + result = registry.poll(s.id) + assert result["status"] == "exited", ( + f"Expected reconciled 'exited' status; got {result!r}. " + "This is issue #17327 — reader is blocked on orphaned pipe." + ) + assert result["exit_code"] == 0 + assert s.exited is True + assert s.id in registry._finished + assert s.id not in registry._running + + # Clean up the orphaned descendant. + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass + + def test_reconcile_noop_when_child_still_running(self, registry): + """Reconcile must NOT flip exited when the direct child is alive.""" + proc = _spawn_python_sleep(5.0) + s = _make_session(sid="proc_running_test") + s.process = proc + s.pid = proc.pid + registry._running[s.id] = s + + result = registry.poll(s.id) + assert result["status"] == "running" + assert s.exited is False + + proc.kill() + proc.wait() + + def test_reconcile_noop_on_already_exited(self, registry): + """Reconcile is a no-op when session.exited is already True.""" + s = _make_session(sid="proc_already_exited", exited=True, exit_code=7) + s.process = MagicMock() + s.process.poll = MagicMock(return_value=0) # Would say exit 0 + registry._finished[s.id] = s + + registry._reconcile_local_exit(s) + # Must not overwrite the existing exit_code with proc.poll()'s 0. + assert s.exit_code == 7 + + def test_reconcile_noop_on_no_process(self, registry): + """Reconcile is a no-op for sessions without a local Popen (env/PTY).""" + s = _make_session(sid="proc_no_popen") + assert getattr(s, "process", None) is None + # Must not raise. + registry._reconcile_local_exit(s) + assert s.exited is False + + def test_wait_returns_when_reader_blocked(self, registry): + """wait() must also reconcile — not just poll().""" + proc = subprocess.Popen( + ["sh", "-c", "( sleep 30 ) & disown; exit 0"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + preexec_fn=os.setsid, + ) + + s = _make_session(sid="proc_wait_orphan") + s.process = proc + s.pid = proc.pid + registry._running[s.id] = s + + assert _wait_until(lambda: proc.poll() is not None, timeout=5.0) + + start = time.monotonic() + result = registry.wait(s.id, timeout=10) + elapsed = time.monotonic() - start + + assert result["status"] == "exited", result + assert elapsed < 5.0, ( + f"wait() should return ~immediately via reconcile; took {elapsed:.1f}s" + ) + + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError): + pass + + # ========================================================================= # Read log # ========================================================================= diff --git a/tests/tools/test_skill_manager_tool.py b/tests/tools/test_skill_manager_tool.py index 9918a826cb..9fc8957f1e 100644 --- a/tests/tools/test_skill_manager_tool.py +++ b/tests/tools/test_skill_manager_tool.py @@ -566,3 +566,262 @@ class TestSecurityScanGate: with patch("hermes_cli.config.load_config", side_effect=RuntimeError("boom")): assert _guard_agent_created_enabled() is False + + +# --------------------------------------------------------------------------- +# External skills directories (skills.external_dirs) — mutations in place +# --------------------------------------------------------------------------- + + +@contextmanager +def _two_roots(local_dir: Path, external_dir: Path): + """Patch the skill manager so local SKILLS_DIR = local_dir and + get_all_skills_dirs() returns [local_dir, external_dir] in order.""" + with patch("tools.skill_manager_tool.SKILLS_DIR", local_dir), \ + patch("agent.skill_utils.get_all_skills_dirs", + return_value=[local_dir, external_dir]): + yield + + +def _write_external_skill(external_dir: Path, name: str = "ext-skill") -> Path: + skill_dir = external_dir / name + skill_dir.mkdir(parents=True) + (skill_dir / "SKILL.md").write_text( + f"---\nname: {name}\ndescription: An external skill.\n---\n\n" + "# External\n\nBody with OLD_MARKER here.\n" + ) + return skill_dir + + +class TestExternalSkillMutations: + """Verify skill_manage can patch/edit/write/remove/delete skills that live + under skills.external_dirs — in place, without duplicating to local. + + Regression for issues #4759 and #4381: the read-only gate used to refuse + with 'Skill X is in an external directory and cannot be modified', which + caused agents to create duplicate copies in ~/.hermes/skills/ as a + workaround. + """ + + def test_patch_external_skill_writes_in_place(self, tmp_path): + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + skill_dir = _write_external_skill(external) + + with _two_roots(local, external): + result = _patch_skill("ext-skill", "OLD_MARKER", "NEW_MARKER") + + assert result["success"] is True, result + assert "NEW_MARKER" in (skill_dir / "SKILL.md").read_text() + # No duplicate in local + assert not (local / "ext-skill").exists() + + def test_edit_external_skill_writes_in_place(self, tmp_path): + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + skill_dir = _write_external_skill(external) + + new_content = ( + "---\nname: ext-skill\ndescription: Rewritten.\n---\n\n" + "# Rewritten\n\nBrand new body.\n" + ) + with _two_roots(local, external): + result = _edit_skill("ext-skill", new_content) + + assert result["success"] is True, result + assert "Brand new body" in (skill_dir / "SKILL.md").read_text() + assert not (local / "ext-skill").exists() + + def test_write_file_on_external_skill(self, tmp_path): + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + skill_dir = _write_external_skill(external) + + with _two_roots(local, external): + result = _write_file("ext-skill", "references/notes.md", "# Notes\n") + + assert result["success"] is True, result + assert (skill_dir / "references" / "notes.md").read_text() == "# Notes\n" + assert not (local / "ext-skill").exists() + + def test_remove_file_on_external_skill(self, tmp_path): + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + skill_dir = _write_external_skill(external) + (skill_dir / "references").mkdir() + (skill_dir / "references" / "notes.md").write_text("# Notes\n") + + with _two_roots(local, external): + result = _remove_file("ext-skill", "references/notes.md") + + assert result["success"] is True, result + assert not (skill_dir / "references" / "notes.md").exists() + + def test_delete_external_skill_removes_skill_not_root(self, tmp_path): + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + skill_dir = _write_external_skill(external) + + with _two_roots(local, external): + result = _delete_skill("ext-skill") + + assert result["success"] is True, result + assert not skill_dir.exists() + # The external root must NOT be rmdir'd, even when empty after deletion + assert external.exists() and external.is_dir() + + def test_delete_external_skill_cleans_empty_category(self, tmp_path): + """When a skill lives under external//, deleting the + last skill in the category should rmdir the empty category dir but + stop at the external root.""" + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + cat_dir = external / "team" + cat_dir.mkdir() + skill_dir = cat_dir / "ext-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: ext-skill\ndescription: An external skill.\n---\n\n" + "# External\n\nBody.\n" + ) + + with _two_roots(local, external): + result = _delete_skill("ext-skill") + + assert result["success"] is True, result + assert not skill_dir.exists() + assert not cat_dir.exists() # empty category cleaned up + assert external.exists() # but never the external root + + def test_create_still_writes_to_local_root(self, tmp_path): + """Creating a new skill always lands in local SKILLS_DIR, never + external_dirs — create is unchanged by this PR.""" + local = tmp_path / "local" + external = tmp_path / "vault" + local.mkdir(); external.mkdir() + + with _two_roots(local, external): + result = _create_skill("fresh-skill", VALID_SKILL_CONTENT.replace( + "name: test-skill", "name: fresh-skill")) + + assert result["success"] is True, result + assert (local / "fresh-skill" / "SKILL.md").exists() + assert not (external / "fresh-skill").exists() + + + +# --------------------------------------------------------------------------- +# Pinned-skill guard — skill_manage refuses all writes to pinned skills. +# The user unpins via `hermes curator unpin `. +# --------------------------------------------------------------------------- + +class TestPinnedGuard: + """Every mutation action must refuse when the skill is pinned.""" + + @staticmethod + def _pin(name: str): + """Return a patch context that marks *name* as pinned in skill_usage.""" + def _fake_get_record(skill_name, _name=name): + return {"pinned": True} if skill_name == _name else {"pinned": False} + return patch("tools.skill_usage.get_record", side_effect=_fake_get_record) + + def test_edit_refuses_pinned(self, tmp_path): + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + with self._pin("my-skill"): + result = _edit_skill("my-skill", VALID_SKILL_CONTENT_2) + assert result["success"] is False + assert "pinned" in result["error"].lower() + assert "hermes curator unpin my-skill" in result["error"] + # Original content preserved + content = (tmp_path / "my-skill" / "SKILL.md").read_text() + assert "A test skill" in content + + def test_patch_refuses_pinned(self, tmp_path): + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + with self._pin("my-skill"): + result = _patch_skill("my-skill", "Do the thing.", "Do the new thing.") + assert result["success"] is False + assert "pinned" in result["error"].lower() + assert "hermes curator unpin my-skill" in result["error"] + content = (tmp_path / "my-skill" / "SKILL.md").read_text() + assert "Do the thing." in content # unchanged + + def test_patch_supporting_file_refuses_pinned(self, tmp_path): + """Pin covers supporting files too, not just SKILL.md.""" + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + _write_file("my-skill", "references/api.md", "original") + with self._pin("my-skill"): + result = _patch_skill( + "my-skill", "original", "modified", + file_path="references/api.md", + ) + assert result["success"] is False + assert "pinned" in result["error"].lower() + assert (tmp_path / "my-skill" / "references" / "api.md").read_text() == "original" + + def test_delete_refuses_pinned(self, tmp_path): + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + with self._pin("my-skill"): + result = _delete_skill("my-skill") + assert result["success"] is False + assert "pinned" in result["error"].lower() + # Skill still exists + assert (tmp_path / "my-skill" / "SKILL.md").exists() + + def test_write_file_refuses_pinned(self, tmp_path): + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + with self._pin("my-skill"): + result = _write_file("my-skill", "references/api.md", "content") + assert result["success"] is False + assert "pinned" in result["error"].lower() + assert not (tmp_path / "my-skill" / "references" / "api.md").exists() + + def test_remove_file_refuses_pinned(self, tmp_path): + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + _write_file("my-skill", "references/api.md", "content") + with self._pin("my-skill"): + result = _remove_file("my-skill", "references/api.md") + assert result["success"] is False + assert "pinned" in result["error"].lower() + # File still there + assert (tmp_path / "my-skill" / "references" / "api.md").exists() + + def test_unpinned_skills_still_editable(self, tmp_path): + """Sanity check: the guard doesn't fire for unpinned skills. + + Only the specifically-pinned skill is refused; a sibling skill must + still be freely editable. + """ + with _skill_dir(tmp_path): + _create_skill("pinned-one", VALID_SKILL_CONTENT) + _create_skill("free-one", VALID_SKILL_CONTENT) + with self._pin("pinned-one"): + blocked = _edit_skill("pinned-one", VALID_SKILL_CONTENT_2) + allowed = _edit_skill("free-one", VALID_SKILL_CONTENT_2) + assert blocked["success"] is False + assert allowed["success"] is True + + def test_broken_sidecar_fails_open(self, tmp_path): + """If skill_usage.get_record raises, we allow the write through. + + Rationale: a corrupted telemetry file shouldn't lock the agent out + of skills it would otherwise be allowed to touch. + """ + with _skill_dir(tmp_path): + _create_skill("my-skill", VALID_SKILL_CONTENT) + with patch("tools.skill_usage.get_record", + side_effect=RuntimeError("sidecar broken")): + result = _edit_skill("my-skill", VALID_SKILL_CONTENT_2) + assert result["success"] is True diff --git a/tests/tools/test_skills_tool.py b/tests/tools/test_skills_tool.py index 79470710b0..d95fc0671d 100644 --- a/tests/tools/test_skills_tool.py +++ b/tests/tools/test_skills_tool.py @@ -932,7 +932,7 @@ class TestSkillViewPrerequisites: @pytest.mark.parametrize( "backend", - ["ssh", "daytona", "docker", "singularity", "modal"], + ["ssh", "daytona", "docker", "singularity", "modal", "vercel_sandbox"], ) def test_remote_backend_becomes_available_after_local_secret_capture( self, tmp_path, monkeypatch, backend diff --git a/tests/tools/test_slash_confirm.py b/tests/tools/test_slash_confirm.py new file mode 100644 index 0000000000..e02f1c752e --- /dev/null +++ b/tests/tools/test_slash_confirm.py @@ -0,0 +1,197 @@ +"""Tests for tools/slash_confirm.py — the generic slash-command confirmation primitive. + +Covers register/resolve/clear lifecycle, stale-entry behavior, confirm_id +mismatch, handler exceptions, and async resolution. +""" + +import asyncio +import time + +import pytest + +from tools import slash_confirm + + +@pytest.fixture(autouse=True) +def _clean_pending(): + """Every test gets a clean primitive state.""" + slash_confirm._pending.clear() + yield + slash_confirm._pending.clear() + + +class TestRegisterAndGetPending: + def test_register_stores_entry(self): + async def handler(choice): + return f"got {choice}" + + slash_confirm.register("sess1", "cid1", "reload-mcp", handler) + + pending = slash_confirm.get_pending("sess1") + assert pending is not None + assert pending["confirm_id"] == "cid1" + assert pending["command"] == "reload-mcp" + assert pending["handler"] is handler + assert "created_at" in pending + + def test_get_pending_missing_returns_none(self): + assert slash_confirm.get_pending("nobody") is None + + def test_register_supersedes_prior_entry(self): + async def h1(choice): + return "first" + + async def h2(choice): + return "second" + + slash_confirm.register("sess1", "cid1", "reload-mcp", h1) + slash_confirm.register("sess1", "cid2", "reload-mcp", h2) + + pending = slash_confirm.get_pending("sess1") + assert pending["confirm_id"] == "cid2" + assert pending["handler"] is h2 + + def test_get_pending_returns_copy_not_reference(self): + async def h(choice): + return "x" + + slash_confirm.register("sess1", "cid1", "cmd", h) + + p1 = slash_confirm.get_pending("sess1") + p1["command"] = "mutated" + + p2 = slash_confirm.get_pending("sess1") + assert p2["command"] == "cmd" + + +class TestResolve: + @pytest.mark.asyncio + async def test_resolve_runs_handler_and_pops_entry(self): + calls = [] + + async def handler(choice): + calls.append(choice) + return f"resolved {choice}" + + slash_confirm.register("sess1", "cid1", "reload-mcp", handler) + + result = await slash_confirm.resolve("sess1", "cid1", "once") + assert result == "resolved once" + assert calls == ["once"] + + # Entry should be popped. + assert slash_confirm.get_pending("sess1") is None + + @pytest.mark.asyncio + async def test_resolve_no_pending_returns_none(self): + result = await slash_confirm.resolve("sess1", "cid1", "once") + assert result is None + + @pytest.mark.asyncio + async def test_resolve_confirm_id_mismatch_returns_none(self): + async def handler(choice): + return "should not run" + + slash_confirm.register("sess1", "cid_real", "cmd", handler) + + result = await slash_confirm.resolve("sess1", "cid_wrong", "once") + assert result is None + + # Stale entry should still be present (mismatch doesn't pop). + assert slash_confirm.get_pending("sess1") is not None + + @pytest.mark.asyncio + async def test_resolve_stale_entry_returns_none(self): + async def handler(choice): + return "should not run" + + slash_confirm.register("sess1", "cid1", "cmd", handler) + # Force entry age past timeout + slash_confirm._pending["sess1"]["created_at"] = time.time() - 10000 + + result = await slash_confirm.resolve("sess1", "cid1", "once") + assert result is None + + @pytest.mark.asyncio + async def test_resolve_handler_exception_returns_error_string(self): + async def handler(choice): + raise RuntimeError("boom") + + slash_confirm.register("sess1", "cid1", "cmd", handler) + + result = await slash_confirm.resolve("sess1", "cid1", "once") + assert result is not None + assert "boom" in result + # Entry should still be popped even when handler raises. + assert slash_confirm.get_pending("sess1") is None + + @pytest.mark.asyncio + async def test_resolve_non_string_return_becomes_none(self): + async def handler(choice): + return {"not": "a string"} + + slash_confirm.register("sess1", "cid1", "cmd", handler) + result = await slash_confirm.resolve("sess1", "cid1", "once") + assert result is None + + @pytest.mark.asyncio + async def test_resolve_double_click_only_runs_handler_once(self): + calls = [] + + async def handler(choice): + calls.append(choice) + return "ran" + + slash_confirm.register("sess1", "cid1", "cmd", handler) + + # Simulate two near-simultaneous button clicks. + r1, r2 = await asyncio.gather( + slash_confirm.resolve("sess1", "cid1", "once"), + slash_confirm.resolve("sess1", "cid1", "once"), + ) + # Exactly one should have run the handler. + assert calls == ["once"] + assert (r1 == "ran") ^ (r2 == "ran") + + +class TestClear: + def test_clear_removes_entry(self): + async def h(c): + return "x" + + slash_confirm.register("sess1", "cid1", "cmd", h) + assert slash_confirm.get_pending("sess1") is not None + + slash_confirm.clear("sess1") + assert slash_confirm.get_pending("sess1") is None + + def test_clear_missing_is_noop(self): + # Should not raise. + slash_confirm.clear("nobody") + + +class TestClearIfStale: + def test_clears_stale_entry(self): + async def h(c): + return "x" + + slash_confirm.register("sess1", "cid1", "cmd", h) + slash_confirm._pending["sess1"]["created_at"] = time.time() - 10000 + + cleared = slash_confirm.clear_if_stale("sess1", timeout=300) + assert cleared is True + assert slash_confirm.get_pending("sess1") is None + + def test_preserves_fresh_entry(self): + async def h(c): + return "x" + + slash_confirm.register("sess1", "cid1", "cmd", h) + + cleared = slash_confirm.clear_if_stale("sess1", timeout=300) + assert cleared is False + assert slash_confirm.get_pending("sess1") is not None + + def test_returns_false_for_missing_entry(self): + cleared = slash_confirm.clear_if_stale("nobody") + assert cleared is False diff --git a/tests/tools/test_terminal_config_env_sync.py b/tests/tools/test_terminal_config_env_sync.py new file mode 100644 index 0000000000..892062fae7 --- /dev/null +++ b/tests/tools/test_terminal_config_env_sync.py @@ -0,0 +1,210 @@ +"""Regression tests for terminal config -> env-var bridging. + +terminal_tool._get_env_config() reads ALL terminal settings from os.environ +(TERMINAL_*). config.yaml values therefore have to be bridged into env vars +at startup, by THREE separate code paths: + + 1. cli.py -> ``env_mappings`` dict (CLI / TUI startup) + 2. gateway/run.py -> ``_terminal_env_map`` dict (gateway / messaging + platforms) + 3. hermes_cli/config.py:save_config_value + -> ``_config_to_env_sync`` dict (one-shot when the + user runs ``hermes config set …``) + +If any one of these is missing a key, the corresponding config.yaml setting +silently does nothing for that entry-point. This bug already shipped once +for ``docker_run_as_host_user`` (gateway and CLI maps) and once for +``docker_mount_cwd_to_workspace`` (gateway map). + +This test guards against future drift by extracting all three maps via source +inspection and asserting they all bridge the same set of writable +``terminal.*`` keys. Source inspection (rather than importing the live +dicts) keeps the test independent of the user's ~/.hermes/config.yaml and +mirrors the pattern used in tests/hermes_cli/test_config_drift.py. +""" + +import ast +import inspect + + +def _extract_dict_values(source: str, dict_name: str) -> set[str]: + """Return the set of *value* strings in `dict_name = { "k": "VALUE", ... }`. + + We parse the source with ast (so multi-line dicts and comments are + handled) instead of regex. The first matching assignment wins. + """ + tree = ast.parse(source) + for node in ast.walk(tree): + if not isinstance(node, ast.Assign): + continue + targets = [t for t in node.targets if isinstance(t, ast.Name)] + if not any(t.id == dict_name for t in targets): + continue + if not isinstance(node.value, ast.Dict): + continue + out: set[str] = set() + for k, v in zip(node.value.keys, node.value.values): + if isinstance(k, ast.Constant) and isinstance(v, ast.Constant): + if isinstance(v.value, str): + out.add(v.value) + return out + raise AssertionError(f"Could not find `{dict_name} = {{...}}` literal in source") + + +def _extract_dict_keys(source: str, dict_name: str) -> set[str]: + """Return the set of *key* strings in `dict_name = { "KEY": "v", ... }`.""" + tree = ast.parse(source) + for node in ast.walk(tree): + if not isinstance(node, ast.Assign): + continue + targets = [t for t in node.targets if isinstance(t, ast.Name)] + if not any(t.id == dict_name for t in targets): + continue + if not isinstance(node.value, ast.Dict): + continue + out: set[str] = set() + for k in node.value.keys: + if isinstance(k, ast.Constant) and isinstance(k.value, str): + out.add(k.value) + return out + raise AssertionError(f"Could not find `{dict_name} = {{...}}` literal in source") + + +def _cli_env_map_keys() -> set[str]: + """terminal config keys bridged by cli.load_cli_config().""" + import cli + source = inspect.getsource(cli.load_cli_config) + return _extract_dict_keys(source, "env_mappings") + + +def _gateway_env_map_keys() -> set[str]: + """terminal config keys bridged by gateway/run.py at module load.""" + # gateway/run.py builds the dict at module top-level (not inside a + # function), so inspect the whole module source. + import gateway.run as gr + source = inspect.getsource(gr) + return _extract_dict_keys(source, "_terminal_env_map") + + +def _save_config_env_sync_keys() -> set[str]: + """terminal config keys bridged by ``hermes config set foo bar``.""" + from hermes_cli import config as hc_config + source = inspect.getsource(hc_config.set_config_value) + keys = _extract_dict_keys(source, "_config_to_env_sync") + # set_config_value uses fully-qualified ``terminal.foo`` keys; strip the + # prefix so we can compare against the other two maps which use bare + # leaf keys. + return {k.split(".", 1)[1] for k in keys if k.startswith("terminal.")} + + +# Keys present in cli.py env_mappings but intentionally absent from +# gateway/run.py or set_config_value. Each entry must be justified. +_CLI_ONLY_OK = frozenset({ + # `env_type` is a legacy YAML key alias for `backend` that cli.py + # accepts for backwards-compat with older cli-config.yaml. The + # gateway path normalizes on the canonical `backend` key, which is + # also in the map and handles the same bridging. See cli.py ~line 515. + "env_type", + # sudo_password is not a terminal-backend option — it's a credential + # used across backends, bridged to $SUDO_PASSWORD (not TERMINAL_*). + # Treating it as terminal-only would be misleading. + "sudo_password", +}) + + +def _terminal_tool_env_var_names() -> set[str]: + """All TERMINAL_* env vars actually consumed by terminal_tool.""" + import tools.terminal_tool as tt + source = inspect.getsource(tt) + # Naive scan: every os.getenv("TERMINAL_X", ...) and _parse_env_var("TERMINAL_X", ...). + import re + pat = re.compile(r'["\'](TERMINAL_[A-Z0-9_]+)["\']') + return set(pat.findall(source)) + + +def test_cli_and_gateway_env_maps_agree(): + """cli.py and gateway/run.py must bridge the same set of terminal keys. + + Both feed the same downstream consumer (terminal_tool). Drift between + them means a config.yaml setting that "works in CLI mode but not gateway + mode" (or vice-versa) — the bug class that shipped twice already. + """ + cli_keys = _cli_env_map_keys() - _CLI_ONLY_OK + gw_keys = _gateway_env_map_keys() + + # Normalize the legacy `env_type` alias: cli.py accepts both `env_type` + # and `backend` as source keys for TERMINAL_ENV; gateway only accepts + # `backend`. Since cli.py copies `backend` → `env_type` before the + # lookup, they're equivalent. Remove `backend` from the gateway side + # to avoid a spurious "backend missing from cli" failure. + gw_keys = gw_keys - {"backend"} + + missing_in_gateway = cli_keys - gw_keys + missing_in_cli = gw_keys - cli_keys + + assert not missing_in_gateway, ( + f"Keys in cli.py env_mappings but missing from gateway/run.py " + f"_terminal_env_map: {sorted(missing_in_gateway)}. Add them to " + f"both maps (same bug class as docker_run_as_host_user shipping " + f"wired in cli but not gateway in April 2026)." + ) + assert not missing_in_cli, ( + f"Keys in gateway/run.py _terminal_env_map but missing from cli.py " + f"env_mappings: {sorted(missing_in_cli)}. Add them to both maps." + ) + + +def test_save_config_set_supports_critical_bridged_keys(): + """``hermes config set terminal.X true`` must propagate to .env for + known-critical keys. This used to be an all-keys invariant but several + pre-existing terminal keys (ssh_*, docker_forward_env, docker_volumes) + aren't in _config_to_env_sync and are instead handled via the separate + api_keys TERMINAL_SSH_* fallback path or user-edits-yaml-directly. + + Until those gaps are audited and fixed, pin the specific keys that are + load-bearing for the docker backend's ownership flag so the bug we just + fixed cannot silently regress. + """ + save_keys = _save_config_env_sync_keys() + required = { + "docker_run_as_host_user", + "docker_mount_cwd_to_workspace", + "backend", + "docker_image", + "container_cpu", + "container_memory", + "container_disk", + "container_persistent", + } + missing = required - save_keys + assert not missing, ( + f"`hermes config set terminal.X` doesn't sync these load-bearing " + f"keys to .env: {sorted(missing)}. Add them to _config_to_env_sync " + f"in hermes_cli/config.py:set_config_value." + ) + + +def test_docker_run_as_host_user_is_bridged_everywhere(): + """Explicit pin for the bug we just fixed. + + docker_run_as_host_user was added to terminal_tool._get_env_config and + DockerEnvironment but NOT to cli.py's env_mappings or gateway/run.py's + _terminal_env_map, so ``terminal.docker_run_as_host_user: true`` in + config.yaml had no effect at runtime. This guard makes the regression + impossible to reintroduce silently. + """ + assert "docker_run_as_host_user" in _cli_env_map_keys() + assert "docker_run_as_host_user" in _gateway_env_map_keys() + assert "docker_run_as_host_user" in _save_config_env_sync_keys() + assert "TERMINAL_DOCKER_RUN_AS_HOST_USER" in _terminal_tool_env_var_names() + + +def test_docker_mount_cwd_to_workspace_is_bridged_everywhere(): + """Same regression class — docker_mount_cwd_to_workspace was missing from + gateway/run.py's _terminal_env_map until the docker_run_as_host_user + audit caught it. + """ + assert "docker_mount_cwd_to_workspace" in _cli_env_map_keys() + assert "docker_mount_cwd_to_workspace" in _gateway_env_map_keys() + assert "docker_mount_cwd_to_workspace" in _save_config_env_sync_keys() + assert "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE" in _terminal_tool_env_var_names() diff --git a/tests/tools/test_terminal_requirements.py b/tests/tools/test_terminal_requirements.py index 7859043ab5..265fd567fd 100644 --- a/tests/tools/test_terminal_requirements.py +++ b/tests/tools/test_terminal_requirements.py @@ -1,6 +1,8 @@ import importlib import logging +import pytest + terminal_tool_module = importlib.import_module("tools.terminal_tool") @@ -8,11 +10,24 @@ def _clear_terminal_env(monkeypatch): """Remove terminal env vars that could affect requirements checks.""" keys = [ "TERMINAL_ENV", + "TERMINAL_CONTAINER_CPU", + "TERMINAL_CONTAINER_DISK", + "TERMINAL_CONTAINER_MEMORY", + "TERMINAL_DOCKER_FORWARD_ENV", + "TERMINAL_DOCKER_VOLUMES", + "TERMINAL_LIFETIME_SECONDS", "TERMINAL_MODAL_MODE", "TERMINAL_SSH_HOST", + "TERMINAL_SSH_PORT", "TERMINAL_SSH_USER", + "TERMINAL_TIMEOUT", + "TERMINAL_VERCEL_RUNTIME", "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", + "VERCEL_OIDC_TOKEN", + "VERCEL_TOKEN", + "VERCEL_PROJECT_ID", + "VERCEL_TEAM_ID", "HOME", "USERPROFILE", ] @@ -176,3 +191,126 @@ def test_modal_backend_managed_mode_without_feature_flag_logs_clear_error(monkey "paid Nous subscription is required" in record.getMessage() for record in caplog.records ) + + +def test_vercel_backend_without_sdk_logs_specific_error(monkeypatch, caplog): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: None) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "vercel is required for the Vercel Sandbox terminal backend" in record.getMessage() + for record in caplog.records + ) + + +def test_vercel_backend_without_auth_logs_specific_error(monkeypatch, caplog): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "no supported auth configuration was found" in record.getMessage() + for record in caplog.records + ) + + +def test_vercel_backend_accepts_oidc_auth(monkeypatch): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + assert terminal_tool_module.check_terminal_requirements() is True + + +def test_vercel_backend_accepts_token_tuple_auth(monkeypatch): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("VERCEL_TOKEN", "token") + monkeypatch.setenv("VERCEL_PROJECT_ID", "project") + monkeypatch.setenv("VERCEL_TEAM_ID", "team") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + assert terminal_tool_module.check_terminal_requirements() is True + + +@pytest.mark.parametrize("runtime", ["node24", "node22", "python3.13"]) +def test_vercel_backend_accepts_supported_runtimes(monkeypatch, runtime): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", runtime) + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + assert terminal_tool_module.check_terminal_requirements() is True + + +def test_vercel_backend_accepts_blank_runtime(monkeypatch): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", " ") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + assert terminal_tool_module.check_terminal_requirements() is True + + +def test_vercel_backend_rejects_unsupported_runtime(monkeypatch, caplog): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_VERCEL_RUNTIME", "node20") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "Vercel Sandbox runtime 'node20' is not supported" in record.getMessage() + and "node24, node22, python3.13" in record.getMessage() + for record in caplog.records + ) + + +def test_vercel_backend_rejects_nondefault_disk(monkeypatch, caplog): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "8192") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "does not support custom TERMINAL_CONTAINER_DISK=8192" in record.getMessage() + for record in caplog.records + ) + + +def test_vercel_backend_rejects_malformed_disk_without_raising(monkeypatch, caplog): + _clear_terminal_env(monkeypatch) + monkeypatch.setenv("TERMINAL_ENV", "vercel_sandbox") + monkeypatch.setenv("TERMINAL_CONTAINER_DISK", "large") + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object()) + + with caplog.at_level(logging.ERROR): + ok = terminal_tool_module.check_terminal_requirements() + + assert ok is False + assert any( + "Invalid value for TERMINAL_CONTAINER_DISK" in record.getMessage() + for record in caplog.records + ) diff --git a/tests/tools/test_terminal_tool_requirements.py b/tests/tools/test_terminal_tool_requirements.py index 1fbaef8e31..fe22bd26c5 100644 --- a/tests/tools/test_terminal_tool_requirements.py +++ b/tests/tools/test_terminal_tool_requirements.py @@ -49,3 +49,68 @@ class TestTerminalRequirements: assert "terminal" in names assert "execute_code" in names + + def test_terminal_and_execute_code_tools_resolve_for_vercel_sandbox(self, monkeypatch): + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr( + terminal_tool_module, + "_get_env_config", + lambda: {"env_type": "vercel_sandbox", "container_disk": 51200}, + ) + monkeypatch.setattr( + terminal_tool_module.importlib.util, + "find_spec", + lambda _name: object(), + ) + tools = get_tool_definitions(enabled_toolsets=["terminal", "code_execution"], quiet_mode=True) + names = {tool["function"]["name"] for tool in tools} + + assert "terminal" in names + assert "execute_code" in names + + def test_terminal_and_execute_code_tools_hide_for_unsupported_vercel_runtime(self, monkeypatch): + monkeypatch.setenv("VERCEL_OIDC_TOKEN", "oidc-token") + monkeypatch.setattr( + terminal_tool_module, + "_get_env_config", + lambda: { + "env_type": "vercel_sandbox", + "container_disk": 51200, + "vercel_runtime": "node20", + }, + ) + monkeypatch.setattr( + terminal_tool_module.importlib.util, + "find_spec", + lambda _name: object(), + ) + tools = get_tool_definitions(enabled_toolsets=["terminal", "code_execution"], quiet_mode=True) + names = {tool["function"]["name"] for tool in tools} + + assert "terminal" not in names + assert "execute_code" not in names + + def test_terminal_and_execute_code_tools_hide_for_vercel_without_auth(self, monkeypatch): + monkeypatch.delenv("VERCEL_OIDC_TOKEN", raising=False) + monkeypatch.delenv("VERCEL_TOKEN", raising=False) + monkeypatch.delenv("VERCEL_PROJECT_ID", raising=False) + monkeypatch.delenv("VERCEL_TEAM_ID", raising=False) + monkeypatch.setattr( + terminal_tool_module, + "_get_env_config", + lambda: { + "env_type": "vercel_sandbox", + "container_disk": 51200, + "vercel_runtime": "node22", + }, + ) + monkeypatch.setattr( + terminal_tool_module.importlib.util, + "find_spec", + lambda _name: object(), + ) + tools = get_tool_definitions(enabled_toolsets=["terminal", "code_execution"], quiet_mode=True) + names = {tool["function"]["name"] for tool in tools} + + assert "terminal" not in names + assert "execute_code" not in names diff --git a/tests/tools/test_tirith_security.py b/tests/tools/test_tirith_security.py index 10a92e9b94..20d20ccfa1 100644 --- a/tests/tools/test_tirith_security.py +++ b/tests/tools/test_tirith_security.py @@ -997,10 +997,13 @@ class TestHermesHomeIsolation: assert "hermes_test" in hermes_home, "Should point to test temp dir" def test_get_hermes_home_fallback(self): - """Without HERMES_HOME set, falls back to ~/.hermes.""" + """Without HERMES_HOME set, falls back to the active OS home.""" from tools.tirith_security import _get_hermes_home with patch.dict(os.environ, {}, clear=True): - # Remove HERMES_HOME entirely + # Remove HERMES_HOME entirely. With HOME also absent, expanduser + # falls back to the account database; compute expected under the + # same environment instead of after patch.dict restores HOME. os.environ.pop("HERMES_HOME", None) + expected = os.path.join(os.path.expanduser("~"), ".hermes") result = _get_hermes_home() - assert result == os.path.join(os.path.expanduser("~"), ".hermes") + assert result == expected diff --git a/tests/tools/test_transcription.py b/tests/tools/test_transcription.py index 9983f9031b..e56577ca55 100644 --- a/tests/tools/test_transcription.py +++ b/tests/tools/test_transcription.py @@ -36,14 +36,16 @@ class TestGetProvider: monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ - patch("tools.transcription_tools._HAS_OPENAI", True): + patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("tools.transcription_tools._has_local_command", return_value=False): from tools.transcription_tools import _get_provider assert _get_provider({"provider": "local"}) == "none" def test_local_nothing_available(self, monkeypatch): monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ - patch("tools.transcription_tools._HAS_OPENAI", False): + patch("tools.transcription_tools._HAS_OPENAI", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False): from tools.transcription_tools import _get_provider assert _get_provider({"provider": "local"}) == "none" diff --git a/tests/tools/test_transcription_dotenv_fallback.py b/tests/tools/test_transcription_dotenv_fallback.py new file mode 100644 index 0000000000..39f5ca108e --- /dev/null +++ b/tests/tools/test_transcription_dotenv_fallback.py @@ -0,0 +1,230 @@ +"""Regression tests for the transcription_tools variant of #17140. + +Same class of bug as ``tools/tts_tool.py`` (fixed in PR #17163): the STT +provider call sites read API keys via ``os.getenv()``, which bypasses +``~/.hermes/.env`` entries. These tests confirm each STT provider now +consults ``get_env_value()`` and the provider auto-detect + explicit +selection gate (``_get_provider``) do the same. +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def isolate_env(monkeypatch): + """Strip every STT-related env var so the test really exercises the + dotenv code path. If any of these survive into the test, the assertion + that ``get_env_value`` was consulted becomes meaningless because + ``os.environ`` already satisfies the lookup. + """ + for key in ( + "GROQ_API_KEY", + "MISTRAL_API_KEY", + "XAI_API_KEY", + "XAI_STT_BASE_URL", + ): + monkeypatch.delenv(key, raising=False) + + +class TestProviderSelectionGate: + """``_get_provider`` picks the STT backend. If it only consulted + ``os.environ`` a user with keys in ``~/.hermes/.env`` would be told + "no STT available" even though the actual transcribe call would + succeed. The gate lives behind ``is_stt_enabled(stt_config)``, so + configure ``{"enabled": True, "provider": ...}`` for explicit tests. + """ + + def test_import_after_config_env_patch_uses_restored_dotenv_loader(self): + """Importing STT while hermes_cli.config.get_env_value is patched must + not freeze that temporary helper into this module forever. + """ + import importlib + import hermes_cli.config as config_mod + from tools import transcription_tools as tt + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(config_mod, "get_env_value", lambda name, default=None: "") + tt = importlib.reload(tt) + + try: + with patch.object(tt, "_HAS_FASTER_WHISPER", False), \ + patch.object(tt, "_HAS_OPENAI", True), \ + patch.object(tt, "_has_local_command", return_value=False), \ + patch("hermes_cli.config.load_env", + return_value={"GROQ_API_KEY": "dotenv-secret"}): + assert tt._get_provider({"enabled": True, "provider": "groq"}) == "groq" + finally: + importlib.reload(tt) + + def test_explicit_groq_sees_dotenv(self): + from tools import transcription_tools as tt + + with patch.object(tt, "_HAS_FASTER_WHISPER", False), \ + patch.object(tt, "_HAS_OPENAI", True), \ + patch.object(tt, "_has_local_command", return_value=False), \ + patch("hermes_cli.config.load_env", + return_value={"GROQ_API_KEY": "dotenv-secret"}): + assert tt._get_provider({"enabled": True, "provider": "groq"}) == "groq" + + def test_explicit_mistral_sees_dotenv(self): + from tools import transcription_tools as tt + + with patch.object(tt, "_HAS_FASTER_WHISPER", False), \ + patch.object(tt, "_HAS_MISTRAL", True), \ + patch.object(tt, "_has_local_command", return_value=False), \ + patch("hermes_cli.config.load_env", + return_value={"MISTRAL_API_KEY": "dotenv-secret"}): + assert tt._get_provider({"enabled": True, "provider": "mistral"}) == "mistral" + + def test_explicit_xai_sees_dotenv(self): + from tools import transcription_tools as tt + + with patch.object(tt, "_HAS_FASTER_WHISPER", False), \ + patch.object(tt, "_has_local_command", return_value=False), \ + patch("hermes_cli.config.load_env", + return_value={"XAI_API_KEY": "dotenv-secret"}): + assert tt._get_provider({"enabled": True, "provider": "xai"}) == "xai" + + def test_auto_detect_sees_dotenv_groq(self): + """No local backend, no explicit provider — auto-detect should fall + through to Groq when its key lives in dotenv only. Before the fix + it would return 'none'.""" + from tools import transcription_tools as tt + + with patch.object(tt, "_HAS_FASTER_WHISPER", False), \ + patch.object(tt, "_HAS_OPENAI", True), \ + patch.object(tt, "_HAS_MISTRAL", False), \ + patch.object(tt, "_has_local_command", return_value=False), \ + patch.object(tt, "_has_openai_audio_backend", return_value=False), \ + patch("hermes_cli.config.load_env", + return_value={"GROQ_API_KEY": "dotenv-secret"}): + # No "provider" key → explicit=False → auto-detect branch + assert tt._get_provider({"enabled": True}) == "groq" + + +class TestTranscribeCallSitesReadDotenv: + """The actual transcribe functions must forward the dotenv-resolved + key into the provider SDK / HTTP call. We mock ``get_env_value`` and + capture what gets passed through.""" + + def test_transcribe_groq_forwards_dotenv_key(self): + from tools import transcription_tools as tt + + seen_keys: list = [] + + class FakeOpenAIClient: + def __init__(self, *, api_key=None, base_url=None, timeout=None, max_retries=None): + seen_keys.append(api_key) + self.audio = MagicMock() + self.audio.transcriptions.create.return_value = "hello" + def close(self): + pass + + fake_openai_module = MagicMock() + fake_openai_module.OpenAI = FakeOpenAIClient + fake_openai_module.APIError = Exception + fake_openai_module.APIConnectionError = Exception + fake_openai_module.APITimeoutError = Exception + + with patch.object(tt, "get_env_value", return_value="groq-dotenv-key"), \ + patch.object(tt, "_HAS_OPENAI", True), \ + patch.dict("sys.modules", {"openai": fake_openai_module}), \ + patch("builtins.open", MagicMock()): + result = tt._transcribe_groq("/tmp/fake.mp3", "whisper-large-v3-turbo") + + assert result["success"] is True + assert seen_keys == ["groq-dotenv-key"] + + def test_transcribe_mistral_forwards_dotenv_key(self): + from tools import transcription_tools as tt + + seen_keys: list = [] + + class FakeMistralClient: + def __init__(self, *, api_key=None): + seen_keys.append(api_key) + self.audio = MagicMock() + completion = MagicMock() + completion.text = "hi" + self.audio.transcriptions.complete.return_value = completion + def __enter__(self): return self + def __exit__(self, *a): return False + + fake_client_module = MagicMock() + fake_client_module.Mistral = FakeMistralClient + + with patch.object(tt, "get_env_value", return_value="mistral-dotenv-key"), \ + patch.dict("sys.modules", {"mistralai.client": fake_client_module}), \ + patch("builtins.open", MagicMock()): + result = tt._transcribe_mistral("/tmp/fake.mp3", "voxtral-mini-latest") + + assert result["success"] is True + assert seen_keys == ["mistral-dotenv-key"] + + def test_transcribe_xai_forwards_dotenv_key(self): + from tools import transcription_tools as tt + + captured: dict = {} + + def fake_post(url, **kwargs): + captured["url"] = url + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.status_code = 200 + response.raise_for_status = MagicMock() + response.json.return_value = {"text": "hello"} + return response + + # get_env_value is consulted for both XAI_API_KEY and XAI_STT_BASE_URL. + # Return the key for the first call, None for base-url override + # (so it defaults to the module-level XAI_STT_BASE_URL). + def fake_get_env_value(name, default=None): + if name == "XAI_API_KEY": + return "xai-dotenv-key" + return None + + with patch.object(tt, "get_env_value", side_effect=fake_get_env_value), \ + patch("requests.post", side_effect=fake_post), \ + patch("builtins.open", MagicMock()): + result = tt._transcribe_xai("/tmp/fake.mp3", "grok-stt") + + assert result["success"] is True + assert captured["headers"]["Authorization"] == "Bearer xai-dotenv-key" + + +class TestEndToEndRegressionGuard: + """End-to-end probe: patch ``hermes_cli.config.load_env`` to simulate + ``~/.hermes/.env`` carrying the key while ``os.environ`` does not. + Before the fix ``_transcribe_xai`` called ``os.getenv("XAI_API_KEY")`` + directly and returned ``XAI_API_KEY not set``.""" + + def test_xai_key_only_in_dotenv_before_fix(self, monkeypatch): + from tools import transcription_tools as tt + + monkeypatch.delenv("XAI_API_KEY", raising=False) + + captured: dict = {} + + def fake_post(url, **kwargs): + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.status_code = 200 + response.raise_for_status = MagicMock() + response.json.return_value = {"text": "ok"} + return response + + with patch("hermes_cli.config.load_env", + return_value={"XAI_API_KEY": "dotenv-secret"}): + # Sanity: get_env_value resolves through load_env when + # os.environ is empty. + from hermes_cli.config import get_env_value as live_get + assert live_get("XAI_API_KEY") == "dotenv-secret" + + with patch("requests.post", side_effect=fake_post), \ + patch("builtins.open", MagicMock()): + result = tt._transcribe_xai("/tmp/fake.mp3", "grok-stt") + + assert result["success"] is True + assert captured["headers"]["Authorization"] == "Bearer dotenv-secret" diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index 50cbe22a6b..5e4a9ad716 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -758,19 +758,12 @@ class TestValidateAudioFileEdgeCases: f = tmp_path / "test.ogg" f.write_bytes(b"data") from tools.transcription_tools import _validate_audio_file - real_stat = f.stat() - call_count = 0 - def stat_side_effect(*args, **kwargs): - nonlocal call_count - call_count += 1 - # First calls are from exists() and is_file(), let them pass - if call_count <= 2: - return real_stat - raise OSError("disk error") - - with patch("pathlib.Path.stat", side_effect=stat_side_effect): + with patch("pathlib.Path.exists", return_value=True), \ + patch("pathlib.Path.is_file", return_value=True), \ + patch("pathlib.Path.stat", side_effect=OSError("disk error")): result = _validate_audio_file(str(f)) + assert result is not None assert "Failed to access" in result["error"] diff --git a/tests/tools/test_tts_command_providers.py b/tests/tools/test_tts_command_providers.py new file mode 100644 index 0000000000..583abcb588 --- /dev/null +++ b/tests/tools/test_tts_command_providers.py @@ -0,0 +1,500 @@ +""" +Tests for custom command-type TTS providers. + +These tests cover the ``tts.providers.`` registry: built-in +precedence, command resolution, placeholder rendering, shell-quote +context handling, timeout / failure cleanup, voice_compatible opt-in, +and max_text_length lookup. + +Nothing here talks to a real TTS engine. The shell command itself is +portable: we write bytes to ``{output_path}`` using ``python -c`` so +the tests run identically on Linux, macOS, and (with minor quoting +differences) Windows. +""" + +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Optional +from unittest.mock import patch + +import pytest + +from tools.tts_tool import ( + BUILTIN_TTS_PROVIDERS, + COMMAND_TTS_OUTPUT_FORMATS, + DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH, + DEFAULT_COMMAND_TTS_OUTPUT_FORMAT, + DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS, + _generate_command_tts, + _get_command_tts_output_format, + _get_command_tts_timeout, + _get_named_provider_config, + _has_any_command_tts_provider, + _is_command_provider_config, + _is_command_tts_voice_compatible, + _iter_command_providers, + _render_command_tts_template, + _resolve_command_provider_config, + _resolve_max_text_length, + _shell_quote_context, + check_tts_requirements, + text_to_speech_tool, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _python_copy_command(output_placeholder: str = "{output_path}") -> str: + """Return a cross-platform shell command that copies {input_path} -> output.""" + interpreter = sys.executable + return ( + f'"{interpreter}" -c "import shutil, sys; ' + f'shutil.copyfile(sys.argv[1], sys.argv[2])" ' + f'{{input_path}} {output_placeholder}' + ) + + +# --------------------------------------------------------------------------- +# _resolve_command_provider_config / built-in precedence +# --------------------------------------------------------------------------- + +class TestResolveCommandProviderConfig: + def test_builtin_names_are_never_command_providers(self): + cfg = { + "providers": { + "openai": {"type": "command", "command": "echo hi"}, + "edge": {"type": "command", "command": "echo hi"}, + }, + } + for name in BUILTIN_TTS_PROVIDERS: + assert _resolve_command_provider_config(name, cfg) is None + + def test_missing_provider_returns_none(self): + cfg = {"providers": {}} + assert _resolve_command_provider_config("nope", cfg) is None + + def test_user_declared_command_provider_resolves(self): + cfg = { + "providers": { + "piper-cli": {"type": "command", "command": "piper-cli foo"}, + }, + } + resolved = _resolve_command_provider_config("piper-cli", cfg) + assert resolved is not None + assert resolved["command"] == "piper-cli foo" + + def test_type_command_is_implied_when_command_is_set(self): + cfg = {"providers": {"piper-cli": {"command": "piper-cli foo"}}} + resolved = _resolve_command_provider_config("piper-cli", cfg) + assert resolved is not None + + def test_other_type_values_reject(self): + cfg = {"providers": {"piper-cli": {"type": "python", "command": "piper-cli foo"}}} + assert _resolve_command_provider_config("piper-cli", cfg) is None + + def test_empty_command_rejects(self): + cfg = {"providers": {"piper-cli": {"type": "command", "command": " "}}} + assert _resolve_command_provider_config("piper-cli", cfg) is None + + def test_case_insensitive_lookup(self): + cfg = {"providers": {"piper-cli": {"type": "command", "command": "x"}}} + assert _resolve_command_provider_config("PIPER-CLI", cfg) is not None + + def test_native_piper_cannot_be_shadowed_by_command_entry(self): + """Regression guard for PR that added native Piper as a built-in. + A user's ``tts.providers.piper`` must not override the built-in.""" + cfg = { + "providers": { + "piper": {"type": "command", "command": "some-script"}, + }, + } + assert _resolve_command_provider_config("piper", cfg) is None + + +class TestGetNamedProviderConfig: + def test_providers_block_wins(self): + cfg = {"providers": {"voxcpm": {"command": "new"}}, + "voxcpm": {"command": "legacy"}} + assert _get_named_provider_config(cfg, "voxcpm") == {"command": "new"} + + def test_legacy_tts_name_block_still_resolves(self): + cfg = {"voxcpm": {"type": "command", "command": "legacy"}} + assert _get_named_provider_config(cfg, "voxcpm") == { + "type": "command", "command": "legacy" + } + + def test_builtin_names_do_not_leak_through_legacy_path(self): + """``tts.openai`` must never be mistaken for a command provider.""" + cfg = {"openai": {"command": "oops", "type": "command"}} + assert _get_named_provider_config(cfg, "openai") == {} + + +class TestIsCommandProviderConfig: + def test_empty_dict_is_false(self): + assert _is_command_provider_config({}) is False + + def test_non_dict_is_false(self): + assert _is_command_provider_config("foo") is False + assert _is_command_provider_config(None) is False + + def test_type_mismatch_is_false(self): + assert _is_command_provider_config({"type": "native", "command": "x"}) is False + + +# --------------------------------------------------------------------------- +# _iter_command_providers / _has_any_command_tts_provider +# --------------------------------------------------------------------------- + +class TestIterCommandProviders: + def test_iterates_only_user_command_providers(self): + cfg = { + "providers": { + "openai": {"type": "command", "command": "shouldnt show up"}, + "piper-cli": {"type": "command", "command": "piper-cli"}, + "voxcpm": {"type": "command", "command": "voxcpm"}, + "broken": {"type": "command", "command": ""}, + }, + } + names = sorted(name for name, _ in _iter_command_providers(cfg)) + assert names == ["piper-cli", "voxcpm"] + + def test_has_any_command_provider_detects_declared(self): + cfg = {"providers": {"piper-cli": {"type": "command", "command": "piper-cli"}}} + assert _has_any_command_tts_provider(cfg) is True + + def test_has_any_command_provider_when_none(self): + assert _has_any_command_tts_provider({"providers": {}}) is False + assert _has_any_command_tts_provider({}) is False + + +# --------------------------------------------------------------------------- +# config getters +# --------------------------------------------------------------------------- + +class TestConfigGetters: + def test_timeout_defaults(self): + assert _get_command_tts_timeout({}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + + def test_timeout_coerces_string(self): + assert _get_command_tts_timeout({"timeout": "45"}) == 45.0 + + def test_timeout_rejects_non_positive(self): + assert _get_command_tts_timeout({"timeout": 0}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + assert _get_command_tts_timeout({"timeout": -1}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + + def test_timeout_rejects_garbage(self): + assert _get_command_tts_timeout({"timeout": "fast"}) == float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + + def test_timeout_seconds_alias(self): + assert _get_command_tts_timeout({"timeout_seconds": 90}) == 90.0 + + def test_output_format_defaults(self): + assert _get_command_tts_output_format({}) == DEFAULT_COMMAND_TTS_OUTPUT_FORMAT + + def test_output_format_path_override(self): + assert _get_command_tts_output_format({}, "/tmp/clip.wav") == "wav" + + def test_output_format_unknown_path_falls_back_to_config(self): + assert _get_command_tts_output_format({"format": "ogg"}, "/tmp/clip.xyz") == "ogg" + + def test_output_format_rejects_unknown(self): + assert _get_command_tts_output_format({"format": "m4a"}) == DEFAULT_COMMAND_TTS_OUTPUT_FORMAT + + def test_output_format_supported_set(self): + assert COMMAND_TTS_OUTPUT_FORMATS == frozenset({"mp3", "wav", "ogg", "flac"}) + + def test_voice_compatible_boolean(self): + assert _is_command_tts_voice_compatible({"voice_compatible": True}) is True + assert _is_command_tts_voice_compatible({"voice_compatible": False}) is False + + def test_voice_compatible_string(self): + assert _is_command_tts_voice_compatible({"voice_compatible": "yes"}) is True + assert _is_command_tts_voice_compatible({"voice_compatible": "0"}) is False + + def test_voice_compatible_default_off(self): + assert _is_command_tts_voice_compatible({}) is False + + +# --------------------------------------------------------------------------- +# _resolve_max_text_length for command providers +# --------------------------------------------------------------------------- + +class TestMaxTextLengthForCommandProviders: + def test_default_for_command_provider(self): + cfg = {"providers": {"piper-cli": {"type": "command", "command": "x"}}} + assert _resolve_max_text_length("piper-cli", cfg) == DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH + + def test_override_under_providers(self): + cfg = {"providers": {"piper-cli": {"type": "command", "command": "x", "max_text_length": 2500}}} + assert _resolve_max_text_length("piper-cli", cfg) == 2500 + + def test_override_under_legacy_tts_name_block(self): + cfg = {"piper-cli": {"type": "command", "command": "x", "max_text_length": 7777}} + assert _resolve_max_text_length("piper-cli", cfg) == 7777 + + def test_non_command_unknown_provider_still_falls_back(self): + assert _resolve_max_text_length("unknown", {}) > 0 + + +# --------------------------------------------------------------------------- +# _shell_quote_context / template rendering +# --------------------------------------------------------------------------- + +class TestShellQuoteContext: + def test_bare_context(self): + tpl = 'tts {output_path}' + pos = tpl.index("{output_path}") + assert _shell_quote_context(tpl, pos) is None + + def test_inside_single_quotes(self): + tpl = "tts '{output_path}'" + pos = tpl.index("{output_path}") + assert _shell_quote_context(tpl, pos) == "'" + + def test_inside_double_quotes(self): + tpl = 'tts "{output_path}"' + pos = tpl.index("{output_path}") + assert _shell_quote_context(tpl, pos) == '"' + + def test_escaped_double_quote_inside_double(self): + tpl = r'tts "foo \" {output_path}"' + pos = tpl.index("{output_path}") + assert _shell_quote_context(tpl, pos) == '"' + + +class TestRenderCommandTtsTemplate: + def test_substitutes_all_placeholders(self): + placeholders = { + "input_path": "/tmp/in.txt", + "text_path": "/tmp/in.txt", + "output_path": "/tmp/out.mp3", + "format": "mp3", + "voice": "af_sky", + "model": "tiny", + "speed": "1.0", + } + rendered = _render_command_tts_template( + "tts --voice {voice} --in {input_path} --out {output_path}", + placeholders, + ) + assert "af_sky" in rendered + assert "/tmp/out.mp3" in rendered + + def test_quotes_paths_with_spaces(self): + placeholders = { + "input_path": "/tmp/Jane Doe/in.txt", + "text_path": "/tmp/Jane Doe/in.txt", + "output_path": "/tmp/out.mp3", + "format": "mp3", + "voice": "", + "model": "", + "speed": "1.0", + } + rendered = _render_command_tts_template( + "tts --in {input_path} --out {output_path}", + placeholders, + ) + # shlex.quote wraps space-containing paths in single quotes on POSIX. + if os.name != "nt": + assert "'/tmp/Jane Doe/in.txt'" in rendered + + def test_literal_braces_survive(self): + placeholders = { + "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt", + "output_path": "/tmp/out.mp3", "format": "mp3", + "voice": "", "model": "", "speed": "1.0", + } + rendered = _render_command_tts_template( + "echo '{{not a placeholder}}' && tts --in {input_path}", + placeholders, + ) + assert "{not a placeholder}" in rendered + + def test_injection_is_neutralized(self): + """Embedded shell metacharacters in a placeholder value must be quoted.""" + placeholders = { + "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt", + "output_path": "/tmp/out; rm -rf /", + "format": "mp3", + "voice": "$(whoami)", "model": "", "speed": "1.0", + } + rendered = _render_command_tts_template( + "tts --voice {voice} --out {output_path}", + placeholders, + ) + # The injection payload must not appear unquoted in the rendered + # command. On POSIX shlex.quote wraps the value in single quotes. + if os.name != "nt": + assert "'$(whoami)'" in rendered or "'\\''" in rendered + assert "; rm -rf /" not in rendered.replace( + "'/tmp/out; rm -rf /'", "", + ) + + def test_preserves_shell_quoting_style(self): + placeholders = { + "input_path": "/tmp/in.txt", "text_path": "/tmp/in.txt", + "output_path": "/tmp/out.mp3", "format": "mp3", + "voice": "bob's voice", "model": "", "speed": "1.0", + } + # When the template wraps the placeholder in double quotes we must + # escape for that context, not collapse to single-quoted form. + rendered = _render_command_tts_template( + 'tts --voice "{voice}"', + placeholders, + ) + assert '"bob\'s voice"' in rendered + + +# --------------------------------------------------------------------------- +# End-to-end: _generate_command_tts +# --------------------------------------------------------------------------- + +class TestGenerateCommandTts: + def test_writes_output_file(self, tmp_path): + out = tmp_path / "clip.mp3" + config = {"command": _python_copy_command()} + result = _generate_command_tts( + "hello world", + str(out), + "py-copy", + config, + {}, + ) + assert result == str(out) + assert out.exists() + # The command copied the input text file over to output, so it + # contains the original UTF-8 text. + assert out.read_text(encoding="utf-8") == "hello world" + + def test_empty_command_raises(self, tmp_path): + with pytest.raises(ValueError, match="is not configured"): + _generate_command_tts( + "hello", + str(tmp_path / "x.mp3"), + "empty", + {"command": " "}, + {}, + ) + + def test_nonzero_exit_raises_runtime(self, tmp_path): + config = {"command": f'"{sys.executable}" -c "import sys; sys.exit(3)"'} + with pytest.raises(RuntimeError, match="exited with code 3"): + _generate_command_tts( + "hello", + str(tmp_path / "x.mp3"), + "failing", + config, + {}, + ) + + def test_empty_output_raises_runtime(self, tmp_path): + # This command completes successfully but writes nothing. + config = {"command": f'"{sys.executable}" -c "pass"'} + with pytest.raises(RuntimeError, match="produced no output"): + _generate_command_tts( + "hello", + str(tmp_path / "x.mp3"), + "silent", + config, + {}, + ) + + @pytest.mark.skipif(os.name == "nt", reason="POSIX-only timeout semantics") + def test_timeout_raises_runtime(self, tmp_path): + config = { + "command": f'"{sys.executable}" -c "import time; time.sleep(10)"', + "timeout": 1, + } + with pytest.raises(RuntimeError, match="timed out"): + _generate_command_tts( + "hello", + str(tmp_path / "x.mp3"), + "slow", + config, + {}, + ) + + +# --------------------------------------------------------------------------- +# text_to_speech_tool integration +# --------------------------------------------------------------------------- + +class TestTextToSpeechToolWithCommandProvider: + def test_command_provider_dispatches_end_to_end(self, tmp_path): + cfg = { + "tts": { + "provider": "py-copy", + "providers": { + "py-copy": { + "type": "command", + "command": _python_copy_command(), + "output_format": "mp3", + }, + }, + }, + } + out = tmp_path / "clip.mp3" + + # Patch the config loader used by the tool so we don't touch disk. + def fake_load(): + return cfg["tts"] + + with patch("tools.tts_tool._load_tts_config", fake_load): + result = text_to_speech_tool(text="hi", output_path=str(out)) + data = json.loads(result) + assert data["success"] is True, data + assert data["provider"] == "py-copy" + assert data["voice_compatible"] is False + assert Path(data["file_path"]).exists() + + def test_voice_compatible_opt_in_toggles_flag(self, tmp_path): + """voice_compatible=true is reflected in the response when the + file is already .ogg (no ffmpeg needed).""" + cfg = { + "provider": "py-copy-ogg", + "providers": { + "py-copy-ogg": { + "type": "command", + "command": _python_copy_command(), + "output_format": "ogg", + "voice_compatible": True, + }, + }, + } + out = tmp_path / "clip.ogg" + + with patch("tools.tts_tool._load_tts_config", return_value=cfg): + result = text_to_speech_tool(text="hi", output_path=str(out)) + data = json.loads(result) + assert data["success"] is True + assert data["voice_compatible"] is True + assert data["media_tag"].startswith("[[audio_as_voice]]") + + def test_missing_command_falls_through_to_builtin(self, tmp_path): + """A provider entry with an empty command is not a command + provider; the tool should not raise a "command not configured" + error but fall through to the built-in resolution path.""" + cfg = { + "provider": "broken", + "providers": { + "broken": {"type": "command", "command": " "}, + }, + } + with patch("tools.tts_tool._load_tts_config", return_value=cfg): + result = text_to_speech_tool(text="hi", output_path=str(tmp_path / "x.mp3")) + data = json.loads(result) + # The response should not carry the command-provider error text. + err = (data.get("error") or "").lower() + assert "tts.providers.broken.command is not configured" not in err + + +class TestCheckTtsRequirements: + def test_configured_command_provider_satisfies_requirement(self): + cfg = {"providers": {"x": {"type": "command", "command": "echo x"}}} + with patch("tools.tts_tool._load_tts_config", return_value=cfg): + assert check_tts_requirements() is True diff --git a/tests/tools/test_tts_dotenv_fallback.py b/tests/tools/test_tts_dotenv_fallback.py new file mode 100644 index 0000000000..0508320870 --- /dev/null +++ b/tests/tools/test_tts_dotenv_fallback.py @@ -0,0 +1,272 @@ +"""Regression tests for #17140. + +TTS provider tools must resolve API keys from ``~/.hermes/.env`` (via +``hermes_cli.config.get_env_value``) and not only from ``os.environ`` — +otherwise users who keep their keys in the dotenv file see "API key not set" +errors even though the key is configured. Same class of bug as #15914 (auth) +already addressed for ``agent/credential_pool`` and ``hermes_cli/auth``. +""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def isolate_env(monkeypatch): + """Strip every TTS-related env var so the test really exercises the + dotenv code path. If any of these survive into the test, the assertion + that ``get_env_value`` was consulted becomes meaningless because + ``os.environ`` already satisfies the lookup. + """ + for key in ( + "ELEVENLABS_API_KEY", + "XAI_API_KEY", + "XAI_BASE_URL", + "MINIMAX_API_KEY", + "MISTRAL_API_KEY", + "GEMINI_API_KEY", + "GEMINI_BASE_URL", + "GOOGLE_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + +class TestDotenvFallbackPerProvider: + """For each affected provider, when only ``~/.hermes/.env`` carries the + key, the provider must find it. These per-provider tests model that + dotenv-backed lookup by mocking ``tools.tts_tool.get_env_value`` directly; + the separate regression-guard tests cover the lower-level + ``hermes_cli.config.load_env`` integration. Before the fix, ``os.getenv`` + returned ``None`` and the provider raised + ``ValueError("X_API_KEY not set")``. + """ + + def test_elevenlabs_reads_dotenv_key(self, tmp_path): + from tools import tts_tool + + with patch.object(tts_tool, "get_env_value", return_value="el-dotenv-key"), \ + patch.object(tts_tool, "_import_elevenlabs") as mock_import: + mock_client = MagicMock() + mock_client.text_to_speech.convert.return_value = iter([b"audio"]) + mock_import.return_value = MagicMock(return_value=mock_client) + + output = str(tmp_path / "out.mp3") + tts_tool._generate_elevenlabs("hi", output, {}) + + mock_import.return_value.assert_called_once_with(api_key="el-dotenv-key") + + def test_xai_reads_dotenv_key(self, tmp_path): + from tools import tts_tool + + captured: dict = {} + + def fake_post(url, **kwargs): + captured["url"] = url + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.content = b"audio" + response.raise_for_status = MagicMock() + return response + + with patch.object(tts_tool, "get_env_value", return_value="xai-dotenv-key"), \ + patch("requests.post", side_effect=fake_post): + tts_tool._generate_xai_tts("hi", str(tmp_path / "out.mp3"), {}) + + assert captured["headers"]["Authorization"] == "Bearer xai-dotenv-key" + + def test_minimax_reads_dotenv_key(self, tmp_path): + from tools import tts_tool + + captured: dict = {} + + def fake_post(url, **kwargs): + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.json.return_value = { + "data": {"audio": b"\x00\x01".hex()}, + "base_resp": {"status_code": 0}, + } + response.raise_for_status = MagicMock() + return response + + with patch.object(tts_tool, "get_env_value", return_value="mm-dotenv-key"), \ + patch("requests.post", side_effect=fake_post): + tts_tool._generate_minimax_tts("hi", str(tmp_path / "out.mp3"), {}) + + assert captured["headers"]["Authorization"] == "Bearer mm-dotenv-key" + + def test_mistral_reads_dotenv_key(self, tmp_path): + import base64 + + from tools import tts_tool + + seen_keys: list = [] + + def fake_mistral_factory(*, api_key=None): + seen_keys.append(api_key) + client = MagicMock() + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + client.audio.speech.complete.return_value = MagicMock( + audio_data=base64.b64encode(b"data").decode() + ) + return client + + with patch.object(tts_tool, "get_env_value", return_value="mistral-dotenv-key"), \ + patch.object(tts_tool, "_import_mistral_client", return_value=fake_mistral_factory): + tts_tool._generate_mistral_tts("hi", str(tmp_path / "out.mp3"), {}) + + assert seen_keys == ["mistral-dotenv-key"] + + def test_gemini_reads_dotenv_key(self, tmp_path): + from tools import tts_tool + + captured: dict = {} + + def fake_post(url, **kwargs): + captured["params"] = kwargs.get("params", {}) + response = MagicMock() + response.status_code = 200 + response.json.return_value = { + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "data": "AAAA", + "mimeType": "audio/L16;codec=pcm;rate=24000", + } + } + ] + } + } + ] + } + response.raise_for_status = MagicMock() + return response + + # GEMINI_API_KEY hits the first branch; GOOGLE_API_KEY would only be + # consulted if the first returned None. Use a side-effect-style mock + # to verify the lookup order matches the production code. + seen_lookups: list = [] + + def fake_get_env_value(key): + seen_lookups.append(key) + if key == "GEMINI_API_KEY": + return "gemini-dotenv-key" + return None + + with patch.object(tts_tool, "get_env_value", side_effect=fake_get_env_value), \ + patch("requests.post", side_effect=fake_post): + tts_tool._generate_gemini_tts("hi", str(tmp_path / "out.wav"), {}) + + assert "GEMINI_API_KEY" in seen_lookups + assert captured["params"]["key"] == "gemini-dotenv-key" + + +class TestRegressionGuard: + """Goal-backward proof that the old behaviour ('only check ``os.environ``') + breaks reading from a dotenv-only key, and the new behaviour fixes it. + Implemented as an end-to-end probe that patches + ``hermes_cli.config.load_env`` to simulate ``~/.hermes/.env`` carrying the + key while ``os.environ`` does not. + """ + + def test_import_after_config_env_patch_uses_restored_dotenv_loader(self, tmp_path, monkeypatch): + """Importing TTS while hermes_cli.config.get_env_value is patched must + not freeze that temporary helper into this module forever. + """ + import importlib + import hermes_cli.config as config_mod + from tools import tts_tool + + monkeypatch.delenv("MINIMAX_API_KEY", raising=False) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(config_mod, "get_env_value", lambda name: "") + tts_tool = importlib.reload(tts_tool) + + try: + captured: dict = {} + + def fake_post(url, **kwargs): + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.json.return_value = { + "data": {"audio": b"\x00".hex()}, + "base_resp": {"status_code": 0}, + } + response.raise_for_status = MagicMock() + return response + + with patch( + "hermes_cli.config.load_env", + return_value={"MINIMAX_API_KEY": "dotenv-secret"}, + ), patch("requests.post", side_effect=fake_post): + tts_tool._generate_minimax_tts( + "hi", str(tmp_path / "out.mp3"), {} + ) + + assert captured["headers"]["Authorization"] == "Bearer dotenv-secret" + finally: + importlib.reload(tts_tool) + + def test_minimax_missing_when_only_in_dotenv_before_fix(self, tmp_path, monkeypatch): + from tools import tts_tool + + monkeypatch.delenv("MINIMAX_API_KEY", raising=False) + + # Simulate ~/.hermes/.env carrying the key (load_env returns the dict + # that get_env_value falls back to). The pre-fix ``os.getenv`` call + # ignores this entirely and raises ValueError. + with patch( + "hermes_cli.config.load_env", + return_value={"MINIMAX_API_KEY": "dotenv-secret"}, + ): + # Sanity-check: get_env_value resolves through load_env when + # os.environ is empty. + from hermes_cli.config import get_env_value as live_get + assert live_get("MINIMAX_API_KEY") == "dotenv-secret" + + # And the production code path now consumes the resolved value + # instead of raising "MINIMAX_API_KEY not set". + captured: dict = {} + + def fake_post(url, **kwargs): + captured["headers"] = kwargs.get("headers", {}) + response = MagicMock() + response.json.return_value = { + "data": {"audio": b"\x00".hex()}, + "base_resp": {"status_code": 0}, + } + response.raise_for_status = MagicMock() + return response + + with patch("requests.post", side_effect=fake_post): + tts_tool._generate_minimax_tts( + "hi", str(tmp_path / "out.mp3"), {} + ) + + assert captured["headers"]["Authorization"] == "Bearer dotenv-secret" + + def test_check_tts_requirements_sees_dotenv_minimax(self, monkeypatch): + """``check_tts_requirements`` is the gate that decides whether + ``/voice on`` is even offered. If it only checked ``os.environ`` it + would say "no provider available" for users who keep MINIMAX_API_KEY + in ``~/.hermes/.env``, even though the dispatcher would later succeed. + """ + from tools import tts_tool + + monkeypatch.delenv("MINIMAX_API_KEY", raising=False) + + with patch( + "hermes_cli.config.load_env", + return_value={"MINIMAX_API_KEY": "dotenv-secret"}, + ), patch.object(tts_tool, "_import_edge_tts", side_effect=ImportError), \ + patch.object(tts_tool, "_import_elevenlabs", side_effect=ImportError), \ + patch.object(tts_tool, "_import_openai_client", side_effect=ImportError), \ + patch.object(tts_tool, "_check_neutts_available", return_value=False), \ + patch.object(tts_tool, "_check_kittentts_available", return_value=False): + assert tts_tool.check_tts_requirements() is True diff --git a/tests/tools/test_tts_mistral.py b/tests/tools/test_tts_mistral.py index 36088f3f0a..6e98946b6c 100644 --- a/tests/tools/test_tts_mistral.py +++ b/tests/tools/test_tts_mistral.py @@ -216,5 +216,8 @@ class TestCheckTtsRequirementsMistral: with patch("tools.tts_tool._import_edge_tts", side_effect=ImportError), \ patch("tools.tts_tool._import_elevenlabs", side_effect=ImportError), \ patch("tools.tts_tool._import_openai_client", side_effect=ImportError), \ - patch("tools.tts_tool._check_neutts_available", return_value=False): + patch("tools.tts_tool._check_neutts_available", return_value=False), \ + patch("tools.tts_tool._check_kittentts_available", return_value=False), \ + patch("tools.tts_tool._check_piper_available", return_value=False), \ + patch("tools.tts_tool._has_any_command_tts_provider", return_value=False): assert check_tts_requirements() is False diff --git a/tests/tools/test_tts_piper.py b/tests/tools/test_tts_piper.py new file mode 100644 index 0000000000..ef7330a18c --- /dev/null +++ b/tests/tools/test_tts_piper.py @@ -0,0 +1,306 @@ +""" +Tests for the native Piper TTS provider. + +These tests pin the resolution / caching / dispatch paths for Piper +without requiring the ``piper-tts`` package to actually be installed +(the synthesis step is monkey-patched to avoid needing the ONNX wheel). +""" + +import json +import os +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from tools import tts_tool +from tools.tts_tool import ( + BUILTIN_TTS_PROVIDERS, + DEFAULT_PIPER_VOICE, + PROVIDER_MAX_TEXT_LENGTH, + _check_piper_available, + _resolve_piper_voice_path, + check_tts_requirements, + text_to_speech_tool, +) + + +# --------------------------------------------------------------------------- +# Registry / constants +# --------------------------------------------------------------------------- + +class TestPiperRegistration: + def test_piper_is_a_builtin_provider(self): + assert "piper" in BUILTIN_TTS_PROVIDERS + + def test_piper_has_a_text_length_cap(self): + assert PROVIDER_MAX_TEXT_LENGTH.get("piper", 0) > 0 + + +# --------------------------------------------------------------------------- +# _check_piper_available +# --------------------------------------------------------------------------- + +class TestCheckPiperAvailable: + def test_returns_bool_without_raising(self): + # We don't care about the current environment's answer — just that + # the probe never raises on a machine without piper installed. + assert isinstance(_check_piper_available(), bool) + + +# --------------------------------------------------------------------------- +# _resolve_piper_voice_path +# --------------------------------------------------------------------------- + +class TestResolvePiperVoicePath: + def test_direct_onnx_path_returned_as_is(self, tmp_path): + model = tmp_path / "custom.onnx" + model.write_bytes(b"fake onnx bytes") + result = _resolve_piper_voice_path(str(model), tmp_path) + assert result == str(model) + + def test_cached_voice_name_not_redownloaded(self, tmp_path): + """If both .onnx and .onnx.json exist in the + download dir, no subprocess is spawned.""" + voice = "en_US-test-medium" + (tmp_path / f"{voice}.onnx").write_bytes(b"model") + (tmp_path / f"{voice}.onnx.json").write_text("{}") + + with patch("tools.tts_tool.subprocess.run") as mock_run: + result = _resolve_piper_voice_path(voice, tmp_path) + + mock_run.assert_not_called() + assert result == str(tmp_path / f"{voice}.onnx") + + def test_missing_voice_triggers_download(self, tmp_path): + voice = "en_US-new-medium" + + def fake_run(cmd, *a, **kw): + # Simulate a successful download: write the expected files. + (tmp_path / f"{voice}.onnx").write_bytes(b"model") + (tmp_path / f"{voice}.onnx.json").write_text("{}") + return MagicMock(returncode=0, stderr="", stdout="") + + with patch("tools.tts_tool.subprocess.run", side_effect=fake_run) as mock_run: + result = _resolve_piper_voice_path(voice, tmp_path) + + mock_run.assert_called_once() + # Verify the command shape: python -m piper.download_voices --download-dir + call_args = mock_run.call_args.args[0] + assert "piper.download_voices" in " ".join(call_args) + assert voice in call_args + assert "--download-dir" in call_args + assert str(tmp_path) in call_args + assert result == str(tmp_path / f"{voice}.onnx") + + def test_download_failure_raises_runtime(self, tmp_path): + voice = "en_US-broken-medium" + fake_result = MagicMock(returncode=1, stderr="voice not found", stdout="") + with patch("tools.tts_tool.subprocess.run", return_value=fake_result): + with pytest.raises(RuntimeError, match="Piper voice download failed"): + _resolve_piper_voice_path(voice, tmp_path) + + def test_download_success_but_missing_file_raises(self, tmp_path): + voice = "en_US-weird-medium" + fake_result = MagicMock(returncode=0, stderr="", stdout="") + # Subprocess "succeeds" but doesn't actually write the files. + with patch("tools.tts_tool.subprocess.run", return_value=fake_result): + with pytest.raises(RuntimeError, match="completed but .+ is missing"): + _resolve_piper_voice_path(voice, tmp_path) + + def test_empty_voice_falls_back_to_default_name(self, tmp_path): + (tmp_path / f"{DEFAULT_PIPER_VOICE}.onnx").write_bytes(b"model") + (tmp_path / f"{DEFAULT_PIPER_VOICE}.onnx.json").write_text("{}") + result = _resolve_piper_voice_path("", tmp_path) + assert result.endswith(f"{DEFAULT_PIPER_VOICE}.onnx") + + +# --------------------------------------------------------------------------- +# _generate_piper_tts — stubbed so we don't need piper-tts installed +# --------------------------------------------------------------------------- + +class _StubPiperVoice: + """Stand-in for piper.PiperVoice used by the synthesis tests.""" + + loaded: list[str] = [] + calls: list[tuple] = [] + + @classmethod + def load(cls, model_path, use_cuda=False): + cls.loaded.append(model_path) + instance = cls() + instance.model_path = model_path + instance.use_cuda = use_cuda + return instance + + def synthesize_wav(self, text, wav_file, syn_config=None): + # Minimal valid WAV: an empty frame set is fine for our size check. + # The wave module accepts any frames; we just need the file to exist + # with non-zero bytes after close. + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(22050) + wav_file.writeframes(b"\x00\x00" * 1024) + _StubPiperVoice.calls.append((text, getattr(self, "model_path", ""), syn_config)) + + +@pytest.fixture(autouse=True) +def _reset_piper_cache(): + """Clear the module-level voice cache between tests.""" + tts_tool._piper_voice_cache.clear() + _StubPiperVoice.loaded = [] + _StubPiperVoice.calls = [] + yield + tts_tool._piper_voice_cache.clear() + + +class TestGeneratePiperTts: + def _prepare_voice_files(self, tmp_path, voice=DEFAULT_PIPER_VOICE): + model = tmp_path / f"{voice}.onnx" + model.write_bytes(b"model") + (tmp_path / f"{voice}.onnx.json").write_text("{}") + return model + + def test_loads_voice_and_writes_wav(self, tmp_path, monkeypatch): + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + out_path = str(tmp_path / "out.wav") + config = {"piper": {"voice": str(model)}} + + result = tts_tool._generate_piper_tts("hello", out_path, config) + + assert result == out_path + assert Path(out_path).exists() + assert Path(out_path).stat().st_size > 0 + assert _StubPiperVoice.loaded == [str(model)] + assert _StubPiperVoice.calls[0][0] == "hello" + + def test_voice_cache_reused_across_calls(self, tmp_path, monkeypatch): + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + config = {"piper": {"voice": str(model)}} + tts_tool._generate_piper_tts("one", str(tmp_path / "a.wav"), config) + tts_tool._generate_piper_tts("two", str(tmp_path / "b.wav"), config) + + # load() should have been called exactly once for the same model+cuda key. + assert _StubPiperVoice.loaded == [str(model)] + # But both synthesize calls went through. + assert [c[0] for c in _StubPiperVoice.calls] == ["one", "two"] + + def test_voice_name_triggers_download(self, tmp_path, monkeypatch): + """A config voice of ``en_US-lessac-medium`` should be resolved via + _resolve_piper_voice_path (which would normally download).""" + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + def fake_resolve(voice, download_dir): + model = download_dir / f"{voice}.onnx" + model.write_bytes(b"model") + return str(model) + + monkeypatch.setattr(tts_tool, "_resolve_piper_voice_path", fake_resolve) + + config = {"piper": {"voice": "en_US-lessac-medium", "voices_dir": str(tmp_path)}} + result = tts_tool._generate_piper_tts("hi", str(tmp_path / "out.wav"), config) + + assert Path(result).exists() + assert _StubPiperVoice.loaded[0].endswith("en_US-lessac-medium.onnx") + + def test_advanced_knobs_passed_as_synconfig(self, tmp_path, monkeypatch): + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + # Fake SynthesisConfig so we can assert the knobs flowed through. + fake_syn_cls = MagicMock() + + class FakePiperModule: + SynthesisConfig = fake_syn_cls + + # The SynthesisConfig import happens inline inside _generate_piper_tts + # via ``from piper import SynthesisConfig``. Inject a fake piper + # module so that import resolves. + monkeypatch.setitem(sys.modules, "piper", FakePiperModule) + + config = { + "piper": { + "voice": str(model), + "length_scale": 2.0, + "volume": 0.8, + }, + } + tts_tool._generate_piper_tts( + "slow voice", str(tmp_path / "out.wav"), config, + ) + + # SynthesisConfig was constructed with the advanced knobs. + fake_syn_cls.assert_called_once() + kwargs = fake_syn_cls.call_args.kwargs + assert kwargs["length_scale"] == 2.0 + assert kwargs["volume"] == 0.8 + + +# --------------------------------------------------------------------------- +# text_to_speech_tool end-to-end (provider == "piper") +# --------------------------------------------------------------------------- + +class TestTextToSpeechToolWithPiper: + def test_dispatches_to_piper(self, tmp_path, monkeypatch): + model = tmp_path / f"{DEFAULT_PIPER_VOICE}.onnx" + model.write_bytes(b"model") + (tmp_path / f"{DEFAULT_PIPER_VOICE}.onnx.json").write_text("{}") + + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + cfg = {"provider": "piper", "piper": {"voice": str(model)}} + monkeypatch.setattr(tts_tool, "_load_tts_config", lambda: cfg) + + result = text_to_speech_tool(text="hi", output_path=str(tmp_path / "clip.wav")) + data = json.loads(result) + + assert data["success"] is True, data + assert data["provider"] == "piper" + assert Path(data["file_path"]).exists() + + def test_missing_package_surfaces_error(self, tmp_path, monkeypatch): + def raise_import(): + raise ImportError("No module named 'piper'") + + monkeypatch.setattr(tts_tool, "_import_piper", raise_import) + + cfg = {"provider": "piper"} + monkeypatch.setattr(tts_tool, "_load_tts_config", lambda: cfg) + + result = text_to_speech_tool(text="hi", output_path=str(tmp_path / "clip.wav")) + data = json.loads(result) + + assert data["success"] is False + assert "piper-tts" in data["error"] + + +# --------------------------------------------------------------------------- +# check_tts_requirements +# --------------------------------------------------------------------------- + +class TestCheckTtsRequirementsPiper: + def test_piper_install_satisfies_requirements(self, monkeypatch): + # Drop every other provider so we can isolate the piper signal. + monkeypatch.setattr(tts_tool, "_import_edge_tts", lambda: (_ for _ in ()).throw(ImportError())) + monkeypatch.setattr(tts_tool, "_import_elevenlabs", lambda: (_ for _ in ()).throw(ImportError())) + monkeypatch.setattr(tts_tool, "_import_openai_client", lambda: (_ for _ in ()).throw(ImportError())) + monkeypatch.setattr(tts_tool, "_import_mistral_client", lambda: (_ for _ in ()).throw(ImportError())) + monkeypatch.setattr(tts_tool, "_check_neutts_available", lambda: False) + monkeypatch.setattr(tts_tool, "_check_kittentts_available", lambda: False) + monkeypatch.setattr(tts_tool, "_has_any_command_tts_provider", lambda: False) + monkeypatch.setattr(tts_tool, "_has_openai_audio_backend", lambda: False) + for env in ("MINIMAX_API_KEY", "XAI_API_KEY", "GEMINI_API_KEY", + "GOOGLE_API_KEY", "MISTRAL_API_KEY", "ELEVENLABS_API_KEY"): + monkeypatch.delenv(env, raising=False) + + # Now toggle the piper check on and off. + monkeypatch.setattr(tts_tool, "_check_piper_available", lambda: False) + assert check_tts_requirements() is False + + monkeypatch.setattr(tts_tool, "_check_piper_available", lambda: True) + assert check_tts_requirements() is True diff --git a/tests/tools/test_vercel_sandbox_environment.py b/tests/tools/test_vercel_sandbox_environment.py new file mode 100644 index 0000000000..944621fe89 --- /dev/null +++ b/tests/tools/test_vercel_sandbox_environment.py @@ -0,0 +1,623 @@ +"""Unit tests for the Vercel Sandbox terminal backend.""" + +from __future__ import annotations + +import importlib +import io +import re +import sys +import tarfile +import threading +import types +from dataclasses import dataclass +from enum import StrEnum +from pathlib import Path +from types import SimpleNamespace + +import pytest + + +class _FakeRunResult: + def __init__(self, output: str | bytes = "", exit_code: int = 0): + self._output = output + self.exit_code = exit_code + + def output(self) -> str | bytes: + return self._output + + +class _FakeSandboxStatus(StrEnum): + PENDING = "pending" + RUNNING = "running" + STOPPING = "stopping" + STOPPED = "stopped" + FAILED = "failed" + ABORTED = "aborted" + SNAPSHOTTING = "snapshotting" + + +@dataclass(frozen=True) +class _FakeSnapshot: + snapshot_id: str + + +class _FakeSandbox: + def __init__( + self, + *, + cwd: str = "/vercel/sandbox", + home: str = "/home/vercel", + status: _FakeSandboxStatus = _FakeSandboxStatus.RUNNING, + ): + self.sandbox = SimpleNamespace(cwd=cwd, id="sb-123") + self.status = status + self.home = home + self.closed = 0 + self.client = SimpleNamespace(close=self._close) + self.run_command_calls: list[tuple[str, list[str], dict]] = [] + self.run_command_side_effects: list[object] = [] + self.write_files_calls: list[list[dict[str, object]]] = [] + self.write_files_side_effects: list[object] = [] + self.download_file_calls: list[tuple[str, Path]] = [] + self.download_file_side_effects: list[object] = [] + self.download_file_content = b"" + self.stop_calls: list[tuple[tuple, dict]] = [] + self.snapshot_calls: list[tuple[tuple, dict]] = [] + self.snapshot_side_effects: list[object] = [] + self.snapshot_id = "snap_default" + self.refresh_calls = 0 + self.wait_for_status_calls: list[tuple[object, object, object]] = [] + self.wait_for_status_side_effects: list[object] = [] + + def _close(self) -> None: + self.closed += 1 + + def refresh(self) -> None: + self.refresh_calls += 1 + + def wait_for_status(self, status: _FakeSandboxStatus | str, *, timeout, poll_interval) -> None: + self.wait_for_status_calls.append((status, timeout, poll_interval)) + if self.wait_for_status_side_effects: + effect = self.wait_for_status_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if callable(effect): + effect(status, timeout, poll_interval) + return + self.status = _FakeSandboxStatus(status) + + def run_command(self, cmd: str, args: list[str] | None = None, **kwargs): + args = list(args or []) + self.run_command_calls.append((cmd, args, kwargs)) + if self.run_command_side_effects: + effect = self.run_command_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if callable(effect): + return effect(cmd, args, kwargs) + return effect + script = args[1] if len(args) > 1 else "" + if 'printf %s "$HOME"' in script: + return _FakeRunResult(self.home) + return _FakeRunResult("") + + def write_files(self, files: list[dict[str, object]]) -> None: + self.write_files_calls.append(files) + if self.write_files_side_effects: + effect = self.write_files_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if callable(effect): + effect(files) + + def download_file(self, remote_path: str, local_path) -> str: + destination = Path(local_path) + self.download_file_calls.append((remote_path, destination)) + if self.download_file_side_effects: + effect = self.download_file_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if callable(effect): + return effect(remote_path, destination) + destination.write_bytes(self.download_file_content) + return str(destination.resolve()) + + def stop(self, *args, **kwargs) -> None: + self.stop_calls.append((args, kwargs)) + + def snapshot(self, *args, **kwargs): + self.snapshot_calls.append((args, kwargs)) + if self.snapshot_side_effects: + effect = self.snapshot_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if callable(effect): + return effect(*args, **kwargs) + if isinstance(effect, str): + return _FakeSnapshot(effect) + return effect + return _FakeSnapshot(self.snapshot_id) + + +@dataclass(frozen=True) +class _FakeResources: + vcpus: float | None = None + memory: int | None = None + + +@dataclass(frozen=True) +class _FakeWriteFile: + path: str + content: bytes + + +class _FakeSDK: + def __init__(self): + self.create_kwargs: list[dict[str, object]] = [] + self.create_side_effects: list[object] = [] + self.sandboxes: list[_FakeSandbox] = [] + + @property + def current(self) -> _FakeSandbox: + return self.sandboxes[-1] + + def create(self, **kwargs): + self.create_kwargs.append(kwargs) + if self.create_side_effects: + effect = self.create_side_effects.pop(0) + if isinstance(effect, Exception): + raise effect + if isinstance(effect, _FakeSandbox): + self.sandboxes.append(effect) + return effect + sandbox = _FakeSandbox() + self.sandboxes.append(sandbox) + return sandbox + + +def _cwd_result(body: str = "", *, cwd: str = "/vercel/sandbox", exit_code: int = 0): + def _result(_cmd: str, args: list[str], _kwargs: dict): + script = args[1] if len(args) > 1 else "" + match = re.search(r"__HERMES_CWD_[A-Za-z0-9]+__", script) + marker = match.group(0) if match else "__HERMES_CWD_MISSING__" + prefix = f"{body}\n\n" if body else "\n" + return _FakeRunResult(f"{prefix}{marker}{cwd}{marker}\n", exit_code) + + return _result + + +def _tar_bytes(entries: dict[str, bytes]) -> bytes: + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w") as tar: + for name, content in entries.items(): + info = tarfile.TarInfo(name) + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + return buffer.getvalue() + + +@pytest.fixture() +def vercel_sdk(monkeypatch): + fake_sdk = _FakeSDK() + sandbox_mod = types.ModuleType("vercel.sandbox") + sandbox_mod.Sandbox = types.SimpleNamespace(create=fake_sdk.create) + sandbox_mod.Resources = _FakeResources + sandbox_mod.WriteFile = _FakeWriteFile + sandbox_mod.SandboxStatus = _FakeSandboxStatus + + vercel_mod = types.ModuleType("vercel") + vercel_mod.sandbox = sandbox_mod + + monkeypatch.setitem(sys.modules, "vercel", vercel_mod) + monkeypatch.setitem(sys.modules, "vercel.sandbox", sandbox_mod) + return fake_sdk + + +@pytest.fixture() +def vercel_module(vercel_sdk, monkeypatch): + monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False) + monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: []) + monkeypatch.setattr("tools.credential_files.iter_skills_files", lambda **kwargs: []) + monkeypatch.setattr("tools.credential_files.iter_cache_files", lambda **kwargs: []) + + module = importlib.import_module("tools.environments.vercel_sandbox") + return importlib.reload(module) + + +@pytest.fixture() +def make_env(vercel_module, request): + envs = [] + + def _cleanup_envs(): + for env in envs: + env._sync_manager = None + env.cleanup() + + request.addfinalizer(_cleanup_envs) + + def _factory(**kwargs): + kwargs.setdefault("runtime", "node22") + kwargs.setdefault("cwd", vercel_module.DEFAULT_VERCEL_CWD) + kwargs.setdefault("timeout", 30) + kwargs.setdefault("task_id", "task-123") + env = vercel_module.VercelSandboxEnvironment(**kwargs) + envs.append(env) + return env + + return _factory + + +class TestStartup: + def test_default_cwd_tracks_remote_workspace_root(self, make_env, vercel_sdk): + sandbox = _FakeSandbox(cwd="/workspace") + vercel_sdk.create_side_effects.append(sandbox) + + env = make_env() + + assert env.cwd == "/workspace" + + def test_tilde_cwd_resolves_against_remote_home(self, make_env, vercel_sdk): + sandbox = _FakeSandbox(home="/home/custom") + vercel_sdk.create_side_effects.append(sandbox) + + env = make_env(cwd="~") + + assert env.cwd == "/home/custom" + + def test_pending_sandbox_timeout_raises_descriptive_error( + self, make_env, vercel_sdk + ): + sandbox = _FakeSandbox(status=_FakeSandboxStatus.PENDING) + sandbox.wait_for_status_side_effects.append(TimeoutError("still pending")) + vercel_sdk.create_side_effects.append(sandbox) + + with pytest.raises(RuntimeError, match="Sandbox did not reach running state"): + make_env() + + +class TestFileSync: + def test_initial_sync_uploads_managed_files_under_remote_home( + self, make_env, vercel_sdk, monkeypatch, tmp_path + ): + src = tmp_path / "token.txt" + src.write_text("secret-token") + monkeypatch.setattr( + "tools.credential_files.get_credential_file_mounts", + lambda: [ + { + "host_path": str(src), + "container_path": "/root/.hermes/credentials/token.txt", + } + ], + ) + monkeypatch.setattr("tools.credential_files.iter_skills_files", lambda **kwargs: []) + monkeypatch.setattr("tools.credential_files.iter_cache_files", lambda **kwargs: []) + + make_env() + + uploaded = vercel_sdk.current.write_files_calls[0] + assert uploaded == [ + { + "path": "/home/vercel/.hermes/credentials/token.txt", + "content": b"secret-token", + } + ] + + def test_execute_resyncs_changed_managed_files( + self, make_env, vercel_sdk, monkeypatch, tmp_path + ): + src = tmp_path / "token.txt" + src.write_text("secret-token") + monkeypatch.setattr( + "tools.credential_files.get_credential_file_mounts", + lambda: [ + { + "host_path": str(src), + "container_path": "/root/.hermes/credentials/token.txt", + } + ], + ) + monkeypatch.setattr("tools.credential_files.iter_skills_files", lambda **kwargs: []) + monkeypatch.setattr("tools.credential_files.iter_cache_files", lambda **kwargs: []) + + env = make_env() + src.write_text("updated-secret-token") + monkeypatch.setenv("HERMES_FORCE_FILE_SYNC", "1") + vercel_sdk.current.run_command_side_effects.append(_cwd_result("hello")) + + result = env.execute("echo hello") + + assert result == {"output": "hello\n", "returncode": 0} + assert vercel_sdk.current.write_files_calls[-1] == [ + { + "path": "/home/vercel/.hermes/credentials/token.txt", + "content": b"updated-secret-token", + } + ] + + def test_cleanup_syncs_back_snapshots_closes_and_is_idempotent( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + src = tmp_path / "token.txt" + src.write_text("host-token") + monkeypatch.setattr( + "tools.credential_files.get_credential_file_mounts", + lambda: [ + { + "host_path": str(src), + "container_path": "/root/.hermes/credentials/token.txt", + } + ], + ) + monkeypatch.setattr( + "tools.credential_files.iter_skills_files", + lambda **kwargs: [], + ) + monkeypatch.setattr( + "tools.credential_files.iter_cache_files", + lambda **kwargs: [], + ) + env = make_env() + sandbox = vercel_sdk.current + sandbox.snapshot_id = "snap_cleanup" + vercel_sdk.current.download_file_content = _tar_bytes( + { + "home/vercel/.hermes/credentials/token.txt": b"remote-token", + "home/vercel/.hermes/credentials/new.txt": b"new-remote", + "home/vercel/.hermes/unmapped/skip.txt": b"skip", + } + ) + + env.cleanup() + env.cleanup() + + assert src.read_text() == "remote-token" + assert (tmp_path / "new.txt").read_text() == "new-remote" + assert not (tmp_path / "skip.txt").exists() + assert len(sandbox.snapshot_calls) == 1 + assert len(sandbox.stop_calls) == 1 # always stop after snapshot to avoid resource leaks + assert sandbox.closed == 1 + assert vercel_module._load_snapshots() == {"task-123": "snap_cleanup"} + + def test_cleanup_sync_back_failure_from_download_does_not_block_snapshot( + self, make_env, vercel_sdk, monkeypatch, tmp_path + ): + src = tmp_path / "token.txt" + src.write_text("host-token") + monkeypatch.setattr( + "tools.credential_files.get_credential_file_mounts", + lambda: [ + { + "host_path": str(src), + "container_path": "/root/.hermes/credentials/token.txt", + } + ], + ) + monkeypatch.setattr( + "tools.credential_files.iter_skills_files", + lambda **kwargs: [], + ) + monkeypatch.setattr( + "tools.credential_files.iter_cache_files", + lambda **kwargs: [], + ) + env = make_env() + sandbox = vercel_sdk.current + sandbox.run_command_side_effects.extend( + [ + _FakeRunResult("tar failed", exit_code=2), + _FakeRunResult(""), + _FakeRunResult("tar failed", exit_code=2), + _FakeRunResult(""), + _FakeRunResult("tar failed", exit_code=2), + _FakeRunResult(""), + ] + ) + monkeypatch.setattr("tools.environments.file_sync.time.sleep", lambda _delay: None) + + env.cleanup() + + assert src.read_text() == "host-token" + assert len(sandbox.snapshot_calls) == 1 + assert sandbox.closed == 1 + assert len(sandbox.download_file_calls) == 0 + + +class TestExecute: + def test_execute_runs_command_from_workspace_root_and_updates_cwd( + self, make_env, vercel_sdk + ): + env = make_env() + vercel_sdk.current.run_command_side_effects.append( + _cwd_result("/tmp", cwd="/tmp") + ) + + result = env.execute("pwd", cwd="/tmp") + + assert result == {"output": "/tmp\n", "returncode": 0} + assert env.cwd == "/tmp" + cmd, args, kwargs = vercel_sdk.current.run_command_calls[-1] + assert cmd == "bash" + assert args[0] == "-c" + assert "cd /tmp" in args[1] + assert kwargs["cwd"] == "/vercel/sandbox" + + @pytest.mark.parametrize( + ("make_unhealthy", "label"), + [ + ( + lambda sandbox: setattr( + sandbox, "status", _FakeSandboxStatus.STOPPED + ), + "terminal state", + ), + ( + lambda sandbox: setattr( + sandbox, + "refresh", + lambda: (_ for _ in ()).throw(RuntimeError("refresh failed")), + ), + "refresh failure", + ), + ], + ids=["terminal-state", "refresh-failure"], + ) + def test_execute_recreates_unhealthy_sandbox_before_running_command( + self, make_env, vercel_sdk, make_unhealthy, label + ): + env = make_env() + original = vercel_sdk.current + make_unhealthy(original) + + replacement = _FakeSandbox() + replacement.run_command_side_effects.extend( + [ + _FakeRunResult(replacement.home), + _cwd_result("hello"), + ] + ) + vercel_sdk.create_side_effects.append(replacement) + + result = env.execute("echo hello") + + assert result == {"output": "hello\n", "returncode": 0}, label + assert original.closed == 1 + assert vercel_sdk.current is replacement + + def test_run_bash_handle_uses_captured_sandbox_for_exec_and_cancel( + self, make_env + ): + env = make_env() + original = env._sandbox + assert original is not None + replacement = _FakeSandbox() + started = threading.Event() + release = threading.Event() + + def blocking_command(_cmd: str, _args: list[str], _kwargs: dict): + started.set() + release.wait(timeout=5) + return _FakeRunResult("done") + + original.run_command_side_effects.append(blocking_command) + + handle = env._run_bash("echo done") + assert started.wait(timeout=1) + + env._sandbox = replacement + handle.kill() + release.set() + + assert handle.wait(timeout=2) == 0 + assert len(original.stop_calls) == 1 + assert replacement.stop_calls == [] + cmd, args, kwargs = original.run_command_calls[-1] + assert cmd == "bash" + assert args == ["-c", "echo done"] + assert kwargs["cwd"] == "/vercel/sandbox" + + +class TestSnapshotPersistence: + def test_create_restores_from_saved_snapshot( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + vercel_module._store_snapshot("task-123", "snap_saved") + restored = _FakeSandbox(cwd="/restored") + vercel_sdk.create_side_effects.append(restored) + + env = make_env() + + assert env.cwd == "/restored" + assert vercel_sdk.create_kwargs[0]["source"] == { + "type": "snapshot", + "snapshot_id": "snap_saved", + } + assert vercel_module._load_snapshots() == {"task-123": "snap_saved"} + + def test_restore_failure_prunes_snapshot_and_falls_back_to_fresh_sandbox( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + vercel_module._store_snapshot("task-123", "snap_stale") + fresh = _FakeSandbox(cwd="/fresh") + vercel_sdk.create_side_effects.extend( + [RuntimeError("snapshot missing"), fresh] + ) + + env = make_env() + + assert env.cwd == "/fresh" + assert vercel_sdk.create_kwargs[0]["source"] == { + "type": "snapshot", + "snapshot_id": "snap_stale", + } + assert "source" not in vercel_sdk.create_kwargs[1] + assert vercel_module._load_snapshots() == {} + + def test_cleanup_stops_when_snapshot_fails_without_storing_metadata( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + env = make_env() + sandbox = vercel_sdk.current + sandbox.snapshot_side_effects.append(RuntimeError("snapshot failed")) + + env.cleanup() + + assert len(sandbox.snapshot_calls) == 1 + assert len(sandbox.stop_calls) == 1 + assert sandbox.closed == 1 + assert vercel_module._load_snapshots() == {} + + def test_non_persistent_cleanup_stops_without_snapshot( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + env = make_env(persistent_filesystem=False) + sandbox = vercel_sdk.current + + env.cleanup() + + assert sandbox.snapshot_calls == [] + assert len(sandbox.stop_calls) == 1 + assert sandbox.closed == 1 + assert vercel_module._load_snapshots() == {} + + def test_persistent_cleanup_without_task_id_stops_without_snapshot( + self, make_env, vercel_module, vercel_sdk, monkeypatch, tmp_path + ): + hermes_home = tmp_path / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + env = make_env(task_id="") + sandbox = vercel_sdk.current + + env.cleanup() + + assert sandbox.snapshot_calls == [] + assert len(sandbox.stop_calls) == 1 + assert sandbox.closed == 1 + assert vercel_module._load_snapshots() == {} + + +class TestCleanup: + def test_cleanup_continues_when_sync_back_raises(self, make_env, vercel_sdk): + env = make_env() + sandbox = vercel_sdk.current + + class FailingSyncManager: + def sync_back(self): + raise RuntimeError("download failed") + + env._sync_manager = FailingSyncManager() + + env.cleanup() + + assert len(sandbox.snapshot_calls) == 1 + assert sandbox.closed == 1 diff --git a/tests/tui_gateway/test_protocol.py b/tests/tui_gateway/test_protocol.py index 6c94ec0710..bd527608a7 100644 --- a/tests/tui_gateway/test_protocol.py +++ b/tests/tui_gateway/test_protocol.py @@ -298,7 +298,7 @@ def test_session_resume_returns_hydrated_messages(server, monkeypatch): def reopen_session(self, _sid): return None - def get_messages_as_conversation(self, _sid): + def get_messages_as_conversation(self, _sid, include_ancestors=False): return [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": "yo"}, @@ -641,6 +641,29 @@ def test_dispatch_long_handler_does_not_block_fast_handler(server): released.set() +def test_dispatch_session_compress_does_not_block_fast_handler(server): + """Manual TUI compaction can take minutes, so it must not block the RPC loop.""" + released = threading.Event() + + def slow_compress(rid, params): + released.wait(timeout=5) + return server._ok(rid, {"done": True}) + + server._methods["session.compress"] = slow_compress + server._methods["fast.ping"] = lambda rid, params: server._ok(rid, {"pong": True}) + + t0 = time.monotonic() + assert server.dispatch({"id": "slow", "method": "session.compress", "params": {}}) is None + + fast_resp = server.dispatch({"id": "fast", "method": "fast.ping", "params": {}}) + fast_elapsed = time.monotonic() - t0 + + assert fast_resp["result"] == {"pong": True} + assert fast_elapsed < 0.5, f"fast handler blocked for {fast_elapsed:.2f}s behind session.compress" + + released.set() + + def test_dispatch_long_handler_exception_produces_error_response(capture): """An exception inside a pool-dispatched handler still yields a JSON-RPC error.""" server, buf = capture diff --git a/tools/approval.py b/tools/approval.py index 5521ab5bc8..78fb481783 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -17,6 +17,7 @@ import threading import time import unicodedata from typing import Optional +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -711,7 +712,7 @@ def _get_cron_approval_mode() -> str: try: from hermes_cli.config import load_config config = load_config() - mode = str(config.get("approvals", {}).get("cron_mode", "deny")).lower().strip() + mode = str(cfg_get(config, "approvals", "cron_mode", default="deny")).lower().strip() if mode in ("approve", "off", "allow", "yes"): return "approve" return "deny" @@ -781,7 +782,7 @@ def check_dangerous_command(command: str, env_type: str, Returns: {"approved": True/False, "message": str or None, ...} """ - if env_type in ("docker", "singularity", "modal", "daytona"): + if env_type in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): return {"approved": True, "message": None} # Hardline floor: commands with no recovery path (rm -rf /, mkfs, dd @@ -906,7 +907,7 @@ def check_all_command_guards(command: str, env_type: str, other was shown to the user. """ # Skip containers for both checks - if env_type in ("docker", "singularity", "modal", "daytona"): + if env_type in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): return {"approved": True, "message": None} # Hardline floor: unconditional block for catastrophic commands diff --git a/tools/browser_camofox.py b/tools/browser_camofox.py index e1233859ae..5f59dd913f 100644 --- a/tools/browser_camofox.py +++ b/tools/browser_camofox.py @@ -32,7 +32,7 @@ from typing import Any, Dict, Optional import requests -from hermes_cli.config import load_config +from hermes_cli.config import cfg_get, load_config from tools.browser_camofox_state import get_camofox_identity from tools.registry import tool_error @@ -544,7 +544,7 @@ def camofox_vision(question: str, annotate: bool = False, try: _cfg = load_config() - _vision_cfg = _cfg.get("auxiliary", {}).get("vision", {}) + _vision_cfg = cfg_get(_cfg, "auxiliary", "vision", default={}) _vision_timeout = float(_vision_cfg.get("timeout", 120)) _vision_temperature = float(_vision_cfg.get("temperature", 0.1)) except Exception: diff --git a/tools/browser_tool.py b/tools/browser_tool.py index 362a1575ca..5cd431de31 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -68,6 +68,7 @@ from pathlib import Path from agent.auxiliary_client import call_llm from hermes_constants import get_hermes_home from utils import is_truthy_value +from hermes_cli.config import cfg_get try: from tools.website_policy import check_website_access @@ -192,7 +193,7 @@ def _get_command_timeout() -> int: try: from hermes_cli.config import read_raw_config cfg = read_raw_config() - val = cfg.get("browser", {}).get("command_timeout") + val = cfg_get(cfg, "browser", "command_timeout") if val is not None: result = max(int(val), 5) # Floor at 5s to avoid instant kills except Exception as e: @@ -2245,7 +2246,7 @@ def _maybe_start_recording(task_id: str): from hermes_cli.config import read_raw_config hermes_home = get_hermes_home() cfg = read_raw_config() - record_enabled = cfg.get("browser", {}).get("record_sessions", False) + record_enabled = cfg_get(cfg, "browser", "record_sessions", default=False) if not record_enabled: return @@ -2448,7 +2449,7 @@ def browser_vision(question: str, annotate: bool = False, task_id: Optional[str] try: from hermes_cli.config import load_config _cfg = load_config() - _vision_cfg = _cfg.get("auxiliary", {}).get("vision", {}) + _vision_cfg = cfg_get(_cfg, "auxiliary", "vision", default={}) _vt = _vision_cfg.get("timeout") if _vt is not None: vision_timeout = float(_vt) diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index db706e6a4c..c91907c4d1 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -73,7 +73,24 @@ MAX_STDERR_BYTES = 10_000 # 10 KB def check_sandbox_requirements() -> bool: """Code execution sandbox requires a POSIX OS for Unix domain sockets.""" - return SANDBOX_AVAILABLE + if not SANDBOX_AVAILABLE: + return False + + try: + from tools.terminal_tool import ( + _check_vercel_sandbox_requirements, + _get_env_config, + ) + + config = _get_env_config() + except Exception: + logger.debug("Could not resolve terminal config for execute_code availability", exc_info=True) + return False + + if config.get("env_type") == "vercel_sandbox": + return _check_vercel_sandbox_requirements(config) + + return True # --------------------------------------------------------------------------- @@ -481,13 +498,15 @@ def _get_or_create_env(task_id: str): cwd = overrides.get("cwd") or config["cwd"] container_config = None - if env_type in ("docker", "singularity", "modal", "daytona"): + if env_type in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): container_config = { "container_cpu": config.get("container_cpu", 1), "container_memory": config.get("container_memory", 5120), "container_disk": config.get("container_disk", 51200), "container_persistent": config.get("container_persistent", True), + "vercel_runtime": config.get("vercel_runtime", ""), "docker_volumes": config.get("docker_volumes", []), + "docker_run_as_host_user": config.get("docker_run_as_host_user", False), } ssh_config = None @@ -1309,10 +1328,20 @@ def _kill_process_group(proc, escalate: bool = False): def _load_config() -> dict: - """Load code_execution config from CLI_CONFIG if available.""" + """Load code_execution config without importing the interactive CLI. + + This helper is called while building the module-level execute_code schema + during tool discovery. Importing ``cli`` here pulls prompt_toolkit/Rich and + a large chunk of the classic REPL onto every agent startup path, including + ``hermes --tui`` where it is never used. Read the lightweight raw config + instead; the config layer already caches by (mtime, size), and an absent + key cleanly falls back to DEFAULT_EXECUTION_MODE. + """ try: - from cli import CLI_CONFIG - return CLI_CONFIG.get("code_execution", {}) + from hermes_cli.config import read_raw_config + + cfg = read_raw_config().get("code_execution", {}) + return cfg if isinstance(cfg, dict) else {} except Exception: return {} diff --git a/tools/credential_files.py b/tools/credential_files.py index 7998321e63..2372950cfe 100644 --- a/tools/credential_files.py +++ b/tools/credential_files.py @@ -25,6 +25,7 @@ import os from contextvars import ContextVar from pathlib import Path from typing import Dict, List +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -138,7 +139,7 @@ def _load_config_files() -> List[Dict[str, str]]: from hermes_cli.config import read_raw_config hermes_home = _resolve_hermes_home() cfg = read_raw_config() - cred_files = cfg.get("terminal", {}).get("credential_files") + cred_files = cfg_get(cfg, "terminal", "credential_files") if isinstance(cred_files, list): from tools.path_security import validate_within_dir diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 994c313623..53e778a7db 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -150,6 +150,27 @@ def _normalize_optional_job_value(value: Optional[Any], *, strip_trailing_slash: return text or None +def _normalize_deliver_param(value: Any) -> Optional[str]: + """Normalize a user-supplied ``deliver`` value to the canonical string form. + + The cron schema documents ``deliver`` as a string (``"local"``, ``"origin"``, + ``"telegram"``, ``"telegram:chat_id[:thread_id]"``, or comma-separated combos). + Some callers — MCP clients passing arrays, scripts building the payload as a + list — supply ``["telegram"]``. ``create_job``/``update_job`` store it as-is, + and the scheduler's ``str(deliver).split(",")`` then serializes the list to + the literal ``"['telegram']"`` which is not a known platform. Flatten lists + / tuples at the API boundary so storage is always a string. Returns ``None`` + for ``None``/empty so callers can treat it as "not supplied". + """ + if value is None: + return None + if isinstance(value, (list, tuple)): + parts = [str(p).strip() for p in value if str(p).strip()] + return ",".join(parts) if parts else None + text = str(value).strip() + return text or None + + def _validate_cron_script_path(script: Optional[str]) -> Optional[str]: """Validate a cron job script path at the API boundary. @@ -283,7 +304,7 @@ def cronjob( schedule=schedule, name=name, repeat=repeat, - deliver=deliver, + deliver=_normalize_deliver_param(deliver), origin=_origin_from_env(), skills=canonical_skills, model=_normalize_optional_job_value(model), @@ -364,7 +385,7 @@ def cronjob( if name is not None: updates["name"] = name if deliver is not None: - updates["deliver"] = deliver + updates["deliver"] = _normalize_deliver_param(deliver) if skills is not None or skill is not None: canonical_skills = _canonical_skills(skill, skills) updates["skills"] = canonical_skills diff --git a/tools/env_passthrough.py b/tools/env_passthrough.py index 07bf333a60..f23f39b954 100644 --- a/tools/env_passthrough.py +++ b/tools/env_passthrough.py @@ -22,6 +22,7 @@ from __future__ import annotations import logging from contextvars import ContextVar from typing import Iterable +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -109,7 +110,7 @@ def _load_config_passthrough() -> frozenset[str]: try: from hermes_cli.config import read_raw_config cfg = read_raw_config() - passthrough = cfg.get("terminal", {}).get("env_passthrough") + passthrough = cfg_get(cfg, "terminal", "env_passthrough") if isinstance(passthrough, list): for item in passthrough: if isinstance(item, str) and item.strip(): diff --git a/tools/environments/docker.py b/tools/environments/docker.py index 65c33b349c..06d8154872 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -151,16 +151,16 @@ def find_docker() -> Optional[str]: # SETUID/SETGID - the image entrypoint drops from root to the 'hermes' # user via `gosu`, which requires these caps. Combined with # `no-new-privileges`, gosu still cannot escalate back to root after -# the drop, so the security posture is preserved. +# the drop, so the security posture is preserved. Omitted entirely +# when the container starts as a non-root user via --user, since +# no gosu drop is needed in that mode. # Block privilege escalation and limit PIDs. # /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds). -_SECURITY_ARGS = [ +_BASE_SECURITY_ARGS = [ "--cap-drop", "ALL", "--cap-add", "DAC_OVERRIDE", "--cap-add", "CHOWN", "--cap-add", "FOWNER", - "--cap-add", "SETUID", - "--cap-add", "SETGID", "--security-opt", "no-new-privileges", "--pids-limit", "256", "--tmpfs", "/tmp:rw,nosuid,size=512m", @@ -168,6 +168,39 @@ _SECURITY_ARGS = [ "--tmpfs", "/run:rw,noexec,nosuid,size=64m", ] +# Extra caps needed when the container starts as root and an entrypoint +# must drop privileges via gosu/su. Skipped when --user is passed because +# the container already starts unprivileged and never needs to switch. +_GOSU_CAP_ARGS = [ + "--cap-add", "SETUID", + "--cap-add", "SETGID", +] + + +def _build_security_args(run_as_host_user: bool) -> list[str]: + """Return the security/cap/tmpfs args tailored to the privilege mode.""" + if run_as_host_user: + return list(_BASE_SECURITY_ARGS) + return list(_BASE_SECURITY_ARGS) + list(_GOSU_CAP_ARGS) + + +def _resolve_host_user_spec() -> Optional[str]: + """Return ``:`` for the current host user, or ``None`` on platforms + where this is not meaningful (e.g. Windows without posix ids). + + We intentionally read ``os.getuid()``/``os.getgid()`` directly rather than + going through ``getpass``/``pwd`` so this stays cheap and never raises on + nameless UIDs (nss lookups can fail inside sandboxed launchers). + """ + get_uid = getattr(os, "getuid", None) + get_gid = getattr(os, "getgid", None) + if get_uid is None or get_gid is None: + return None + try: + return f"{get_uid()}:{get_gid()}" + except Exception: # pragma: no cover - defensive + return None + _storage_opt_ok: Optional[bool] = None # cached result across instances @@ -266,6 +299,7 @@ class DockerEnvironment(BaseEnvironment): network: bool = True, host_cwd: str = None, auto_mount_cwd: bool = False, + run_as_host_user: bool = False, ): if cwd == "~": cwd = "/root" @@ -421,8 +455,35 @@ class DockerEnvironment(BaseEnvironment): for key in sorted(self._env): env_args.extend(["-e", f"{key}={self._env[key]}"]) + # Optional: run the container as the host user so files written into + # bind-mounted dirs (/workspace, /root, docker_volumes entries) are + # owned by that user on the host instead of by root. Skip cleanly on + # platforms without POSIX uid/gid (e.g. native Windows Docker). + user_args: list[str] = [] + if run_as_host_user: + user_spec = _resolve_host_user_spec() + if user_spec is not None: + user_args = ["--user", user_spec] + logger.info("Docker: running container as host user %s", user_spec) + else: + logger.warning( + "docker_run_as_host_user is enabled but this platform does " + "not expose POSIX uid/gid; container will start as its " + "image default user." + ) + # Fall back to the full cap set — without --user, an image's + # entrypoint may still need gosu/su to drop privileges. + security_args = _build_security_args(run_as_host_user and bool(user_args)) + logger.info(f"Docker volume_args: {volume_args}") - all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args + volume_args + env_args + all_run_args = ( + security_args + + user_args + + writable_args + + resource_args + + volume_args + + env_args + ) logger.info(f"Docker run_args: {all_run_args}") # Resolve the docker executable once so it works even when diff --git a/tools/environments/local.py b/tools/environments/local.py index 1029545f08..d419c72c30 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -6,6 +6,7 @@ import shutil import signal import subprocess import tempfile +import time from tools.environments.base import BaseEnvironment, _pipe_stdin @@ -100,6 +101,10 @@ def _build_provider_env_blocklist() -> frozenset: "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", "DAYTONA_API_KEY", + "VERCEL_OIDC_TOKEN", + "VERCEL_TOKEN", + "VERCEL_PROJECT_ID", + "VERCEL_TEAM_ID", }) return frozenset(blocked) @@ -365,6 +370,11 @@ class LocalEnvironment(BaseEnvironment): preexec_fn=None if _IS_WINDOWS else os.setsid, cwd=self.cwd, ) + if not _IS_WINDOWS: + try: + proc._hermes_pgid = os.getpgid(proc.pid) + except ProcessLookupError: + pass if stdin_data is not None: _pipe_stdin(proc, stdin_data) @@ -377,12 +387,42 @@ class LocalEnvironment(BaseEnvironment): if _IS_WINDOWS: proc.terminate() else: - pgid = os.getpgid(proc.pid) + try: + pgid = os.getpgid(proc.pid) + except ProcessLookupError: + pgid = getattr(proc, "_hermes_pgid", None) + if pgid is None: + raise os.killpg(pgid, signal.SIGTERM) + deadline = time.monotonic() + 1.0 + while time.monotonic() < deadline: + if proc.poll() is not None: + try: + os.killpg(pgid, 0) + except ProcessLookupError: + return + time.sleep(0.05) + + # The shell can exit quickly while a child in the same process + # group is still shutting down. Escalate based on the process + # group, not just the shell wrapper, so interrupted commands do + # not leave orphaned grandchildren under load. + try: + # _IS_WINDOWS is guarded by the enclosing else branch. + os.killpg(pgid, signal.SIGKILL) + except ProcessLookupError: + return try: proc.wait(timeout=1.0) except subprocess.TimeoutExpired: - os.killpg(pgid, signal.SIGKILL) + pass + deadline = time.monotonic() + 1.0 + while time.monotonic() < deadline: + try: + os.killpg(pgid, 0) + except ProcessLookupError: + return + time.sleep(0.05) except (ProcessLookupError, PermissionError): try: proc.kill() @@ -392,7 +432,8 @@ class LocalEnvironment(BaseEnvironment): def _update_cwd(self, result: dict): """Read CWD from temp file (local-only, no round-trip needed).""" try: - cwd_path = open(self._cwd_file).read().strip() + with open(self._cwd_file) as f: + cwd_path = f.read().strip() if cwd_path: self.cwd = cwd_path except (OSError, FileNotFoundError): diff --git a/tools/environments/vercel_sandbox.py b/tools/environments/vercel_sandbox.py new file mode 100644 index 0000000000..2b434af159 --- /dev/null +++ b/tools/environments/vercel_sandbox.py @@ -0,0 +1,638 @@ +"""Vercel Sandbox execution environment. + +Uses the Vercel Python SDK to run commands in cloud sandboxes through Hermes' +shared ``BaseEnvironment`` shell contract. When persistence is enabled, the +backend stores task-scoped snapshot metadata under ``HERMES_HOME`` and restores +new sandboxes from those snapshots on later task reuse. +""" + +from __future__ import annotations + +from functools import cache +from dataclasses import dataclass +from datetime import timedelta +import logging +import math +import os +import shlex +import threading +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import httpx + +from hermes_constants import get_hermes_home +from tools.environments.base import ( + BaseEnvironment, + _ThreadedProcessHandle, + _load_json_store, + _save_json_store, +) +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_rm_command, +) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from vercel.sandbox import Resources, Sandbox, SandboxStatus, WriteFile + +DEFAULT_VERCEL_CWD = "/vercel/sandbox" +_DEFAULT_CONTAINER_DISK_MB = 51200 +_CREATE_RETRY_ATTEMPTS = 3 +_WRITE_RETRY_ATTEMPTS = 3 +_TRANSIENT_STATUS_CODES = frozenset({408, 425, 429, 500, 502, 503, 504}) +_RETRY_BACKOFF_STEP = timedelta(milliseconds=100) +_MIN_SANDBOX_TIMEOUT = timedelta(minutes=5) +_MIN_RUNNING_WAIT = timedelta(seconds=1) +_RUNNING_WAIT_TIMEOUT = timedelta(seconds=30) +_RUNNING_WAIT_POLL_INTERVAL = timedelta(milliseconds=250) +_STOP_TIMEOUT = timedelta(seconds=15) +_STOP_POLL_INTERVAL = timedelta(milliseconds=500) +_SNAPSHOT_STORE_NAME = "vercel_sandbox_snapshots.json" + + +def _exception_chain(exc: BaseException) -> list[BaseException]: + chain: list[BaseException] = [] + current: BaseException | None = exc + seen: set[int] = set() + while current is not None and id(current) not in seen: + chain.append(current) + seen.add(id(current)) + current = current.__cause__ or current.__context__ + return chain + + +def _extract_status_code(exc: BaseException) -> int | None: + response = getattr(exc, "response", None) + for value in (getattr(exc, "status_code", None), getattr(response, "status_code", None)): + if isinstance(value, int): + return value + return None + + +def _is_transient_vercel_error(exc: BaseException) -> bool: + for error in _exception_chain(exc): + status_code = _extract_status_code(error) + if status_code in _TRANSIENT_STATUS_CODES: + return True + if isinstance( + error, + (httpx.NetworkError, httpx.ProtocolError, httpx.ReadError), + ): + return True + error_name = type(error).__name__.lower() + if "ratelimit" in error_name or "servererror" in error_name: + return True + return False + + +def _retry_vercel_call( + label: str, + callback, + *, + attempts: int, +): + backoff_seconds = _RETRY_BACKOFF_STEP.total_seconds() + for attempt in range(1, attempts + 1): + try: + return callback() + except Exception as exc: + if attempt >= attempts or not _is_transient_vercel_error(exc): + raise + logger.warning( + "Vercel: %s failed (%s); retrying %d/%d", + label, + exc, + attempt, + attempts, + ) + time.sleep(backoff_seconds * attempt) + + +def _coerce_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value) + + +def _extract_result_output(result: Any) -> str: + try: + return _coerce_text(result.output()) + except (AttributeError, TypeError): + return _coerce_text(result) + + +def _extract_result_returncode(result: Any) -> int: + try: + exit_code = result.exit_code + except AttributeError: + try: + exit_code = result.returncode + except AttributeError: + return 1 + return exit_code if isinstance(exit_code, int) else 1 + + +def _snapshot_store_path() -> Path: + return get_hermes_home() / _SNAPSHOT_STORE_NAME + + +def _load_snapshots() -> dict: + return _load_json_store(_snapshot_store_path()) + + +def _save_snapshots(data: dict) -> None: + _save_json_store(_snapshot_store_path(), data) + + +def _get_snapshot_id(task_id: str) -> str | None: + if not task_id: + return None + snapshot_id = _load_snapshots().get(task_id) + return snapshot_id if isinstance(snapshot_id, str) and snapshot_id else None + + +def _store_snapshot(task_id: str, snapshot_id: str) -> None: + if not task_id or not snapshot_id: + return + snapshots = _load_snapshots() + snapshots[task_id] = snapshot_id + _save_snapshots(snapshots) + + +def _delete_snapshot(task_id: str, snapshot_id: str | None = None) -> None: + if not task_id: + return + snapshots = _load_snapshots() + existing = snapshots.get(task_id) + if existing is None: + return + if snapshot_id is not None and existing != snapshot_id: + return + snapshots.pop(task_id, None) + _save_snapshots(snapshots) + + +def _extract_snapshot_id(snapshot: Any) -> str | None: + for attr in ("snapshot_id", "snapshotId", "id"): + value = getattr(snapshot, attr, None) + if isinstance(value, str) and value: + return value + if isinstance(snapshot, dict): + for key in ("snapshot_id", "snapshotId", "id"): + value = snapshot.get(key) + if isinstance(value, str) and value: + return value + return None + + +@cache +def _sandbox_status_type() -> type[SandboxStatus]: + from vercel.sandbox import SandboxStatus + + return SandboxStatus + + +@cache +def _terminal_sandbox_states() -> frozenset[SandboxStatus]: + SandboxStatus = _sandbox_status_type() + return frozenset( + { + SandboxStatus.ABORTED, + SandboxStatus.FAILED, + SandboxStatus.STOPPED, + } + ) + + +@dataclass(frozen=True, slots=True) +class _SandboxCreateParams: + timeout: timedelta + runtime: str | None = None + resources: Resources | None = None + + +class VercelSandboxEnvironment(BaseEnvironment): + """Vercel cloud sandbox backend.""" + + _stdin_mode = "heredoc" + + def __init__( + self, + runtime: str | None = None, + cwd: str = DEFAULT_VERCEL_CWD, + timeout: int = 60, + cpu: float = 1, + memory: int = 5120, + disk: int = _DEFAULT_CONTAINER_DISK_MB, + persistent_filesystem: bool = True, + task_id: str = "default", + ): + requested_cwd = cwd + super().__init__(cwd=cwd, timeout=timeout) + + self._runtime = runtime or None + self._persistent = persistent_filesystem + self._task_id = task_id + self._requested_cwd = requested_cwd + self._lock = threading.Lock() + self._sandbox: Sandbox | None = None + self._workspace_root = DEFAULT_VERCEL_CWD + self._remote_home = DEFAULT_VERCEL_CWD + self._sync_manager: FileSyncManager | None = None + self._create_params = self._build_create_params(cpu=cpu, memory=memory, disk=disk) + + self._sandbox = self._create_sandbox() + self._configure_attached_sandbox(requested_cwd=requested_cwd) + self._sync_manager.sync(force=True) + self.init_session() + + def _build_create_params(self, *, cpu: float, memory: int, disk: int) -> _SandboxCreateParams: + if disk not in (0, _DEFAULT_CONTAINER_DISK_MB): + raise ValueError( + "Vercel Sandbox does not support configurable container_disk. " + "Use the default shared setting." + ) + + from vercel.sandbox import Resources + + sandbox_timeout = max( + timedelta(seconds=max(self.timeout, 0)), + _MIN_SANDBOX_TIMEOUT, + ) + vcpus = math.floor(cpu) if cpu > 0 else None + memory_mb = memory if memory > 0 else None + resources = ( + Resources(vcpus=vcpus, memory=memory_mb) + if vcpus is not None or memory_mb is not None + else None + ) + + return _SandboxCreateParams( + timeout=sandbox_timeout, + runtime=self._runtime, + resources=resources, + ) + + def _create_sandbox(self) -> Sandbox: + from vercel.sandbox import Sandbox + + snapshot_id = _get_snapshot_id(self._task_id) if self._persistent else None + if snapshot_id: + try: + return _retry_vercel_call( + "sandbox restore", + lambda: Sandbox.create( + timeout=self._create_params.timeout, + runtime=self._create_params.runtime, + resources=self._create_params.resources, + source={"type": "snapshot", "snapshot_id": snapshot_id}, + ), + attempts=_CREATE_RETRY_ATTEMPTS, + ) + except Exception as exc: + logger.warning( + "Vercel: failed to restore snapshot %s for task %s; " + "falling back to a fresh sandbox: %s", + snapshot_id, + self._task_id, + exc, + ) + _delete_snapshot(self._task_id, snapshot_id) + + params = self._create_params + return _retry_vercel_call( + "sandbox create", + lambda: Sandbox.create( + timeout=params.timeout, + runtime=params.runtime, + resources=params.resources, + ), + attempts=_CREATE_RETRY_ATTEMPTS, + ) + + def _configure_attached_sandbox(self, *, requested_cwd: str) -> None: + self._wait_for_running() + self._workspace_root = self._detect_workspace_root() + self._remote_home = self._detect_remote_home() + + if self._remote_home == "/": + container_base = "/.hermes" + else: + container_base = f"{self._remote_home.rstrip('/')}/.hermes" + self._sync_manager = FileSyncManager( + get_files_fn=lambda: iter_sync_files(container_base), + upload_fn=self._vercel_upload, + delete_fn=self._vercel_delete, + bulk_upload_fn=self._vercel_bulk_upload, + bulk_download_fn=self._vercel_bulk_download, + ) + + if requested_cwd == "~": + self.cwd = self._remote_home + elif requested_cwd in ("", DEFAULT_VERCEL_CWD): + self.cwd = self._workspace_root + else: + self.cwd = requested_cwd + + def _detect_workspace_root(self) -> str: + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + cwd = sandbox.sandbox.cwd + return cwd if cwd.startswith("/") else DEFAULT_VERCEL_CWD + + def _detect_remote_home(self) -> str: + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + try: + result = sandbox.run_command( + "sh", + ["-lc", 'printf %s "$HOME"'], + cwd=self._workspace_root, + ) + except Exception as exc: + logger.debug( + "Vercel: home detection failed for task %s: %s", + self._task_id, + exc, + ) + return self._workspace_root + + home = _extract_result_output(result).strip() + if home.startswith("/"): + return home + return self._workspace_root + + def _wait_for_running(self, timeout: timedelta = _RUNNING_WAIT_TIMEOUT) -> None: + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + SandboxStatus = _sandbox_status_type() + status = sandbox.status + if status is None or status == SandboxStatus.RUNNING: + return + if status in _terminal_sandbox_states(): + raise RuntimeError(f"Sandbox entered terminal state: {status}") + + try: + sandbox.wait_for_status( + SandboxStatus.RUNNING, + timeout=max(timeout, _MIN_RUNNING_WAIT), + poll_interval=_RUNNING_WAIT_POLL_INTERVAL, + ) + except TimeoutError as exc: + status = sandbox.status + if status in _terminal_sandbox_states(): + raise RuntimeError(f"Sandbox entered terminal state: {status}") from exc + raise RuntimeError( + f"Sandbox did not reach running state (last status: {status})" + ) from exc + + def _close_sandbox_client(self, sandbox: Sandbox | None) -> None: + if sandbox is None: + return + try: + sandbox.client.close() + except Exception: + pass + + def _stop_sandbox(self, sandbox: Sandbox | None) -> None: + if sandbox is None: + return + try: + sandbox.stop( + blocking=True, + timeout=_STOP_TIMEOUT, + poll_interval=_STOP_POLL_INTERVAL, + ) + except TypeError: + try: + sandbox.stop() + except Exception: + pass + except Exception: + pass + + def _snapshot_sandbox(self, sandbox: Sandbox) -> str | None: + if not self._persistent or not self._task_id: + return None + try: + snapshot = sandbox.snapshot() + except Exception as exc: + logger.warning( + "Vercel: filesystem snapshot failed for task %s: %s", + self._task_id, + exc, + ) + return None + + snapshot_id = _extract_snapshot_id(snapshot) + if not snapshot_id: + logger.warning( + "Vercel: filesystem snapshot for task %s did not return a snapshot id", + self._task_id, + ) + return None + + _store_snapshot(self._task_id, snapshot_id) + logger.info( + "Vercel: saved filesystem snapshot %s for task %s", + snapshot_id, + self._task_id, + ) + return snapshot_id + + def _ensure_sandbox_ready(self) -> None: + sandbox = self._sandbox + requested_cwd = self.cwd or self._requested_cwd or DEFAULT_VERCEL_CWD + + if sandbox is None: + self._sandbox = self._create_sandbox() + self._configure_attached_sandbox(requested_cwd=requested_cwd) + return + + try: + sandbox.refresh() + except Exception as exc: + logger.warning( + "Vercel: sandbox refresh failed for task %s: %s; recreating", + self._task_id, + exc, + ) + self._close_sandbox_client(sandbox) + self._sandbox = self._create_sandbox() + self._configure_attached_sandbox(requested_cwd=requested_cwd) + return + + status = sandbox.status + if status in _terminal_sandbox_states(): + logger.warning( + "Vercel: sandbox entered state %s for task %s; recreating", + status, + self._task_id, + ) + self._close_sandbox_client(sandbox) + self._sandbox = self._create_sandbox() + self._configure_attached_sandbox(requested_cwd=requested_cwd) + return + + self._wait_for_running() + + def _vercel_upload(self, host_path: str, remote_path: str) -> None: + self._vercel_bulk_upload([(host_path, remote_path)]) + + def _vercel_bulk_upload(self, files: list[tuple[str, str]]) -> None: + if not files: + return + + payload: list[WriteFile] = [ + { + "path": remote_path, + "content": Path(host_path).read_bytes(), + } + for host_path, remote_path in files + ] + + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + _retry_vercel_call( + "write_files", + lambda: sandbox.write_files(payload), + attempts=_WRITE_RETRY_ATTEMPTS, + ) + + def _vercel_delete(self, remote_paths: list[str]) -> None: + if not remote_paths: + return + + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + result = sandbox.run_command( + "bash", + ["-lc", quoted_rm_command(remote_paths)], + cwd=self._workspace_root, + ) + if _extract_result_returncode(result) != 0: + raise RuntimeError( + f"Vercel delete failed: {_extract_result_output(result).strip()}" + ) + + def _vercel_bulk_download(self, dest_tar_path: Path) -> None: + remote_hermes = ( + "/.hermes" + if self._remote_home == "/" + else f"{self._remote_home.rstrip('/')}/.hermes" + ) + archive_member = remote_hermes.lstrip("/") + remote_tar = f"/tmp/.hermes_sync.{os.getpid()}.tar" + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + + try: + result = sandbox.run_command( + "bash", + [ + "-lc", + f"tar cf {shlex.quote(remote_tar)} -C / {shlex.quote(archive_member)}", + ], + cwd=self._workspace_root, + ) + if _extract_result_returncode(result) != 0: + raise RuntimeError( + f"Vercel bulk download failed: {_extract_result_output(result).strip()}" + ) + + sandbox.download_file(remote_tar, dest_tar_path) + finally: + try: + sandbox.run_command( + "bash", + ["-lc", f"rm -f {shlex.quote(remote_tar)}"], + cwd=self._workspace_root, + ) + except Exception: + pass + + def _before_execute(self) -> None: + with self._lock: + self._ensure_sandbox_ready() + if self._sync_manager is not None: + self._sync_manager.sync() + + def _run_bash( + self, + cmd_string: str, + *, + login: bool = False, + timeout: int = 120, + stdin_data: str | None = None, + ): + """Run a bash command in the Vercel sandbox. + + ``timeout`` is not forwarded to the Vercel SDK (which does not expose + a per-exec timeout parameter); the base class ``_wait_for_process`` + enforces timeout by killing the sandbox via ``cancel_fn``. + + ``stdin_data`` is intentionally discarded here because + ``_stdin_mode = "heredoc"`` causes the base class ``execute()`` to + embed any stdin payload into the command string before calling this + method. + """ + del timeout + del stdin_data + + sandbox = self._sandbox + if sandbox is None: + raise RuntimeError("Vercel sandbox is not attached") + workspace_root = self._workspace_root + lock = self._lock + + def cancel() -> None: + with lock: + self._stop_sandbox(sandbox) + + def exec_fn() -> tuple[str, int]: + result = sandbox.run_command( + "bash", + ["-lc" if login else "-c", cmd_string], + cwd=workspace_root, + ) + return _extract_result_output(result), _extract_result_returncode(result) + + return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel) + + def cleanup(self): + with self._lock: + sandbox = self._sandbox + sync_manager = self._sync_manager + if sandbox is not None and sync_manager is not None: + try: + sync_manager.sync_back() + except Exception as exc: + logger.warning( + "Vercel: sync_back failed for task %s: %s", + self._task_id, + exc, + ) + self._sandbox = None + self._sync_manager = None + + if sandbox is None: + return + + snapshot_id = self._snapshot_sandbox(sandbox) + # Always stop the sandbox during cleanup to avoid resource leaks, + # matching the Modal and Daytona patterns. + self._stop_sandbox(sandbox) + self._close_sandbox_client(sandbox) diff --git a/tools/file_tools.py b/tools/file_tools.py index 7d81cd8f8e..7a7f092954 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -380,15 +380,17 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations: logger.info("Creating new %s environment for task %s...", env_type, task_id[:8]) container_config = None - if env_type in ("docker", "singularity", "modal", "daytona"): + if env_type in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): container_config = { "container_cpu": config.get("container_cpu", 1), "container_memory": config.get("container_memory", 5120), "container_disk": config.get("container_disk", 51200), "container_persistent": config.get("container_persistent", True), + "vercel_runtime": config.get("vercel_runtime", ""), "docker_volumes": config.get("docker_volumes", []), "docker_mount_cwd_to_workspace": config.get("docker_mount_cwd_to_workspace", False), "docker_forward_env": config.get("docker_forward_env", []), + "docker_run_as_host_user": config.get("docker_run_as_host_user", False), } ssh_config = None @@ -487,12 +489,15 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = task_data = _read_tracker.setdefault(task_id, { "last_key": None, "consecutive": 0, "read_history": set(), "dedup": {}, - "dedup_hits": {}, + "dedup_hits": {}, "read_timestamps": {}, }) # Backward-compat for pre-existing tracker entries that predate - # dedup_hits (long-lived task or crossed an upgrade boundary). + # dedup_hits/read_timestamps (long-lived task or crossed an + # upgrade boundary). if "dedup_hits" not in task_data: task_data["dedup_hits"] = {} + if "read_timestamps" not in task_data: + task_data["read_timestamps"] = {} cached_mtime = task_data.get("dedup", {}).get(dedup_key) if cached_mtime is not None: diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 8905c2237a..2a0115ec85 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -915,11 +915,12 @@ class MCPServerTask: except Exception: logger.exception("MCP server '%s': dynamic tool refresh failed", self.name) - def _schedule_tools_refresh(self) -> None: + def _schedule_tools_refresh(self) -> asyncio.Task: """Schedule a background tool refresh and keep it strongly referenced.""" task = asyncio.create_task(self._refresh_tools_task()) self._pending_refresh_tasks.add(task) task.add_done_callback(self._pending_refresh_tasks.discard) + return task def _make_message_handler(self): """Build a ``message_handler`` callback for ``ClientSession``. @@ -950,6 +951,10 @@ class MCPServerTask: # a separate task and let the handler return # promptly. self._schedule_tools_refresh() + # Yield one loop tick so tests and short-lived + # notification contexts can observe the scheduled + # refresh without awaiting the full server RPC. + await asyncio.sleep(0) case PromptListChangedNotification(): logger.debug("MCP server '%s': prompts/list_changed (ignored)", self.name) case ResourceListChangedNotification(): diff --git a/tools/process_registry.py b/tools/process_registry.py index 479030120d..da5c8d224b 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -800,6 +800,78 @@ class ProcessRegistry: session = self._running.get(session_id) or self._finished.get(session_id) return self._refresh_detached_session(session) + def _reconcile_local_exit(self, session: "ProcessSession") -> None: + """Reconcile session.exited against the real child process state. + + The reader thread (`_reader_loop`) sets `session.exited = True` only + in its `finally` block, which runs when `stdout.read()` returns EOF. + If the direct `Popen` child has exited but a descendant process (e.g. + a daemon spawned by `hermes update` restarting the gateway) is still + holding the stdout pipe open, the reader blocks forever and poll() + keeps returning "running" indefinitely (issue #17327 — 74 polls over + 7 minutes on Feishu). + + This helper closes that window: when `session.exited` is still False + but the direct child's `Popen.poll()` reports an exit code, drain any + readable bytes non-blocking and flip `session.exited`. The orphaned + reader thread remains stuck on its blocking `read()` but is a daemon + thread and will be reaped with the process. + + Safe no-op on sessions without a local `Popen` (env/PTY), already- + exited sessions, and detached-recovered sessions. + """ + if session is None or session.exited: + return + proc = getattr(session, "process", None) + if proc is None: + return + try: + rc = proc.poll() + except Exception: + return + if rc is None: + return # Direct child still running — reader block is legitimate. + + # Direct child exited. Try to drain any bytes the reader hasn't + # consumed yet. This is best-effort: if the pipe is held open by a + # descendant, the non-blocking read returns what's immediately + # available and we stop. + drained = "" + stdout = getattr(proc, "stdout", None) + if stdout is not None and not _IS_WINDOWS: + try: + import fcntl + fd = stdout.fileno() + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + try: + chunk = stdout.read() + if chunk: + drained = chunk if isinstance(chunk, str) else chunk.decode("utf-8", errors="replace") + except (BlockingIOError, OSError, ValueError): + pass + finally: + try: + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + except Exception: + pass + except Exception as e: + logger.debug("Non-blocking drain failed for %s: %s", session.id, e) + + with session._lock: + if drained: + session.output_buffer += drained + if len(session.output_buffer) > session.max_output_chars: + session.output_buffer = session.output_buffer[-session.max_output_chars:] + session.exited = True + session.exit_code = rc + logger.info( + "Reconciled session %s: direct child exited with code %s but reader " + "was still blocked (orphaned pipe). Flipped to exited.", + session.id, rc, + ) + self._move_to_finished(session) + def poll(self, session_id: str) -> dict: """Check status and get new output for a background process.""" from tools.ansi_strip import strip_ansi @@ -808,6 +880,10 @@ class ProcessRegistry: if session is None: return {"status": "not_found", "error": f"No process with ID {session_id}"} + # Reconcile against real child state before reading session.exited. + # Guards against orphaned-pipe reader hangs (issue #17327). + self._reconcile_local_exit(session) + with session._lock: output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else "" @@ -898,6 +974,10 @@ class ProcessRegistry: while time.monotonic() < deadline: session = self._refresh_detached_session(session) + # Reconcile against real child state — guards against orphaned- + # pipe reader hangs where the reader is blocked but the direct + # child has already exited (issue #17327). + self._reconcile_local_exit(session) if session.exited: self._completion_consumed.add(session_id) result = { diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index a2321c2e50..1a3ede29d6 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -40,8 +40,12 @@ _PHONE_PLATFORMS = frozenset({"signal", "sms", "whatsapp"}) _E164_TARGET_RE = re.compile(r"^\s*\+(\d{7,15})\s*$") _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"} -_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"} +_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a", ".flac"} _VOICE_EXTS = {".ogg", ".opus"} +# Telegram's Bot API sendAudio only accepts MP3 / M4A. Other audio +# formats either route through sendVoice (Opus/OGG) or fall back to +# document delivery. +_TELEGRAM_SEND_AUDIO_EXTS = {".mp3", ".m4a"} _URL_SECRET_QUERY_RE = re.compile( r"([?&](?:access_token|api[_-]?key|auth[_-]?token|token|signature|sig)=)([^&#\s]+)", re.IGNORECASE, @@ -205,30 +209,12 @@ def _handle_send(args): except Exception as e: return json.dumps(_error(f"Failed to load gateway config: {e}")) - platform_map = { - "telegram": Platform.TELEGRAM, - "discord": Platform.DISCORD, - "slack": Platform.SLACK, - "whatsapp": Platform.WHATSAPP, - "signal": Platform.SIGNAL, - "bluebubbles": Platform.BLUEBUBBLES, - "qqbot": Platform.QQBOT, - "matrix": Platform.MATRIX, - "mattermost": Platform.MATTERMOST, - "homeassistant": Platform.HOMEASSISTANT, - "dingtalk": Platform.DINGTALK, - "feishu": Platform.FEISHU, - "wecom": Platform.WECOM, - "wecom_callback": Platform.WECOM_CALLBACK, - "weixin": Platform.WEIXIN, - "email": Platform.EMAIL, - "sms": Platform.SMS, - "yuanbao": Platform.YUANBAO, - } - platform = platform_map.get(platform_name) - if not platform: - avail = ", ".join(platform_map.keys()) - return tool_error(f"Unknown platform: {platform_name}. Available: {avail}") + # Accept any platform name — built-in names resolve to their enum + # member, plugin platform names create dynamic members via _missing_(). + try: + platform = Platform(platform_name) + except (ValueError, KeyError): + return tool_error(f"Unknown platform: {platform_name}") pconfig = config.platforms.get(platform) if not pconfig or not pconfig.enabled: @@ -429,6 +415,27 @@ def _maybe_skip_cron_duplicate_send(platform_name: str, chat_id: str, thread_id: } +async def _send_via_adapter(platform, pconfig, chat_id, chunk): + """Send a message via a live gateway adapter (for plugin platforms). + + Falls back to error if no adapter is connected for this platform. + """ + try: + from gateway.run import _gateway_runner_ref + runner = _gateway_runner_ref() + if runner: + adapter = runner.adapters.get(platform) + if adapter: + from gateway.platforms.base import SendResult + result = await adapter.send(chat_id=chat_id, content=chunk) + if result.success: + return {"success": True, "message_id": result.message_id} + return {"error": f"Adapter send failed: {result.error}"} + except Exception as e: + return {"error": f"Plugin platform send failed: {e}"} + return {"error": f"No live adapter for platform '{platform.value}'. Is the gateway running with this platform connected?"} + + async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, media_files=None): """Route a message to the appropriate platform sender. @@ -473,6 +480,16 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, if _feishu_available: _MAX_LENGTHS[Platform.FEISHU] = FeishuAdapter.MAX_MESSAGE_LENGTH + # Check plugin registry for max_message_length + if platform not in _MAX_LENGTHS: + try: + from gateway.platform_registry import platform_registry + entry = platform_registry.get(platform.value) + if entry and entry.max_message_length > 0: + _MAX_LENGTHS[platform] = entry.max_message_length + except Exception: + pass + # Smart-chunk the message to fit within platform limits. # For short messages or platforms without a known limit this is a no-op. # Telegram measures length in UTF-16 code units, not Unicode codepoints. @@ -556,6 +573,21 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, last_result = result return last_result + # --- Yuanbao: native media attachment support via running gateway adapter --- + if platform == Platform.YUANBAO and media_files: + last_result = None + for i, chunk in enumerate(chunks): + is_last = (i == len(chunks) - 1) + result = await _send_yuanbao( + chat_id, + chunk, + media_files=media_files if is_last else None, + ) + if isinstance(result, dict) and result.get("error"): + return result + last_result = result + return last_result + # --- Non-media platforms --- if media_files and not message.strip(): return { @@ -599,8 +631,12 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, result = await _send_bluebubbles(pconfig.extra, chat_id, chunk) elif platform == Platform.QQBOT: result = await _send_qqbot(pconfig, chat_id, chunk) + elif platform == Platform.YUANBAO: + result = await _send_yuanbao(chat_id, chunk) else: - result = {"error": f"Direct sending not yet implemented for {platform.value}"} + # Plugin platform — route through the gateway's live adapter + # if available, otherwise report the error. + result = await _send_via_adapter(platform, pconfig, chat_id, chunk) if isinstance(result, dict) and result.get("error"): return result @@ -708,7 +744,7 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No last_msg = await bot.send_voice( chat_id=int_chat_id, voice=f, **thread_kwargs ) - elif ext in _AUDIO_EXTS: + elif ext in _TELEGRAM_SEND_AUDIO_EXTS: last_msg = await bot.send_audio( chat_id=int_chat_id, audio=f, **thread_kwargs ) diff --git a/tools/skill_manager_tool.py b/tools/skill_manager_tool.py index 37de1087c1..cc8b0fed28 100644 --- a/tools/skill_manager_tool.py +++ b/tools/skill_manager_tool.py @@ -43,6 +43,7 @@ from hermes_constants import get_hermes_home, display_hermes_home from typing import Dict, Any, Optional, Tuple from utils import atomic_replace +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def _guard_agent_created_enabled() -> bool: try: from hermes_cli.config import load_config cfg = load_config() - return bool(cfg.get("skills", {}).get("guard_agent_created", False)) + return bool(cfg_get(cfg, "skills", "guard_agent_created", default=False)) except Exception: return False @@ -108,16 +109,55 @@ MAX_NAME_LENGTH = 64 MAX_DESCRIPTION_LENGTH = 1024 -def _is_local_skill(skill_path: Path) -> bool: - """Check if a skill path is within the local SKILLS_DIR. +def _containing_skills_root(skill_path: Path) -> Path: + """Return the skills root directory (local or external_dirs entry) that + contains ``skill_path``. Falls back to the local ``SKILLS_DIR`` if no + match is found (defensive — callers should have located the skill via + ``_find_skill`` first). + """ + from agent.skill_utils import get_all_skills_dirs - Skills found in external_dirs are read-only from the agent's perspective. + try: + resolved = skill_path.resolve() + except OSError: + resolved = skill_path + + for root in get_all_skills_dirs(): + try: + resolved.relative_to(root.resolve()) + return root + except (ValueError, OSError): + continue + return SKILLS_DIR + + +def _pinned_guard(name: str) -> Optional[str]: + """Return a refusal message if *name* is pinned, else None. + + Pinned skills are off-limits to the agent's skill_manage tool. The only + way to modify one is for the user to unpin it via + ``hermes curator unpin `` (or edit it directly by hand). This + mirrors the curator's own pinned-skip behavior but extends the guard + to tool-driven writes as well, giving users a hard fence against + accidental agent edits. + + Best-effort: if the sidecar is unreadable we let the write through + rather than block on a broken telemetry file. """ try: - skill_path.resolve().relative_to(SKILLS_DIR.resolve()) - return True - except ValueError: - return False + from tools import skill_usage + rec = skill_usage.get_record(name) + if rec.get("pinned"): + return ( + f"Skill '{name}' is pinned and cannot be modified by " + f"skill_manage. Ask the user to run " + f"`hermes curator unpin {name}` if they want the change." + ) + except Exception: + logger.debug("pinned-guard lookup failed for %s", name, exc_info=True) + return None + + MAX_SKILL_CONTENT_CHARS = 100_000 # ~36k tokens at 2.75 chars/token MAX_SKILL_FILE_BYTES = 1_048_576 # 1 MiB per supporting file @@ -396,8 +436,9 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]: if not existing: return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."} - if not _is_local_skill(existing["path"]): - return {"success": False, "error": f"Skill '{name}' is in an external directory and cannot be modified. Copy it to your local skills directory first."} + pinned_err = _pinned_guard(name) + if pinned_err: + return {"success": False, "error": pinned_err} skill_md = existing["path"] / "SKILL.md" # Back up original content for rollback @@ -439,8 +480,9 @@ def _patch_skill( if not existing: return {"success": False, "error": f"Skill '{name}' not found."} - if not _is_local_skill(existing["path"]): - return {"success": False, "error": f"Skill '{name}' is in an external directory and cannot be modified. Copy it to your local skills directory first."} + pinned_err = _pinned_guard(name) + if pinned_err: + return {"success": False, "error": pinned_err} skill_dir = existing["path"] @@ -521,15 +563,17 @@ def _delete_skill(name: str) -> Dict[str, Any]: if not existing: return {"success": False, "error": f"Skill '{name}' not found."} - if not _is_local_skill(existing["path"]): - return {"success": False, "error": f"Skill '{name}' is in an external directory and cannot be deleted."} + pinned_err = _pinned_guard(name) + if pinned_err: + return {"success": False, "error": pinned_err} skill_dir = existing["path"] + skills_root = _containing_skills_root(skill_dir) shutil.rmtree(skill_dir) - # Clean up empty category directories (don't remove SKILLS_DIR itself) + # Clean up empty category directories (don't remove the skills root itself) parent = skill_dir.parent - if parent != SKILLS_DIR and parent.exists() and not any(parent.iterdir()): + if parent != skills_root and parent.exists() and not any(parent.iterdir()): parent.rmdir() return { @@ -566,8 +610,9 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]: if not existing: return {"success": False, "error": f"Skill '{name}' not found. Create it first with action='create'."} - if not _is_local_skill(existing["path"]): - return {"success": False, "error": f"Skill '{name}' is in an external directory and cannot be modified. Copy it to your local skills directory first."} + pinned_err = _pinned_guard(name) + if pinned_err: + return {"success": False, "error": pinned_err} target, err = _resolve_skill_target(existing["path"], file_path) if err: @@ -603,8 +648,9 @@ def _remove_file(name: str, file_path: str) -> Dict[str, Any]: if not existing: return {"success": False, "error": f"Skill '{name}' not found."} - if not _is_local_skill(existing["path"]): - return {"success": False, "error": f"Skill '{name}' is in an external directory and cannot be modified."} + pinned_err = _pinned_guard(name) + if pinned_err: + return {"success": False, "error": pinned_err} skill_dir = existing["path"] @@ -738,7 +784,10 @@ SKILL_MANAGE_SCHEMA = { "After difficult/iterative tasks, offer to save as a skill. " "Skip for simple one-offs. Confirm with user before creating/deleting.\n\n" "Good skills: trigger conditions, numbered steps with exact commands, " - "pitfalls section, verification steps. Use skill_view() to see format examples." + "pitfalls section, verification steps. Use skill_view() to see format examples.\n\n" + "Pinned skills are off-limits — all write actions refuse with a message " + "pointing the user to `hermes curator unpin `. Don't try to route " + "around this by renaming or recreating." ), "parameters": { "type": "object", diff --git a/tools/skills_tool.py b/tools/skills_tool.py index d501e6c85c..4ce338c59f 100644 --- a/tools/skills_tool.py +++ b/tools/skills_tool.py @@ -77,6 +77,7 @@ from pathlib import Path from typing import Dict, Any, List, Optional, Set, Tuple from tools.registry import registry, tool_error +from hermes_cli.config import cfg_get logger = logging.getLogger(__name__) @@ -100,7 +101,9 @@ _PLATFORM_MAP = { } _ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") _EXCLUDED_SKILL_DIRS = frozenset((".git", ".github", ".hub")) -_REMOTE_ENV_BACKENDS = frozenset({"docker", "singularity", "modal", "ssh", "daytona"}) +_REMOTE_ENV_BACKENDS = frozenset( + {"docker", "singularity", "modal", "ssh", "daytona", "vercel_sandbox"} +) _secret_capture_callback = None @@ -535,7 +538,7 @@ def _is_skill_disabled(name: str, platform: str = None) -> bool: skills_cfg = config.get("skills", {}) resolved_platform = platform or os.getenv("HERMES_PLATFORM") or _get_session_platform() if resolved_platform: - platform_disabled = skills_cfg.get("platform_disabled", {}).get(resolved_platform) + platform_disabled = cfg_get(skills_cfg, "platform_disabled", resolved_platform) if platform_disabled is not None: return name in platform_disabled return name in skills_cfg.get("disabled", []) @@ -1509,3 +1512,4 @@ registry.register( check_fn=check_skills_requirements, emoji="📚", ) + diff --git a/tools/slash_confirm.py b/tools/slash_confirm.py new file mode 100644 index 0000000000..81c1526352 --- /dev/null +++ b/tools/slash_confirm.py @@ -0,0 +1,162 @@ +"""Generic slash-command confirmation primitive (gateway-side). + +Slash commands that have a non-destructive but expensive side effect worth +surfacing to the user (currently only ``/reload-mcp``, which invalidates +the provider prompt cache) route through this module. + +Two delivery paths: + + 1. Button UI — adapters that override ``send_slash_confirm`` render + three inline buttons (Approve Once / Always Approve / Cancel). The + button callback calls ``resolve(session_key, confirm_id, choice)``. + + 2. Text fallback — adapters without button UIs get a plain text prompt. + Users reply with ``/approve``, ``/always``, or ``/cancel``; the + gateway's ``_handle_message`` intercepts those replies and calls + ``resolve()`` directly. + +State is stored module-level (like ``tools.approval``) so platform +adapters can resolve callbacks without needing a backreference to the +``GatewayRunner`` instance. The CLI path (``cli.py``) uses a local +synchronous variant — see ``_prompt_slash_confirm`` there. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +from typing import Any, Awaitable, Callable, Dict, Optional + +logger = logging.getLogger(__name__) + +# Pending confirmations keyed by gateway session_key. Each entry: +# { +# "confirm_id": str, +# "command": str, # e.g. "reload-mcp" +# "handler": Callable[[str], Awaitable[Optional[str]]], +# "created_at": float, # time.time() +# } +_pending: Dict[str, Dict[str, Any]] = {} +_lock = threading.RLock() + +# Default timeout — a pending confirm older than this is discarded when +# the next message arrives for the same session. Buttons work up until +# the adapter drops the callback_data (Telegram: ~48h; Discord: ephemeral; +# Slack: 3s ack + long-lived actions). +DEFAULT_TIMEOUT_SECONDS = 300 + + +def register( + session_key: str, + confirm_id: str, + command: str, + handler: Callable[[str], Awaitable[Optional[str]]], +) -> None: + """Register a pending slash-command confirmation. + + Overwrites any prior pending confirm for the same ``session_key`` — the + user invoking a new confirmable command supersedes the stale one. + """ + with _lock: + _pending[session_key] = { + "confirm_id": confirm_id, + "command": command, + "handler": handler, + "created_at": time.time(), + } + + +def get_pending(session_key: str) -> Optional[Dict[str, Any]]: + """Return the pending confirm dict for a session, or None.""" + with _lock: + entry = _pending.get(session_key) + return dict(entry) if entry else None + + +def clear(session_key: str) -> None: + """Drop the pending confirm for ``session_key`` without running it.""" + with _lock: + _pending.pop(session_key, None) + + +def clear_if_stale(session_key: str, timeout: float = DEFAULT_TIMEOUT_SECONDS) -> bool: + """Drop the pending confirm if older than ``timeout`` seconds. + + Returns True if an entry was dropped. + """ + with _lock: + entry = _pending.get(session_key) + if not entry: + return False + if time.time() - float(entry.get("created_at", 0) or 0) > timeout: + _pending.pop(session_key, None) + return True + return False + + +async def resolve( + session_key: str, + confirm_id: str, + choice: str, + timeout: float = DEFAULT_TIMEOUT_SECONDS, +) -> Optional[str]: + """Resolve a pending confirm. + + ``choice`` must be one of ``"once"``, ``"always"``, or ``"cancel"``. + Returns the handler's output string (to be sent as a follow-up + message), or ``None`` if the confirm was stale, already resolved, or + the confirm_id doesn't match. + + Safe to call from an asyncio callback (button click) or from the + gateway's message intercept path. + """ + with _lock: + entry = _pending.get(session_key) + if not entry: + return None + if entry.get("confirm_id") != confirm_id: + # Stale confirm_id — superseded by a newer prompt on the same session. + return None + # Pop before we run the handler to prevent duplicate callbacks + # (e.g. button double-click) from running it twice. + _pending.pop(session_key, None) + if time.time() - float(entry.get("created_at", 0) or 0) > timeout: + return None + handler = entry.get("handler") + command = entry.get("command", "?") + + if not handler: + return None + try: + result = await handler(choice) + except Exception as exc: + logger.error( + "Slash-confirm handler for /%s raised: %s", + command, exc, exc_info=True, + ) + return f"❌ Error handling confirmation: {exc}" + return result if isinstance(result, str) else None + + +def resolve_sync_compat( + loop: asyncio.AbstractEventLoop, + session_key: str, + confirm_id: str, + choice: str, +) -> Optional[str]: + """Synchronous helper: schedule resolve() on a loop and wait for the result. + + Used by platform callback paths that run on a different thread than the + event loop (e.g. Discord's button click handler in some configurations). + Prefer the async ``resolve()`` from an async context. + """ + try: + fut = asyncio.run_coroutine_threadsafe( + resolve(session_key, confirm_id, choice), loop, + ) + return fut.result(timeout=30) + except Exception as exc: + logger.error("resolve_sync_compat failed: %s", exc) + return None diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 395ee8f5b6..f9c203fe06 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -2,16 +2,19 @@ """ Terminal Tool Module -A terminal tool that executes commands in local, Docker, Modal, SSH, Singularity, and Daytona environments. -Supports local execution, containerized backends, and Modal cloud sandboxes, including managed gateway mode. +A terminal tool that executes commands in local, Docker, Modal, SSH, +Singularity, Daytona, and Vercel Sandbox environments. Supports local +execution, containerized backends, and cloud sandboxes, including managed +Modal mode. Environment Selection (via TERMINAL_ENV environment variable): - "local": Execute directly on the host machine (default, fastest) - "docker": Execute in Docker containers (isolated, requires Docker) - "modal": Execute in Modal cloud sandboxes (direct Modal or managed gateway) +- "vercel_sandbox": Execute in Vercel Sandbox cloud sandboxes Features: -- Multiple execution backends (local, docker, modal) +- Multiple execution backends (local, docker, modal, vercel_sandbox) - Background task support - VM/container lifecycle management - Automatic cleanup after inactivity @@ -114,6 +117,68 @@ DISK_USAGE_WARNING_THRESHOLD_GB = _safe_parse_import_env( float, "number", ) +_VERCEL_SANDBOX_DEFAULT_CWD = "/vercel/sandbox" +_SUPPORTED_VERCEL_RUNTIMES = ("node24", "node22", "python3.13") + + +def _is_supported_vercel_runtime(runtime: str) -> bool: + return not runtime or runtime in _SUPPORTED_VERCEL_RUNTIMES + + +def _check_vercel_sandbox_requirements(config: dict[str, Any]) -> bool: + """Validate Vercel Sandbox terminal backend requirements.""" + runtime = (config.get("vercel_runtime") or "").strip() + if not _is_supported_vercel_runtime(runtime): + supported = ", ".join(_SUPPORTED_VERCEL_RUNTIMES) + logger.error( + "Vercel Sandbox runtime %r is not supported. " + "Set TERMINAL_VERCEL_RUNTIME to one of: %s.", + runtime, + supported, + ) + return False + + disk = config.get("container_disk", 51200) + if disk not in (0, 51200): + logger.error( + "Vercel Sandbox does not support custom TERMINAL_CONTAINER_DISK=%s. " + "Use the default shared setting (51200 MB).", + disk, + ) + return False + + if importlib.util.find_spec("vercel") is None: + logger.error( + "vercel is required for the Vercel Sandbox terminal backend: pip install vercel" + ) + return False + + has_oidc = bool(os.getenv("VERCEL_OIDC_TOKEN")) + has_token = bool(os.getenv("VERCEL_TOKEN")) + has_project = bool(os.getenv("VERCEL_PROJECT_ID")) + has_team = bool(os.getenv("VERCEL_TEAM_ID")) + + if has_oidc: + return True + + if has_token or has_project or has_team: + if has_token and has_project and has_team: + return True + logger.error( + "Vercel Sandbox backend selected with token auth, but " + "VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID must all " + "be set together. VERCEL_OIDC_TOKEN is supported for one-off " + "local development only." + ) + return False + + logger.error( + "Vercel Sandbox backend selected but no supported auth configuration " + "was found. Set VERCEL_TOKEN, VERCEL_PROJECT_ID, and VERCEL_TEAM_ID " + "for normal use. VERCEL_OIDC_TOKEN is supported for one-off local " + "development only." + ) + return False def _check_disk_usage_warning(): @@ -744,9 +809,10 @@ def _transform_sudo_command(command: str | None) -> tuple[str | None, str | None should prepend sudo_stdin to their stdin_data and pass the merged bytes to Popen's stdin pipe. - Callers that cannot pipe subprocess stdin (modal, daytona) must embed the - password in the command string themselves; see their execute() methods for - how they handle the non-None sudo_stdin case. + Callers that cannot pipe subprocess stdin (modal, daytona, + vercel_sandbox) must embed the password in the command string + themselves; see their execute() methods for how they handle the + non-None sudo_stdin case. If SUDO_PASSWORD is not set and in interactive mode (HERMES_INTERACTIVE=1): Prompts user for password with 45s timeout, caches for session. @@ -910,13 +976,15 @@ def _get_env_config() -> Dict[str, Any]: mount_docker_cwd = os.getenv("TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", "false").lower() in ("true", "1", "yes") - # Default cwd: local uses the host's current directory, everything - # else starts in the user's home (~ resolves to whatever account - # is running inside the container/remote). + # Default cwd: local uses the host's current directory, ssh uses the + # remote home, Vercel uses its documented workspace root, and everything + # else starts in the backend's default root-like cwd. if env_type == "local": default_cwd = os.getcwd() elif env_type == "ssh": default_cwd = "~" + elif env_type == "vercel_sandbox": + default_cwd = _VERCEL_SANDBOX_DEFAULT_CWD else: default_cwd = "/root" @@ -938,7 +1006,7 @@ def _get_env_config() -> Dict[str, Any]: ): host_cwd = candidate cwd = "/workspace" - elif env_type in ("modal", "docker", "singularity", "daytona") and cwd: + elif env_type in ("modal", "docker", "singularity", "daytona", "vercel_sandbox") and cwd: # Host paths and relative paths that won't work inside containers is_host_path = any(cwd.startswith(p) for p in host_prefixes) is_relative = not os.path.isabs(cwd) # e.g. "." or "src/" @@ -956,6 +1024,7 @@ def _get_env_config() -> Dict[str, Any]: "singularity_image": os.getenv("TERMINAL_SINGULARITY_IMAGE", f"docker://{default_image}"), "modal_image": os.getenv("TERMINAL_MODAL_IMAGE", default_image), "daytona_image": os.getenv("TERMINAL_DAYTONA_IMAGE", default_image), + "vercel_runtime": os.getenv("TERMINAL_VERCEL_RUNTIME", "").strip(), "cwd": cwd, "host_cwd": host_cwd, "docker_mount_cwd_to_workspace": mount_docker_cwd, @@ -974,12 +1043,14 @@ def _get_env_config() -> Dict[str, Any]: os.getenv("TERMINAL_PERSISTENT_SHELL", "true"), ).lower() in ("true", "1", "yes"), "local_persistent": os.getenv("TERMINAL_LOCAL_PERSISTENT", "false").lower() in ("true", "1", "yes"), - # Container resource config (applies to docker, singularity, modal, daytona -- ignored for local/ssh) + # Container resource config (applies to docker, singularity, modal, + # daytona, and vercel_sandbox -- ignored for local/ssh) "container_cpu": _parse_env_var("TERMINAL_CONTAINER_CPU", "1", float, "number"), "container_memory": _parse_env_var("TERMINAL_CONTAINER_MEMORY", "5120"), # MB (default 5GB) "container_disk": _parse_env_var("TERMINAL_CONTAINER_DISK", "51200"), # MB (default 50GB) "container_persistent": os.getenv("TERMINAL_CONTAINER_PERSISTENT", "true").lower() in ("true", "1", "yes"), "docker_volumes": _parse_env_var("TERMINAL_DOCKER_VOLUMES", "[]", json.loads, "valid JSON"), + "docker_run_as_host_user": os.getenv("TERMINAL_DOCKER_RUN_AS_HOST_USER", "false").lower() in ("true", "1", "yes"), } @@ -1001,8 +1072,9 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, Create an execution environment for sandboxed command execution. Args: - env_type: One of "local", "docker", "singularity", "modal", "daytona", "ssh" - image: Docker/Singularity/Modal image name (ignored for local/ssh) + env_type: One of "local", "docker", "singularity", "modal", + "daytona", "vercel_sandbox", "ssh" + image: Docker/Singularity/Modal image name (ignored for local/ssh/vercel) cwd: Working directory timeout: Default command timeout ssh_config: SSH connection config (for env_type="ssh") @@ -1035,6 +1107,7 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, auto_mount_cwd=cc.get("docker_mount_cwd_to_workspace", False), forward_env=docker_forward_env, env=docker_env, + run_as_host_user=cc.get("docker_run_as_host_user", False), ) elif env_type == "singularity": @@ -1105,6 +1178,21 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, persistent_filesystem=persistent, task_id=task_id, ) + elif env_type == "vercel_sandbox": + from tools.environments.vercel_sandbox import ( + VercelSandboxEnvironment as _VercelSandboxEnvironment, + ) + return _VercelSandboxEnvironment( + runtime=cc.get("vercel_runtime") or None, + cwd=cwd, + timeout=timeout, + cpu=cpu, + memory=memory, + disk=disk, + persistent_filesystem=persistent, + task_id=task_id, + ) + elif env_type == "ssh": if not ssh_config or not ssh_config.get("host") or not ssh_config.get("user"): raise ValueError("SSH environment requires ssh_host and ssh_user to be configured") @@ -1118,7 +1206,10 @@ def _create_environment(env_type: str, image: str, cwd: str, timeout: int, ) else: - raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', 'singularity', 'modal', 'daytona', or 'ssh'") + raise ValueError( + f"Unknown environment type: {env_type}. Use 'local', 'docker', " + f"'singularity', 'modal', 'daytona', 'vercel_sandbox', or 'ssh'" + ) def _cleanup_inactive_envs(lifetime_seconds: int = 300): @@ -1652,17 +1743,19 @@ def terminal_tool( } container_config = None - if env_type in ("docker", "singularity", "modal", "daytona"): + if env_type in ("docker", "singularity", "modal", "daytona", "vercel_sandbox"): container_config = { "container_cpu": config.get("container_cpu", 1), "container_memory": config.get("container_memory", 5120), "container_disk": config.get("container_disk", 51200), "container_persistent": config.get("container_persistent", True), "modal_mode": config.get("modal_mode", "auto"), + "vercel_runtime": config.get("vercel_runtime", ""), "docker_volumes": config.get("docker_volumes", []), "docker_mount_cwd_to_workspace": config.get("docker_mount_cwd_to_workspace", False), "docker_forward_env": config.get("docker_forward_env", []), "docker_env": config.get("docker_env", {}), + "docker_run_as_host_user": config.get("docker_run_as_host_user", False), } local_config = None @@ -1987,10 +2080,10 @@ def terminal_tool( def check_terminal_requirements() -> bool: """Check if all requirements for the terminal tool are met.""" - config = _get_env_config() - env_type = config["env_type"] - try: + config = _get_env_config() + env_type = config["env_type"] + if env_type == "local": return True @@ -2074,6 +2167,9 @@ def check_terminal_requirements() -> bool: return True + elif env_type == "vercel_sandbox": + return _check_vercel_sandbox_requirements(config) + elif env_type == "daytona": from daytona import Daytona # noqa: F401 — SDK presence check return os.getenv("DAYTONA_API_KEY") is not None @@ -2081,7 +2177,7 @@ def check_terminal_requirements() -> bool: else: logger.error( "Unknown TERMINAL_ENV '%s'. Use one of: local, docker, singularity, " - "modal, daytona, ssh.", + "modal, daytona, vercel_sandbox, ssh.", env_type, ) return False @@ -2121,7 +2217,11 @@ if __name__ == "__main__": print("\nEnvironment Variables:") default_img = "nikolaik/python-nodejs:python3.11-nodejs20" - print(f" TERMINAL_ENV: {os.getenv('TERMINAL_ENV', 'local')} (local/docker/singularity/modal/daytona/ssh)") + print( + " TERMINAL_ENV: " + f"{os.getenv('TERMINAL_ENV', 'local')} " + "(local/docker/singularity/modal/daytona/vercel_sandbox/ssh)" + ) print(f" TERMINAL_DOCKER_IMAGE: {os.getenv('TERMINAL_DOCKER_IMAGE', default_img)}") print(f" TERMINAL_SINGULARITY_IMAGE: {os.getenv('TERMINAL_SINGULARITY_IMAGE', f'docker://{default_img}')}") print(f" TERMINAL_MODAL_IMAGE: {os.getenv('TERMINAL_MODAL_IMAGE', default_img)}") diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index bbc9a10e6a..663345eb74 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -42,6 +42,20 @@ from tools.tool_backend_helpers import managed_nous_tools_enabled, resolve_opena logger = logging.getLogger(__name__) +def get_env_value(name, default=None): + """Read env values through the live config module. + + Tests may monkeypatch and later restore ``hermes_cli.config.get_env_value`` + before this module is imported. Resolve the helper at call time so STT does + not keep a stale imported function for the rest of the test process. + """ + try: + from hermes_cli.config import get_env_value as _get_env_value + except ImportError: + return os.getenv(name, default) + value = _get_env_value(name) + return default if value is None else value + # --------------------------------------------------------------------------- # Optional imports — graceful degradation # --------------------------------------------------------------------------- @@ -222,7 +236,7 @@ def _get_provider(stt_config: dict) -> str: return "none" if provider == "groq": - if _HAS_OPENAI and os.getenv("GROQ_API_KEY"): + if _HAS_OPENAI and get_env_value("GROQ_API_KEY"): return "groq" logger.warning( "STT provider 'groq' configured but GROQ_API_KEY not set" @@ -238,7 +252,7 @@ def _get_provider(stt_config: dict) -> str: return "none" if provider == "mistral": - if _HAS_MISTRAL and os.getenv("MISTRAL_API_KEY"): + if _HAS_MISTRAL and get_env_value("MISTRAL_API_KEY"): return "mistral" logger.warning( "STT provider 'mistral' configured but mistralai package " @@ -247,7 +261,7 @@ def _get_provider(stt_config: dict) -> str: return "none" if provider == "xai": - if os.getenv("XAI_API_KEY"): + if get_env_value("XAI_API_KEY"): return "xai" logger.warning( "STT provider 'xai' configured but XAI_API_KEY not set" @@ -262,16 +276,16 @@ def _get_provider(stt_config: dict) -> str: return "local" if _has_local_command(): return "local_command" - if _HAS_OPENAI and os.getenv("GROQ_API_KEY"): + if _HAS_OPENAI and get_env_value("GROQ_API_KEY"): logger.info("No local STT available, using Groq Whisper API") return "groq" if _HAS_OPENAI and _has_openai_audio_backend(): logger.info("No local STT available, using OpenAI Whisper API") return "openai" - if _HAS_MISTRAL and os.getenv("MISTRAL_API_KEY"): + if _HAS_MISTRAL and get_env_value("MISTRAL_API_KEY"): logger.info("No local STT available, using Mistral Voxtral Transcribe API") return "mistral" - if os.getenv("XAI_API_KEY"): + if get_env_value("XAI_API_KEY"): logger.info("No local STT available, using xAI Grok STT API") return "xai" return "none" @@ -527,7 +541,7 @@ def _transcribe_local_command(file_path: str, model_name: str) -> Dict[str, Any] def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]: """Transcribe using Groq Whisper API (free tier available).""" - api_key = os.getenv("GROQ_API_KEY") + api_key = get_env_value("GROQ_API_KEY") if not api_key: return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"} @@ -640,7 +654,7 @@ def _transcribe_mistral(file_path: str, model_name: str) -> Dict[str, Any]: Uses the ``mistralai`` Python SDK to call ``/v1/audio/transcriptions``. Requires ``MISTRAL_API_KEY`` environment variable. """ - api_key = os.getenv("MISTRAL_API_KEY") + api_key = get_env_value("MISTRAL_API_KEY") if not api_key: return {"success": False, "transcript": "", "error": "MISTRAL_API_KEY not set"} @@ -680,7 +694,7 @@ def _transcribe_xai(file_path: str, model_name: str) -> Dict[str, Any]: Supports Inverse Text Normalization, diarization, and word-level timestamps. Requires ``XAI_API_KEY`` environment variable. """ - api_key = os.getenv("XAI_API_KEY") + api_key = get_env_value("XAI_API_KEY") if not api_key: return {"success": False, "transcript": "", "error": "XAI_API_KEY not set"} @@ -688,7 +702,7 @@ def _transcribe_xai(file_path: str, model_name: str) -> Dict[str, Any]: xai_config = stt_config.get("xai", {}) base_url = str( xai_config.get("base_url") - or os.getenv("XAI_STT_BASE_URL") + or get_env_value("XAI_STT_BASE_URL") or XAI_STT_BASE_URL ).strip().rstrip("/") language = str( diff --git a/tools/tts_tool.py b/tools/tts_tool.py index a7ca57fab1..7473b32a1d 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -2,14 +2,24 @@ """ Text-to-Speech Tool Module -Supports seven TTS providers: +Built-in TTS providers: - Edge TTS (default, free, no API key): Microsoft Edge neural voices - ElevenLabs (premium): High-quality voices, needs ELEVENLABS_API_KEY - OpenAI TTS: Good quality, needs OPENAI_API_KEY - MiniMax TTS: High-quality with voice cloning, needs MINIMAX_API_KEY - Mistral (Voxtral TTS): Multilingual, native Opus, needs MISTRAL_API_KEY - Google Gemini TTS: Controllable, 30 prebuilt voices, needs GEMINI_API_KEY -- NeuTTS (local, free, no API key): On-device TTS via neutts_cli, needs neutts installed +- xAI TTS: Grok voices, needs XAI_API_KEY +- NeuTTS (local, free, no API key): On-device TTS via neutts +- KittenTTS (local, free, no API key): On-device 25MB model +- Piper (local, free, no API key): OHF-Voice/piper1-gpl neural VITS, 44 languages + +Custom command providers: +- Users can declare any number of named providers with ``type: command`` + under ``tts.providers.`` in ``~/.hermes/config.yaml``. Hermes + writes the input text to a temp file and runs the configured shell + command, which must produce the audio file at the expected path. + See the Local Command section of ``website/docs/user-guide/features/tts.md``. Output formats: - Opus (.ogg) for Telegram voice bubbles (requires ffmpeg for Edge TTS) @@ -32,7 +42,9 @@ import logging import os import queue import re +import shlex import shutil +import signal import subprocess import tempfile import threading @@ -44,6 +56,19 @@ from urllib.parse import urljoin from hermes_constants import display_hermes_home logger = logging.getLogger(__name__) +def get_env_value(name, default=None): + """Read env values through the live config module. + + Tests may monkeypatch and later restore ``hermes_cli.config.get_env_value`` + before this module is imported. Resolve the helper at call time so TTS does + not keep a stale imported function for the rest of the test process. + """ + try: + from hermes_cli.config import get_env_value as _get_env_value + except ImportError: + return os.getenv(name, default) + value = _get_env_value(name) + return default if value is None else value from tools.managed_tool_gateway import resolve_managed_tool_gateway from tools.tool_backend_helpers import managed_nous_tools_enabled, prefers_gateway, resolve_openai_audio_api_key from tools.xai_http import hermes_xai_user_agent @@ -85,6 +110,18 @@ def _import_kittentts(): return KittenTTS +def _import_piper(): + """Lazy import Piper. Returns the PiperVoice class or raises ImportError. + + Piper is an optional, fully-local neural TTS engine (Home Assistant / + Open Home Foundation). ``pip install piper-tts`` provides cross-platform + wheels (Linux / macOS / Windows, x86_64 + ARM64) with embedded espeak-ng. + Voice models (.onnx + .onnx.json) are downloaded on first use. + """ + from piper import PiperVoice + return PiperVoice + + # =========================================================================== # Defaults # =========================================================================== @@ -96,6 +133,7 @@ DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5" DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts" DEFAULT_KITTENTTS_MODEL = "KittenML/kitten-tts-nano-0.8-int8" # 25MB DEFAULT_KITTENTTS_VOICE = "Jasper" +DEFAULT_PIPER_VOICE = "en_US-lessac-medium" # balanced size/quality DEFAULT_OPENAI_VOICE = "alloy" DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1" DEFAULT_MINIMAX_MODEL = "speech-2.8-hd" @@ -139,6 +177,7 @@ PROVIDER_MAX_TEXT_LENGTH: Dict[str, int] = { "elevenlabs": 10000, # fallback when model-aware lookup can't resolve (multilingual_v2) "neutts": 2000, # local model, quality falls off on long text "kittentts": 2000, # local 25MB model + "piper": 5000, # local VITS model, phoneme-based; practical cap } # ElevenLabs caps vary by model_id. https://elevenlabs.io/docs/overview/models @@ -168,9 +207,13 @@ def _resolve_max_text_length( Resolution order: 1. ``tts..max_text_length`` (user override in config.yaml) - 2. ElevenLabs model-aware table (keyed on configured ``model_id``) - 3. ``PROVIDER_MAX_TEXT_LENGTH`` default - 4. ``FALLBACK_MAX_TEXT_LENGTH`` (4000) + 2. ``tts.providers..max_text_length`` for user-declared + command providers + 3. ElevenLabs model-aware table (keyed on configured ``model_id``) + 4. ``PROVIDER_MAX_TEXT_LENGTH`` default + 5. ``DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH`` when the provider is a + command-type user provider without an explicit cap + 6. ``FALLBACK_MAX_TEXT_LENGTH`` (4000) Non-positive or non-integer overrides fall through to the default so a broken config can't accidentally disable truncation entirely. @@ -179,11 +222,12 @@ def _resolve_max_text_length( return FALLBACK_MAX_TEXT_LENGTH key = provider.lower().strip() cfg = tts_config or {} - prov_cfg = cfg.get(key) if isinstance(cfg.get(key), dict) else {} + # Built-in-style override at tts..max_text_length wins first, + # matching historical behavior. + prov_cfg = cfg.get(key) if isinstance(cfg.get(key), dict) else {} override = prov_cfg.get("max_text_length") if prov_cfg else None if isinstance(override, bool): - # bool is an int subclass; treat explicit booleans as "not set" override = None if isinstance(override, int) and override > 0: return override @@ -194,7 +238,21 @@ def _resolve_max_text_length( if mapped: return mapped - return PROVIDER_MAX_TEXT_LENGTH.get(key, FALLBACK_MAX_TEXT_LENGTH) + if key in PROVIDER_MAX_TEXT_LENGTH: + return PROVIDER_MAX_TEXT_LENGTH[key] + + # User-declared command provider (under tts.providers.) + if key not in BUILTIN_TTS_PROVIDERS: + named = _get_named_provider_config(cfg, key) + if _is_command_provider_config(named): + named_override = named.get("max_text_length") + if isinstance(named_override, bool): + named_override = None + if isinstance(named_override, int) and named_override > 0: + return named_override + return DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH + + return FALLBACK_MAX_TEXT_LENGTH # =========================================================================== @@ -224,6 +282,409 @@ def _get_provider(tts_config: Dict[str, Any]) -> str: return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip() +# =========================================================================== +# Custom command providers (type: command under tts.providers.) +# =========================================================================== +# +# Users can declare any number of command-type providers alongside the +# built-ins so they can plug any local CLI (Piper, VoxCPM, Kokoro CLIs, +# custom voice-cloning scripts, etc.) into Hermes without any Python code +# changes. The config shape is:: +# +# tts: +# provider: piper-en +# providers: +# piper-en: +# type: command +# command: "piper -m ~/model.onnx -f {output_path} < {input_path}" +# output_format: wav +# +# Hermes writes the input text to a temp UTF-8 file, runs the command with +# placeholder substitution, and reads the audio file the command wrote to +# ``{output_path}``. Supported placeholders: ``{input_path}``, +# ``{text_path}`` (alias for input_path), ``{output_path}``, ``{format}``, +# ``{voice}``, ``{model}``, ``{speed}``. Use ``{{`` / ``}}`` for literal braces. +# +# Built-in provider names always win over an entry with the same name under +# ``tts.providers``, so user config can't silently shadow ``edge`` etc. +# +# Placeholder values are shell-quoted for their surrounding context +# (bare / single / double quote), so paths with spaces work transparently. + +# Built-in provider names. Any ``tts.provider`` value NOT in this set is +# interpreted as a reference to ``tts.providers.``. +BUILTIN_TTS_PROVIDERS = frozenset({ + "edge", + "elevenlabs", + "openai", + "minimax", + "xai", + "mistral", + "gemini", + "neutts", + "kittentts", + "piper", +}) + +DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS = 120 +DEFAULT_COMMAND_TTS_OUTPUT_FORMAT = "mp3" +COMMAND_TTS_OUTPUT_FORMATS = frozenset({"mp3", "wav", "ogg", "flac"}) +DEFAULT_COMMAND_TTS_MAX_TEXT_LENGTH = 5000 + + +def _get_provider_section(tts_config: Dict[str, Any], name: str) -> Dict[str, Any]: + """Return a provider config block if it's a dict, else an empty dict.""" + if not isinstance(tts_config, dict): + return {} + section = tts_config.get(name) + return section if isinstance(section, dict) else {} + + +def _get_named_provider_config( + tts_config: Dict[str, Any], + name: str, +) -> Dict[str, Any]: + """Return the config dict for a user-declared provider. + + Looks up ``tts.providers.`` first (the canonical location), and + falls back to ``tts.`` so users who followed the built-in layout + still work. Returns an empty dict when the provider is not declared. + """ + providers = _get_provider_section(tts_config, "providers") + section = providers.get(name) if isinstance(providers, dict) else None + if isinstance(section, dict): + return section + # Back-compat: allow ``tts.`` for user-declared providers too, + # but only when the name is not a built-in (so a user's ``tts.openai`` + # block still means the OpenAI provider, not a custom command). + if name.lower() not in BUILTIN_TTS_PROVIDERS: + legacy = _get_provider_section(tts_config, name) + if legacy: + return legacy + return {} + + +def _is_command_provider_config(config: Dict[str, Any]) -> bool: + """Return True when *config* declares a command-type provider.""" + if not isinstance(config, dict): + return False + ptype = str(config.get("type") or "").strip().lower() + if ptype and ptype != "command": + return False + command = config.get("command") + return isinstance(command, str) and bool(command.strip()) + + +def _resolve_command_provider_config( + provider: str, + tts_config: Dict[str, Any], +) -> Optional[Dict[str, Any]]: + """Return the provider config if *provider* resolves to a command type. + + Built-in provider names are rejected (they have native handlers). + Returns None when the name is a built-in, unknown, or not a command + type. + """ + if not provider: + return None + key = provider.lower().strip() + if key in BUILTIN_TTS_PROVIDERS: + return None + config = _get_named_provider_config(tts_config, key) + if _is_command_provider_config(config): + return config + return None + + +def _iter_command_providers(tts_config: Dict[str, Any]): + """Yield (name, config) pairs for every declared command-type provider.""" + if not isinstance(tts_config, dict): + return + providers = _get_provider_section(tts_config, "providers") + for name, cfg in (providers or {}).items(): + if isinstance(name, str) and name.lower() not in BUILTIN_TTS_PROVIDERS: + if _is_command_provider_config(cfg): + yield name, cfg + + +def _get_command_tts_timeout(config: Dict[str, Any]) -> float: + """Return timeout in seconds, falling back when invalid.""" + raw = config.get("timeout", config.get("timeout_seconds", DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS)) + try: + value = float(raw) + except (TypeError, ValueError): + return float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + if value <= 0: + return float(DEFAULT_COMMAND_TTS_TIMEOUT_SECONDS) + return value + + +def _get_command_tts_output_format( + config: Dict[str, Any], + output_path: Optional[str] = None, +) -> str: + """Return the validated output format (mp3/wav/ogg/flac).""" + if output_path: + suffix = Path(output_path).suffix.lower().strip().lstrip(".") + if suffix in COMMAND_TTS_OUTPUT_FORMATS: + return suffix + raw = ( + config.get("format") + or config.get("output_format") + or DEFAULT_COMMAND_TTS_OUTPUT_FORMAT + ) + fmt = str(raw).lower().strip().lstrip(".") + return fmt if fmt in COMMAND_TTS_OUTPUT_FORMATS else DEFAULT_COMMAND_TTS_OUTPUT_FORMAT + + +def _is_command_tts_voice_compatible(config: Dict[str, Any]) -> bool: + """Return True only when the user explicitly opted in to voice delivery.""" + value = config.get("voice_compatible", False) + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _shell_quote_context(command_template: str, position: int) -> Optional[str]: + """Return the shell quote character active right before *position*. + + Returns ``"'"`` / ``'"'`` when inside a single- / double-quoted region + of the template, ``None`` for bare context. + """ + quote: Optional[str] = None + escaped = False + i = 0 + while i < position: + char = command_template[i] + if quote == "'": + if char == "'": + quote = None + elif quote == '"': + if escaped: + escaped = False + elif char == "\\": + escaped = True + elif char == '"': + quote = None + else: + if char == "'": + quote = "'" + elif char == '"': + quote = '"' + elif char == "\\": + i += 1 + i += 1 + return quote + + +def _quote_command_tts_placeholder(value: str, quote_context: Optional[str]) -> str: + """Quote a placeholder value for its position in a shell command template.""" + if quote_context == "'": + return value.replace("'", r"'\''") + if quote_context == '"': + return ( + value + .replace("\\", "\\\\") + .replace('"', r'\"') + .replace("$", r"\$") + .replace("`", r"\`") + ) + if os.name == "nt": + return subprocess.list2cmdline([value]) + return shlex.quote(value) + + +def _render_command_tts_template( + command_template: str, + placeholders: Dict[str, str], +) -> str: + """Replace supported placeholders while preserving ``{{`` / ``}}``.""" + names = "|".join(re.escape(name) for name in placeholders) + pattern = re.compile( + rf"(?{names})\}}\}}|\{{(?P{names})\}})" + ) + replacements: list[tuple[str, str]] = [] + + def replace_match(match: re.Match[str]) -> str: + name = match.group("double") or match.group("single") + token = f"__HERMES_TTS_PLACEHOLDER_{len(replacements)}__" + replacements.append(( + token, + _quote_command_tts_placeholder( + placeholders[name], + _shell_quote_context(command_template, match.start()), + ), + )) + return token + + rendered = pattern.sub(replace_match, command_template) + rendered = rendered.replace("{{", "{").replace("}}", "}") + for token, value in replacements: + rendered = rendered.replace(token, value) + return rendered + + +def _terminate_command_tts_process_tree(proc: subprocess.Popen) -> None: + """Best-effort termination of a shell process and all of its children.""" + if proc.poll() is not None: + return + + if os.name == "nt": + try: + subprocess.run( + ["taskkill", "/F", "/T", "/PID", str(proc.pid)], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + timeout=5, + ) + except Exception: + proc.kill() + return + + try: + os.killpg(proc.pid, signal.SIGTERM) + except ProcessLookupError: + return + except Exception: + proc.terminate() + + try: + proc.wait(timeout=2) + return + except subprocess.TimeoutExpired: + pass + + try: + os.killpg(proc.pid, signal.SIGKILL) + except ProcessLookupError: + return + except Exception: + proc.kill() + + +def _run_command_tts(command: str, timeout: float) -> subprocess.CompletedProcess: + """Run a command-provider shell command with process-tree timeout cleanup.""" + popen_kwargs: Dict[str, Any] = { + "shell": True, + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE, + "text": True, + } + if os.name == "nt": + popen_kwargs["creationflags"] = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0) + else: + popen_kwargs["start_new_session"] = True + + proc = subprocess.Popen(command, **popen_kwargs) + try: + stdout, stderr = proc.communicate(timeout=timeout) + except subprocess.TimeoutExpired as exc: + _terminate_command_tts_process_tree(proc) + try: + stdout, stderr = proc.communicate(timeout=1) + except Exception: + stdout = getattr(exc, "output", None) + stderr = getattr(exc, "stderr", None) + raise subprocess.TimeoutExpired( + command, + timeout, + output=stdout, + stderr=stderr, + ) from exc + + if proc.returncode: + raise subprocess.CalledProcessError( + proc.returncode, + command, + output=stdout, + stderr=stderr, + ) + return subprocess.CompletedProcess(command, proc.returncode, stdout, stderr) + + +def _configured_command_tts_output_path(path: Path, config: Dict[str, Any]) -> Path: + """Return an output path whose extension matches the provider's output_format.""" + fmt = _get_command_tts_output_format(config) + return path.with_suffix(f".{fmt}") + + +def _generate_command_tts( + text: str, + output_path: str, + provider_name: str, + config: Dict[str, Any], + tts_config: Dict[str, Any], +) -> str: + """Generate speech by running a user-configured shell command. + + Returns the absolute path of the audio file the command wrote. + Raises ``ValueError`` when the provider config is invalid, and + ``RuntimeError`` for timeouts / non-zero exits / empty output. + """ + command_template = str(config.get("command") or "").strip() + if not command_template: + raise ValueError( + f"tts.providers.{provider_name}.command is not configured" + ) + + output = Path(output_path).expanduser() + output.parent.mkdir(parents=True, exist_ok=True) + if output.exists(): + output.unlink() + + timeout = _get_command_tts_timeout(config) + output_format = _get_command_tts_output_format(config, str(output)) + speed = config.get("speed", tts_config.get("speed", "")) + + with tempfile.TemporaryDirectory() as tmpdir: + text_path = Path(tmpdir) / "input.txt" + text_path.write_text(text, encoding="utf-8") + + placeholders = { + "input_path": str(text_path), + "text_path": str(text_path), + "output_path": str(output), + "format": output_format, + "voice": str(config.get("voice", "")), + "model": str(config.get("model", "")), + "speed": str(speed), + } + command = _render_command_tts_template(command_template, placeholders) + + try: + _run_command_tts(command, timeout) + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"TTS provider '{provider_name}' timed out after {timeout:g}s" + ) from exc + except subprocess.CalledProcessError as exc: + detail_parts = [] + if exc.stderr: + detail_parts.append(f"stderr: {exc.stderr.strip()}") + if exc.stdout: + detail_parts.append(f"stdout: {exc.stdout.strip()}") + detail = "; ".join(detail_parts) or "no command output" + raise RuntimeError( + f"TTS provider '{provider_name}' exited with code " + f"{exc.returncode}: {detail}" + ) from exc + + if not output.exists() or output.stat().st_size <= 0: + raise RuntimeError( + f"TTS provider '{provider_name}' produced no output at {output}" + ) + return str(output) + + +def _has_any_command_tts_provider(tts_config: Optional[Dict[str, Any]] = None) -> bool: + """Return True when any command-type TTS provider is configured.""" + if tts_config is None: + tts_config = _load_tts_config() + for _name, _cfg in _iter_command_providers(tts_config): + return True + return False + + # =========================================================================== # ffmpeg Opus conversion (Edge TTS MP3 -> OGG Opus for Telegram) # =========================================================================== @@ -312,7 +773,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any] Returns: Path to the saved audio file. """ - api_key = os.getenv("ELEVENLABS_API_KEY", "") + api_key = (get_env_value("ELEVENLABS_API_KEY") or "") if not api_key: raise ValueError("ELEVENLABS_API_KEY not set. Get one at https://elevenlabs.io/") @@ -406,7 +867,7 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) - """ import requests - api_key = os.getenv("XAI_API_KEY", "").strip() + api_key = (get_env_value("XAI_API_KEY") or "").strip() if not api_key: raise ValueError("XAI_API_KEY not set. Get one at https://console.x.ai/") @@ -417,7 +878,7 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) - bit_rate = int(xai_config.get("bit_rate", DEFAULT_XAI_BIT_RATE)) base_url = str( xai_config.get("base_url") - or os.getenv("XAI_BASE_URL") + or get_env_value("XAI_BASE_URL") or DEFAULT_XAI_BASE_URL ).strip().rstrip("/") @@ -479,7 +940,7 @@ def _generate_minimax_tts(text: str, output_path: str, tts_config: Dict[str, Any """ import requests - api_key = os.getenv("MINIMAX_API_KEY", "") + api_key = (get_env_value("MINIMAX_API_KEY") or "") if not api_key: raise ValueError("MINIMAX_API_KEY not set. Get one at https://platform.minimax.io/") @@ -556,7 +1017,7 @@ def _generate_mistral_tts(text: str, output_path: str, tts_config: Dict[str, Any and writes the raw bytes to *output_path*. Supports native Opus output for Telegram voice bubbles. """ - api_key = os.getenv("MISTRAL_API_KEY", "") + api_key = (get_env_value("MISTRAL_API_KEY") or "") if not api_key: raise ValueError("MISTRAL_API_KEY not set. Get one at https://console.mistral.ai/") @@ -651,7 +1112,7 @@ def _generate_gemini_tts(text: str, output_path: str, tts_config: Dict[str, Any] """ import requests - api_key = (os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") or "").strip() + api_key = (get_env_value("GEMINI_API_KEY") or get_env_value("GOOGLE_API_KEY") or "").strip() if not api_key: raise ValueError( "GEMINI_API_KEY not set. Get one at https://aistudio.google.com/app/apikey" @@ -662,7 +1123,7 @@ def _generate_gemini_tts(text: str, output_path: str, tts_config: Dict[str, Any] voice = str(gemini_config.get("voice", DEFAULT_GEMINI_TTS_VOICE)).strip() or DEFAULT_GEMINI_TTS_VOICE base_url = str( gemini_config.get("base_url") - or os.getenv("GEMINI_BASE_URL") + or get_env_value("GEMINI_BASE_URL") or DEFAULT_GEMINI_TTS_BASE_URL ).strip().rstrip("/") @@ -848,6 +1309,167 @@ def _generate_neutts(text: str, output_path: str, tts_config: Dict[str, Any]) -> return output_path +# =========================================================================== +# Provider: Piper (local, neural VITS, 44 languages) +# =========================================================================== + +# Module-level cache for Piper voice instances. Voices are keyed on their +# absolute .onnx model path so switching voices doesn't invalidate older +# cached voices. +_piper_voice_cache: Dict[str, Any] = {} + + +def _check_piper_available() -> bool: + """Check whether the piper-tts package is importable.""" + try: + import importlib.util + return importlib.util.find_spec("piper") is not None + except Exception: + return False + + +def _get_piper_voices_dir() -> Path: + """Return the directory where Hermes caches Piper voice models. + + Resolves to ``~/.hermes/cache/piper-voices/`` under the active + HERMES_HOME so voice downloads follow profile boundaries. + """ + from hermes_constants import get_hermes_dir + root = Path(get_hermes_dir("cache/piper-voices", "piper_voices_cache")) + root.mkdir(parents=True, exist_ok=True) + return root + + +def _resolve_piper_voice_path(voice: str, download_dir: Path) -> str: + """Resolve *voice* (a model name or path) to a concrete .onnx file path. + + Accepts any of: + - Absolute / expanded path to an .onnx file the user already has + - A voice *name* like ``en_US-lessac-medium`` (downloads to + ``download_dir`` on first use via ``python -m piper.download_voices``) + + Raises RuntimeError if the model can't be located or downloaded. + """ + if not voice: + voice = DEFAULT_PIPER_VOICE + + # Case 1: user gave a direct file path. + candidate = Path(voice).expanduser() + if candidate.suffix.lower() == ".onnx" and candidate.exists(): + return str(candidate) + + # Case 2: user gave a voice *name*. See if it's already downloaded. + cached = download_dir / f"{voice}.onnx" + if cached.exists() and (download_dir / f"{voice}.onnx.json").exists(): + return str(cached) + + # Case 3: download the voice. piper ships a download helper module. + import sys as _sys + logger.info("[Piper] Downloading voice '%s' to %s (first use)", voice, download_dir) + try: + result = subprocess.run( + [_sys.executable, "-m", "piper.download_voices", voice, + "--download-dir", str(download_dir)], + capture_output=True, text=True, timeout=300, + ) + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"Piper voice download timed out after 300s for '{voice}'" + ) from exc + + if result.returncode != 0: + stderr = (result.stderr or "").strip() or "no stderr output" + raise RuntimeError( + f"Piper voice download failed for '{voice}': {stderr[:400]}" + ) + + if not cached.exists(): + raise RuntimeError( + f"Piper voice download completed but {cached} is missing — " + f"check voice name (see: https://github.com/OHF-Voice/piper1-gpl/" + f"blob/main/docs/VOICES.md)" + ) + return str(cached) + + +def _generate_piper_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str: + """Generate speech using the local Piper engine. + + Loads the voice model once per process (cached by absolute path) and + writes a WAV file. Caller is responsible for converting to MP3/Opus + via ffmpeg when a different output format is required. + """ + PiperVoice = _import_piper() + import wave + + piper_config = tts_config.get("piper", {}) if isinstance(tts_config, dict) else {} + voice_name = piper_config.get("voice") or DEFAULT_PIPER_VOICE + download_dir = Path(piper_config.get("voices_dir") or _get_piper_voices_dir()).expanduser() + download_dir.mkdir(parents=True, exist_ok=True) + use_cuda = bool(piper_config.get("use_cuda", False)) + + model_path = _resolve_piper_voice_path(voice_name, download_dir) + + cache_key = f"{model_path}::cuda={use_cuda}" + global _piper_voice_cache + if cache_key not in _piper_voice_cache: + logger.info("[Piper] Loading voice: %s", model_path) + _piper_voice_cache[cache_key] = PiperVoice.load(model_path, use_cuda=use_cuda) + logger.info("[Piper] Voice loaded") + voice = _piper_voice_cache[cache_key] + + # Optional synthesis knobs — only pass a SynthesisConfig when at least + # one advanced knob is configured, so we don't depend on a newer Piper + # version than the user's installed one unless we need to. + syn_config = None + has_advanced = any( + k in piper_config + for k in ("length_scale", "noise_scale", "noise_w_scale", "volume", "normalize_audio") + ) + if has_advanced: + try: + from piper import SynthesisConfig # type: ignore + syn_config = SynthesisConfig( + length_scale=float(piper_config.get("length_scale", 1.0)), + noise_scale=float(piper_config.get("noise_scale", 0.667)), + noise_w_scale=float(piper_config.get("noise_w_scale", 0.8)), + volume=float(piper_config.get("volume", 1.0)), + normalize_audio=bool(piper_config.get("normalize_audio", True)), + ) + except ImportError: + logger.warning( + "[Piper] SynthesisConfig not available in this piper-tts " + "version — advanced knobs ignored" + ) + + # Piper outputs WAV. Caller handles downstream MP3/Opus conversion. + wav_path = output_path + if not output_path.endswith(".wav"): + wav_path = output_path.rsplit(".", 1)[0] + ".wav" + + with wave.open(wav_path, "wb") as wav_file: + if syn_config is not None: + voice.synthesize_wav(text, wav_file, syn_config=syn_config) + else: + voice.synthesize_wav(text, wav_file) + + # Convert to desired format if caller requested mp3/ogg + if wav_path != output_path: + ffmpeg = shutil.which("ffmpeg") + if ffmpeg: + conv_cmd = [ffmpeg, "-i", wav_path, "-y", "-loglevel", "error", output_path] + subprocess.run(conv_cmd, check=True, timeout=30) + try: + os.remove(wav_path) + except OSError: + pass + else: + # No ffmpeg — keep WAV and return that path + os.rename(wav_path, output_path) + + return output_path + + # =========================================================================== # Provider: KittenTTS (local, lightweight) # =========================================================================== @@ -941,6 +1563,12 @@ def text_to_speech_tool( tts_config = _load_tts_config() provider = _get_provider(tts_config) + # User-declared command provider (type: command under tts.providers.) + # resolves BEFORE the built-in dispatch. Built-in names short-circuit here + # so a user's ``tts.providers.openai.command`` can't override the real + # OpenAI handler. + command_provider_config = _resolve_command_provider_config(provider, tts_config) + # Truncate very long text with a warning. The cap is per-provider # (OpenAI 4096, xAI 15k, MiniMax 10k, ElevenLabs model-aware, etc.). max_len = _resolve_max_text_length(provider, tts_config) @@ -962,13 +1590,23 @@ def text_to_speech_tool( # Determine output path if output_path: file_path = Path(output_path).expanduser() + if command_provider_config is not None: + # Respect caller-supplied path but align the extension with the + # provider's configured output_format so the command writes to a + # path the caller actually expects. + file_path = _configured_command_tts_output_path( + file_path, command_provider_config + ) else: timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") out_dir = Path(DEFAULT_OUTPUT_DIR) out_dir.mkdir(parents=True, exist_ok=True) + if command_provider_config is not None: + fmt = _get_command_tts_output_format(command_provider_config) + file_path = out_dir / f"tts_{timestamp}.{fmt}" # Use .ogg for Telegram with providers that support native Opus output, # otherwise fall back to .mp3 (Edge TTS will attempt ffmpeg conversion later). - if want_opus and provider in ("openai", "elevenlabs", "mistral", "gemini"): + elif want_opus and provider in ("openai", "elevenlabs", "mistral", "gemini"): file_path = out_dir / f"tts_{timestamp}.ogg" else: file_path = out_dir / f"tts_{timestamp}.mp3" @@ -979,7 +1617,15 @@ def text_to_speech_tool( try: # Generate audio with the configured provider - if provider == "elevenlabs": + if command_provider_config is not None: + logger.info( + "Generating speech with command TTS provider '%s'...", provider, + ) + file_str = _generate_command_tts( + text, file_str, provider, command_provider_config, tts_config, + ) + + elif provider == "elevenlabs": try: _import_elevenlabs() except ImportError: @@ -1048,6 +1694,19 @@ def text_to_speech_tool( logger.info("Generating speech with KittenTTS (local, ~25MB)...") _generate_kittentts(text, file_str, tts_config) + elif provider == "piper": + try: + _import_piper() + except ImportError: + return json.dumps({ + "success": False, + "error": "Piper provider selected but 'piper-tts' package not installed. " + "Run 'hermes tools' and select Piper under TTS, or install manually: " + "pip install piper-tts", + }, ensure_ascii=False) + logger.info("Generating speech with Piper (local)...") + _generate_piper_tts(text, file_str, tts_config) + else: # Default: Edge TTS (free), with NeuTTS as local fallback edge_available = True @@ -1087,7 +1746,17 @@ def text_to_speech_tool( # Try Opus conversion for Telegram compatibility # Edge TTS outputs MP3, NeuTTS/KittenTTS output WAV — all need ffmpeg conversion voice_compatible = False - if provider in ("edge", "neutts", "minimax", "xai", "kittentts") and not file_str.endswith(".ogg"): + if command_provider_config is not None: + # Command providers are documents by default. Voice-bubble + # delivery only kicks in when the user explicitly opts in + # via ``voice_compatible: true`` in their provider config. + if _is_command_tts_voice_compatible(command_provider_config): + if not file_str.endswith(".ogg"): + opus_path = _convert_to_opus(file_str) + if opus_path: + file_str = opus_path + voice_compatible = file_str.endswith(".ogg") + elif provider in ("edge", "neutts", "minimax", "xai", "kittentts", "piper") and not file_str.endswith(".ogg"): opus_path = _convert_to_opus(file_str) if opus_path: file_str = opus_path @@ -1136,11 +1805,15 @@ def check_tts_requirements() -> bool: Check if at least one TTS provider is available. Edge TTS needs no API key and is the default, so if the package - is installed, TTS is available. + is installed, TTS is available. A user-declared command provider + also satisfies the requirement. Returns: bool: True if at least one provider can work. """ + # Any configured command provider counts as available. + if _has_any_command_tts_provider(): + return True try: _import_edge_tts() return True @@ -1148,7 +1821,7 @@ def check_tts_requirements() -> bool: pass try: _import_elevenlabs() - if os.getenv("ELEVENLABS_API_KEY"): + if get_env_value("ELEVENLABS_API_KEY"): return True except ImportError: pass @@ -1158,15 +1831,15 @@ def check_tts_requirements() -> bool: return True except ImportError: pass - if os.getenv("MINIMAX_API_KEY"): + if get_env_value("MINIMAX_API_KEY"): return True - if os.getenv("XAI_API_KEY"): + if get_env_value("XAI_API_KEY"): return True - if os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"): + if get_env_value("GEMINI_API_KEY") or get_env_value("GOOGLE_API_KEY"): return True try: _import_mistral_client() - if os.getenv("MISTRAL_API_KEY"): + if get_env_value("MISTRAL_API_KEY"): return True except ImportError: pass @@ -1174,6 +1847,8 @@ def check_tts_requirements() -> bool: return True if _check_kittentts_available(): return True + if _check_piper_available(): + return True return False @@ -1278,7 +1953,7 @@ def stream_tts_to_speaker( {**tts_config, "elevenlabs": {**el_config, "model_id": model_id}}, ) - api_key = os.getenv("ELEVENLABS_API_KEY", "") + api_key = (get_env_value("ELEVENLABS_API_KEY") or "") if not api_key: logger.warning("ELEVENLABS_API_KEY not set; streaming TTS audio disabled") else: @@ -1464,13 +2139,14 @@ if __name__ == "__main__": print("\nProvider availability:") print(f" Edge TTS: {'installed' if _check(_import_edge_tts, 'edge') else 'not installed (pip install edge-tts)'}") print(f" ElevenLabs: {'installed' if _check(_import_elevenlabs, 'el') else 'not installed (pip install elevenlabs)'}") - print(f" API Key: {'set' if os.getenv('ELEVENLABS_API_KEY') else 'not set'}") + print(f" API Key: {'set' if get_env_value('ELEVENLABS_API_KEY') else 'not set'}") print(f" OpenAI: {'installed' if _check(_import_openai_client, 'oai') else 'not installed'}") print( " API Key: " f"{'set' if resolve_openai_audio_api_key() else 'not set (VOICE_TOOLS_OPENAI_KEY or OPENAI_API_KEY)'}" ) - print(f" MiniMax: {'API key set' if os.getenv('MINIMAX_API_KEY') else 'not set (MINIMAX_API_KEY)'}") + print(f" MiniMax: {'API key set' if get_env_value('MINIMAX_API_KEY') else 'not set (MINIMAX_API_KEY)'}") + print(f" Piper: {'installed' if _check_piper_available() else 'not installed (pip install piper-tts)'}") print(f" ffmpeg: {'✅ found' if _has_ffmpeg() else '❌ not found (needed for Telegram Opus)'}") print(f"\n Output dir: {DEFAULT_OUTPUT_DIR}") @@ -1486,7 +2162,7 @@ from tools.registry import registry, tool_error TTS_SCHEMA = { "name": "text_to_speech", - "description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as a voice message. On Telegram it plays as a voice bubble, on Discord/WhatsApp as an audio attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured, not model-selected.", + "description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as native audio. Compatible providers render as a voice bubble on Telegram; otherwise audio is sent as a regular attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured (built-in providers like edge/openai or custom command providers under tts.providers.), not model-selected.", "parameters": { "type": "object", "properties": { diff --git a/tools/vision_tools.py b/tools/vision_tools.py index 32a1a68938..233b737272 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -38,6 +38,7 @@ from typing import Any, Awaitable, Dict, Optional from urllib.parse import urlparse import httpx from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning +from hermes_constants import get_hermes_dir from tools.debug_helpers import DebugSession from tools.website_policy import check_website_access @@ -56,9 +57,9 @@ def _resolve_download_timeout() -> float: except ValueError: pass try: - from hermes_cli.config import load_config + from hermes_cli.config import cfg_get, load_config cfg = load_config() - val = cfg.get("auxiliary", {}).get("vision", {}).get("download_timeout") + val = cfg_get(cfg, "auxiliary", "vision", "download_timeout") if val is not None: return float(val) except Exception: @@ -435,7 +436,7 @@ async def vision_analyze_tool( Exception: If download fails, analysis fails, or API key is not set Note: - - For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up + - For URLs, temporary images are stored under $HERMES_HOME/cache/vision/ and cleaned up - For local file paths, the file is used directly and NOT deleted - Supports common image formats (JPEG, PNG, GIF, WebP, etc.) """ @@ -483,7 +484,7 @@ async def vision_analyze_tool( if blocked: raise PermissionError(blocked["message"]) logger.info("Downloading image from URL...") - temp_dir = Path("./temp_vision_images") + temp_dir = get_hermes_dir("cache/vision", "temp_vision_images") temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg" await _download_image(image_url, temp_image_path) should_cleanup = True @@ -555,9 +556,9 @@ async def vision_analyze_tool( vision_timeout = 120.0 vision_temperature = 0.1 try: - from hermes_cli.config import load_config + from hermes_cli.config import cfg_get, load_config _cfg = load_config() - _vision_cfg = _cfg.get("auxiliary", {}).get("vision", {}) + _vision_cfg = cfg_get(_cfg, "auxiliary", "vision", default={}) _vt = _vision_cfg.get("timeout") if _vt is not None: vision_timeout = float(_vt) diff --git a/toolsets.py b/toolsets.py index a444713f57..ee067aa13e 100644 --- a/toolsets.py +++ b/toolsets.py @@ -564,6 +564,27 @@ def resolve_toolset(name: str, visited: Set[str] = None) -> List[str]: # Get toolset definition toolset = get_toolset(name) if not toolset: + # Auto-generate a toolset for plugin platforms (hermes-). + # Gives them _HERMES_CORE_TOOLS plus any tools the plugin registered + # into a toolset matching the platform name. + if name.startswith("hermes-"): + platform_name = name[len("hermes-"):] + try: + from gateway.platform_registry import platform_registry + if platform_registry.is_registered(platform_name): + plugin_tools = set(_HERMES_CORE_TOOLS) + try: + from tools.registry import registry + plugin_tools.update( + e.name for e in registry._tools.values() + if e.toolset == platform_name + ) + except Exception: + pass + return list(plugin_tools) + except Exception: + pass + return [] # Collect direct tools diff --git a/tui_gateway/entry.py b/tui_gateway/entry.py index 70fc851820..d3be53a6c4 100644 --- a/tui_gateway/entry.py +++ b/tui_gateway/entry.py @@ -165,11 +165,29 @@ def main(): # a model_tools.py module-level side effect; moved to explicit # startup calls to avoid freezing the gateway's loop on lazy import # (#16856). + # + # Cold-start guard: importing ``tools.mcp_tool`` transitively pulls the + # full MCP SDK (mcp, pydantic, httpx, jsonschema, starlette parsers — + # ~200ms on macOS), which runs on the TUI's critical path before + # ``gateway.ready`` can be emitted. The overwhelming majority of users + # have no ``mcp_servers`` configured, in which case every byte of that + # import is wasted. Check the config first (cheap — it's already been + # loaded once by ``_config_mtime`` elsewhere) and only pay the import + # cost when there's actually MCP work to do. try: - from tools.mcp_tool import discover_mcp_tools - discover_mcp_tools() + from hermes_cli.config import read_raw_config + _mcp_servers = (read_raw_config() or {}).get("mcp_servers") + _has_mcp_servers = isinstance(_mcp_servers, dict) and len(_mcp_servers) > 0 except Exception: - pass + # Be conservative: if we can't decide, fall back to the old + # behaviour and let the discovery path handle its own errors. + _has_mcp_servers = True + if _has_mcp_servers: + try: + from tools.mcp_tool import discover_mcp_tools + discover_mcp_tools() + except Exception: + pass if not write_json({ "jsonrpc": "2.0", diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 4072e49647..82e9750611 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -128,11 +128,13 @@ _cfg_path = None _SLASH_WORKER_TIMEOUT_S = max( 5.0, float(os.environ.get("HERMES_TUI_SLASH_TIMEOUT_S", "45") or 45) ) +_DETAIL_SECTION_NAMES = ("thinking", "tools", "subagents", "activity") +_DETAIL_MODES = frozenset({"hidden", "collapsed", "expanded"}) # ── Async RPC dispatch (#12546) ────────────────────────────────────── # A handful of handlers block the dispatcher loop in entry.py for seconds # to minutes (slash.exec, cli.exec, shell.exec, session.resume, -# session.branch, skills.manage). While they're running, inbound RPCs — +# session.branch, session.compress, skills.manage). While they're running, inbound RPCs — # notably approval.respond and session.interrupt — sit unread in the # stdin pipe. We route only those slow handlers onto a small thread pool; # everything else stays on the main thread so ordering stays sane for the @@ -143,6 +145,7 @@ _LONG_HANDLERS = frozenset( "browser.manage", "cli.exec", "session.branch", + "session.compress", "session.resume", "shell.exec", "skills.manage", @@ -465,6 +468,119 @@ def _wait_agent(session: dict, rid: str, timeout: float = 30.0) -> dict | None: return _err(rid, 5032, err) if err else None +def _start_agent_build(sid: str, session: dict) -> None: + """Start building the real AIAgent for a TUI session, once. + + Classic `hermes` shows the prompt before constructing AIAgent; the TUI used + to eagerly build it during session.create, making startup feel blocked on + tool discovery/model metadata even though the composer was visible. Keep + the shell responsive by deferring this work until the first prompt (or any + command that actually needs the agent), while retaining the same ready/error + event contract for the frontend. + """ + ready = session.get("agent_ready") + if ready is None: + return + lock = session.setdefault("agent_build_lock", threading.Lock()) + with lock: + if ready.is_set() or session.get("agent_build_started"): + return + session["agent_build_started"] = True + key = session["session_key"] + + def _build() -> None: + current = _sessions.get(sid) + if current is None: + ready.set() + return + + worker = None + notify_registered = False + try: + tokens = _set_session_context(key) + try: + agent = _make_agent(sid, key) + finally: + _clear_session_context(tokens) + + db = _get_db() + if db is not None: + db.create_session(key, source="tui", model=_resolve_model()) + pending_title = (current.get("pending_title") or "").strip() + if pending_title: + try: + title_applied = db.set_session_title(key, pending_title) + if title_applied: + current["pending_title"] = None + else: + existing_row = db.get_session(key) + existing_title = ((existing_row or {}).get("title") or "").strip() + if existing_title == pending_title: + current["pending_title"] = None + else: + logger.info( + "Pending title still queued for session %s (wanted=%r, current=%r)", + sid, + pending_title, + existing_title, + ) + except ValueError as e: + current["pending_title"] = None + logger.info("Dropping pending title for session %s: %s", sid, e) + except Exception: + logger.warning("Failed to apply pending title for session %s", sid, exc_info=True) + current["agent"] = agent + + try: + worker = _SlashWorker(key, getattr(agent, "model", _resolve_model())) + current["slash_worker"] = worker + except Exception: + pass + + try: + from tools.approval import ( + register_gateway_notify, + load_permanent_allowlist, + ) + register_gateway_notify(key, lambda data: _emit("approval.request", sid, data)) + notify_registered = True + load_permanent_allowlist() + except Exception: + pass + + _wire_callbacks(sid) + _notify_session_boundary("on_session_reset", key) + + info = _session_info(agent) + warn = _probe_credentials(agent) + if warn: + info["credential_warning"] = warn + cfg_warn = _probe_config_health(_load_cfg()) + if cfg_warn: + info["config_warning"] = cfg_warn + logger.warning(cfg_warn) + _emit("session.info", sid, info) + except Exception as e: + current["agent_error"] = str(e) + _emit("error", sid, {"message": f"agent init failed: {e}"}) + finally: + if _sessions.get(sid) is not current: + if worker is not None: + try: + worker.close() + except Exception: + pass + if notify_registered: + try: + from tools.approval import unregister_gateway_notify + unregister_gateway_notify(key) + except Exception: + pass + ready.set() + + threading.Thread(target=_build, daemon=True).start() + + def _sess_nowait(params, rid): s = _sessions.get(params.get("session_id") or "") return (s, None) if s else (None, _err(rid, 4001, "session not found")) @@ -472,7 +588,10 @@ def _sess_nowait(params, rid): def _sess(params, rid): s, err = _sess_nowait(params, rid) - return (None, err) if err else (s, _wait_agent(s, rid)) + if err: + return (None, err) + _start_agent_build(params.get("session_id") or "", s) + return (s, _wait_agent(s, rid)) def _normalize_completion_path(path_part: str) -> str: @@ -743,10 +862,100 @@ def _load_tool_progress_mode() -> str: def _load_enabled_toolsets() -> list[str] | None: + explicit = [ + item.strip() + for item in os.environ.get("HERMES_TUI_TOOLSETS", "").split(",") + if item.strip() + ] + cfg = None + fallback_notice = None + + try: + from toolsets import validate_toolset + except Exception: + validate_toolset = None + + if explicit and validate_toolset is not None: + built_in = [name for name in explicit if validate_toolset(name)] + unresolved = [name for name in explicit if name not in built_in] + + if unresolved: + try: + from hermes_cli.plugins import discover_plugins + + discover_plugins() + plugin_valid = [name for name in unresolved if validate_toolset(name)] + except Exception: + plugin_valid = [] + + if plugin_valid: + built_in.extend(plugin_valid) + unresolved = [name for name in unresolved if name not in plugin_valid] + + if any(name in {"all", "*"} for name in built_in): + ignored = [name for name in explicit if name not in {"all", "*"}] + if ignored: + print( + "[tui] HERMES_TUI_TOOLSETS=all enables every toolset; " + f"ignoring additional entries: {', '.join(ignored)}", + file=sys.stderr, + flush=True, + ) + return None + + if not unresolved: + return built_in + + mcp_names: set[str] = set() + mcp_disabled: set[str] = set() + try: + from hermes_cli.config import read_raw_config + from hermes_cli.tools_config import _parse_enabled_flag + + raw_cfg = read_raw_config() + mcp_servers = raw_cfg.get("mcp_servers") if isinstance(raw_cfg.get("mcp_servers"), dict) else {} + for name, server_cfg in mcp_servers.items(): + if not isinstance(server_cfg, dict): + continue + if _parse_enabled_flag(server_cfg.get("enabled", True), default=True): + mcp_names.add(str(name)) + else: + mcp_disabled.add(str(name)) + except Exception: + mcp_names = set() + mcp_disabled = set() + + mcp_valid = [name for name in unresolved if name in mcp_names] + disabled = [name for name in unresolved if name in mcp_disabled] + unknown = [name for name in unresolved if name not in mcp_names and name not in mcp_disabled] + valid = built_in + mcp_valid + + if unknown: + print( + f"[tui] ignoring unknown HERMES_TUI_TOOLSETS entries: {', '.join(unknown)}", + file=sys.stderr, + flush=True, + ) + if disabled: + print( + "[tui] ignoring disabled MCP servers in HERMES_TUI_TOOLSETS " + "(set enabled: true in config.yaml to use): " + f"{', '.join(disabled)}", + file=sys.stderr, + flush=True, + ) + + if valid: + return valid + + fallback_notice = "[tui] no valid HERMES_TUI_TOOLSETS entries; using configured CLI toolsets" + try: from hermes_cli.config import load_config from hermes_cli.tools_config import _get_platform_tools + cfg = cfg if cfg is not None else load_config() + # Runtime toolset resolution must include default MCP servers so the # agent can actually call them. Passing ``False`` here is the # config-editing variant — used when we need to persist a toolset @@ -754,10 +963,18 @@ def _load_enabled_toolsets() -> list[str] | None: # variant at agent creation time makes MCP tools silently missing # from the TUI. See PR #3252 for the original design split. enabled = sorted( - _get_platform_tools(load_config(), "cli", include_default_mcp_servers=True) + _get_platform_tools(cfg, "cli", include_default_mcp_servers=True) ) + if fallback_notice is not None: + print(fallback_notice, file=sys.stderr, flush=True) return enabled or None except Exception: + if fallback_notice is not None: + print( + "[tui] no valid HERMES_TUI_TOOLSETS entries and configured CLI toolsets could not be loaded; enabling all toolsets", + file=sys.stderr, + flush=True, + ) return None @@ -885,24 +1102,111 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict: def _compress_session_history( - session: dict, focus_topic: str | None = None + session: dict, + focus_topic: str | None = None, + approx_tokens: int | None = None, + before_messages: list | None = None, + history_version: int | None = None, ) -> tuple[int, dict]: from agent.model_metadata import estimate_messages_tokens_rough agent = session["agent"] - history = list(session.get("history", [])) + # Snapshot history under the lock so the LLM-bound compression call + # below does NOT hold history_lock for the duration of the request — + # otherwise other handlers acquiring the lock (prompt.submit etc.) + # block on the dispatcher loop while compaction runs. + if before_messages is None or history_version is None: + with session["history_lock"]: + before_messages = list(session.get("history", [])) + history_version = int(session.get("history_version", 0)) + history = before_messages if len(history) < 4: - return 0, _get_usage(agent) - approx_tokens = estimate_messages_tokens_rough(history) + usage = _get_usage(agent) + return 0, usage + if approx_tokens is None: + approx_tokens = estimate_messages_tokens_rough(history) + # Pass system_message=None so AIAgent._compress_context rebuilds the + # system prompt cleanly via _build_system_prompt(None). Passing the + # cached prompt (which already contains the agent identity block) + # makes the rebuild append the identity a second time. Mirrors the + # CLI's _manual_compress fix for issue #15281. compressed, _ = agent._compress_context( history, - getattr(agent, "_cached_system_prompt", "") or "", + None, approx_tokens=approx_tokens, focus_topic=focus_topic or None, ) - session["history"] = compressed - session["history_version"] = int(session.get("history_version", 0)) + 1 - return len(history) - len(compressed), _get_usage(agent) + with session["history_lock"]: + if int(session.get("history_version", 0)) != history_version: + # External mutation during compaction — drop the compressed + # result so we don't clobber concurrent edits. + usage = _get_usage(agent) + return 0, usage + session["history"] = compressed + session["history_version"] = history_version + 1 + usage = _get_usage(agent) + return len(history) - len(compressed), usage + + +def _sync_session_key_after_compress(sid: str, session: dict) -> None: + """Re-anchor session_key when AIAgent._compress_context rotates session_id. + + AIAgent._compress_context ends the current SessionDB session and creates + a new continuation session, rotating ``agent.session_id``. The TUI + gateway keeps the gateway-side ``session_key`` separate (used for + approval routing, slash worker init, DB title/history lookups, yolo + state). Without this sync, those operations would target the ended + parent session while the agent writes to the new continuation session. + Mirrors HermesCLI._manual_compress's session_id sync. + """ + agent = session.get("agent") + new_session_id = getattr(agent, "session_id", None) or "" + old_key = session.get("session_key", "") or "" + if not new_session_id or new_session_id == old_key: + return + + try: + from tools.approval import ( + disable_session_yolo, + enable_session_yolo, + is_session_yolo_enabled, + register_gateway_notify, + unregister_gateway_notify, + ) + + try: + unregister_gateway_notify(old_key) + except Exception: + pass + session["session_key"] = new_session_id + try: + yolo_was_on = is_session_yolo_enabled(old_key) + except Exception: + yolo_was_on = False + if yolo_was_on: + try: + enable_session_yolo(new_session_id) + disable_session_yolo(old_key) + except Exception: + pass + try: + register_gateway_notify( + new_session_id, + lambda data: _emit("approval.request", sid, data), + ) + except Exception: + pass + except Exception: + # Even if the approval module fails to import, still anchor the + # session_key on the new continuation id so downstream lookups + # don't keep targeting the ended row. + session["session_key"] = new_session_id + + session["pending_title"] = None + try: + _restart_slash_worker(session) + except Exception: + pass def _get_usage(agent) -> dict: @@ -1627,129 +1931,18 @@ def _(rid, params: dict) -> dict: "transport": current_transport() or _stdio_transport, } - def _build() -> None: + # Return the lightweight session immediately so Ink can paint the composer + # + skeleton panel, then build the real AIAgent just after this response is + # flushed. This keeps startup responsive while still hydrating tools/skills + # without requiring the user to submit a first prompt. + def _deferred_build() -> None: session = _sessions.get(sid) - if session is None: - # session.close ran before the build thread got scheduled. - ready.set() - return + if session is not None: + _start_agent_build(sid, session) - # Track what we allocate so we can clean up if session.close - # races us to the finish line. session.close pops _sessions[sid] - # unconditionally and tries to close the slash_worker it finds; - # if _build is still mid-construction when close runs, close - # finds slash_worker=None / notify unregistered and returns - # cleanly — leaving us, the build thread, to later install the - # worker + notify on an orphaned session dict. The finally - # block below detects the orphan and cleans up instead of - # leaking a subprocess and a global notify registration. - worker = None - notify_registered = False - try: - tokens = _set_session_context(key) - try: - agent = _make_agent(sid, key) - finally: - _clear_session_context(tokens) - - db = _get_db() - if db is not None: - db.create_session(key, source="tui", model=_resolve_model()) - pending_title = (session.get("pending_title") or "").strip() - if pending_title: - try: - title_applied = db.set_session_title(key, pending_title) - if title_applied: - session["pending_title"] = None - else: - existing_row = db.get_session(key) - existing_title = ( - (existing_row or {}).get("title") or "" - ).strip() - if existing_title == pending_title: - session["pending_title"] = None - else: - logger.info( - "Pending title still queued for session %s (wanted=%r, current=%r)", - sid, - pending_title, - existing_title, - ) - except ValueError as e: - # Queued title can become invalid/duplicate between queue time - # and DB row creation. Drop the queue and log the reason so - # future /title reads don't surface a stuck pending value. - session["pending_title"] = None - logger.info( - "Dropping pending title for session %s: %s", - sid, - e, - ) - except Exception: - logger.warning( - "Failed to apply pending title for session %s", - sid, - exc_info=True, - ) - session["agent"] = agent - - try: - worker = _SlashWorker(key, getattr(agent, "model", _resolve_model())) - session["slash_worker"] = worker - except Exception: - pass - - try: - from tools.approval import ( - register_gateway_notify, - load_permanent_allowlist, - ) - - register_gateway_notify( - key, lambda data: _emit("approval.request", sid, data) - ) - notify_registered = True - load_permanent_allowlist() - except Exception: - pass - - _wire_callbacks(sid) - _notify_session_boundary("on_session_reset", key) - - info = _session_info(agent) - warn = _probe_credentials(agent) - if warn: - info["credential_warning"] = warn - cfg_warn = _probe_config_health(_load_cfg()) - if cfg_warn: - info["config_warning"] = cfg_warn - logger.warning(cfg_warn) - _emit("session.info", sid, info) - except Exception as e: - session["agent_error"] = str(e) - _emit("error", sid, {"message": f"agent init failed: {e}"}) - finally: - # Orphan check: if session.close raced us and popped - # _sessions[sid] while we were building, the dict we just - # populated is unreachable. Clean up the subprocess and - # the global notify registration ourselves — session.close - # couldn't see them at the time it ran. - if _sessions.get(sid) is not session: - if worker is not None: - try: - worker.close() - except Exception: - pass - if notify_registered: - try: - from tools.approval import unregister_gateway_notify - - unregister_gateway_notify(key) - except Exception: - pass - ready.set() - - threading.Thread(target=_build, daemon=True).start() + build_timer = threading.Timer(0.05, _deferred_build) + build_timer.daemon = True + build_timer.start() return _ok( rid, @@ -1760,6 +1953,7 @@ def _(rid, params: dict) -> dict: "tools": {}, "skills": {}, "cwd": os.getenv("TERMINAL_CWD", os.getcwd()), + "lazy": True, }, }, ) @@ -1899,9 +2093,50 @@ def _(rid, params: dict) -> dict: ) +@method("session.delete") +def _(rid, params: dict) -> dict: + """Delete a stored session and its on-disk transcript files. + + Used by the TUI resume picker (``d`` key) so users can prune old + sessions without dropping to the CLI. Refuses to delete a session + that is currently active in this gateway process — those rows are + still being written to and removing them out from under the live + agent corrupts message ordering and trips FK constraints when the + next message append flushes. + """ + target = params.get("session_id", "") + if not target: + return _err(rid, 4006, "session_id required") + db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5036) + # Block deletion of any session currently bound to a live TUI session + # in this process. The picker hides the active session anyway, but a + # racing caller could still target it. Snapshot via ``list(...)`` + # because ``_sessions`` is mutated by concurrent RPCs on the thread + # pool — iterating the dict directly can raise ``RuntimeError: + # dictionary changed size during iteration``. If even the snapshot + # raises, fail closed (refuse the delete) rather than fail open. + try: + snapshot = list(_sessions.values()) + except Exception as e: + return _err(rid, 5036, f"could not enumerate active sessions: {e}") + active = {s.get("session_key") for s in snapshot if s.get("session_key")} + if target in active: + return _err(rid, 4023, "cannot delete an active session") + sessions_dir = get_hermes_home() / "sessions" + try: + deleted = db.delete_session(target, sessions_dir=sessions_dir) + except Exception as e: + return _err(rid, 5036, f"delete failed: {e}") + if not deleted: + return _err(rid, 4007, "session not found") + return _ok(rid, {"deleted": target}) + + @method("session.title") def _(rid, params: dict) -> dict: - session, err = _sess(params, rid) + session, err = _sess_nowait(params, rid) if err: return err db = _get_db() @@ -1964,13 +2199,16 @@ def _(rid, params: dict) -> dict: @method("session.usage") def _(rid, params: dict) -> dict: - session, err = _sess(params, rid) - return err or _ok(rid, _get_usage(session["agent"])) + session, err = _sess_nowait(params, rid) + if err: + return err + agent = session.get("agent") + return _ok(rid, _get_usage(agent) if agent is not None else {"calls": 0, "input": 0, "output": 0, "total": 0}) @method("session.history") def _(rid, params: dict) -> dict: - session, err = _sess(params, rid) + session, err = _sess_nowait(params, rid) if err: return err history = list(session.get("history", [])) @@ -2028,24 +2266,70 @@ def _(rid, params: dict) -> dict: return _err( rid, 4009, "session busy — /interrupt the current turn before /compress" ) + sid = params.get("session_id", "") + focus_topic = str(params.get("focus_topic", "") or "").strip() try: + from agent.manual_compression_feedback import summarize_manual_compression + from agent.model_metadata import estimate_messages_tokens_rough + with session["history_lock"]: - removed, usage = _compress_session_history( - session, str(params.get("focus_topic", "") or "").strip() - ) - messages = list(session.get("history", [])) - info = _session_info(session["agent"]) - _emit("session.info", params.get("session_id", ""), info) - return _ok( - rid, - { - "status": "compressed", - "removed": removed, - "usage": usage, - "info": info, - "messages": messages, - }, + before_messages = list(session.get("history", [])) + history_version = int(session.get("history_version", 0)) + before_count = len(before_messages) + before_tokens = ( + estimate_messages_tokens_rough(before_messages) if before_count else 0 ) + + if before_count >= 4: + focus_suffix = f', focus: "{focus_topic}"' if focus_topic else "" + _status_update( + sid, + "compressing", + f"⠋ compressing {before_count} messages " + f"(~{before_tokens:,} tok){focus_suffix}…", + ) + + try: + removed, usage = _compress_session_history( + session, + focus_topic, + approx_tokens=before_tokens, + before_messages=before_messages, + history_version=history_version, + ) + with session["history_lock"]: + messages = list(session.get("history", [])) + after_count = len(messages) + after_tokens = ( + estimate_messages_tokens_rough(messages) if after_count else 0 + ) + agent = session["agent"] + _sync_session_key_after_compress(sid, session) + summary = summarize_manual_compression( + before_messages, messages, before_tokens, after_tokens + ) + info = _session_info(agent) + _emit("session.info", sid, info) + return _ok( + rid, + { + "status": "compressed", + "removed": removed, + "before_messages": before_count, + "after_messages": after_count, + "before_tokens": before_tokens, + "after_tokens": after_tokens, + "summary": summary, + "usage": usage, + "info": info, + "messages": messages, + }, + ) + finally: + # Always clear the pinned compressing status so the bar + # reverts to neutral whether compaction succeeded, was a + # no-op, or raised. + _status_update(sid, "ready") except Exception as e: return _err(rid, 5005, str(e)) @@ -2437,13 +2721,31 @@ def _(rid, params: dict) -> dict: @method("prompt.submit") def _(rid, params: dict) -> dict: sid, text = params.get("session_id", ""), params.get("text", "") - session, err = _sess(params, rid) + session, err = _sess_nowait(params, rid) if err: return err with session["history_lock"]: if session.get("running"): return _err(rid, 4009, "session busy") session["running"] = True + + _start_agent_build(sid, session) + + def run_after_agent_ready() -> None: + err = _wait_agent(session, rid) + if err: + _emit("error", sid, {"message": err.get("error", {}).get("message", "agent initialization failed")}) + with session["history_lock"]: + session["running"] = False + return + _run_prompt_submit(rid, sid, session, text) + + threading.Thread(target=run_after_agent_ready, daemon=True).start() + return _ok(rid, {"status": "streaming"}) + + +def _run_prompt_submit(rid, sid: str, session: dict, text: Any) -> None: + with session["history_lock"]: history = list(session["history"]) history_version = int(session.get("history_version", 0)) images = list(session.get("attached_images", [])) @@ -2682,7 +2984,6 @@ def _(rid, params: dict) -> dict: session["running"] = False threading.Thread(target=run, daemon=True).start() - return _ok(rid, {"status": "streaming"}) @method("clipboard.paste") @@ -3079,12 +3380,34 @@ def _(rid, params: dict) -> dict: arg = str(value or "").strip().lower() if arg in ("show", "on"): - _write_config_key("display.show_reasoning", True) + cfg = _load_cfg() + display = cfg.get("display") if isinstance(cfg.get("display"), dict) else {} + sections = ( + display.get("sections") + if isinstance(display.get("sections"), dict) + else {} + ) + display["show_reasoning"] = True + sections["thinking"] = "expanded" + display["sections"] = sections + cfg["display"] = display + _save_cfg(cfg) if session: session["show_reasoning"] = True return _ok(rid, {"key": key, "value": "show"}) if arg in ("hide", "off"): - _write_config_key("display.show_reasoning", False) + cfg = _load_cfg() + display = cfg.get("display") if isinstance(cfg.get("display"), dict) else {} + sections = ( + display.get("sections") + if isinstance(display.get("sections"), dict) + else {} + ) + display["show_reasoning"] = False + sections["thinking"] = "hidden" + display["sections"] = sections + cfg["display"] = display + _save_cfg(cfg) if session: session["show_reasoning"] = False return _ok(rid, {"key": key, "value": "hide"}) @@ -3101,19 +3424,26 @@ def _(rid, params: dict) -> dict: if key == "details_mode": nv = str(value or "").strip().lower() - allowed_dm = frozenset({"hidden", "collapsed", "expanded"}) - if nv not in allowed_dm: + if nv not in _DETAIL_MODES: return _err(rid, 4002, f"unknown details_mode: {value}") - _write_config_key("display.details_mode", nv) + cfg = _load_cfg() + display = cfg.get("display") if isinstance(cfg.get("display"), dict) else {} + sections = display.get("sections") if isinstance(display.get("sections"), dict) else {} + display["details_mode"] = nv + for section in _DETAIL_SECTION_NAMES: + sections[section] = nv + display["sections"] = sections + cfg["display"] = display + _save_cfg(cfg) return _ok(rid, {"key": key, "value": nv}) if key.startswith("details_mode."): # Per-section override: `details_mode.
` writes to - # `display.sections.
`. Empty value clears the override - # and lets the section fall back to the global details_mode. + # `display.sections.
`. Empty value clears the explicit + # override and lets frontend resolution apply built-in section defaults + # before the global details_mode. section = key.split(".", 1)[1] - allowed_sections = frozenset({"thinking", "tools", "subagents", "activity"}) - if section not in allowed_sections: + if section not in _DETAIL_SECTION_NAMES: return _err(rid, 4002, f"unknown section: {section}") cfg = _load_cfg() @@ -3130,8 +3460,7 @@ def _(rid, params: dict) -> dict: _save_cfg(cfg) return _ok(rid, {"key": key, "value": ""}) - allowed_dm = frozenset({"hidden", "collapsed", "expanded"}) - if nv not in allowed_dm: + if nv not in _DETAIL_MODES: return _err(rid, 4002, f"unknown details_mode: {value}") sections_cfg[section] = nv @@ -3415,6 +3744,40 @@ def _(rid, params: dict) -> dict: def _(rid, params: dict) -> dict: session = _sessions.get(params.get("session_id", "")) try: + # Gate: /reload-mcp invalidates the prompt cache for this session. + # Respect the ``approvals.mcp_reload_confirm`` config toggle — if + # set (default true) AND the caller did not pass ``confirm=true`` + # in params, surface a warning to the transcript instead of just + # reloading silently. Users pass confirm=true either by + # re-invoking after reading the warning, or by setting the + # config key to false permanently. + user_confirm = bool(params.get("confirm", False)) + if not user_confirm: + try: + from hermes_cli.config import load_config as _load_config + _cfg = _load_config() + _approvals = _cfg.get("approvals") if isinstance(_cfg, dict) else None + _confirm_required = True + if isinstance(_approvals, dict): + _confirm_required = bool(_approvals.get("mcp_reload_confirm", True)) + except Exception: + _confirm_required = True + if _confirm_required: + # Return a structured response the Ink client can surface + # as a warning/confirmation without actually reloading yet. + # Ink's ops.ts reads ``status`` and prints ``message`` to + # the transcript; a follow-up invocation with confirm=true + # (or an `always` choice that flips the config) proceeds. + return _ok(rid, { + "status": "confirm_required", + "message": ( + "⚠️ /reload-mcp invalidates the prompt cache (next " + "message re-sends full input tokens). Reply `/reload-mcp " + "now` to proceed, or `/reload-mcp always` to proceed and " + "silence this prompt permanently." + ), + }) + from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools shutdown_mcp_servers() @@ -3424,6 +3787,15 @@ def _(rid, params: dict) -> dict: if hasattr(agent, "refresh_tools"): agent.refresh_tools() _emit("session.info", params.get("session_id", ""), _session_info(agent)) + + # Honor `always=true` by persisting the opt-out to config. + if bool(params.get("always", False)): + try: + from cli import save_config_value as _save_cfg + _save_cfg("approvals.mcp_reload_confirm", False) + except Exception as _exc: + logger.warning("Failed to persist mcp_reload_confirm=false: %s", _exc) + return _ok(rid, {"status": "reloaded"}) except Exception as e: return _err(rid, 5015, str(e)) @@ -4354,8 +4726,8 @@ def _mirror_slash_side_effects(sid: str, session: dict, command: str) -> str: agent.ephemeral_system_prompt = new_prompt or None agent._cached_system_prompt = None elif name == "compress" and agent: - with session["history_lock"]: - _compress_session_history(session, arg) + _compress_session_history(session, arg) + _sync_session_key_after_compress(sid, session) _emit("session.info", sid, _session_info(agent)) elif name == "fast" and agent: mode = arg.lower() diff --git a/ui-tui/package-lock.json b/ui-tui/package-lock.json index 017e9913bd..2efd64fe40 100644 --- a/ui-tui/package-lock.json +++ b/ui-tui/package-lock.json @@ -124,6 +124,7 @@ "integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/code-frame": "^7.29.0", "@babel/generator": "^7.29.0", @@ -501,31 +502,6 @@ "node": ">=6.9.0" } }, - "node_modules/@emnapi/core": { - "version": "1.10.0", - "resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.10.0.tgz", - "integrity": "sha512-yq6OkJ4p82CAfPl0u9mQebQHKPJkY7WrIuk205cTYnYe+k2Z8YBh11FrbRG/H6ihirqcacOgl2BIO8oyMQLeXw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@emnapi/wasi-threads": "1.2.1", - "tslib": "^2.4.0" - } - }, - "node_modules/@emnapi/runtime": { - "version": "1.10.0", - "resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.10.0.tgz", - "integrity": "sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "tslib": "^2.4.0" - } - }, "node_modules/@emnapi/wasi-threads": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz", @@ -1700,6 +1676,7 @@ "integrity": "sha512-+qIYRKdNYJwY3vRCZMdJbPLJAtGjQBudzZzdzwQYkEPQd+PJGixUL5QfvCLDaULoLv+RhT3LDkwEfKaAkgSmNQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~7.19.0" } @@ -1710,6 +1687,7 @@ "integrity": "sha512-ilcTH/UniCkMdtexkoCN0bI7pMcJDvmQFPvuPvmEaYA/NSfFTAgdUSLAoVjaRJm7+6PvcM+q1zYOwS4wTYMF9w==", "devOptional": true, "license": "MIT", + "peer": true, "dependencies": { "csstype": "^3.2.2" } @@ -1720,6 +1698,7 @@ "integrity": "sha512-eSkwoemjo76bdXl2MYqtxg51HNwUSkWfODUOQ3PaTLZGh9uIWWFZIjyjaJnex7wXDu+TRx+ATsnSxdN9YWfRTQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.58.1", @@ -1749,6 +1728,7 @@ "integrity": "sha512-gGkiNMPqerb2cJSVcruigx9eHBlLG14fSdPdqMoOcBfh+vvn4iCq2C8MzUB89PrxOXk0y3GZ1yIWb9aOzL93bw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.58.1", "@typescript-eslint/types": "8.58.1", @@ -2066,6 +2046,7 @@ "integrity": "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -2468,6 +2449,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.10.12", "caniuse-lite": "^1.0.30001782", @@ -3203,6 +3185,7 @@ "integrity": "sha512-XoMjdBOwe/esVgEvLmNsD3IRHkm7fbKIUGvrleloJXUZgDHig2IPWNniv+GwjyJXzuNqVjlr5+4yVUZjycJwfQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -3334,6 +3317,7 @@ "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", "dev": true, "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } @@ -4242,6 +4226,7 @@ "resolved": "https://registry.npmjs.org/ink-text-input/-/ink-text-input-6.0.0.tgz", "integrity": "sha512-Fw64n7Yha5deb1rHY137zHTAbSTNelUKuB5Kkk2HACXEtwIHBCf9OH2tP/LQ9fRYTl1F0dZgbW0zPnZk6FA9Lw==", "license": "MIT", + "peer": true, "dependencies": { "chalk": "^5.3.0", "type-fest": "^4.18.2" @@ -5678,6 +5663,7 @@ "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -5787,6 +5773,7 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.5.tgz", "integrity": "sha512-llUJLzz1zTUBrskt2pwZgLq59AemifIftw4aB7JxOqf1HY2FDaGDxgwpAPVzHU1kdWabH7FauP4i1oEeer2WCA==", "license": "MIT", + "peer": true, "engines": { "node": ">=0.10.0" } @@ -6611,6 +6598,7 @@ "integrity": "sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "~0.27.0", "get-tsconfig": "^4.7.5" @@ -6737,6 +6725,7 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -6846,6 +6835,7 @@ "integrity": "sha512-dbU7/iLVa8KZALJyLOBOQ88nOXtNG8vxKuOT4I2mD+Ya70KPceF4IAmDsmU0h1Qsn5bPrvsY9HJstCRh3hG6Uw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "lightningcss": "^1.32.0", "picomatch": "^4.0.4", @@ -7261,6 +7251,7 @@ "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", "dev": true, "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/ui-tui/packages/hermes-ink/src/ink/colorize.test.ts b/ui-tui/packages/hermes-ink/src/ink/colorize.test.ts new file mode 100644 index 0000000000..814b8d91e5 --- /dev/null +++ b/ui-tui/packages/hermes-ink/src/ink/colorize.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from 'vitest' + +import { + CHALK_USES_RICH_EIGHT_BIT_DOWNGRADE, + richEightBitColorNumber, + shouldUseRichEightBitDowngradeForLegacyAppleTerminal +} from './colorize.js' + +describe('shouldUseRichEightBitDowngradeForLegacyAppleTerminal', () => { + it('memoizes the current process decision for render hot paths', () => { + expect(typeof CHALK_USES_RICH_EIGHT_BIT_DOWNGRADE).toBe('boolean') + }) + + it('uses Rich-compatible 256-color downgrade on legacy Apple Terminal', () => { + expect( + shouldUseRichEightBitDowngradeForLegacyAppleTerminal({ TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv, 2) + ).toBe(true) + }) + + it('normalizes Apple Terminal names before matching', () => { + expect( + shouldUseRichEightBitDowngradeForLegacyAppleTerminal({ TERM_PROGRAM: ' Apple_Terminal ' } as NodeJS.ProcessEnv, 2) + ).toBe(true) + }) + + it('does not rewrite when Apple Terminal advertises truecolor', () => { + expect( + shouldUseRichEightBitDowngradeForLegacyAppleTerminal( + { COLORTERM: 'truecolor', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv, + 3 + ) + ).toBe(false) + }) + + it('does not override explicit color environment choices', () => { + expect( + shouldUseRichEightBitDowngradeForLegacyAppleTerminal( + { FORCE_COLOR: '2', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv, + 2 + ) + ).toBe(false) + expect( + shouldUseRichEightBitDowngradeForLegacyAppleTerminal( + { HERMES_TUI_TRUECOLOR: '1', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv, + 3 + ) + ).toBe(false) + }) +}) + +describe('richEightBitColorNumber', () => { + it('matches Rich downgrade output for default Hermes skin colors', () => { + expect(richEightBitColorNumber(0xff, 0xd7, 0x00)).toBe(220) + expect(richEightBitColorNumber(0xff, 0xbf, 0x00)).toBe(214) + expect(richEightBitColorNumber(0xcd, 0x7f, 0x32)).toBe(173) + expect(richEightBitColorNumber(0xb8, 0x86, 0x0b)).toBe(136) + expect(richEightBitColorNumber(0xff, 0xf8, 0xdc)).toBe(230) + }) +}) + diff --git a/ui-tui/packages/hermes-ink/src/ink/colorize.ts b/ui-tui/packages/hermes-ink/src/ink/colorize.ts index 2229f70a97..7a8a57a568 100644 --- a/ui-tui/packages/hermes-ink/src/ink/colorize.ts +++ b/ui-tui/packages/hermes-ink/src/ink/colorize.ts @@ -28,6 +28,39 @@ function boostChalkLevelForXtermJs(): boolean { return false } +export function shouldUseRichEightBitDowngradeForLegacyAppleTerminal( + env: NodeJS.ProcessEnv = process.env, + level = chalk.level +): boolean { + const termProgram = (env.TERM_PROGRAM ?? '').trim() + const truecolorOverride = /^(?:1|true|yes|on)$/i.test((env.HERMES_TUI_TRUECOLOR ?? '').trim()) + const advertisesTruecolor = /^(?:truecolor|24bit)$/i.test((env.COLORTERM ?? '').trim()) + + return termProgram === 'Apple_Terminal' && !truecolorOverride && !advertisesTruecolor && !('FORCE_COLOR' in env) && level === 2 +} + +export function richEightBitColorNumber(red: number, green: number, blue: number): number { + const rn = red / 255 + const gn = green / 255 + const bn = blue / 255 + const max = Math.max(rn, gn, bn) + const min = Math.min(rn, gn, bn) + const lightness = (max + min) / 2 + const saturation = max === min ? 0 : lightness > 0.5 ? (max - min) / (2 - max - min) : (max - min) / (max + min) + + if (saturation < 0.15) { + const gray = Math.round(lightness * 25) + + return gray === 0 ? 16 : gray === 25 ? 231 : 231 + gray + } + + const sixRed = red < 95 ? red / 95 : 1 + (red - 95) / 40 + const sixGreen = green < 95 ? green / 95 : 1 + (green - 95) / 40 + const sixBlue = blue < 95 ? blue / 95 : 1 + (blue - 95) / 40 + + return 16 + 36 * Math.round(sixRed) + 6 * Math.round(sixGreen) + Math.round(sixBlue) +} + /** * tmux parses truecolor SGR (\e[48;2;r;g;bm) into its cell buffer correctly, * but its client-side emitter only re-emits truecolor to the outer terminal if @@ -58,15 +91,17 @@ function clampChalkLevelForTmux(): boolean { } // Computed once at module load — terminal/tmux environment doesn't change mid-session. -// Order matters: boost first so the tmux clamp can re-clamp if tmux is running -// inside a VS Code terminal. Exported for debugging — tree-shaken if unused. +// Order matters: boost first; then tmux can still clamp RGB to 256. +// Exported for debugging — tree-shaken if unused. export const CHALK_BOOSTED_FOR_XTERMJS = boostChalkLevelForXtermJs() export const CHALK_CLAMPED_FOR_TMUX = clampChalkLevelForTmux() +export const CHALK_USES_RICH_EIGHT_BIT_DOWNGRADE = shouldUseRichEightBitDowngradeForLegacyAppleTerminal() export type ColorType = 'foreground' | 'background' const RGB_REGEX = /^rgb\(\s?(\d+),\s?(\d+),\s?(\d+)\s?\)$/ const ANSI_REGEX = /^ansi256\(\s?(\d+)\s?\)$/ +const HEX_REGEX = /^#[0-9a-fA-F]{6}$/ export const colorize = (str: string, color: string | undefined, type: ColorType): string => { if (!color) { @@ -128,6 +163,16 @@ export const colorize = (str: string, color: string | undefined, type: ColorType } if (color.startsWith('#')) { + if (HEX_REGEX.test(color) && CHALK_USES_RICH_EIGHT_BIT_DOWNGRADE) { + const value = Number.parseInt(color.slice(1), 16) + const red = (value >> 16) & 0xff + const green = (value >> 8) & 0xff + const blue = value & 0xff + const ansi = richEightBitColorNumber(red, green, blue) + + return type === 'foreground' ? chalk.ansi256(ansi)(str) : chalk.bgAnsi256(ansi)(str) + } + return type === 'foreground' ? chalk.hex(color)(str) : chalk.bgHex(color)(str) } @@ -154,6 +199,12 @@ export const colorize = (str: string, color: string | undefined, type: ColorType const secondValue = Number(matches[2]) const thirdValue = Number(matches[3]) + if (CHALK_USES_RICH_EIGHT_BIT_DOWNGRADE) { + const ansi = richEightBitColorNumber(firstValue, secondValue, thirdValue) + + return type === 'foreground' ? chalk.ansi256(ansi)(str) : chalk.bgAnsi256(ansi)(str) + } + return type === 'foreground' ? chalk.rgb(firstValue, secondValue, thirdValue)(str) : chalk.bgRgb(firstValue, secondValue, thirdValue)(str) diff --git a/ui-tui/packages/hermes-ink/src/ink/components/Text.test.ts b/ui-tui/packages/hermes-ink/src/ink/components/Text.test.ts index 9869189edd..50628d5380 100644 --- a/ui-tui/packages/hermes-ink/src/ink/components/Text.test.ts +++ b/ui-tui/packages/hermes-ink/src/ink/components/Text.test.ts @@ -1,18 +1,38 @@ import { describe, expect, it } from 'vitest' -import { shouldUseAnsiDim } from './Text.js' +import { dimColorFallback, shouldUseAnsiDim } from './Text.js' describe('shouldUseAnsiDim', () => { it('disables ANSI dim on VTE terminals by default', () => { expect(shouldUseAnsiDim({ VTE_VERSION: '7603' } as NodeJS.ProcessEnv)).toBe(false) }) + it('disables ANSI dim on Apple Terminal by default', () => { + expect(shouldUseAnsiDim({ TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBe(false) + }) + it('keeps ANSI dim enabled elsewhere by default', () => { expect(shouldUseAnsiDim({ TERM: 'xterm-256color' } as NodeJS.ProcessEnv)).toBe(true) }) it('honors explicit env override', () => { expect(shouldUseAnsiDim({ HERMES_TUI_DIM: '1', VTE_VERSION: '7603' } as NodeJS.ProcessEnv)).toBe(true) + expect(shouldUseAnsiDim({ HERMES_TUI_DIM: '1', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBe(true) expect(shouldUseAnsiDim({ HERMES_TUI_DIM: '0' } as NodeJS.ProcessEnv)).toBe(false) }) }) + +describe('dimColorFallback', () => { + it('renders Apple Terminal dim as muted gray by default', () => { + expect(dimColorFallback({ TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBe('#6B7280') + }) + + it('normalizes Apple Terminal names before matching', () => { + expect(dimColorFallback({ TERM_PROGRAM: ' Apple_Terminal ' } as NodeJS.ProcessEnv)).toBe('#6B7280') + }) + + it('does not apply when dim is explicitly configured', () => { + expect(dimColorFallback({ HERMES_TUI_DIM: '1', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBeUndefined() + expect(dimColorFallback({ HERMES_TUI_DIM: '0', TERM_PROGRAM: 'Apple_Terminal' } as NodeJS.ProcessEnv)).toBeUndefined() + }) +}) diff --git a/ui-tui/packages/hermes-ink/src/ink/components/Text.tsx b/ui-tui/packages/hermes-ink/src/ink/components/Text.tsx index d6b7fdccd5..4eb4bc7b96 100644 --- a/ui-tui/packages/hermes-ink/src/ink/components/Text.tsx +++ b/ui-tui/packages/hermes-ink/src/ink/components/Text.tsx @@ -6,6 +6,7 @@ import type { Color, Styles } from '../styles.js' const ENV_ON_RE = /^(?:1|true|yes|on)$/i const ENV_OFF_RE = /^(?:0|false|no|off)$/i +const LEGACY_APPLE_DIM_COLOR: Color = '#6B7280' type BaseProps = { /** * Change text color. Accepts a raw color value (rgb, hex, ansi). @@ -76,9 +77,23 @@ export function shouldUseAnsiDim(env: NodeJS.ProcessEnv = process.env): boolean return false } + if ((env.TERM_PROGRAM ?? '').trim() === 'Apple_Terminal') { + return false + } + return !env.VTE_VERSION } +export function dimColorFallback(env: NodeJS.ProcessEnv = process.env): Color | undefined { + const override = (env.HERMES_TUI_DIM ?? '').trim() + + if (ENV_ON_RE.test(override) || ENV_OFF_RE.test(override)) { + return undefined + } + + return (env.TERM_PROGRAM ?? '').trim() === 'Apple_Terminal' ? LEGACY_APPLE_DIM_COLOR : undefined +} + const memoizedStylesForWrap: Record, Styles> = { wrap: { flexGrow: 0, @@ -161,6 +176,7 @@ export default function Text(t0: Props) { const inverse = t4 === undefined ? false : t4 const wrap = t5 === undefined ? 'wrap' : t5 const effectiveDim = dim && shouldUseAnsiDim() + const effectiveColor = dim && !effectiveDim ? (color ?? dimColorFallback()) : color if (children === undefined || children === null) { return null @@ -168,11 +184,11 @@ export default function Text(t0: Props) { let t6 - if ($[0] !== color) { - t6 = color && { - color + if ($[0] !== effectiveColor) { + t6 = effectiveColor && { + color: effectiveColor } - $[0] = color + $[0] = effectiveColor $[1] = t6 } else { t6 = $[1] diff --git a/ui-tui/packages/hermes-ink/src/ink/parse-keypress.test.ts b/ui-tui/packages/hermes-ink/src/ink/parse-keypress.test.ts index 58745b8c40..89c842c015 100644 --- a/ui-tui/packages/hermes-ink/src/ink/parse-keypress.test.ts +++ b/ui-tui/packages/hermes-ink/src/ink/parse-keypress.test.ts @@ -39,3 +39,60 @@ describe('parseMultipleKeypresses bracketed paste recovery', () => { expect(state.pasteBuffer).toBe('') }) }) + +describe('mouse wheel modifier decoding', () => { + // SGR mouse format: ESC [ < button ; col ; row M + // Wheel up = 64 (0x40), wheel down = 65 (0x41). + // Modifier bits: shift = 0x04, meta = 0x08, ctrl = 0x10. + const sgrWheel = (button: number) => `\x1b[<${button};10;10M` + + it('plain wheel up has no modifiers', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x40)) + + expect(key).toMatchObject({ name: 'wheelup', ctrl: false, meta: false, shift: false }) + }) + + it('plain wheel down has no modifiers', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x41)) + + expect(key).toMatchObject({ name: 'wheeldown', ctrl: false, meta: false, shift: false }) + }) + + it('decodes meta (Alt/Option) on wheel up', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x40 | 0x08)) + + expect(key).toMatchObject({ name: 'wheelup', ctrl: false, meta: true, shift: false }) + }) + + it('decodes meta (Alt/Option) on wheel down', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x41 | 0x08)) + + expect(key).toMatchObject({ name: 'wheeldown', ctrl: false, meta: true, shift: false }) + }) + + it('decodes ctrl on wheel events', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x40 | 0x10)) + + expect(key).toMatchObject({ name: 'wheelup', ctrl: true, meta: false, shift: false }) + }) + + it('decodes shift on wheel events', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x41 | 0x04)) + + expect(key).toMatchObject({ name: 'wheeldown', ctrl: false, meta: false, shift: true }) + }) + + it('decodes combined modifiers', () => { + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, sgrWheel(0x40 | 0x08 | 0x10)) + + expect(key).toMatchObject({ name: 'wheelup', ctrl: true, meta: true, shift: false }) + }) + + it('decodes meta on legacy X10 wheel encoding', () => { + // X10: ESC [ M Cb Cx Cy where each byte is value+32. + const x10 = `\x1b[M${String.fromCharCode(0x40 + 0x08 + 32)}${String.fromCharCode(10 + 32)}${String.fromCharCode(10 + 32)}` + const [[key]] = parseMultipleKeypresses(INITIAL_STATE, x10) + + expect(key).toMatchObject({ name: 'wheelup', meta: true }) + }) +}) diff --git a/ui-tui/packages/hermes-ink/src/ink/parse-keypress.ts b/ui-tui/packages/hermes-ink/src/ink/parse-keypress.ts index 56976d8a84..3a21aa2646 100644 --- a/ui-tui/packages/hermes-ink/src/ink/parse-keypress.ts +++ b/ui-tui/packages/hermes-ink/src/ink/parse-keypress.ts @@ -697,16 +697,17 @@ function parseKeypress(s: string = ''): ParsedKey { // never reach here. Mask with 0x43 (bits 6+1+0) to check wheel-flag // + direction while ignoring modifier bits (Shift=0x04, Meta=0x08, // Ctrl=0x10) — modified wheel events (e.g. Ctrl+scroll, button=80) - // should still be recognized as wheelup/wheeldown. + // should still be recognized as wheelup/wheeldown. Preserve those + // modifier bits for callers that bind modified wheel gestures. if ((match = SGR_MOUSE_RE.exec(s))) { const button = parseInt(match[1]!, 10) if ((button & 0x43) === 0x40) { - return createNavKey(s, 'wheelup', false) + return createWheelKey(s, 'wheelup', button) } if ((button & 0x43) === 0x41) { - return createNavKey(s, 'wheeldown', false) + return createWheelKey(s, 'wheeldown', button) } // Shouldn't reach here (parseMouseEvent catches non-wheel) but be safe @@ -722,11 +723,11 @@ function parseKeypress(s: string = ''): ParsedKey { const button = s.charCodeAt(3) - 32 if ((button & 0x43) === 0x40) { - return createNavKey(s, 'wheelup', false) + return createWheelKey(s, 'wheelup', button) } if ((button & 0x43) === 0x41) { - return createNavKey(s, 'wheeldown', false) + return createWheelKey(s, 'wheeldown', button) } return createNavKey(s, 'mouse', false) @@ -834,3 +835,19 @@ function createNavKey(s: string, name: string, ctrl: boolean): ParsedKey { isPasted: false } } + +function createWheelKey(s: string, name: 'wheelup' | 'wheeldown', button: number): ParsedKey { + return { + kind: 'key', + name, + ctrl: !!(button & 0x10), + meta: !!(button & 0x08), + shift: !!(button & 0x04), + option: false, + super: false, + fn: false, + sequence: s, + raw: s, + isPasted: false + } +} diff --git a/ui-tui/src/__tests__/createGatewayEventHandler.test.ts b/ui-tui/src/__tests__/createGatewayEventHandler.test.ts index 378f873b4b..1729f0c273 100644 --- a/ui-tui/src/__tests__/createGatewayEventHandler.test.ts +++ b/ui-tui/src/__tests__/createGatewayEventHandler.test.ts @@ -119,6 +119,19 @@ describe('createGatewayEventHandler', () => { expect(getTurnState().todos).toEqual(todos) }) + it('prints compaction progress status into the transcript', () => { + const appended: Msg[] = [] + const ctx = buildCtx(appended) + const onEvent = createGatewayEventHandler(ctx) + + onEvent({ + payload: { kind: 'compressing', text: 'compressing 968 messages (~123,400 tok)…' }, + type: 'status.update' + } as any) + + expect(ctx.system.sys).toHaveBeenCalledWith('compressing 968 messages (~123,400 tok)…') + }) + it('clears the visible todo list when the todo tool returns an empty list', () => { const appended: Msg[] = [] const todos = [{ content: 'Boil water', id: 'boil', status: 'in_progress' }] diff --git a/ui-tui/src/__tests__/createSlashHandler.test.ts b/ui-tui/src/__tests__/createSlashHandler.test.ts index 3ec340b8a2..e8c50c05d2 100644 --- a/ui-tui/src/__tests__/createSlashHandler.test.ts +++ b/ui-tui/src/__tests__/createSlashHandler.test.ts @@ -76,6 +76,45 @@ describe('createSlashHandler', () => { }) }) + it('applies /reasoning hide to the thinking section immediately', async () => { + patchUiState({ sections: { thinking: 'expanded' }, showReasoning: true, sid: 'sid-abc' }) + const ctx = buildCtx({ + gateway: { + ...buildGateway(), + rpc: vi.fn(() => Promise.resolve({ value: 'hide' })) + } + }) + + expect(createSlashHandler(ctx)('/reasoning hide')).toBe(true) + + await vi.waitFor(() => { + expect(getUiState().showReasoning).toBe(false) + expect(getUiState().sections.thinking).toBe('hidden') + }) + expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', { + key: 'reasoning', + session_id: 'sid-abc', + value: 'hide' + }) + }) + + it('applies /reasoning show to the thinking section immediately', async () => { + patchUiState({ sections: { thinking: 'hidden' }, showReasoning: false, sid: 'sid-abc' }) + const ctx = buildCtx({ + gateway: { + ...buildGateway(), + rpc: vi.fn(() => Promise.resolve({ value: 'show' })) + } + }) + + expect(createSlashHandler(ctx)('/reasoning show')).toBe(true) + + await vi.waitFor(() => { + expect(getUiState().showReasoning).toBe(true) + expect(getUiState().sections.thinking).toBe('expanded') + }) + }) + it('opens the skills hub locally for bare /skills', () => { const ctx = buildCtx() @@ -141,6 +180,12 @@ describe('createSlashHandler', () => { expect(createSlashHandler(ctx)('/details toggle')).toBe(true) expect(getUiState().detailsMode).toBe('expanded') expect(getUiState().detailsModeCommandOverride).toBe(true) + expect(getUiState().sections).toEqual({ + thinking: 'expanded', + tools: 'expanded', + subagents: 'expanded', + activity: 'expanded' + }) expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', { key: 'details_mode', value: 'expanded' diff --git a/ui-tui/src/__tests__/forceTruecolor.test.ts b/ui-tui/src/__tests__/forceTruecolor.test.ts index 7cbf46d2b6..4d97832815 100644 --- a/ui-tui/src/__tests__/forceTruecolor.test.ts +++ b/ui-tui/src/__tests__/forceTruecolor.test.ts @@ -1,6 +1,7 @@ import { describe, expect, it } from 'vitest' -const ENV_KEYS = ['COLORTERM', 'FORCE_COLOR', 'HERMES_TUI_TRUECOLOR', 'NO_COLOR'] as const +const ENV_KEYS = ['COLORTERM', 'FORCE_COLOR', 'HERMES_TUI_TRUECOLOR', 'NO_COLOR', 'TERM', 'TERM_PROGRAM'] as const +let importId = 0 async function withCleanEnv(setup: () => void, body: () => Promise) { const saved: Record = {} @@ -25,11 +26,39 @@ async function withCleanEnv(setup: () => void, body: () => Promise) { } describe('forceTruecolor', () => { - it('sets COLORTERM=truecolor and FORCE_COLOR=3 when unset', async () => { + it('does not force truecolor by default', async () => { await withCleanEnv( () => {}, async () => { - await import('../lib/forceTruecolor.js?t=' + Date.now()) + await import('../lib/forceTruecolor.js?t=default-' + importId++) + expect(process.env.COLORTERM).toBeUndefined() + expect(process.env.FORCE_COLOR).toBeUndefined() + } + ) + }) + + it('does not infer truecolor from Apple Terminal on pre-Tahoe macOS', async () => { + await withCleanEnv( + () => { + process.env.TERM_PROGRAM = 'Apple_Terminal' + process.env.TERM = 'xterm-256color' + }, + async () => { + const mod = await import('../lib/forceTruecolor.js?t=apple-' + importId++) + expect(mod.shouldForceTruecolor({ TERM_PROGRAM: 'Apple_Terminal' })).toBe(false) + expect(process.env.COLORTERM).toBeUndefined() + expect(process.env.FORCE_COLOR).toBeUndefined() + } + ) + }) + + it('sets COLORTERM=truecolor and FORCE_COLOR=3 when explicitly enabled', async () => { + await withCleanEnv( + () => { + process.env.HERMES_TUI_TRUECOLOR = '1' + }, + async () => { + await import('../lib/forceTruecolor.js?t=enabled-' + importId++) expect(process.env.COLORTERM).toBe('truecolor') expect(process.env.FORCE_COLOR).toBe('3') } @@ -40,9 +69,10 @@ describe('forceTruecolor', () => { await withCleanEnv( () => { process.env.HERMES_TUI_TRUECOLOR = '0' + process.env.TERM_PROGRAM = 'Apple_Terminal' }, async () => { - await import('../lib/forceTruecolor.js?t=optout-' + Date.now()) + await import('../lib/forceTruecolor.js?t=optout-' + importId++) expect(process.env.COLORTERM).toBeUndefined() expect(process.env.FORCE_COLOR).toBeUndefined() } @@ -53,12 +83,41 @@ describe('forceTruecolor', () => { await withCleanEnv( () => { process.env.NO_COLOR = '1' + process.env.HERMES_TUI_TRUECOLOR = '1' }, async () => { - await import('../lib/forceTruecolor.js?t=no-color-' + Date.now()) + await import('../lib/forceTruecolor.js?t=no-color-' + importId++) expect(process.env.COLORTERM).toBeUndefined() expect(process.env.FORCE_COLOR).toBeUndefined() } ) }) + + it('respects existing FORCE_COLOR unless Hermes truecolor is explicit', async () => { + await withCleanEnv( + () => { + process.env.FORCE_COLOR = '' + }, + async () => { + const mod = await import('../lib/forceTruecolor.js?t=force-color-' + importId++) + expect(mod.shouldForceTruecolor(process.env)).toBe(false) + expect(process.env.COLORTERM).toBeUndefined() + expect(process.env.FORCE_COLOR).toBe('') + } + ) + }) + + it('lets explicit Hermes truecolor override existing FORCE_COLOR', async () => { + await withCleanEnv( + () => { + process.env.FORCE_COLOR = '0' + process.env.HERMES_TUI_TRUECOLOR = '1' + }, + async () => { + await import('../lib/forceTruecolor.js?t=explicit-force-' + importId++) + expect(process.env.COLORTERM).toBe('truecolor') + expect(process.env.FORCE_COLOR).toBe('3') + } + ) + }) }) diff --git a/ui-tui/src/__tests__/markdown.test.ts b/ui-tui/src/__tests__/markdown.test.ts index 0e95ba6c0f..a415668f46 100644 --- a/ui-tui/src/__tests__/markdown.test.ts +++ b/ui-tui/src/__tests__/markdown.test.ts @@ -61,6 +61,66 @@ describe('stripInlineMarkup', () => { expect(stripInlineMarkup('Yay ~! nice work ~!')).toBe('Yay ~! nice work ~!') expect(stripInlineMarkup('H~2~O and CO~2~')).toBe('H_2O and CO_2') }) + + it('strips inline math delimiters but keeps the formula text', () => { + expect(stripInlineMarkup('$\\mathbb{Z}$ is a ring')).toBe('\\mathbb{Z} is a ring') + expect(stripInlineMarkup('see \\(a + b\\) ok')).toBe('see a + b ok') + }) +}) + +describe('INLINE_RE inline math', () => { + it('matches single-dollar math and beats emphasis at the same start', () => { + // Without math handling, `*b*` would have matched as italics and + // corrupted the formula. With math added to INLINE_RE, the leftmost + // match at column 0 (`$P=a*b*c$`) wins. + expect(matches('$P=a*b*c$')).toEqual(['$P=a*b*c$']) + expect(matches('see $\\mathbb{Z}$ here')).toEqual(['$\\mathbb{Z}$']) + }) + + it('does not match currency-style prose', () => { + expect(matches('it costs $5 and $10')).toEqual([]) + expect(matches('paid $5')).toEqual([]) + }) + + it('does not let inline math swallow a $$ display fence', () => { + // `$$x$$` is a display block, not two abutting inline-math spans. + expect(matches('$$x$$')).toEqual([]) + }) + + it('matches \\(...\\) inline math', () => { + expect(matches('foo \\(x + y\\) bar')).toEqual(['\\(x + y\\)']) + }) + + it('does not corrupt subscripts/superscripts inside math', () => { + // `_n` and `^r` are markdown emphasis/superscript markers in prose, but + // inside a `$...$` span the entire formula is captured as a single + // inline-math token so the inner regexes never see those characters. + expect(matches('$P=a_n x^n + a_0$')).toEqual(['$P=a_n x^n + a_0$']) + expect(matches('$\\beta_1,\\dots,\\beta_r$')).toEqual(['$\\beta_1,\\dots,\\beta_r$']) + }) + + it('places math content in the correct capture group (regression: m[16] is bare URL)', () => { + // When `m[16]` was the bare URL group AND the inline-math `$...$` + // group simultaneously (because the bare URL pattern lacked its own + // capturing parens), MdInline rendered `$\\mathbb{R}$` as an + // underlined autolink instead of italic amber math. Lock down the + // numbering: math goes in m[17] / m[18], URLs go in m[16]. + const url = [...'see https://example.com here'.matchAll(INLINE_RE)][0]! + const dollarMath = [...'$\\mathbb{R}$'.matchAll(INLINE_RE)][0]! + const parenMath = [...'\\(\\pi\\)'.matchAll(INLINE_RE)][0]! + + expect(url[16]).toBe('https://example.com') + expect(url[17]).toBeUndefined() + expect(url[18]).toBeUndefined() + + expect(dollarMath[16]).toBeUndefined() + expect(dollarMath[17]).toBe('\\mathbb{R}') + expect(dollarMath[18]).toBeUndefined() + + expect(parenMath[16]).toBeUndefined() + expect(parenMath[17]).toBeUndefined() + expect(parenMath[18]).toBe('\\pi') + }) }) describe('protocol sentinels', () => { diff --git a/ui-tui/src/__tests__/mathUnicode.test.ts b/ui-tui/src/__tests__/mathUnicode.test.ts new file mode 100644 index 0000000000..fb9f029aa8 --- /dev/null +++ b/ui-tui/src/__tests__/mathUnicode.test.ts @@ -0,0 +1,293 @@ +import { describe, expect, it } from 'vitest' + +import { BOX_CLOSE, BOX_OPEN, BOX_RE, texToUnicode } from '../lib/mathUnicode.js' + +const stripBox = (s: string) => s.replace(BOX_RE, '$1') + +describe('texToUnicode — symbols', () => { + it('substitutes lowercase Greek', () => { + expect(texToUnicode('\\alpha + \\beta + \\pi')).toBe('α + β + π') + expect(texToUnicode('\\omega')).toBe('ω') + }) + + it('substitutes uppercase Greek', () => { + expect(texToUnicode('\\Sigma \\Omega \\Pi')).toBe('Σ Ω Π') + }) + + it('substitutes set theory and logic operators', () => { + expect(texToUnicode('A \\cup B \\cap C')).toBe('A ∪ B ∩ C') + expect(texToUnicode('\\forall x \\in \\emptyset')).toBe('∀ x ∈ ∅') + expect(texToUnicode('p \\implies q \\iff r')).toBe('p ⟹ q ⟺ r') + }) + + it('substitutes relations and arrows', () => { + expect(texToUnicode('a \\le b \\ge c \\ne d')).toBe('a ≤ b ≥ c ≠ d') + expect(texToUnicode('f: A \\to B')).toBe('f: A → B') + }) + + it('uses longest-match-first so \\leq beats \\le', () => { + expect(texToUnicode('\\leq')).toBe('≤') + }) + + it('preserves unknown commands that share a prefix with known ones', () => { + // `\leqq` is a real LaTeX command (≦) we don't have in our table. + // The word-boundary lookahead prevents `\le` from matching, so the + // whole thing is preserved verbatim — much better than `≤qq`. + expect(texToUnicode('\\leqq')).toBe('\\leqq') + }) + + it('refuses to substitute a partial command (word boundary)', () => { + expect(texToUnicode('\\alphabet')).toBe('\\alphabet') + expect(texToUnicode('\\pin')).toBe('\\pin') + }) +}) + +describe('texToUnicode — blackboard / calligraphic / fraktur', () => { + it('renders \\mathbb capitals', () => { + expect(texToUnicode('\\mathbb{R}')).toBe('ℝ') + expect(texToUnicode('\\mathbb{N} \\subset \\mathbb{Z} \\subset \\mathbb{Q} \\subset \\mathbb{R}')).toBe('ℕ ⊂ ℤ ⊂ ℚ ⊂ ℝ') + }) + + it('renders \\mathcal and \\mathfrak', () => { + expect(texToUnicode('\\mathcal{F} \\subset \\mathfrak{A}')).toBe('ℱ ⊂ 𝔄') + }) + + it('preserves \\mathbb{...} when argument is multi-letter or non-letter', () => { + expect(texToUnicode('\\mathbb{NN}')).toBe('\\mathbb{NN}') + expect(texToUnicode('\\mathbb{1}')).toBe('\\mathbb{1}') + }) + + it('strips \\mathbf / \\mathit / \\mathrm / \\text wrappers (no Unicode bold/italic in monospace)', () => { + expect(texToUnicode('\\mathbf{x}')).toBe('x') + expect(texToUnicode('\\text{if } x > 0')).toBe('if x > 0') + expect(texToUnicode('\\operatorname{rank}(A)')).toBe('rank(A)') + }) +}) + +describe('texToUnicode — sub / superscripts', () => { + it('converts simple superscripts', () => { + expect(texToUnicode('x^2 + y^2')).toBe('x² + y²') + expect(texToUnicode('e^{n}')).toBe('eⁿ') + }) + + it('converts simple subscripts', () => { + expect(texToUnicode('a_1 + a_2 + a_n')).toBe('a₁ + a₂ + aₙ') + expect(texToUnicode('x_{0}')).toBe('x₀') + }) + + it('converts mixed-content scripts when every glyph has a Unicode form', () => { + // `+`, digits, and lowercase letters all have superscript glyphs, + // so `n+1` → `ⁿ⁺¹`. Comma has no subscript form, so `i,j` falls + // back to `_(i,j)` (parens) rather than partially substituting — + // parens read as ordinary grouping while braces look like leftover + // unrendered LaTeX. + expect(texToUnicode('x^{n+1}')).toBe('xⁿ⁺¹') + expect(texToUnicode('a_{i,j}')).toBe('a_(i,j)') + }) + + it('uses parens (not braces) when the body has Greek with no superscript form', () => { + // π has no Unicode superscript, so `e^{i\pi}` after symbol pass is + // `e^{iπ}` and the script fallback emits `e^(iπ)` — much more + // readable than the LaTeX-looking `e^{iπ}`. + expect(texToUnicode('e^{i\\pi}')).toBe('e^(iπ)') + }) + + it('strips braces on script fallback when body collapses to a single char', () => { + // `^{\infty}` → symbol pass produces `^{∞}` → convertScript can't + // find ∞ in SUPERSCRIPT, but the body is one char so we drop the + // braces and emit `^∞` (much more readable than `^{∞}`). + expect(texToUnicode('e^{\\infty}')).toBe('e^∞') + }) + + it('handles a real-world sum', () => { + expect(texToUnicode('\\sum_{n=0}^{\\infty} \\frac{1}{n!}')).toBe('∑ₙ₌₀^∞ 1/n!') + }) +}) + +describe('texToUnicode — fractions', () => { + it('collapses \\frac to a/b', () => { + expect(texToUnicode('\\frac{1}{2}')).toBe('1/2') + expect(texToUnicode('\\frac{a}{b}')).toBe('a/b') + }) + + it('parenthesises multi-token numerator / denominator', () => { + expect(texToUnicode('\\frac{n+1}{2}')).toBe('(n+1)/2') + expect(texToUnicode('\\frac{a + b}{c - d}')).toBe('(a + b)/(c - d)') + }) + + it('handles nested fractions', () => { + expect(texToUnicode('\\frac{1}{\\frac{1}{x}}')).toBe('1/(1/x)') + }) + + it('handles braces inside numerator / denominator (regression: regex \\frac couldn\'t)', () => { + // The regex-only `\frac` matcher used `[^{}]*` for each arg, which + // failed the moment a numerator contained its own braces (here the + // `{p-1}` from a superscript). The balanced-brace parser handles it. + expect(texToUnicode('\\frac{|t|^{p-1}|P(t)|^p}{(p-1)!}')).toBe('(|t|ᵖ⁻¹|P(t)|ᵖ)/((p-1)!)') + }) + + it('preserves \\frac when arguments are malformed', () => { + expect(texToUnicode('\\frac{a}')).toBe('\\frac{a}') + expect(texToUnicode('\\fraction{a}{b}')).toBe('\\fraction{a}{b}') + }) +}) + +describe('texToUnicode — typography no-ops', () => { + it('strips \\displaystyle / \\textstyle / \\scriptstyle / \\scriptscriptstyle', () => { + expect(texToUnicode('\\displaystyle\\sum_{i=1}^n x_i')).toBe('∑ᵢ₌₁ⁿ xᵢ') + expect(texToUnicode('f(x) = \\displaystyle \\frac{1}{2}')).toBe('f(x) = 1/2') + expect(texToUnicode('\\textstyle x + y')).toBe('x + y') + }) + + it('strips \\limits / \\nolimits which only affect bound positioning', () => { + expect(texToUnicode('\\sum\\limits_{k=1}^n a_k')).toBe('∑ₖ₌₁ⁿ aₖ') + expect(texToUnicode('\\int\\nolimits_0^1 f(x) dx')).toBe('∫₀¹ f(x) dx') + }) + + it('does not eat letter-continuation commands like \\limit_inf', () => { + // The `(?![A-Za-z])` lookahead protects hypothetical commands that + // start with `\limit` / `\display` / etc. The bare names are stripped + // but anything longer is preserved verbatim. + expect(texToUnicode('\\limitinf x')).toBe('\\limitinf x') + }) +}) + +describe('texToUnicode — sizing wrappers', () => { + it('strips \\big / \\Big / \\bigg / \\Bigg before delimiters', () => { + expect(texToUnicode('\\bigl[ x \\bigr]')).toBe('[ x ]') + expect(texToUnicode('\\Big( y \\Big)')).toBe('( y )') + expect(texToUnicode('\\bigg| z \\bigg|')).toBe('| z |') + expect(texToUnicode('\\Biggl\\{ a \\Biggr\\}')).toBe('{ a }') + }) + + it('does not eat \\bigtriangleup or other letter-continuations', () => { + expect(texToUnicode('A \\bigtriangleup B')).toBe('A \\bigtriangleup B') + }) +}) + +describe('texToUnicode — modular arithmetic and tags', () => { + it('renders \\pmod{p} as " (mod p)"', () => { + expect(texToUnicode('a \\equiv b \\pmod{p}')).toBe('a ≡ b (mod p)') + }) + + it('renders \\bmod / \\mod inline', () => { + expect(texToUnicode('a \\bmod n')).toBe('a mod n') + }) + + it('collapses \\tag{n} to " (n)"', () => { + expect(texToUnicode('x = y \\tag{24}')).toBe('x = y (24)') + }) +}) + +describe('texToUnicode — newly added symbols', () => { + it('renders \\nmid, \\blacksquare, \\qed', () => { + expect(texToUnicode('p \\nmid q')).toBe('p ∤ q') + expect(texToUnicode('Therefore \\blacksquare')).toBe('Therefore ■') + expect(texToUnicode('done \\qed')).toBe('done ∎') + }) +}) + +describe('texToUnicode — \\boxed / \\fbox', () => { + // `\boxed` produces non-printable U+0001 / U+0002 sentinels around its + // content so the markdown renderer can apply highlight styling. These + // tests assert both the sentinel form and the human-readable + // strip-fallback (BOX_RE). + it('wraps simple boxed content in BOX_OPEN/BOX_CLOSE sentinels', () => { + expect(texToUnicode('\\boxed{x = 0}')).toBe(`${BOX_OPEN}x = 0${BOX_CLOSE}`) + expect(stripBox(texToUnicode('\\boxed{x = 0}'))).toBe('x = 0') + expect(stripBox(texToUnicode('\\fbox{answer}'))).toBe('answer') + }) + + it('handles boxed expressions with nested braces (regression: regex couldn\'t)', () => { + // A `[^{}]*` regex would stop at the first `{` inside the body. The + // balanced-brace parser walks past it. + expect(stripBox(texToUnicode('\\boxed{x^{n+1}}'))).toBe('xⁿ⁺¹') + expect(stripBox(texToUnicode('\\boxed{\\frac{a}{b}}'))).toBe('a/b') + }) + + it('handles real-world boxed final answer', () => { + expect(stripBox(texToUnicode('\\boxed{J = -\\sum_{k=0}^n a_k F(k)}'))).toBe('J = -∑ₖ₌₀ⁿ aₖ F(k)') + }) + + it('preserves \\boxed without a brace argument', () => { + expect(texToUnicode('\\boxed something')).toBe('\\boxed something') + }) +}) + +describe('texToUnicode — combining marks', () => { + it('applies \\overline / \\bar / \\hat / \\vec / \\tilde', () => { + expect(texToUnicode('\\overline{x}')).toBe('x\u0305') + expect(texToUnicode('\\hat{y}')).toBe('y\u0302') + expect(texToUnicode('\\vec{v}')).toBe('v\u20D7') + }) +}) + +describe('texToUnicode — left/right delimiters', () => { + it('strips \\left and \\right keeping the delimiter character', () => { + expect(texToUnicode('\\left( x + y \\right)')).toBe('( x + y )') + expect(texToUnicode('\\left| x \\right|')).toBe('| x |') + }) + + it('handles escaped delimiters \\left\\{ ... \\right\\}', () => { + expect(texToUnicode('\\left\\{p/q \\mid q \\neq 0\\right\\}')).toBe('{p/q ∣ q ≠ 0}') + }) + + it('handles named delimiters via \\left\\langle / \\right\\rangle', () => { + expect(texToUnicode('\\left\\langle u, v \\right\\rangle')).toBe('⟨ u, v ⟩') + }) + + it('drops \\left. and \\right. (which are explicit "no delimiter")', () => { + expect(texToUnicode('\\left. f \\right|')).toBe(' f |') + }) + + it('preserves \\leftarrow / \\rightarrow (word boundary blocks the strip)', () => { + expect(texToUnicode('A \\leftarrow B \\rightarrow C')).toBe('A ← B → C') + }) +}) + +describe('texToUnicode — labelled arrows', () => { + it('renders \\xrightarrow{label} as ─label→', () => { + expect(texToUnicode('a \\xrightarrow{x=1} b')).toBe('a ─x=1→ b') + }) + + it('renders \\xleftarrow{label} as ←label─', () => { + expect(texToUnicode('a \\xleftarrow{n} b')).toBe('a ←n─ b') + }) + + it('still applies symbol substitution inside the label', () => { + expect(texToUnicode('a \\xrightarrow{n \\to \\infty} L')).toBe('a ─n → ∞→ L') + }) +}) + +describe('texToUnicode — punctuation commands without lookahead', () => { + it('substitutes \\{ even when immediately followed by a letter', () => { + // Regression: with a global `(?![A-Za-z])` lookahead, `\{p` refused + // to substitute (because `p` is a letter) and rendered as `\{p`. + expect(texToUnicode('\\{p, q\\}')).toBe('{p, q}') + }) + + it('substitutes thin-space \\, before a letter', () => { + expect(texToUnicode('a\\,b')).toBe('a b') + }) +}) + +describe('texToUnicode — round-trip realism', () => { + it('renders a typical model-emitted formula', () => { + expect(texToUnicode('\\alpha \\in \\mathbb{R}, \\alpha \\notin \\mathbb{Q}')).toBe('α ∈ ℝ, α ∉ ℚ') + }) + + it('preserves unknown commands verbatim', () => { + expect(texToUnicode('\\bigtriangleup \\circledast')).toBe('\\bigtriangleup \\circledast') + }) + + it('handles commands without delimiters between', () => { + // Word-boundary lookahead means `\alpha\beta` doesn't accidentally + // match `\alphabeta` as one ungrouped token. + expect(texToUnicode('\\alpha\\beta')).toBe('αβ') + }) + + it('leaves plain text alone', () => { + expect(texToUnicode('hello world')).toBe('hello world') + expect(texToUnicode('')).toBe('') + }) +}) diff --git a/ui-tui/src/__tests__/streamingMarkdown.test.ts b/ui-tui/src/__tests__/streamingMarkdown.test.ts index 5389d56e42..1a825a62f1 100644 --- a/ui-tui/src/__tests__/streamingMarkdown.test.ts +++ b/ui-tui/src/__tests__/streamingMarkdown.test.ts @@ -67,6 +67,48 @@ describe('findStableBoundary', () => { it('handles empty input', () => { expect(findStableBoundary('')).toBe(-1) }) + + it('refuses to split inside an open $$ math block', () => { + // Display math has been opened but not closed; the only blank line + // sits inside the open block, so there's no safe boundary yet. + const text = '$$\nx + y\n\nmore math' + + expect(findStableBoundary(text)).toBe(-1) + }) + + it('allows splitting after a $$ math block closes', () => { + const text = '$$\nx + y = z\n$$\n\nnarration continues' + const idx = findStableBoundary(text) + + expect(text.slice(0, idx)).toBe('$$\nx + y = z\n$$\n\n') + expect(text.slice(idx)).toBe('narration continues') + }) + + it('splits before an open $$ block but not inside', () => { + // Mirror of the existing fenced-code test: prose, then an unclosed + // math block. The only safe boundary is the blank line BEFORE `$$`. + const text = 'intro paragraph\n\n$$\nx + y\n\nmore' + const idx = findStableBoundary(text) + + expect(text.slice(0, idx)).toBe('intro paragraph\n\n') + expect(text.slice(idx).startsWith('$$')).toBe(true) + }) + + it('treats single-line $$x$$ as zero net toggle', () => { + // `$$x = y$$` opens AND closes on one line, so the stable boundary + // after it is allowed. + const text = 'intro\n\n$$x = y$$\n\nnarration' + const idx = findStableBoundary(text) + + expect(text.slice(0, idx)).toBe('intro\n\n$$x = y$$\n\n') + expect(text.slice(idx)).toBe('narration') + }) + + it('refuses to split inside an open \\[ math block', () => { + const text = '\\[\nx + y\n\nmore' + + expect(findStableBoundary(text)).toBe(-1) + }) }) describe('streaming theme assumption', () => { diff --git a/ui-tui/src/__tests__/textInputWrap.test.ts b/ui-tui/src/__tests__/textInputWrap.test.ts index 5521012e9c..c25c9629e7 100644 --- a/ui-tui/src/__tests__/textInputWrap.test.ts +++ b/ui-tui/src/__tests__/textInputWrap.test.ts @@ -1,9 +1,9 @@ import { describe, expect, it } from 'vitest' import { offsetFromPosition } from '../components/textInput.js' -import { cursorLayout, inputVisualHeight, stableComposerColumns } from '../lib/inputMetrics.js' +import { composerPromptWidth, cursorLayout, inputVisualHeight, stableComposerColumns } from '../lib/inputMetrics.js' -describe('cursorLayout — char-wrap parity with wrap-ansi', () => { +describe('cursorLayout — word-wrap parity with wrap-ansi', () => { it('places cursor mid-line at its column', () => { expect(cursorLayout('hello world', 6, 40)).toEqual({ column: 6, line: 0 }) }) @@ -18,12 +18,20 @@ describe('cursorLayout — char-wrap parity with wrap-ansi', () => { expect(cursorLayout('abcdefgh', 8, 8)).toEqual({ column: 0, line: 1 }) }) - it('tracks a word across a char-wrap boundary without jumping', () => { - // With wordWrap:false, "hello world" at cols=8 is "hello wo\nrld" — - // typing incremental letters doesn't reshuffle the word across lines. + it('moves words across wrap boundaries instead of splitting them', () => { + // With wordWrap:true, "hello wor" at cols=8 is "hello \nwor" rather + // than "hello wo\nr". expect(cursorLayout('hello wo', 8, 8)).toEqual({ column: 0, line: 1 }) - expect(cursorLayout('hello wor', 9, 8)).toEqual({ column: 1, line: 1 }) - expect(cursorLayout('hello worl', 10, 8)).toEqual({ column: 2, line: 1 }) + expect(cursorLayout('hello wor', 9, 8)).toEqual({ column: 3, line: 1 }) + expect(cursorLayout('hello worl', 10, 8)).toEqual({ column: 4, line: 1 }) + expect(cursorLayout('hello world', 11, 8)).toEqual({ column: 5, line: 1 }) + }) + + it('wraps the next word instead of splitting it at the right edge', () => { + const text = 'hello world baby chickens are so cool its really rainy outside but wish' + + expect(cursorLayout(text, text.length, 70)).toEqual({ column: 4, line: 1 }) + expect(inputVisualHeight(text, 70)).toBe(2) }) it('honours explicit newlines', () => { @@ -42,6 +50,12 @@ describe('input metrics helpers', () => { expect(inputVisualHeight('one\ntwo', 40)).toBe(2) }) + it('counts the prompt gap as its own cell', () => { + expect(composerPromptWidth('>')).toBe(2) + expect(composerPromptWidth('❯')).toBe(2) + expect(composerPromptWidth('Ψ >')).toBe(4) + }) + it('reserves gutters on wide panes without starving narrow composer width', () => { expect(stableComposerColumns(100, 3)).toBe(93) expect(stableComposerColumns(100, 5)).toBe(91) @@ -50,7 +64,7 @@ describe('input metrics helpers', () => { }) }) -describe('offsetFromPosition — char-wrap inverse of cursorLayout', () => { +describe('offsetFromPosition — word-wrap inverse of cursorLayout', () => { it('returns 0 for empty input', () => { expect(offsetFromPosition('', 0, 0, 10)).toBe(0) }) @@ -64,11 +78,23 @@ describe('offsetFromPosition — char-wrap inverse of cursorLayout', () => { }) it('maps clicks on a wrapped second row at cols boundary', () => { - // "abcdefghij" at cols=8 wraps to "abcdefgh\nij" — click at row 1 col 0 - // should land on 'i' (offset 8). + // Long words still hard-wrap when there is no word boundary. expect(offsetFromPosition('abcdefghij', 1, 0, 8)).toBe(8) }) + it('maps clicks on a word-wrapped second row', () => { + // "hello world" at cols=8 wraps to "hello \nworld". + expect(offsetFromPosition('hello world', 1, 0, 8)).toBe(6) + expect(offsetFromPosition('hello world', 1, 3, 8)).toBe(9) + }) + + it('maps clicks on the moved final word', () => { + const text = 'hello world baby chickens are so cool its really rainy outside but wish' + + expect(offsetFromPosition(text, 1, 0, 70)).toBe(text.indexOf('wish')) + expect(offsetFromPosition(text, 1, 3, 70)).toBe(text.indexOf('wish') + 3) + }) + it('maps clicks past a \\n into the target line', () => { expect(offsetFromPosition('one\ntwo', 1, 2, 40)).toBe(6) }) diff --git a/ui-tui/src/__tests__/theme.test.ts b/ui-tui/src/__tests__/theme.test.ts index 888bd9142a..30a047df66 100644 --- a/ui-tui/src/__tests__/theme.test.ts +++ b/ui-tui/src/__tests__/theme.test.ts @@ -16,12 +16,13 @@ const RELEVANT_ENV = [ 'HERMES_TUI_THEME', 'HERMES_TUI_BACKGROUND', 'COLORFGBG', + 'COLORTERM', 'TERM_PROGRAM' ] as const -async function importThemeWithCleanEnv() { +async function importThemeWithEnv(env: Partial> = {}) { for (const key of RELEVANT_ENV) { - vi.stubEnv(key, '') + vi.stubEnv(key, env[key] ?? '') } vi.resetModules() @@ -29,6 +30,10 @@ async function importThemeWithCleanEnv() { return import('../theme.js') } +async function importThemeWithCleanEnv() { + return importThemeWithEnv() +} + afterEach(() => { vi.unstubAllEnvs() vi.resetModules() @@ -84,6 +89,12 @@ describe('detectLightMode', () => { expect(detectLightMode({})).toBe(false) }) + it('defaults Apple Terminal to light when no stronger signal is present', async () => { + const { detectLightMode } = await importThemeWithCleanEnv() + + expect(detectLightMode({ TERM_PROGRAM: 'Apple_Terminal' })).toBe(true) + }) + it('honors HERMES_TUI_LIGHT on/off', async () => { const { detectLightMode } = await importThemeWithCleanEnv() @@ -159,8 +170,8 @@ describe('detectLightMode', () => { it('treats COLORFGBG as authoritative when present so it dominates the TERM_PROGRAM allow-list', async () => { const { detectLightMode } = await importThemeWithCleanEnv() - // Inject a light-default allow-list so the precedence test is - // meaningful even though the production allow-list is empty. + // Injecting the allow-list keeps this precedence rule explicit even if + // production defaults change. const allowList = new Set(['Apple_Terminal']) // Sanity: the allow-list alone WOULD turn this terminal light. @@ -221,6 +232,40 @@ describe('fromSkin', () => { expect(fromSkin({}, {}).brand.icon).toBe(DEFAULT_THEME.brand.icon) }) + it('normalizes non-banner foregrounds on light Apple Terminal', async () => { + const { fromSkin } = await importThemeWithEnv({ TERM_PROGRAM: 'Apple_Terminal' }) + + const theme = fromSkin({ + banner_accent: '#FFBF00', + banner_border: '#CD7F32', + banner_dim: '#B8860B', + banner_text: '#FFF8DC', + banner_title: '#FFD700', + prompt: '#FFF8DC' + }, {}) + + expect(theme.color.primary).toBe('#FFD700') + expect(theme.color.accent).toBe('#FFBF00') + expect(theme.color.border).toBe('#CD7F32') + expect(theme.color.muted).toBe('ansi256(245)') + expect(theme.color.text).toBe('ansi256(136)') + expect(theme.color.prompt).toBe('ansi256(136)') + }) + + it('does not normalize light Apple Terminal when truecolor is advertised', async () => { + const { fromSkin } = await importThemeWithEnv({ COLORTERM: 'truecolor', TERM_PROGRAM: 'Apple_Terminal' }) + const theme = fromSkin({ banner_text: '#FFF8DC' }, {}) + + expect(theme.color.text).toBe('#FFF8DC') + }) + + it('normalizes Apple Terminal names before matching', async () => { + const { fromSkin } = await importThemeWithEnv({ TERM_PROGRAM: ' Apple_Terminal ' }) + const theme = fromSkin({ banner_text: '#FFF8DC' }, {}) + + expect(theme.color.text).toBe('ansi256(136)') + }) + it('passes banner logo/hero', async () => { const { fromSkin } = await importThemeWithCleanEnv() diff --git a/ui-tui/src/app/createGatewayEventHandler.ts b/ui-tui/src/app/createGatewayEventHandler.ts index 0dd190c10e..86295f67d9 100644 --- a/ui-tui/src/app/createGatewayEventHandler.ts +++ b/ui-tui/src/app/createGatewayEventHandler.ts @@ -282,6 +282,11 @@ export function createGatewayEventHandler(ctx: GatewayEventHandlerContext): (ev: setStatus(p.text) + if (p.kind === 'compressing') { + sys(p.text) + return + } + if (!p.kind || p.kind === 'status') { return } diff --git a/ui-tui/src/app/slash/commands/core.ts b/ui-tui/src/app/slash/commands/core.ts index 1b29366361..f9b54c34c1 100644 --- a/ui-tui/src/app/slash/commands/core.ts +++ b/ui-tui/src/app/slash/commands/core.ts @@ -266,7 +266,9 @@ export const coreCommands: SlashCommand[] = [ return transcript.sys(DETAILS_USAGE) } - patchUiState({ detailsMode: next, detailsModeCommandOverride: true }) + const sections = Object.fromEntries(SECTION_NAMES.map(section => [section, next])) + + patchUiState({ detailsMode: next, detailsModeCommandOverride: true, sections }) gateway.rpc('config.set', { key: 'details_mode', value: next }).catch(() => {}) transcript.sys(`details: ${next}`) } diff --git a/ui-tui/src/app/slash/commands/ops.ts b/ui-tui/src/app/slash/commands/ops.ts index 7353f6fb4d..ad9f3e94d1 100644 --- a/ui-tui/src/app/slash/commands/ops.ts +++ b/ui-tui/src/app/slash/commands/ops.ts @@ -76,14 +76,39 @@ export const opsCommands: SlashCommand[] = [ { aliases: ['reload_mcp'], - help: 'reload MCP servers in the live session', + help: 'reload MCP servers in the live session (warns about prompt cache invalidation)', name: 'reload-mcp', - run: (_arg, ctx) => { + run: (arg, ctx) => { + // Parse arg: `now` / `always` skip the confirmation gate. + // `always` additionally persists approvals.mcp_reload_confirm=false. + const a = (arg || '').trim().toLowerCase() + const params: { session_id: string | null; confirm?: boolean; always?: boolean } = { + session_id: ctx.sid + } + if (a === 'now' || a === 'approve' || a === 'once' || a === 'yes') { + params.confirm = true + } else if (a === 'always') { + params.confirm = true + params.always = true + } + ctx.gateway - .rpc('reload.mcp', { session_id: ctx.sid }) + .rpc('reload.mcp', params) .then( ctx.guarded(r => { - ctx.transcript.sys(r.status === 'reloaded' ? 'MCP servers reloaded' : 'reload complete') + if (r.status === 'confirm_required') { + ctx.transcript.sys(r.message || '/reload-mcp requires confirmation') + return + } + if (r.status === 'reloaded') { + ctx.transcript.sys( + params.always + ? 'MCP servers reloaded · future /reload-mcp will run without confirmation' + : 'MCP servers reloaded' + ) + return + } + ctx.transcript.sys('reload complete') }) ) .catch(ctx.guardedErr) diff --git a/ui-tui/src/app/slash/commands/session.ts b/ui-tui/src/app/slash/commands/session.ts index ecd1b7866f..0a5324ef55 100644 --- a/ui-tui/src/app/slash/commands/session.ts +++ b/ui-tui/src/app/slash/commands/session.ts @@ -154,6 +154,22 @@ export const sessionCommands: SlashCommand[] = [ patchUiState(state => ({ ...state, usage: { ...state.usage, ...r.usage } })) } + if (r.summary?.headline) { + const prefix = r.summary.noop ? '' : '✓ ' + + ctx.transcript.sys(`${prefix}${r.summary.headline}`) + + if (r.summary.token_line) { + ctx.transcript.sys(` ${r.summary.token_line}`) + } + + if (r.summary.note) { + ctx.transcript.sys(` ${r.summary.note}`) + } + + return + } + if ((r.removed ?? 0) <= 0) { return ctx.transcript.sys('nothing to compress') } @@ -163,6 +179,7 @@ export const sessionCommands: SlashCommand[] = [ ) }) ) + .catch(ctx.guardedErr) } }, @@ -332,7 +349,29 @@ export const sessionCommands: SlashCommand[] = [ ctx.gateway .rpc('config.set', { key: 'reasoning', session_id: ctx.sid, value: arg }) - .then(ctx.guarded(r => r.value && ctx.transcript.sys(`reasoning: ${r.value}`))) + .then( + ctx.guarded(r => { + if (!r.value) { + return + } + + if (r.value === 'hide') { + patchUiState(state => ({ + ...state, + sections: { ...state.sections, thinking: 'hidden' }, + showReasoning: false + })) + } else if (r.value === 'show') { + patchUiState(state => ({ + ...state, + sections: { ...state.sections, thinking: 'expanded' }, + showReasoning: true + })) + } + + ctx.transcript.sys(`reasoning: ${r.value}`) + }) + ) } }, diff --git a/ui-tui/src/app/useConfigSync.ts b/ui-tui/src/app/useConfigSync.ts index 8695855759..ad8f52f148 100644 --- a/ui-tui/src/app/useConfigSync.ts +++ b/ui-tui/src/app/useConfigSync.ts @@ -5,8 +5,7 @@ import type { GatewayClient } from '../gatewayClient.js' import type { ConfigFullResponse, ConfigMtimeResponse, - ReloadMcpResponse, - VoiceToggleResponse + ReloadMcpResponse } from '../gatewayTypes.js' import { asRpcResult } from '../lib/rpc.js' @@ -118,7 +117,11 @@ export function useConfigSync({ gw, setBellOnComplete, setVoiceEnabled, sid }: U return } - quietRpc(gw, 'voice.toggle', { action: 'status' }).then(r => setVoiceEnabled(!!r?.enabled)) + // Keep startup cheap: voice.toggle status probes optional audio/STT deps and + // can run long enough to delay prompt.submit on the single stdio RPC pipe. + // Environment flags are enough to initialize the UI bit; the heavier status + // check still runs when the user opens /voice. + setVoiceEnabled(process.env.HERMES_VOICE === '1') quietRpc(gw, 'config.get', { key: 'mtime' }).then(r => { mtimeRef.current = Number(r?.mtime ?? 0) }) @@ -148,7 +151,7 @@ export function useConfigSync({ gw, setBellOnComplete, setVoiceEnabled, sid }: U mtimeRef.current = next - quietRpc(gw, 'reload.mcp', { session_id: sid }).then( + quietRpc(gw, 'reload.mcp', { session_id: sid, confirm: true }).then( r => r && turnController.pushActivity('MCP reloaded after config change') ) quietRpc(gw, 'config.get', { key: 'full' }).then(r => applyDisplay(r, setBellOnComplete)) diff --git a/ui-tui/src/app/useInputHandlers.ts b/ui-tui/src/app/useInputHandlers.ts index 6fdd04ea61..a74c9e8431 100644 --- a/ui-tui/src/app/useInputHandlers.ts +++ b/ui-tui/src/app/useInputHandlers.ts @@ -21,6 +21,8 @@ import { patchTurnState } from './turnStore.js' import { getUiState } from './uiStore.js' const isCtrl = (key: { ctrl: boolean }, ch: string, target: string) => key.ctrl && ch.toLowerCase() === target +const PRECISION_WHEEL_MIN_GAP_MS = 80 +const PRECISION_WHEEL_STICKY_MS = 80 export function useInputHandlers(ctx: InputHandlerContext): InputHandlerResult { const { actions, composer, gateway, terminal, voice, wheelStep } = ctx @@ -36,6 +38,10 @@ export function useInputHandlers(ctx: InputHandlerContext): InputHandlerResult { // rows = wheelStep × accelMult. State mutates in place across renders. const wheelAccelRef = useRef(initWheelAccelForHost()) + const precisionWheelRef = useRef<{ active: boolean; dir: 0 | -1 | 1; lastEventAtMs: number; lastScrollAtMs: number }>( + { active: false, dir: 0, lastEventAtMs: 0, lastScrollAtMs: 0 } + ) + useEffect(() => () => clearTimeout(scrollIdleTimer.current ?? undefined), []) const scrollTranscript = (delta: number) => { @@ -284,8 +290,43 @@ export function useInputHandlers(ctx: InputHandlerContext): InputHandlerResult { if (key.wheelUp || key.wheelDown) { const dir: -1 | 1 = key.wheelUp ? -1 : 1 + const now = Date.now() + // Modifier-held wheel = precision mode: at most one wheelStep per short + // interval. Smooth mice / trackpads emit many raw wheel events for one + // intended line step, so raw 1:1 still moves too far. + // SGR/X10 mouse encoding only carries shift/meta/ctrl bits; Cmd on + // macOS is intercepted by the terminal, so we honor Option (meta) on + // Mac / Alt (meta) on Win+Linux / Ctrl as a portable fallback. Shift + // is reserved for selection extension. + const hasModifier = key.meta || key.ctrl + const precision = precisionWheelRef.current + // Keep precision active through the current wheel burst after the + // modifier is released. Otherwise a stream of queued/momentum wheel + // events can hand off mid-burst into the accelerated path and jump. + const precisionSticky = now - precision.lastEventAtMs < PRECISION_WHEEL_STICKY_MS + + if (hasModifier || precisionSticky) { + if (!precision.active) { + precision.active = true + wheelAccelRef.current = initWheelAccelForHost() + } + + precision.lastEventAtMs = now + + if (dir === precision.dir && now - precision.lastScrollAtMs < PRECISION_WHEEL_MIN_GAP_MS) { + return + } + + precision.lastScrollAtMs = now + precision.dir = dir + + return scrollTranscript(dir * wheelStep) + } + + precision.active = false + // 0 = direction-flip bounce deferred; skip the no-op scroll. - const rows = computeWheelStep(wheelAccelRef.current, dir, Date.now()) + const rows = computeWheelStep(wheelAccelRef.current, dir, now) return rows ? scrollTranscript(dir * rows * wheelStep) : undefined } diff --git a/ui-tui/src/app/useMainApp.ts b/ui-tui/src/app/useMainApp.ts index 70dc96fec3..9ec18337bb 100644 --- a/ui-tui/src/app/useMainApp.ts +++ b/ui-tui/src/app/useMainApp.ts @@ -711,6 +711,9 @@ export function useMainApp(gw: GatewayClient) { const anyPanelVisible = SECTION_NAMES.some( s => sectionMode(s, ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' ) + const thinkingPanelVisible = sectionMode('thinking', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' + const toolsPanelVisible = sectionMode('tools', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' + const activityPanelVisible = sectionMode('activity', ui.detailsMode, ui.sections, ui.detailsModeCommandOverride) !== 'hidden' const showProgressArea = useTurnSelector(state => anyPanelVisible @@ -718,12 +721,25 @@ export function useMainApp(gw: GatewayClient) { ui.busy || state.outcome || state.streamPendingTools.length || - state.streamSegments.length || + state.streamSegments.some(segment => { + const hasThinking = Boolean(segment.thinking?.trim()) + const hasTrailTools = Boolean(segment.tools?.length) + + if (segment.kind === 'trail' && !segment.text) { + return (thinkingPanelVisible && hasThinking) || ((toolsPanelVisible || activityPanelVisible) && hasTrailTools) + } + + return ( + Boolean(segment.text?.trim()) || + (thinkingPanelVisible && hasThinking) || + ((toolsPanelVisible || activityPanelVisible) && hasTrailTools) + ) + }) || state.subagents.length || state.tools.length || state.todos.length || state.turnTrail.length || - hasReasoning || + (thinkingPanelVisible && hasReasoning) || state.activity.length ) : state.activity.some(item => item.tone !== 'info') diff --git a/ui-tui/src/app/useSubmission.ts b/ui-tui/src/app/useSubmission.ts index df6acfadbe..bbb288e001 100644 --- a/ui-tui/src/app/useSubmission.ts +++ b/ui-tui/src/app/useSubmission.ts @@ -126,6 +126,13 @@ export function useSubmission(opts: UseSubmissionOptions) { return sys('session not ready yet') } + // Plain prompts are the common path and should not pay an extra RPC + // before prompt.submit. File-drop detection still runs for absolute, + // tilde, file://, and explicit relative paths. + if (!looksLikeSlashCommand(text) && !/(?:^|\s)(?:file:\/\/|~\/|\.?\.\/|\/)[^\s]+/.test(text)) { + return startSubmit(text, expand(text), showUserMessage) + } + gw.request('input.detect_drop', { session_id: sid, text }) .then(r => { if (!r?.matched) { diff --git a/ui-tui/src/components/appLayout.tsx b/ui-tui/src/components/appLayout.tsx index 16d96f390b..6927460770 100644 --- a/ui-tui/src/components/appLayout.tsx +++ b/ui-tui/src/components/appLayout.tsx @@ -1,4 +1,4 @@ -import { AlternateScreen, Box, NoSelect, ScrollBox, stringWidth, Text } from '@hermes/ink' +import { AlternateScreen, Box, NoSelect, ScrollBox, Text } from '@hermes/ink' import { useStore } from '@nanostores/react' import { Fragment, memo, useMemo, useRef } from 'react' @@ -9,7 +9,12 @@ import { $uiState } from '../app/uiStore.js' import { INLINE_MODE, SHOW_FPS } from '../config/env.js' import { FULL_RENDER_TAIL_ITEMS } from '../config/limits.js' import { PLACEHOLDER } from '../content/placeholders.js' -import { inputVisualHeight, stableComposerColumns } from '../lib/inputMetrics.js' +import { + COMPOSER_PROMPT_GAP_WIDTH, + composerPromptWidth, + inputVisualHeight, + stableComposerColumns +} from '../lib/inputMetrics.js' import { PerfPane } from '../lib/perfPane.js' import { AgentsOverlay } from './agentsOverlay.js' @@ -22,6 +27,31 @@ import { QueuedMessages } from './queuedMessages.js' import { LiveTodoPanel, StreamingAssistant } from './streamingAssistant.js' import { TextInput, type TextInputMouseApi } from './textInput.js' +const PromptPrefix = memo(function PromptPrefix({ + bold = false, + color, + promptText, + width +}: { + bold?: boolean + color: string + promptText: string + width: number +}) { + const glyphWidth = Math.max(1, width - COMPOSER_PROMPT_GAP_WIDTH) + + return ( + + + + {promptText} + + + + + ) +}) + const TranscriptPane = memo(function TranscriptPane({ actions, composer, @@ -68,7 +98,7 @@ const TranscriptPane = memo(function TranscriptPane({ - {row.msg.info?.version && } + {row.msg.info && } ) : row.msg.kind === 'panel' && row.msg.panelData ? ( @@ -125,8 +155,8 @@ const ComposerPane = memo(function ComposerPane({ const isBlocked = useStore($isBlocked) const sh = (composer.inputBuf[0] ?? composer.input).startsWith('!') const promptText = sh ? '$' : ui.theme.brand.prompt - const promptLabel = `${promptText} ` - const promptWidth = Math.max(1, stringWidth(promptLabel)) + const promptWidth = composerPromptWidth(promptText) + const promptBlank = ' '.repeat(promptWidth) const inputColumns = stableComposerColumns(composer.cols, promptWidth) const inputHeight = inputVisualHeight(composer.input, inputColumns) const inputMouseRef = useRef(null) @@ -217,7 +247,11 @@ const ComposerPane = memo(function ComposerPane({ {composer.inputBuf.map((line, i) => ( - {i === 0 ? promptLabel : ' '.repeat(promptWidth)} + {i === 0 ? ( + + ) : ( + {promptBlank} + )} {line || ' '} @@ -229,18 +263,19 @@ const ComposerPane = memo(function ComposerPane({ onMouseDrag={dragFromPromptRow} onMouseUp={endInputDrag} position="relative" + width={Math.max(1, composer.cols - 2)} > {sh ? ( - {promptLabel} + + ) : composer.inputBuf.length ? ( + {promptBlank} ) : ( - - {composer.inputBuf.length ? ' '.repeat(promptWidth) : promptLabel} - + )} - + {/* Reserve the transcript scrollbar gutter too so typing never rewraps when the scrollbar column repaints. */} + - - - + + diff --git a/ui-tui/src/components/branding.tsx b/ui-tui/src/components/branding.tsx index 25e161fd71..84e502aada 100644 --- a/ui-tui/src/components/branding.tsx +++ b/ui-tui/src/components/branding.tsx @@ -1,10 +1,32 @@ import { Box, Text, useStdout } from '@hermes/ink' +import { useEffect, useState } from 'react' +import unicodeSpinners from 'unicode-animations' import { artWidth, caduceus, CADUCEUS_WIDTH, logo, LOGO_WIDTH } from '../banner.js' import { flat } from '../lib/text.js' import type { Theme } from '../theme.js' import type { PanelSection, SessionInfo } from '../types.js' +const LOADER_TICK_MS = 120 + +function InlineLoader({ label, t }: { label: string; t: Theme }) { + const [tick, setTick] = useState(0) + const spinner = unicodeSpinners.braille + const frame = spinner.frames[tick % spinner.frames.length] ?? '⠋' + + useEffect(() => { + const id = setInterval(() => setTick(n => n + 1), Math.max(LOADER_TICK_MS, spinner.interval)) + + return () => clearInterval(id) + }, [spinner.interval]) + + return ( + + {frame} {label} + + ) +} + export function ArtLines({ lines }: { lines: [string, string][] }) { return ( <> @@ -67,6 +89,7 @@ export function SessionPanel({ info, sid, t }: SessionPanelProps) { const entries = Object.entries(data).sort() const shown = entries.slice(0, max) const overflow = entries.length - max + const skeleton = info.lazy && entries.length === 0 return ( @@ -74,12 +97,16 @@ export function SessionPanel({ info, sid, t }: SessionPanelProps) { Available {title} - {shown.map(([k, vs]) => ( - - {strip(k)}: - {truncLine(strip(k) + ': ', vs)} - - ))} + {skeleton ? ( + + ) : ( + shown.map(([k, vs]) => ( + + {strip(k)}: + {truncLine(strip(k) + ': ', vs)} + + )) + )} {overflow > 0 && ( diff --git a/ui-tui/src/components/markdown.tsx b/ui-tui/src/components/markdown.tsx index 3b38b25558..163768a51c 100644 --- a/ui-tui/src/components/markdown.tsx +++ b/ui-tui/src/components/markdown.tsx @@ -2,9 +2,60 @@ import { Box, Link, Text } from '@hermes/ink' import { Fragment, memo, type ReactNode, useMemo } from 'react' import { ensureEmojiPresentation } from '../lib/emoji.js' +import { BOX_CLOSE, BOX_OPEN, texToUnicode } from '../lib/mathUnicode.js' import { highlightLine, isHighlightable } from '../lib/syntax.js' import type { Theme } from '../theme.js' +// `\boxed{X}` regions in `texToUnicode` output are marked with the +// non-printable U+0001 / U+0002 sentinels. Split on them and render the +// boxed segment with `inverse + bold` so it reads as a highlighter-pen +// emphasis on top of whatever color the parent `` is using (the +// theme accent for math). The leading / trailing space inside the +// highlight gives a one-cell visual margin so the highlight reads as a +// block, not a hug. +const renderMath = (text: string): ReactNode => { + if (!text.includes(BOX_OPEN)) { + return text + } + + const out: ReactNode[] = [] + let i = 0 + let key = 0 + + while (i < text.length) { + const start = text.indexOf(BOX_OPEN, i) + + if (start < 0) { + out.push(text.slice(i)) + + break + } + + if (start > i) { + out.push(text.slice(i, start)) + } + + const end = text.indexOf(BOX_CLOSE, start + 1) + + if (end < 0) { + out.push(text.slice(start)) + + break + } + + out.push( + + {' '} + {text.slice(start + 1, end)}{' '} + + ) + + i = end + 1 + } + + return out +} + const FENCE_RE = /^\s*(`{3,}|~{3,})(.*)$/ const FENCE_CLOSE_RE = /^\s*(`{3,}|~{3,})\s*$/ const HR_RE = /^ {0,3}([-*_])(?:\s*\1){2,}\s*$/ @@ -19,6 +70,15 @@ const QUOTE_RE = /^\s*(?:>\s*)+/ const TABLE_DIVIDER_CELL_RE = /^:?-{3,}:?$/ const MD_URL_RE = '((?:[^\\s()]|\\([^\\s()]*\\))+?)' +// Display math openers: `$$ ... $$` (TeX) and `\[ ... \]` (LaTeX). The +// opener is matched only when `$$` / `\[` appears at the very start of the +// trimmed line — `startsWith('$$')` used to fire on prose like +// `$$x+y$$ followed by more`, opening a block that never closed because the +// trailing `$$` on the same line was invisible to the close-scan loop. +const MATH_BLOCK_OPEN_RE = /^\s*(\$\$|\\\[)(.*)$/ +const MATH_BLOCK_CLOSE_DOLLAR_RE = /^(.*?)\$\$\s*$/ +const MATH_BLOCK_CLOSE_BRACKET_RE = /^(.*?)\\\]\s*$/ + export const MEDIA_LINE_RE = /^\s*[`"']?MEDIA:\s*(\S+?)[`"']?\s*$/ export const AUDIO_DIRECTIVE_RE = /^\s*\[\[audio_as_voice\]\]\s*$/ @@ -31,6 +91,13 @@ export const AUDIO_DIRECTIVE_RE = /^\s*\[\[audio_as_voice\]\]\s*$/ // `thing ~! more ~?` from Kimi / Qwen / GLM (kaomoji-style decorators) // doesn't pair up the first `~` with the next one on the line and swallow // the text between them as a dim `_`-prefixed span. +// +// Inline math (`$x$` and `\(x\)`) takes precedence over emphasis at the +// same start position because regex alternation is leftmost-first; a +// dollar-delimited span at column N wins over a `*` at column N+1, so +// `$P=a*b*c$` renders as math instead of having `*b*` corrupted into +// italics. Single-character minimums and "no space adjacent to delimiter" +// rules keep currency prose like `$5 to $10` from being swallowed. export const INLINE_RE = new RegExp( [ `!\\[(.*?)\\]\\(${MD_URL_RE}\\)`, // 1,2 image @@ -46,7 +113,13 @@ export const INLINE_RE = new RegExp( `\\[\\^([^\\]]+)\\]`, // 13 footnote ref `\\^([^^\\s][^^]*?)\\^`, // 14 superscript `~([A-Za-z0-9]{1,8})~`, // 15 subscript - `https?:\\/\\/[^\\s<]+` // 16 bare URL + `(https?:\\/\\/[^\\s<]+)`, // 16 bare URL — wrapped so it owns its own + // capture group; without this, the math + // spans below would land in m[16] and the + // MdInline dispatcher would treat them as + // bare URLs and render them as autolinks. + `(? .replace(/\[\^([^\]]+)\]/g, '[$1]') .replace(/\^([^^\s][^^]*?)\^/g, '^$1') .replace(/~([A-Za-z0-9]{1,8})~/g, '_$1') + .replace(/(? { const widths = rows[0]!.map((_, ci) => Math.max(...rows.map(r => stripInlineMarkup(r[ci] ?? '').length))) // Thin divider under the header. Without it tables look like prose - // with extra spacing because the header is just amber-coloured text + // with extra spacing because the header is just accent-coloured text // (#15534). We avoid full borders on purpose — column widths come // from `stripInlineMarkup(...).length` (UTF-16 code units, not // display width), so a real outline often misaligns on emoji and @@ -163,31 +238,39 @@ function MdInline({ t, text }: { t: Theme; text: string }) { } else if (m[6]) { parts.push( - {m[6]} + ) } else if (m[7]) { + // Code is the one wrap that does NOT recurse — inline `code` spans + // are verbatim by definition. Letting MdInline reprocess them + // would corrupt regex examples and shell snippets. parts.push( {m[7]} ) } else if (m[8] ?? m[9]) { + // Recurse into bold / italic / strike / highlight so nested + // `$...$` math (and other inline tokens) inside a `**bolded + // statement with $\mathbb{Z}$ math**` actually render. Without + // this the inner content is dropped into a single `` + // verbatim and the math renderer never sees it. parts.push( - {m[8] ?? m[9]} + ) } else if (m[10] ?? m[11]) { parts.push( - {m[10] ?? m[11]} + ) } else if (m[12]) { parts.push( - {m[12]} + ) } else if (m[13]) { @@ -218,6 +301,19 @@ function MdInline({ t, text }: { t: Theme; text: string }) { if (url.length < m[16].length) { parts.push({m[16].slice(url.length)}) } + } else if (m[17] ?? m[18]) { + // Inline math is run through `texToUnicode` (Greek letters, ℕℤℚℝ, + // operators, sub/superscripts, fractions) and rendered in italic + // accent. Italic is the disambiguator — links use accent+underline, + // so without italic readers can't tell `\mathbb{R}` (math) from a + // hyperlinked word. Anything `texToUnicode` doesn't recognise is + // preserved verbatim, so unfamiliar commands just look like their + // raw LaTeX rather than vanishing. + parts.push( + + {renderMath(texToUnicode(m[17] ?? m[18]!))} + + ) } last = i + m[0].length @@ -415,32 +511,80 @@ function MdImpl({ compact, t, text }: MdProps) { continue } - if (line.trim().startsWith('$$')) { - start('code') + const mathOpen = line.match(MATH_BLOCK_OPEN_RE) + if (mathOpen) { + const opener = mathOpen[1]! + const closeRe = opener === '$$' ? MATH_BLOCK_CLOSE_DOLLAR_RE : MATH_BLOCK_CLOSE_BRACKET_RE + const headRest = mathOpen[2] ?? '' const block: string[] = [] - for (i++; i < lines.length; i++) { - if (lines[i]!.trim().startsWith('$$')) { - i++ + // Single-line block: `$$x + y = z$$` or `\[x\]`. Capture inner content + // and emit the block immediately. Without this, the close-scan loop + // skips line `i` and treats the next opener as our closer, swallowing + // every paragraph in between. + const sameLineClose = headRest.match(closeRe) + + if (sameLineClose) { + const inner = sameLineClose[1]!.trim() + + start('code') + nodes.push( + + {inner ? {renderMath(texToUnicode(inner))} : null} + + ) + i++ + + continue + } + + // Multi-line block: scan ahead for a real closer before committing. + // If none exists in the rest of the document, render this line as a + // paragraph instead of consuming everything that follows. + let closeIdx = -1 + + for (let j = i + 1; j < lines.length; j++) { + if (closeRe.test(lines[j]!)) { + closeIdx = j break } - - block.push(lines[i]!) } + if (closeIdx < 0) { + start('paragraph') + nodes.push() + i++ + + continue + } + + if (headRest.trim()) { + block.push(headRest) + } + + for (let j = i + 1; j < closeIdx; j++) { + block.push(lines[j]!) + } + + const tail = lines[closeIdx]!.match(closeRe)![1]!.trimEnd() + + if (tail.trim()) { + block.push(tail) + } + + start('code') nodes.push( - ─ math - {block.map((l, j) => ( - {l} + {renderMath(texToUnicode(l))} ))} ) + i = closeIdx + 1 continue } @@ -451,7 +595,7 @@ function MdImpl({ compact, t, text }: MdProps) { start('heading') nodes.push( - {heading} + ) i++ @@ -463,7 +607,7 @@ function MdImpl({ compact, t, text }: MdProps) { start('heading') nodes.push( - {line.trim()} + ) i += 2 diff --git a/ui-tui/src/components/sessionPicker.tsx b/ui-tui/src/components/sessionPicker.tsx index fd29d9e7ec..e836e59852 100644 --- a/ui-tui/src/components/sessionPicker.tsx +++ b/ui-tui/src/components/sessionPicker.tsx @@ -2,7 +2,7 @@ import { Box, Text, useInput, useStdout } from '@hermes/ink' import { useEffect, useState } from 'react' import type { GatewayClient } from '../gatewayClient.js' -import type { SessionListItem, SessionListResponse } from '../gatewayTypes.js' +import type { SessionDeleteResponse, SessionListItem, SessionListResponse } from '../gatewayTypes.js' import { asRpcResult, rpcErrorMessage } from '../lib/rpc.js' import type { Theme } from '../theme.js' @@ -31,6 +31,10 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) const [err, setErr] = useState('') const [sel, setSel] = useState(0) const [loading, setLoading] = useState(true) + // When non-null, the user pressed `d` on this index and we're waiting for + // a second `d`/`D` to confirm deletion. Any other key cancels the prompt. + const [confirmDelete, setConfirmDelete] = useState(null) + const [deleting, setDeleting] = useState(false) const { stdout } = useStdout() const width = Math.max(MIN_WIDTH, Math.min(MAX_WIDTH, (stdout?.columns ?? 80) - 6)) @@ -59,7 +63,57 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) }) }, [gw]) + const performDelete = (index: number) => { + const target = items[index] + + if (!target || deleting) { + return + } + + setDeleting(true) + gw.request('session.delete', { session_id: target.id }) + .then(raw => { + const r = asRpcResult(raw) + + if (!r || r.deleted !== target.id) { + setErr('invalid response: session.delete') + setDeleting(false) + + return + } + + setItems(prev => { + const next = prev.filter((_, i) => i !== index) + setSel(s => Math.max(0, Math.min(s, next.length - 1))) + + return next + }) + setErr('') + setDeleting(false) + }) + .catch((e: unknown) => { + setErr(rpcErrorMessage(e)) + setDeleting(false) + }) + } + useInput((ch, key) => { + if (deleting) { + return + } + + if (confirmDelete !== null) { + if (ch?.toLowerCase() === 'd') { + const idx = confirmDelete + setConfirmDelete(null) + performDelete(idx) + } else { + setConfirmDelete(null) + } + + return + } + if (key.upArrow && sel > 0) { setSel(s => s - 1) } @@ -70,6 +124,14 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) if (key.return && items[sel]) { onSelect(items[sel]!.id) + + return + } + + if (ch?.toLowerCase() === 'd' && items[sel]) { + setConfirmDelete(sel) + + return } const n = parseInt(ch) @@ -83,7 +145,7 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) return loading sessions… } - if (err) { + if (err && !items.length) { return ( error: {err} @@ -109,11 +171,12 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) Resume Session - {offset > 0 && ↑ {offset} more} + {offset > 0 && ↑ {offset} more} {items.slice(offset, offset + VISIBLE).map((s, vi) => { const i = offset + vi const selected = sel === i + const pendingDelete = confirmDelete === i return ( @@ -135,18 +198,23 @@ export function SessionPicker({ gw, onCancel, onSelect, t }: SessionPickerProps) - {s.title || s.preview || '(untitled)'} + {pendingDelete ? 'press d again to delete' : s.title || s.preview || '(untitled)'} ) })} - {offset + VISIBLE < items.length && ↓ {items.length - offset - VISIBLE} more} - ↑/↓ select · Enter resume · 1-9 quick · Esc/q cancel + {offset + VISIBLE < items.length && ↓ {items.length - offset - VISIBLE} more} + {err && error: {err}} + {deleting ? ( + deleting… + ) : ( + ↑/↓ select · Enter resume · 1-9 quick · d delete · Esc/q cancel + )} ) } diff --git a/ui-tui/src/components/streamingMarkdown.tsx b/ui-tui/src/components/streamingMarkdown.tsx index 111ed61e09..1be70b283a 100644 --- a/ui-tui/src/components/streamingMarkdown.tsx +++ b/ui-tui/src/components/streamingMarkdown.tsx @@ -35,19 +35,60 @@ import type { Theme } from '../theme.js' import { Md } from './markdown.js' -// Count ``` or ~~~ fence toggles in `s` up to `end`. Odd = currently inside -// a fenced block; we can't split the prefix there or we'd orphan the fence. +// Count ``` / ~~~ AND `$$` / `\[…\]` fence toggles in `s` up to `end`. Odd +// = currently inside a fenced block; splitting the prefix there would +// orphan the fence and let the unstable suffix re-render as broken +// markdown. Math fences only toggle when the code fence is closed so +// snippets like ` ```\n$$x$$\n``` ` (math example inside a code block) +// don't double-count. A `$$x$$` line that opens AND closes on its own +// produces zero net toggles; that's `len >= 4` plus `endsDollar`. +// +// NB: this is INTENTIONALLY more conservative than `markdown.tsx`'s +// parser, which falls back to paragraph rendering when an `$$` opener +// has no matching closer. The renderer can do that safely because it +// always sees the full text on every call. The streaming chunker +// cannot — once a chunk is committed to the monotonic stable prefix it +// is frozen, so prematurely deciding "this `$$` is just prose" would +// permanently commit a paragraph rendering that becomes wrong the +// instant the closer streams in. Treating any unmatched `$$` opener +// as still-open keeps the boundary parked behind it until the closer +// arrives (or the stream ends and the non-streaming `` takes over, +// at which point the renderer's fallback kicks in correctly). const fenceOpenAt = (s: string, end: number) => { - let open = false + let codeOpen = false + let mathOpen = false + let mathOpener: '$$' | '\\[' | null = null let i = 0 while (i < end) { const nl = s.indexOf('\n', i) const lineEnd = nl < 0 || nl > end ? end : nl - const line = s.slice(i, lineEnd) + const line = s.slice(i, lineEnd).trim() - if (/^\s*(?:`{3,}|~{3,})/.test(line)) { - open = !open + if (/^(?:`{3,}|~{3,})/.test(line)) { + codeOpen = !codeOpen + } else if (!codeOpen) { + if (!mathOpen && /^\$\$/.test(line)) { + const isSingleLine = line.length >= 4 && /\$\$$/.test(line) + + if (!isSingleLine) { + mathOpen = true + mathOpener = '$$' + } + } else if (!mathOpen && /^\\\[/.test(line)) { + const isSingleLine = /\\\]$/.test(line) + + if (!isSingleLine) { + mathOpen = true + mathOpener = '\\[' + } + } else if (mathOpen && mathOpener === '$$' && /\$\$$/.test(line)) { + mathOpen = false + mathOpener = null + } else if (mathOpen && mathOpener === '\\[' && /\\\]$/.test(line)) { + mathOpen = false + mathOpener = null + } } if (nl < 0 || nl >= end) { @@ -57,7 +98,7 @@ const fenceOpenAt = (s: string, end: number) => { i = nl + 1 } - return open + return codeOpen || mathOpen } // Find the last "\n\n" boundary before `end` that is OUTSIDE a fenced code diff --git a/ui-tui/src/components/textInput.tsx b/ui-tui/src/components/textInput.tsx index 0052e69ed7..3008f0baf4 100644 --- a/ui-tui/src/components/textInput.tsx +++ b/ui-tui/src/components/textInput.tsx @@ -4,7 +4,7 @@ import { type MutableRefObject, useEffect, useMemo, useRef, useState } from 'rea import { setInputSelection } from '../app/inputSelectionStore.js' import { readClipboardText, writeClipboardText } from '../lib/clipboard.js' -import { cursorLayout } from '../lib/inputMetrics.js' +import { cursorLayout, offsetFromPosition } from '../lib/inputMetrics.js' import { isActionMod, isMac, isMacActionFallback } from '../lib/platform.js' type InkExt = typeof Ink & { @@ -170,57 +170,7 @@ export function lineNav(s: string, p: number, dir: -1 | 1): null | number { return snapPos(s, Math.min(nextBreak + 1 + col, lineEnd)) } -export function offsetFromPosition(value: string, row: number, col: number, cols: number) { - if (!value.length) { - return 0 - } - - const targetRow = Math.max(0, Math.floor(row)) - const targetCol = Math.max(0, Math.floor(col)) - const w = Math.max(1, cols) - - let line = 0 - let column = 0 - let lastOffset = 0 - - for (const { segment, index } of seg().segment(value)) { - lastOffset = index - - if (segment === '\n') { - if (line === targetRow) { - return index - } - - line++ - column = 0 - - continue - } - - const sw = Math.max(1, stringWidth(segment)) - - if (column + sw > w) { - if (line === targetRow) { - return index - } - - line++ - column = 0 - } - - if (line === targetRow && targetCol <= column + Math.max(0, sw - 1)) { - return index - } - - column += sw - } - - if (targetRow >= line) { - return value.length - } - - return lastOffset -} +export { offsetFromPosition } function renderWithCursor(value: string, cursor: number) { const pos = Math.max(0, Math.min(cursor, value.length)) @@ -1059,7 +1009,7 @@ export function TextInput({ ref={boxRef} width={columns} > - {rendered} + {rendered} ) } diff --git a/ui-tui/src/entry.tsx b/ui-tui/src/entry.tsx index f1ce52bab5..bd56c7f0f8 100644 --- a/ui-tui/src/entry.tsx +++ b/ui-tui/src/entry.tsx @@ -1,8 +1,6 @@ #!/usr/bin/env -S node --max-old-space-size=8192 --expose-gc -// Must be first import — mutates process.env.FORCE_COLOR / COLORTERM before -// any chalk / supports-color import so the banner gradient renders in -// truecolor instead of being downsampled to 256-color (which collapses -// gold #FFD700 and amber #FFBF00 to the same slot). +// Must be first import. If the user explicitly opts into truecolor, this +// nudges chalk / supports-color before either package is initialized. import './lib/forceTruecolor.js' import type { FrameEvent } from '@hermes/ink' diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index 1f43096340..60957fc28e 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -129,6 +129,10 @@ export interface SessionListResponse { sessions?: SessionListItem[] } +export interface SessionDeleteResponse { + deleted: string +} + export interface SessionMostRecentResponse { session_id?: null | string source?: string @@ -167,9 +171,19 @@ export interface SessionUsageResponse { } export interface SessionCompressResponse { + after_messages?: number + after_tokens?: number + before_messages?: number + before_tokens?: number info?: SessionInfo messages?: GatewayTranscriptMessage[] removed?: number + summary?: { + headline?: string + noop?: boolean + note?: null | string + token_line?: string + } usage?: Usage } @@ -306,6 +320,7 @@ export interface ModelOptionsResponse { export interface ReloadMcpResponse { status?: string + message?: string } export interface ReloadEnvResponse { diff --git a/ui-tui/src/lib/forceTruecolor.ts b/ui-tui/src/lib/forceTruecolor.ts index 3e99b6b184..25de7b2dc3 100644 --- a/ui-tui/src/lib/forceTruecolor.ts +++ b/ui-tui/src/lib/forceTruecolor.ts @@ -1,27 +1,25 @@ /** - * Force 24-bit truecolor output before any chalk / supports-color import. + * Targeted 24-bit truecolor override before chalk / supports-color imports. * - * Why this exists: - * The base CLI (Python/Rich) emits banner colors as truecolor ANSI - * (`\033[38;2;R;G;Bm`). The TUI renders through Ink → chalk, whose - * supports-color auto-detection defaults to 256-color on macOS Terminal.app - * and any terminal that does NOT set `COLORTERM=truecolor`. In 256-color - * mode, chalk downsamples `#FFD700` (gold) and `#FFBF00` (amber) to the - * *same* xterm-256 palette slot (220) — collapsing the banner gradient - * into a single flat yellow band. The bronze and dim rows also lose - * contrast against each other. - * - * Terminal.app (macOS 12+), iTerm2, kitty, Alacritty, VS Code, Cursor, - * and WezTerm all render truecolor correctly. The few that don't - * (ancient xterm, some CI environments) can set `HERMES_TUI_TRUECOLOR=0` - * to opt out. - * - * This MUST run before any `chalk` or `supports-color` import. supports-color - * caches its level on first load, so nudging env vars after that point has - * no effect. + * macOS Terminal.app before Tahoe 26 does not support RGB SGR, so do not + * infer truecolor from TERM_PROGRAM=Apple_Terminal. Users can still opt in + * explicitly on terminals that support RGB but do not advertise COLORTERM. */ -if (process.env.HERMES_TUI_TRUECOLOR !== '0' && !process.env.NO_COLOR && !process.env.FORCE_COLOR) { +const TRUE_RE = /^(?:1|true|yes|on)$/i +const FALSE_RE = /^(?:0|false|no|off)$/i + +export function shouldForceTruecolor(env: NodeJS.ProcessEnv = process.env): boolean { + const override = (env.HERMES_TUI_TRUECOLOR ?? '').trim() + + if (FALSE_RE.test(override) || 'NO_COLOR' in env) { + return false + } + + return TRUE_RE.test(override) +} + +if (shouldForceTruecolor()) { if (!process.env.COLORTERM) { process.env.COLORTERM = 'truecolor' } diff --git a/ui-tui/src/lib/inputMetrics.ts b/ui-tui/src/lib/inputMetrics.ts index d54f963709..245baae96f 100644 --- a/ui-tui/src/lib/inputMetrics.ts +++ b/ui-tui/src/lib/inputMetrics.ts @@ -1,58 +1,167 @@ import { stringWidth } from '@hermes/ink' +export const COMPOSER_PROMPT_GAP_WIDTH = 1 + let _seg: Intl.Segmenter | null = null const seg = () => (_seg ??= new Intl.Segmenter(undefined, { granularity: 'grapheme' })) +interface VisualLine { + end: number + start: number +} + +const isWhitespace = (value: string) => /\s/.test(value) + +const graphemes = (value: string) => + [...seg().segment(value)].map(({ segment, index }) => ({ + end: index + segment.length, + index, + segment, + width: Math.max(1, stringWidth(segment)) + })) + +function visualLines(value: string, cols: number): VisualLine[] { + const width = Math.max(1, cols) + const lines: VisualLine[] = [] + let sourceLineStart = 0 + + for (const sourceLine of value.split('\n')) { + const parts = graphemes(sourceLine) + + if (!parts.length) { + lines.push({ start: sourceLineStart, end: sourceLineStart }) + sourceLineStart += 1 + continue + } + + let lineStartPart = 0 + let lineStartOffset = sourceLineStart + let column = 0 + let breakPart: null | number = null + let i = 0 + + while (i < parts.length) { + const part = parts[i]! + const partStart = sourceLineStart + part.index + + if (column + part.width > width && i > lineStartPart) { + if (breakPart !== null && breakPart > lineStartPart) { + const breakOffset = sourceLineStart + parts[breakPart - 1]!.end + lines.push({ start: lineStartOffset, end: breakOffset }) + lineStartPart = breakPart + lineStartOffset = breakOffset + } else { + lines.push({ start: lineStartOffset, end: partStart }) + lineStartPart = i + lineStartOffset = partStart + } + + column = 0 + breakPart = null + i = lineStartPart + continue + } + + column += part.width + + if (isWhitespace(part.segment)) { + breakPart = i + 1 + } + + i += 1 + + if (column >= width && i < parts.length) { + const next = parts[i]! + const nextStartsWord = !isWhitespace(next.segment) + + if (breakPart !== null && breakPart > lineStartPart && nextStartsWord) { + const breakOffset = sourceLineStart + parts[breakPart - 1]!.end + lines.push({ start: lineStartOffset, end: breakOffset }) + lineStartPart = breakPart + lineStartOffset = breakOffset + column = 0 + breakPart = null + i = lineStartPart + } + } + } + + lines.push({ start: lineStartOffset, end: sourceLineStart + sourceLine.length }) + sourceLineStart += sourceLine.length + 1 + } + + return lines.length ? lines : [{ start: 0, end: 0 }] +} + +function widthBetween(value: string, start: number, end: number) { + let width = 0 + + for (const part of graphemes(value.slice(start, end))) { + width += part.width + } + + return width +} + /** - * Mirrors the char-wrap behavior used by the composer TextInput. + * Mirrors the word-wrap behavior used by the composer TextInput. * Returns the zero-based visual line and column of the cursor cell. */ export function cursorLayout(value: string, cursor: number, cols: number) { const pos = Math.max(0, Math.min(cursor, value.length)) const w = Math.max(1, cols) + const lines = visualLines(value, w) + let lineIndex = 0 - let col = 0, - line = 0 - - for (const { segment, index } of seg().segment(value)) { - if (index >= pos) { + for (let i = 0; i < lines.length; i += 1) { + if (lines[i]!.start <= pos) { + lineIndex = i + } else { break } - - if (segment === '\n') { - line++ - col = 0 - - continue - } - - const sw = stringWidth(segment) - - if (!sw) { - continue - } - - if (col + sw > w) { - line++ - col = 0 - } - - col += sw } + const line = lines[lineIndex]! + let column = widthBetween(value, line.start, Math.min(pos, line.end)) + // trailing cursor-cell overflows to the next row at the wrap column - if (col >= w) { - line++ - col = 0 + if (column >= w) { + lineIndex++ + column = 0 } - return { column: col, line } + return { column, line: lineIndex } +} + +export function offsetFromPosition(value: string, row: number, col: number, cols: number) { + if (!value.length) { + return 0 + } + + const lines = visualLines(value, cols) + const target = lines[Math.max(0, Math.min(lines.length - 1, Math.floor(row)))]! + const targetCol = Math.max(0, Math.floor(col)) + let column = 0 + + for (const part of graphemes(value.slice(target.start, target.end))) { + if (targetCol <= column + Math.max(0, part.width - 1)) { + return target.start + part.index + } + + column += part.width + } + + return target.end } export function inputVisualHeight(value: string, columns: number) { return cursorLayout(value, value.length, columns).line + 1 } +export function composerPromptWidth(promptText: string) { + return Math.max(1, stringWidth(promptText)) + COMPOSER_PROMPT_GAP_WIDTH +} + export function stableComposerColumns(totalCols: number, promptWidth: number) { // Physical render/wrap width. Always reserve outer composer padding and // prompt prefix. Only reserve the transcript scrollbar gutter when the diff --git a/ui-tui/src/lib/mathUnicode.ts b/ui-tui/src/lib/mathUnicode.ts new file mode 100644 index 0000000000..17af85ee03 --- /dev/null +++ b/ui-tui/src/lib/mathUnicode.ts @@ -0,0 +1,770 @@ +// Best-effort LaTeX → Unicode for inline / display math captured by the +// markdown renderer. The terminal can't typeset LaTeX, but Unicode covers +// most of what models actually emit: Greek letters, blackboard / fraktur / +// calligraphic capitals, set theory + logic operators, common arrows, +// sub/superscripts, and `\frac{a}{b}` collapsed to `a/b`. +// +// Design rules: +// • Pure regex pipeline. Anything we don't recognise is preserved +// verbatim (so a `\foo{bar}` we've never heard of still survives). +// A real LaTeX parser would be more correct but throws on partial +// input — terminal users would rather see the raw command than a +// parse-error placeholder. +// • Longest-match-first ordering on commands so `\le` doesn't shadow +// `\leq`, `\sub` doesn't shadow `\subseteq`, etc. +// • Word-boundary lookahead `(?![A-Za-z])` after each command so +// `\pix` (made-up command) doesn't get partially substituted as `π`. +// • `\mathbb{X}`, `\mathcal{X}`, `\mathfrak{X}` only handle a single +// letter argument — multi-letter `\mathbb{NN}` is rare and would +// need a real parser to do correctly. +// • Sub/super scripts only convert if EVERY character has a Unicode +// equivalent. Mixed content like `^{n+1}` falls back to the raw +// LaTeX so we don't emit `ⁿ+¹` (which has no `+` superscript glyph +// in some fonts and reads worse than the source). + +const SYMBOLS: Record = { + // Greek lowercase + '\\alpha': 'α', + '\\beta': 'β', + '\\gamma': 'γ', + '\\delta': 'δ', + '\\epsilon': 'ε', + '\\varepsilon': 'ε', + '\\zeta': 'ζ', + '\\eta': 'η', + '\\theta': 'θ', + '\\vartheta': 'ϑ', + '\\iota': 'ι', + '\\kappa': 'κ', + '\\lambda': 'λ', + '\\mu': 'μ', + '\\nu': 'ν', + '\\xi': 'ξ', + '\\pi': 'π', + '\\varpi': 'ϖ', + '\\rho': 'ρ', + '\\varrho': 'ϱ', + '\\sigma': 'σ', + '\\varsigma': 'ς', + '\\tau': 'τ', + '\\upsilon': 'υ', + '\\phi': 'φ', + '\\varphi': 'φ', + '\\chi': 'χ', + '\\psi': 'ψ', + '\\omega': 'ω', + + // Greek uppercase + '\\Gamma': 'Γ', + '\\Delta': 'Δ', + '\\Theta': 'Θ', + '\\Lambda': 'Λ', + '\\Xi': 'Ξ', + '\\Pi': 'Π', + '\\Sigma': 'Σ', + '\\Upsilon': 'Υ', + '\\Phi': 'Φ', + '\\Psi': 'Ψ', + '\\Omega': 'Ω', + + // Big operators + '\\sum': '∑', + '\\prod': '∏', + '\\coprod': '∐', + '\\int': '∫', + '\\iint': '∬', + '\\iiint': '∭', + '\\oint': '∮', + '\\bigcup': '⋃', + '\\bigcap': '⋂', + '\\bigvee': '⋁', + '\\bigwedge': '⋀', + '\\bigoplus': '⨁', + '\\bigotimes': '⨂', + + // Calculus + '\\partial': '∂', + '\\nabla': '∇', + '\\sqrt': '√', + + // Sets + '\\emptyset': '∅', + '\\varnothing': '∅', + '\\infty': '∞', + '\\in': '∈', + '\\notin': '∉', + '\\ni': '∋', + '\\subset': '⊂', + '\\supset': '⊃', + '\\subseteq': '⊆', + '\\supseteq': '⊇', + '\\subsetneq': '⊊', + '\\supsetneq': '⊋', + '\\cup': '∪', + '\\cap': '∩', + '\\setminus': '∖', + '\\complement': '∁', + + // Logic + '\\forall': '∀', + '\\exists': '∃', + '\\nexists': '∄', + '\\land': '∧', + '\\lor': '∨', + '\\lnot': '¬', + '\\neg': '¬', + '\\therefore': '∴', + '\\because': '∵', + + // Relations + '\\le': '≤', + '\\leq': '≤', + '\\ge': '≥', + '\\geq': '≥', + '\\ne': '≠', + '\\neq': '≠', + '\\ll': '≪', + '\\gg': '≫', + '\\approx': '≈', + '\\equiv': '≡', + '\\cong': '≅', + '\\sim': '∼', + '\\simeq': '≃', + '\\propto': '∝', + '\\perp': '⊥', + '\\parallel': '∥', + '\\models': '⊨', + '\\vdash': '⊢', + '\\mid': '∣', + '\\nmid': '∤', + '\\divides': '∣', + + // Common standalone glyphs + '\\blacksquare': '■', + '\\square': '□', + '\\Box': '□', + '\\qed': '∎', + '\\bigstar': '★', + + // Modular arithmetic — the `\pmod{p}` form (with arg) is handled below; + // the bare `\bmod` / `\mod` commands are simple text substitutions. + '\\bmod': 'mod', + '\\mod': 'mod', + + // Brackets / fences (named delimiter commands; the `\left\X` / `\right\X` + // unwrapping below leaves these behind for the symbol pass to resolve). + '\\langle': '⟨', + '\\rangle': '⟩', + '\\lceil': '⌈', + '\\rceil': '⌉', + '\\lfloor': '⌊', + '\\rfloor': '⌋', + '\\|': '‖', + + // Arrows + '\\to': '→', + '\\rightarrow': '→', + '\\leftarrow': '←', + '\\leftrightarrow': '↔', + '\\Rightarrow': '⇒', + '\\Leftarrow': '⇐', + '\\Leftrightarrow': '⇔', + '\\implies': '⟹', + '\\impliedby': '⟸', + '\\iff': '⟺', + '\\mapsto': '↦', + '\\hookrightarrow': '↪', + '\\hookleftarrow': '↩', + '\\uparrow': '↑', + '\\downarrow': '↓', + '\\updownarrow': '↕', + + // Binary operators + '\\cdot': '⋅', + '\\cdots': '⋯', + '\\ldots': '…', + '\\dots': '…', + '\\dotsb': '…', + '\\dotsc': '…', + '\\vdots': '⋮', + '\\ddots': '⋱', + '\\times': '×', + '\\div': '÷', + '\\pm': '±', + '\\mp': '∓', + '\\circ': '∘', + '\\bullet': '•', + '\\star': '⋆', + '\\ast': '∗', + '\\oplus': '⊕', + '\\ominus': '⊖', + '\\otimes': '⊗', + '\\odot': '⊙', + '\\diamond': '⋄', + '\\angle': '∠', + '\\triangle': '△', + + // Spacing — collapse to varying widths of regular space + '\\,': ' ', + '\\;': ' ', + '\\:': ' ', + '\\!': '', + '\\ ': ' ', + '\\quad': ' ', + '\\qquad': ' ', + + // Functions (LaTeX renders these in roman; we just keep the name) + '\\sin': 'sin', + '\\cos': 'cos', + '\\tan': 'tan', + '\\cot': 'cot', + '\\sec': 'sec', + '\\csc': 'csc', + '\\arcsin': 'arcsin', + '\\arccos': 'arccos', + '\\arctan': 'arctan', + '\\sinh': 'sinh', + '\\cosh': 'cosh', + '\\tanh': 'tanh', + '\\log': 'log', + '\\ln': 'ln', + '\\exp': 'exp', + '\\det': 'det', + '\\dim': 'dim', + '\\ker': 'ker', + '\\lim': 'lim', + '\\liminf': 'liminf', + '\\limsup': 'limsup', + '\\sup': 'sup', + '\\inf': 'inf', + '\\max': 'max', + '\\min': 'min', + '\\arg': 'arg', + '\\gcd': 'gcd', + + // Escaped literals — model occasionally emits these for display + '\\&': '&', + '\\%': '%', + '\\$': '$', + '\\#': '#', + '\\_': '_', + '\\{': '{', + '\\}': '}' +} + +const BB: Record = { + A: '𝔸', + B: '𝔹', + C: 'ℂ', + D: '𝔻', + E: '𝔼', + F: '𝔽', + G: '𝔾', + H: 'ℍ', + I: '𝕀', + J: '𝕁', + K: '𝕂', + L: '𝕃', + M: '𝕄', + N: 'ℕ', + O: '𝕆', + P: 'ℙ', + Q: 'ℚ', + R: 'ℝ', + S: '𝕊', + T: '𝕋', + U: '𝕌', + V: '𝕍', + W: '𝕎', + X: '𝕏', + Y: '𝕐', + Z: 'ℤ' +} + +const CAL: Record = { + A: '𝒜', + B: 'ℬ', + C: '𝒞', + D: '𝒟', + E: 'ℰ', + F: 'ℱ', + G: '𝒢', + H: 'ℋ', + I: 'ℐ', + J: '𝒥', + K: '𝒦', + L: 'ℒ', + M: 'ℳ', + N: '𝒩', + O: '𝒪', + P: '𝒫', + Q: '𝒬', + R: 'ℛ', + S: '𝒮', + T: '𝒯', + U: '𝒰', + V: '𝒱', + W: '𝒲', + X: '𝒳', + Y: '𝒴', + Z: '𝒵' +} + +const FRAK: Record = { + A: '𝔄', + B: '𝔅', + C: 'ℭ', + D: '𝔇', + E: '𝔈', + F: '𝔉', + G: '𝔊', + H: 'ℌ', + I: 'ℑ', + J: '𝔍', + K: '𝔎', + L: '𝔏', + M: '𝔐', + N: '𝔑', + O: '𝔒', + P: '𝔓', + Q: '𝔔', + R: 'ℜ', + S: '𝔖', + T: '𝔗', + U: '𝔘', + V: '𝔙', + W: '𝔚', + X: '𝔛', + Y: '𝔜', + Z: 'ℨ' +} + +const SUPERSCRIPT: Record = { + '0': '⁰', + '1': '¹', + '2': '²', + '3': '³', + '4': '⁴', + '5': '⁵', + '6': '⁶', + '7': '⁷', + '8': '⁸', + '9': '⁹', + '+': '⁺', + '-': '⁻', + '=': '⁼', + '(': '⁽', + ')': '⁾', + a: 'ᵃ', + b: 'ᵇ', + c: 'ᶜ', + d: 'ᵈ', + e: 'ᵉ', + f: 'ᶠ', + g: 'ᵍ', + h: 'ʰ', + i: 'ⁱ', + j: 'ʲ', + k: 'ᵏ', + l: 'ˡ', + m: 'ᵐ', + n: 'ⁿ', + o: 'ᵒ', + p: 'ᵖ', + r: 'ʳ', + s: 'ˢ', + t: 'ᵗ', + u: 'ᵘ', + v: 'ᵛ', + w: 'ʷ', + x: 'ˣ', + y: 'ʸ', + z: 'ᶻ' +} + +const SUBSCRIPT: Record = { + '0': '₀', + '1': '₁', + '2': '₂', + '3': '₃', + '4': '₄', + '5': '₅', + '6': '₆', + '7': '₇', + '8': '₈', + '9': '₉', + '+': '₊', + '-': '₋', + '=': '₌', + '(': '₍', + ')': '₎', + a: 'ₐ', + e: 'ₑ', + h: 'ₕ', + i: 'ᵢ', + j: 'ⱼ', + k: 'ₖ', + l: 'ₗ', + m: 'ₘ', + n: 'ₙ', + o: 'ₒ', + p: 'ₚ', + r: 'ᵣ', + s: 'ₛ', + t: 'ₜ', + u: 'ᵤ', + v: 'ᵥ', + x: 'ₓ' +} + +// Sentinel control characters used to mark `\boxed` / `\fbox` regions in +// the converted output. The renderer splits on these to apply a highlight +// style; consumers that don't want highlighting can strip them with the +// exported `BOX_RE` below. +export const BOX_OPEN = '\u0001' +export const BOX_CLOSE = '\u0002' +export const BOX_RE = /\u0001([^\u0001\u0002]*)\u0002/g + +const escapeRe = (s: string) => s.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + +// Pre-compile two symbol regexes: one for letter-ending commands (`\pi`, +// `\sum`) which need a `(?![A-Za-z])` lookahead so they don't partially +// match `\pix` or `\summa`, and one for punctuation-ending commands +// (`\{`, `\,`, `\|`) which must NOT have the lookahead — otherwise +// `\{p` would refuse to substitute because `p` is a letter. +// +// Longest commands first inside each group so `\leq` beats `\le`. +const splitByEnding = (keys: string[]) => { + const letter: string[] = [] + const punct: string[] = [] + + for (const k of keys) { + if (/[A-Za-z]$/.test(k)) { + letter.push(k) + } else { + punct.push(k) + } + } + + return { letter, punct } +} + +const buildAlt = (cmds: string[]) => + cmds + .sort((a, b) => b.length - a.length) + .map(escapeRe) + .join('|') + +const { letter: LETTER_CMDS, punct: PUNCT_CMDS } = splitByEnding(Object.keys(SYMBOLS)) + +const SYMBOL_LETTER_RE = new RegExp('(?:' + buildAlt(LETTER_CMDS) + ')(?![A-Za-z])', 'g') +const SYMBOL_PUNCT_RE = new RegExp('(?:' + buildAlt(PUNCT_CMDS) + ')', 'g') + +const convertScript = (input: string, table: Record, sigil: '^' | '_'): string => { + let out = '' + let allMapped = true + + for (const ch of input) { + const mapped = table[ch] + + if (!mapped) { + allMapped = false + + break + } + + out += mapped + } + + if (allMapped) { + return out + } + + // Fallback: if the body is a single visible character (e.g. `∞` after + // earlier symbol substitution), render it without braces — `^∞` reads + // far better than `^{∞}` in a terminal. Multi-char bodies that don't + // fully convert use parens (`e^(iπ)`) instead of braces (`e^{iπ}`) + // because parens are normal punctuation while braces look like + // unrendered LaTeX. + const trimmed = input.trim() + + if ([...trimmed].length === 1) { + return `${sigil}${trimmed}` + } + + return `${sigil}(${trimmed})` +} + +// Walk the string and parse `{...}` honouring nested braces. Unlike a +// `\{[^{}]*\}` regex this survives `\frac{|t|^{p-1}|P(t)|^p}{...}` where +// the numerator contains its own braces from a superscript. Returns the +// inner content (without the outer braces) and the offset just past the +// closing `}`. Returns null if there is no balanced brace at `start`. +const readBraced = (s: string, start: number): { content: string; end: number } | null => { + if (s[start] !== '{') { + return null + } + + let depth = 1 + let i = start + 1 + + while (i < s.length && depth > 0) { + const c = s[i] + + // Skip escapes — `\{` and `\}` inside a body are literal braces and + // should not change the brace counter. + if (c === '\\' && i + 1 < s.length) { + i += 2 + continue + } + + if (c === '{') { + depth++ + } else if (c === '}') { + depth-- + } + + if (depth > 0) { + i++ + } + } + + if (depth !== 0) { + return null + } + + return { content: s.slice(start + 1, i), end: i + 1 } +} + +// Replace every occurrence of `\command{arg}` using balanced-brace parsing +// (so `\boxed{x^{n+1}}` works where a `[^{}]*` regex would fail). The +// `render` callback receives the inner content already recursed-into, so +// `\boxed{\boxed{x}}` resolves outside-in cleanly. Unmatched `\command` +// (no following `{...}`) is preserved verbatim. +const replaceBracedCommand = (input: string, command: string, render: (content: string) => string): string => { + const cmdLen = command.length + let out = '' + let i = 0 + + while (i < input.length) { + const idx = input.indexOf(command, i) + + if (idx < 0) { + out += input.slice(i) + + return out + } + + const after = input[idx + cmdLen] + + if (after && /[A-Za-z]/.test(after)) { + out += input.slice(i, idx + cmdLen) + i = idx + cmdLen + continue + } + + out += input.slice(i, idx) + + let p = idx + cmdLen + + while (input[p] === ' ' || input[p] === '\t') p++ + + const arg = readBraced(input, p) + + if (!arg) { + out += input.slice(idx, p + 1) + i = p + 1 + continue + } + + out += render(replaceBracedCommand(arg.content, command, render)) + i = arg.end + } + + return out +} + +// Replace every `\frac{num}{den}` with `num/den` (parens around either +// side when its precedence demands it). The recursion handles nested +// fractions naturally: `\frac{1}{\frac{1}{x}}` collapses to `1/(1/x)` +// because we recurse into `den` before deciding whether to parenthesise. +const replaceFracs = (input: string): string => { + let out = '' + let i = 0 + + while (i < input.length) { + const idx = input.indexOf('\\frac', i) + + if (idx < 0) { + out += input.slice(i) + + return out + } + + const after = input[idx + 5] + + // `(?![A-Za-z])` — protect hypothetical commands like `\fraction`. + if (after && /[A-Za-z]/.test(after)) { + out += input.slice(i, idx + 5) + i = idx + 5 + continue + } + + out += input.slice(i, idx) + + let p = idx + 5 + + while (input[p] === ' ' || input[p] === '\t') p++ + + const num = readBraced(input, p) + + if (!num) { + out += input.slice(idx, p + 1) + i = p + 1 + continue + } + + p = num.end + + while (input[p] === ' ' || input[p] === '\t') p++ + + const den = readBraced(input, p) + + if (!den) { + out += input.slice(idx, p + 1) + i = p + 1 + continue + } + + out += `${wrapForFrac(replaceFracs(num.content))}/${wrapForFrac(replaceFracs(den.content))}` + i = den.end + } + + return out +} + +// Wrap multi-token expressions in parens so `\frac{a+b}{c}` becomes +// `(a+b)/c` rather than `a+b/c`. We wrap whenever inline `/` would +// change the meaning — that's any binary operator (`+`, `-`, `*`, `/`) +// or whitespace separating tokens. `*` and `/` matter because nested +// fractions and products like `\frac{a*b}{c}` and `\frac{1/x}{y}` would +// otherwise read as `a*b/c` (right-associative ambiguity) and `1/x/y`. +// Atomic factors like `n!`, `x^2`, `\sin x` don't trigger any of these +// and stay un-parenthesised — wrapping them just clutters the output. +const wrapForFrac = (expr: string) => { + const trimmed = expr.trim() + + if (!trimmed) { + return trimmed + } + + if (/^\(.*\)$/.test(trimmed)) { + return trimmed + } + + if (/[+\-/*]|\s/.test(trimmed)) { + return `(${trimmed})` + } + + return trimmed +} + +export function texToUnicode(input: string): string { + let s = input + + s = s.replace(/\\mathbb\s*\{([A-Za-z])\}/g, (raw, c: string) => BB[c] ?? raw) + s = s.replace(/\\mathcal\s*\{([A-Za-z])\}/g, (raw, c: string) => CAL[c] ?? raw) + s = s.replace(/\\mathfrak\s*\{([A-Za-z])\}/g, (raw, c: string) => FRAK[c] ?? raw) + s = s.replace(/\\mathbf\s*\{([^{}]+)\}/g, (_, c: string) => c) + s = s.replace(/\\mathit\s*\{([^{}]+)\}/g, (_, c: string) => c) + s = s.replace(/\\mathrm\s*\{([^{}]+)\}/g, (_, c: string) => c) + s = s.replace(/\\text\s*\{([^{}]+)\}/g, (_, c: string) => c) + s = s.replace(/\\operatorname\s*\{([^{}]+)\}/g, (_, c: string) => c) + + s = s.replace(/\\overline\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0305`) + s = s.replace(/\\hat\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0302`) + s = s.replace(/\\bar\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0304`) + s = s.replace(/\\tilde\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0303`) + s = s.replace(/\\vec\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u20D7`) + s = s.replace(/\\dot\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0307`) + s = s.replace(/\\ddot\s*\{([^{}]+)\}/g, (_, c: string) => `${c}\u0308`) + + s = replaceFracs(s) + + // `\boxed{X}` / `\fbox{X}` highlight a final answer. Terminals can't + // draw a real box, so we wrap the content in U+0001 / U+0002 control + // characters — non-printable, never present in real text — and let the + // markdown renderer split on them and apply a highlight style (inverse + // video) to the bracketed region. This keeps `texToUnicode` pure-string + // while letting the React layer do the actual visual emphasis. + // Argument is parsed with balanced braces so nested `{...}` from + // superscripts / fractions inside the box survive. + s = replaceBracedCommand(s, '\\boxed', body => `${BOX_OPEN}${body.trim()}${BOX_CLOSE}`) + s = replaceBracedCommand(s, '\\fbox', body => `${BOX_OPEN}${body.trim()}${BOX_CLOSE}`) + + // `\xrightarrow{label}` / `\xleftarrow{label}` collapse to an arrow with + // the label inline. LaTeX renders the label above the arrow; in monospace + // we put it adjacent — `─label→` is the closest readable approximation. + // Run before the symbol pass so the label can still pick up Greek and + // operator substitutions afterwards. + s = s.replace(/\\xrightarrow\s*\{([^{}]*)\}/g, (_, label: string) => `─${label.trim()}→`) + s = s.replace(/\\xleftarrow\s*\{([^{}]*)\}/g, (_, label: string) => `←${label.trim()}─`) + s = s.replace(/\\Longrightarrow/g, '⟹') + s = s.replace(/\\Longleftarrow/g, '⟸') + s = s.replace(/\\Longleftrightarrow/g, '⟺') + + // `\pmod{p}` → ` (mod p)` (LaTeX adds parens automatically); `\pod{p}` + // is a paren-less variant; `\tag{n}` is the equation-number annotation + // shown to the right of an equation. Collapse to a single-space-prefixed + // bracketed form. The leading `\s*` in the pattern absorbs any whitespace + // already in the source so we don't end up with `b (mod p)` (double + // space) when the user wrote `b \pmod{p}`. + s = s.replace(/\s*\\pmod\s*\{([^{}]*)\}/g, (_, p: string) => ` (mod ${p.trim()})`) + s = s.replace(/\s*\\pod\s*\{([^{}]*)\}/g, (_, p: string) => ` (${p.trim()})`) + s = s.replace(/\s*\\tag\s*\{([^{}]*)\}/g, (_, n: string) => ` (${n.trim()})`) + + // `\big`, `\Big`, `\bigg`, `\Bigg` (with optional `l`/`r`/`m` suffix) + // are sizing wrappers analogous to `\left`/`\right` but without the + // automatic-pairing semantics. Strip them and leave whatever delimiter + // follows. The trailing `(?![A-Za-z])` protects `\bigtriangleup` and + // any other letter-continuation command from being shaved. + s = s.replace(/\\(?:Bigg|bigg|Big|big)[lrm]?(?![A-Za-z])/g, '') + + // Style / size hints that don't typeset any glyph and only affect how + // things would be sized in a real LaTeX engine. In a terminal every + // glyph is one monospace cell, so there's nothing to do — drop them + // (with any trailing whitespace) so they don't leak through as raw + // `\displaystyle` in the output. + s = s.replace(/\\(?:scriptscriptstyle|displaystyle|scriptstyle|textstyle|nolimits|limits)(?![A-Za-z])\s*/g, '') + + // `\left` and `\right` are sizing wrappers around any delimiter — bare + // (`\left(`), escaped (`\left\{`), or named (`\left\langle`). Strip the + // wrapper unconditionally and let the rest of the pipeline (or the + // upcoming symbol pass) handle whatever delimiter follows. The optional + // `.?` consumes `\left.` / `\right.` which mean "no delimiter". + // Lookahead `(?![A-Za-z])` keeps `\leftarrow` / `\leftrightarrow` safe. + s = s.replace(/\\left(?![A-Za-z])\.?/g, '') + s = s.replace(/\\right(?![A-Za-z])\.?/g, '') + + // Run symbol substitution BEFORE scripts so a body like `^{\infty}` + // becomes `^{∞}` first; convertScript can then either map ∞ to a + // superscript (it can't — Unicode lacks one) or fall back to `^∞` + // by stripping braces around the now-single-character body. + // + // Punctuation pass first — these can be followed by letters (`\{p` + // is "open-brace then p"), so the letter pass's `(?![A-Za-z])` rule + // would wrongly block them. + s = s.replace(SYMBOL_PUNCT_RE, m => SYMBOLS[m] ?? m) + s = s.replace(SYMBOL_LETTER_RE, m => SYMBOLS[m] ?? m) + + // Bare `^c` / `_c` handles ONLY alphanumerics and `+`/`-`/`=`. Parens + // are intentionally excluded because the braced-fallback above can + // emit `(...)` and we don't want a second pass to greedily convert + // its opening paren into `⁽` and orphan the closing one. + s = s.replace(/\^\s*\{([^{}]+)\}/g, (_, body: string) => convertScript(body, SUPERSCRIPT, '^')) + s = s.replace(/\^([A-Za-z0-9+\-=])/g, (raw, ch: string) => SUPERSCRIPT[ch] ?? raw) + s = s.replace(/_\s*\{([^{}]+)\}/g, (_, body: string) => convertScript(body, SUBSCRIPT, '_')) + s = s.replace(/_([A-Za-z0-9+\-=])/g, (raw, ch: string) => SUBSCRIPT[ch] ?? raw) + + return s +} diff --git a/ui-tui/src/lib/memoryMonitor.ts b/ui-tui/src/lib/memoryMonitor.ts index bbdb229705..eaf11574a4 100644 --- a/ui-tui/src/lib/memoryMonitor.ts +++ b/ui-tui/src/lib/memoryMonitor.ts @@ -1,5 +1,3 @@ -import { evictInkCaches } from '@hermes/ink' - import { type HeapDumpResult, performHeapDump } from './memory.js' export type MemoryLevel = 'critical' | 'high' | 'normal' @@ -20,6 +18,40 @@ export interface MemoryMonitorOptions { const GB = 1024 ** 3 +// Deferred @hermes/ink import: loading `@hermes/ink` at module top-level +// pulls the full ~414KB Ink bundle (React, renderer, components, hooks) onto +// the critical path before the Python gateway can even be spawned. That +// serialised roughly 150ms of Node work in front of gw.start() on every +// cold `hermes --tui` launch. +// +// evictInkCaches only runs inside `tick()`, which fires on a 10s timer and +// only when heap pressure crosses the high-water mark — by then Ink has +// long since been loaded by the app entry. This dynamic import is a no-op +// on the hot path (module is already in the ESM cache); when a startup +// spike somehow trips the threshold before the app registers its own Ink +// import, we pay the load cost exactly once, inside the tick that needs it. +let _evictInkCaches: ((level: 'all' | 'half') => unknown) | null = null +let _evictInkCachesPromise: Promise<(level: 'all' | 'half') => unknown> | null = null + +async function _ensureEvictInkCaches(): Promise<(level: 'all' | 'half') => unknown> { + if (_evictInkCaches) { + return _evictInkCaches + } + + _evictInkCachesPromise ??= import('@hermes/ink') + .then(mod => { + _evictInkCaches = mod.evictInkCaches as (level: 'all' | 'half') => unknown + + return _evictInkCaches + }) + .catch(err => { + _evictInkCachesPromise = null + throw err + }) + + return _evictInkCachesPromise +} + export function startMemoryMonitor({ criticalBytes = 2.5 * GB, highBytes = 1.5 * GB, @@ -28,29 +60,45 @@ export function startMemoryMonitor({ onHigh }: MemoryMonitorOptions = {}): () => void { const dumped = new Set>() + const inFlight = new Set>() const tick = async () => { const { heapUsed, rss } = process.memoryUsage() const level: MemoryLevel = heapUsed >= criticalBytes ? 'critical' : heapUsed >= highBytes ? 'high' : 'normal' if (level === 'normal') { - return void dumped.clear() - } - - if (dumped.has(level)) { + dumped.clear() return } + if (dumped.has(level) || inFlight.has(level)) { + return + } + + inFlight.add(level) + // Prune Ink content caches before dump/exit — half on 'high' (recoverable), // full on 'critical' (post-dump RSS reduction, keeps user running). - evictInkCaches(level === 'critical' ? 'all' : 'half') + // Deferred import keeps `@hermes/ink` off the cold-start critical path; + // by the time a tick fires 10s after launch the app has already loaded + // the same module, so this resolves instantly from the ESM cache. + try { + try { + const evictInkCaches = await _ensureEvictInkCaches() + evictInkCaches(level === 'critical' ? 'all' : 'half') + } catch { + // Best-effort: if the dynamic import fails for any reason we still + // continue to the heap dump below so the user gets diagnostics. + } - dumped.add(level) - const dump = await performHeapDump(level === 'critical' ? 'auto-critical' : 'auto-high').catch(() => null) + dumped.add(level) + const dump = await performHeapDump(level === 'critical' ? 'auto-critical' : 'auto-high').catch(() => null) + const snap: MemorySnapshot = { heapUsed, level, rss } - const snap: MemorySnapshot = { heapUsed, level, rss } - - ;(level === 'critical' ? onCritical : onHigh)?.(snap, dump) + ;(level === 'critical' ? onCritical : onHigh)?.(snap, dump) + } finally { + inFlight.delete(level) + } } const handle = setInterval(() => void tick(), intervalMs) diff --git a/ui-tui/src/theme.ts b/ui-tui/src/theme.ts index e14b8d2a52..2a55709036 100644 --- a/ui-tui/src/theme.ts +++ b/ui-tui/src/theme.ts @@ -76,6 +76,162 @@ function mix(a: string, b: string, t: number) { return '#' + ((1 << 24) | (lerp(0) << 16) | (lerp(1) << 8) | lerp(2)).toString(16).slice(1) } +const XTERM_6_LEVELS = [0, 95, 135, 175, 215, 255] as const +const ANSI_LIGHT_MAX_LUMINANCE = 0.72 +const ANSI_LIGHT_TARGET_LUMINANCE = 0.34 +const ANSI_LIGHT_MIN_SATURATION = 0.22 +const ANSI_MUTED_BUCKET = 245 + +const ANSI_NORMALIZED_FOREGROUNDS: readonly (keyof ThemeColors)[] = [ + 'text', + 'label', + 'ok', + 'error', + 'warn', + 'prompt', + 'statusFg', + 'statusGood', + 'statusWarn', + 'statusBad', + 'statusCritical', + 'shellDollar' +] + +const ANSI_MUTED_FOREGROUNDS: readonly (keyof ThemeColors)[] = ['muted', 'sessionLabel', 'sessionBorder'] + +function xtermEightBitRgb(colorNumber: number): [number, number, number] { + if (colorNumber >= 232) { + const value = 8 + (colorNumber - 232) * 10 + + return [value, value, value] + } + + if (colorNumber >= 16) { + const offset = colorNumber - 16 + + return [ + XTERM_6_LEVELS[Math.floor(offset / 36) % 6]!, + XTERM_6_LEVELS[Math.floor(offset / 6) % 6]!, + XTERM_6_LEVELS[offset % 6]! + ] + } + + return [0, 0, 0] +} + +function channelLuminance(value: number): number { + const normalized = value / 255 + + return normalized <= 0.03928 ? normalized / 12.92 : ((normalized + 0.055) / 1.055) ** 2.4 +} + +function relativeLuminance(red: number, green: number, blue: number): number { + return 0.2126 * channelLuminance(red) + 0.7152 * channelLuminance(green) + 0.0722 * channelLuminance(blue) +} + +function rgbToHsl(red: number, green: number, blue: number): [number, number, number] { + const rn = red / 255 + const gn = green / 255 + const bn = blue / 255 + const max = Math.max(rn, gn, bn) + const min = Math.min(rn, gn, bn) + const lightness = (max + min) / 2 + + if (max === min) { + return [0, 0, lightness] + } + + const delta = max - min + const saturation = lightness > 0.5 ? delta / (2 - max - min) : delta / (max + min) + + const hue = + max === rn + ? (gn - bn) / delta + (gn < bn ? 6 : 0) + : max === gn + ? (bn - rn) / delta + 2 + : (rn - gn) / delta + 4 + + return [hue / 6, saturation, lightness] +} + +function circularDistance(a: number, b: number): number { + const distance = Math.abs(a - b) + + return Math.min(distance, 1 - distance) +} + +// Mirrors @hermes/ink's colorize.ts. Keep local: app code compiles from +// ui-tui/src, while @hermes/ink is bundled separately from packages/. +function richEightBitColorNumber(red: number, green: number, blue: number): number { + const [, saturation, lightness] = rgbToHsl(red, green, blue) + + if (saturation < 0.15) { + const gray = Math.round(lightness * 25) + + return gray === 0 ? 16 : gray === 25 ? 231 : 231 + gray + } + + const sixRed = red < 95 ? red / 95 : 1 + (red - 95) / 40 + const sixGreen = green < 95 ? green / 95 : 1 + (green - 95) / 40 + const sixBlue = blue < 95 ? blue / 95 : 1 + (blue - 95) / 40 + + return 16 + 36 * Math.round(sixRed) + 6 * Math.round(sixGreen) + Math.round(sixBlue) +} + +function bestReadableAnsiColor(red: number, green: number, blue: number): number { + const [hue, saturation, lightness] = rgbToHsl(red, green, blue) + let bestColor = richEightBitColorNumber(red, green, blue) + let bestScore = Number.POSITIVE_INFINITY + + for (let colorNumber = 16; colorNumber <= 255; colorNumber += 1) { + const [candidateRed, candidateGreen, candidateBlue] = xtermEightBitRgb(colorNumber) + const candidateLuminance = relativeLuminance(candidateRed, candidateGreen, candidateBlue) + + if (candidateLuminance > ANSI_LIGHT_MAX_LUMINANCE) { + continue + } + + const [candidateHue, candidateSaturation, candidateLightness] = rgbToHsl( + candidateRed, + candidateGreen, + candidateBlue + ) + + const saturationFloorPenalty = + candidateSaturation < ANSI_LIGHT_MIN_SATURATION ? (ANSI_LIGHT_MIN_SATURATION - candidateSaturation) * 3 : 0 + + const score = + circularDistance(candidateHue, hue) * 4 + + Math.abs(candidateSaturation - Math.max(ANSI_LIGHT_MIN_SATURATION, saturation)) * 0.8 + + Math.abs(candidateLightness - Math.min(lightness, ANSI_LIGHT_TARGET_LUMINANCE)) * 2 + + saturationFloorPenalty + + if (score < bestScore) { + bestColor = colorNumber + bestScore = score + } + } + + return bestColor +} + +function normalizeAnsiForeground(color: string): string { + const rgb = parseHex(color) + + if (!rgb) { + return color + } + + const richAnsi = richEightBitColorNumber(rgb[0], rgb[1], rgb[2]) + const richRgb = xtermEightBitRgb(richAnsi) + + const ansi = relativeLuminance(richRgb[0], richRgb[1], richRgb[2]) > ANSI_LIGHT_MAX_LUMINANCE + ? bestReadableAnsiColor(rgb[0], rgb[1], rgb[2]) + : richAnsi + + return `ansi256(${ansi})` +} + // ── Defaults ───────────────────────────────────────────────────────── const BRAND: ThemeBrand = { @@ -190,12 +346,11 @@ export const LIGHT_THEME: Theme = { const TRUE_RE = /^(?:1|true|yes|on)$/ const FALSE_RE = /^(?:0|false|no|off)$/ -// Reserved for future TERM_PROGRAM-based heuristics. Empty by default: -// most modern terminals (Ghostty, Warp, iTerm2, Apple_Terminal) ship a -// dark profile out of the box, so guessing wrong here is more annoying -// than missing a light user — light users can always set -// `HERMES_TUI_LIGHT=1` or `HERMES_TUI_THEME=light`. -const LIGHT_DEFAULT_TERM_PROGRAMS = new Set() +// TERM_PROGRAM fallback allow-list for terminals whose default profile is +// light and which may not expose COLORFGBG. This currently includes Apple +// Terminal. Explicit HERMES_TUI_THEME / COLORFGBG signals above still win, +// so dark Apple Terminal profiles that advertise a dark background stay dark. +const LIGHT_DEFAULT_TERM_PROGRAMS = new Set(['Apple_Terminal']) // Best-effort RGB → luminance check. Currently only accepts a 3- or // 6-digit hex value (with or without a leading `#`); the env var name @@ -247,7 +402,7 @@ function backgroundLuminance(raw: string): null | number { // slot 7 or 15 on light profiles; 0–15 ranges are otherwise // treated as authoritatively dark so the TERM_PROGRAM // allow-list below cannot override an explicit dark profile. -// 5. `TERM_PROGRAM` light-default allow-list (currently empty). +// 5. `TERM_PROGRAM` light-default allow-list. // // Anything we can't decide stays dark — the default Hermes palette // is the dark one. @@ -313,7 +468,42 @@ export function detectLightMode( return lightDefaultTermPrograms.has(termProgram) } -export const DEFAULT_THEME: Theme = detectLightMode() ? LIGHT_THEME : DARK_THEME +function shouldNormalizeAnsiLightTheme(env: NodeJS.ProcessEnv = process.env, isLight = detectLightMode(env)): boolean { + const colorTerm = (env.COLORTERM ?? '').trim().toLowerCase() + const termProgram = (env.TERM_PROGRAM ?? '').trim() + + return termProgram === 'Apple_Terminal' && colorTerm !== 'truecolor' && colorTerm !== '24bit' && isLight +} + +export function normalizeThemeForAnsiLightTerminal( + theme: Theme, + env: NodeJS.ProcessEnv = process.env, + isLight = detectLightMode(env) +): Theme { + if (!shouldNormalizeAnsiLightTheme(env, isLight)) { + return theme + } + + const color = { ...theme.color } + + for (const key of ANSI_NORMALIZED_FOREGROUNDS) { + color[key] = normalizeAnsiForeground(color[key]) + } + + for (const key of ANSI_MUTED_FOREGROUNDS) { + color[key] = `ansi256(${ANSI_MUTED_BUCKET})` + } + + return { ...theme, color } +} + +const DEFAULT_LIGHT_MODE = detectLightMode() + +export const DEFAULT_THEME: Theme = normalizeThemeForAnsiLightTerminal( + DEFAULT_LIGHT_MODE ? LIGHT_THEME : DARK_THEME, + process.env, + DEFAULT_LIGHT_MODE +) // ── Skin → Theme ───────────────────────────────────────────────────── @@ -333,7 +523,7 @@ export function fromSkin( const muted = c('banner_dim') ?? d.color.muted const completionBg = c('completion_menu_bg') ?? d.color.completionBg - return { + return normalizeThemeForAnsiLightTerminal({ color: { primary: c('ui_primary') ?? c('banner_title') ?? d.color.primary, accent, @@ -379,5 +569,5 @@ export function fromSkin( bannerLogo, bannerHero - } + }, process.env, DEFAULT_LIGHT_MODE) } diff --git a/ui-tui/src/types.ts b/ui-tui/src/types.ts index 6aea78e3e4..b3ecc8fbb6 100644 --- a/ui-tui/src/types.ts +++ b/ui-tui/src/types.ts @@ -143,11 +143,12 @@ export interface McpServerStatus { export interface SessionInfo { cwd?: string fast?: boolean + lazy?: boolean mcp_servers?: McpServerStatus[] model: string reasoning_effort?: string - service_tier?: string release_date?: string + service_tier?: string skills: Record tools: Record update_behind?: number | null diff --git a/uv.lock b/uv.lock index dfb2f786b0..93db335ce9 100644 --- a/uv.lock +++ b/uv.lock @@ -1934,6 +1934,7 @@ all = [ { name = "sounddevice" }, { name = "ty" }, { name = "uvicorn", extra = ["standard"] }, + { name = "vercel" }, ] bedrock = [ { name = "boto3" }, @@ -2025,6 +2026,9 @@ termux = [ tts-premium = [ { name = "elevenlabs" }, ] +vercel = [ + { name = "vercel" }, +] voice = [ { name = "faster-whisper" }, { name = "numpy" }, @@ -2089,6 +2093,7 @@ requires-dist = [ { name = "hermes-agent", extras = ["slack"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["sms"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["tts-premium"], marker = "extra == 'all'" }, + { name = "hermes-agent", extras = ["vercel"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["voice"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["web"], marker = "extra == 'all'" }, { name = "honcho-ai", marker = "extra == 'honcho'", specifier = ">=2.0.1,<3" }, @@ -2133,10 +2138,11 @@ requires-dist = [ { name = "ty", marker = "extra == 'dev'", specifier = ">=0.0.1a29,<0.0.22" }, { name = "uvicorn", extras = ["standard"], marker = "extra == 'rl'", specifier = ">=0.24.0,<1" }, { name = "uvicorn", extras = ["standard"], marker = "extra == 'web'", specifier = ">=0.24.0,<1" }, + { name = "vercel", marker = "extra == 'vercel'", specifier = ">=0.5.7,<0.6.0" }, { name = "wandb", marker = "extra == 'rl'", specifier = ">=0.15.0,<1" }, { name = "yc-bench", marker = "python_full_version >= '3.12' and extra == 'yc-bench'", git = "https://github.com/collinear-ai/yc-bench.git?rev=bfb0c88062450f46341bd9a5298903fc2e952a5c" }, ] -provides-extras = ["modal", "daytona", "dev", "messaging", "cron", "slack", "matrix", "cli", "tts-premium", "voice", "pty", "honcho", "mcp", "homeassistant", "sms", "acp", "mistral", "bedrock", "termux", "dingtalk", "feishu", "web", "rl", "yc-bench", "all"] +provides-extras = ["modal", "daytona", "vercel", "dev", "messaging", "cron", "slack", "matrix", "cli", "tts-premium", "voice", "pty", "honcho", "mcp", "homeassistant", "sms", "acp", "mistral", "bedrock", "termux", "dingtalk", "feishu", "web", "rl", "yc-bench", "all"] [[package]] name = "hf-transfer" @@ -5339,6 +5345,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e4/16/c1fd27e9549f3c4baf1dc9c20c456cd2f822dbf8de9f463824b0c0357e06/uvloop-0.22.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6cde23eeda1a25c75b2e07d39970f3374105d5eafbaab2a4482be82f272d5a5e", size = 4296730, upload-time = "2025-10-16T22:17:00.744Z" }, ] +[[package]] +name = "vercel" +version = "0.5.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "cbor2" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "vercel-workers", marker = "python_full_version >= '3.12'" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/68/a671ebc656afbb5e25fb88c681b61511cc13670ea771c87b2f711782022b/vercel-0.5.7.tar.gz", hash = "sha256:8070ea1b33962adfed98498f9273f24ea2066a20c74d38643d479d8280801c6e", size = 118597, upload-time = "2026-04-15T17:58:20.424Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/2e/bacf1ccc0ec95464a68398e64bf5e36f859cd51f3e379623f103802f85f1/vercel-0.5.7-py3-none-any.whl", hash = "sha256:90eb2689c34e403db2170fec3eb47e1a91092c200d91baf4b4501fb3e2a44d28", size = 139698, upload-time = "2026-04-15T17:58:18.945Z" }, +] + +[[package]] +name = "vercel-workers" +version = "0.0.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "python_full_version >= '3.12'" }, + { name = "httpx", marker = "python_full_version >= '3.12'" }, + { name = "python-dotenv", marker = "python_full_version >= '3.12'" }, + { name = "vercel", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/73/d8/17ba256fceff42be231ca8ff0567dcf2da54ee8de633e949fa08b9403b1f/vercel_workers-0.0.16.tar.gz", hash = "sha256:38df45dbf42fbae39ffa0e419f0908bf1beb047e38fc5ddd0a479feac340fb8c", size = 51615, upload-time = "2026-04-13T21:23:27.649Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/3a/0137d5b157845e1d41a70130d8dce8ba15d8712f34619693cda04ecb8f02/vercel_workers-0.0.16-py3-none-any.whl", hash = "sha256:542be839e46e236a68cc308695ccc3c970d76de72c978d7f416cc6ce09688896", size = 50141, upload-time = "2026-04-13T21:23:28.652Z" }, +] + [[package]] name = "wandb" version = "0.25.1" diff --git a/web/package-lock.json b/web/package-lock.json index 2c6377b4f2..7f987c5a1d 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -76,7 +76,6 @@ "integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.29.0", "@babel/generator": "^7.29.0", @@ -1125,7 +1124,6 @@ "resolved": "https://registry.npmjs.org/@observablehq/plot/-/plot-0.6.17.tgz", "integrity": "sha512-/qaXP/7mc4MUS0s4cPPFASDRjtsWp85/TbfsciqDgU1HwYixbSbbytNuInD8AcTYC3xaxACgVX06agdfQy9W+g==", "license": "ISC", - "peer": true, "dependencies": { "d3": "^7.9.0", "interval-tree-1d": "^1.0.0", @@ -1778,7 +1776,6 @@ "resolved": "https://registry.npmjs.org/@react-three/fiber/-/fiber-9.6.0.tgz", "integrity": "sha512-90abYK2q5/qDM+GACs9zRvc5KhEEpEWqWlHSd64zTPNxg+9wCJvTfyD9x2so7hlQhjRYO1Fa6flR3BC/kpTFkA==", "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.17.8", "@types/webxr": "*", @@ -2484,7 +2481,6 @@ "integrity": "sha512-A1sre26ke7HDIuY/M23nd9gfB+nrmhtYyMINbjI1zHJxYteKR6qSMX56FsmjMcDb3SMcjJg5BiRRgOCC/yBD0g==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -2494,7 +2490,6 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.14.tgz", "integrity": "sha512-ilcTH/UniCkMdtexkoCN0bI7pMcJDvmQFPvuPvmEaYA/NSfFTAgdUSLAoVjaRJm7+6PvcM+q1zYOwS4wTYMF9w==", "license": "MIT", - "peer": true, "dependencies": { "csstype": "^3.2.2" } @@ -2505,7 +2500,6 @@ "integrity": "sha512-jp2L/eY6fn+KgVVQAOqYItbF0VY/YApe5Mz2F0aykSO8gx31bYCZyvSeYxCHKvzHG5eZjc+zyaS5BrBWya2+kQ==", "devOptional": true, "license": "MIT", - "peer": true, "peerDependencies": { "@types/react": "^19.2.0" } @@ -2570,7 +2564,6 @@ "integrity": "sha512-HDQH9O/47Dxi1ceDhBXdaldtf/WV9yRYMjbjCuNk3qnaTD564qwv61Y7+gTxwxRKzSrgO5uhtw584igXVuuZkA==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.59.1", "@typescript-eslint/types": "8.59.1", @@ -2899,7 +2892,6 @@ "integrity": "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -3052,7 +3044,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.10.12", "caniuse-lite": "^1.0.30001782", @@ -3560,7 +3551,6 @@ "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "license": "ISC", - "peer": true, "engines": { "node": ">=12" } @@ -3874,7 +3864,6 @@ "integrity": "sha512-XoMjdBOwe/esVgEvLmNsD3IRHkm7fbKIUGvrleloJXUZgDHig2IPWNniv+GwjyJXzuNqVjlr5+4yVUZjycJwfQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -4253,8 +4242,7 @@ "version": "3.15.0", "resolved": "https://registry.npmjs.org/gsap/-/gsap-3.15.0.tgz", "integrity": "sha512-dMW4CWBTUK1AEEDeZc1g4xpPGIrSf9fJF960qbTZmN/QwZIWY5wgliS6JWl9/25fpTGJrMRtSjGtOmPnfjZB+A==", - "license": "Standard 'no charge' license: https://gsap.com/standard-license.", - "peer": true + "license": "Standard 'no charge' license: https://gsap.com/standard-license." }, "node_modules/has-flag": { "version": "4.0.0", @@ -4560,7 +4548,6 @@ "resolved": "https://registry.npmjs.org/leva/-/leva-0.10.1.tgz", "integrity": "sha512-BcjnfUX8jpmwZUz2L7AfBtF9vn4ggTH33hmeufDULbP3YgNZ/C+ss/oO3stbrqRQyaOmRwy70y7BGTGO81S3rA==", "license": "MIT", - "peer": true, "dependencies": { "@radix-ui/react-portal": "^1.1.4", "@radix-ui/react-tooltip": "^1.1.8", @@ -4999,7 +4986,6 @@ } ], "license": "MIT", - "peer": true, "engines": { "node": "^20.0.0 || >=22.0.0" } @@ -5127,7 +5113,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.4.tgz", "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -5199,7 +5184,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.5.tgz", "integrity": "sha512-llUJLzz1zTUBrskt2pwZgLq59AemifIftw4aB7JxOqf1HY2FDaGDxgwpAPVzHU1kdWabH7FauP4i1oEeer2WCA==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -5219,7 +5203,6 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.5.tgz", "integrity": "sha512-J5bAZz+DXMMwW/wV3xzKke59Af6CHY7G4uYLN1OvBcKEsWOs4pQExj86BBKamxl/Ik5bx9whOrvBlSDfWzgSag==", "license": "MIT", - "peer": true, "dependencies": { "scheduler": "^0.27.0" }, @@ -5579,8 +5562,7 @@ "version": "0.180.0", "resolved": "https://registry.npmjs.org/three/-/three-0.180.0.tgz", "integrity": "sha512-o+qycAMZrh+TsE01GqWUxUIKR1AL0S8pq7zDkYOQw8GqfX8b8VoCKYUoHbhiX5j+7hr8XsuHDVU6+gkQJQKg9w==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/tinyglobby": { "version": "0.2.16", @@ -5645,7 +5627,6 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -5744,7 +5725,6 @@ "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", "license": "MIT", - "peer": true, "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } @@ -5760,7 +5740,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-7.3.2.tgz", "integrity": "sha512-Bby3NOsna2jsjfLVOHKes8sGwgl4TT0E6vvpYgnAYDIF/tie7MRaFthmKuHx1NSXjiTueXH3do80FMQgvEktRg==", "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.27.0", "fdir": "^6.5.0", @@ -5882,7 +5861,6 @@ "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", "dev": true, "license": "MIT", - "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/web/src/App.tsx b/web/src/App.tsx index 9c09e1151b..b03beef8e0 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -20,6 +20,7 @@ import { BookOpen, Clock, Code, + Cpu, Database, Download, Eye, @@ -60,6 +61,7 @@ import EnvPage from "@/pages/EnvPage"; import SessionsPage from "@/pages/SessionsPage"; import LogsPage from "@/pages/LogsPage"; import AnalyticsPage from "@/pages/AnalyticsPage"; +import ModelsPage from "@/pages/ModelsPage"; import CronPage from "@/pages/CronPage"; import ProfilesPage from "@/pages/ProfilesPage"; import SkillsPage from "@/pages/SkillsPage"; @@ -96,6 +98,7 @@ const BUILTIN_ROUTES_CORE: Record = { "/": RootRedirect, "/sessions": SessionsPage, "/analytics": AnalyticsPage, + "/models": ModelsPage, "/logs": LogsPage, "/cron": CronPage, "/skills": SkillsPage, @@ -126,6 +129,12 @@ const BUILTIN_NAV_REST: NavItem[] = [ label: "Analytics", icon: BarChart3, }, + { + path: "/models", + labelKey: "models", + label: "Models", + icon: Cpu, + }, { path: "/logs", labelKey: "logs", label: "Logs", icon: FileText }, { path: "/cron", labelKey: "cron", label: "Cron", icon: Clock }, { path: "/skills", labelKey: "skills", label: "Skills", icon: Package }, @@ -144,6 +153,7 @@ const ICON_MAP: Record> = { Activity, BarChart3, Clock, + Cpu, FileText, KeyRound, MessageSquare, diff --git a/web/src/components/ModelPickerDialog.tsx b/web/src/components/ModelPickerDialog.tsx index 010e52a79b..d99ea09a8a 100644 --- a/web/src/components/ModelPickerDialog.tsx +++ b/web/src/components/ModelPickerDialog.tsx @@ -13,9 +13,18 @@ import { useEffect, useMemo, useRef, useState } from "react"; * Stage 1: pick provider (authenticated providers only) * Stage 2: pick model within that provider * - * On confirm, emits `/model --provider [--global]` through - * the parent callback so ChatPage can dispatch it via the existing slash - * pipeline. That keeps persistence + actual switch logic in one place. + * Two invocation modes: + * + * 1. Chat-session mode (ChatSidebar) — pass `gw` + `sessionId`. The picker + * loads options via `model.options` JSON-RPC and emits the result as a + * slash command string (`/model --provider [--global]`) + * through `onSubmit`, which the ChatPage pipes to `slashExec`. + * + * 2. Standalone mode (ModelsPage, Config settings) — pass a `loader` and + * `onApply`. The picker fetches options via the REST endpoint and calls + * `onApply(provider, model, persistGlobal)` instead of emitting a slash + * command. This lets the Models page reuse the same UI without + * requiring an open chat PTY. */ interface ModelOptionProvider { @@ -34,14 +43,38 @@ interface ModelOptionsResponse { } interface Props { - gw: GatewayClient; - sessionId: string; + /** Chat-mode: when present, picker emits a slash command via onSubmit. */ + gw?: GatewayClient; + sessionId?: string; + onSubmit?(slashCommand: string): void; + + /** Standalone-mode: when present (and onSubmit absent), picker calls onApply. */ + loader?(): Promise; + onApply?(args: { + provider: string; + model: string; + persistGlobal: boolean; + }): Promise | void; + onClose(): void; - /** Parent runs the resulting slash command through slashExec. */ - onSubmit(slashCommand: string): void; + title?: string; + /** If true, hides "Persist globally" checkbox — always saves to config.yaml. */ + alwaysGlobal?: boolean; } -export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) { +export function ModelPickerDialog(props: Props) { + const { + gw, + sessionId, + onSubmit, + loader, + onApply, + onClose, + title = "Switch Model", + alwaysGlobal = false, + } = props; + const standalone = !!loader && !!onApply; + const [providers, setProviders] = useState([]); const [currentModel, setCurrentModel] = useState(""); const [currentProviderSlug, setCurrentProviderSlug] = useState(""); @@ -50,17 +83,22 @@ export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) { const [selectedSlug, setSelectedSlug] = useState(""); const [selectedModel, setSelectedModel] = useState(""); const [query, setQuery] = useState(""); - const [persistGlobal, setPersistGlobal] = useState(false); + const [persistGlobal, setPersistGlobal] = useState(alwaysGlobal); + const [applying, setApplying] = useState(false); const closedRef = useRef(false); // Load providers + models on open. useEffect(() => { closedRef.current = false; - gw.request( - "model.options", - sessionId ? { session_id: sessionId } : {}, - ) + const promise = standalone + ? (loader as () => Promise)() + : (gw as GatewayClient).request( + "model.options", + sessionId ? { session_id: sessionId } : {}, + ); + + promise .then((r) => { if (closedRef.current) return; const next = r?.providers ?? []; @@ -82,7 +120,9 @@ export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) { return () => { closedRef.current = true; }; - }, [gw, sessionId]); + // Deliberately omit props from deps — stable for the dialog's lifetime. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); // Esc closes. useEffect(() => { @@ -127,15 +167,31 @@ export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) { [models, needle], ); - const canConfirm = !!selectedProvider && !!selectedModel; + const canConfirm = !!selectedProvider && !!selectedModel && !applying; - const confirm = () => { - if (!canConfirm) return; - const global = persistGlobal ? " --global" : ""; - onSubmit( - `/model ${selectedModel} --provider ${selectedProvider.slug}${global}`, - ); - onClose(); + const confirm = async () => { + if (!canConfirm || !selectedProvider) return; + if (standalone && onApply) { + setApplying(true); + try { + await onApply({ + provider: selectedProvider.slug, + model: selectedModel, + persistGlobal, + }); + onClose(); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setApplying(false); + } + } else if (onSubmit) { + const global = persistGlobal ? " --global" : ""; + onSubmit( + `/model ${selectedModel} --provider ${selectedProvider.slug}${global}`, + ); + onClose(); + } }; return ( @@ -162,7 +218,7 @@ export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) { id="model-picker-title" className="font-display text-base tracking-wider uppercase" > - Switch Model + {title}

current: {currentModel || "(unknown)"} @@ -214,22 +270,28 @@ export function ModelPickerDialog({ gw, sessionId, onClose, onSubmit }: Props) {

- + {alwaysGlobal ? ( + + Saves to config.yaml — applies to new sessions. + + ) : ( + + )}
-
diff --git a/web/src/components/NouiTypography.tsx b/web/src/components/NouiTypography.tsx index 4f5547adb5..eb26d75cc1 100644 --- a/web/src/components/NouiTypography.tsx +++ b/web/src/components/NouiTypography.tsx @@ -14,7 +14,7 @@ type TypographyProps = HTMLAttributes & { }; const variantClasses: Record, string> = { - sm: "leading-1.4 text-[.9375rem] tracking-[0.1875rem]", + sm: "leading-[1.4] text-[.9375rem] tracking-[0.1875rem]", md: "text-[2.625rem] leading-[1] tracking-[0.0525rem]", lg: "text-[2.625rem] leading-[1] tracking-[0.0525rem]", xl: "text-[4.5rem] leading-[1] tracking-[0.135rem]", diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index ea31e4abf6..1aaabd0f63 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -74,6 +74,7 @@ export const en: Translations = { documentation: "Documentation", keys: "Keys", logs: "Logs", + models: "Models", profiles: "profiles : multi agents", sessions: "Sessions", skills: "Skills", @@ -173,6 +174,18 @@ export const en: Translations = { inOut: "{input} in / {output} out", }, + models: { + modelsUsed: "Models Used", + estimatedCost: "Est. Cost", + tokens: "tokens", + sessions: "sessions", + avgPerSession: "avg/session", + apiCalls: "API calls", + toolCalls: "tool calls", + noModelsData: "No model usage data for this period", + startSession: "Start a session to see model data here", + }, + logs: { title: "Logs", autoRefresh: "Auto-refresh", diff --git a/web/src/i18n/types.ts b/web/src/i18n/types.ts index 0da93a722e..bb6266a2dd 100644 --- a/web/src/i18n/types.ts +++ b/web/src/i18n/types.ts @@ -74,6 +74,7 @@ export interface Translations { documentation: string; keys: string; logs: string; + models: string; profiles: string; sessions: string; skills: string; @@ -175,6 +176,19 @@ export interface Translations { inOut: string; }; + // ── Models page ── + models: { + modelsUsed: string; + estimatedCost: string; + tokens: string; + sessions: string; + avgPerSession: string; + apiCalls: string; + toolCalls: string; + noModelsData: string; + startSession: string; + }; + // ── Logs page ── logs: { title: string; diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 8c3753e4d4..f7a7399af0 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -73,6 +73,7 @@ export const zh: Translations = { documentation: "文档", keys: "密钥", logs: "日志", + models: "模型", profiles: "多Agent配置", sessions: "会话", skills: "技能", @@ -171,6 +172,18 @@ export const zh: Translations = { inOut: "输入 {input} / 输出 {output}", }, + models: { + modelsUsed: "使用模型数", + estimatedCost: "预估费用", + tokens: "Token", + sessions: "会话", + avgPerSession: "平均/会话", + apiCalls: "API 调用", + toolCalls: "工具调用", + noModelsData: "该时间段暂无模型使用数据", + startSession: "开始会话后将在此显示模型数据", + }, + logs: { title: "日志", autoRefresh: "自动刷新", diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 5b1fa9fb24..10ed9acf89 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -63,10 +63,20 @@ export const api = { }, getAnalytics: (days: number) => fetchJSON(`/api/analytics/usage?days=${days}`), + getModelsAnalytics: (days: number) => + fetchJSON(`/api/analytics/models?days=${days}`), getConfig: () => fetchJSON>("/api/config"), getDefaults: () => fetchJSON>("/api/config/defaults"), getSchema: () => fetchJSON<{ fields: Record; category_order: string[] }>("/api/config/schema"), getModelInfo: () => fetchJSON("/api/model/info"), + getModelOptions: () => fetchJSON("/api/model/options"), + getAuxiliaryModels: () => fetchJSON("/api/model/auxiliary"), + setModelAssignment: (body: ModelAssignmentRequest) => + fetchJSON("/api/model/set", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }), saveConfig: (config: Record) => fetchJSON<{ ok: boolean }>("/api/config", { method: "PUT", @@ -145,6 +155,10 @@ export const api = { `/api/profiles/${encodeURIComponent(name)}`, { method: "DELETE" }, ), + getProfileSetupCommand: (name: string) => + fetchJSON<{ command: string }>( + `/api/profiles/${encodeURIComponent(name)}/setup-command`, + ), getProfileSoul: (name: string) => fetchJSON<{ content: string; exists: boolean }>( `/api/profiles/${encodeURIComponent(name)}/soul`, @@ -417,6 +431,46 @@ export interface ProfileInfo { skill_count: number; } +export interface ModelsAnalyticsModelEntry { + model: string; + provider: string; + input_tokens: number; + output_tokens: number; + cache_read_tokens: number; + reasoning_tokens: number; + estimated_cost: number; + actual_cost: number; + sessions: number; + api_calls: number; + tool_calls: number; + last_used_at: number; + avg_tokens_per_session: number; + capabilities: { + supports_tools?: boolean; + supports_vision?: boolean; + supports_reasoning?: boolean; + context_window?: number; + max_output_tokens?: number; + model_family?: string; + }; +} + +export interface ModelsAnalyticsResponse { + models: ModelsAnalyticsModelEntry[]; + totals: { + distinct_models: number; + total_input: number; + total_output: number; + total_cache_read: number; + total_reasoning: number; + total_estimated_cost: number; + total_actual_cost: number; + total_sessions: number; + total_api_calls: number; + }; + period_days: number; +} + export interface CronJob { id: string; name?: string; @@ -478,6 +532,54 @@ export interface ModelInfoResponse { }; } +// ── Model options / assignment types ────────────────────────────────── + +export interface ModelOptionProvider { + name: string; + slug: string; + models?: string[]; + total_models?: number; + is_current?: boolean; + is_user_defined?: boolean; + source?: string; + warning?: string; +} + +export interface ModelOptionsResponse { + model?: string; + provider?: string; + providers?: ModelOptionProvider[]; +} + +export interface AuxiliaryTaskAssignment { + task: string; + provider: string; + model: string; + base_url: string; +} + +export interface AuxiliaryModelsResponse { + tasks: AuxiliaryTaskAssignment[]; + main: { provider: string; model: string }; +} + +export interface ModelAssignmentRequest { + scope: "main" | "auxiliary"; + provider: string; + model: string; + /** For auxiliary: task slot name, "" for all, "__reset__" to reset all. */ + task?: string; +} + +export interface ModelAssignmentResponse { + ok: boolean; + scope?: string; + provider?: string; + model?: string; + tasks?: string[]; + reset?: boolean; +} + // ── OAuth provider types ──────────────────────────────────────────────── export interface OAuthProviderStatus { diff --git a/web/src/pages/ModelsPage.tsx b/web/src/pages/ModelsPage.tsx new file mode 100644 index 0000000000..72b082f629 --- /dev/null +++ b/web/src/pages/ModelsPage.tsx @@ -0,0 +1,817 @@ +import { useCallback, useEffect, useLayoutEffect, useState } from "react"; +import { + Brain, + ChevronDown, + Cpu, + DollarSign, + Eye, + RefreshCw, + Settings2, + Star, + Wrench, + Zap, +} from "lucide-react"; +import { api } from "@/lib/api"; +import type { + AuxiliaryModelsResponse, + AuxiliaryTaskAssignment, + ModelsAnalyticsModelEntry, + ModelsAnalyticsResponse, +} from "@/lib/api"; +import { timeAgo } from "@/lib/utils"; +import { formatTokenCount } from "@/lib/format"; +import { Button } from "@nous-research/ui/ui/components/button"; +import { Spinner } from "@nous-research/ui/ui/components/spinner"; +import { Stats } from "@nous-research/ui/ui/components/stats"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Badge } from "@nous-research/ui/ui/components/badge"; +import { usePageHeader } from "@/contexts/usePageHeader"; +import { useI18n } from "@/i18n"; +import { PluginSlot } from "@/plugins"; +import { ModelPickerDialog } from "@/components/ModelPickerDialog"; + +const PERIODS = [ + { label: "7d", days: 7 }, + { label: "30d", days: 30 }, + { label: "90d", days: 90 }, +] as const; + +// Must match _AUX_TASK_SLOTS in hermes_cli/web_server.py. +const AUX_TASKS: readonly { key: string; label: string; hint: string }[] = [ + { key: "vision", label: "Vision", hint: "Image analysis" }, + { key: "web_extract", label: "Web Extract", hint: "Page summarization" }, + { key: "compression", label: "Compression", hint: "Context compaction" }, + { key: "session_search", label: "Session Search", hint: "Recall queries" }, + { key: "skills_hub", label: "Skills Hub", hint: "Skill search" }, + { key: "approval", label: "Approval", hint: "Smart auto-approve" }, + { key: "mcp", label: "MCP", hint: "MCP tool routing" }, + { key: "title_generation", label: "Title Gen", hint: "Session titles" }, + { key: "curator", label: "Curator", hint: "Skill-usage review" }, +] as const; + +function formatTokens(n: number): string { + if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`; + if (n >= 1_000) return `${(n / 1_000).toFixed(1)}K`; + return String(n); +} + +function formatCost(n: number): string { + if (n >= 1) return `$${n.toFixed(2)}`; + if (n >= 0.01) return `$${n.toFixed(3)}`; + if (n > 0) return `$${n.toFixed(4)}`; + return "$0"; +} + +/** Short model name: strip vendor prefix like "openrouter/" or "anthropic/". */ +function shortModelName(model: string): string { + const slashIdx = model.indexOf("/"); + if (slashIdx > 0) return model.slice(slashIdx + 1); + return model; +} + +/** Extract vendor prefix from a model string like "anthropic/claude-opus-4.7" → "anthropic". */ +function modelVendor(model: string, fallback?: string): string { + const slashIdx = model.indexOf("/"); + if (slashIdx > 0) return model.slice(0, slashIdx); + return fallback || ""; +} + +function TokenBar({ + input, + output, + cacheRead, + reasoning, +}: { + input: number; + output: number; + cacheRead: number; + reasoning: number; +}) { + const total = input + output + cacheRead + reasoning; + if (total === 0) return null; + + const segments = [ + { value: cacheRead, color: "bg-blue-400/60", label: "Cache Read" }, + { value: reasoning, color: "bg-purple-400/60", label: "Reasoning" }, + { value: input, color: "bg-[#ffe6cb]/70", label: "Input" }, + { value: output, color: "bg-emerald-500/70", label: "Output" }, + ].filter((s) => s.value > 0); + + return ( +
+
+ {segments.map((s, i) => ( +
+ ))} +
+
+ {segments.map((s, i) => ( + + + {s.label} {formatTokens(s.value)} + + ))} +
+
+ ); +} + +function CapabilityBadges({ + capabilities, +}: { + capabilities: ModelsAnalyticsModelEntry["capabilities"]; +}) { + const hasAny = + capabilities.supports_tools || + capabilities.supports_vision || + capabilities.supports_reasoning || + capabilities.model_family; + if (!hasAny) return null; + + return ( +
+ {capabilities.supports_tools && ( + + Tools + + )} + {capabilities.supports_vision && ( + + Vision + + )} + {capabilities.supports_reasoning && ( + + Reasoning + + )} + {capabilities.model_family && ( + + {capabilities.model_family} + + )} +
+ ); +} + +/* ──────────────────────────────────────────────────────────────────── */ +/* Per-card "Use as" menu */ +/* ──────────────────────────────────────────────────────────────────── */ + +function UseAsMenu({ + provider, + model, + isMain, + mainAuxTask, + onAssigned, +}: { + provider: string; + model: string; + /** True when this card's model+provider match config.yaml's main slot. */ + isMain: boolean; + /** If this model is assigned to a specific aux task, that task's key. */ + mainAuxTask: string | null; + onAssigned(): void; +}) { + const [open, setOpen] = useState(false); + const [busy, setBusy] = useState(false); + const [error, setError] = useState(null); + + const assign = async ( + scope: "main" | "auxiliary", + task: string, + ) => { + if (!provider || !model) { + setError("Missing provider/model"); + return; + } + setBusy(true); + setError(null); + try { + await api.setModelAssignment({ scope, provider, model, task }); + onAssigned(); + setOpen(false); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setBusy(false); + } + }; + + // Close on outside click. + useEffect(() => { + if (!open) return; + const onDown = (e: MouseEvent) => { + const target = e.target as HTMLElement | null; + if (target && !target.closest?.("[data-use-as-menu]")) setOpen(false); + }; + window.addEventListener("mousedown", onDown); + return () => window.removeEventListener("mousedown", onDown); + }, [open]); + + return ( +
+ + {open && ( +
+ + +
+ Auxiliary task +
+ + + + {AUX_TASKS.map((t) => ( + + ))} + + {error && ( +
+ {error} +
+ )} +
+ )} +
+ ); +} + +/* ──────────────────────────────────────────────────────────────────── */ +/* ModelCard */ +/* ──────────────────────────────────────────────────────────────────── */ + +function ModelCard({ + entry, + rank, + main, + aux, + onAssigned, +}: { + entry: ModelsAnalyticsModelEntry; + rank: number; + main: { provider: string; model: string } | null; + aux: AuxiliaryTaskAssignment[]; + onAssigned(): void; +}) { + const { t } = useI18n(); + const provider = entry.provider || modelVendor(entry.model); + const totalTokens = entry.input_tokens + entry.output_tokens; + const caps = entry.capabilities; + + const isMain = + !!main && + main.provider === provider && + main.model === entry.model; + + // First aux task currently using this model (if any). + const mainAuxTask = + aux.find( + (a) => a.provider === provider && a.model === entry.model, + )?.task ?? null; + + return ( + + +
+
+
+ + #{rank} + + + {shortModelName(entry.model)} + + {isMain && ( + + main + + )} + {mainAuxTask && ( + + aux · {mainAuxTask} + + )} +
+
+ {provider && ( + + {provider} + + )} + {caps.context_window && caps.context_window > 0 && ( + + {formatTokenCount(caps.context_window)} ctx + + )} + {caps.max_output_tokens && caps.max_output_tokens > 0 && ( + + {formatTokenCount(caps.max_output_tokens)} out + + )} +
+
+
+
+
+ {formatTokens(totalTokens)} +
+
+ {t.models.tokens} +
+
+ +
+
+
+ + + +
+
+
{entry.sessions}
+
+ {t.models.sessions} +
+
+
+
+ {formatTokens(entry.avg_tokens_per_session)} +
+
+ {t.models.avgPerSession} +
+
+
+
+ {entry.api_calls > 0 ? formatTokens(entry.api_calls) : "—"} +
+
+ {t.models.apiCalls} +
+
+
+ +
+
+ {entry.estimated_cost > 0 && ( + + + {formatCost(entry.estimated_cost)} + + )} + {entry.tool_calls > 0 && ( + + + {entry.tool_calls} {t.models.toolCalls} + + )} +
+ {entry.last_used_at > 0 && ( + {timeAgo(entry.last_used_at)} + )} +
+ + +
+
+ ); +} + +/* ──────────────────────────────────────────────────────────────────── */ +/* Model Settings panel (top of page) */ +/* ──────────────────────────────────────────────────────────────────── */ + +type PickerTarget = + | { kind: "main" } + | { kind: "aux"; task: string }; + +function ModelSettingsPanel({ + aux, + refreshKey, + onSaved, +}: { + aux: AuxiliaryModelsResponse | null; + refreshKey: number; + onSaved(): void; +}) { + const [expanded, setExpanded] = useState(false); + const [picker, setPicker] = useState(null); + const [resetBusy, setResetBusy] = useState(false); + + const mainProv = aux?.main.provider ?? ""; + const mainModel = aux?.main.model ?? ""; + + const applyAssignment = async ({ + scope, + task, + provider, + model, + }: { + scope: "main" | "auxiliary"; + task: string; + provider: string; + model: string; + }) => { + await api.setModelAssignment({ scope, task, provider, model }); + onSaved(); + }; + + const resetAllAux = async () => { + if (!window.confirm("Reset every auxiliary task to 'auto'? This overrides any per-task overrides you've set.")) { + return; + } + setResetBusy(true); + try { + await api.setModelAssignment({ + scope: "auxiliary", + task: "__reset__", + provider: "", + model: "", + }); + onSaved(); + } finally { + setResetBusy(false); + } + }; + + return ( + + +
+
+ + Model Settings + + applies to new sessions + +
+ +
+
+ + + {/* Main row */} +
+
+
+ + + Main model + +
+
+ {mainProv || "(unset)"} + {mainProv && mainModel && " · "} + {mainModel || "(unset)"} +
+
+ +
+ + {/* Auxiliary rows */} + {expanded && ( +
+
+
+ Auxiliary tasks +
+ +
+ +

+ Auxiliary tasks handle side-jobs like vision, session search, and + compression. auto means + "use the main model". Override per-task when you want a + cheap/fast model for a specific job. +

+ + {AUX_TASKS.map((t) => { + const cur = aux?.tasks.find((a) => a.task === t.key); + const isAuto = + !cur || cur.provider === "auto" || !cur.provider; + return ( +
+
+
+ {t.label} + + {t.hint} + +
+
+ {isAuto + ? "auto (use main model)" + : `${cur?.provider} · ${cur?.model || "(provider default)"}`} +
+
+ +
+ ); + })} +
+ )} + + {picker && ( + t.key === picker.task)?.label ?? + picker.task + }` + } + onApply={async ({ provider, model }) => { + await applyAssignment({ + scope: picker.kind === "main" ? "main" : "auxiliary", + task: picker.kind === "main" ? "" : picker.task, + provider, + model, + }); + }} + onClose={() => setPicker(null)} + /> + )} +
+
+ ); +} + +/* ──────────────────────────────────────────────────────────────────── */ +/* Page */ +/* ──────────────────────────────────────────────────────────────────── */ + +export default function ModelsPage() { + const [days, setDays] = useState(30); + const [data, setData] = useState(null); + const [aux, setAux] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [saveKey, setSaveKey] = useState(0); + const { t } = useI18n(); + const { setAfterTitle, setEnd } = usePageHeader(); + + const load = useCallback(() => { + setLoading(true); + setError(null); + Promise.all([ + api.getModelsAnalytics(days), + api.getAuxiliaryModels().catch(() => null), + ]) + .then(([models, auxData]) => { + setData(models); + setAux(auxData); + }) + .catch((err) => setError(String(err))) + .finally(() => setLoading(false)); + }, [days]); + + const onAssigned = useCallback(() => { + // Reload aux state after any assignment change. + api + .getAuxiliaryModels() + .then(setAux) + .catch(() => {}); + setSaveKey((k) => k + 1); + }, []); + + useLayoutEffect(() => { + const periodLabel = + PERIODS.find((p) => p.days === days)?.label ?? `${days}d`; + setAfterTitle( + + {loading && } + + {periodLabel} + + , + ); + setEnd( +
+
+ {PERIODS.map((p) => ( + + ))} +
+ +
, + ); + return () => { + setAfterTitle(null); + setEnd(null); + }; + }, [days, loading, load, setAfterTitle, setEnd, t.common.refresh]); + + useEffect(() => { + load(); + }, [load]); + + return ( +
+ + + + + {loading && !data && ( +
+ +
+ )} + + {error && ( + + +

{error}

+
+
+ )} + + {data && ( + <> + + + + + + + {data.models.length > 0 ? ( +
+ {data.models.map((m, i) => ( + + ))} +
+ ) : ( + + +
+ +

{t.models.noModelsData}

+

+ {t.models.startSession} +

+
+
+
+ )} + + )} + + +
+ ); +} diff --git a/web/src/pages/ProfilesPage.tsx b/web/src/pages/ProfilesPage.tsx index 7adf9ae510..e8dbfe0737 100644 --- a/web/src/pages/ProfilesPage.tsx +++ b/web/src/pages/ProfilesPage.tsx @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { ChevronDown, Pencil, Plus, Terminal, Trash2, Users } from "lucide-react"; import { H2 } from "@/components/NouiTypography"; import { api } from "@/lib/api"; @@ -37,6 +37,9 @@ export default function ProfilesPage() { const [editingSoulFor, setEditingSoulFor] = useState(null); const [soulText, setSoulText] = useState(""); const [soulSaving, setSoulSaving] = useState(false); + // Tracks the latest SOUL request so out-of-order responses don't overwrite + // newer state when the user switches profiles or closes the editor. + const activeSoulRequest = useRef(null); const load = useCallback(() => { api @@ -99,16 +102,22 @@ export default function ProfilesPage() { const openSoulEditor = useCallback( async (name: string) => { if (editingSoulFor === name) { + activeSoulRequest.current = null; setEditingSoulFor(null); return; } setEditingSoulFor(name); setSoulText(""); + activeSoulRequest.current = name; try { const soul = await api.getProfileSoul(name); - setSoulText(soul.content); + if (activeSoulRequest.current === name) { + setSoulText(soul.content); + } } catch (e) { - showToast(`${t.status.error}: ${e}`, "error"); + if (activeSoulRequest.current === name) { + showToast(`${t.status.error}: ${e}`, "error"); + } } }, [editingSoulFor, showToast, t.status.error], @@ -127,7 +136,14 @@ export default function ProfilesPage() { }; const handleCopyTerminalCommand = async (name: string) => { - const cmd = name === "default" ? "hermes setup" : `${name} setup`; + let cmd: string; + try { + const res = await api.getProfileSetupCommand(name); + cmd = res.command; + } catch (e) { + showToast(`${t.status.error}: ${e}`, "error"); + return; + } try { await navigator.clipboard.writeText(cmd); showToast(`${t.profiles.commandCopied}: ${cmd}`, "success"); @@ -395,10 +411,14 @@ export default function ProfilesPage() { {isEditingSoul && (
-