diff --git a/Dockerfile b/Dockerfile index 0eddaba0bc..5c57897f57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,14 +6,15 @@ ENV PYTHONUNBUFFERED=1 # Install system dependencies in one layer, clear APT cache RUN apt-get update && \ apt-get install -y --no-install-recommends \ - build-essential nodejs npm python3 python3-pip ripgrep ffmpeg gcc python3-dev libffi-dev && \ + build-essential nodejs npm python3 python3-pip ripgrep ffmpeg gcc python3-dev libffi-dev procps && \ rm -rf /var/lib/apt/lists/* COPY . /opt/hermes WORKDIR /opt/hermes # Install Python and Node dependencies in one layer, no cache -RUN pip install --no-cache-dir -e ".[all]" --break-system-packages && \ +RUN pip install --no-cache-dir uv --break-system-packages && \ + uv pip install --system --break-system-packages --no-cache -e ".[all]" && \ npm install --prefer-offline --no-audit && \ npx playwright install --with-deps chromium --only-shell && \ cd /opt/hermes/scripts/whatsapp-bridge && \ diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 11064a1e4e..29f9a10e8b 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -36,6 +36,7 @@ from acp.schema import ( SessionCapabilities, SessionForkCapabilities, SessionListCapabilities, + SessionResumeCapabilities, SessionInfo, TextContentBlock, UnstructuredCommandInput, @@ -245,9 +246,11 @@ class HermesACPAgent(acp.Agent): protocol_version=acp.PROTOCOL_VERSION, agent_info=Implementation(name="hermes-agent", version=HERMES_VERSION), agent_capabilities=AgentCapabilities( + load_session=True, session_capabilities=SessionCapabilities( fork=SessionForkCapabilities(), list=SessionListCapabilities(), + resume=SessionResumeCapabilities(), ), ), auth_methods=auth_methods, @@ -451,14 +454,13 @@ class HermesACPAgent(acp.Agent): await conn.session_update(session_id, update) usage = None - usage_data = result.get("usage") - if usage_data and isinstance(usage_data, dict): + if any(result.get(key) is not None for key in ("prompt_tokens", "completion_tokens", "total_tokens")): usage = Usage( - input_tokens=usage_data.get("prompt_tokens", 0), - output_tokens=usage_data.get("completion_tokens", 0), - total_tokens=usage_data.get("total_tokens", 0), - thought_tokens=usage_data.get("reasoning_tokens"), - cached_read_tokens=usage_data.get("cached_tokens"), + input_tokens=result.get("prompt_tokens", 0), + output_tokens=result.get("completion_tokens", 0), + total_tokens=result.get("total_tokens", 0), + thought_tokens=result.get("reasoning_tokens"), + cached_read_tokens=result.get("cache_read_tokens"), ) stop_reason = "cancelled" if state.cancel_event and state.cancel_event.is_set() else "end_turn" diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index d5c0c06fbb..830c0f4de7 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -60,6 +60,8 @@ _ANTHROPIC_OUTPUT_LIMITS = { "claude-3-opus": 4_096, "claude-3-sonnet": 4_096, "claude-3-haiku": 4_096, + # Third-party Anthropic-compatible providers + "minimax": 131_072, } # For any model not in the table, assume the highest current limit. @@ -74,8 +76,11 @@ def _get_anthropic_max_output(model: str) -> int: model IDs (claude-sonnet-4-5-20250929) and variant suffixes (:1m, :fast) resolve correctly. Longest-prefix match wins to avoid e.g. "claude-3-5" matching before "claude-3-5-sonnet". + + Normalizes dots to hyphens so that model names like + ``anthropic/claude-opus-4.6`` match the ``claude-opus-4-6`` table key. """ - m = model.lower() + m = model.lower().replace(".", "-") best_key = "" best_val = _ANTHROPIC_DEFAULT_OUTPUT_LIMIT for key, val in _ANTHROPIC_OUTPUT_LIMITS.items(): @@ -95,6 +100,15 @@ _COMMON_BETAS = [ "interleaved-thinking-2025-05-14", "fine-grained-tool-streaming-2025-05-14", ] +# MiniMax's Anthropic-compatible endpoints fail tool-use requests when +# the fine-grained tool streaming beta is present. Omit it so tool calls +# fall back to the provider's default response path. +_TOOL_STREAMING_BETA = "fine-grained-tool-streaming-2025-05-14" + +# Fast mode beta — enables the ``speed: "fast"`` request parameter for +# significantly higher output token throughput on Opus 4.6 (~2.5x). +# See https://platform.claude.com/docs/en/build-with-claude/fast-mode +_FAST_MODE_BETA = "fast-mode-2026-02-01" # Additional beta headers required for OAuth/subscription auth. # Matches what Claude Code (and pi-ai / OpenCode) send. @@ -149,18 +163,27 @@ def _get_claude_code_version() -> str: def _is_oauth_token(key: str) -> bool: - """Check if the key is an OAuth/setup token (not a regular Console API key). + """Check if the key is an Anthropic OAuth/setup token. - Regular API keys start with 'sk-ant-api'. Everything else (setup-tokens - starting with 'sk-ant-oat', managed keys, JWTs, etc.) needs Bearer auth. + Positively identifies Anthropic OAuth tokens by their key format: + - ``sk-ant-`` prefix (but NOT ``sk-ant-api``) → setup tokens, managed keys + - ``eyJ`` prefix → JWTs from the Anthropic OAuth flow + + Non-Anthropic keys (MiniMax, Alibaba, etc.) don't match either pattern + and correctly return False. """ if not key: return False - # Regular Console API keys use x-api-key header + # Regular Anthropic Console API keys — x-api-key auth, never OAuth if key.startswith("sk-ant-api"): return False - # Everything else (setup-tokens, managed keys, JWTs) uses Bearer auth - return True + # Anthropic-issued tokens (setup-tokens sk-ant-oat-*, managed keys) + if key.startswith("sk-ant-"): + return True + # JWTs from Anthropic OAuth flow + if key.startswith("eyJ"): + return True + return False def _normalize_base_url_text(base_url) -> str: @@ -204,6 +227,19 @@ def _requires_bearer_auth(base_url: str | None) -> bool: return normalized.startswith(("https://api.minimax.io/anthropic", "https://api.minimaxi.com/anthropic")) +def _common_betas_for_base_url(base_url: str | None) -> list[str]: + """Return the beta headers that are safe for the configured endpoint. + + MiniMax's Anthropic-compatible endpoints (Bearer-auth) reject requests + that include Anthropic's ``fine-grained-tool-streaming`` beta — every + tool-use message triggers a connection error. Strip that beta for + Bearer-auth endpoints while keeping all other betas intact. + """ + if _requires_bearer_auth(base_url): + return [b for b in _COMMON_BETAS if b != _TOOL_STREAMING_BETA] + return _COMMON_BETAS + + def build_anthropic_client(api_key: str, base_url: str = None): """Create an Anthropic client, auto-detecting setup-tokens vs API keys. @@ -222,6 +258,7 @@ def build_anthropic_client(api_key: str, base_url: str = None): } if normalized_base_url: kwargs["base_url"] = normalized_base_url + common_betas = _common_betas_for_base_url(normalized_base_url) if _requires_bearer_auth(normalized_base_url): # Some Anthropic-compatible providers (e.g. MiniMax) expect the API key in @@ -231,21 +268,21 @@ def build_anthropic_client(api_key: str, base_url: str = None): # not use Anthropic's sk-ant-api prefix and would otherwise be misread as # Anthropic OAuth/setup tokens. kwargs["auth_token"] = api_key - if _COMMON_BETAS: - kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)} + if common_betas: + kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)} elif _is_third_party_anthropic_endpoint(base_url): # Third-party proxies (Azure AI Foundry, AWS Bedrock, etc.) use their # own API keys with x-api-key auth. Skip OAuth detection — their keys # don't follow Anthropic's sk-ant-* prefix convention and would be # misclassified as OAuth tokens. kwargs["api_key"] = api_key - if _COMMON_BETAS: - kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)} + if common_betas: + kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)} elif _is_oauth_token(api_key): # OAuth access token / setup-token → Bearer auth + Claude Code identity. # Anthropic routes OAuth requests based on user-agent and headers; # without Claude Code's fingerprint, requests get intermittent 500s. - all_betas = _COMMON_BETAS + _OAUTH_ONLY_BETAS + all_betas = common_betas + _OAUTH_ONLY_BETAS kwargs["auth_token"] = api_key kwargs["default_headers"] = { "anthropic-beta": ",".join(all_betas), @@ -255,8 +292,8 @@ def build_anthropic_client(api_key: str, base_url: str = None): else: # Regular API key → x-api-key header + common betas kwargs["api_key"] = api_key - if _COMMON_BETAS: - kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)} + if common_betas: + kwargs["default_headers"] = {"anthropic-beta": ",".join(common_betas)} return _anthropic_sdk.Anthropic(**kwargs) @@ -485,35 +522,6 @@ def _prefer_refreshable_claude_code_token(env_token: str, creds: Optional[Dict[s return None -def get_anthropic_token_source(token: Optional[str] = None) -> str: - """Best-effort source classification for an Anthropic credential token.""" - token = (token or "").strip() - if not token: - return "none" - - env_token = os.getenv("ANTHROPIC_TOKEN", "").strip() - if env_token and env_token == token: - return "anthropic_token_env" - - cc_env_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN", "").strip() - if cc_env_token and cc_env_token == token: - return "claude_code_oauth_token_env" - - creds = read_claude_code_credentials() - if creds and creds.get("accessToken") == token: - return str(creds.get("source") or "claude_code_credentials") - - managed_key = read_claude_managed_key() - if managed_key and managed_key == token: - return "claude_json_primary_api_key" - - api_key = os.getenv("ANTHROPIC_API_KEY", "").strip() - if api_key and api_key == token: - return "anthropic_api_key_env" - - return "unknown" - - def resolve_anthropic_token() -> Optional[str]: """Resolve an Anthropic token from all available sources. @@ -720,21 +728,6 @@ def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]: } -def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None: - """Save OAuth credentials to ~/.hermes/.anthropic_oauth.json.""" - data = { - "accessToken": access_token, - "refreshToken": refresh_token, - "expiresAt": expires_at_ms, - } - try: - _HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True) - _HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8") - _HERMES_OAUTH_FILE.chmod(0o600) - except (OSError, IOError) as e: - logger.debug("Failed to save Hermes OAuth credentials: %s", e) - - def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]: """Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json.""" if _HERMES_OAUTH_FILE.exists(): @@ -783,39 +776,6 @@ def _sanitize_tool_id(tool_id: str) -> str: return sanitized or "tool_0" -def _convert_openai_image_part_to_anthropic(part: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Convert an OpenAI-style image block to Anthropic's image source format.""" - image_data = part.get("image_url", {}) - url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data) - if not isinstance(url, str) or not url.strip(): - return None - url = url.strip() - - if url.startswith("data:"): - header, sep, data = url.partition(",") - if sep and ";base64" in header: - media_type = header[5:].split(";", 1)[0] or "image/png" - return { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": data, - }, - } - - if url.startswith(("http://", "https://")): - return { - "type": "image", - "source": { - "type": "url", - "url": url, - }, - } - - return None - - def convert_tools_to_anthropic(tools: List[Dict]) -> List[Dict]: """Convert OpenAI tool definitions to Anthropic format.""" if not tools: @@ -1235,6 +1195,7 @@ def build_anthropic_kwargs( preserve_dots: bool = False, context_length: Optional[int] = None, base_url: str | None = None, + fast_mode: bool = False, ) -> Dict[str, Any]: """Build kwargs for anthropic.messages.create(). @@ -1268,6 +1229,10 @@ def build_anthropic_kwargs( When *base_url* points to a third-party Anthropic-compatible endpoint, thinking block signatures are stripped (they are Anthropic-proprietary). + + When *fast_mode* is True, adds ``speed: "fast"`` and the fast-mode beta + header for ~2.5x faster output throughput on Opus 4.6. Currently only + supported on native Anthropic endpoints (not third-party compatible ones). """ system, anthropic_messages = convert_messages_to_anthropic(messages, base_url=base_url) anthropic_tools = convert_tools_to_anthropic(tools) if tools else [] @@ -1350,9 +1315,10 @@ def build_anthropic_kwargs( # Map reasoning_config to Anthropic's thinking parameter. # Claude 4.6 models use adaptive thinking + output_config.effort. # Older models use manual thinking with budget_tokens. - # Haiku and MiniMax models do NOT support extended thinking — skip entirely. + # MiniMax Anthropic-compat endpoints support thinking (manual mode only, + # not adaptive). Haiku does NOT support extended thinking — skip entirely. if reasoning_config and isinstance(reasoning_config, dict): - if reasoning_config.get("enabled") is not False and "haiku" not in model.lower() and "minimax" not in model.lower(): + if reasoning_config.get("enabled") is not False and "haiku" not in model.lower(): effort = str(reasoning_config.get("effort", "medium")).lower() budget = THINKING_BUDGET.get(effort, 8000) if _supports_adaptive_thinking(model): @@ -1366,6 +1332,20 @@ def build_anthropic_kwargs( kwargs["temperature"] = 1 kwargs["max_tokens"] = max(effective_max_tokens, budget + 4096) + # ── Fast mode (Opus 4.6 only) ──────────────────────────────────── + # Adds speed:"fast" + the fast-mode beta header for ~2.5x output speed. + # Only for native Anthropic endpoints — third-party providers would + # reject the unknown beta header and speed parameter. + if fast_mode and not _is_third_party_anthropic_endpoint(base_url): + kwargs["speed"] = "fast" + # Build extra_headers with ALL applicable betas (the per-request + # extra_headers override the client-level anthropic-beta header). + betas = list(_common_betas_for_base_url(base_url)) + if is_oauth: + betas.extend(_OAUTH_ONLY_BETAS) + betas.append(_FAST_MODE_BETA) + kwargs["extra_headers"] = {"anthropic-beta": ",".join(betas)} + return kwargs @@ -1427,4 +1407,4 @@ def normalize_anthropic_response( reasoning_details=reasoning_details or None, ), finish_reason, - ) \ No newline at end of file + ) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index a757f42699..e48f9c2c3e 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -59,6 +59,9 @@ from hermes_constants import OPENROUTER_BASE_URL logger = logging.getLogger(__name__) +# Module-level flag: only warn once per process about stale OPENAI_BASE_URL. +_stale_base_url_warned = False + _PROVIDER_ALIASES = { "google": "gemini", "google-gemini": "gemini", @@ -687,6 +690,15 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: if pconfig.auth_type != "api_key": continue if provider_id == "anthropic": + # Only try anthropic when the user has explicitly configured it. + # Without this gate, Claude Code credentials get silently used + # as auxiliary fallback when the user's primary provider fails. + try: + from hermes_cli.auth import is_provider_explicitly_configured + if not is_provider_explicitly_configured("anthropic"): + continue + except ImportError: + pass return _try_anthropic() pool_present, entry = _select_pool_entry(provider_id) @@ -698,11 +710,13 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: base_url = _to_openai_base_url( _pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url ) - model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default") + model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id) + if model is None: + continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model) extra = {} if "api.kimi.com" in base_url.lower(): - extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"} + extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} elif "api.githubcopilot.com" in base_url.lower(): from hermes_cli.models import copilot_default_headers @@ -717,11 +731,13 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: base_url = _to_openai_base_url( str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url ) - model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default") + model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id) + if model is None: + continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model) extra = {} if "api.kimi.com" in base_url.lower(): - extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"} + extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} elif "api.githubcopilot.com" in base_url.lower(): from hermes_cli.models import copilot_default_headers @@ -848,7 +864,7 @@ def _read_main_provider() -> str: return "" -def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]: +def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[str]]: """Resolve the active custom/main endpoint the same way the main CLI does. This covers both env-driven OPENAI_BASE_URL setups and config-saved custom @@ -861,18 +877,29 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]: runtime = resolve_runtime_provider(requested="custom") except Exception as exc: logger.debug("Auxiliary client: custom runtime resolution failed: %s", exc) - return None, None + runtime = None + + if not isinstance(runtime, dict): + openai_base = os.getenv("OPENAI_BASE_URL", "").strip().rstrip("/") + openai_key = os.getenv("OPENAI_API_KEY", "").strip() + if not openai_base: + return None, None, None + runtime = { + "base_url": openai_base, + "api_key": openai_key, + } custom_base = runtime.get("base_url") custom_key = runtime.get("api_key") + custom_mode = runtime.get("api_mode") if not isinstance(custom_base, str) or not custom_base.strip(): - return None, None + return None, None, None custom_base = custom_base.strip().rstrip("/") if "openrouter.ai" in custom_base.lower(): # requested='custom' falls back to OpenRouter when no custom endpoint is # configured. Treat that as "no custom endpoint" for auxiliary routing. - return None, None + return None, None, None # Local servers (Ollama, llama.cpp, vLLM, LM Studio) don't require auth. # Use a placeholder key — the OpenAI SDK requires a non-empty string but @@ -881,20 +908,33 @@ def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str]]: if not isinstance(custom_key, str) or not custom_key.strip(): custom_key = "no-key-required" - return custom_base, custom_key.strip() + if not isinstance(custom_mode, str) or not custom_mode.strip(): + custom_mode = None + + return custom_base, custom_key.strip(), custom_mode def _current_custom_base_url() -> str: - custom_base, _ = _resolve_custom_runtime() + custom_base, _, _ = _resolve_custom_runtime() return custom_base or "" def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]: - custom_base, custom_key = _resolve_custom_runtime() + runtime = _resolve_custom_runtime() + if len(runtime) == 2: + custom_base, custom_key = runtime + custom_mode = None + else: + custom_base, custom_key, custom_mode = runtime if not custom_base or not custom_key: return None, None + if custom_base.lower().startswith(_CODEX_AUX_BASE_URL.lower()): + return None, None model = _read_main_model() or "gpt-4o-mini" - logger.debug("Auxiliary client: custom endpoint (%s)", model) + logger.debug("Auxiliary client: custom endpoint (%s, api_mode=%s)", model, custom_mode or "chat_completions") + if custom_mode == "codex_responses": + real_client = OpenAI(api_key=custom_key, base_url=custom_base) + return CodexAuxiliaryClient(real_client, model), model return OpenAI(api_key=custom_key, base_url=custom_base), model @@ -967,40 +1007,6 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]: return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model -def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]: - """Resolve a specific forced provider. Returns (None, None) if creds missing.""" - if forced == "openrouter": - client, model = _try_openrouter() - if client is None: - logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set") - return client, model - - if forced == "nous": - client, model = _try_nous() - if client is None: - logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes auth)") - return client, model - - if forced == "codex": - client, model = _try_codex() - if client is None: - logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)") - return client, model - - if forced == "main": - # "main" = skip OpenRouter/Nous, use the main chat model's credentials. - for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider): - client, model = try_fn() - if client is not None: - return client, model - logger.warning("auxiliary.provider=main but no main endpoint credentials found") - return None, None - - # Unknown provider name — fall through to auto - logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced) - return None, None - - _AUTO_PROVIDER_LABELS = { "_try_openrouter": "openrouter", "_try_nous": "nous", @@ -1076,11 +1082,12 @@ def _is_connection_error(exc: Exception) -> bool: def _try_payment_fallback( failed_provider: str, task: str = None, + reason: str = "payment error", ) -> Tuple[Optional[Any], Optional[str], str]: - """Try alternative providers after a payment/credit error. + """Try alternative providers after a payment/credit or connection error. Iterates the standard auto-detection chain, skipping the provider that - returned a payment error. + failed. Returns: (client, model, provider_label) or (None, None, "") if no fallback. @@ -1106,15 +1113,15 @@ def _try_payment_fallback( client, model = try_fn() if client is not None: logger.info( - "Auxiliary %s: payment error on %s — falling back to %s (%s)", - task or "call", failed_provider, label, model or "default", + "Auxiliary %s: %s on %s — falling back to %s (%s)", + task or "call", reason, failed_provider, label, model or "default", ) return client, model, label tried.append(label) logger.warning( - "Auxiliary %s: payment error on %s and no fallback available (tried: %s)", - task or "call", failed_provider, ", ".join(tried), + "Auxiliary %s: %s on %s and no fallback available (tried: %s)", + task or "call", reason, failed_provider, ", ".join(tried), ) return None, None, "" @@ -1129,9 +1136,28 @@ def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: provider they already have credentials for — no OpenRouter key needed. 2. OpenRouter → Nous → custom → Codex → API-key providers (original chain). """ - global auxiliary_is_nous + global auxiliary_is_nous, _stale_base_url_warned auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins + # ── Warn once if OPENAI_BASE_URL is set but config.yaml uses a named + # provider (not 'custom'). This catches the common "env poisoning" + # scenario where a user switches providers via `hermes model` but the + # old OPENAI_BASE_URL lingers in ~/.hermes/.env. ── + if not _stale_base_url_warned: + _env_base = os.getenv("OPENAI_BASE_URL", "").strip() + _cfg_provider = _read_main_provider() + if (_env_base and _cfg_provider + and _cfg_provider != "custom" + and not _cfg_provider.startswith("custom:")): + logger.warning( + "OPENAI_BASE_URL is set (%s) but model.provider is '%s'. " + "Auxiliary clients may route to the wrong endpoint. " + "Run: hermes model to reconfigure, or remove " + "OPENAI_BASE_URL from ~/.hermes/.env", + _env_base, _cfg_provider, + ) + _stale_base_url_warned = True + # ── Step 1: non-aggregator main provider → use main model directly ── main_provider = _read_main_provider() main_model = _read_main_model() @@ -1195,10 +1221,22 @@ def _to_async_client(sync_client, model: str): async_kwargs["default_headers"] = copilot_default_headers() elif "api.kimi.com" in base_lower: - async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"} + async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} return AsyncOpenAI(**async_kwargs), model +def _normalize_resolved_model(model_name: Optional[str], provider: str) -> Optional[str]: + """Normalize a resolved model for the provider that will receive it.""" + if not model_name: + return model_name + try: + from hermes_cli.model_normalize import normalize_model_for_provider + + return normalize_model_for_provider(model_name, provider) + except Exception: + return model_name + + def resolve_provider_client( provider: str, model: str = None, @@ -1206,6 +1244,7 @@ def resolve_provider_client( raw_codex: bool = False, explicit_base_url: str = None, explicit_api_key: str = None, + api_mode: str = None, ) -> Tuple[Optional[Any], Optional[str]]: """Central router: given a provider name and optional model, return a configured client with the correct auth, base URL, and API format. @@ -1229,6 +1268,10 @@ def resolve_provider_client( the main agent loop). explicit_base_url: Optional direct OpenAI-compatible endpoint. explicit_api_key: Optional API key paired with explicit_base_url. + api_mode: API mode override. One of "chat_completions", + "codex_responses", or None (auto-detect). When set to + "codex_responses", the client is wrapped in + CodexAuxiliaryClient to route through the Responses API. Returns: (client, resolved_model) or (None, None) if auth is unavailable. @@ -1236,6 +1279,40 @@ def resolve_provider_client( # Normalise aliases provider = _normalize_aux_provider(provider) + def _needs_codex_wrap(client_obj, base_url_str: str, model_str: str) -> bool: + """Decide if a plain OpenAI client should be wrapped for Responses API. + + Returns True when api_mode is explicitly "codex_responses", or when + auto-detection (api.openai.com + codex-family model) suggests it. + Already-wrapped clients (CodexAuxiliaryClient) are skipped. + """ + if isinstance(client_obj, CodexAuxiliaryClient): + return False + if raw_codex: + return False + if api_mode == "codex_responses": + return True + # Auto-detect: api.openai.com + codex model name pattern + if api_mode and api_mode != "codex_responses": + return False # explicit non-codex mode + normalized_base = (base_url_str or "").strip().lower() + if "api.openai.com" in normalized_base and "openrouter" not in normalized_base: + model_lower = (model_str or "").lower() + if "codex" in model_lower: + return True + return False + + def _wrap_if_needed(client_obj, final_model_str: str, base_url_str: str = ""): + """Wrap a plain OpenAI client in CodexAuxiliaryClient if Responses API is needed.""" + if _needs_codex_wrap(client_obj, base_url_str, final_model_str): + logger.debug( + "resolve_provider_client: wrapping client in CodexAuxiliaryClient " + "(api_mode=%s, model=%s, base_url=%s)", + api_mode or "auto-detected", final_model_str, + base_url_str[:60] if base_url_str else "") + return CodexAuxiliaryClient(client_obj, final_model_str) + return client_obj + # ── Auto: try all providers in priority order ──────────────────── if provider == "auto": client, resolved = _resolve_auto() @@ -1261,7 +1338,7 @@ def resolve_provider_client( logger.warning("resolve_provider_client: openrouter requested " "but OPENROUTER_API_KEY not set") return None, None - final_model = model or default + final_model = _normalize_resolved_model(model or default, provider) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) @@ -1272,7 +1349,7 @@ def resolve_provider_client( logger.warning("resolve_provider_client: nous requested " "but Nous Portal not configured (run: hermes auth)") return None, None - final_model = model or default + final_model = _normalize_resolved_model(model or default, provider) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) @@ -1286,7 +1363,7 @@ def resolve_provider_client( logger.warning("resolve_provider_client: openai-codex requested " "but no Codex OAuth token found (run: hermes model)") return None, None - final_model = model or _CODEX_AUX_MODEL + final_model = _normalize_resolved_model(model or _CODEX_AUX_MODEL, provider) raw_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL) return (raw_client, final_model) # Standard path: wrap in CodexAuxiliaryClient adapter @@ -1295,7 +1372,7 @@ def resolve_provider_client( logger.warning("resolve_provider_client: openai-codex requested " "but no Codex OAuth token found (run: hermes model)") return None, None - final_model = model or default + final_model = _normalize_resolved_model(model or default, provider) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) @@ -1314,14 +1391,18 @@ def resolve_provider_client( "but base_url is empty" ) return None, None - final_model = model or _read_main_model() or "gpt-4o-mini" + final_model = _normalize_resolved_model( + model or _read_main_model() or "gpt-4o-mini", + provider, + ) extra = {} if "api.kimi.com" in custom_base.lower(): - extra["default_headers"] = {"User-Agent": "KimiCLI/1.3"} + extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} elif "api.githubcopilot.com" in custom_base.lower(): from hermes_cli.models import copilot_default_headers extra["default_headers"] = copilot_default_headers() client = OpenAI(api_key=custom_key, base_url=custom_base, **extra) + client = _wrap_if_needed(client, final_model, custom_base) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) # Try custom first, then codex, then API-key providers @@ -1329,7 +1410,9 @@ def resolve_provider_client( _resolve_api_key_provider): client, default = try_fn() if client is not None: - final_model = model or default + final_model = _normalize_resolved_model(model or default, provider) + _cbase = str(getattr(client, "base_url", "") or "") + client = _wrap_if_needed(client, final_model, _cbase) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) logger.warning("resolve_provider_client: custom/main requested " @@ -1344,8 +1427,12 @@ def resolve_provider_client( custom_base = custom_entry.get("base_url", "").strip() custom_key = custom_entry.get("api_key", "").strip() or "no-key-required" if custom_base: - final_model = model or _read_main_model() or "gpt-4o-mini" + final_model = _normalize_resolved_model( + model or _read_main_model() or "gpt-4o-mini", + provider, + ) client = OpenAI(api_key=custom_key, base_url=custom_base) + client = _wrap_if_needed(client, final_model, custom_base) logger.debug( "resolve_provider_client: named custom provider %r (%s)", provider, final_model) @@ -1376,7 +1463,7 @@ def resolve_provider_client( if client is None: logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found") return None, None - final_model = model or default_model + final_model = _normalize_resolved_model(model or default_model, provider) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) creds = resolve_api_key_provider_credentials(provider) @@ -1395,12 +1482,12 @@ def resolve_provider_client( ) default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "") - final_model = model or default_model + final_model = _normalize_resolved_model(model or default_model, provider) # Provider-specific headers headers = {} if "api.kimi.com" in base_url.lower(): - headers["User-Agent"] = "KimiCLI/1.3" + headers["User-Agent"] = "KimiCLI/1.30.0" elif "api.githubcopilot.com" in base_url.lower(): from hermes_cli.models import copilot_default_headers @@ -1408,6 +1495,28 @@ def resolve_provider_client( client = OpenAI(api_key=api_key, base_url=base_url, **({"default_headers": headers} if headers else {})) + + # Copilot GPT-5+ models (except gpt-5-mini) require the Responses + # API — they are not accessible via /chat/completions. Wrap the + # plain client in CodexAuxiliaryClient so call_llm() transparently + # routes through responses.stream(). + if provider == "copilot" and final_model and not raw_codex: + try: + from hermes_cli.models import _should_use_copilot_responses_api + if _should_use_copilot_responses_api(final_model): + logger.debug( + "resolve_provider_client: copilot model %s needs " + "Responses API — wrapping with CodexAuxiliaryClient", + final_model) + client = CodexAuxiliaryClient(client, final_model) + except ImportError: + pass + + # Honor api_mode for any API-key provider (e.g. direct OpenAI with + # codex-family models). The copilot-specific wrapping above handles + # copilot; this covers the general case (#6800). + client = _wrap_if_needed(client, final_model, base_url) + logger.debug("resolve_provider_client: %s (%s)", provider, final_model) return (_to_async_client(client, final_model) if async_mode else (client, final_model)) @@ -1440,12 +1549,13 @@ def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optiona Callers may override the returned model with a per-task env var (e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL). """ - provider, model, base_url, api_key = _resolve_task_provider_model(task or None) + provider, model, base_url, api_key, api_mode = _resolve_task_provider_model(task or None) return resolve_provider_client( provider, model=model, explicit_base_url=base_url, explicit_api_key=api_key, + api_mode=api_mode, ) @@ -1456,13 +1566,14 @@ def get_async_text_auxiliary_client(task: str = ""): (AsyncCodexAuxiliaryClient, model) which wraps the Responses API. Returns (None, None) when no provider is available. """ - provider, model, base_url, api_key = _resolve_task_provider_model(task or None) + provider, model, base_url, api_key, api_mode = _resolve_task_provider_model(task or None) return resolve_provider_client( provider, model=model, async_mode=True, explicit_base_url=base_url, explicit_api_key=api_key, + api_mode=api_mode, ) @@ -1495,22 +1606,6 @@ def _strict_vision_backend_available(provider: str) -> bool: return _resolve_strict_vision_backend(provider)[0] is not None -def _preferred_main_vision_provider() -> Optional[str]: - """Return the selected main provider when it is also a supported vision backend.""" - try: - from hermes_cli.config import load_config - - config = load_config() - model_cfg = config.get("model", {}) - if isinstance(model_cfg, dict): - provider = _normalize_vision_provider(model_cfg.get("provider", "")) - if provider in _VISION_AUTO_PROVIDER_ORDER: - return provider - except Exception: - pass - return None - - def get_available_vision_backends() -> List[str]: """Return the currently available vision backends in auto-selection order. @@ -1551,7 +1646,7 @@ def resolve_vision_provider_client( backends, so users can intentionally force experimental providers. Auto mode stays conservative and only tries vision backends known to work today. """ - requested, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model( + requested, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( "vision", provider, model, base_url, api_key ) requested = _normalize_vision_provider(requested) @@ -1624,18 +1719,6 @@ def resolve_vision_provider_client( return requested, client, final_model -def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]: - """Return (client, default_model_slug) for vision/multimodal auxiliary tasks.""" - _, client, final_model = resolve_vision_provider_client(async_mode=False) - return client, final_model - - -def get_async_vision_auxiliary_client(): - """Return (async_client, model_slug) for async vision consumers.""" - _, client, final_model = resolve_vision_provider_client(async_mode=True) - return client, final_model - - def get_auxiliary_extra_body() -> dict: """Return extra_body kwargs for auxiliary API calls. @@ -1779,12 +1862,30 @@ def cleanup_stale_async_clients() -> None: del _client_cache[key] +def _is_openrouter_client(client: Any) -> bool: + for obj in (client, getattr(client, "_client", None), getattr(client, "client", None)): + if obj and "openrouter" in str(getattr(obj, "base_url", "") or "").lower(): + return True + return False + + +def _compat_model(client: Any, model: Optional[str], cached_default: Optional[str]) -> Optional[str]: + """Drop OpenRouter-format model slugs (with '/') for non-OpenRouter clients. + + Mirrors the guard in resolve_provider_client() which is skipped on cache hits. + """ + if model and "/" in model and not _is_openrouter_client(client): + return cached_default + return model or cached_default + + def _get_cached_client( provider: str, model: str = None, async_mode: bool = False, base_url: str = None, api_key: str = None, + api_mode: str = None, ) -> Tuple[Optional[Any], Optional[str]]: """Get or create a cached client for the given provider. @@ -1808,7 +1909,7 @@ def _get_cached_client( loop_id = id(current_loop) except RuntimeError: pass - cache_key = (provider, async_mode, base_url or "", api_key or "", loop_id) + cache_key = (provider, async_mode, base_url or "", api_key or "", api_mode or "", loop_id) with _client_cache_lock: if cache_key in _client_cache: cached_client, cached_default, cached_loop = _client_cache[cache_key] @@ -1820,9 +1921,11 @@ def _get_cached_client( _force_close_async_httpx(cached_client) del _client_cache[cache_key] else: - return cached_client, model or cached_default + effective = _compat_model(cached_client, model, cached_default) + return cached_client, effective else: - return cached_client, model or cached_default + effective = _compat_model(cached_client, model, cached_default) + return cached_client, effective # Build outside the lock client, default_model = resolve_provider_client( provider, @@ -1830,6 +1933,7 @@ def _get_cached_client( async_mode, explicit_base_url=base_url, explicit_api_key=api_key, + api_mode=api_mode, ) if client is not None: # For async clients, remember which loop they were created on so we @@ -1849,7 +1953,7 @@ def _resolve_task_provider_model( model: str = None, base_url: str = None, api_key: str = None, -) -> Tuple[str, Optional[str], Optional[str], Optional[str]]: +) -> Tuple[str, Optional[str], Optional[str], Optional[str], Optional[str]]: """Determine provider + model for a call. Priority: @@ -1858,15 +1962,17 @@ def _resolve_task_provider_model( 3. Config file (auxiliary.{task}.* or compression.*) 4. "auto" (full auto-detection chain) - Returns (provider, model, base_url, api_key) where model may be None - (use provider default). When base_url is set, provider is forced to - "custom" and the task uses that direct endpoint. + Returns (provider, model, base_url, api_key, api_mode) where model may + be None (use provider default). When base_url is set, provider is forced + to "custom" and the task uses that direct endpoint. api_mode is one of + "chat_completions", "codex_responses", or None (auto-detect). """ config = {} cfg_provider = None cfg_model = None cfg_base_url = None cfg_api_key = None + cfg_api_mode = None if task: try: @@ -1883,6 +1989,7 @@ def _resolve_task_provider_model( cfg_model = str(task_config.get("model", "")).strip() or None cfg_base_url = str(task_config.get("base_url", "")).strip() or None cfg_api_key = str(task_config.get("api_key", "")).strip() or None + cfg_api_mode = str(task_config.get("api_mode", "")).strip() or None # Backwards compat: compression section has its own keys. # The auxiliary.compression defaults to provider="auto", so treat @@ -1896,30 +2003,32 @@ def _resolve_task_provider_model( cfg_base_url = cfg_base_url or _sbu.strip() or None env_model = _get_auxiliary_env_override(task, "MODEL") if task else None + env_api_mode = _get_auxiliary_env_override(task, "API_MODE") if task else None resolved_model = model or env_model or cfg_model + resolved_api_mode = env_api_mode or cfg_api_mode if base_url: - return "custom", resolved_model, base_url, api_key + return "custom", resolved_model, base_url, api_key, resolved_api_mode if provider: - return provider, resolved_model, base_url, api_key + return provider, resolved_model, base_url, api_key, resolved_api_mode if task: env_base_url = _get_auxiliary_env_override(task, "BASE_URL") env_api_key = _get_auxiliary_env_override(task, "API_KEY") if env_base_url: - return "custom", resolved_model, env_base_url, env_api_key or cfg_api_key + return "custom", resolved_model, env_base_url, env_api_key or cfg_api_key, resolved_api_mode env_provider = _get_auxiliary_provider(task) if env_provider != "auto": - return env_provider, resolved_model, None, None + return env_provider, resolved_model, None, None, resolved_api_mode if cfg_base_url: - return "custom", resolved_model, cfg_base_url, cfg_api_key + return "custom", resolved_model, cfg_base_url, cfg_api_key, resolved_api_mode if cfg_provider and cfg_provider != "auto": - return cfg_provider, resolved_model, None, None - return "auto", resolved_model, None, None + return cfg_provider, resolved_model, None, None, resolved_api_mode + return "auto", resolved_model, None, None, resolved_api_mode - return "auto", resolved_model, None, None + return "auto", resolved_model, None, None, resolved_api_mode _DEFAULT_AUX_TIMEOUT = 30.0 @@ -1991,6 +2100,37 @@ def _build_call_kwargs( return kwargs +def _validate_llm_response(response: Any, task: str = None) -> Any: + """Validate that an LLM response has the expected .choices[0].message shape. + + Fails fast with a clear error instead of letting malformed payloads + propagate to downstream consumers where they crash with misleading + AttributeError (e.g. "'str' object has no attribute 'choices'"). + + See #7264. + """ + if response is None: + raise RuntimeError( + f"Auxiliary {task or 'call'}: LLM returned None response" + ) + # Allow SimpleNamespace responses from adapters (CodexAuxiliaryClient, + # AnthropicAuxiliaryClient) — they have .choices[0].message. + try: + choices = response.choices + if not choices or not hasattr(choices[0], "message"): + raise AttributeError("missing choices[0].message") + except (AttributeError, TypeError, IndexError) as exc: + response_type = type(response).__name__ + response_preview = str(response)[:120] + raise RuntimeError( + f"Auxiliary {task or 'call'}: LLM returned invalid response " + f"(type={response_type}): {response_preview!r}. " + f"Expected object with .choices[0].message — check provider " + f"adapter or custom endpoint compatibility." + ) from exc + return response + + def call_llm( task: str = None, *, @@ -2029,7 +2169,7 @@ def call_llm( Raises: RuntimeError: If no provider is configured. """ - resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model( + resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( task, provider, model, base_url, api_key) if task == "vision": @@ -2062,6 +2202,7 @@ def call_llm( resolved_model, base_url=resolved_base_url, api_key=resolved_api_key, + api_mode=resolved_api_mode, ) if client is None: # When the user explicitly chose a non-OpenRouter provider but no @@ -2105,18 +2246,20 @@ def call_llm( # Handle max_tokens vs max_completion_tokens retry, then payment fallback. try: - return client.chat.completions.create(**kwargs) + return _validate_llm_response( + client.chat.completions.create(**kwargs), task) except Exception as first_err: err_str = str(first_err) if "max_tokens" in err_str or "unsupported_parameter" in err_str: kwargs.pop("max_tokens", None) kwargs["max_completion_tokens"] = max_tokens try: - return client.chat.completions.create(**kwargs) + return _validate_llm_response( + client.chat.completions.create(**kwargs), task) except Exception as retry_err: - # If the max_tokens retry also hits a payment error, - # fall through to the payment fallback below. - if not _is_payment_error(retry_err): + # If the max_tokens retry also hits a payment or connection + # error, fall through to the fallback chain below. + if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)): raise first_err = retry_err @@ -2133,19 +2276,24 @@ def call_llm( # and providers the user never configured that got picked up by # the auto-detection chain. should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) - if should_fallback: + # Only try alternative providers when the user didn't explicitly + # configure this task's provider. Explicit provider = hard constraint; + # auto (the default) = best-effort fallback chain. (#7559) + is_auto = resolved_provider in ("auto", "", None) + if should_fallback and is_auto: reason = "payment error" if _is_payment_error(first_err) else "connection error" logger.info("Auxiliary %s: %s on %s (%s), trying fallback", task or "call", reason, resolved_provider, first_err) fb_client, fb_model, fb_label = _try_payment_fallback( - resolved_provider, task) + resolved_provider, task, reason=reason) if fb_client is not None: fb_kwargs = _build_call_kwargs( fb_label, fb_model, messages, temperature=temperature, max_tokens=max_tokens, tools=tools, timeout=effective_timeout, extra_body=extra_body) - return fb_client.chat.completions.create(**fb_kwargs) + return _validate_llm_response( + fb_client.chat.completions.create(**fb_kwargs), task) raise @@ -2223,7 +2371,7 @@ async def async_call_llm( Same as call_llm() but async. See call_llm() for full documentation. """ - resolved_provider, resolved_model, resolved_base_url, resolved_api_key = _resolve_task_provider_model( + resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( task, provider, model, base_url, api_key) if task == "vision": @@ -2257,6 +2405,7 @@ async def async_call_llm( async_mode=True, base_url=resolved_base_url, api_key=resolved_api_key, + api_mode=resolved_api_mode, ) if client is None: _explicit = (resolved_provider or "").strip().lower() @@ -2267,11 +2416,9 @@ async def async_call_llm( f"variable, or switch to a different provider with `hermes model`." ) if not resolved_base_url: - logger.warning("Provider %s unavailable, falling back to openrouter", - resolved_provider) - client, final_model = _get_cached_client( - "openrouter", resolved_model or _OPENROUTER_MODEL, - async_mode=True) + logger.info("Auxiliary %s: provider %s unavailable, trying auto-detection chain", + task or "call", resolved_provider) + client, final_model = _get_cached_client("auto", async_mode=True) if client is None: raise RuntimeError( f"No LLM provider configured for task={task} provider={resolved_provider}. " @@ -2286,11 +2433,42 @@ async def async_call_llm( base_url=resolved_base_url) try: - return await client.chat.completions.create(**kwargs) + return _validate_llm_response( + await client.chat.completions.create(**kwargs), task) except Exception as first_err: err_str = str(first_err) if "max_tokens" in err_str or "unsupported_parameter" in err_str: kwargs.pop("max_tokens", None) kwargs["max_completion_tokens"] = max_tokens - return await client.chat.completions.create(**kwargs) + try: + return _validate_llm_response( + await client.chat.completions.create(**kwargs), task) + except Exception as retry_err: + # If the max_tokens retry also hits a payment or connection + # error, fall through to the fallback chain below. + if not (_is_payment_error(retry_err) or _is_connection_error(retry_err)): + raise + first_err = retry_err + + # ── Payment / connection fallback (mirrors sync call_llm) ───── + should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) + is_auto = resolved_provider in ("auto", "", None) + if should_fallback and is_auto: + reason = "payment error" if _is_payment_error(first_err) else "connection error" + logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback", + task or "call", reason, resolved_provider, first_err) + fb_client, fb_model, fb_label = _try_payment_fallback( + resolved_provider, task, reason=reason) + if fb_client is not None: + fb_kwargs = _build_call_kwargs( + fb_label, fb_model, messages, + temperature=temperature, max_tokens=max_tokens, + tools=tools, timeout=effective_timeout, + extra_body=extra_body) + # Convert sync fallback client to async + async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "") + if async_fb_model and async_fb_model != fb_kwargs.get("model"): + fb_kwargs["model"] = async_fb_model + return _validate_llm_response( + await async_fb.chat.completions.create(**fb_kwargs), task) raise diff --git a/agent/builtin_memory_provider.py b/agent/builtin_memory_provider.py deleted file mode 100644 index 77df9a303d..0000000000 --- a/agent/builtin_memory_provider.py +++ /dev/null @@ -1,114 +0,0 @@ -"""BuiltinMemoryProvider — wraps MEMORY.md / USER.md as a MemoryProvider. - -Always registered as the first provider. Cannot be disabled or removed. -This is the existing Hermes memory system exposed through the provider -interface for compatibility with the MemoryManager. - -The actual storage logic lives in tools/memory_tool.py (MemoryStore). -This provider is a thin adapter that delegates to MemoryStore and -exposes the memory tool schema. -""" - -from __future__ import annotations - -import json -import logging -from typing import Any, Dict, List - -from agent.memory_provider import MemoryProvider -from tools.registry import tool_error - -logger = logging.getLogger(__name__) - - -class BuiltinMemoryProvider(MemoryProvider): - """Built-in file-backed memory (MEMORY.md + USER.md). - - Always active, never disabled by other providers. The `memory` tool - is handled by run_agent.py's agent-level tool interception (not through - the normal registry), so get_tool_schemas() returns an empty list — - the memory tool is already wired separately. - """ - - def __init__( - self, - memory_store=None, - memory_enabled: bool = False, - user_profile_enabled: bool = False, - ): - self._store = memory_store - self._memory_enabled = memory_enabled - self._user_profile_enabled = user_profile_enabled - - @property - def name(self) -> str: - return "builtin" - - def is_available(self) -> bool: - """Built-in memory is always available.""" - return True - - def initialize(self, session_id: str, **kwargs) -> None: - """Load memory from disk if not already loaded.""" - if self._store is not None: - self._store.load_from_disk() - - def system_prompt_block(self) -> str: - """Return MEMORY.md and USER.md content for the system prompt. - - Uses the frozen snapshot captured at load time. This ensures the - system prompt stays stable throughout a session (preserving the - prompt cache), even though the live entries may change via tool calls. - """ - if not self._store: - return "" - - parts = [] - if self._memory_enabled: - mem_block = self._store.format_for_system_prompt("memory") - if mem_block: - parts.append(mem_block) - if self._user_profile_enabled: - user_block = self._store.format_for_system_prompt("user") - if user_block: - parts.append(user_block) - - return "\n\n".join(parts) - - def prefetch(self, query: str, *, session_id: str = "") -> str: - """Built-in memory doesn't do query-based recall — it's injected via system_prompt_block.""" - return "" - - def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Built-in memory doesn't auto-sync turns — writes happen via the memory tool.""" - - def get_tool_schemas(self) -> List[Dict[str, Any]]: - """Return empty list. - - The `memory` tool is an agent-level intercepted tool, handled - specially in run_agent.py before normal tool dispatch. It's not - part of the standard tool registry. We don't duplicate it here. - """ - return [] - - def handle_tool_call(self, tool_name: str, args: Dict[str, Any], **kwargs) -> str: - """Not used — the memory tool is intercepted in run_agent.py.""" - return tool_error("Built-in memory tool is handled by the agent loop") - - def shutdown(self) -> None: - """No cleanup needed — files are saved on every write.""" - - # -- Property access for backward compatibility -------------------------- - - @property - def store(self): - """Access the underlying MemoryStore for legacy code paths.""" - return self._store - - @property - def memory_enabled(self) -> bool: - return self._memory_enabled - - @property - def user_profile_enabled(self) -> bool: - return self._user_profile_enabled diff --git a/agent/context_compressor.py b/agent/context_compressor.py index eba2de3f3f..069a5b65e1 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -18,6 +18,7 @@ import time from typing import Any, Dict, List, Optional from agent.auxiliary_client import call_llm +from agent.context_engine import ContextEngine from agent.model_metadata import ( get_model_context_length, estimate_messages_tokens_rough, @@ -50,8 +51,8 @@ _CHARS_PER_TOKEN = 4 _SUMMARY_FAILURE_COOLDOWN_SECONDS = 600 -class ContextCompressor: - """Compresses conversation context when approaching the model's context limit. +class ContextCompressor(ContextEngine): + """Default context engine — compresses conversation context via lossy summarization. Algorithm: 1. Prune old tool results (cheap, no LLM call) @@ -61,6 +62,33 @@ class ContextCompressor: 5. On subsequent compactions, iteratively update the previous summary """ + @property + def name(self) -> str: + return "compressor" + + def on_session_reset(self) -> None: + """Reset all per-session state for /new or /reset.""" + super().on_session_reset() + self._context_probed = False + self._context_probe_persistable = False + self._previous_summary = None + + def update_model( + self, + model: str, + context_length: int, + base_url: str = "", + api_key: str = "", + provider: str = "", + ) -> None: + """Update model info after a model switch or fallback activation.""" + self.model = model + self.base_url = base_url + self.api_key = api_key + self.provider = provider + self.context_length = context_length + self.threshold_tokens = int(context_length * self.threshold_percent) + def __init__( self, model: str, @@ -114,7 +142,6 @@ class ContextCompressor: self.last_prompt_tokens = 0 self.last_completion_tokens = 0 - self.last_total_tokens = 0 self.summary_model = summary_model_override or "" @@ -126,28 +153,12 @@ class ContextCompressor: """Update tracked token usage from API response.""" self.last_prompt_tokens = usage.get("prompt_tokens", 0) self.last_completion_tokens = usage.get("completion_tokens", 0) - self.last_total_tokens = usage.get("total_tokens", 0) def should_compress(self, prompt_tokens: int = None) -> bool: """Check if context exceeds the compression threshold.""" tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens return tokens >= self.threshold_tokens - def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool: - """Quick pre-flight check using rough estimate (before API call).""" - rough_estimate = estimate_messages_tokens_rough(messages) - return rough_estimate >= self.threshold_tokens - - def get_status(self) -> Dict[str, Any]: - """Get current compression status for display/logging.""" - return { - "last_prompt_tokens": self.last_prompt_tokens, - "threshold_tokens": self.threshold_tokens, - "context_length": self.context_length, - "usage_percent": min(100, (self.last_prompt_tokens / self.context_length * 100)) if self.context_length else 0, - "compression_count": self.compression_count, - } - # ------------------------------------------------------------------ # Tool output pruning (cheap pre-pass, no LLM call) # ------------------------------------------------------------------ diff --git a/agent/context_engine.py b/agent/context_engine.py new file mode 100644 index 0000000000..6cd7275fe9 --- /dev/null +++ b/agent/context_engine.py @@ -0,0 +1,184 @@ +"""Abstract base class for pluggable context engines. + +A context engine controls how conversation context is managed when +approaching the model's token limit. The built-in ContextCompressor +is the default implementation. Third-party engines (e.g. LCM) can +replace it via the plugin system or by being placed in the +``plugins/context_engine//`` directory. + +Selection is config-driven: ``context.engine`` in config.yaml. +Default is ``"compressor"`` (the built-in). Only one engine is active. + +The engine is responsible for: + - Deciding when compaction should fire + - Performing compaction (summarization, DAG construction, etc.) + - Optionally exposing tools the agent can call (e.g. lcm_grep) + - Tracking token usage from API responses + +Lifecycle: + 1. Engine is instantiated and registered (plugin register() or default) + 2. on_session_start() called when a conversation begins + 3. update_from_response() called after each API response with usage data + 4. should_compress() checked after each turn + 5. compress() called when should_compress() returns True + 6. on_session_end() called at real session boundaries (CLI exit, /reset, + gateway session expiry) — NOT per-turn +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class ContextEngine(ABC): + """Base class all context engines must implement.""" + + # -- Identity ---------------------------------------------------------- + + @property + @abstractmethod + def name(self) -> str: + """Short identifier (e.g. 'compressor', 'lcm').""" + + # -- Token state (read by run_agent.py for display/logging) ------------ + # + # Engines MUST maintain these. run_agent.py reads them directly. + + last_prompt_tokens: int = 0 + last_completion_tokens: int = 0 + last_total_tokens: int = 0 + threshold_tokens: int = 0 + context_length: int = 0 + compression_count: int = 0 + + # -- Compaction parameters (read by run_agent.py for preflight) -------- + # + # These control the preflight compression check. Subclasses may + # override via __init__ or property; defaults are sensible for most + # engines. + + threshold_percent: float = 0.75 + protect_first_n: int = 3 + protect_last_n: int = 6 + + # -- Core interface ---------------------------------------------------- + + @abstractmethod + def update_from_response(self, usage: Dict[str, Any]) -> None: + """Update tracked token usage from an API response. + + Called after every LLM call with the usage dict from the response. + """ + + @abstractmethod + def should_compress(self, prompt_tokens: int = None) -> bool: + """Return True if compaction should fire this turn.""" + + @abstractmethod + def compress( + self, + messages: List[Dict[str, Any]], + current_tokens: int = None, + ) -> List[Dict[str, Any]]: + """Compact the message list and return the new message list. + + This is the main entry point. The engine receives the full message + list and returns a (possibly shorter) list that fits within the + context budget. The implementation is free to summarize, build a + DAG, or do anything else — as long as the returned list is a valid + OpenAI-format message sequence. + """ + + # -- Optional: pre-flight check ---------------------------------------- + + def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool: + """Quick rough check before the API call (no real token count yet). + + Default returns False (skip pre-flight). Override if your engine + can do a cheap estimate. + """ + return False + + # -- Optional: session lifecycle --------------------------------------- + + def on_session_start(self, session_id: str, **kwargs) -> None: + """Called when a new conversation session begins. + + Use this to load persisted state (DAG, store) for the session. + kwargs may include hermes_home, platform, model, etc. + """ + + def on_session_end(self, session_id: str, messages: List[Dict[str, Any]]) -> None: + """Called at real session boundaries (CLI exit, /reset, gateway expiry). + + Use this to flush state, close DB connections, etc. + NOT called per-turn — only when the session truly ends. + """ + + def on_session_reset(self) -> None: + """Called on /new or /reset. Reset per-session state. + + Default resets compression_count and token tracking. + """ + self.last_prompt_tokens = 0 + self.last_completion_tokens = 0 + self.last_total_tokens = 0 + self.compression_count = 0 + + # -- Optional: tools --------------------------------------------------- + + def get_tool_schemas(self) -> List[Dict[str, Any]]: + """Return tool schemas this engine provides to the agent. + + Default returns empty list (no tools). LCM would return schemas + for lcm_grep, lcm_describe, lcm_expand here. + """ + return [] + + def handle_tool_call(self, name: str, args: Dict[str, Any], **kwargs) -> str: + """Handle a tool call from the agent. + + Only called for tool names returned by get_tool_schemas(). + Must return a JSON string. + + kwargs may include: + messages: the current in-memory message list (for live ingestion) + """ + import json + return json.dumps({"error": f"Unknown context engine tool: {name}"}) + + # -- Optional: status / display ---------------------------------------- + + def get_status(self) -> Dict[str, Any]: + """Return status dict for display/logging. + + Default returns the standard fields run_agent.py expects. + """ + return { + "last_prompt_tokens": self.last_prompt_tokens, + "threshold_tokens": self.threshold_tokens, + "context_length": self.context_length, + "usage_percent": ( + min(100, self.last_prompt_tokens / self.context_length * 100) + if self.context_length else 0 + ), + "compression_count": self.compression_count, + } + + # -- Optional: model switch support ------------------------------------ + + def update_model( + self, + model: str, + context_length: int, + base_url: str = "", + api_key: str = "", + provider: str = "", + ) -> None: + """Called when the user switches models or on fallback activation. + + Default updates context_length and recalculates threshold_tokens + from threshold_percent. Override if your engine needs more + (e.g. recalculate DAG budgets, switch summary models). + """ + self.context_length = context_length + self.threshold_tokens = int(context_length * self.threshold_percent) diff --git a/agent/context_references.py b/agent/context_references.py index 1b8ac9481a..7ecb90c497 100644 --- a/agent/context_references.py +++ b/agent/context_references.py @@ -13,8 +13,9 @@ from typing import Awaitable, Callable from agent.model_metadata import estimate_tokens_rough +_QUOTED_REFERENCE_VALUE = r'(?:`[^`\n]+`|"[^"\n]+"|\'[^\'\n]+\')' REFERENCE_PATTERN = re.compile( - r"(?diff|staged)\b|(?Pfile|folder|git|url):(?P\S+))" + rf"(?diff|staged)\b|(?Pfile|folder|git|url):(?P{_QUOTED_REFERENCE_VALUE}(?::\d+(?:-\d+)?)?|\S+))" ) TRAILING_PUNCTUATION = ",.;!?" _SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh") @@ -81,14 +82,10 @@ def parse_context_references(message: str) -> list[ContextReference]: value = _strip_trailing_punctuation(match.group("value") or "") line_start = None line_end = None - target = value + target = _strip_reference_wrappers(value) if kind == "file": - range_match = re.match(r"^(?P.+?):(?P\d+)(?:-(?P\d+))?$", value) - if range_match: - target = range_match.group("path") - line_start = int(range_match.group("start")) - line_end = int(range_match.group("end") or range_match.group("start")) + target, line_start, line_end = _parse_file_reference_value(value) refs.append( ContextReference( @@ -375,6 +372,38 @@ def _strip_trailing_punctuation(value: str) -> str: return stripped +def _strip_reference_wrappers(value: str) -> str: + if len(value) >= 2 and value[0] == value[-1] and value[0] in "`\"'": + return value[1:-1] + return value + + +def _parse_file_reference_value(value: str) -> tuple[str, int | None, int | None]: + quoted_match = re.match( + r'^(?P`|"|\')(?P.+?)(?P=quote)(?::(?P\d+)(?:-(?P\d+))?)?$', + value, + ) + if quoted_match: + line_start = quoted_match.group("start") + line_end = quoted_match.group("end") + return ( + quoted_match.group("path"), + int(line_start) if line_start is not None else None, + int(line_end or line_start) if line_start is not None else None, + ) + + range_match = re.match(r"^(?P.+?):(?P\d+)(?:-(?P\d+))?$", value) + if range_match: + line_start = int(range_match.group("start")) + return ( + range_match.group("path"), + line_start, + int(range_match.group("end") or range_match.group("start")), + ) + + return _strip_reference_wrappers(value), None, None + + def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str: pieces: list[str] = [] cursor = 0 diff --git a/agent/credential_pool.py b/agent/credential_pool.py index a17d71ba5e..bff262bdc0 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -20,6 +20,7 @@ from hermes_cli.auth import ( DEFAULT_AGENT_KEY_MIN_TTL_SECONDS, KIMI_CODE_BASE_URL, PROVIDER_REGISTRY, + _auth_store_lock, _codex_access_token_is_expiring, _decode_jwt_claims, _import_codex_cli_tokens, @@ -27,6 +28,8 @@ from hermes_cli.auth import ( _load_provider_state, _resolve_kimi_base_url, _resolve_zai_base_url, + _save_auth_store, + _save_provider_state, read_credential_pool, write_credential_pool, ) @@ -479,6 +482,67 @@ class CredentialPool: logger.debug("Failed to sync from ~/.codex/auth.json: %s", exc) return entry + def _sync_device_code_entry_to_auth_store(self, entry: PooledCredential) -> None: + """Write refreshed pool entry tokens back to auth.json providers. + + After a pool-level refresh, the pool entry has fresh tokens but + auth.json's ``providers.`` still holds the pre-refresh state. + On the next ``load_pool()``, ``_seed_from_singletons()`` reads that + stale state and can overwrite the fresh pool entry — potentially + re-seeding a consumed single-use refresh token. + + Applies to any OAuth provider whose singleton lives in auth.json + (currently Nous and OpenAI Codex). + """ + if entry.source != "device_code": + return + try: + with _auth_store_lock(): + auth_store = _load_auth_store() + if self.provider == "nous": + state = _load_provider_state(auth_store, "nous") + if state is None: + return + state["access_token"] = entry.access_token + if entry.refresh_token: + state["refresh_token"] = entry.refresh_token + if entry.expires_at: + state["expires_at"] = entry.expires_at + if entry.agent_key: + state["agent_key"] = entry.agent_key + if entry.agent_key_expires_at: + state["agent_key_expires_at"] = entry.agent_key_expires_at + for extra_key in ("obtained_at", "expires_in", "agent_key_id", + "agent_key_expires_in", "agent_key_reused", + "agent_key_obtained_at"): + val = entry.extra.get(extra_key) + if val is not None: + state[extra_key] = val + if entry.inference_base_url: + state["inference_base_url"] = entry.inference_base_url + _save_provider_state(auth_store, "nous", state) + + elif self.provider == "openai-codex": + state = _load_provider_state(auth_store, "openai-codex") + if not isinstance(state, dict): + return + tokens = state.get("tokens") + if not isinstance(tokens, dict): + return + tokens["access_token"] = entry.access_token + if entry.refresh_token: + tokens["refresh_token"] = entry.refresh_token + if entry.last_refresh: + state["last_refresh"] = entry.last_refresh + _save_provider_state(auth_store, "openai-codex", state) + + else: + return + + _save_auth_store(auth_store) + except Exception as exc: + logger.debug("Failed to sync %s pool entry back to auth store: %s", self.provider, exc) + def _refresh_entry(self, entry: PooledCredential, *, force: bool) -> Optional[PooledCredential]: if entry.auth_type != AUTH_TYPE_OAUTH or not entry.refresh_token: if force: @@ -513,6 +577,13 @@ class CredentialPool: except Exception as wexc: logger.debug("Failed to write refreshed token to credentials file: %s", wexc) elif self.provider == "openai-codex": + # Proactively sync from ~/.codex/auth.json before refresh. + # The Codex CLI (or another Hermes profile) may have already + # consumed our refresh_token. Syncing first avoids a + # "refresh_token_reused" error when the CLI has a newer pair. + synced = self._sync_codex_entry_from_cli(entry) + if synced is not entry: + entry = synced refreshed = auth_mod.refresh_codex_oauth_pure( entry.access_token, entry.refresh_token, @@ -598,6 +669,37 @@ class CredentialPool: # Credentials file had a valid (non-expired) token — use it directly logger.debug("Credentials file has valid token, using without refresh") return synced + # For openai-codex: the refresh_token may have been consumed by + # the Codex CLI between our proactive sync and the refresh call. + # Re-sync and retry once. + if self.provider == "openai-codex": + synced = self._sync_codex_entry_from_cli(entry) + if synced.refresh_token != entry.refresh_token: + logger.debug("Retrying Codex refresh with synced token from ~/.codex/auth.json") + try: + refreshed = auth_mod.refresh_codex_oauth_pure( + synced.access_token, + synced.refresh_token, + ) + updated = replace( + synced, + access_token=refreshed["access_token"], + refresh_token=refreshed["refresh_token"], + last_refresh=refreshed.get("last_refresh"), + last_status=STATUS_OK, + last_status_at=None, + last_error_code=None, + ) + self._replace_entry(synced, updated) + self._persist() + self._sync_device_code_entry_to_auth_store(updated) + return updated + except Exception as retry_exc: + logger.debug("Codex retry refresh also failed: %s", retry_exc) + elif not self._entry_needs_refresh(synced): + logger.debug("Codex CLI has valid token, using without refresh") + self._sync_device_code_entry_to_auth_store(synced) + return synced self._mark_exhausted(entry, None) return None @@ -612,6 +714,10 @@ class CredentialPool: ) self._replace_entry(entry, updated) self._persist() + # Sync refreshed tokens back to auth.json providers so that + # _seed_from_singletons() on the next load_pool() sees fresh state + # instead of re-seeding stale/consumed tokens. + self._sync_device_code_entry_to_auth_store(updated) return updated def _entry_needs_refresh(self, entry: PooledCredential) -> bool: @@ -633,17 +739,6 @@ class CredentialPool: return False return False - def mark_used(self, entry_id: Optional[str] = None) -> None: - """Increment request_count for tracking. Used by least_used strategy.""" - target_id = entry_id or self._current_id - if not target_id: - return - with self._lock: - for idx, entry in enumerate(self._entries): - if entry.id == target_id: - self._entries[idx] = replace(entry, request_count=entry.request_count + 1) - return - def select(self) -> Optional[PooledCredential]: with self._lock: return self._select_unlocked() @@ -805,11 +900,6 @@ class CredentialPool: else: self._active_leases[credential_id] = count - 1 - def active_lease_count(self, credential_id: str) -> int: - """Return the number of active leases for a credential.""" - with self._lock: - return self._active_leases.get(credential_id, 0) - def try_refresh_current(self) -> Optional[PooledCredential]: with self._lock: return self._try_refresh_current_unlocked() @@ -969,6 +1059,17 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup auth_store = _load_auth_store() if provider == "anthropic": + # Only auto-discover external credentials (Claude Code, Hermes PKCE) + # when the user has explicitly configured anthropic as their provider. + # Without this gate, auxiliary client fallback chains silently read + # ~/.claude/.credentials.json without user consent. See PR #4210. + try: + from hermes_cli.auth import is_provider_explicitly_configured + if not is_provider_explicitly_configured("anthropic"): + return changed, active_sources + except ImportError: + pass + from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials for source_name, creds in ( @@ -976,6 +1077,13 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup ("claude_code", read_claude_code_credentials()), ): if creds and creds.get("accessToken"): + # Check if user explicitly removed this source + try: + from hermes_cli.auth import is_source_suppressed + if is_source_suppressed(provider, source_name): + continue + except ImportError: + pass active_sources.add(source_name) changed |= _upsert_entry( entries, diff --git a/agent/display.py b/agent/display.py index 7c7707eb8f..604b7a298c 100644 --- a/agent/display.py +++ b/agent/display.py @@ -21,11 +21,73 @@ _RESET = "\033[0m" logger = logging.getLogger(__name__) _ANSI_RESET = "\033[0m" -_ANSI_DIM = "\033[38;2;150;150;150m" -_ANSI_FILE = "\033[38;2;180;160;255m" -_ANSI_HUNK = "\033[38;2;120;120;140m" -_ANSI_MINUS = "\033[38;2;255;255;255;48;2;120;20;20m" -_ANSI_PLUS = "\033[38;2;255;255;255;48;2;20;90;20m" + +# Diff colors — resolved lazily from the skin engine so they adapt +# to light/dark themes. Falls back to sensible defaults on import +# failure. We cache after first resolution for performance. +_diff_colors_cached: dict[str, str] | None = None + + +def _diff_ansi() -> dict[str, str]: + """Return ANSI escapes for diff display, resolved from the active skin.""" + global _diff_colors_cached + if _diff_colors_cached is not None: + return _diff_colors_cached + + # Defaults that work on dark terminals + dim = "\033[38;2;150;150;150m" + file_c = "\033[38;2;180;160;255m" + hunk = "\033[38;2;120;120;140m" + minus = "\033[38;2;255;255;255;48;2;120;20;20m" + plus = "\033[38;2;255;255;255;48;2;20;90;20m" + + try: + from hermes_cli.skin_engine import get_active_skin + skin = get_active_skin() + + def _hex_fg(key: str, fallback_rgb: tuple[int, int, int]) -> str: + h = skin.get_color(key, "") + if h and len(h) == 7 and h[0] == "#": + r, g, b = int(h[1:3], 16), int(h[3:5], 16), int(h[5:7], 16) + return f"\033[38;2;{r};{g};{b}m" + r, g, b = fallback_rgb + return f"\033[38;2;{r};{g};{b}m" + + dim = _hex_fg("banner_dim", (150, 150, 150)) + file_c = _hex_fg("session_label", (180, 160, 255)) + hunk = _hex_fg("session_border", (120, 120, 140)) + # minus/plus use background colors — derive from ui_error/ui_ok + err_h = skin.get_color("ui_error", "#ef5350") + ok_h = skin.get_color("ui_ok", "#4caf50") + if err_h and len(err_h) == 7: + er, eg, eb = int(err_h[1:3], 16), int(err_h[3:5], 16), int(err_h[5:7], 16) + # Use a dark tinted version as background + minus = f"\033[38;2;255;255;255;48;2;{max(er//2,20)};{max(eg//4,10)};{max(eb//4,10)}m" + if ok_h and len(ok_h) == 7: + or_, og, ob = int(ok_h[1:3], 16), int(ok_h[3:5], 16), int(ok_h[5:7], 16) + plus = f"\033[38;2;255;255;255;48;2;{max(or_//4,10)};{max(og//2,20)};{max(ob//4,10)}m" + except Exception: + pass + + _diff_colors_cached = { + "dim": dim, "file": file_c, "hunk": hunk, + "minus": minus, "plus": plus, + } + return _diff_colors_cached + + +def reset_diff_colors() -> None: + """Reset cached diff colors (call after /skin switch).""" + global _diff_colors_cached + _diff_colors_cached = None + + +# Module-level helpers — each call resolves from the active skin lazily. +def _diff_dim(): return _diff_ansi()["dim"] +def _diff_file(): return _diff_ansi()["file"] +def _diff_hunk(): return _diff_ansi()["hunk"] +def _diff_minus(): return _diff_ansi()["minus"] +def _diff_plus(): return _diff_ansi()["plus"] _MAX_INLINE_DIFF_FILES = 6 _MAX_INLINE_DIFF_LINES = 80 @@ -67,26 +129,6 @@ def _get_skin(): return None -def get_skin_faces(key: str, default: list) -> list: - """Get spinner face list from active skin, falling back to default.""" - skin = _get_skin() - if skin: - faces = skin.get_spinner_list(key) - if faces: - return faces - return default - - -def get_skin_verbs() -> list: - """Get thinking verbs from active skin.""" - skin = _get_skin() - if skin: - verbs = skin.get_spinner_list("thinking_verbs") - if verbs: - return verbs - return KawaiiSpinner.THINKING_VERBS - - def get_skin_tool_prefix() -> str: """Get tool output prefix character from active skin.""" skin = _get_skin() @@ -423,19 +465,19 @@ def _render_inline_unified_diff(diff: str) -> list[str]: if raw_line.startswith("+++ "): to_file = raw_line[4:].strip() if from_file or to_file: - rendered.append(f"{_ANSI_FILE}{from_file or 'a/?'} → {to_file or 'b/?'}{_ANSI_RESET}") + rendered.append(f"{_diff_file()}{from_file or 'a/?'} → {to_file or 'b/?'}{_ANSI_RESET}") continue if raw_line.startswith("@@"): - rendered.append(f"{_ANSI_HUNK}{raw_line}{_ANSI_RESET}") + rendered.append(f"{_diff_hunk()}{raw_line}{_ANSI_RESET}") continue if raw_line.startswith("-"): - rendered.append(f"{_ANSI_MINUS}{raw_line}{_ANSI_RESET}") + rendered.append(f"{_diff_minus()}{raw_line}{_ANSI_RESET}") continue if raw_line.startswith("+"): - rendered.append(f"{_ANSI_PLUS}{raw_line}{_ANSI_RESET}") + rendered.append(f"{_diff_plus()}{raw_line}{_ANSI_RESET}") continue if raw_line.startswith(" "): - rendered.append(f"{_ANSI_DIM}{raw_line}{_ANSI_RESET}") + rendered.append(f"{_diff_dim()}{raw_line}{_ANSI_RESET}") continue if raw_line: rendered.append(raw_line) @@ -501,7 +543,7 @@ def _summarize_rendered_diff_sections( summary = f"… omitted {omitted_lines} diff line(s)" if omitted_files: summary += f" across {omitted_files} additional file(s)/section(s)" - rendered.append(f"{_ANSI_HUNK}{summary}{_ANSI_RESET}") + rendered.append(f"{_diff_hunk()}{summary}{_ANSI_RESET}") return rendered @@ -723,46 +765,6 @@ class KawaiiSpinner: return False -# ========================================================================= -# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text) -# ========================================================================= - -KAWAII_SEARCH = [ - "♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ", - "٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)ノ*:・゚✧", "\(◎o◎)/", -] -KAWAII_READ = [ - "φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)", - "ヾ(@⌒ー⌒@)ノ", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )ノ", -] -KAWAII_TERMINAL = [ - "ヽ(>∀<☆)ノ", "(ノ°∀°)ノ", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و", - "┗(^0^)┓", "(`・ω・´)", "\( ̄▽ ̄)/", "(ง •̀_•́)ง", "ヽ(´▽`)/", -] -KAWAII_BROWSER = [ - "(ノ°∀°)ノ", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)?", - "ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "\(◎o◎)/", -] -KAWAII_CREATE = [ - "✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "٩(♡ε♡)۶", "(◕‿◕)♡", - "✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(^-^)ノ", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°", -] -KAWAII_SKILL = [ - "ヾ(@⌒ー⌒@)ノ", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)ノ", - "(ノ´ヮ`)ノ*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)", - "ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "\(◎o◎)/", - "(✧ω✧)", "ヽ(>∀<☆)ノ", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)", -] -KAWAII_THINK = [ - "(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)", - "(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )ノ", "(;一_一)", -] -KAWAII_GENERIC = [ - "♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)", - "(ノ´ヮ`)ノ*:・゚✧", "ヽ(>∀<☆)ノ", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)", -] - - # ========================================================================= # Cute tool message (completion line that replaces the spinner) # ========================================================================= @@ -970,22 +972,6 @@ _SKY_BLUE = "\033[38;5;117m" _ANSI_RESET = "\033[0m" -def honcho_session_url(workspace: str, session_name: str) -> str: - """Build a Honcho app URL for a session.""" - from urllib.parse import quote - return ( - f"https://app.honcho.dev/explore" - f"?workspace={quote(workspace, safe='')}" - f"&view=sessions" - f"&session={quote(session_name, safe='')}" - ) - - -def _osc8_link(url: str, text: str) -> str: - """OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.).""" - return f"\033]8;;{url}\033\\{text}\033]8;;\033\\" - - # ========================================================================= # Context pressure display (CLI user-facing warnings) # ========================================================================= diff --git a/agent/error_classifier.py b/agent/error_classifier.py index 1f6b48a095..dc5ae6b56f 100644 --- a/agent/error_classifier.py +++ b/agent/error_classifier.py @@ -82,16 +82,6 @@ class ClassifiedError: def is_auth(self) -> bool: return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent) - @property - def is_transient(self) -> bool: - """Error is expected to resolve on retry (with or without backoff).""" - return self.reason in ( - FailoverReason.rate_limit, - FailoverReason.overloaded, - FailoverReason.server_error, - FailoverReason.timeout, - FailoverReason.unknown, - ) # ── Provider-specific patterns ────────────────────────────────────────── @@ -122,6 +112,7 @@ _RATE_LIMIT_PATTERNS = [ "try again in", "please retry after", "resource_exhausted", + "rate increased too quickly", # Alibaba/DashScope throttling ] # Usage-limit patterns that need disambiguation (could be billing OR rate_limit) @@ -725,11 +716,16 @@ def _classify_by_message( ) # Auth patterns + # Auth errors should NOT be retried directly — the credential is invalid and + # retrying with the same key will always fail. Set retryable=False so the + # caller triggers credential rotation (should_rotate_credential=True) or + # provider fallback rather than an immediate retry loop. if any(p in error_msg for p in _AUTH_PATTERNS): return result_fn( FailoverReason.auth, - retryable=True, + retryable=False, should_rotate_credential=True, + should_fallback=True, ) # Model not found patterns diff --git a/agent/insights.py b/agent/insights.py index d529ffedfc..b15327c825 100644 --- a/agent/insights.py +++ b/agent/insights.py @@ -39,15 +39,6 @@ def _has_known_pricing(model_name: str, provider: str = None, base_url: str = No return has_known_pricing(model_name, provider=provider, base_url=base_url) -def _get_pricing(model_name: str) -> Dict[str, float]: - """Look up pricing for a model. Uses fuzzy matching on model name. - - Returns _DEFAULT_PRICING (zero cost) for unknown/custom models — - we can't assume costs for self-hosted endpoints, local inference, etc. - """ - return get_pricing(model_name) - - def _estimate_cost( session_or_model: Dict[str, Any] | str, input_tokens: int = 0, diff --git a/agent/manual_compression_feedback.py b/agent/manual_compression_feedback.py new file mode 100644 index 0000000000..8f2d5e5d52 --- /dev/null +++ b/agent/manual_compression_feedback.py @@ -0,0 +1,49 @@ +"""User-facing summaries for manual compression commands.""" + +from __future__ import annotations + +from typing import Any, Sequence + + +def summarize_manual_compression( + before_messages: Sequence[dict[str, Any]], + after_messages: Sequence[dict[str, Any]], + before_tokens: int, + after_tokens: int, +) -> dict[str, Any]: + """Return consistent user-facing feedback for manual compression.""" + before_count = len(before_messages) + after_count = len(after_messages) + noop = list(after_messages) == list(before_messages) + + if noop: + headline = f"No changes from compression: {before_count} messages" + if after_tokens == before_tokens: + token_line = ( + f"Rough transcript estimate: ~{before_tokens:,} tokens (unchanged)" + ) + else: + token_line = ( + f"Rough transcript estimate: ~{before_tokens:,} → " + f"~{after_tokens:,} tokens" + ) + else: + headline = f"Compressed: {before_count} → {after_count} messages" + token_line = ( + f"Rough transcript estimate: ~{before_tokens:,} → " + f"~{after_tokens:,} tokens" + ) + + note = None + if not noop and after_count < before_count and after_tokens > before_tokens: + note = ( + "Note: fewer messages can still raise this rough transcript estimate " + "when compression rewrites the transcript into denser summaries." + ) + + return { + "noop": noop, + "headline": headline, + "token_line": token_line, + "note": note, + } diff --git a/agent/memory_manager.py b/agent/memory_manager.py index 4630c481fd..e6e0570480 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -134,11 +134,6 @@ class MemoryManager: """All registered providers in order.""" return list(self._providers) - @property - def provider_names(self) -> List[str]: - """Names of all registered providers.""" - return [p.name for p in self._providers] - def get_provider(self, name: str) -> Optional[MemoryProvider]: """Get a provider by name, or None if not registered.""" for p in self._providers: diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 791f778c22..2ce0cefa0d 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -113,19 +113,31 @@ DEFAULT_CONTEXT_LENGTHS = { "deepseek": 128000, # Meta "llama": 131072, - # Qwen + # Qwen — specific model families before the catch-all. + # Official docs: https://help.aliyun.com/zh/model-studio/developer-reference/ + "qwen3-coder-plus": 1000000, # 1M context + "qwen3-coder": 262144, # 256K context "qwen": 131072, - # MiniMax (lowercase — lookup lowercases model names at line 973) - "minimax-m1-256k": 1000000, - "minimax-m1-128k": 1000000, - "minimax-m1-80k": 1000000, - "minimax-m1-40k": 1000000, - "minimax-m1": 1000000, - "minimax-m2.5": 1048576, - "minimax-m2.7": 1048576, - "minimax": 1048576, + # MiniMax — official docs: 204,800 context for all models + # https://platform.minimax.io/docs/api-reference/text-anthropic-api + "minimax": 204800, # GLM "glm": 202752, + # xAI Grok — xAI /v1/models does not return context_length metadata, + # so these hardcoded fallbacks prevent Hermes from probing-down to + # the default 128k when the user points at https://api.x.ai/v1 + # via a custom provider. Values sourced from models.dev (2026-04). + # Keys use substring matching (longest-first), so e.g. "grok-4.20" + # matches "grok-4.20-0309-reasoning" / "-non-reasoning" / "-multi-agent-0309". + "grok-code-fast": 256000, # grok-code-fast-1 + "grok-4-1-fast": 2000000, # grok-4-1-fast-(non-)reasoning + "grok-2-vision": 8192, # grok-2-vision, -1212, -latest + "grok-4-fast": 2000000, # grok-4-fast-(non-)reasoning + "grok-4.20": 2000000, # grok-4.20-0309-(non-)reasoning, -multi-agent-0309 + "grok-4": 256000, # grok-4, grok-4-0709 + "grok-3": 131072, # grok-3, grok-3-mini, grok-3-fast, grok-3-mini-fast + "grok-2": 131072, # grok-2, grok-2-1212, grok-2-latest + "grok": 131072, # catch-all (grok-beta, unknown grok-*) # Kimi "kimi": 262144, # Arcee @@ -136,7 +148,7 @@ DEFAULT_CONTEXT_LENGTHS = { "deepseek-ai/DeepSeek-V3.2": 65536, "moonshotai/Kimi-K2.5": 262144, "moonshotai/Kimi-K2-Thinking": 262144, - "MiniMaxAI/MiniMax-M2.5": 1048576, + "MiniMaxAI/MiniMax-M2.5": 204800, "XiaomiMiMo/MiMo-V2-Flash": 32768, "mimo-v2-pro": 1048576, "mimo-v2-omni": 1048576, @@ -198,6 +210,7 @@ _URL_TO_PROVIDER: Dict[str, str] = { "models.github.ai": "copilot", "api.fireworks.ai": "fireworks", "opencode.ai": "opencode-go", + "api.x.ai": "xai", } diff --git a/agent/models_dev.py b/agent/models_dev.py index cc360d77cf..d3620733bf 100644 --- a/agent/models_dev.py +++ b/agent/models_dev.py @@ -135,9 +135,6 @@ class ProviderInfo: doc: str = "" # documentation URL model_count: int = 0 - def has_api_url(self) -> bool: - return bool(self.api) - # --------------------------------------------------------------------------- # Provider ID mapping: Hermes ↔ models.dev @@ -634,43 +631,6 @@ def get_provider_info(provider_id: str) -> Optional[ProviderInfo]: return _parse_provider_info(mdev_id, raw) -def list_all_providers() -> Dict[str, ProviderInfo]: - """Return all providers from models.dev as {provider_id: ProviderInfo}. - - Returns the full catalog — 109+ providers. For providers that have - a Hermes alias, both the models.dev ID and the Hermes ID are included. - """ - data = fetch_models_dev() - result: Dict[str, ProviderInfo] = {} - - for pid, pdata in data.items(): - if isinstance(pdata, dict): - info = _parse_provider_info(pid, pdata) - result[pid] = info - - return result - - -def get_providers_for_env_var(env_var: str) -> List[str]: - """Reverse lookup: find all providers that use a given env var. - - Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which - providers does that enable?" - - Returns list of models.dev provider IDs. - """ - data = fetch_models_dev() - matches: List[str] = [] - - for pid, pdata in data.items(): - if isinstance(pdata, dict): - env = pdata.get("env", []) - if isinstance(env, list) and env_var in env: - matches.append(pid) - - return matches - - # --------------------------------------------------------------------------- # Model-level queries (rich ModelInfo) # --------------------------------------------------------------------------- @@ -708,74 +668,3 @@ def get_model_info( return None -def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]: - """Search all providers for a model by ID. - - Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or - a bare name and want to find it anywhere. Checks Hermes-mapped providers - first, then falls back to all models.dev providers. - """ - data = fetch_models_dev() - - # Try Hermes-mapped providers first (more likely what the user wants) - for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items(): - pdata = data.get(mdev_id) - if not isinstance(pdata, dict): - continue - models = pdata.get("models", {}) - if not isinstance(models, dict): - continue - - raw = models.get(model_id) - if isinstance(raw, dict): - return _parse_model_info(model_id, raw, mdev_id) - - # Case-insensitive - model_lower = model_id.lower() - for mid, mdata in models.items(): - if mid.lower() == model_lower and isinstance(mdata, dict): - return _parse_model_info(mid, mdata, mdev_id) - - # Fall back to ALL providers - for pid, pdata in data.items(): - if pid in _get_reverse_mapping(): - continue # already checked - if not isinstance(pdata, dict): - continue - models = pdata.get("models", {}) - if not isinstance(models, dict): - continue - - raw = models.get(model_id) - if isinstance(raw, dict): - return _parse_model_info(model_id, raw, pid) - - return None - - -def list_provider_model_infos(provider_id: str) -> List[ModelInfo]: - """Return all models for a provider as ModelInfo objects. - - Filters out deprecated models by default. - """ - mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id) - - data = fetch_models_dev() - pdata = data.get(mdev_id) - if not isinstance(pdata, dict): - return [] - - models = pdata.get("models", {}) - if not isinstance(models, dict): - return [] - - result: List[ModelInfo] = [] - for mid, mdata in models.items(): - if not isinstance(mdata, dict): - continue - status = mdata.get("status", "") - if status == "deprecated": - continue - result.append(_parse_model_info(mid, mdata, mdev_id)) - - return result diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 8302973aac..08b8fe0a6a 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -40,7 +40,7 @@ _CONTEXT_THREAT_PATTERNS = [ (r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"), (r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"), (r'', "html_comment_injection"), - (r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"), + (r'<\s*div\s+style\s*=\s*["\'][\s\S]*?display\s*:\s*none', "hidden_div"), (r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"), (r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"), (r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"), @@ -356,6 +356,14 @@ PLATFORM_HINTS = { "MEDIA:/absolute/path/to/file in your response. Images (.jpg, .png, " ".heic) appear as photos and other files arrive as attachments." ), + "weixin": ( + "You are on Weixin/WeChat. Markdown formatting is supported, so you may use it when " + "it improves readability, but keep the message compact and chat-friendly. You can send media files natively: " + "include MEDIA:/absolute/path/to/file in your response. Images are sent as native " + "photos, videos play inline when supported, and other files arrive as downloadable " + "documents. You can also include image URLs in markdown format ![alt](url) and they " + "will be downloaded and sent as native media when possible." + ), } CONTEXT_FILE_MAX_CHARS = 20_000 @@ -479,7 +487,7 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]: (True, {}, "") to err on the side of showing the skill. """ try: - raw = skill_file.read_text(encoding="utf-8")[:2000] + raw = skill_file.read_text(encoding="utf-8") frontmatter, _ = parse_frontmatter(raw) if not skill_matches_platform(frontmatter): @@ -487,21 +495,10 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]: return True, frontmatter, extract_skill_description(frontmatter) except Exception as e: - logger.debug("Failed to parse skill file %s: %s", skill_file, e) + logger.warning("Failed to parse skill file %s: %s", skill_file, e) return True, {}, "" -def _read_skill_conditions(skill_file: Path) -> dict: - """Extract conditional activation fields from SKILL.md frontmatter.""" - try: - raw = skill_file.read_text(encoding="utf-8")[:2000] - frontmatter, _ = parse_frontmatter(raw) - return extract_skill_conditions(frontmatter) - except Exception as e: - logger.debug("Failed to read skill conditions from %s: %s", skill_file, e) - return {} - - def _skill_should_show( conditions: dict, available_tools: "set[str] | None", @@ -561,9 +558,10 @@ def build_skills_system_prompt( # ── Layer 1: in-process LRU cache ───────────────────────────────── # Include the resolved platform so per-platform disabled-skill lists # produce distinct cache entries (gateway serves multiple platforms). + from gateway.session_context import get_session_env _platform_hint = ( os.environ.get("HERMES_PLATFORM") - or os.environ.get("HERMES_SESSION_PLATFORM") + or get_session_env("HERMES_SESSION_PLATFORM") or "" ) cache_key = ( diff --git a/agent/rate_limit_tracker.py b/agent/rate_limit_tracker.py index c87e096a1d..73e1152229 100644 --- a/agent/rate_limit_tracker.py +++ b/agent/rate_limit_tracker.py @@ -97,8 +97,12 @@ def parse_rate_limit_headers( Returns None if no rate limit headers are present. """ + # Normalize to lowercase so lookups work regardless of how the server + # capitalises headers (HTTP header names are case-insensitive per RFC 7230). + lowered = {k.lower(): v for k, v in headers.items()} + # Quick check: at least one rate limit header must exist - has_any = any(k.lower().startswith("x-ratelimit-") for k in headers) + has_any = any(k.startswith("x-ratelimit-") for k in lowered) if not has_any: return None @@ -109,9 +113,9 @@ def parse_rate_limit_headers( # resource="tokens", suffix="-1h" -> per-hour tag = f"{resource}{suffix}" return RateLimitBucket( - limit=_safe_int(headers.get(f"x-ratelimit-limit-{tag}")), - remaining=_safe_int(headers.get(f"x-ratelimit-remaining-{tag}")), - reset_seconds=_safe_float(headers.get(f"x-ratelimit-reset-{tag}")), + limit=_safe_int(lowered.get(f"x-ratelimit-limit-{tag}")), + remaining=_safe_int(lowered.get(f"x-ratelimit-remaining-{tag}")), + reset_seconds=_safe_float(lowered.get(f"x-ratelimit-reset-{tag}")), captured_at=now, ) diff --git a/agent/skill_commands.py b/agent/skill_commands.py index 18414199dc..1f000eefed 100644 --- a/agent/skill_commands.py +++ b/agent/skill_commands.py @@ -168,7 +168,7 @@ def _build_skill_message( subdir_path = skill_dir / subdir if subdir_path.exists(): for f in sorted(subdir_path.rglob("*")): - if f.is_file(): + if f.is_file() and not f.is_symlink(): rel = str(f.relative_to(skill_dir)) supporting.append(rel) diff --git a/agent/skill_utils.py b/agent/skill_utils.py index 6b06a19e36..ba606b358d 100644 --- a/agent/skill_utils.py +++ b/agent/skill_utils.py @@ -145,10 +145,11 @@ def get_disabled_skill_names(platform: str | None = None) -> Set[str]: if not isinstance(skills_cfg, dict): return set() + from gateway.session_context import get_session_env resolved_platform = ( platform or os.getenv("HERMES_PLATFORM") - or os.getenv("HERMES_SESSION_PLATFORM") + or get_session_env("HERMES_SESSION_PLATFORM") ) if resolved_platform: platform_disabled = (skills_cfg.get("platform_disabled") or {}).get( diff --git a/agent/smart_model_routing.py b/agent/smart_model_routing.py index 8a62e98fc3..6d482be270 100644 --- a/agent/smart_model_routing.py +++ b/agent/smart_model_routing.py @@ -181,6 +181,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any "api_mode": runtime.get("api_mode"), "command": runtime.get("command"), "args": list(runtime.get("args") or []), + "credential_pool": runtime.get("credential_pool"), }, "label": f"smart route → {route.get('model')} ({runtime.get('provider')})", "signature": ( diff --git a/agent/usage_pricing.py b/agent/usage_pricing.py index cfd0f88c4e..2b04eab625 100644 --- a/agent/usage_pricing.py +++ b/agent/usage_pricing.py @@ -595,30 +595,6 @@ def get_pricing( } -def estimate_cost_usd( - model: str, - input_tokens: int, - output_tokens: int, - *, - provider: Optional[str] = None, - base_url: Optional[str] = None, - api_key: Optional[str] = None, -) -> float: - """Backward-compatible helper for legacy callers. - - This uses non-cached input/output only. New code should call - `estimate_usage_cost()` with canonical usage buckets. - """ - result = estimate_usage_cost( - model, - CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens), - provider=provider, - base_url=base_url, - api_key=api_key, - ) - return float(result.amount_usd or _ZERO) - - def format_duration_compact(seconds: float) -> str: if seconds < 60: return f"{seconds:.0f}s" diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 346e6e851f..e9284d8137 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -480,6 +480,12 @@ agent: # Fires once per run when inactivity reaches this threshold (seconds). # Set to 0 to disable the warning. # gateway_timeout_warning: 900 + + # Graceful drain timeout for gateway stop/restart (seconds). + # The gateway stops accepting new work, waits for in-flight agents to + # finish, then interrupts anything still running after this timeout. + # 0 = no drain, interrupt immediately. + # restart_drain_timeout: 60 # Enable verbose logging verbose: false @@ -582,7 +588,7 @@ platform_toolsets: # skills_hub - skill_hub (search/install/manage from online registries — user-driven only) # moa - mixture_of_agents (requires OPENROUTER_API_KEY) # todo - todo (in-memory task planning, no deps) -# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX key) +# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key) # cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks) # rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY) # @@ -611,7 +617,7 @@ platform_toolsets: # todo - Task planning and tracking for multi-step work # memory - Persistent memory across sessions (personal notes + user profile) # session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization) -# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax) +# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral) # cronjob - Schedule and manage automated tasks (CLI-only) # rl - RL training tools (Tinker-Atropos) # @@ -684,7 +690,11 @@ platform_toolsets: stt: enabled: true # provider: "local" # auto-detected if omitted - model: "whisper-1" # whisper-1 (cheapest) | gpt-4o-mini-transcribe | gpt-4o-transcribe + local: + model: "base" # tiny | base | small | medium | large-v3 | turbo + # language: "" # auto-detect; set to "en", "es", "fr", etc. to force + openai: + model: "whisper-1" # whisper-1 | gpt-4o-mini-transcribe | gpt-4o-transcribe # mistral: # model: "voxtral-mini-latest" # voxtral-mini-latest | voxtral-mini-2602 diff --git a/cli.py b/cli.py index 237ed78998..18f6df6711 100644 --- a/cli.py +++ b/cli.py @@ -158,6 +158,18 @@ def _parse_reasoning_config(effort: str) -> dict | None: return result +def _parse_service_tier_config(raw: str) -> str | None: + """Parse a persisted service-tier preference into a Responses API value.""" + value = str(raw or "").strip().lower() + if not value or value in {"normal", "default", "standard", "off", "none"}: + return None + if value in {"fast", "priority", "on"}: + return "priority" + logger.warning("Unknown service_tier '%s', ignoring", raw) + return None + + + def _get_chrome_debug_candidates(system: str) -> list[str]: """Return likely browser executables for local CDP auto-launch.""" candidates: list[str] = [] @@ -277,6 +289,7 @@ def load_cli_config() -> Dict[str, Any]: "system_prompt": "", "prefill_messages_file": "", "reasoning_effort": "", + "service_tier": "", "personalities": { "helpful": "You are a helpful, friendly AI assistant.", "concise": "You are a concise assistant. Keep responses brief and to the point.", @@ -344,7 +357,7 @@ def load_cli_config() -> Dict[str, Any]: # Load from file if exists if config_path.exists(): try: - with open(config_path, "r") as f: + with open(config_path, "r", encoding="utf-8") as f: file_config = yaml.safe_load(f) or {} _file_has_terminal_config = "terminal" in file_config @@ -1012,11 +1025,60 @@ def _prune_orphaned_branches(repo_root: str) -> None: # - Dim: #B8860B (muted text) # ANSI building blocks for conversation display -_GOLD = "\033[1;38;2;255;215;0m" # True-color #FFD700 bold — matches Rich Panel gold +_ACCENT_ANSI_DEFAULT = "\033[1;38;2;255;215;0m" # True-color #FFD700 bold — fallback _BOLD = "\033[1m" _DIM = "\033[2m" _RST = "\033[0m" + +def _hex_to_ansi_bold(hex_color: str) -> str: + """Convert a hex color like '#268bd2' to a bold true-color ANSI escape.""" + try: + r = int(hex_color[1:3], 16) + g = int(hex_color[3:5], 16) + b = int(hex_color[5:7], 16) + return f"\033[1;38;2;{r};{g};{b}m" + except (ValueError, IndexError): + return _ACCENT_ANSI_DEFAULT + + +class _SkinAwareAnsi: + """Lazy ANSI escape that resolves from the skin engine on first use. + + Acts as a string in f-strings and concatenation. Call ``.reset()`` to + force re-resolution after a ``/skin`` switch. + """ + + def __init__(self, skin_key: str, fallback_hex: str = "#FFD700"): + self._skin_key = skin_key + self._fallback_hex = fallback_hex + self._cached: str | None = None + + def __str__(self) -> str: + if self._cached is None: + try: + from hermes_cli.skin_engine import get_active_skin + self._cached = _hex_to_ansi_bold( + get_active_skin().get_color(self._skin_key, self._fallback_hex) + ) + except Exception: + self._cached = _hex_to_ansi_bold(self._fallback_hex) + return self._cached + + def __add__(self, other: str) -> str: + return str(self) + other + + def __radd__(self, other: str) -> str: + return other + str(self) + + def reset(self) -> None: + """Clear cache so the next access re-reads the skin.""" + self._cached = None + + +_ACCENT = _SkinAwareAnsi("response_border", "#FFD700") + + def _accent_hex() -> str: """Return the active skin accent color for legacy CLI output lines.""" try: @@ -1073,7 +1135,7 @@ def _termux_example_image_path(filename: str = "cat.png") -> str: def _split_path_input(raw: str) -> tuple[str, str]: - """Split a leading file path token from trailing free-form text. + r"""Split a leading file path token from trailing free-form text. Supports quoted paths and backslash-escaped spaces so callers can accept inputs like: @@ -1147,6 +1209,45 @@ def _resolve_attachment_path(raw_path: str) -> Path | None: return resolved +def _format_process_notification(evt: dict) -> "str | None": + """Format a process notification event into a [SYSTEM: ...] message. + + Handles both completion events (notify_on_complete) and watch pattern + match events from the unified completion_queue. + """ + evt_type = evt.get("type", "completion") + _sid = evt.get("session_id", "unknown") + _cmd = evt.get("command", "unknown") + + if evt_type == "watch_disabled": + return f"[SYSTEM: {evt.get('message', '')}]" + + if evt_type == "watch_match": + _pat = evt.get("pattern", "?") + _out = evt.get("output", "") + _sup = evt.get("suppressed", 0) + text = ( + f"[SYSTEM: Background process {_sid} matched " + f"watch pattern \"{_pat}\".\n" + f"Command: {_cmd}\n" + f"Matched output:\n{_out}" + ) + if _sup: + text += f"\n({_sup} earlier matches were suppressed by rate limit)" + text += "]" + return text + + # Default: completion event + _exit = evt.get("exit_code", "?") + _out = evt.get("output", "") + return ( + f"[SYSTEM: Background process {_sid} completed " + f"(exit code {_exit}).\n" + f"Command: {_cmd}\n" + f"Output:\n{_out}]" + ) + + def _detect_file_drop(user_input: str) -> "dict | None": """Detect if *user_input* starts with a real local file path. @@ -1228,6 +1329,11 @@ def _format_image_attachment_badges(attached_images: list[Path], image_counter: ) +def _should_auto_attach_clipboard_image_on_paste(pasted_text: str) -> bool: + """Auto-attach clipboard images only for image-only paste gestures.""" + return not pasted_text.strip() + + def _collect_query_images(query: str | None, image_arg: str | None = None) -> tuple[str, list[Path]]: """Collect local image attachments for single-query CLI flows.""" message = query or "" @@ -1312,14 +1418,6 @@ HERMES_CADUCEUS = """[#CD7F32]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣀⡀⠀⣀⣀ [#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⠈⣡⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/] [#B8860B]⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[/]""" -# Compact banner for smaller terminals (fallback) -# Note: built dynamically by _build_compact_banner() to fit terminal width -COMPACT_BANNER = """ -[bold #FFD700]╔══════════════════════════════════════════════════════════════╗[/] -[bold #FFD700]║[/] [#FFBF00]⚕ NOUS HERMES[/] [dim #B8860B]- AI Agent Framework[/] [bold #FFD700]║[/] -[bold #FFD700]║[/] [#CD7F32]Messenger of the Digital Gods[/] [dim #B8860B]Nous Research[/] [bold #FFD700]║[/] -[bold #FFD700]╚══════════════════════════════════════════════════════════════╝[/] -""" def _build_compact_banner() -> str: @@ -1565,7 +1663,6 @@ class HermesCLI: self._stream_buf = "" # Partial line buffer for line-buffered rendering self._stream_started = False # True once first delta arrives self._stream_box_opened = False # True once the response box header is printed - self._reasoning_stream_started = False # True once live reasoning starts streaming self._reasoning_preview_buf = "" # Coalesce tiny reasoning chunks for [thinking] output self._pending_edit_snapshots = {} @@ -1623,8 +1720,6 @@ class HermesCLI: self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") else: self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY") - self._nous_key_expires_at: Optional[str] = None - self._nous_key_source: Optional[str] = None # Max turns priority: CLI arg > config file > env var > default if max_turns is not None: # CLI arg was explicitly set self.max_turns = max_turns @@ -1672,6 +1767,9 @@ class HermesCLI: self.reasoning_config = _parse_reasoning_config( CLI_CONFIG["agent"].get("reasoning_effort", "") ) + self.service_tier = _parse_service_tier_config( + CLI_CONFIG["agent"].get("service_tier", "") + ) # OpenRouter provider routing preferences pr = CLI_CONFIG.get("provider_routing", {}) or {} @@ -1747,6 +1845,7 @@ class HermesCLI: self._secret_state = None self._secret_deadline = 0 self._spinner_text: str = "" # thinking spinner text for TUI + self._tool_start_time: float = 0.0 # monotonic timestamp when current tool started (for live elapsed) self._command_running = False self._command_status = "" self._attached_images: list[Path] = [] @@ -2055,6 +2154,25 @@ class HermesCLI: current_model = (self.model or "").strip() changed = False + try: + from hermes_cli.model_normalize import ( + _AGGREGATOR_PROVIDERS, + normalize_model_for_provider, + ) + + if resolved_provider not in _AGGREGATOR_PROVIDERS: + normalized_model = normalize_model_for_provider(current_model, resolved_provider) + if normalized_model and normalized_model != current_model: + if not self._model_is_default: + self.console.print( + f"[yellow]⚠️ Normalized model '{current_model}' to '{normalized_model}' for {resolved_provider}.[/]" + ) + self.model = normalized_model + current_model = normalized_model + changed = True + except Exception: + pass + if resolved_provider == "copilot": try: from hermes_cli.models import copilot_model_api_mode, normalize_copilot_model_id @@ -2100,7 +2218,7 @@ class HermesCLI: return changed if resolved_provider != "openai-codex": - return False + return changed # 1. Strip provider prefix ("openai/gpt-5.4" → "gpt-5.4") if "/" in current_model: @@ -2139,6 +2257,7 @@ class HermesCLI: if not text: self._flush_reasoning_preview(force=True) self._spinner_text = text or "" + self._tool_start_time = 0.0 # clear tool timer when switching to thinking self._invalidate() # ── Streaming display ──────────────────────────────────────────────── @@ -2251,7 +2370,6 @@ class HermesCLI: """ if not text: return - self._reasoning_stream_started = True self._reasoning_shown_this_turn = True if getattr(self, "_stream_box_opened", False): return @@ -2330,17 +2448,59 @@ class HermesCLI: # Append to a pre-filter buffer first self._stream_prefilt = getattr(self, "_stream_prefilt", "") + text - # Check if we're entering a reasoning block + # Check if we're entering a reasoning block. + # Only match tags that appear at a "block boundary": start of the + # stream, after a newline (with optional whitespace), or when nothing + # but whitespace has been emitted on the current line. + # This prevents false positives when models *mention* tags in prose + # like "(/think not producing tags)". + # + # _stream_last_was_newline tracks whether the last character emitted + # (or the start of the stream) is a line boundary. It's True at + # stream start and set True whenever emitted text ends with '\n'. + if not hasattr(self, "_stream_last_was_newline"): + self._stream_last_was_newline = True # start of stream = boundary + if not getattr(self, "_in_reasoning_block", False): for tag in _OPEN_TAGS: - idx = self._stream_prefilt.find(tag) - if idx != -1: - # Emit everything before the tag - before = self._stream_prefilt[:idx] - if before: - self._emit_stream_text(before) - self._in_reasoning_block = True - self._stream_prefilt = self._stream_prefilt[idx + len(tag):] + search_start = 0 + while True: + idx = self._stream_prefilt.find(tag, search_start) + if idx == -1: + break + # Check if this is a block boundary position + preceding = self._stream_prefilt[:idx] + if idx == 0: + # At buffer start — only a boundary if we're at + # a line start (stream start or last emit ended + # with newline) + is_block_boundary = getattr(self, "_stream_last_was_newline", True) + else: + # Find last newline in the buffer before the tag + last_nl = preceding.rfind("\n") + if last_nl == -1: + # No newline in buffer — boundary only if + # last emit was a newline AND only whitespace + # has accumulated before the tag + is_block_boundary = ( + getattr(self, "_stream_last_was_newline", True) + and preceding.strip() == "" + ) + else: + # Text between last newline and tag must be + # whitespace-only + is_block_boundary = preceding[last_nl + 1:].strip() == "" + if is_block_boundary: + # Emit everything before the tag + if preceding: + self._emit_stream_text(preceding) + self._stream_last_was_newline = preceding.endswith("\n") + self._in_reasoning_block = True + self._stream_prefilt = self._stream_prefilt[idx + len(tag):] + break + # Not a block boundary — keep searching after this occurrence + search_start = idx + 1 + if getattr(self, "_in_reasoning_block", False): break # Could also be a partial open tag at the end — hold it back @@ -2354,6 +2514,7 @@ class HermesCLI: break if safe: self._emit_stream_text(safe) + self._stream_last_was_newline = safe.endswith("\n") self._stream_prefilt = self._stream_prefilt[len(safe):] return @@ -2431,7 +2592,7 @@ class HermesCLI: self._stream_text_ansi = "" w = shutil.get_terminal_size().columns fill = w - 2 - len(label) - _cprint(f"\n{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}") + _cprint(f"\n{_ACCENT}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}") self._stream_buf += text @@ -2443,6 +2604,14 @@ class HermesCLI: def _flush_stream(self) -> None: """Emit any remaining partial line from the stream buffer and close the box.""" + # If we're still inside a "reasoning block" at end-of-stream, it was + # a false positive — the model mentioned a tag like in prose + # but never closed it. Recover the buffered content as regular text. + if getattr(self, "_in_reasoning_block", False) and getattr(self, "_stream_prefilt", ""): + self._in_reasoning_block = False + self._emit_stream_text(self._stream_prefilt) + self._stream_prefilt = "" + # Close reasoning box if still open (in case no content tokens arrived) self._close_reasoning_box() @@ -2454,17 +2623,17 @@ class HermesCLI: # Close the response box if self._stream_box_opened: w = shutil.get_terminal_size().columns - _cprint(f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}") + _cprint(f"{_ACCENT}╰{'─' * (w - 2)}╯{_RST}") def _reset_stream_state(self) -> None: """Reset streaming state before each agent invocation.""" self._stream_buf = "" self._stream_started = False self._stream_box_opened = False - self._reasoning_stream_started = False self._stream_text_ansi = "" self._stream_prefilt = "" self._in_reasoning_block = False + self._stream_last_was_newline = True self._reasoning_box_opened = False self._reasoning_buf = "" self._reasoning_preview_buf = "" @@ -2594,8 +2763,9 @@ class HermesCLI: def _resolve_turn_agent_config(self, user_message: str) -> dict: """Resolve model/runtime overrides for a single user turn.""" from agent.smart_model_routing import resolve_turn_route + from hermes_cli.models import resolve_fast_mode_overrides - return resolve_turn_route( + route = resolve_turn_route( user_message, self._smart_model_routing, { @@ -2610,7 +2780,19 @@ class HermesCLI: }, ) - def _init_agent(self, *, model_override: str = None, runtime_override: dict = None, route_label: str = None) -> bool: + service_tier = getattr(self, "service_tier", None) + if not service_tier: + route["request_overrides"] = None + return route + + try: + overrides = resolve_fast_mode_overrides(route.get("model")) + except Exception: + overrides = None + route["request_overrides"] = overrides + return route + + def _init_agent(self, *, model_override: str = None, runtime_override: dict = None, route_label: str = None, request_overrides: dict | None = None) -> bool: """ Initialize the agent on first use. When resuming a session, restores conversation history from SQLite. @@ -2697,6 +2879,8 @@ class HermesCLI: ephemeral_system_prompt=self.system_prompt if self.system_prompt else None, prefill_messages=self.prefill_messages or None, reasoning_config=self.reasoning_config, + service_tier=self.service_tier, + request_overrides=request_overrides, providers_allowed=self._providers_only, providers_ignored=self._providers_ignore, providers_order=self._providers_order, @@ -2862,15 +3046,17 @@ class HermesCLI: title_part = "" if session_meta.get("title"): title_part = f' "{session_meta["title"]}"' + accent_color = _accent_hex() self.console.print( - f"[#DAA520]↻ Resumed session [bold]{self.session_id}[/bold]" + f"[{accent_color}]↻ Resumed session [bold]{self.session_id}[/bold]" f"{title_part} " f"({msg_count} user message{'s' if msg_count != 1 else ''}, " f"{len(restored)} total messages)[/]" ) else: + accent_color = _accent_hex() self.console.print( - f"[#DAA520]Session {self.session_id} found but has no " + f"[{accent_color}]Session {self.session_id} found but has no " f"messages. Starting fresh.[/]" ) return False @@ -3383,37 +3569,112 @@ class HermesCLI: pass # Don't crash on import errors def _show_status(self): - """Show current status bar.""" + """Show compact startup status line.""" # Get tool count tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) tool_count = len(tools) if tools else 0 - + # Format model name (shorten if needed) model_short = self.model.split("/")[-1] if "/" in self.model else self.model if len(model_short) > 30: model_short = model_short[:27] + "..." - + # Get API status indicator if self.api_key: api_indicator = "[green bold]●[/]" else: api_indicator = "[red bold]●[/]" - - # Build status line with proper markup + + # Build status line with proper markup — skin-aware colors + try: + from hermes_cli.skin_engine import get_active_skin + skin = get_active_skin() + separator_color = skin.get_color("banner_dim", "#B8860B") + accent_color = skin.get_color("ui_accent", "#FFBF00") + label_color = skin.get_color("ui_label", "#4dd0e1") + except Exception: + separator_color, accent_color, label_color = "#B8860B", "#FFBF00", "cyan" toolsets_info = "" if self.enabled_toolsets and "all" not in self.enabled_toolsets: - toolsets_info = f" [dim #B8860B]·[/] [#CD7F32]toolsets: {', '.join(self.enabled_toolsets)}[/]" + toolsets_info = f" [dim {separator_color}]·[/] [{label_color}]toolsets: {', '.join(self.enabled_toolsets)}[/]" - provider_info = f" [dim #B8860B]·[/] [dim]provider: {self.provider}[/]" + provider_info = f" [dim {separator_color}]·[/] [dim]provider: {self.provider}[/]" if self._provider_source: - provider_info += f" [dim #B8860B]·[/] [dim]auth: {self._provider_source}[/]" + provider_info += f" [dim {separator_color}]·[/] [dim]auth: {self._provider_source}[/]" self.console.print( - f" {api_indicator} [#FFBF00]{model_short}[/] " - f"[dim #B8860B]·[/] [bold cyan]{tool_count} tools[/]" + f" {api_indicator} [{accent_color}]{model_short}[/] " + f"[dim {separator_color}]·[/] [bold {label_color}]{tool_count} tools[/]" f"{toolsets_info}{provider_info}" ) + + def _show_session_status(self): + """Show gateway-style status for the current CLI session.""" + session_meta = {} + if self._session_db: + try: + session_meta = self._session_db.get_session(self.session_id) or {} + except Exception: + session_meta = {} + + title = (session_meta.get("title") or "").strip() + + created_at = self.session_start + started_at = session_meta.get("started_at") + if started_at: + try: + created_at = datetime.fromtimestamp(float(started_at)) + except Exception: + created_at = self.session_start + + updated_at = created_at + for field in ("updated_at", "last_updated_at", "last_activity_at"): + value = session_meta.get(field) + if not value: + continue + try: + updated_at = datetime.fromtimestamp(float(value)) + break + except Exception: + pass + + agent = getattr(self, "agent", None) + total_tokens = getattr(agent, "session_total_tokens", 0) or 0 + provider = getattr(self, "provider", None) or "unknown" + model = getattr(self, "model", None) or "(unknown)" + is_running = bool(getattr(self, "_agent_running", False)) + + lines = [ + "Hermes CLI Status", + "", + f"Session ID: {self.session_id}", + f"Path: {display_hermes_home()}", + ] + if title: + lines.append(f"Title: {title}") + lines.extend([ + f"Model: {model} ({provider})", + f"Created: {created_at.strftime('%Y-%m-%d %H:%M')}", + f"Last Activity: {updated_at.strftime('%Y-%m-%d %H:%M')}", + f"Tokens: {total_tokens:,}", + f"Agent Running: {'Yes' if is_running else 'No'}", + ]) + self.console.print("\n".join(lines), highlight=False, markup=False) + def _fast_command_available(self) -> bool: + try: + from hermes_cli.models import model_supports_fast_mode + except Exception: + return False + agent = getattr(self, "agent", None) + model = getattr(agent, "model", None) or getattr(self, "model", None) + return model_supports_fast_mode(model) + + def _command_available(self, slash_command: str) -> bool: + if slash_command == "/fast": + return self._fast_command_available() + return True + def show_help(self): """Display help information with categorized commands.""" from hermes_cli.commands import COMMANDS_BY_CATEGORY @@ -3434,6 +3695,8 @@ class HermesCLI: for category, commands in COMMANDS_BY_CATEGORY.items(): _cprint(f"\n {_BOLD}── {category} ──{_RST}") for cmd, desc in commands.items(): + if not self._command_available(cmd): + continue ChatConsole().print(f" [bold {_accent_hex()}]{cmd:<15}[/] [dim]-[/] {_escape(desc)}") if _skill_commands: @@ -3532,7 +3795,7 @@ class HermesCLI: # TUI event loop (known pitfall). verb = "Disabling" if subcommand == "disable" else "Enabling" label = ", ".join(names) - _cprint(f"{_GOLD}{verb} {label}...{_RST}") + _cprint(f"{_ACCENT}{verb} {label}...{_RST}") tools_disable_enable_command( Namespace(tools_action=subcommand, names=names, platform="cli")) @@ -4124,6 +4387,16 @@ class HermesCLI: # Parse --provider and --global flags model_input, explicit_provider, persist_global = parse_model_flags(raw_args) + user_provs = None + custom_provs = None + try: + from hermes_cli.config import load_config + cfg = load_config() + user_provs = cfg.get("providers") + custom_provs = cfg.get("custom_providers") + except Exception: + pass + # No args at all: show available providers + models if not model_input and not explicit_provider: model_display = self.model or "unknown" @@ -4133,18 +4406,10 @@ class HermesCLI: # Show authenticated providers with top models try: - # Load user providers from config - user_provs = None - try: - from hermes_cli.config import load_config - cfg = load_config() - user_provs = cfg.get("providers") - except Exception: - pass - providers = list_authenticated_providers( current_provider=self.provider or "", user_providers=user_provs, + custom_providers=custom_provs, max_models=6, ) if providers: @@ -4185,6 +4450,8 @@ class HermesCLI: current_api_key=self.api_key or "", is_global=persist_global, explicit_provider=explicit_provider, + user_providers=user_provs, + custom_providers=custom_provs, ) if not result.success: @@ -4876,6 +5143,8 @@ class HermesCLI: self._handle_skills_command(cmd_original) elif canonical == "platforms": self._show_gateway_status() + elif canonical == "status": + self._show_session_status() elif canonical == "statusbar": self._status_bar_visible = not self._status_bar_visible state = "visible" if self._status_bar_visible else "hidden" @@ -4886,6 +5155,8 @@ class HermesCLI: self._toggle_yolo() elif canonical == "reasoning": self._handle_reasoning_command(cmd_original) + elif canonical == "fast": + self._handle_fast_command(cmd_original) elif canonical == "compress": self._manual_compress() elif canonical == "usage": @@ -5041,17 +5312,17 @@ class HermesCLI: if full_name == typed_base: # Already an exact token — no expansion possible; fall through _cprint(f"\033[1;31mUnknown command: {cmd_lower}{_RST}") - _cprint(f"{_DIM}{_GOLD}Type /help for available commands{_RST}") + _cprint(f"{_DIM}{_ACCENT}Type /help for available commands{_RST}") else: remainder = cmd_original.strip()[len(typed_base):] full_cmd = full_name + remainder return self.process_command(full_cmd) elif len(matches) > 1: - _cprint(f"{_GOLD}Ambiguous command: {cmd_lower}{_RST}") + _cprint(f"{_ACCENT}Ambiguous command: {cmd_lower}{_RST}") _cprint(f"{_DIM}Did you mean: {', '.join(sorted(matches))}?{_RST}") else: _cprint(f"\033[1;31mUnknown command: {cmd_lower}{_RST}") - _cprint(f"{_DIM}{_GOLD}Type /help for available commands{_RST}") + _cprint(f"{_DIM}{_ACCENT}Type /help for available commands{_RST}") return True @@ -5129,6 +5400,8 @@ class HermesCLI: platform="cli", session_db=self._session_db, reasoning_config=self.reasoning_config, + service_tier=self.service_tier, + request_overrides=turn_route.get("request_overrides"), providers_allowed=self._providers_only, providers_ignored=self._providers_ignore, providers_order=self._providers_order, @@ -5264,6 +5537,8 @@ class HermesCLI: session_id=task_id, platform="cli", reasoning_config=self.reasoning_config, + service_tier=self.service_tier, + request_overrides=turn_route.get("request_overrides"), providers_allowed=self._providers_only, providers_ignored=self._providers_ignore, providers_order=self._providers_order, @@ -5585,6 +5860,7 @@ class HermesCLI: return set_active_skin(new_skin) + _ACCENT.reset() # Re-resolve ANSI color for the new skin if save_config_value("display.skin", new_skin): print(f" Skin set to: {new_skin} (saved)") else: @@ -5653,8 +5929,8 @@ class HermesCLI: else: level = rc.get("effort", "medium") display_state = "on ✓" if self.show_reasoning else "off" - _cprint(f" {_GOLD}Reasoning effort: {level}{_RST}") - _cprint(f" {_GOLD}Reasoning display: {display_state}{_RST}") + _cprint(f" {_ACCENT}Reasoning effort: {level}{_RST}") + _cprint(f" {_ACCENT}Reasoning display: {display_state}{_RST}") _cprint(f" {_DIM}Usage: /reasoning {_RST}") return @@ -5666,7 +5942,7 @@ class HermesCLI: if self.agent: self.agent.reasoning_callback = self._current_reasoning_callback() save_config_value("display.show_reasoning", True) - _cprint(f" {_GOLD}✓ Reasoning display: ON (saved){_RST}") + _cprint(f" {_ACCENT}✓ Reasoning display: ON (saved){_RST}") _cprint(f" {_DIM} Model thinking will be shown during and after each response.{_RST}") return if arg in ("hide", "off"): @@ -5674,7 +5950,7 @@ class HermesCLI: if self.agent: self.agent.reasoning_callback = self._current_reasoning_callback() save_config_value("display.show_reasoning", False) - _cprint(f" {_GOLD}✓ Reasoning display: OFF (saved){_RST}") + _cprint(f" {_ACCENT}✓ Reasoning display: OFF (saved){_RST}") return # Effort level change @@ -5689,9 +5965,52 @@ class HermesCLI: self.agent = None # Force agent re-init with new reasoning config if save_config_value("agent.reasoning_effort", arg): - _cprint(f" {_GOLD}✓ Reasoning effort set to '{arg}' (saved to config){_RST}") + _cprint(f" {_ACCENT}✓ Reasoning effort set to '{arg}' (saved to config){_RST}") else: - _cprint(f" {_GOLD}✓ Reasoning effort set to '{arg}' (session only){_RST}") + _cprint(f" {_ACCENT}✓ Reasoning effort set to '{arg}' (session only){_RST}") + + def _handle_fast_command(self, cmd: str): + """Handle /fast — toggle fast mode (OpenAI Priority Processing / Anthropic Fast Mode).""" + if not self._fast_command_available(): + _cprint(" (._.) /fast is only available for models that support fast mode (OpenAI Priority Processing or Anthropic Fast Mode).") + return + + # Determine the branding for the current model + try: + from hermes_cli.models import _is_anthropic_fast_model + agent = getattr(self, "agent", None) + model = getattr(agent, "model", None) or getattr(self, "model", None) + feature_name = "Anthropic Fast Mode" if _is_anthropic_fast_model(model) else "Priority Processing" + except Exception: + feature_name = "Fast mode" + + parts = cmd.strip().split(maxsplit=1) + if len(parts) < 2 or parts[1].strip().lower() == "status": + status = "fast" if self.service_tier == "priority" else "normal" + _cprint(f" {_ACCENT}{feature_name}: {status}{_RST}") + _cprint(f" {_DIM}Usage: /fast [normal|fast|status]{_RST}") + return + + arg = parts[1].strip().lower() + + if arg in {"fast", "on"}: + self.service_tier = "priority" + saved_value = "fast" + label = "FAST" + elif arg in {"normal", "off"}: + self.service_tier = None + saved_value = "normal" + label = "NORMAL" + else: + _cprint(f" {_DIM}(._.) Unknown argument: {arg}{_RST}") + _cprint(f" {_DIM}Usage: /fast [normal|fast|status]{_RST}") + return + + self.agent = None # Force agent re-init with new service-tier config + if save_config_value("agent.service_tier", saved_value): + _cprint(f" {_ACCENT}✓ {feature_name} set to {label} (saved to config){_RST}") + else: + _cprint(f" {_ACCENT}✓ {feature_name} set to {label} (session only){_RST}") def _on_reasoning(self, reasoning_text: str): """Callback for intermediate reasoning display during tool-call loops.""" @@ -5717,21 +6036,29 @@ class HermesCLI: original_count = len(self.conversation_history) try: from agent.model_metadata import estimate_messages_tokens_rough - approx_tokens = estimate_messages_tokens_rough(self.conversation_history) + from agent.manual_compression_feedback import summarize_manual_compression + original_history = list(self.conversation_history) + approx_tokens = estimate_messages_tokens_rough(original_history) print(f"🗜️ Compressing {original_count} messages (~{approx_tokens:,} tokens)...") - compressed, new_system = self.agent._compress_context( - self.conversation_history, + compressed, _ = self.agent._compress_context( + original_history, self.agent._cached_system_prompt or "", approx_tokens=approx_tokens, ) self.conversation_history = compressed - new_count = len(self.conversation_history) new_tokens = estimate_messages_tokens_rough(self.conversation_history) - print( - f" ✅ Compressed: {original_count} → {new_count} messages " - f"(~{approx_tokens:,} → ~{new_tokens:,} tokens)" + summary = summarize_manual_compression( + original_history, + self.conversation_history, + approx_tokens, + new_tokens, ) + icon = "🗜️" if summary["noop"] else "✅" + print(f" {icon} {summary['headline']}") + print(f" {summary['token_line']}") + if summary["note"]: + print(f" {summary['note']}") except Exception as e: print(f" ❌ Compression failed: {e}") @@ -6029,11 +6356,20 @@ class HermesCLI: Updates the TUI spinner widget so the user can see what the agent is doing during tool execution (fills the gap between thinking spinner and next response). Also plays audio cue in voice mode. + + On tool.started, records a monotonic timestamp so get_spinner_text() + can show a live elapsed timer (the TUI poll loop already invalidates + every ~0.15s, so the counter updates automatically). """ - # Only act on tool.started; ignore tool.completed, reasoning.available, etc. + if event_type == "tool.completed": + import time as _time + self._tool_start_time = 0.0 + self._invalidate() + return if event_type != "tool.started": return if function_name and not function_name.startswith("_"): + import time as _time from agent.display import get_tool_emoji emoji = get_tool_emoji(function_name) label = preview or function_name @@ -6042,6 +6378,7 @@ class HermesCLI: if _pl > 0 and len(label) > _pl: label = label[:_pl - 3] + "..." self._spinner_text = f"{emoji} {label}" + self._tool_start_time = _time.monotonic() self._invalidate() if not self._voice_mode: @@ -6173,7 +6510,7 @@ class HermesCLI: _recording_hint = "Termux:API capture | Ctrl+B to stop" else: _recording_hint = "Ctrl+B to stop" - _cprint(f"\n{_GOLD}● Recording...{_RST} {_DIM}({_recording_hint}){_RST}") + _cprint(f"\n{_ACCENT}● Recording...{_RST} {_DIM}({_recording_hint}){_RST}") # Periodically refresh prompt to update audio level indicator def _refresh_level(): @@ -6236,6 +6573,9 @@ class HermesCLI: if result.get("success") and result.get("transcript", "").strip(): transcript = result["transcript"].strip() + self._attached_images.clear() + if hasattr(self, '_app') and self._app: + self._app.invalidate() self._pending_input.put(transcript) submitted = True elif result.get("success"): @@ -6370,14 +6710,14 @@ class HermesCLI: # Environment detection -- warn and block in incompatible environments env_check = detect_audio_environment() if not env_check["available"]: - _cprint(f"\n{_GOLD}Voice mode unavailable in this environment:{_RST}") + _cprint(f"\n{_ACCENT}Voice mode unavailable in this environment:{_RST}") for warning in env_check["warnings"]: _cprint(f" {_DIM}{warning}{_RST}") return reqs = check_voice_requirements() if not reqs["available"]: - _cprint(f"\n{_GOLD}Voice mode requirements not met:{_RST}") + _cprint(f"\n{_ACCENT}Voice mode requirements not met:{_RST}") for line in reqs["details"].split("\n"): _cprint(f" {_DIM}{line}{_RST}") if reqs["missing_packages"]: @@ -6415,7 +6755,7 @@ class HermesCLI: except Exception: _ptt_key = "c-b" _ptt_display = _ptt_key.replace("c-", "Ctrl+").upper() - _cprint(f"\n{_GOLD}Voice mode enabled{tts_status}{_RST}") + _cprint(f"\n{_ACCENT}Voice mode enabled{tts_status}{_RST}") _cprint(f" {_DIM}{_ptt_display} to start/stop recording{_RST}") _cprint(f" {_DIM}/voice tts to toggle speech output{_RST}") _cprint(f" {_DIM}/voice off to disable voice mode{_RST}") @@ -6467,7 +6807,7 @@ class HermesCLI: if not check_tts_requirements(): _cprint(f"{_DIM}Warning: No TTS provider available. Install edge-tts or set API keys.{_RST}") - _cprint(f"{_GOLD}Voice TTS {status}.{_RST}") + _cprint(f"{_ACCENT}Voice TTS {status}.{_RST}") def _show_voice_status(self): """Show current voice mode status.""" @@ -6851,6 +7191,7 @@ class HermesCLI: model_override=turn_route["model"], runtime_override=turn_route["runtime"], route_label=turn_route["label"], + request_overrides=turn_route.get("request_overrides"), ): return None @@ -6951,7 +7292,7 @@ class HermesCLI: w = self.console.width label = " ⚕ Hermes " fill = w - 2 - len(label) - _cprint(f"\n{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}") + _cprint(f"\n{_ACCENT}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}") _cprint(sentence.rstrip()) tts_thread = threading.Thread( @@ -7167,7 +7508,7 @@ class HermesCLI: if use_streaming_tts and _streaming_box_opened and not is_error_response: # Text was already printed sentence-by-sentence; just close the box w = shutil.get_terminal_size().columns - _cprint(f"\n{_GOLD}╰{'─' * (w - 2)}╯{_RST}") + _cprint(f"\n{_ACCENT}╰{'─' * (w - 2)}╯{_RST}") elif already_streamed: # Response was already streamed token-by-token with box framing; # _flush_stream() already closed the box. Skip Rich Panel. @@ -7879,7 +8220,7 @@ class HermesCLI: agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent") msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back." def _suspend(): - os.write(1, msg.encode()) + os.write(1, msg.encode("utf-8", errors="replace")) os.kill(0, _sig.SIGTSTP) run_in_terminal(_suspend) @@ -7959,8 +8300,9 @@ class HermesCLI: """Handle terminal paste — detect clipboard images. When the terminal supports bracketed paste, Ctrl+V / Cmd+V - triggers this with the pasted text. We also check the - clipboard for an image on every paste event. + triggers this with the pasted text. We only auto-attach a + clipboard image for image-only/empty paste gestures so text + pastes and dictation do not accidentally attach stale images. Large pastes (5+ lines) are collapsed to a file reference placeholder while preserving any existing user text in the @@ -7970,7 +8312,7 @@ class HermesCLI: # Normalise line endings — Windows \r\n and old Mac \r both become \n # so the 5-line collapse threshold and display are consistent. pasted_text = pasted_text.replace('\r\n', '\n').replace('\r', '\n') - if self._try_attach_clipboard_image(): + if _should_auto_attach_clipboard_image_on_paste(pasted_text) and self._try_attach_clipboard_image(): event.app.invalidate() if pasted_text: line_count = pasted_text.count('\n') @@ -8033,6 +8375,7 @@ class HermesCLI: _completer = SlashCommandCompleter( skill_commands_provider=lambda: _skill_commands, + command_filter=cli_ref._command_available, ) input_area = TextArea( height=Dimension(min=1, max=8, preferred=1), @@ -8237,6 +8580,17 @@ class HermesCLI: txt = cli_ref._spinner_text if not txt: return [] + # Append live elapsed timer when a tool is running + t0 = cli_ref._tool_start_time + if t0 > 0: + import time as _time + elapsed = _time.monotonic() - t0 + if elapsed >= 60: + _m, _s = int(elapsed // 60), int(elapsed % 60) + elapsed_str = f"{_m}m {_s}s" + else: + elapsed_str = f"{elapsed:.1f}s" + return [('class:hint', f' {txt} ({elapsed_str})')] return [('class:hint', f' {txt}')] def get_spinner_height(): @@ -8657,23 +9011,15 @@ class HermesCLI: # Periodic config watcher — auto-reload MCP on mcp_servers change if not self._agent_running: self._check_config_mcp_changes() - # Check for background process completion notifications - # while the agent is idle (user hasn't typed anything yet). + # Check for background process notifications (completions + # and watch pattern matches) while agent is idle. try: from tools.process_registry import process_registry if not process_registry.completion_queue.empty(): - completion = process_registry.completion_queue.get_nowait() - _exit = completion.get("exit_code", "?") - _cmd = completion.get("command", "unknown") - _sid = completion.get("session_id", "unknown") - _out = completion.get("output", "") - _synth = ( - f"[SYSTEM: Background process {_sid} completed " - f"(exit code {_exit}).\n" - f"Command: {_cmd}\n" - f"Output:\n{_out}]" - ) - self._pending_input.put(_synth) + evt = process_registry.completion_queue.get_nowait() + _synth = _format_process_notification(evt) + if _synth: + self._pending_input.put(_synth) except Exception: pass continue @@ -8771,6 +9117,7 @@ class HermesCLI: finally: self._agent_running = False self._spinner_text = "" + self._tool_start_time = 0.0 app.invalidate() # Refresh status line @@ -8790,25 +9137,15 @@ class HermesCLI: _cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}") threading.Thread(target=_restart_recording, daemon=True).start() - # Drain process completion notifications — any background - # process that finished with notify_on_complete while the - # agent was running (or before) gets auto-injected as a - # new user message so the agent can react to it. + # Drain process notifications (completions + watch matches) + # that arrived while the agent was running. try: from tools.process_registry import process_registry while not process_registry.completion_queue.empty(): - completion = process_registry.completion_queue.get_nowait() - _exit = completion.get("exit_code", "?") - _cmd = completion.get("command", "unknown") - _sid = completion.get("session_id", "unknown") - _out = completion.get("output", "") - _synth = ( - f"[SYSTEM: Background process {_sid} completed " - f"(exit code {_exit}).\n" - f"Command: {_cmd}\n" - f"Output:\n{_out}]" - ) - self._pending_input.put(_synth) + evt = process_registry.completion_queue.get_nowait() + _synth = _format_process_notification(evt) + if _synth: + self._pending_input.put(_synth) except Exception: pass # Non-fatal — don't break the main loop @@ -9111,6 +9448,7 @@ def main( model_override=turn_route["model"], runtime_override=turn_route["runtime"], route_label=turn_route["label"], + request_overrides=turn_route.get("request_overrides"), ): cli.agent.quiet_mode = True cli.agent.suppress_status_output = True diff --git a/cron/jobs.py b/cron/jobs.py index 4096d1fd81..47e0b66efa 100644 --- a/cron/jobs.py +++ b/cron/jobs.py @@ -31,7 +31,7 @@ except ImportError: # Configuration # ============================================================================= -HERMES_DIR = get_hermes_home() +HERMES_DIR = get_hermes_home().resolve() CRON_DIR = HERMES_DIR / "cron" JOBS_FILE = CRON_DIR / "jobs.json" OUTPUT_DIR = CRON_DIR / "output" @@ -338,10 +338,12 @@ def load_jobs() -> List[Dict[str, Any]]: save_jobs(jobs) logger.warning("Auto-repaired jobs.json (had invalid control characters)") return jobs - except Exception: - return [] - except IOError: - return [] + except Exception as e: + logger.error("Failed to auto-repair jobs.json: %s", e) + raise RuntimeError(f"Cron database corrupted and unrepairable: {e}") from e + except IOError as e: + logger.error("IOError reading jobs.json: %s", e) + raise RuntimeError(f"Failed to read cron database: {e}") from e def save_jobs(jobs: List[Dict[str, Any]]): @@ -452,6 +454,7 @@ def create_job( "last_run_at": None, "last_status": None, "last_error": None, + "last_delivery_error": None, # Delivery configuration "deliver": deliver, "origin": origin, # Tracks where job was created for "origin" delivery @@ -620,8 +623,8 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None, save_jobs(jobs) return - - save_jobs(jobs) + + logger.warning("mark_job_run: job_id %s not found, skipping save", job_id) def advance_next_run(job_id: str) -> bool: diff --git a/cron/scheduler.py b/cron/scheduler.py index 6a7f12acd6..0e04fb047b 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", - "wecom", "sms", "email", "webhook", "bluebubbles", + "wecom", "weixin", "sms", "email", "webhook", "bluebubbles", }) from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -234,6 +234,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option "dingtalk": Platform.DINGTALK, "feishu": Platform.FEISHU, "wecom": Platform.WECOM, + "weixin": Platform.WEIXIN, "email": Platform.EMAIL, "sms": Platform.SMS, "bluebubbles": Platform.BLUEBUBBLES, @@ -346,7 +347,42 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option return None -_SCRIPT_TIMEOUT = 120 # seconds +_DEFAULT_SCRIPT_TIMEOUT = 120 # seconds +# Backward-compatible module override used by tests and emergency monkeypatches. +_SCRIPT_TIMEOUT = _DEFAULT_SCRIPT_TIMEOUT + + +def _get_script_timeout() -> int: + """Resolve cron pre-run script timeout from module/env/config with a safe default.""" + if _SCRIPT_TIMEOUT != _DEFAULT_SCRIPT_TIMEOUT: + try: + timeout = int(float(_SCRIPT_TIMEOUT)) + if timeout > 0: + return timeout + except Exception: + logger.warning("Invalid patched _SCRIPT_TIMEOUT=%r; using env/config/default", _SCRIPT_TIMEOUT) + + env_value = os.getenv("HERMES_CRON_SCRIPT_TIMEOUT", "").strip() + if env_value: + try: + timeout = int(float(env_value)) + if timeout > 0: + return timeout + except Exception: + logger.warning("Invalid HERMES_CRON_SCRIPT_TIMEOUT=%r; using config/default", env_value) + + try: + cfg = load_config() or {} + cron_cfg = cfg.get("cron", {}) if isinstance(cfg, dict) else {} + configured = cron_cfg.get("script_timeout_seconds") + if configured is not None: + timeout = int(float(configured)) + if timeout > 0: + return timeout + except Exception as exc: + logger.debug("Failed to load cron script timeout from config: %s", exc) + + return _DEFAULT_SCRIPT_TIMEOUT def _run_job_script(script_path: str) -> tuple[bool, str]: @@ -393,17 +429,27 @@ def _run_job_script(script_path: str) -> tuple[bool, str]: if not path.is_file(): return False, f"Script path is not a file: {path}" + script_timeout = _get_script_timeout() + try: result = subprocess.run( [sys.executable, str(path)], capture_output=True, text=True, - timeout=_SCRIPT_TIMEOUT, + timeout=script_timeout, cwd=str(path.parent), ) stdout = (result.stdout or "").strip() stderr = (result.stderr or "").strip() + # Redact secrets from both stdout and stderr before any return path. + try: + from agent.redact import redact_sensitive_text + stdout = redact_sensitive_text(stdout) + stderr = redact_sensitive_text(stderr) + except Exception: + pass + if result.returncode != 0: parts = [f"Script exited with code {result.returncode}"] if stderr: @@ -412,17 +458,10 @@ def _run_job_script(script_path: str) -> tuple[bool, str]: parts.append(f"stdout:\n{stdout}") return False, "\n".join(parts) - # Redact any secrets that may appear in script output before - # they are injected into the LLM prompt context. - try: - from agent.redact import redact_sensitive_text - stdout = redact_sensitive_text(stdout) - except Exception: - pass return True, stdout except subprocess.TimeoutExpired: - return False, f"Script timed out after {_SCRIPT_TIMEOUT}s: {path}" + return False, f"Script timed out after {script_timeout}s: {path}" except Exception as exc: return False, f"Script execution failed: {exc}" @@ -646,6 +685,24 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: }, ) + fallback_model = _cfg.get("fallback_providers") or _cfg.get("fallback_model") or None + credential_pool = None + runtime_provider = str(turn_route["runtime"].get("provider") or "").strip().lower() + if runtime_provider: + try: + from agent.credential_pool import load_pool + pool = load_pool(runtime_provider) + if pool.has_credentials(): + credential_pool = pool + logger.info( + "Job '%s': loaded credential pool for provider %s with %d entries", + job_id, + runtime_provider, + len(pool.entries()), + ) + except Exception as e: + logger.debug("Job '%s': failed to load credential pool for %s: %s", job_id, runtime_provider, e) + agent = AIAgent( model=turn_route["model"], api_key=turn_route["runtime"].get("api_key"), @@ -657,6 +714,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: max_iterations=max_iterations, reasoning_config=reasoning_config, prefill_messages=prefill_messages, + fallback_model=fallback_model, + credential_pool=credential_pool, providers_allowed=pr.get("only"), providers_ignored=pr.get("ignore"), providers_order=pr.get("order"), @@ -711,7 +770,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: _cron_pool.shutdown(wait=False, cancel_futures=True) raise finally: - _cron_pool.shutdown(wait=False) + _cron_pool.shutdown(wait=False, cancel_futures=True) if _inactivity_timeout: # Build diagnostic summary from the agent's activity tracker. diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 4c6366cbe5..68e3b79c1d 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -9,7 +9,10 @@ INSTALL_DIR="/opt/hermes" # (cache/images, cache/audio, platforms/whatsapp, etc.) are created on # demand by the application — don't pre-create them here so new installs # get the consolidated layout from get_hermes_dir(). -mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills} +# The "home/" subdirectory is a per-profile HOME for subprocesses (git, +# ssh, gh, npm …). Without it those tools write to /root which is +# ephemeral and shared across profiles. See issue #4426. +mkdir -p "$HERMES_HOME"/{cron,sessions,logs,hooks,memories,skills,skins,plans,workspace,home} # .env if [ ! -f "$HERMES_HOME/.env" ]; then diff --git a/environments/tool_call_parsers/hermes_parser.py b/environments/tool_call_parsers/hermes_parser.py index c1902fd623..c6f911db04 100644 --- a/environments/tool_call_parsers/hermes_parser.py +++ b/environments/tool_call_parsers/hermes_parser.py @@ -49,6 +49,8 @@ class HermesToolCallParser(ToolCallParser): continue tc_data = json.loads(raw_json) + if "name" not in tc_data: + continue tool_calls.append( ChatCompletionMessageToolCall( id=f"call_{uuid.uuid4().hex[:8]}", diff --git a/environments/tool_call_parsers/mistral_parser.py b/environments/tool_call_parsers/mistral_parser.py index 50e98a6f86..a23684e873 100644 --- a/environments/tool_call_parsers/mistral_parser.py +++ b/environments/tool_call_parsers/mistral_parser.py @@ -89,6 +89,8 @@ class MistralToolCallParser(ToolCallParser): parsed = [parsed] for tc in parsed: + if "name" not in tc: + continue args = tc.get("arguments", {}) if isinstance(args, dict): args = json.dumps(args, ensure_ascii=False) diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index 022ebcae4e..ae2beda9ef 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -76,10 +76,15 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: except Exception as e: logger.warning("Channel directory: failed to build %s: %s", platform.value, e) - # Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history - for plat_name in ("telegram", "whatsapp", "signal", "email", "sms", "bluebubbles"): - if plat_name not in platforms: - platforms[plat_name] = _build_from_sessions(plat_name) + # Platforms that don't support direct channel enumeration get session-based + # discovery automatically. Skip infrastructure entries that aren't messaging + # platforms — everything else falls through to _build_from_sessions(). + _SKIP_SESSION_DISCOVERY = frozenset({"local", "api_server", "webhook"}) + for plat in Platform: + plat_name = plat.value + if plat_name in _SKIP_SESSION_DISCOVERY or plat_name in platforms: + continue + platforms[plat_name] = _build_from_sessions(plat_name) directory = { "updated_at": datetime.now().isoformat(), diff --git a/gateway/config.py b/gateway/config.py index e4f04d8911..bde52eb559 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -63,6 +63,7 @@ class Platform(Enum): WEBHOOK = "webhook" FEISHU = "feishu" WECOM = "wecom" + WEIXIN = "weixin" BLUEBUBBLES = "bluebubbles" @@ -261,6 +262,11 @@ class GatewayConfig: for platform, config in self.platforms.items(): if not config.enabled: continue + # Weixin requires both a token and an account_id + if platform == Platform.WEIXIN: + if config.extra.get("account_id") and (config.token or config.extra.get("token")): + connected.append(platform) + continue # Platforms that use token/api_key auth if config.token or config.api_key: connected.append(platform) @@ -536,6 +542,8 @@ def load_gateway_config() -> GatewayConfig: bridged["free_response_channels"] = platform_cfg["free_response_channels"] if "mention_patterns" in platform_cfg: bridged["mention_patterns"] = platform_cfg["mention_patterns"] + if plat == Platform.DISCORD and "channel_skill_bindings" in platform_cfg: + bridged["channel_skill_bindings"] = platform_cfg["channel_skill_bindings"] if not bridged: continue plat_data = platforms_data.setdefault(plat.value, {}) @@ -581,6 +589,12 @@ def load_gateway_config() -> GatewayConfig: if isinstance(ic, list): ic = ",".join(str(v) for v in ic) os.environ["DISCORD_IGNORED_CHANNELS"] = str(ic) + # allowed_channels: if set, bot ONLY responds in these channels (whitelist) + ac = discord_cfg.get("allowed_channels") + if ac is not None and not os.getenv("DISCORD_ALLOWED_CHANNELS"): + if isinstance(ac, list): + ac = ",".join(str(v) for v in ac) + os.environ["DISCORD_ALLOWED_CHANNELS"] = str(ac) # no_thread_channels: channels where bot responds directly without creating thread ntc = discord_cfg.get("no_thread_channels") if ntc is not None and not os.getenv("DISCORD_NO_THREAD_CHANNELS"): @@ -628,6 +642,8 @@ def load_gateway_config() -> GatewayConfig: os.environ["MATRIX_FREE_RESPONSE_ROOMS"] = str(frc) if "auto_thread" in matrix_cfg and not os.getenv("MATRIX_AUTO_THREAD"): os.environ["MATRIX_AUTO_THREAD"] = str(matrix_cfg["auto_thread"]).lower() + if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"): + os.environ["MATRIX_DM_MENTION_THREADS"] = str(matrix_cfg["dm_mention_threads"]).lower() except Exception as e: logger.warning( @@ -666,6 +682,7 @@ def load_gateway_config() -> GatewayConfig: Platform.SLACK: "SLACK_BOT_TOKEN", Platform.MATTERMOST: "MATTERMOST_TOKEN", Platform.MATRIX: "MATRIX_ACCESS_TOKEN", + Platform.WEIXIN: "WEIXIN_TOKEN", } for platform, pconfig in config.platforms.items(): if not pconfig.enabled: @@ -970,6 +987,44 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"), ) + # Weixin (personal WeChat via iLink Bot API) + weixin_token = os.getenv("WEIXIN_TOKEN") + weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID") + if weixin_token or weixin_account_id: + if Platform.WEIXIN not in config.platforms: + config.platforms[Platform.WEIXIN] = PlatformConfig() + config.platforms[Platform.WEIXIN].enabled = True + if weixin_token: + config.platforms[Platform.WEIXIN].token = weixin_token + extra = config.platforms[Platform.WEIXIN].extra + if weixin_account_id: + extra["account_id"] = weixin_account_id + weixin_base_url = os.getenv("WEIXIN_BASE_URL", "").strip() + if weixin_base_url: + extra["base_url"] = weixin_base_url.rstrip("/") + weixin_cdn_base_url = os.getenv("WEIXIN_CDN_BASE_URL", "").strip() + if weixin_cdn_base_url: + extra["cdn_base_url"] = weixin_cdn_base_url.rstrip("/") + weixin_dm_policy = os.getenv("WEIXIN_DM_POLICY", "").strip().lower() + if weixin_dm_policy: + extra["dm_policy"] = weixin_dm_policy + weixin_group_policy = os.getenv("WEIXIN_GROUP_POLICY", "").strip().lower() + if weixin_group_policy: + extra["group_policy"] = weixin_group_policy + weixin_allowed_users = os.getenv("WEIXIN_ALLOWED_USERS", "").strip() + if weixin_allowed_users: + extra["allow_from"] = weixin_allowed_users + weixin_group_allowed_users = os.getenv("WEIXIN_GROUP_ALLOWED_USERS", "").strip() + if weixin_group_allowed_users: + extra["group_allow_from"] = weixin_group_allowed_users + weixin_home = os.getenv("WEIXIN_HOME_CHANNEL", "").strip() + if weixin_home: + config.platforms[Platform.WEIXIN].home_channel = HomeChannel( + platform=Platform.WEIXIN, + chat_id=weixin_home, + name=os.getenv("WEIXIN_HOME_CHANNEL_NAME", "Home"), + ) + # BlueBubbles (iMessage) bluebubbles_server_url = os.getenv("BLUEBUBBLES_SERVER_URL") bluebubbles_password = os.getenv("BLUEBUBBLES_PASSWORD") diff --git a/gateway/delivery.py b/gateway/delivery.py index 294c9b8142..d7fa6afdbf 100644 --- a/gateway/delivery.py +++ b/gateway/delivery.py @@ -124,53 +124,6 @@ class DeliveryRouter: self.adapters = adapters or {} self.output_dir = get_hermes_home() / "cron" / "output" - def resolve_targets( - self, - deliver: Union[str, List[str]], - origin: Optional[SessionSource] = None - ) -> List[DeliveryTarget]: - """ - Resolve delivery specification to concrete targets. - - Args: - deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc. - origin: The source where the request originated (for "origin" target) - - Returns: - List of resolved delivery targets - """ - if isinstance(deliver, str): - deliver = [deliver] - - targets = [] - seen_platforms = set() - - for target_str in deliver: - target = DeliveryTarget.parse(target_str, origin) - - # Resolve home channel if needed - if target.chat_id is None and target.platform != Platform.LOCAL: - home = self.config.get_home_channel(target.platform) - if home: - target.chat_id = home.chat_id - else: - # No home channel configured, skip this platform - continue - - # Deduplicate - key = (target.platform, target.chat_id, target.thread_id) - if key not in seen_platforms: - seen_platforms.add(key) - targets.append(target) - - # Always include local if configured - if self.config.always_log_local: - local_key = (Platform.LOCAL, None, None) - if local_key not in seen_platforms: - targets.append(DeliveryTarget(platform=Platform.LOCAL)) - - return targets - async def deliver( self, content: str, @@ -299,19 +252,5 @@ class DeliveryRouter: return await adapter.send(target.chat_id, content, metadata=send_metadata or None) -def parse_deliver_spec( - deliver: Optional[Union[str, List[str]]], - origin: Optional[SessionSource] = None, - default: str = "origin" -) -> Union[str, List[str]]: - """ - Normalize a delivery specification. - - If None or empty, returns the default. - """ - if not deliver: - return default - return deliver - diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 132790e5bd..baada7e058 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -20,10 +20,13 @@ Requires: """ import asyncio +import hashlib import hmac import json import logging import os +import socket as _socket +import re import sqlite3 import time import uuid @@ -40,6 +43,7 @@ from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( BasePlatformAdapter, SendResult, + is_network_accessible, ) logger = logging.getLogger(__name__) @@ -282,6 +286,24 @@ def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str: return sha256(repr(subset).encode("utf-8")).hexdigest() +def _derive_chat_session_id( + system_prompt: Optional[str], + first_user_message: str, +) -> str: + """Derive a stable session ID from the conversation's first user message. + + OpenAI-compatible frontends (Open WebUI, LibreChat, etc.) send the full + conversation history with every request. The system prompt and first user + message are constant across all turns of the same conversation, so hashing + them produces a deterministic session ID that lets the API server reuse + the same Hermes session (and therefore the same Docker container sandbox + directory) across turns. + """ + seed = f"{system_prompt or ''}\n{first_user_message}" + digest = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16] + return f"api-{digest}" + + class APIServerAdapter(BasePlatformAdapter): """ OpenAI-compatible HTTP API server adapter. @@ -386,7 +408,8 @@ class APIServerAdapter(BasePlatformAdapter): Validate Bearer token from Authorization header. Returns None if auth is OK, or a 401 web.Response on failure. - If no API key is configured, all requests are allowed. + If no API key is configured, all requests are allowed (only when API + server is local). """ if not self._api_key: return None # No key configured — allow all (local-only use) @@ -554,8 +577,32 @@ class APIServerAdapter(BasePlatformAdapter): # Allow caller to continue an existing session by passing X-Hermes-Session-Id. # When provided, history is loaded from state.db instead of from the request body. + # + # Security: session continuation exposes conversation history, so it is + # only allowed when the API key is configured and the request is + # authenticated. Without this gate, any unauthenticated client could + # read arbitrary session history by guessing/enumerating session IDs. provided_session_id = request.headers.get("X-Hermes-Session-Id", "").strip() if provided_session_id: + if not self._api_key: + logger.warning( + "Session continuation via X-Hermes-Session-Id rejected: " + "no API key configured. Set API_SERVER_KEY to enable " + "session continuity." + ) + return web.json_response( + _openai_error( + "Session continuation requires API key authentication. " + "Configure API_SERVER_KEY to enable this feature." + ), + status=403, + ) + # Sanitize: reject control characters that could enable header injection. + if re.search(r'[\r\n\x00]', provided_session_id): + return web.json_response( + {"error": {"message": "Invalid session ID", "type": "invalid_request_error"}}, + status=400, + ) session_id = provided_session_id try: db = self._ensure_session_db() @@ -565,7 +612,16 @@ class APIServerAdapter(BasePlatformAdapter): logger.warning("Failed to load session history for %s: %s", session_id, e) history = [] else: - session_id = str(uuid.uuid4()) + # Derive a stable session ID from the conversation fingerprint so + # that consecutive messages from the same Open WebUI (or similar) + # conversation map to the same Hermes session. The first user + # message + system prompt are constant across all turns. + first_user = "" + for cm in conversation_messages: + if cm.get("role") == "user": + first_user = cm.get("content", "") + break + session_id = _derive_chat_session_id(system_prompt, first_user) # history already set from request body above completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" @@ -588,15 +644,35 @@ class APIServerAdapter(BasePlatformAdapter): _stream_q.put(delta) def _on_tool_progress(event_type, name, preview, args, **kwargs): - """Inject tool progress into the SSE stream for Open WebUI.""" + """Send tool progress as a separate SSE event. + + Previously, progress markers like ``⏰ list`` were injected + directly into ``delta.content``. OpenAI-compatible frontends + (Open WebUI, LobeChat, …) store ``delta.content`` verbatim as + the assistant message and send it back on subsequent requests. + After enough turns the model learns to *emit* the markers as + plain text instead of issuing real tool calls — silently + hallucinating tool results. See #6972. + + The fix: push a tagged tuple ``("__tool_progress__", payload)`` + onto the stream queue. The SSE writer emits it as a custom + ``event: hermes.tool.progress`` line that compliant frontends + can render for UX but will *not* persist into conversation + history. Clients that don't understand the custom event type + silently ignore it per the SSE specification. + """ if event_type != "tool.started": - return # Only show tool start events in chat stream + return if name.startswith("_"): - return # Skip internal events (_thinking) + return from agent.display import get_tool_emoji emoji = get_tool_emoji(name) label = preview or name - _stream_q.put(f"\n`{emoji} {label}`\n") + _stream_q.put(("__tool_progress__", { + "tool": name, + "emoji": emoji, + "label": label, + })) # Start agent in background. agent_ref is a mutable container # so the SSE writer can interrupt the agent on client disconnect. @@ -707,6 +783,29 @@ class APIServerAdapter(BasePlatformAdapter): } await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) + # Helper — route a queue item to the correct SSE event. + async def _emit(item): + """Write a single queue item to the SSE stream. + + Plain strings are sent as normal ``delta.content`` chunks. + Tagged tuples ``("__tool_progress__", payload)`` are sent + as a custom ``event: hermes.tool.progress`` SSE event so + frontends can display them without storing the markers in + conversation history. See #6972. + """ + if isinstance(item, tuple) and len(item) == 2 and item[0] == "__tool_progress__": + event_data = json.dumps(item[1]) + await response.write( + f"event: hermes.tool.progress\ndata: {event_data}\n\n".encode() + ) + else: + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + # Stream content chunks as they arrive from the agent loop = asyncio.get_event_loop() while True: @@ -720,12 +819,7 @@ class APIServerAdapter(BasePlatformAdapter): delta = stream_q.get_nowait() if delta is None: break - content_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + await _emit(delta) except _q.Empty: break break @@ -734,12 +828,7 @@ class APIServerAdapter(BasePlatformAdapter): if delta is None: # End of stream sentinel break - content_chunk = { - "id": completion_id, "object": "chat.completion.chunk", - "created": created, "model": model, - "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], - } - await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + await _emit(delta) # Get usage from completed agent usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} @@ -1341,6 +1430,7 @@ class APIServerAdapter(BasePlatformAdapter): result = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, + task_id="default", ) usage = { "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, @@ -1507,6 +1597,7 @@ class APIServerAdapter(BasePlatformAdapter): r = agent.run_conversation( user_message=user_message, conversation_history=conversation_history, + task_id="default", ) u = { "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, @@ -1658,8 +1749,16 @@ class APIServerAdapter(BasePlatformAdapter): if hasattr(sweep_task, "add_done_callback"): sweep_task.add_done_callback(self._background_tasks.discard) + # Refuse to start network-accessible without authentication + if is_network_accessible(self._host) and not self._api_key: + logger.error( + "[%s] Refusing to start: binding to %s requires API_SERVER_KEY. " + "Set API_SERVER_KEY or use the default 127.0.0.1.", + self.name, self._host, + ) + return False + # Port conflict detection — fail fast if port is already in use - import socket as _socket try: with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as _s: _s.settimeout(1) @@ -1675,6 +1774,14 @@ class APIServerAdapter(BasePlatformAdapter): await self._site.start() self._mark_connected() + if not self._api_key: + logger.warning( + "[%s] ⚠️ No API key configured (API_SERVER_KEY / platforms.api_server.key). " + "All requests will be accepted without authentication. " + "Set an API key for production deployments to prevent " + "unauthorized access to sessions, responses, and cron jobs.", + self.name, + ) logger.info( "[%s] API server listening on http://%s:%d (model: %s)", self.name, self._host, self._port, self._model_name, diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 2831eb98fa..b4c84f3119 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -6,10 +6,12 @@ and implement the required methods. """ import asyncio +import ipaddress import logging import os import random import re +import socket as _socket import subprocess import sys import uuid @@ -19,6 +21,41 @@ from urllib.parse import urlsplit logger = logging.getLogger(__name__) +def is_network_accessible(host: str) -> bool: + """Return True if *host* would expose the server beyond loopback. + + Loopback addresses (127.0.0.1, ::1, IPv4-mapped ::ffff:127.0.0.1) + are local-only. Unspecified addresses (0.0.0.0, ::) bind all + interfaces. Hostnames are resolved; DNS failure fails closed. + """ + try: + addr = ipaddress.ip_address(host) + if addr.is_loopback: + return False + # ::ffff:127.0.0.1 — Python reports is_loopback=False for mapped + # addresses, so check the underlying IPv4 explicitly. + if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback: + return False + return True + except ValueError: + # when host variable is a hostname, we should try to resolve below + pass + + try: + resolved = _socket.getaddrinfo( + host, None, _socket.AF_UNSPEC, _socket.SOCK_STREAM, + ) + # if the hostname resolves into at least one non-loopback address, + # then we consider it to be network accessible + for _family, _type, _proto, _canonname, sockaddr in resolved: + addr = ipaddress.ip_address(sockaddr[0]) + if not addr.is_loopback: + return True + return False + except (_socket.gaierror, OSError): + return True + + def _detect_macos_system_proxy() -> str | None: """Read the macOS system HTTP(S) proxy via ``scutil --proxy``. @@ -160,7 +197,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = ( ) -def _safe_url_for_log(url: str, max_len: int = 80) -> str: +def safe_url_for_log(url: str, max_len: int = 80) -> str: """Return a URL string safe for logs (no query/fragment/userinfo).""" if max_len <= 0: return "" @@ -197,6 +234,23 @@ def _safe_url_for_log(url: str, max_len: int = 80) -> str: return f"{safe[:max_len - 3]}..." +async def _ssrf_redirect_guard(response): + """Re-validate each redirect target to prevent redirect-based SSRF. + + Without this, an attacker can host a public URL that 302-redirects to + http://169.254.169.254/ and bypass the pre-flight is_safe_url() check. + + Must be async because httpx.AsyncClient awaits response event hooks. + """ + if response.is_redirect and response.next_request: + redirect_url = str(response.next_request.url) + from tools.url_safety import is_safe_url + if not is_safe_url(redirect_url): + raise ValueError( + f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}" + ) + + # --------------------------------------------------------------------------- # Image cache utilities # @@ -216,6 +270,23 @@ def get_image_cache_dir() -> Path: return IMAGE_CACHE_DIR +def _looks_like_image(data: bytes) -> bool: + """Return True if *data* starts with a known image magic-byte sequence.""" + if len(data) < 4: + return False + if data[:8] == b"\x89PNG\r\n\x1a\n": + return True + if data[:3] == b"\xff\xd8\xff": + return True + if data[:6] in (b"GIF87a", b"GIF89a"): + return True + if data[:2] == b"BM": + return True + if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": + return True + return False + + def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str: """ Save raw image bytes to the cache and return the absolute file path. @@ -226,7 +297,17 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str: Returns: Absolute path to the cached image file as a string. + + Raises: + ValueError: If *data* does not look like a valid image (e.g. an HTML + error page returned by the upstream server). """ + if not _looks_like_image(data): + snippet = data[:80].decode("utf-8", errors="replace") + raise ValueError( + f"Refusing to cache non-image data as {ext} " + f"(starts with: {snippet!r})" + ) cache_dir = get_image_cache_dir() filename = f"img_{uuid.uuid4().hex[:12]}{ext}" filepath = cache_dir / filename @@ -254,7 +335,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> """ from tools.url_safety import is_safe_url if not is_safe_url(url): - raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}") + raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}") import asyncio import httpx @@ -262,7 +343,11 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> _log = _logging.getLogger(__name__) last_exc = None - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_ssrf_redirect_guard]}, + ) as client: for attempt in range(retries + 1): try: response = await client.get( @@ -284,7 +369,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> "Media cache retry %d/%d for %s (%.1fs): %s", attempt + 1, retries, - _safe_url_for_log(url), + safe_url_for_log(url), wait, exc, ) @@ -369,7 +454,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> """ from tools.url_safety import is_safe_url if not is_safe_url(url): - raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}") + raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}") import asyncio import httpx @@ -377,7 +462,11 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> _log = _logging.getLogger(__name__) last_exc = None - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_ssrf_redirect_guard]}, + ) as client: for attempt in range(retries + 1): try: response = await client.get( @@ -399,7 +488,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> "Audio cache retry %d/%d for %s (%.1fs): %s", attempt + 1, retries, - _safe_url_for_log(url), + safe_url_for_log(url), wait, exc, ) @@ -502,6 +591,14 @@ class MessageType(Enum): COMMAND = "command" # /command style +class ProcessingOutcome(Enum): + """Result classification for message-processing lifecycle hooks.""" + + SUCCESS = "success" + FAILURE = "failure" + CANCELLED = "cancelled" + + @dataclass class MessageEvent: """ @@ -529,8 +626,9 @@ class MessageEvent: reply_to_message_id: Optional[str] = None reply_to_text: Optional[str] = None # Text of the replied-to message (for context injection) - # Auto-loaded skill for topic/channel bindings (e.g., Telegram DM Topics) - auto_skill: Optional[str] = None + # Auto-loaded skill(s) for topic/channel bindings (e.g., Telegram DM Topics, + # Discord channel_skill_bindings). A single name or ordered list. + auto_skill: Optional[str | list[str]] = None # Internal flag — set for synthetic events (e.g. background process # completion notifications) that must bypass user authorization checks. @@ -552,6 +650,9 @@ class MessageEvent: raw = parts[0][1:].lower() if parts else None if raw and "@" in raw: raw = raw.split("@", 1)[0] + # Reject file paths: valid command names never contain / + if raw and "/" in raw: + return None return raw def get_command_args(self) -> str: @@ -572,6 +673,32 @@ class SendResult: retryable: bool = False # True for transient connection errors — base will retry automatically +def merge_pending_message_event( + pending_messages: Dict[str, MessageEvent], + session_key: str, + event: MessageEvent, +) -> None: + """Store or merge a pending event for a session. + + Photo bursts/albums often arrive as multiple near-simultaneous PHOTO + events. Merge those into the existing queued event so the next turn sees + the whole burst, while non-photo follow-ups still replace the pending + event normally. + """ + existing = pending_messages.get(session_key) + if ( + existing + and getattr(existing, "message_type", None) == MessageType.PHOTO + and event.message_type == MessageType.PHOTO + ): + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) + if event.text: + existing.text = BasePlatformAdapter._merge_caption(existing.text, event.text) + return + pending_messages[session_key] = event + + # Error substrings that indicate a transient *connection* failure worth retrying. # "timeout" / "timed out" / "readtimeout" / "writetimeout" are intentionally # excluded: a read/write timeout on a non-idempotent call (e.g. send_message) @@ -625,6 +752,8 @@ class BasePlatformAdapter(ABC): # Gateway shutdown cancels these so an old gateway instance doesn't keep # working on a task after --replace or manual restarts. self._background_tasks: set[asyncio.Task] = set() + self._expected_cancelled_tasks: set[asyncio.Task] = set() + self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None # Chats where auto-TTS on voice input is disabled (set by /voice off) self._auto_tts_disabled_chats: set = set() # Chats where typing indicator is paused (e.g. during approval waits). @@ -713,6 +842,10 @@ class BasePlatformAdapter(ABC): an optional response string. """ self._message_handler = handler + + def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None: + """Set an optional handler for messages arriving during active sessions.""" + self._busy_session_handler = handler def set_session_store(self, session_store: Any) -> None: """ @@ -1133,7 +1266,7 @@ class BasePlatformAdapter(ABC): async def on_processing_start(self, event: MessageEvent) -> None: """Hook called when background processing begins.""" - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: """Hook called when background processing completes.""" async def _run_processing_hook(self, hook_name: str, *args: Any, **kwargs: Any) -> None: @@ -1294,7 +1427,18 @@ class BasePlatformAdapter(ABC): # session lifecycle and its cleanup races with the running task # (see PR #4926). cmd = event.get_command() - if cmd in ("approve", "deny", "status", "agents", "tasks", "stop", "new", "reset"): + if cmd in ( + "approve", + "deny", + "status", + "agents", + "tasks", + "stop", + "new", + "reset", + "background", + "restart", + ): logger.debug( "[%s] Command '/%s' bypassing active-session guard for %s", self.name, cmd, session_key, @@ -1313,19 +1457,19 @@ class BasePlatformAdapter(ABC): logger.error("[%s] Command '/%s' dispatch failed: %s", self.name, cmd, e, exc_info=True) return + if self._busy_session_handler is not None: + try: + if await self._busy_session_handler(event, session_key): + return + except Exception as e: + logger.error("[%s] Busy-session handler failed: %s", self.name, e, exc_info=True) + # Special case: photo bursts/albums frequently arrive as multiple near- # simultaneous messages. Queue them without interrupting the active run, # then process them immediately after the current task finishes. if event.message_type == MessageType.PHOTO: logger.debug("[%s] Queuing photo follow-up for session %s without interrupt", self.name, session_key) - existing = self._pending_messages.get(session_key) - if existing and existing.message_type == MessageType.PHOTO: - existing.media_urls.extend(event.media_urls) - existing.media_types.extend(event.media_types) - if event.text: - existing.text = self._merge_caption(existing.text, event.text) - else: - self._pending_messages[session_key] = event + merge_pending_message_event(self._pending_messages, session_key, event) return # Don't interrupt now - will run after current task completes # Default behavior for non-photo follow-ups: interrupt the running agent @@ -1352,6 +1496,7 @@ class BasePlatformAdapter(ABC): return if hasattr(task, "add_done_callback"): task.add_done_callback(self._background_tasks.discard) + task.add_done_callback(self._expected_cancelled_tasks.discard) @staticmethod def _get_human_delay() -> float: @@ -1488,7 +1633,7 @@ class BasePlatformAdapter(ABC): logger.info( "[%s] Sending image: %s (alt=%s)", self.name, - _safe_url_for_log(image_url), + safe_url_for_log(image_url), alt_text[:30] if alt_text else "", ) # Route animated GIFs through send_animation for proper playback @@ -1580,7 +1725,11 @@ class BasePlatformAdapter(ABC): # Determine overall success for the processing hook processing_ok = delivery_succeeded if delivery_attempted else not bool(response) - await self._run_processing_hook("on_processing_complete", event, processing_ok) + await self._run_processing_hook( + "on_processing_complete", + event, + ProcessingOutcome.SUCCESS if processing_ok else ProcessingOutcome.FAILURE, + ) # Check if there's a pending message that was queued during our processing if session_key in self._pending_messages: @@ -1599,10 +1748,14 @@ class BasePlatformAdapter(ABC): return # Already cleaned up except asyncio.CancelledError: - await self._run_processing_hook("on_processing_complete", event, False) + current_task = asyncio.current_task() + outcome = ProcessingOutcome.CANCELLED + if current_task is None or current_task not in self._expected_cancelled_tasks: + outcome = ProcessingOutcome.FAILURE + await self._run_processing_hook("on_processing_complete", event, outcome) raise except Exception as e: - await self._run_processing_hook("on_processing_complete", event, False) + await self._run_processing_hook("on_processing_complete", event, ProcessingOutcome.FAILURE) logger.error("[%s] Error handling message: %s", self.name, e, exc_info=True) # Send the error to the user so they aren't left with radio silence try: @@ -1646,10 +1799,12 @@ class BasePlatformAdapter(ABC): """ tasks = [task for task in self._background_tasks if not task.done()] for task in tasks: + self._expected_cancelled_tasks.add(task) task.cancel() if tasks: await asyncio.gather(*tasks, return_exceptions=True) self._background_tasks.clear() + self._expected_cancelled_tasks.clear() self._pending_messages.clear() self._active_sessions.clear() diff --git a/gateway/platforms/bluebubbles.py b/gateway/platforms/bluebubbles.py index 83f94d3bf8..f50cd9503c 100644 --- a/gateway/platforms/bluebubbles.py +++ b/gateway/platforms/bluebubbles.py @@ -207,9 +207,17 @@ class BlueBubblesAdapter(BasePlatformAdapter): self.webhook_port, self.webhook_path, ) + + # Register webhook with BlueBubbles server + # This is required for the server to know where to send events + await self._register_webhook() + return True async def disconnect(self) -> None: + # Unregister webhook before cleaning up + await self._unregister_webhook() + if self.client: await self.client.aclose() self.client = None @@ -218,6 +226,105 @@ class BlueBubblesAdapter(BasePlatformAdapter): self._runner = None self._mark_disconnected() + @property + def _webhook_url(self) -> str: + """Compute the external webhook URL for BlueBubbles registration.""" + host = self.webhook_host + if host in ("0.0.0.0", "127.0.0.1", "localhost", "::"): + host = "localhost" + return f"http://{host}:{self.webhook_port}{self.webhook_path}" + + async def _find_registered_webhooks(self, url: str) -> list: + """Return list of BB webhook entries matching *url*.""" + try: + res = await self._api_get("/api/v1/webhook") + data = res.get("data") + if isinstance(data, list): + return [wh for wh in data if wh.get("url") == url] + except Exception: + pass + return [] + + async def _register_webhook(self) -> bool: + """Register this webhook URL with the BlueBubbles server. + + BlueBubbles requires webhooks to be registered via API before + it will send events. Checks for an existing registration first + to avoid duplicates (e.g. after a crash without clean shutdown). + """ + if not self.client: + return False + + webhook_url = self._webhook_url + + # Crash resilience — reuse an existing registration if present + existing = await self._find_registered_webhooks(webhook_url) + if existing: + logger.info( + "[bluebubbles] webhook already registered: %s", webhook_url + ) + return True + + payload = { + "url": webhook_url, + "events": ["new-message", "updated-message", "message"], + } + + try: + res = await self._api_post("/api/v1/webhook", payload) + status = res.get("status", 0) + if 200 <= status < 300: + logger.info( + "[bluebubbles] webhook registered with server: %s", + webhook_url, + ) + return True + else: + logger.warning( + "[bluebubbles] webhook registration returned status %s: %s", + status, + res.get("message"), + ) + return False + except Exception as exc: + logger.warning( + "[bluebubbles] failed to register webhook with server: %s", + exc, + ) + return False + + async def _unregister_webhook(self) -> bool: + """Unregister this webhook URL from the BlueBubbles server. + + Removes *all* matching registrations to clean up any duplicates + left by prior crashes. + """ + if not self.client: + return False + + webhook_url = self._webhook_url + removed = False + + try: + for wh in await self._find_registered_webhooks(webhook_url): + wh_id = wh.get("id") + if wh_id: + res = await self.client.delete( + self._api_url(f"/api/v1/webhook/{wh_id}") + ) + res.raise_for_status() + removed = True + if removed: + logger.info( + "[bluebubbles] webhook unregistered: %s", webhook_url + ) + except Exception as exc: + logger.debug( + "[bluebubbles] failed to unregister webhook (non-critical): %s", + exc, + ) + return removed + # ------------------------------------------------------------------ # Chat GUID resolution # ------------------------------------------------------------------ @@ -826,3 +933,4 @@ class BlueBubblesAdapter(BasePlatformAdapter): asyncio.create_task(self.mark_read(session_chat_id)) return web.Response(text="ok") + diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index 8ed3769624..e83b902dfb 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -20,6 +20,7 @@ Configuration in config.yaml: import asyncio import logging import os +import re import time import uuid from datetime import datetime, timezone @@ -54,6 +55,8 @@ MAX_MESSAGE_LENGTH = 20000 DEDUP_WINDOW_SECONDS = 300 DEDUP_MAX_SIZE = 1000 RECONNECT_BACKOFF = [2, 5, 10, 30, 60] +_SESSION_WEBHOOKS_MAX = 500 +_DINGTALK_WEBHOOK_RE = re.compile(r'^https://api\.dingtalk\.com/') def check_dingtalk_requirements() -> bool: @@ -195,9 +198,15 @@ class DingTalkAdapter(BasePlatformAdapter): chat_id = conversation_id or sender_id chat_type = "group" if is_group else "dm" - # Store session webhook for reply routing + # Store session webhook for reply routing (validate origin to prevent SSRF) session_webhook = getattr(message, "session_webhook", None) or "" - if session_webhook and chat_id: + if session_webhook and chat_id and _DINGTALK_WEBHOOK_RE.match(session_webhook): + if len(self._session_webhooks) >= _SESSION_WEBHOOKS_MAX: + # Evict oldest entry to cap memory growth + try: + self._session_webhooks.pop(next(iter(self._session_webhooks))) + except StopIteration: + pass self._session_webhooks[chat_id] = session_webhook source = self.build_source( diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index a19b6d6663..dcf05a1625 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -49,6 +49,7 @@ from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, cache_image_from_url, cache_audio_from_url, @@ -422,6 +423,7 @@ class DiscordAdapter(BasePlatformAdapter): # Discord message limits MAX_MESSAGE_LENGTH = 2000 + _SPLIT_THRESHOLD = 1900 # near the 2000-char split point # Auto-disconnect from voice channel after this many seconds of inactivity VOICE_TIMEOUT = 300 @@ -433,6 +435,11 @@ class DiscordAdapter(BasePlatformAdapter): self._allowed_user_ids: set = set() # For button approval authorization # Voice channel state (per-guild) self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient + # Text batching: merge rapid successive messages (Telegram-style) + self._text_batch_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", "0.6")) + self._text_batch_split_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) + self._pending_text_batches: Dict[str, MessageEvent] = {} + self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task # Phase 2: voice listening @@ -599,22 +606,35 @@ class DiscordAdapter(BasePlatformAdapter): if not self._client.user or self._client.user not in message.mentions: return # "all" falls through to handle_message - - # If the message @mentions other users but NOT the bot, the - # sender is talking to someone else — stay silent. Only - # applies in server channels; in DMs the user is always - # talking to the bot (mentions are just references). - # Controlled by DISCORD_IGNORE_NO_MENTION (default: true). - _ignore_no_mention = os.getenv( - "DISCORD_IGNORE_NO_MENTION", "true" - ).lower() in ("true", "1", "yes") - if _ignore_no_mention and message.mentions and not isinstance(message.channel, discord.DMChannel): - _bot_mentioned = ( + + # Multi-agent filtering: if the message mentions specific bots + # but NOT this bot, the sender is talking to another agent — + # stay silent. Messages with no bot mentions (general chat) + # still fall through to _handle_message for the existing + # DISCORD_REQUIRE_MENTION check. + # + # This replaces the older DISCORD_IGNORE_NO_MENTION logic + # with bot-aware filtering that works correctly when multiple + # agents share a channel. + if not isinstance(message.channel, discord.DMChannel) and message.mentions: + _self_mentioned = ( self._client.user is not None and self._client.user in message.mentions ) - if not _bot_mentioned: - return # Talking to someone else, don't interrupt + _other_bots_mentioned = any( + m.bot and m != self._client.user + for m in message.mentions + ) + # If other bots are mentioned but we're not → not for us + if _other_bots_mentioned and not _self_mentioned: + return + # If humans are mentioned but we're not → not for us + # (preserves old DISCORD_IGNORE_NO_MENTION=true behavior) + _ignore_no_mention = os.getenv( + "DISCORD_IGNORE_NO_MENTION", "true" + ).lower() in ("true", "1", "yes") + if _ignore_no_mention and not _self_mentioned and not _other_bots_mentioned: + return await self._handle_message(message) @@ -748,14 +768,17 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message, "add_reaction"): await self._add_reaction(message, "👀") - async def on_processing_complete(self, event: MessageEvent, success: bool) -> None: + async def on_processing_complete(self, event: MessageEvent, outcome: ProcessingOutcome) -> None: """Swap the in-progress reaction for a final success/failure reaction.""" if not self._reactions_enabled(): return message = event.raw_message if hasattr(message, "add_reaction"): await self._remove_reaction(message, "👀") - await self._add_reaction(message, "✅" if success else "❌") + if outcome == ProcessingOutcome.SUCCESS: + await self._add_reaction(message, "✅") + elif outcome == ProcessingOutcome.FAILURE: + await self._add_reaction(message, "❌") async def send( self, @@ -764,18 +787,34 @@ class DiscordAdapter(BasePlatformAdapter): reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None ) -> SendResult: - """Send a message to a Discord channel.""" + """Send a message to a Discord channel or thread. + + When metadata contains a thread_id, the message is sent to that + thread instead of the parent channel identified by chat_id. + """ if not self._client: return SendResult(success=False, error="Not connected") try: - # Get the channel - channel = self._client.get_channel(int(chat_id)) - if not channel: - channel = await self._client.fetch_channel(int(chat_id)) + # Determine target channel: thread_id in metadata takes precedence. + thread_id = None + if metadata and metadata.get("thread_id"): + thread_id = metadata["thread_id"] - if not channel: - return SendResult(success=False, error=f"Channel {chat_id} not found") + if thread_id: + # Fetch the thread directly — threads are addressed by their own ID. + channel = self._client.get_channel(int(thread_id)) + if not channel: + channel = await self._client.fetch_channel(int(thread_id)) + if not channel: + return SendResult(success=False, error=f"Thread {thread_id} not found") + else: + # Get the parent channel + channel = self._client.get_channel(int(chat_id)) + if not channel: + channel = await self._client.fetch_channel(int(chat_id)) + if not channel: + return SendResult(success=False, error=f"Channel {chat_id} not found") # Format and split message if needed formatted = self.format_message(content) @@ -1238,9 +1277,8 @@ class DiscordAdapter(BasePlatformAdapter): try: await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) - from tools.transcription_tools import transcribe_audio, get_stt_model_from_config - stt_model = get_stt_model_from_config() - result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model) + from tools.transcription_tools import transcribe_audio + result = await asyncio.to_thread(transcribe_audio, wav_path) if not result.get("success"): return @@ -1867,14 +1905,42 @@ class DiscordAdapter(BasePlatformAdapter): chat_topic=chat_topic, ) + _parent_id = str(getattr(getattr(interaction, "channel", None), "parent_id", "") or "") + _skills = self._resolve_channel_skills(thread_id, _parent_id or None) event = MessageEvent( text=text, message_type=MessageType.TEXT, source=source, raw_message=interaction, + auto_skill=_skills, ) await self.handle_message(event) + def _resolve_channel_skills(self, channel_id: str, parent_id: str | None = None) -> list[str] | None: + """Look up auto-skill bindings for a Discord channel/forum thread. + + Config format (in platform extra): + channel_skill_bindings: + - id: "123456" + skills: ["skill-a", "skill-b"] + Also checks parent_id so forum threads inherit the forum's bindings. + """ + bindings = self.config.extra.get("channel_skill_bindings", []) + if not bindings: + return None + ids_to_check = {channel_id} + if parent_id: + ids_to_check.add(parent_id) + for entry in bindings: + entry_id = str(entry.get("id", "")) + if entry_id in ids_to_check: + skills = entry.get("skills") or entry.get("skill") + if isinstance(skills, str): + return [skills] + if isinstance(skills, list) and skills: + return list(dict.fromkeys(skills)) # dedup, preserve order + return None + def _thread_parent_channel(self, channel: Any) -> Any: """Return the parent text channel when invoked from a thread.""" return getattr(channel, "parent", None) or channel @@ -2228,6 +2294,7 @@ class DiscordAdapter(BasePlatformAdapter): # discord.require_mention: Require @mention in server channels (default: true) # discord.free_response_channels: Channel IDs where bot responds without mention # discord.ignored_channels: Channel IDs where bot NEVER responds (even when mentioned) + # discord.allowed_channels: If set, bot ONLY responds in these channels (whitelist) # discord.no_thread_channels: Channel IDs where bot responds directly without creating thread # discord.auto_thread: Auto-create thread on @mention in channels (default: true) @@ -2239,12 +2306,21 @@ class DiscordAdapter(BasePlatformAdapter): parent_channel_id = self._get_parent_channel_id(message.channel) if not isinstance(message.channel, discord.DMChannel): - # Check ignored channels first - never respond even when mentioned - ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "") - ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()} channel_ids = {str(message.channel.id)} if parent_channel_id: channel_ids.add(parent_channel_id) + + # Check allowed channels - if set, only respond in these channels + allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "") + if allowed_channels_raw: + allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()} + if not (channel_ids & allowed_channels): + logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids) + return + + # Check ignored channels - never respond even when mentioned + ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "") + ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()} if channel_ids & ignored_channels: logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids) return @@ -2449,6 +2525,10 @@ class DiscordAdapter(BasePlatformAdapter): if not event_text or not event_text.strip(): event_text = "(The user sent a message with no text content)" + _chan = message.channel + _parent_id = str(getattr(_chan, "parent_id", "") or "") + _chan_id = str(getattr(_chan, "id", "")) + _skills = self._resolve_channel_skills(_chan_id, _parent_id or None) event = MessageEvent( text=event_text, message_type=msg_type, @@ -2459,6 +2539,7 @@ class DiscordAdapter(BasePlatformAdapter): media_types=media_types, reply_to_message_id=str(message.reference.message_id) if message.reference else None, timestamp=message.created_at, + auto_skill=_skills, ) # Track thread participation so the bot won't require @mention for @@ -2466,7 +2547,80 @@ class DiscordAdapter(BasePlatformAdapter): if thread_id: self._track_thread(thread_id) - await self.handle_message(event) + # Only batch plain text messages — commands, media, etc. dispatch + # immediately since they won't be split by the Discord client. + if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0: + self._enqueue_text_event(event) + else: + await self.handle_message(event) + + # ------------------------------------------------------------------ + # Text message aggregation (handles Discord client-side splits) + # ------------------------------------------------------------------ + + def _text_batch_key(self, event: MessageEvent) -> str: + """Session-scoped key for text message batching.""" + from gateway.session import build_session_key + return build_session_key( + event.source, + group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), + thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False), + ) + + def _enqueue_text_event(self, event: MessageEvent) -> None: + """Buffer a text event and reset the flush timer. + + When Discord splits a long user message at 2000 chars, the chunks + arrive within a few hundred milliseconds. This merges them into + a single event before dispatching. + """ + key = self._text_batch_key(event) + existing = self._pending_text_batches.get(key) + chunk_len = len(event.text or "") + if existing is None: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending_text_batches[key] = event + else: + if event.text: + existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + if event.media_urls: + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) + + prior_task = self._pending_text_batch_tasks.get(key) + if prior_task and not prior_task.done(): + prior_task.cancel() + self._pending_text_batch_tasks[key] = asyncio.create_task( + self._flush_text_batch(key) + ) + + async def _flush_text_batch(self, key: str) -> None: + """Wait for the quiet period then dispatch the aggregated text. + + Uses a longer delay when the latest chunk is near Discord's 2000-char + split point, since a continuation chunk is almost certain. + """ + current_task = asyncio.current_task() + try: + pending = self._pending_text_batches.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + if last_len >= self._SPLIT_THRESHOLD: + delay = self._text_batch_split_delay_seconds + else: + delay = self._text_batch_delay_seconds + await asyncio.sleep(delay) + event = self._pending_text_batches.pop(key, None) + if not event: + return + logger.info( + "[Discord] Flushing text batch %s (%d chars)", + key, len(event.text or ""), + ) + await self.handle_message(event) + finally: + if self._pending_text_batch_tasks.get(key) is current_task: + self._pending_text_batch_tasks.pop(key, None) # --------------------------------------------------------------------------- diff --git a/gateway/platforms/email.py b/gateway/platforms/email.py index a54bd94bb2..d4261ccfb8 100644 --- a/gateway/platforms/email.py +++ b/gateway/platforms/email.py @@ -195,7 +195,11 @@ def _extract_attachments( ext = Path(filename).suffix.lower() if ext in _IMAGE_EXTS: - cached_path = cache_image_from_bytes(payload, ext) + try: + cached_path = cache_image_from_bytes(payload, ext) + except ValueError: + logger.debug("Skipping non-image attachment %s (invalid magic bytes)", filename) + continue attachments.append({ "path": cached_path, "filename": filename, diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index 6012a0f1c0..a88c7e52b9 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -264,6 +264,7 @@ class FeishuAdapterSettings: bot_name: str dedup_cache_size: int text_batch_delay_seconds: float + text_batch_split_delay_seconds: float text_batch_max_messages: int text_batch_max_chars: int media_batch_delay_seconds: float @@ -972,7 +973,8 @@ def _run_official_feishu_ws_client(ws_client: Any, adapter: Any) -> None: return await original_connect(*args, **kwargs) def _configure_with_overrides(conf: Any) -> Any: - assert original_configure is not None + if original_configure is None: + raise RuntimeError("Feishu _configure_with_overrides called but original_configure is None") result = original_configure(conf) _apply_runtime_ws_overrides() return result @@ -1014,6 +1016,10 @@ class FeishuAdapter(BasePlatformAdapter): """Feishu/Lark bot adapter.""" MAX_MESSAGE_LENGTH = 8000 + # Threshold for detecting Feishu client-side message splits. + # When a chunk is near the ~4096-char practical limit, a continuation + # is almost certain. + _SPLIT_THRESHOLD = 4000 # ========================================================================= # Lifecycle — init / settings / connect / disconnect @@ -1105,6 +1111,9 @@ class FeishuAdapter(BasePlatformAdapter): text_batch_delay_seconds=float( os.getenv("HERMES_FEISHU_TEXT_BATCH_DELAY_SECONDS", str(_DEFAULT_TEXT_BATCH_DELAY_SECONDS)) ), + text_batch_split_delay_seconds=float( + os.getenv("HERMES_FEISHU_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0") + ), text_batch_max_messages=max( 1, int(os.getenv("HERMES_FEISHU_TEXT_BATCH_MAX_MESSAGES", str(_DEFAULT_TEXT_BATCH_MAX_MESSAGES))), @@ -1152,6 +1161,7 @@ class FeishuAdapter(BasePlatformAdapter): self._bot_name = settings.bot_name self._dedup_cache_size = settings.dedup_cache_size self._text_batch_delay_seconds = settings.text_batch_delay_seconds + self._text_batch_split_delay_seconds = settings.text_batch_split_delay_seconds self._text_batch_max_messages = settings.text_batch_max_messages self._text_batch_max_chars = settings.text_batch_max_chars self._media_batch_delay_seconds = settings.media_batch_delay_seconds @@ -1180,6 +1190,8 @@ class FeishuAdapter(BasePlatformAdapter): lambda data: self._on_reaction_event("im.message.reaction.deleted_v1", data) ) .register_p2_card_action_trigger(self._on_card_action_trigger) + .register_p2_im_chat_member_bot_added_v1(self._on_bot_added_to_chat) + .register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat) .build() ) @@ -1570,13 +1582,18 @@ class FeishuAdapter(BasePlatformAdapter): return SendResult(success=False, error=f"Image file not found: {image_path}") try: - with open(image_path, "rb") as image_file: - body = self._build_image_upload_body( - image_type=_FEISHU_IMAGE_UPLOAD_TYPE, - image=image_file, - ) - request = self._build_image_upload_request(body) - upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request) + import io as _io + with open(image_path, "rb") as f: + image_bytes = f.read() + # Wrap in BytesIO so lark SDK's MultipartEncoder can read .name and .tell() + image_file = _io.BytesIO(image_bytes) + image_file.name = os.path.basename(image_path) + body = self._build_image_upload_body( + image_type=_FEISHU_IMAGE_UPLOAD_TYPE, + image=image_file, + ) + request = self._build_image_upload_request(body) + upload_response = await asyncio.to_thread(self._client.im.v1.image.create, request) image_key = self._extract_response_field(upload_response, "image_key") if not image_key: return self._response_error_result( @@ -2478,8 +2495,10 @@ class FeishuAdapter(BasePlatformAdapter): async def _enqueue_text_event(self, event: MessageEvent) -> None: """Debounce rapid Feishu text bursts into a single MessageEvent.""" key = self._text_batch_key(event) + chunk_len = len(event.text or "") existing = self._pending_text_batches.get(key) if existing is None: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] self._pending_text_batches[key] = event self._pending_text_batch_counts[key] = 1 self._schedule_text_batch_flush(key) @@ -2504,6 +2523,7 @@ class FeishuAdapter(BasePlatformAdapter): return existing.text = next_text + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] existing.timestamp = event.timestamp if event.message_id: existing.message_id = event.message_id @@ -2530,10 +2550,22 @@ class FeishuAdapter(BasePlatformAdapter): task_map[key] = asyncio.create_task(flush_fn(key)) async def _flush_text_batch(self, key: str) -> None: - """Flush a pending text batch after the quiet period.""" + """Flush a pending text batch after the quiet period. + + Uses a longer delay when the latest chunk is near Feishu's ~4096-char + split point, since a continuation chunk is almost certain. + """ current_task = asyncio.current_task() try: - await asyncio.sleep(self._text_batch_delay_seconds) + # Adaptive delay: if the latest chunk is near the split threshold, + # a continuation is almost certain — wait longer. + pending = self._pending_text_batches.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + if last_len >= self._SPLIT_THRESHOLD: + delay = self._text_batch_split_delay_seconds + else: + delay = self._text_batch_delay_seconds + await asyncio.sleep(delay) await self._flush_text_batch_now(key) finally: if self._pending_text_batch_tasks.get(key) is current_task: diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index e29ae379b3..409d2d6e4a 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -1,8 +1,8 @@ """Matrix gateway adapter. Connects to any Matrix homeserver (self-hosted or matrix.org) via the -matrix-nio Python SDK. Supports optional end-to-end encryption (E2EE) -when installed with ``pip install "matrix-nio[e2e]"``. +mautrix Python SDK. Supports optional end-to-end encryption (E2EE) +when installed with ``pip install "mautrix[encryption]"``. Environment variables: MATRIX_HOMESERVER Homeserver URL (e.g. https://matrix.example.org) @@ -18,12 +18,12 @@ Environment variables: MATRIX_REQUIRE_MENTION Require @mention in rooms (default: true) MATRIX_FREE_RESPONSE_ROOMS Comma-separated room IDs exempt from mention requirement MATRIX_AUTO_THREAD Auto-create threads for room messages (default: true) + MATRIX_DM_MENTION_THREADS Create a thread when bot is @mentioned in a DM (default: false) """ from __future__ import annotations import asyncio -import io import json import logging import mimetypes @@ -35,11 +35,61 @@ from typing import Any, Dict, Optional, Set from html import escape as _html_escape +try: + from mautrix.types import ( + ContentURI, + EventID, + EventType, + PaginationDirection, + PresenceState, + RoomCreatePreset, + RoomID, + SyncToken, + TrustState, + UserID, + ) +except ImportError: + # Stubs so the module is importable without mautrix installed. + # check_matrix_requirements() will return False and the adapter + # won't be instantiated in production, but tests may exercise + # adapter methods so stubs must have the right attributes. + ContentURI = EventID = RoomID = SyncToken = UserID = str # type: ignore[misc,assignment] + + class _EventTypeStub: # type: ignore[no-redef] + ROOM_MESSAGE = "m.room.message" + REACTION = "m.reaction" + ROOM_ENCRYPTED = "m.room.encrypted" + ROOM_NAME = "m.room.name" + EventType = _EventTypeStub # type: ignore[misc,assignment] + + class _PaginationDirectionStub: # type: ignore[no-redef] + BACKWARD = "b" + FORWARD = "f" + PaginationDirection = _PaginationDirectionStub # type: ignore[misc,assignment] + + class _PresenceStateStub: # type: ignore[no-redef] + ONLINE = "online" + OFFLINE = "offline" + UNAVAILABLE = "unavailable" + PresenceState = _PresenceStateStub # type: ignore[misc,assignment] + + class _RoomCreatePresetStub: # type: ignore[no-redef] + PRIVATE = "private_chat" + PUBLIC = "public_chat" + TRUSTED_PRIVATE = "trusted_private_chat" + RoomCreatePreset = _RoomCreatePresetStub # type: ignore[misc,assignment] + + class _TrustStateStub: # type: ignore[no-redef] + UNVERIFIED = 0 + VERIFIED = 1 + TrustState = _TrustStateStub # type: ignore[misc,assignment] + from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, MessageType, + ProcessingOutcome, SendResult, ) @@ -53,30 +103,27 @@ MAX_MESSAGE_LENGTH = 4000 # Uses get_hermes_home() so each profile gets its own Matrix store. from hermes_constants import get_hermes_dir as _get_hermes_dir _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") +_CRYPTO_PICKLE_PATH = _STORE_DIR / "crypto_store.pickle" # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 -# E2EE key export file for persistence across restarts. -_KEY_EXPORT_FILE = _STORE_DIR / "exported_keys.txt" -_KEY_EXPORT_PASSPHRASE = "hermes-matrix-e2ee-keys" - # Pending undecrypted events: cap and TTL for retry buffer. _MAX_PENDING_EVENTS = 100 _PENDING_EVENT_TTL = 300 # seconds — stop retrying after 5 min _E2EE_INSTALL_HINT = ( - "Install with: pip install 'matrix-nio[e2e]' " + "Install with: pip install 'mautrix[encryption]' " "(requires libolm C library)" ) def _check_e2ee_deps() -> bool: - """Return True if matrix-nio E2EE dependencies (python-olm) are available.""" + """Return True if mautrix E2EE dependencies (python-olm) are available.""" try: - from nio.crypto import ENCRYPTION_ENABLED - return bool(ENCRYPTION_ENABLED) + from mautrix.crypto import OlmMachine # noqa: F401 + return True except (ImportError, AttributeError): return False @@ -94,11 +141,11 @@ def check_matrix_requirements() -> bool: logger.warning("Matrix: MATRIX_HOMESERVER not set") return False try: - import nio # noqa: F401 + import mautrix # noqa: F401 except ImportError: logger.warning( - "Matrix: matrix-nio not installed. " - "Run: pip install 'matrix-nio[e2e]'" + "Matrix: mautrix not installed. " + "Run: pip install 'mautrix[encryption]'" ) return False @@ -120,6 +167,11 @@ def check_matrix_requirements() -> bool: class MatrixAdapter(BasePlatformAdapter): """Gateway adapter for Matrix (any homeserver).""" + # Threshold for detecting Matrix client-side message splits. + # When a chunk is near the ~4000-char practical limit, a continuation + # is almost certain. + _SPLIT_THRESHOLD = 3900 + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.MATRIX) @@ -145,7 +197,7 @@ class MatrixAdapter(BasePlatformAdapter): or os.getenv("MATRIX_DEVICE_ID", "") ) - self._client: Any = None # nio.AsyncClient + self._client: Any = None # mautrix.client.Client self._sync_task: Optional[asyncio.Task] = None self._closing = False self._startup_ts: float = 0.0 @@ -160,17 +212,32 @@ class MatrixAdapter(BasePlatformAdapter): self._processed_events_set: set = set() # Buffer for undecrypted events pending key receipt. - # Each entry: (room, event, timestamp) + # Each entry: (room_id, event, timestamp) self._pending_megolm: list = [] # Thread participation tracking (for require_mention bypass) self._bot_participated_threads: set = self._load_participated_threads() self._MAX_TRACKED_THREADS = 500 + # Mention/thread gating — parsed once from env vars. + self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") + free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") + self._free_rooms: Set[str] = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} + self._auto_thread: bool = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") + self._dm_mention_threads: bool = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes") + # Reactions: configurable via MATRIX_REACTIONS (default: true). self._reactions_enabled: bool = os.getenv( "MATRIX_REACTIONS", "true" ).lower() not in ("false", "0", "no") + self._pending_reactions: dict[tuple[str, str], str] = {} + + # Text batching: merge rapid successive messages (Telegram-style). + # Matrix clients split long messages around 4000 chars. + self._text_batch_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_DELAY_SECONDS", "0.6")) + self._text_batch_split_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) + self._pending_text_batches: Dict[str, MessageEvent] = {} + self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} def _is_duplicate_event(self, event_id) -> bool: """Return True if this event was already processed. Tracks the ID otherwise.""" @@ -191,21 +258,87 @@ class MatrixAdapter(BasePlatformAdapter): async def connect(self) -> bool: """Connect to the Matrix homeserver and start syncing.""" - import nio + from mautrix.api import HTTPAPI + from mautrix.client import Client + from mautrix.client.state_store import MemoryStateStore, MemorySyncStore if not self._homeserver: logger.error("Matrix: homeserver URL not configured") return False - # Determine store path and ensure it exists. - store_path = str(_STORE_DIR) + # Ensure store dir exists for E2EE key persistence. _STORE_DIR.mkdir(parents=True, exist_ok=True) + # Create the HTTP API layer. + api = HTTPAPI( + base_url=self._homeserver, + token=self._access_token or "", + ) + # Create the client. - # When a stable device_id is configured, pass it to the constructor - # so matrix-nio binds to it from the start (important for E2EE - # crypto-store persistence across restarts). - ctor_device_id = self._device_id or None + state_store = MemoryStateStore() + sync_store = MemorySyncStore() + client = Client( + mxid=UserID(self._user_id) if self._user_id else UserID(""), + device_id=self._device_id or None, + api=api, + state_store=state_store, + sync_store=sync_store, + ) + + self._client = client + + # Authenticate. + if self._access_token: + api.token = self._access_token + + # Validate the token and learn user_id / device_id. + try: + resp = await client.whoami() + resolved_user_id = getattr(resp, "user_id", "") or self._user_id + resolved_device_id = getattr(resp, "device_id", "") + if resolved_user_id: + self._user_id = str(resolved_user_id) + client.mxid = UserID(self._user_id) + + # Prefer user-configured device_id for stable E2EE identity. + effective_device_id = self._device_id or resolved_device_id + if effective_device_id: + client.device_id = effective_device_id + + logger.info( + "Matrix: using access token for %s%s", + self._user_id or "(unknown user)", + f" (device {effective_device_id})" if effective_device_id else "", + ) + except Exception as exc: + logger.error( + "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER: %s", + exc, + ) + await api.session.close() + return False + elif self._password and self._user_id: + try: + resp = await client.login( + identifier=self._user_id, + password=self._password, + device_name="Hermes Agent", + device_id=self._device_id or None, + ) + if resp and hasattr(resp, "device_id"): + client.device_id = resp.device_id + logger.info("Matrix: logged in as %s", self._user_id) + except Exception as exc: + logger.error("Matrix: login failed — %s", exc) + await api.session.close() + return False + else: + logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD") + await api.session.close() + return False + + # Set up E2EE if requested. if self._encryption: if not _check_e2ee_deps(): logger.error( @@ -213,177 +346,95 @@ class MatrixAdapter(BasePlatformAdapter): "Refusing to connect — encrypted rooms would silently fail.", _E2EE_INSTALL_HINT, ) + await api.session.close() return False try: - client = nio.AsyncClient( - self._homeserver, - self._user_id or "", - device_id=ctor_device_id, - store_path=store_path, - ) + from mautrix.crypto import OlmMachine + from mautrix.crypto.store import MemoryCryptoStore + + crypto_store = MemoryCryptoStore() + + # Restore persisted crypto state from a previous run. + # Uses HMAC to verify integrity before unpickling. + pickle_path = _CRYPTO_PICKLE_PATH + if pickle_path.exists(): + try: + import hashlib, hmac, pickle + raw = pickle_path.read_bytes() + # Format: 32-byte HMAC-SHA256 signature + pickle data. + if len(raw) > 32: + sig, payload = raw[:32], raw[32:] + # Key is derived from the device_id + user_id (stable per install). + hmac_key = f"{self._user_id}:{self._device_id}".encode() + expected = hmac.new(hmac_key, payload, hashlib.sha256).digest() + if hmac.compare_digest(sig, expected): + saved = pickle.loads(payload) # noqa: S301 + if isinstance(saved, MemoryCryptoStore): + crypto_store = saved + logger.info("Matrix: restored E2EE crypto store from %s", pickle_path) + else: + logger.warning("Matrix: crypto store HMAC mismatch — ignoring stale/tampered file") + except Exception as exc: + logger.warning("Matrix: could not restore crypto store: %s", exc) + + olm = OlmMachine(client, crypto_store, state_store) + + # Set trust policy: accept unverified devices so senders + # share Megolm session keys with us automatically. + olm.share_keys_min_trust = TrustState.UNVERIFIED + olm.send_keys_min_trust = TrustState.UNVERIFIED + + await olm.load() + client.crypto = olm logger.info( "Matrix: E2EE enabled (store: %s%s)", - store_path, - f", device_id={self._device_id}" if self._device_id else "", + str(_STORE_DIR), + f", device_id={client.device_id}" if client.device_id else "", ) except Exception as exc: logger.error( "Matrix: failed to create E2EE client: %s. %s", exc, _E2EE_INSTALL_HINT, ) + await api.session.close() return False - else: - client = nio.AsyncClient( - self._homeserver, - self._user_id or "", - device_id=ctor_device_id, - ) - self._client = client + # Register event handlers. + from mautrix.client import InternalEventType as IntEvt - # Authenticate. - if self._access_token: - client.access_token = self._access_token + client.add_event_handler(EventType.ROOM_MESSAGE, self._on_room_message) + client.add_event_handler(EventType.REACTION, self._on_reaction) + client.add_event_handler(IntEvt.INVITE, self._on_invite) - # With access-token auth, always resolve whoami so we validate the - # token and learn the device_id. The device_id matters for E2EE: - # without it, matrix-nio can send plain messages but may fail to - # decrypt inbound encrypted events or encrypt outbound room sends. - resp = await client.whoami() - if isinstance(resp, nio.WhoamiResponse): - resolved_user_id = getattr(resp, "user_id", "") or self._user_id - resolved_device_id = getattr(resp, "device_id", "") - if resolved_user_id: - self._user_id = resolved_user_id - - # Prefer the user-configured device_id (MATRIX_DEVICE_ID) so - # the bot reuses a stable identity across restarts. Fall back - # to whatever whoami returned. - effective_device_id = self._device_id or resolved_device_id - - # restore_login() is the matrix-nio path that binds the access - # token to a specific device and loads the crypto store. - if effective_device_id and hasattr(client, "restore_login"): - client.restore_login( - self._user_id or resolved_user_id, - effective_device_id, - self._access_token, - ) - else: - if self._user_id: - client.user_id = self._user_id - if effective_device_id: - client.device_id = effective_device_id - client.access_token = self._access_token - if self._encryption: - logger.warning( - "Matrix: access-token login did not restore E2EE state; " - "encrypted rooms may fail until a device_id is available. " - "Set MATRIX_DEVICE_ID to a stable value." - ) - - logger.info( - "Matrix: using access token for %s%s", - self._user_id or "(unknown user)", - f" (device {effective_device_id})" if effective_device_id else "", - ) - else: - logger.error( - "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" - ) - await client.close() - return False - elif self._password and self._user_id: - resp = await client.login( - self._password, - device_name="Hermes Agent", - ) - if isinstance(resp, nio.LoginResponse): - logger.info("Matrix: logged in as %s", self._user_id) - else: - logger.error("Matrix: login failed — %s", getattr(resp, "message", resp)) - await client.close() - return False - else: - logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD") - await client.close() - return False - - # If E2EE is enabled, load the crypto store. - if self._encryption and getattr(client, "olm", None): - try: - if client.should_upload_keys: - await client.keys_upload() - logger.info("Matrix: E2EE crypto initialized") - except Exception as exc: - logger.warning("Matrix: crypto init issue: %s", exc) - - # Import previously exported Megolm keys (survives restarts). - if _KEY_EXPORT_FILE.exists(): - try: - await client.import_keys( - str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, - ) - logger.info("Matrix: imported Megolm keys from backup") - except Exception as exc: - logger.debug("Matrix: could not import keys: %s", exc) - elif self._encryption: - # E2EE was requested but the crypto store failed to load — - # this means encrypted rooms will silently not work. Hard-fail. - logger.error( - "Matrix: E2EE requested but crypto store is not loaded — " - "cannot decrypt or encrypt messages. %s", - _E2EE_INSTALL_HINT, - ) - await client.close() - return False - - # Register event callbacks. - client.add_event_callback(self._on_room_message, nio.RoomMessageText) - client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage) - client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio) - client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo) - client.add_event_callback(self._on_room_message_media, nio.RoomMessageFile) - for encrypted_media_cls in ( - getattr(nio, "RoomEncryptedImage", None), - getattr(nio, "RoomEncryptedAudio", None), - getattr(nio, "RoomEncryptedVideo", None), - getattr(nio, "RoomEncryptedFile", None), - ): - if encrypted_media_cls is not None: - client.add_event_callback(self._on_room_message_media, encrypted_media_cls) - client.add_event_callback(self._on_invite, nio.InviteMemberEvent) - - # Reaction events (m.reaction). - if hasattr(nio, "ReactionEvent"): - client.add_event_callback(self._on_reaction, nio.ReactionEvent) - else: - # Older matrix-nio versions: use UnknownEvent fallback. - client.add_event_callback(self._on_unknown_event, nio.UnknownEvent) - - # If E2EE: handle encrypted events. - if self._encryption and hasattr(client, "olm"): - client.add_event_callback( - self._on_room_message, nio.MegolmEvent - ) + if self._encryption and getattr(client, "crypto", None): + client.add_event_handler(EventType.ROOM_ENCRYPTED, self._on_encrypted_event) # Initial sync to catch up, then start background sync. self._startup_ts = time.time() self._closing = False - # Do an initial sync to populate room state. - resp = await client.sync(timeout=10000, full_state=True) - if isinstance(resp, nio.SyncResponse): - self._joined_rooms = set(resp.rooms.join.keys()) - logger.info( - "Matrix: initial sync complete, joined %d rooms", - len(self._joined_rooms), - ) - # Build DM room cache from m.direct account data. - await self._refresh_dm_cache() - await self._run_e2ee_maintenance() - else: - logger.warning("Matrix: initial sync returned %s", type(resp).__name__) + try: + sync_data = await client.sync(timeout=10000, full_state=True) + if isinstance(sync_data, dict): + rooms_join = sync_data.get("rooms", {}).get("join", {}) + self._joined_rooms = set(rooms_join.keys()) + logger.info( + "Matrix: initial sync complete, joined %d rooms", + len(self._joined_rooms), + ) + # Build DM room cache from m.direct account data. + await self._refresh_dm_cache() + else: + logger.warning("Matrix: initial sync returned unexpected type %s", type(sync_data).__name__) + except Exception as exc: + logger.warning("Matrix: initial sync error: %s", exc) + + # Share keys after initial sync if E2EE is enabled. + if self._encryption and getattr(client, "crypto", None): + try: + await client.crypto.share_keys() + except Exception as exc: + logger.warning("Matrix: initial key share failed: %s", exc) # Start the sync loop. self._sync_task = asyncio.create_task(self._sync_loop()) @@ -401,20 +452,27 @@ class MatrixAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass - # Export Megolm keys before closing so the next restart can decrypt - # events that used sessions from this run. - if self._client and self._encryption and getattr(self._client, "olm", None): + # Persist E2EE crypto store before closing so the next restart + # can decrypt events using sessions from this run. + if self._client and self._encryption and getattr(self._client, "crypto", None): try: + import hashlib, hmac, pickle + crypto_store = self._client.crypto.crypto_store _STORE_DIR.mkdir(parents=True, exist_ok=True) - await self._client.export_keys( - str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE, - ) - logger.info("Matrix: exported Megolm keys for next restart") + pickle_path = _CRYPTO_PICKLE_PATH + payload = pickle.dumps(crypto_store) + hmac_key = f"{self._user_id}:{self._device_id}".encode() + sig = hmac.new(hmac_key, payload, hashlib.sha256).digest() + pickle_path.write_bytes(sig + payload) + logger.info("Matrix: persisted E2EE crypto store to %s", pickle_path) except Exception as exc: - logger.debug("Matrix: could not export keys on disconnect: %s", exc) + logger.debug("Matrix: could not persist crypto store on disconnect: %s", exc) if self._client: - await self._client.close() + try: + await self._client.api.session.close() + except Exception: + pass self._client = None logger.info("Matrix: disconnected") @@ -427,7 +485,6 @@ class MatrixAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a message to a Matrix room.""" - import nio if not content: return SendResult(success=True) @@ -465,69 +522,55 @@ class MatrixAdapter(BasePlatformAdapter): relates_to["m.in_reply_to"] = {"event_id": reply_to} msg_content["m.relates_to"] = relates_to - async def _room_send_once(*, ignore_unverified_devices: bool = False): - return await asyncio.wait_for( - self._client.room_send( - chat_id, - "m.room.message", + try: + event_id = await asyncio.wait_for( + self._client.send_message_event( + RoomID(chat_id), + EventType.ROOM_MESSAGE, msg_content, - ignore_unverified_devices=ignore_unverified_devices, ), timeout=45, ) - - try: - resp = await _room_send_once(ignore_unverified_devices=False) - except Exception as exc: - retryable = isinstance(exc, asyncio.TimeoutError) - olm_unverified = getattr(nio, "OlmUnverifiedDeviceError", None) - send_retry = getattr(nio, "SendRetryError", None) - if isinstance(olm_unverified, type) and isinstance(exc, olm_unverified): - retryable = True - if isinstance(send_retry, type) and isinstance(exc, send_retry): - retryable = True - - if not retryable: - logger.error("Matrix: failed to send to %s: %s", chat_id, exc) - return SendResult(success=False, error=str(exc)) - - logger.warning( - "Matrix: initial encrypted send to %s failed (%s); " - "retrying after E2EE maintenance with ignored unverified devices", - chat_id, - exc, - ) - await self._run_e2ee_maintenance() - try: - resp = await _room_send_once(ignore_unverified_devices=True) - except Exception as retry_exc: - logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc) - return SendResult(success=False, error=str(retry_exc)) - - if isinstance(resp, nio.RoomSendResponse): - last_event_id = resp.event_id + last_event_id = str(event_id) logger.info("Matrix: sent event %s to %s", last_event_id, chat_id) - else: - err = getattr(resp, "message", str(resp)) - logger.error("Matrix: failed to send to %s: %s", chat_id, err) - return SendResult(success=False, error=err) + except Exception as exc: + # On E2EE errors, retry after sharing keys. + if self._encryption and getattr(self._client, "crypto", None): + try: + await self._client.crypto.share_keys() + event_id = await asyncio.wait_for( + self._client.send_message_event( + RoomID(chat_id), + EventType.ROOM_MESSAGE, + msg_content, + ), + timeout=45, + ) + last_event_id = str(event_id) + logger.info("Matrix: sent event %s to %s (after key share)", last_event_id, chat_id) + continue + except Exception as retry_exc: + logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc) + return SendResult(success=False, error=str(retry_exc)) + logger.error("Matrix: failed to send to %s: %s", chat_id, exc) + return SendResult(success=False, error=str(exc)) return SendResult(success=True, message_id=last_event_id) async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Return room name and type (dm/group).""" name = chat_id - chat_type = "group" + chat_type = "dm" if await self._is_dm_room(chat_id) else "group" if self._client: - room = self._client.rooms.get(chat_id) - if room: - name = room.display_name or room.canonical_alias or chat_id - # Use DM cache. - if self._dm_rooms.get(chat_id, False): - chat_type = "dm" - elif room.member_count == 2: - chat_type = "dm" + try: + name_evt = await self._client.get_state_event( + RoomID(chat_id), EventType.ROOM_NAME, + ) + if name_evt and hasattr(name_evt, "name") and name_evt.name: + name = name_evt.name + except Exception: + pass return {"name": name, "type": chat_type} @@ -541,7 +584,7 @@ class MatrixAdapter(BasePlatformAdapter): """Send a typing indicator.""" if self._client: try: - await self._client.room_typing(chat_id, typing_state=True, timeout=30000) + await self._client.set_typing(RoomID(chat_id), timeout=30000) except Exception: pass @@ -549,7 +592,6 @@ class MatrixAdapter(BasePlatformAdapter): self, chat_id: str, message_id: str, content: str ) -> SendResult: """Edit an existing message (via m.replace).""" - import nio formatted = self.format_message(content) msg_content: Dict[str, Any] = { @@ -572,10 +614,13 @@ class MatrixAdapter(BasePlatformAdapter): msg_content["format"] = "org.matrix.custom.html" msg_content["formatted_body"] = f"* {html}" - resp = await self._client.room_send(chat_id, "m.room.message", msg_content) - if isinstance(resp, nio.RoomSendResponse): - return SendResult(success=True, message_id=resp.event_id) - return SendResult(success=False, error=getattr(resp, "message", str(resp))) + try: + event_id = await self._client.send_message_event( + RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content, + ) + return SendResult(success=True, message_id=str(event_id)) + except Exception as exc: + return SendResult(success=False, error=str(exc)) async def send_image( self, @@ -648,7 +693,7 @@ class MatrixAdapter(BasePlatformAdapter): ) -> SendResult: """Upload an audio file as a voice message (MSC3245 native voice).""" return await self._send_local_file( - chat_id, audio_path, "m.audio", caption, reply_to, + chat_id, audio_path, "m.audio", caption, reply_to, metadata=metadata, is_voice=True ) @@ -686,29 +731,23 @@ class MatrixAdapter(BasePlatformAdapter): is_voice: bool = False, ) -> SendResult: """Upload bytes to Matrix and send as a media message.""" - import nio # Upload to homeserver. - # nio expects a DataProvider (callable) or file-like object, not raw bytes. - # nio.upload() returns a tuple (UploadResponse|UploadError, Optional[Dict]) - resp, maybe_encryption_info = await self._client.upload( - io.BytesIO(data), - content_type=content_type, - filename=filename, - filesize=len(data), - ) - if not isinstance(resp, nio.UploadResponse): - err = getattr(resp, "message", str(resp)) - logger.error("Matrix: upload failed: %s", err) - return SendResult(success=False, error=err) - - mxc_url = resp.content_uri + try: + mxc_url = await self._client.upload_media( + data, + mime_type=content_type, + filename=filename, + ) + except Exception as exc: + logger.error("Matrix: upload failed: %s", exc) + return SendResult(success=False, error=str(exc)) # Build media message content. msg_content: Dict[str, Any] = { "msgtype": msgtype, "body": caption or filename, - "url": mxc_url, + "url": str(mxc_url), "info": { "mimetype": content_type, "size": len(data), @@ -732,10 +771,13 @@ class MatrixAdapter(BasePlatformAdapter): relates_to["is_falling_back"] = True msg_content["m.relates_to"] = relates_to - resp2 = await self._client.room_send(room_id, "m.room.message", msg_content) - if isinstance(resp2, nio.RoomSendResponse): - return SendResult(success=True, message_id=resp2.event_id) - return SendResult(success=False, error=getattr(resp2, "message", str(resp2))) + try: + event_id = await self._client.send_message_event( + RoomID(room_id), EventType.ROOM_MESSAGE, msg_content, + ) + return SendResult(success=True, message_id=str(event_id)) + except Exception as exc: + return SendResult(success=False, error=str(exc)) async def _send_local_file( self, @@ -767,37 +809,32 @@ class MatrixAdapter(BasePlatformAdapter): async def _sync_loop(self) -> None: """Continuously sync with the homeserver.""" - import nio - while not self._closing: try: - resp = await self._client.sync(timeout=30000) - if isinstance(resp, nio.SyncError): - if self._closing: - return - err_msg = str(getattr(resp, "message", resp)).lower() - if "m_unknown_token" in err_msg or "m_forbidden" in err_msg or "401" in err_msg: - logger.error( - "Matrix: permanent auth error from sync: %s — stopping sync", - getattr(resp, "message", resp), - ) - return - logger.warning( - "Matrix: sync returned %s: %s — retrying in 5s", - type(resp).__name__, - getattr(resp, "message", resp), - ) - await asyncio.sleep(5) - continue + sync_data = await self._client.sync(timeout=30000) + if isinstance(sync_data, dict): + # Update joined rooms from sync response. + rooms_join = sync_data.get("rooms", {}).get("join", {}) + if rooms_join: + self._joined_rooms.update(rooms_join.keys()) + + # Share keys periodically if E2EE is enabled. + if self._encryption and getattr(self._client, "crypto", None): + try: + await self._client.crypto.share_keys() + except Exception as exc: + logger.warning("Matrix: E2EE key share failed: %s", exc) + + # Retry any buffered undecrypted events. + if self._pending_megolm: + await self._retry_pending_decryptions() - await self._run_e2ee_maintenance() except asyncio.CancelledError: return except Exception as exc: if self._closing: return - # Detect permanent auth/permission failures that will never - # succeed on retry — stop syncing instead of looping forever. + # Detect permanent auth/permission failures. err_str = str(exc).lower() if "401" in err_str or "403" in err_str or "unauthorized" in err_str or "forbidden" in err_str: logger.error("Matrix: permanent auth error: %s — stopping sync", exc) @@ -805,98 +842,19 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: sync error: %s — retrying in 5s", exc) await asyncio.sleep(5) - async def _run_e2ee_maintenance(self) -> None: - """Run matrix-nio E2EE housekeeping between syncs. - - Hermes uses a custom sync loop instead of matrix-nio's sync_forever(), - so we need to explicitly drive the key management work that sync_forever() - normally handles for encrypted rooms. - - Also auto-trusts all devices (so senders share session keys with us) - and retries decryption for any buffered MegolmEvents. - """ - client = self._client - if not client or not self._encryption or not getattr(client, "olm", None): - return - - did_query_keys = client.should_query_keys - - tasks = [asyncio.create_task(client.send_to_device_messages())] - - if client.should_upload_keys: - tasks.append(asyncio.create_task(client.keys_upload())) - - if did_query_keys: - tasks.append(asyncio.create_task(client.keys_query())) - - if client.should_claim_keys: - users = client.get_users_for_key_claiming() - if users: - tasks.append(asyncio.create_task(client.keys_claim(users))) - - for task in asyncio.as_completed(tasks): - try: - await task - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning("Matrix: E2EE maintenance task failed: %s", exc) - - # After key queries, auto-trust all devices so senders share keys with - # us. For a bot this is the right default — we want to decrypt - # everything, not enforce manual verification. - if did_query_keys: - self._auto_trust_devices() - - # Retry any buffered undecrypted events now that new keys may have - # arrived (from key requests, key queries, or to-device forwarding). - if self._pending_megolm: - await self._retry_pending_decryptions() - - def _auto_trust_devices(self) -> None: - """Trust/verify all unverified devices we know about. - - When other clients see our device as verified, they proactively share - Megolm session keys with us. Without this, many clients will refuse - to include an unverified device in key distributions. - """ - client = self._client - if not client: - return - - device_store = getattr(client, "device_store", None) - if not device_store: - return - - own_device = getattr(client, "device_id", None) - trusted_count = 0 - - try: - # DeviceStore.__iter__ yields OlmDevice objects directly. - for device in device_store: - if getattr(device, "device_id", None) == own_device: - continue - if not getattr(device, "verified", False): - client.verify_device(device) - trusted_count += 1 - except Exception as exc: - logger.debug("Matrix: auto-trust error: %s", exc) - - if trusted_count: - logger.info("Matrix: auto-trusted %d new device(s)", trusted_count) - async def _retry_pending_decryptions(self) -> None: - """Retry decrypting buffered MegolmEvents after new keys arrive.""" - import nio - + """Retry decrypting buffered encrypted events after new keys arrive.""" client = self._client if not client or not self._pending_megolm: return + crypto = getattr(client, "crypto", None) + if not crypto: + return now = time.time() still_pending: list = [] - for room, event, ts in self._pending_megolm: + for room_id, event, ts in self._pending_megolm: # Drop events that have aged past the TTL. if now - ts > _PENDING_EVENT_TTL: logger.debug( @@ -906,39 +864,28 @@ class MatrixAdapter(BasePlatformAdapter): continue try: - decrypted = client.decrypt_event(event) + decrypted = await crypto.decrypt_megolm_event(event) except Exception: - # Still missing the key — keep in buffer. - still_pending.append((room, event, ts)) + still_pending.append((room_id, event, ts)) continue - if isinstance(decrypted, nio.MegolmEvent): - # decrypt_event returned the same undecryptable event. - still_pending.append((room, event, ts)) + if decrypted is None or decrypted is event: + still_pending.append((room_id, event, ts)) continue logger.info( - "Matrix: decrypted buffered event %s (%s)", + "Matrix: decrypted buffered event %s", getattr(event, "event_id", "?"), - type(decrypted).__name__, ) - # Route to the appropriate handler based on decrypted type. + # Route to the appropriate handler. + # Remove from dedup set so _on_room_message doesn't drop it + # (the encrypted event ID was already registered by _on_encrypted_event). + decrypted_id = str(getattr(decrypted, "event_id", getattr(event, "event_id", ""))) + if decrypted_id: + self._processed_events_set.discard(decrypted_id) try: - if isinstance(decrypted, nio.RoomMessageText): - await self._on_room_message(room, decrypted) - elif isinstance( - decrypted, - (nio.RoomMessageImage, nio.RoomMessageAudio, - nio.RoomMessageVideo, nio.RoomMessageFile), - ): - await self._on_room_message_media(room, decrypted) - else: - logger.debug( - "Matrix: decrypted event %s has unhandled type %s", - getattr(event, "event_id", "?"), - type(decrypted).__name__, - ) + await self._on_room_message(decrypted) except Exception as exc: logger.warning( "Matrix: error processing decrypted event %s: %s", @@ -951,92 +898,147 @@ class MatrixAdapter(BasePlatformAdapter): # Event callbacks # ------------------------------------------------------------------ - async def _on_room_message(self, room: Any, event: Any) -> None: - """Handle incoming text messages (and decrypted megolm events).""" - import nio + async def _on_room_message(self, event: Any) -> None: + """Handle incoming room message events (text, media).""" + room_id = str(getattr(event, "room_id", "")) + sender = str(getattr(event, "sender", "")) # Ignore own messages. - if event.sender == self._user_id: + if sender == self._user_id: return - # Deduplicate by event ID (nio can fire the same event more than once). - if self._is_duplicate_event(getattr(event, "event_id", None)): + # Deduplicate by event ID. + event_id = str(getattr(event, "event_id", "")) + if self._is_duplicate_event(event_id): return # Startup grace: ignore old messages from initial sync. - event_ts = getattr(event, "server_timestamp", 0) / 1000.0 + raw_ts = getattr(event, "timestamp", None) or getattr(event, "server_timestamp", None) or 0 + event_ts = raw_ts / 1000.0 if raw_ts else 0.0 if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: return - # Handle undecryptable MegolmEvents: request the missing session key - # and buffer the event for retry once the key arrives. - if isinstance(event, nio.MegolmEvent): - logger.warning( - "Matrix: could not decrypt event %s in %s — requesting key", - event.event_id, room.room_id, - ) - - # Ask other devices in the room to forward the session key. - try: - resp = await self._client.request_room_key(event) - if hasattr(resp, "event_id") or not isinstance(resp, Exception): - logger.debug( - "Matrix: room key request sent for session %s", - getattr(event, "session_id", "?"), - ) - except Exception as exc: - logger.debug("Matrix: room key request failed: %s", exc) - - # Buffer for retry on next maintenance cycle. - self._pending_megolm.append((room, event, time.time())) - if len(self._pending_megolm) > _MAX_PENDING_EVENTS: - self._pending_megolm = self._pending_megolm[-_MAX_PENDING_EVENTS:] + # Extract content from the event. + content = getattr(event, "content", None) + if content is None: return - # Skip edits (m.replace relation). - source_content = getattr(event, "source", {}).get("content", {}) + # Get msgtype — either from content object or raw dict. + if hasattr(content, "msgtype"): + msgtype = str(content.msgtype) + elif isinstance(content, dict): + msgtype = content.get("msgtype", "") + else: + msgtype = "" + + # Determine source content dict for relation/thread extraction. + if isinstance(content, dict): + source_content = content + elif hasattr(content, "serialize"): + source_content = content.serialize() + else: + source_content = {} + relates_to = source_content.get("m.relates_to", {}) + + # Skip edits (m.replace relation). if relates_to.get("rel_type") == "m.replace": return - body = getattr(event, "body", "") or "" - if not body: + # Ignore m.notice to prevent bot-to-bot loops (m.notice is the + # conventional msgtype for bot responses in the Matrix ecosystem). + if msgtype == "m.notice": return - # Determine chat type. - is_dm = self._dm_rooms.get(room.room_id, False) - if not is_dm and room.member_count == 2: - is_dm = True + # Dispatch by msgtype. + media_msgtypes = ("m.image", "m.audio", "m.video", "m.file") + if msgtype in media_msgtypes: + await self._handle_media_message(room_id, sender, event_id, event_ts, source_content, relates_to, msgtype) + elif msgtype == "m.text": + await self._handle_text_message(room_id, sender, event_id, event_ts, source_content, relates_to) + + async def _resolve_message_context( + self, + room_id: str, + sender: str, + event_id: str, + body: str, + source_content: dict, + relates_to: dict, + ) -> Optional[tuple]: + """Shared mention/thread/DM gating for text and media handlers. + + Returns (body, is_dm, chat_type, thread_id, display_name, source) + or None if the message should be dropped (mention gating). + """ + is_dm = await self._is_dm_room(room_id) chat_type = "dm" if is_dm else "group" - # Thread support. thread_id = None if relates_to.get("rel_type") == "m.thread": thread_id = relates_to.get("event_id") + formatted_body = source_content.get("formatted_body") + is_mentioned = self._is_bot_mentioned(body, formatted_body) + # Require-mention gating. if not is_dm: - free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") - free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} - require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - is_free_room = room.room_id in free_rooms + is_free_room = room_id in self._free_rooms in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) + if self._require_mention and not is_free_room and not in_bot_thread: + if not is_mentioned: + return None - formatted_body = source_content.get("formatted_body") - if require_mention and not is_free_room and not in_bot_thread: - if not self._is_bot_mentioned(body, formatted_body): - return + # DM mention-thread. + if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: + thread_id = event_id + self._track_thread(thread_id) - # Strip mention from body when present (including in DMs). - if self._is_bot_mentioned(body, source_content.get("formatted_body")): + # Strip mention from body. + if is_mentioned: body = self._strip_mention(body) - # Auto-thread: create a thread for non-DM, non-threaded messages. - if not is_dm and not thread_id: - auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - if auto_thread: - thread_id = event.event_id - self._track_thread(thread_id) + # Auto-thread. + if not is_dm and not thread_id and self._auto_thread: + thread_id = event_id + self._track_thread(thread_id) + + display_name = await self._get_display_name(room_id, sender) + source = self.build_source( + chat_id=room_id, + chat_type=chat_type, + user_id=sender, + user_name=display_name, + thread_id=thread_id, + ) + + if thread_id: + self._track_thread(thread_id) + + self._background_read_receipt(room_id, event_id) + + return body, is_dm, chat_type, thread_id, display_name, source + + async def _handle_text_message( + self, + room_id: str, + sender: str, + event_id: str, + event_ts: float, + source_content: dict, + relates_to: dict, + ) -> None: + """Process a text message event.""" + body = source_content.get("body", "") or "" + if not body: + return + + ctx = await self._resolve_message_context( + room_id, sender, event_id, body, source_content, relates_to, + ) + if ctx is None: + return + body, is_dm, chat_type, thread_id, display_name, source = ctx # Reply-to detection. reply_to = None @@ -1044,7 +1046,7 @@ class MatrixAdapter(BasePlatformAdapter): if in_reply_to: reply_to = in_reply_to.get("event_id") - # Strip reply fallback from body (Matrix prepends "> ..." lines). + # Strip reply fallback from body. if reply_to and body.startswith("> "): lines = body.split("\n") stripped = [] @@ -1060,161 +1062,105 @@ class MatrixAdapter(BasePlatformAdapter): stripped.append(line) body = "\n".join(stripped) if stripped else body - # Message type. msg_type = MessageType.TEXT if body.startswith(("!", "/")): msg_type = MessageType.COMMAND - source = self.build_source( - chat_id=room.room_id, - chat_type=chat_type, - user_id=event.sender, - user_name=self._get_display_name(room, event.sender), - thread_id=thread_id, - ) - msg_event = MessageEvent( text=body, message_type=msg_type, source=source, - raw_message=getattr(event, "source", {}), - message_id=event.event_id, + raw_message=source_content, + message_id=event_id, reply_to_message_id=reply_to, ) - if thread_id: - self._track_thread(thread_id) + if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0: + self._enqueue_text_event(msg_event) + else: + await self.handle_message(msg_event) - # Acknowledge receipt so the room shows as read (fire-and-forget). - self._background_read_receipt(room.room_id, event.event_id) - - await self.handle_message(msg_event) - - async def _on_room_message_media(self, room: Any, event: Any) -> None: - """Handle incoming media messages (images, audio, video, files).""" - import nio - - # Ignore own messages. - if event.sender == self._user_id: - return - - # Deduplicate by event ID. - if self._is_duplicate_event(getattr(event, "event_id", None)): - return - - # Startup grace. - event_ts = getattr(event, "server_timestamp", 0) / 1000.0 - if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: - return - - body = getattr(event, "body", "") or "" - url = getattr(event, "url", "") + async def _handle_media_message( + self, + room_id: str, + sender: str, + event_id: str, + event_ts: float, + source_content: dict, + relates_to: dict, + msgtype: str, + ) -> None: + """Process a media message event (image, audio, video, file).""" + body = source_content.get("body", "") or "" + url = source_content.get("url", "") # Convert mxc:// to HTTP URL for downstream processing. http_url = "" if url and url.startswith("mxc://"): http_url = self._mxc_to_http(url) - # Determine message type from event class. - # Use the MIME type from the event's content info when available, - # falling back to category-level MIME types for downstream matching - # (gateway/run.py checks startswith("image/"), startswith("audio/"), etc.) - source_content = getattr(event, "source", {}).get("content", {}) - if not isinstance(source_content, dict): - source_content = {} - event_content = getattr(event, "content", {}) - if not isinstance(event_content, dict): - event_content = {} - content_info = event_content.get("info") if isinstance(event_content, dict) else {} - if not isinstance(content_info, dict) or not content_info: - content_info = source_content.get("info", {}) if isinstance(source_content, dict) else {} - event_mimetype = ( - (content_info.get("mimetype") if isinstance(content_info, dict) else None) - or getattr(event, "mimetype", "") - or "" - ) - # For encrypted media, the URL may be in file.url instead of event.url. - file_content = source_content.get("file", {}) if isinstance(source_content, dict) else {} + # Extract MIME type from content info. + content_info = source_content.get("info", {}) + if not isinstance(content_info, dict): + content_info = {} + event_mimetype = content_info.get("mimetype", "") + + # For encrypted media, the URL may be in file.url. + file_content = source_content.get("file", {}) if not url and isinstance(file_content, dict): url = file_content.get("url", "") or "" if url and url.startswith("mxc://"): http_url = self._mxc_to_http(url) + is_encrypted_media = bool(file_content and isinstance(file_content, dict) and file_content.get("url")) + media_type = "application/octet-stream" msg_type = MessageType.DOCUMENT - - # Safely resolve encrypted media classes — they may not exist on older - # nio versions, and in test environments nio may be mocked (MagicMock - # auto-attributes are not valid types for isinstance). - def _safe_isinstance(obj, cls_name): - cls = getattr(nio, cls_name, None) - if cls is None or not isinstance(cls, type): - return False - return isinstance(obj, cls) - - is_encrypted_image = _safe_isinstance(event, "RoomEncryptedImage") - is_encrypted_audio = _safe_isinstance(event, "RoomEncryptedAudio") - is_encrypted_video = _safe_isinstance(event, "RoomEncryptedVideo") - is_encrypted_file = _safe_isinstance(event, "RoomEncryptedFile") - is_encrypted_media = any((is_encrypted_image, is_encrypted_audio, is_encrypted_video, is_encrypted_file)) is_voice_message = False - if isinstance(event, nio.RoomMessageImage) or is_encrypted_image: + if msgtype == "m.image": msg_type = MessageType.PHOTO media_type = event_mimetype or "image/png" - elif isinstance(event, nio.RoomMessageAudio) or is_encrypted_audio: + elif msgtype == "m.audio": if source_content.get("org.matrix.msc3245.voice") is not None: is_voice_message = True msg_type = MessageType.VOICE else: msg_type = MessageType.AUDIO media_type = event_mimetype or "audio/ogg" - elif isinstance(event, nio.RoomMessageVideo) or is_encrypted_video: + elif msgtype == "m.video": msg_type = MessageType.VIDEO media_type = event_mimetype or "video/mp4" elif event_mimetype: media_type = event_mimetype - # Cache media locally when downstream tools need a real file path: - # - photos (vision tools can't access MXC URLs) - # - voice messages (transcription tools need local files) - # - any encrypted media (HTTP fallback would point at ciphertext) + # Cache media locally when downstream tools need a real file path. cached_path = None should_cache_locally = ( msg_type == MessageType.PHOTO or is_voice_message or is_encrypted_media ) if should_cache_locally and url: try: - if is_voice_message: - download_resp = await self._client.download(mxc=url) - else: - download_resp = await self._client.download(url) - file_bytes = getattr(download_resp, "body", None) + file_bytes = await self._client.download_media(ContentURI(url)) if file_bytes is not None: if is_encrypted_media: - from nio.crypto.attachments import decrypt_attachment + from mautrix.crypto.attachments import decrypt_attachment - hashes_value = getattr(event, "hashes", None) - if hashes_value is None and isinstance(file_content, dict): - hashes_value = file_content.get("hashes") + hashes_value = file_content.get("hashes") if isinstance(file_content, dict) else None hash_value = hashes_value.get("sha256") if isinstance(hashes_value, dict) else None - key_value = getattr(event, "key", None) - if key_value is None and isinstance(file_content, dict): - key_value = file_content.get("key") + key_value = file_content.get("key") if isinstance(file_content, dict) else None if isinstance(key_value, dict): key_value = key_value.get("k") - iv_value = getattr(event, "iv", None) - if iv_value is None and isinstance(file_content, dict): - iv_value = file_content.get("iv") + iv_value = file_content.get("iv") if isinstance(file_content, dict) else None if key_value and hash_value and iv_value: file_bytes = decrypt_attachment(file_bytes, key_value, hash_value, iv_value) else: logger.warning( "[Matrix] Encrypted media event missing decryption metadata for %s", - event.event_id, + event_id, ) file_bytes = None @@ -1246,48 +1192,12 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as e: logger.warning("[Matrix] Failed to cache media: %s", e) - is_dm = self._dm_rooms.get(room.room_id, False) - if not is_dm and room.member_count == 2: - is_dm = True - chat_type = "dm" if is_dm else "group" - - # Thread/reply detection. - relates_to = source_content.get("m.relates_to", {}) - thread_id = None - if relates_to.get("rel_type") == "m.thread": - thread_id = relates_to.get("event_id") - - # Require-mention gating (media messages). - if not is_dm: - free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") - free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} - require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - is_free_room = room.room_id in free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) - - if require_mention and not is_free_room and not in_bot_thread: - formatted_body = source_content.get("formatted_body") - if not self._is_bot_mentioned(body, formatted_body): - return - - # Strip mention from body when present (including in DMs). - if self._is_bot_mentioned(body, source_content.get("formatted_body")): - body = self._strip_mention(body) - - # Auto-thread: create a thread for non-DM, non-threaded messages. - if not is_dm and not thread_id: - auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - if auto_thread: - thread_id = event.event_id - self._track_thread(thread_id) - - source = self.build_source( - chat_id=room.room_id, - chat_type=chat_type, - user_id=event.sender, - user_name=self._get_display_name(room, event.sender), - thread_id=thread_id, + ctx = await self._resolve_message_context( + room_id, sender, event_id, body, source_content, relates_to, ) + if ctx is None: + return + body, is_dm, chat_type, thread_id, display_name, source = ctx allow_http_fallback = bool(http_url) and not is_encrypted_media media_urls = [cached_path] if cached_path else ([http_url] if allow_http_fallback else None) @@ -1297,52 +1207,47 @@ class MatrixAdapter(BasePlatformAdapter): text=body, message_type=msg_type, source=source, - raw_message=getattr(event, "source", {}), - message_id=event.event_id, + raw_message=source_content, + message_id=event_id, media_urls=media_urls, media_types=media_types, ) - if thread_id: - self._track_thread(thread_id) - - # Acknowledge receipt so the room shows as read (fire-and-forget). - self._background_read_receipt(room.room_id, event.event_id) - await self.handle_message(msg_event) - async def _on_invite(self, room: Any, event: Any) -> None: + async def _on_encrypted_event(self, event: Any) -> None: + """Handle encrypted events that could not be auto-decrypted.""" + room_id = str(getattr(event, "room_id", "")) + event_id = str(getattr(event, "event_id", "")) + + if self._is_duplicate_event(event_id): + return + + logger.warning( + "Matrix: could not decrypt event %s in %s — buffering for retry", + event_id, room_id, + ) + + self._pending_megolm.append((room_id, event, time.time())) + if len(self._pending_megolm) > _MAX_PENDING_EVENTS: + self._pending_megolm = self._pending_megolm[-_MAX_PENDING_EVENTS:] + + async def _on_invite(self, event: Any) -> None: """Auto-join rooms when invited.""" - import nio - if not isinstance(event, nio.InviteMemberEvent): - return - - # Only process invites directed at us. - if event.state_key != self._user_id: - return - - if event.membership != "invite": - return + room_id = str(getattr(event, "room_id", "")) logger.info( - "Matrix: invited to %s by %s — joining", - room.room_id, event.sender, + "Matrix: invited to %s — joining", + room_id, ) try: - resp = await self._client.join(room.room_id) - if isinstance(resp, nio.JoinResponse): - self._joined_rooms.add(room.room_id) - logger.info("Matrix: joined %s", room.room_id) - # Refresh DM cache since new room may be a DM. - await self._refresh_dm_cache() - else: - logger.warning( - "Matrix: failed to join %s: %s", - room.room_id, getattr(resp, "message", resp), - ) + await self._client.join_room(RoomID(room_id)) + self._joined_rooms.add(room_id) + logger.info("Matrix: joined %s", room_id) + await self._refresh_dm_cache() except Exception as exc: - logger.warning("Matrix: error joining %s: %s", room.room_id, exc) + logger.warning("Matrix: error joining %s: %s", room_id, exc) # ------------------------------------------------------------------ # Reactions (send, receive, processing lifecycle) @@ -1350,12 +1255,13 @@ class MatrixAdapter(BasePlatformAdapter): async def _send_reaction( self, room_id: str, event_id: str, emoji: str, - ) -> bool: - """Send an emoji reaction to a message in a room.""" - import nio + ) -> Optional[str]: + """Send an emoji reaction to a message in a room. + Returns the reaction event_id on success, None on failure. + """ if not self._client: - return False + return None content = { "m.relates_to": { "rel_type": "m.annotation", @@ -1364,18 +1270,14 @@ class MatrixAdapter(BasePlatformAdapter): } } try: - resp = await self._client.room_send( - room_id, "m.reaction", content, - ignore_unverified_devices=True, + resp_event_id = await self._client.send_message_event( + RoomID(room_id), EventType.REACTION, content, ) - if isinstance(resp, nio.RoomSendResponse): - logger.debug("Matrix: sent reaction %s to %s", emoji, event_id) - return True - logger.debug("Matrix: reaction send failed: %s", resp) - return False + logger.debug("Matrix: sent reaction %s to %s", emoji, event_id) + return str(resp_event_id) except Exception as exc: logger.debug("Matrix: reaction send error: %s", exc) - return False + return None async def _redact_reaction( self, room_id: str, reaction_event_id: str, reason: str = "", @@ -1390,10 +1292,12 @@ class MatrixAdapter(BasePlatformAdapter): msg_id = event.message_id room_id = event.source.chat_id if msg_id and room_id: - await self._send_reaction(room_id, msg_id, "\U0001f440") + reaction_event_id = await self._send_reaction(room_id, msg_id, "\U0001f440") + if reaction_event_id: + self._pending_reactions[(room_id, msg_id)] = reaction_event_id async def on_processing_complete( - self, event: MessageEvent, success: bool, + self, event: MessageEvent, outcome: ProcessingOutcome, ) -> None: """Replace eyes with checkmark (success) or cross (failure).""" if not self._reactions_enabled: @@ -1402,49 +1306,104 @@ class MatrixAdapter(BasePlatformAdapter): room_id = event.source.chat_id if not msg_id or not room_id: return - # Note: Matrix doesn't support removing a specific reaction easily - # without tracking the reaction event_id. We send the new reaction; - # the eyes stays (acceptable UX — both are visible). + if outcome == ProcessingOutcome.CANCELLED: + return + reaction_key = (room_id, msg_id) + if reaction_key in self._pending_reactions: + eyes_event_id = self._pending_reactions.pop(reaction_key) + if not await self._redact_reaction(room_id, eyes_event_id): + logger.debug("Matrix: failed to redact eyes reaction %s", eyes_event_id) await self._send_reaction( - room_id, msg_id, "\u2705" if success else "\u274c", + room_id, + msg_id, + "\u2705" if outcome == ProcessingOutcome.SUCCESS else "\u274c", ) - async def _on_reaction(self, room: Any, event: Any) -> None: + async def _on_reaction(self, event: Any) -> None: """Handle incoming reaction events.""" - if event.sender == self._user_id: + sender = str(getattr(event, "sender", "")) + if sender == self._user_id: return - if self._is_duplicate_event(getattr(event, "event_id", None)): + event_id = str(getattr(event, "event_id", "")) + if self._is_duplicate_event(event_id): return - # Log for now; future: trigger agent actions based on emoji. - reacts_to = getattr(event, "reacts_to", "") - key = getattr(event, "key", "") - logger.info( - "Matrix: reaction %s from %s on %s in %s", - key, event.sender, reacts_to, room.room_id, + + room_id = str(getattr(event, "room_id", "")) + content = getattr(event, "content", None) + if content: + relates_to = content.get("m.relates_to", {}) if isinstance(content, dict) else getattr(content, "relates_to", {}) + reacts_to = "" + key = "" + if isinstance(relates_to, dict): + reacts_to = relates_to.get("event_id", "") + key = relates_to.get("key", "") + elif hasattr(relates_to, "event_id"): + reacts_to = str(getattr(relates_to, "event_id", "")) + key = str(getattr(relates_to, "key", "")) + logger.info( + "Matrix: reaction %s from %s on %s in %s", + key, sender, reacts_to, room_id, + ) + + # ------------------------------------------------------------------ + # Text message aggregation (handles Matrix client-side splits) + # ------------------------------------------------------------------ + + def _text_batch_key(self, event: MessageEvent) -> str: + """Session-scoped key for text message batching.""" + from gateway.session import build_session_key + return build_session_key( + event.source, + group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), + thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False), ) - async def _on_unknown_event(self, room: Any, event: Any) -> None: - """Fallback handler for events not natively parsed by matrix-nio. + def _enqueue_text_event(self, event: MessageEvent) -> None: + """Buffer a text event and reset the flush timer.""" + key = self._text_batch_key(event) + existing = self._pending_text_batches.get(key) + chunk_len = len(event.text or "") + if existing is None: + event._last_chunk_len = chunk_len # type: ignore[attr-defined] + self._pending_text_batches[key] = event + else: + if event.text: + existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text + existing._last_chunk_len = chunk_len # type: ignore[attr-defined] + if event.media_urls: + existing.media_urls.extend(event.media_urls) + existing.media_types.extend(event.media_types) - Catches m.reaction on older nio versions that lack ReactionEvent. - """ - source = getattr(event, "source", {}) - if source.get("type") != "m.reaction": - return - content = source.get("content", {}) - relates_to = content.get("m.relates_to", {}) - if relates_to.get("rel_type") != "m.annotation": - return - if source.get("sender") == self._user_id: - return - logger.info( - "Matrix: reaction %s from %s on %s in %s", - relates_to.get("key", "?"), - source.get("sender", "?"), - relates_to.get("event_id", "?"), - room.room_id, + prior_task = self._pending_text_batch_tasks.get(key) + if prior_task and not prior_task.done(): + prior_task.cancel() + self._pending_text_batch_tasks[key] = asyncio.create_task( + self._flush_text_batch(key) ) + async def _flush_text_batch(self, key: str) -> None: + """Wait for the quiet period then dispatch the aggregated text.""" + current_task = asyncio.current_task() + try: + pending = self._pending_text_batches.get(key) + last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 + if last_len >= self._SPLIT_THRESHOLD: + delay = self._text_batch_split_delay_seconds + else: + delay = self._text_batch_delay_seconds + await asyncio.sleep(delay) + event = self._pending_text_batches.pop(key, None) + if not event: + return + logger.info( + "[Matrix] Flushing text batch %s (%d chars)", + key, len(event.text or ""), + ) + await self.handle_message(event) + finally: + if self._pending_text_batch_tasks.get(key) is current_task: + self._pending_text_batch_tasks.pop(key, None) + # ------------------------------------------------------------------ # Read receipts # ------------------------------------------------------------------ @@ -1459,25 +1418,15 @@ class MatrixAdapter(BasePlatformAdapter): asyncio.ensure_future(_send()) async def send_read_receipt(self, room_id: str, event_id: str) -> bool: - """Send a read receipt (m.read) for an event. - - Also sets the fully-read marker so the room is marked as read - in all clients. - """ + """Send a read receipt (m.read) for an event.""" if not self._client: return False try: - if hasattr(self._client, "room_read_markers"): - await self._client.room_read_markers( - room_id, - fully_read_event=event_id, - read_event=event_id, - ) - else: - # Fallback for older matrix-nio. - await self._client.room_send( - room_id, "m.receipt", {"event_id": event_id}, - ) + await self._client.set_read_markers( + RoomID(room_id), + fully_read_event=EventID(event_id), + read_receipt=EventID(event_id), + ) logger.debug("Matrix: sent read receipt for %s in %s", event_id, room_id) return True except Exception as exc: @@ -1492,19 +1441,14 @@ class MatrixAdapter(BasePlatformAdapter): self, room_id: str, event_id: str, reason: str = "", ) -> bool: """Redact (delete) a message or event from a room.""" - import nio - if not self._client: return False try: - resp = await self._client.room_redact( - room_id, event_id, reason=reason, + await self._client.redact( + RoomID(room_id), EventID(event_id), reason=reason or None, ) - if isinstance(resp, nio.RoomRedactResponse): - logger.info("Matrix: redacted %s in %s", event_id, room_id) - return True - logger.warning("Matrix: redact failed: %s", resp) - return False + logger.info("Matrix: redacted %s in %s", event_id, room_id) + return True except Exception as exc: logger.warning("Matrix: redact error: %s", exc) return False @@ -1519,40 +1463,38 @@ class MatrixAdapter(BasePlatformAdapter): limit: int = 50, start: str = "", ) -> list: - """Fetch recent messages from a room. - - Returns a list of dicts with keys: event_id, sender, body, - timestamp, type. Uses the ``room_messages()`` API. - """ - import nio - + """Fetch recent messages from a room.""" if not self._client: return [] try: - resp = await self._client.room_messages( - room_id, - start=start or "", + resp = await self._client.get_messages( + RoomID(room_id), + direction=PaginationDirection.BACKWARD, + from_token=SyncToken(start) if start else None, limit=limit, - direction=nio.Api.MessageDirection.back - if hasattr(nio.Api, "MessageDirection") - else "b", ) except Exception as exc: - logger.warning("Matrix: room_messages failed for %s: %s", room_id, exc) + logger.warning("Matrix: get_messages failed for %s: %s", room_id, exc) return [] - if not isinstance(resp, nio.RoomMessagesResponse): - logger.warning("Matrix: room_messages returned %s", type(resp).__name__) + if not resp: return [] + events = getattr(resp, "chunk", []) or (resp.get("chunk", []) if isinstance(resp, dict) else []) messages = [] - for event in reversed(resp.chunk): - body = getattr(event, "body", "") or "" + for event in reversed(events): + body = "" + content = getattr(event, "content", None) + if content: + if hasattr(content, "body"): + body = content.body or "" + elif isinstance(content, dict): + body = content.get("body", "") messages.append({ - "event_id": getattr(event, "event_id", ""), - "sender": getattr(event, "sender", ""), + "event_id": str(getattr(event, "event_id", "")), + "sender": str(getattr(event, "sender", "")), "body": body, - "timestamp": getattr(event, "server_timestamp", 0), + "timestamp": getattr(event, "timestamp", 0) or getattr(event, "server_timestamp", 0), "type": type(event).__name__, }) return messages @@ -1569,56 +1511,39 @@ class MatrixAdapter(BasePlatformAdapter): is_direct: bool = False, preset: str = "private_chat", ) -> Optional[str]: - """Create a new Matrix room. - - Args: - name: Human-readable room name. - topic: Room topic. - invite: List of user IDs to invite. - is_direct: Mark as a DM room. - preset: One of private_chat, public_chat, trusted_private_chat. - - Returns the room_id on success, None on failure. - """ - import nio - + """Create a new Matrix room.""" if not self._client: return None try: - resp = await self._client.room_create( + preset_enum = { + "private_chat": RoomCreatePreset.PRIVATE, + "public_chat": RoomCreatePreset.PUBLIC, + "trusted_private_chat": RoomCreatePreset.TRUSTED_PRIVATE, + }.get(preset, RoomCreatePreset.PRIVATE) + invitees = [UserID(u) for u in (invite or [])] + room_id = await self._client.create_room( name=name or None, topic=topic or None, - invite=invite or [], + invitees=invitees, is_direct=is_direct, - preset=getattr( - nio.Api.RoomPreset if hasattr(nio.Api, "RoomPreset") else type("", (), {}), - preset, None, - ) or preset, + preset=preset_enum, ) - if isinstance(resp, nio.RoomCreateResponse): - room_id = resp.room_id - self._joined_rooms.add(room_id) - logger.info("Matrix: created room %s (%s)", room_id, name or "unnamed") - return room_id - logger.warning("Matrix: room_create failed: %s", resp) - return None + room_id_str = str(room_id) + self._joined_rooms.add(room_id_str) + logger.info("Matrix: created room %s (%s)", room_id_str, name or "unnamed") + return room_id_str except Exception as exc: - logger.warning("Matrix: room_create error: %s", exc) + logger.warning("Matrix: create_room error: %s", exc) return None async def invite_user(self, room_id: str, user_id: str) -> bool: """Invite a user to a room.""" - import nio - if not self._client: return False try: - resp = await self._client.room_invite(room_id, user_id) - if isinstance(resp, nio.RoomInviteResponse): - logger.info("Matrix: invited %s to %s", user_id, room_id) - return True - logger.warning("Matrix: invite failed: %s", resp) - return False + await self._client.invite_user(RoomID(room_id), UserID(user_id)) + logger.info("Matrix: invited %s to %s", user_id, room_id) + return True except Exception as exc: logger.warning("Matrix: invite error: %s", exc) return False @@ -1637,92 +1562,84 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: invalid presence state %r", state) return False try: - if hasattr(self._client, "set_presence"): - await self._client.set_presence(state, status_msg=status_msg or None) - logger.debug("Matrix: presence set to %s", state) - return True + presence_map = { + "online": PresenceState.ONLINE, + "offline": PresenceState.OFFLINE, + "unavailable": PresenceState.UNAVAILABLE, + } + await self._client.set_presence( + presence=presence_map[state], + status=status_msg or None, + ) + logger.debug("Matrix: presence set to %s", state) + return True except Exception as exc: logger.debug("Matrix: set_presence failed: %s", exc) - return False + return False # ------------------------------------------------------------------ # Emote & notice message types # ------------------------------------------------------------------ - async def send_emote( - self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, + async def _send_simple_message( + self, chat_id: str, text: str, msgtype: str, ) -> SendResult: - """Send an emote message (/me style action).""" - import nio - + """Send a simple message (emote, notice) with optional HTML formatting.""" if not self._client or not text: return SendResult(success=False, error="No client or empty text") - msg_content: Dict[str, Any] = { - "msgtype": "m.emote", - "body": text, - } + msg_content: Dict[str, Any] = {"msgtype": msgtype, "body": text} html = self._markdown_to_html(text) if html and html != text: msg_content["format"] = "org.matrix.custom.html" msg_content["formatted_body"] = html try: - resp = await self._client.room_send( - chat_id, "m.room.message", msg_content, - ignore_unverified_devices=True, + event_id = await self._client.send_message_event( + RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content, ) - if isinstance(resp, nio.RoomSendResponse): - return SendResult(success=True, message_id=resp.event_id) - return SendResult(success=False, error=str(resp)) + return SendResult(success=True, message_id=str(event_id)) except Exception as exc: return SendResult(success=False, error=str(exc)) + async def send_emote( + self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an emote message (/me style action).""" + return await self._send_simple_message(chat_id, text, "m.emote") + async def send_notice( self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a notice message (bot-appropriate, non-alerting).""" - import nio - - if not self._client or not text: - return SendResult(success=False, error="No client or empty text") - - msg_content: Dict[str, Any] = { - "msgtype": "m.notice", - "body": text, - } - html = self._markdown_to_html(text) - if html and html != text: - msg_content["format"] = "org.matrix.custom.html" - msg_content["formatted_body"] = html - - try: - resp = await self._client.room_send( - chat_id, "m.room.message", msg_content, - ignore_unverified_devices=True, - ) - if isinstance(resp, nio.RoomSendResponse): - return SendResult(success=True, message_id=resp.event_id) - return SendResult(success=False, error=str(resp)) - except Exception as exc: - return SendResult(success=False, error=str(exc)) + return await self._send_simple_message(chat_id, text, "m.notice") # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ - async def _refresh_dm_cache(self) -> None: - """Refresh the DM room cache from m.direct account data. + async def _is_dm_room(self, room_id: str) -> bool: + """Check if a room is a DM.""" + if self._dm_rooms.get(room_id, False): + return True + # Fallback: check member count via state store. + state_store = getattr(self._client, "state_store", None) if self._client else None + if state_store: + try: + members = await state_store.get_members(room_id) + if members and len(members) == 2: + return True + except Exception: + pass + return False - Tries the account_data API first, then falls back to parsing - the sync response's account_data for robustness. - """ + async def _refresh_dm_cache(self) -> None: + """Refresh the DM room cache from m.direct account data.""" if not self._client: return dm_data: Optional[Dict] = None - # Primary: try the dedicated account data endpoint. try: resp = await self._client.get_account_data("m.direct") if hasattr(resp, "content"): @@ -1730,21 +1647,7 @@ class MatrixAdapter(BasePlatformAdapter): elif isinstance(resp, dict): dm_data = resp except Exception as exc: - logger.debug("Matrix: get_account_data('m.direct') failed: %s — trying sync fallback", exc) - - # Fallback: parse from the client's account_data store (populated by sync). - if dm_data is None: - try: - # matrix-nio stores account data events on the client object - ad = getattr(self._client, "account_data", None) - if ad and isinstance(ad, dict) and "m.direct" in ad: - event = ad["m.direct"] - if hasattr(event, "content"): - dm_data = event.content - elif isinstance(event, dict): - dm_data = event - except Exception: - pass + logger.debug("Matrix: get_account_data('m.direct') failed: %s", exc) if dm_data is None: return @@ -1752,7 +1655,7 @@ class MatrixAdapter(BasePlatformAdapter): dm_room_ids: Set[str] = set() for user_id, rooms in dm_data.items(): if isinstance(rooms, list): - dm_room_ids.update(rooms) + dm_room_ids.update(str(r) for r in rooms) self._dm_rooms = { rid: (rid in dm_room_ids) @@ -1809,15 +1712,12 @@ class MatrixAdapter(BasePlatformAdapter): """Return True if the bot is mentioned in the message.""" if not body and not formatted_body: return False - # Check for full @user:server in body if self._user_id and self._user_id in body: return True - # Check for localpart with word boundaries (case-insensitive) if self._user_id and ":" in self._user_id: localpart = self._user_id.split(":")[0].lstrip("@") if localpart and re.search(r'\b' + re.escape(localpart) + r'\b', body, re.IGNORECASE): return True - # Check formatted_body for Matrix pill if formatted_body and self._user_id: if f"matrix.to/#/{self._user_id}" in formatted_body: return True @@ -1825,22 +1725,24 @@ class MatrixAdapter(BasePlatformAdapter): def _strip_mention(self, body: str) -> str: """Remove bot mention from message body.""" - # Remove full @user:server if self._user_id: body = body.replace(self._user_id, "") - # If still contains localpart mention, remove it if self._user_id and ":" in self._user_id: localpart = self._user_id.split(":")[0].lstrip("@") if localpart: body = re.sub(r'\b' + re.escape(localpart) + r'\b', '', body, flags=re.IGNORECASE) return body.strip() - def _get_display_name(self, room: Any, user_id: str) -> str: + async def _get_display_name(self, room_id: str, user_id: str) -> str: """Get a user's display name in a room, falling back to user_id.""" - if room and hasattr(room, "users"): - user = room.users.get(user_id) - if user and getattr(user, "display_name", None): - return user.display_name + state_store = getattr(self._client, "state_store", None) if self._client else None + if state_store: + try: + member = await state_store.get_member(room_id, user_id) + if member and getattr(member, "displayname", None): + return member.displayname + except Exception: + pass # Strip the @...:server format to just the localpart. if user_id.startswith("@") and ":" in user_id: return user_id[1:].split(":")[0] @@ -1848,13 +1750,9 @@ class MatrixAdapter(BasePlatformAdapter): def _mxc_to_http(self, mxc_url: str) -> str: """Convert mxc://server/media_id to an HTTP download URL.""" - # mxc://matrix.org/abc123 → https://matrix.org/_matrix/client/v1/media/download/matrix.org/abc123 - # Uses the authenticated client endpoint (spec v1.11+) instead of the - # deprecated /_matrix/media/v3/download/ path. if not mxc_url.startswith("mxc://"): return mxc_url parts = mxc_url[6:] # strip mxc:// - # Use our homeserver for download (federation handles the rest). return f"{self._homeserver}/_matrix/client/v1/media/download/{parts}" def _markdown_to_html(self, text: str) -> str: @@ -1872,16 +1770,12 @@ class MatrixAdapter(BasePlatformAdapter): md = _md.Markdown( extensions=["fenced_code", "tables", "nl2br", "sane_lists"], ) - # Remove the raw HTML preprocessor so